Sfoglia il codice sorgente

Pushing latest changes.

Keith Carichner Jr 4 mesi fa
parent
commit
2002ed1c9c
1 ha cambiato i file con 91 aggiunte e 21 eliminazioni
  1. 91 21
      internal/glance/widget-dns-stats.go

+ 91 - 21
internal/glance/widget-dns-stats.go

@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"html/template"
 	"io"
 	"net/http"
@@ -348,6 +349,35 @@ func piholeGetSID(instanceURL, appPassword string, allowInsecure bool) (string,
 	return jsonResponse.Session.SID, nil
 }
 
+// checkPiholeSID checks if the SID is valid by checking HTTP response status code from /api/auth.
+func checkPiholeSID(instanceURL string, appPassword, sid string, allowInsecure bool) error {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/auth?sid=" + sid
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return err
+	}
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	response, err := client.Do(request)
+	if err != nil {
+		return err
+	}
+	defer response.Body.Close()
+
+	if response.StatusCode != http.StatusOK {
+		return errors.New("SID is invalid, received status: " + response.Status)
+	}
+
+	return nil
+}
+
 // fetchPiholeTopDomains fetches the top blocked domains for Pi-hole v6+.
 func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (piholeTopDomainsResponse, error) {
 	requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/top_domains?blocked=true&sid=" + sid
@@ -367,6 +397,35 @@ func fetchPiholeTopDomains(instanceURL string, sid string, allowInsecure bool) (
 	return decodeJsonFromRequest[piholeTopDomainsResponse](client, request)
 }
 
+// fetchPiholeSeries fetches the series data for Pi-hole v6+ (QueriesSeries and BlockedSeries).
+func fetchPiholeSeries(instanceURL string, sid string, allowInsecure bool) (piholeQueriesSeries, map[int64]int, error) {
+	requestURL := strings.TrimRight(instanceURL, "/") + "/api/stats/over_time_data?sid=" + sid
+
+	request, err := http.NewRequest("GET", requestURL, nil)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	var client requestDoer
+	if !allowInsecure {
+		client = defaultHTTPClient
+	} else {
+		client = defaultInsecureHTTPClient
+	}
+
+	var responseJson struct {
+		QueriesSeries piholeQueriesSeries `json:"queries_over_time"`
+		BlockedSeries map[int64]int       `json:"blocked_over_time"`
+	}
+
+	err = decodeJsonFromRequest[&responseJson](client, request)
+	if err != nil {
+		return nil, nil, err
+	}
+
+	return responseJson.QueriesSeries, responseJson.BlockedSeries, nil
+}
+
 // Helper functions to process the responses
 func parsePiholeStats(r *piholeStatsResponse, topDomains piholeTopDomainsResponse) *dnsStats {
 
@@ -461,64 +520,75 @@ func parsePiholeStatsLegacy(r *legacyPiholeStatsResponse, noGraph bool) *dnsStat
 }
 
 func fetchPiholeStats(instanceURL string, allowInsecure bool, token string, noGraph bool, version, appPassword string) (*dnsStats, error) {
+	instanceURL = strings.TrimRight(instanceURL, "/")
 	var requestURL string
 	var sid string
+	isV6 := version == "" || version == "6"
 
-	// Handle Pi-hole v6 authentication
-	if version == "" || version == "6" {
+	if isV6 {
 		if appPassword == "" {
 			return nil, errors.New("missing app password")
 		}
-		// If SID env var is not set, get a new SID
-		if os.Getenv("SID") == "" {
-			sid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
+
+		sid = os.Getenv("SID")
+		// Only get a new SID if it's not set or is invalid
+		if sid == "" {
+			newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
 			if err != nil {
-				return nil, err
+				return nil, fmt.Errorf("failed to get SID: %w", err) // Use %w for wrapping
 			}
+			sid = newSid
 			os.Setenv("SID", sid)
-
+		} else {
+			// Check existing SID validity.  Only get a new one if the check fails.
+			err := checkPiholeSID(instanceURL, appPassword, sid, allowInsecure)
+			if err != nil {
+				newSid, err := piholeGetSID(instanceURL, appPassword, allowInsecure)
+				if err != nil {
+					return nil, fmt.Errorf("failed to get SID after invalid SID check: %w", err)
+				}
+				sid = newSid
+				os.Setenv("SID", sid)
+			}
 		}
-		sid := os.Getenv("SID")
-		requestURL = strings.TrimRight(instanceURL, "/") + "/api/stats/summary?sid=" + sid
+
+		requestURL = instanceURL + "/api/stats/summary?sid=" + sid
 
 	} else {
 		if token == "" {
 			return nil, errors.New("missing API token")
 		}
-		requestURL = strings.TrimRight(instanceURL, "/") + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
-
+		requestURL = instanceURL + "/admin/api.php?summaryRaw&topItems&overTimeData10mins&auth=" + token
 	}
 
 	request, err := http.NewRequest("GET", requestURL, nil)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to create HTTP request: %w", err)
 	}
 
 	var client requestDoer
-	if !allowInsecure {
-		client = defaultHTTPClient
-	} else {
+	client = defaultHTTPClient
+	if allowInsecure {
 		client = defaultInsecureHTTPClient
 	}
 
 	var responseJson interface{}
-
-	if version == "" || version == "6" {
+	if isV6 {
 		responseJson, err = decodeJsonFromRequest[piholeStatsResponse](client, request)
-
 	} else {
 		responseJson, err = decodeJsonFromRequest[legacyPiholeStatsResponse](client, request)
 	}
+
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("failed to decode JSON response: %w", err)
 	}
 
 	switch r := responseJson.(type) {
 	case *piholeStatsResponse:
-		// Fetch top domains separately for v6
+		// Fetch top domains separately for v6+.
 		topDomains, err := fetchPiholeTopDomains(instanceURL, sid, allowInsecure)
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("failed to fetch top domains: %w", err)
 		}
 		return parsePiholeStats(r, topDomains), nil
 	case *legacyPiholeStatsResponse: