Explorar o código

Merge 78ad88a245c3d05ffadf23f00f5ae2e22bb4bbe2 into 3b79c8e09fc9d3056e978006d7989e0e1f70c6bc

Keith Carichner Jr hai 3 meses
pai
achega
f35ac00c6a
Modificáronse 3 ficheiros con 357 adicións e 43 borrados
  1. 8 2
      docs/configuration.md
  2. 339 41
      internal/glance/widget-dns-stats.go
  3. 10 0
      internal/glance/widget-utils.go

+ 8 - 2
docs/configuration.md

@@ -1786,8 +1786,14 @@ Only required when using AdGuard Home. The username used to log into the admin d
 ##### `password`
 Only required when using AdGuard Home. The password used to log into the admin dashboard. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`.
 
-##### `token`
-Only required when using Pi-hole. The API token which can be found in `Settings -> API -> Show API token`. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`.
+##### `token` (Deprecated)
+Only required when using Pi-hole major version 5 or earlier. The API token which can be found in `Settings -> API -> Show API token`. Can be specified from an environment variable using the syntax `${VARIABLE_NAME}`.
+
+##### `app-password`
+Only required when using Pi-hole. The App Password can be found in `Settings -> Web Interface / API -> Configure app password`.
+
+##### `pihole-version`
+Only required if using an older version of Pi-hole (major version 5 or earlier).
 
 ##### `hide-graph`
 Whether to hide the graph showing the number of queries over time.

+ 339 - 41
internal/glance/widget-dns-stats.go

@@ -1,12 +1,15 @@
 package glance
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"html/template"
-	"log/slog"
+	"io"
 	"net/http"
+	"os"
 	"sort"
 	"strings"
 	"time"
@@ -27,6 +30,8 @@ type dnsStatsWidget struct {
 	AllowInsecure  bool   `yaml:"allow-insecure"`
 	URL            string `yaml:"url"`
 	Token          string `yaml:"token"`
+	AppPassword    string `yaml:"app-password"`
+	PiHoleVersion  string `yaml:"pihole-version"`
 	Username       string `yaml:"username"`
 	Password       string `yaml:"password"`
 }
