瀏覽代碼

Pushing latest fixes to handle the two different JSON responses.

Keith Carichner Jr 4 月之前
父節點
當前提交
e643a44b59
共有 1 個文件被更改,包括 144 次插入65 次删除
  1. 144 65
      internal/glance/widget-dns-stats.go

+ 144 - 65
internal/glance/widget-dns-stats.go

@@ -7,7 +7,6 @@ import (
 	"errors"
 	"html/template"
 	"io"
-	"log/slog"
 	"net/http"
 	"os"
 	"sort"
@@ -15,9 +14,6 @@ import (
 	"time"
 )
 
-// Global HTTP client for reuse
-var httpClient = &http.Client{}
-
 var dnsStatsWidgetTemplate = mustParseTemplate("dns-stats.html", "widget-base.html")
 
 type dnsStatsWidget struct {
@@ -235,7 +231,8 @@ func fetchAdguardStats(instanceURL string, allowInsecure bool, username, passwor
 	return stats, nil
 }
 
-type piholeStatsResponse struct {
+// Legacy Pi-hole stats response (before v6)
+type legacyPiholeStatsResponse struct {
 	TotalQueries      int                     `json:"dns_queries_today"`
 	QueriesSeries     piholeQueriesSeries     `json:"domains_over_time"`
 	BlockedQueries    int                     `json:"ads_blocked_today"`
@@ -245,6 +242,24 @@ type piholeStatsResponse struct {
 	DomainsBlocked    int                     `json:"domains_being_blocked"`
 }
 
+// Pi-hole v6+ response format
+type piholeStatsResponse struct {
+	Queries struct {
+		Total          int     `json:"total"`
+		Blocked        int     `json:"blocked"`
+		PercentBlocked float64 `json:"percent_blocked"`
+	} `json:"queries"`
+	Gravity struct {
+		DomainsBlocked int `json:"domains_being_blocked"`
+	} `json:"gravity"`
+	//Note we do not need the full structure. We extract the values needed
+	//Adding dummy fields to allow easier json parsing.
+	QueriesSeries piholeQueriesSeries `json:"domains_over_time"` // Will always be empty
+	BlockedSeries map[int64]int       `json:"ads_over_time"`     // Will always be empty.
+}
+
+type piholeTopDomainsResponse map[string]int
+
 // If the user has query logging disabled it's possible for domains_over_time to be returned as an
 // empty array rather than a map which will prevent unmashalling the rest of the data so we use
 // custom unmarshal behavior to fallback to an empty map.
@@ -284,7 +299,14 @@ func (p *piholeTopBlockedDomains) UnmarshalJSON(data []byte) error {
 }
 
 // piholeGetSID retrieves a new SID from Pi-hole using the app password.
-func piholeGetSID(instanceURL, appPassword string) (string, error) {
+func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string, error) {
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
 	requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth"
 	requestBody := []byte(`{"password":"` + appPassword + `"}`)
 
@@ -294,7 +316,7 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) {
 	}
 	request.Header.Set("Content-Type", "application/json")
 
-	response, err := httpClient.Do(request)
+	response, err := client.Do(request)
 	if err != nil {
 		return "", errors.New("failed to send authentication request: " + err.Error())
 	}
@@ -326,31 +348,9 @@ func piholeGetSID(instanceURL, appPassword string) (string, error) {
 	return jsonResponse.Session.SID, nil
 }
 
-func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) {
-	var requestURL string
-
-	// Handle Pi-hole v6 authentication
-	if version == "" || version == "6" {
-		if appPassword == "" {
-			return nil, errors.New("missing app password")
-		}
-		// If SID env var is not set, get a new SID
-		if os.Getenv("SID") == "" {
-			sid, err := piholeGetSID(instanceURL, appPassword)
-			os.Setenv("SID", sid)
-			if err != nil {
-				return nil, err
-			}
-		}
-		sid := os.Getenv("SID")
-
-		requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid
-	} else {
-		if token == "" {
-			return nil, errors.New("missing API token")
-		}
-		requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
-	}
+// fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+.
+func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true&sid=" + sid
 
 	request, err := http.NewRequest("GET", requestURL, nil)
 	if err != nil {
@@ -364,25 +364,51 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
 		client = defaultInsecureHTTPClient
 	}
 
-	responseJson, err := decodeJsonFromRequest[piholeStatsResponse](client, request)
-	if err != nil {
-		return nil, err
-	}
+	return decodeJsonFromRequest[piholeTopDomainsResponse](client, request)
+}
+
+// Helper functions to process the responses
+func parsePiholeStats(r *piholeStatsResponse, topDomains piholeTopDomainsResponse) *dnsStats {
 
 	stats := &dnsStats{
-		TotalQueries:   responseJson.TotalQueries,
-		BlockedQueries: responseJson.BlockedQueries,
-		BlockedPercent: int(responseJson.BlockedPercentage),
-		DomainsBlocked: responseJson.DomainsBlocked,
+		TotalQueries:   r.Queries.Total,
+		BlockedQueries: r.Queries.Blocked,
+		BlockedPercent: int(r.Queries.PercentBlocked),
+		DomainsBlocked: r.Gravity.DomainsBlocked,
+	}
+
+	if len(topDomains) > 0 {
+		domains := make([]dnsStatsBlockedDomain, 0, len(topDomains))
+		for domain, count := range topDomains {
+			domains = append(domains, dnsStatsBlockedDomain{
+				Domain:         domain,
+				PercentBlocked: int(float64(count) / float64(r.Queries.Blocked) * 100), // Calculate percentage here
+			})
+		}
+
+		sort.Slice(domains, func(a, b int) bool {
+			return domains[a].PercentBlocked > domains[b].PercentBlocked
+		})
+		stats.TopBlockedDomains = domains[:min(len(domains), 5)]
 	}
 
-	if len(responseJson.TopBlockedDomains) > 0 {
-		domains := make([]dnsStatsBlockedDomain, 0, len(responseJson.TopBlockedDomains))
+	return stats
+}
+func parsePiholeStatsLegacy(r *legacyPiholeStatsResponse, noGraph bool) *dnsStats {
+
+	stats := &dnsStats{
+		TotalQueries:   r.TotalQueries,
+		BlockedQueries: r.BlockedQueries,
+		BlockedPercent: int(r.BlockedPercentage),
+		DomainsBlocked: r.DomainsBlocked,
+	}
+	if len(r.TopBlockedDomains) > 0 {
+		domains := make([]dnsStatsBlockedDomain, 0, len(r.TopBlockedDomains))
 
-		for domain, count := range responseJson.TopBlockedDomains {
+		for domain, count := range r.TopBlockedDomains {
 			domains = append(domains, dnsStatsBlockedDomain{
 				Domain:         domain,
-				PercentBlocked: int(float64(count) / float64(responseJson.BlockedQueries) * 100),
+				PercentBlocked: int(float64(count) / float64(r.BlockedQueries) * 100),
 			})
 		}
 
@@ -392,59 +418,112 @@ func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGr
 
 		stats.TopBlockedDomains = domains[:min(len(domains), 5)]
 	}
-
 	if noGraph {
-		return stats, nil
+		return stats
 	}
 
 	// Pihole _should_ return data for the last 24 hours in a 10 minute interval, 6*24 = 144
-	if len(responseJson.QueriesSeries) != 144 || len(responseJson.BlockedSeries) != 144 {
-		slog.Warn(
-			"DNS stats for pihole: did not get expected 144 data points",
-			"len(queries)", len(responseJson.QueriesSeries),
-			"len(blocked)", len(responseJson.BlockedSeries),
-		)
-		return stats, nil
+	if len(r.QueriesSeries) != 144 || len(r.BlockedSeries) != 144 {
+		return stats
 	}
 
 	var lowestTimestamp int64 = 0
-
-	for timestamp := range responseJson.QueriesSeries {
+	for timestamp := range r.QueriesSeries {
 		if lowestTimestamp == 0 || timestamp < lowestTimestamp {
 			lowestTimestamp = timestamp
 		}
 	}
-
 	maxQueriesInSeries := 0
 
 	for i := 0; i < 8; i++ {
 		queries := 0
 		blocked := 0
-
 		for j := 0; j < 18; j++ {
 			index := lowestTimestamp + int64(i*10800+j*600)
-
-			queries += responseJson.QueriesSeries[index]
-			blocked += responseJson.BlockedSeries[index]
+			queries += r.QueriesSeries[index]
+			blocked += r.BlockedSeries[index]
 		}
-
 		if queries > maxQueriesInSeries {
 			maxQueriesInSeries = queries
 		}
-
 		stats.Series[i] = dnsStatsSeries{
 			Queries: queries,
 			Blocked: blocked,
 		}
-
 		if queries > 0 {
 			stats.Series[i].PercentBlocked = int(float64(blocked) / float64(queries) * 100)
 		}
 	}
-
 	for i := 0; i < 8; i++ {
 		stats.Series[i].PercentTotal = int(float64(stats.Series[i].Queries) / float64(maxQueriesInSeries) * 100)
 	}
+	return stats
+}
 
-	return stats, nil
+func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) {
+	var requestURL string
+	var sid string
+
+	// Handle Pi-hole v6 authentication
+	if version == "" || version == "6" {
+		if appPassword == "" {
+			return nil, errors.New("missing app password")
+		}
+		// If SID env var is not set, get a new SID
+		if os.Getenv("SID") == "" {
+			sid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
+			if err != nil {
+				return nil, err
+			}
+			os.Setenv("SID", sid)
+
+		}
+		sid := os.Getenv("SID")
+		requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid
+
+	} else {
+		if token == "" {
+			return nil, errors.New("missing API token")
+		}
+		requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
+
+	}
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	var responseJson interface{}
+
+	if version == "" || version == "6" {
+		responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request)
+
+	} else {
+		responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	switch r := responseJson.(type) {
+	case *piholeStatsResponse:
+		// Fetch top domains separately for v6
+		topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure)
+		if err != nil {
+			return nil, err
+		}
+		return parsePiholeStats(r, topDomains), nil
+	case *legacyPiholeStatsResponse:
+		return parsePiholeStatsLegacy(r, noGraph), nil
+	default:
+		return nil, errors.New("unexpected response type")
+	}
 }