@@ -62,7 +67,7 @@ func (widget *dnsStatsWidget) update(ctx context.Context) {
 	if widget.Service == "adguard" {
 		stats, err = fetchAdguardStats(widget.URL, widget.AllowInsecure, widget.Username, widget.Password, widget.HideGraph)
 	} else {
-		stats, err = fetchPiholeStats(widget.URL, widget.AllowInsecure, widget.Token, widget.HideGraph)
+		stats, err = fetchPiholeStats(widget.URL, widget.AllowInsecure, widget.Token, widget.HideGraph, widget.PiHoleVersion, widget.AppPassword)
 	}
 
 	if !widget.canContinueUpdateAfterHandlingErr(err) {
@@ -227,7 +232,8 @@ func fetchAdguardStats(instanceURL string, allowInsecure bool, username, passwor
 	return stats, nil
 }
 
-type piholeStatsResponse struct {
+// Legacy Pi-hole stats response (before v6)
+type legacyPiholeStatsResponse struct {
 	TotalQueries      int                     `json:"dns_queries_today"`
 	QueriesSeries     piholeQueriesSeries     `json:"domains_over_time"`
 	BlockedQueries    int                     `json:"ads_blocked_today"`
@@ -237,6 +243,34 @@ type piholeStatsResponse struct {
 	DomainsBlocked    int                     `json:"domains_being_blocked"`
 }
 
+// Pi-hole v6+ response format
+type piholeStatsResponse struct {
+	Queries struct {
+		Total          int     `json:"total"`
+		Blocked        int     `json:"blocked"`
+		PercentBlocked float64 `json:"percent_blocked"`
+	} `json:"queries"`
+	Gravity struct {
+		DomainsBlocked int `json:"domains_being_blocked"`
+	} `json:"gravity"`
+	//Note we do not need the full structure. We extract the values needed
+	//Adding dummy fields to allow easier json parsing.
+	QueriesSeries piholeQueriesSeries `json:"domains_over_time"` // Will always be empty
+	BlockedSeries map[int64]int       `json:"ads_over_time"`     // Will always be empty.
+}
+
+type piholeTopDomainsResponse struct {
+	Domains			[]Domains	`json:"domains"`
+	TotalQueries 	int			`json:"total_queries"`
+	BlockedQueries 	int			`json:"blocked_queries"`
+	Took			float64		`json:"took"`
+}
+
+type Domains struct {
+	Domain	string	`json:"domain"`
+	Count 	int		`json:"count"`
+}
+
 // If the user has query logging disabled it's possible for domains_over_time to be returned as an
 // empty array rather than a map which will prevent unmashalling the rest of the data so we use
 // custom unmarshal behavior to fallback to an empty map.
@@ -275,18 +309,65 @@ func (p *piholeTopBlockedDomains) UnmarshalJSON(data []byte) error {
 	return nil
 }
 
-func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool) (*dnsStats, error) {
-	if token == "" {
-		return nil, errors.New("missing API token")
+// piholeGetSID retrieves a new SID from Pi-hole using the app password.
+func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, error) {
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth"
+	requestBody := []byte(`{"password":"` + appPassword + `"}`)
+
+	request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(requestBody))
+	if err != nil {
+		return "", errors.New("failed to create authentication request: " + err.Error())
 	}
+	request.Header.Set("Content-Type", "application/json")
 
-	requestURL := strings.TrimRight(instanceURL, "/") +
-		"/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
+	response, err := client.Do(request)
+	if err != nil {
+		return "", errors.New("failed to send authentication request: " + err.Error())
+	}
+	defer response.Body.Close()
+
+	if response.StatusCode != http.StatusOK {
+		return "", errors.New("authentication failed, received status: " + response.Status)
+	}
+
+	body, err := io.ReadAll(response.Body)
+	if err != nil {
+		return "", errors.New("failed to read authentication response: " + err.Error())
+	}
+
+	var jsonResponse struct {
+		Session struct {
+			SID string `json:"sid"`
+		} `json:"session"`
+	}
+
+	if err := json.Unmarshal(body, &jsonResponse); err != nil {
+		return "", errors.New("failed to parse authentication response: " + err.Error())
+	}
+
+	if jsonResponse.Session.SID == "" {
+		return "", errors.New("authentication response did not contain a valid SID")
+	}
+
+	return jsonResponse.Session.SID, nil
+}
+
+// checkPiholeSID checks if the SID is valid by checking HTTP response status code from /api/auth.
+func checkPiholeSID(instanceURL string, sid string, allowInsecure bool) error {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth"
 
 	request, err := http.NewRequest("GET", requestURL, nil)
 	if err != nil {
-		return nil, err
+		return err
 	}
+	request.Header.Set("x-ftl-sid", sid)
 
 	var client requestDoer
 	if !allowInsecure {
@@ -295,25 +376,163 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
 		client = defaultInsecureHTTPClient
 	}
 
-	responseJson, err := decodeJsonFromRequest[piholeStatsResponse](client, request)
+	response, err := client.Do(request)
 	if err != nil {
-		return nil, err
+		return err
+	}
+	defer response.Body.Close()
+
+	if response.StatusCode != http.StatusOK {
+		return errors.New("SID is invalid, received status: " + response.Status)
+	}
+
+	return nil
+}
+
+// fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+.
+func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true"
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return piholeTopDomainsResponse{}, err
+	}
+	request.Header.Set("x-ftl-sid", sid)
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	return decodeJsonFromRequest[piholeTopDomainsResponse](client, request)
+}
+
+// fetchPiholeSeries fetches the series data for Pi-hole v6+ (QueriesSeries and BlockedSeries).
+func fetchPiholeSeries(instanceURL string, sid string, allowInsecure bool) (piholeQueriesSeries, map[int64]int, error) {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/history"
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return nil, nil, err
 	}
+	request.Header.Set("x-ftl-sid", sid)
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	// Define the correct struct to match the API response
+	var responseJson struct {
+		History []struct {
+			Timestamp int64 `json:"timestamp"`
+			Total     int   `json:"total"`
+			Blocked   int   `json:"blocked"`
+		} `json:"history"`
+	}
+
+	err = decodeJsonInto(client, request, &responseJson)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	queriesSeries := make(piholeQueriesSeries)
+	blockedSeries := make(map[int64]int)
+
+	// Populate the series data from history array
+	for _, entry := range responseJson.History {
+		queriesSeries[entry.Timestamp] = entry.Total
+		blockedSeries[entry.Timestamp] = entry.Blocked
+	}
+
+	return queriesSeries, blockedSeries, nil
+}
+
+// Helper functions to process the responses
+func parsePiholeStats(r piholeStatsResponse, topDomains piholeTopDomainsResponse, noGraph bool) *dnsStats {
 
 	stats := &dnsStats{
-		TotalQueries:   responseJson.TotalQueries,
-		BlockedQueries: responseJson.BlockedQueries,
-		BlockedPercent: int(responseJson.BlockedPercentage),
-		DomainsBlocked: responseJson.DomainsBlocked,
+		TotalQueries:   r.Queries.Total,
+		BlockedQueries: r.Queries.Blocked,
+		BlockedPercent: int(r.Queries.PercentBlocked),
+		DomainsBlocked: r.Gravity.DomainsBlocked,
+	}
+
+	if len(topDomains.Domains) > 0 {
+		domains := make([]dnsStatsBlockedDomain, 0, len(topDomains.Domains))
+		for _, d := range topDomains.Domains {
+			domains = append(domains, dnsStatsBlockedDomain{
+				Domain:         d.Domain,
+				PercentBlocked: int(float64(d.Count) / float64(r.Queries.Blocked) * 100),
+			})
+		}
+
+		sort.Slice(domains, func(a, b int) bool {
+			return domains[a].PercentBlocked > domains[b].PercentBlocked
+		})
+		stats.TopBlockedDomains = domains[:min(len(domains), 5)]
+	}
+	if noGraph {
+		return stats
 	}
 
-	if len(responseJson.TopBlockedDomains) > 0 {
-		domains := make([]dnsStatsBlockedDomain, 0, len(responseJson.TopBlockedDomains))
+	// Pihole _should_ return data for the last 24 hours
+	if len(r.QueriesSeries) != 145 || len(r.BlockedSeries) != 145 {
+		return stats
+	}
+
+
+	var lowestTimestamp int64 = 0
+	for timestamp := range r.QueriesSeries {
+		if lowestTimestamp == 0 || timestamp < lowestTimestamp {
+			lowestTimestamp = timestamp
+		}
+	}
+	maxQueriesInSeries := 0
 
-		for domain, count := range responseJson.TopBlockedDomains {
+	for i := 0; i < 8; i++ {
+		queries := 0
+		blocked := 0
+		for j := 0; j < 18; j++ {
+			index := lowestTimestamp + int64(i*10800+j*600)
+			queries += r.QueriesSeries[index]
+			blocked += r.BlockedSeries[index]
+		}
+		if queries > maxQueriesInSeries {
+			maxQueriesInSeries = queries
+		}
+		stats.Series[i] = dnsStatsSeries{
+			Queries: queries,
+			Blocked: blocked,
+		}
+		if queries > 0 {
+			stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100)
+		}
+	}
+	for i := 0; i < 8; i++ {
+		stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100)
+	}
+	return stats
+}
+func parsePiholeStatsLegacy(r legacyPiholeStatsResponse, noGraph bool) *dnsStats {
+
+	stats := &dnsStats{
+		TotalQueries:   r.TotalQueries,
+		BlockedQueries: r.BlockedQueries,
+		BlockedPercent: int(r.BlockedPercentage),
+		DomainsBlocked: r.DomainsBlocked,
+	}
+	if len(r.TopBlockedDomains) > 0 {
+		domains := make([]dnsStatsBlockedDomain, 0, len(r.TopBlockedDomains))
+
+		for domain, count := range r.TopBlockedDomains {
 			domains = append(domains, dnsStatsBlockedDomain{
 				Domain:         domain,
-				PercentBlocked: int(float64(count) / float64(responseJson.BlockedQueries) * 100),
+				PercentBlocked: int(float64(count) / float64(r.BlockedQueries) * 100),
 			})
 		}
 
@@ -323,59 +542,138 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
 
 		stats.TopBlockedDomains = domains[:min(len(domains), 5)]
 	}
-
 	if noGraph {
-		return stats, nil
+		return stats
 	}
 
 	// Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144
-	if len(responseJson.QueriesSeries) != 144 || len(responseJson.BlockedSeries) != 144 {
-		slog.Warn(
-			"DNS stats for pihole: did not get expected 144 data points",
-			"len(queries)", len(responseJson.QueriesSeries),
-			"len(blocked)", len(responseJson.BlockedSeries),
-		)
-		return stats, nil
+	if len(r.QueriesSeries) != 144 || len(r.BlockedSeries) != 144 {
+		return stats
 	}
 
 	var lowestTimestamp int64 = 0
-
-	for timestamp := range responseJson.QueriesSeries {
+	for timestamp := range r.QueriesSeries {
 		if lowestTimestamp == 0 || timestamp < lowestTimestamp {
 			lowestTimestamp = timestamp
 		}
 	}
-
 	maxQueriesInSeries := 0
 
 	for i := 0; i < 8; i++ {
 		queries := 0
 		blocked := 0
-
 		for j := 0; j < 18; j++ {
 			index := lowestTimestamp + int64(i*10800+j*600)
-
-			queries += responseJson.QueriesSeries[index]
-			blocked += responseJson.BlockedSeries[index]
+			queries += r.QueriesSeries[index]
+			blocked += r.BlockedSeries[index]
 		}
-
 		if queries > maxQueriesInSeries {
 			maxQueriesInSeries = queries
 		}
-
 		stats.Series[i] = dnsStatsSeries{
 			Queries: queries,
 			Blocked: blocked,
 		}
-
 		if queries > 0 {
 			stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100)
 		}
 	}
-
 	for i := 0; i < 8; i++ {
 		stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100)
 	}
-
-	return stats, nil
+	return stats
 }
+
+func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) {
+	instanceURL = strings.TrimRight(instanceURL, "/")
+	var requestURL string
+	var sid string
+	isV6 := version == "" || version == "6"
+
+	if isV6 {
+		if appPassword == "" {
+			return nil, errors.New("missing app password")
+		}
+
+		sid = os.Getenv("SID")
+		if sid == "" {
+			newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
+			if err != nil {
+				return nil, fmt.Errorf("failed to get SID: %w", err)
+			}
+			sid = newSid
+			os.Setenv("SID", sid)
+		} else {
+			err := checkPiholeSID(instanceURL, sid, allowInsecure)
+			if err != nil {
+				newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
+				if err != nil {
+					return nil, fmt.Errorf("failed to get SID after invalid check: %w", err)
+				}
+				sid = newSid
+				os.Setenv("SID", sid)
+			}
+		}
+
+		requestURL = instanceURL + "/api/stats/summary"
+	} else {
+		if token == "" {
+			return nil, errors.New("missing API token")
+		}
+		requestURL = instanceURL + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
+	}
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return nil, fmt.Errorf("failed to create HTTP request: %w", err)
+	}
+
+	if isV6 {
+		request.Header.Set("x-ftl-sid", sid)
+	}
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	var responseJson interface{}
+	if isV6 {
+		responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request)
+	} else {
+		responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request)
+	}
+
+	if err != nil {
+		return nil, fmt.Errorf("failed to decode JSON response: %w", err)
+	}
+
+	switch r := responseJson.(type) {
+	case piholeStatsResponse:
+		// Fetch top domains separately for v6+
+		topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure)
+		if err != nil {
+			return nil, fmt.Errorf("failed to fetch top domains: %w", err)
+		}
+
+		// Fetch series data separately for v6+
+		queriesSeries, blockedSeries, err := fetchPiholeSeries(instanceURL, sid, allowInsecure)
+		if err != nil {
+			return nil, fmt.Errorf("failed to fetch queries series: %w", err)
+		}
+
+		// Merge series data
+		r.QueriesSeries = queriesSeries
+		r.BlockedSeries = blockedSeries
+		
+		return parsePiholeStats(r, topDomains, noGraph), nil
+
+	case legacyPiholeStatsResponse:
+		return parsePiholeStatsLegacy(r, noGraph), nil
+
+	default:
+		return nil, errors.New("unexpected response type")
+	}
+}

+ 10 - 0
internal/glance/widget-utils.go

@@ -82,6 +82,16 @@ func decodeJsonFromRequest[T any](client requestDoer, request *http.Request) (T,
 	return result, nil
 }
 
+func decodeJsonInto[T any](client requestDoer, request *http.Request, out *T) error {
+	result, err := decodeJsonFromRequest[T](client, request)
+	if err != nil {
+		return err
+	}
+
+	*out = result
+	return nil
+}
+
 func decodeJsonFromRequestTask[T any](client requestDoer) func(*http.Request) (T, error) {
 	return func(request *http.Request) (T, error) {
 		return decodeJsonFromRequest[T](client, request)