Преглед изворни кода

fix(config): actually load threshold config (#696)

* fix(config): actually load threshold config

Signed-off-by: Xe Iaso <me@xeiaso.net>

* chore: spelling

Signed-off-by: Xe Iaso <me@xeiaso.net>

* test(lib): fix test failures

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
Xe Iaso пре 3 дана
родитељ
комит
7aa732c700

+ 2 - 0
.github/actions/spelling/excludes.txt

@@ -83,7 +83,9 @@
 ^\Q.github/FUNDING.yml\E$
 ^\Q.github/FUNDING.yml\E$
 ^\Q.github/workflows/spelling.yml\E$
 ^\Q.github/workflows/spelling.yml\E$
 ^data/crawlers/
 ^data/crawlers/
+^docs/blog/tags\.yml$
 ^docs/manifest/.*$
 ^docs/manifest/.*$
 ^docs/static/\.nojekyll$
 ^docs/static/\.nojekyll$
+^lib/policy/config/testdata/bad/unparseable\.json$
 ignore$
 ignore$
 robots.txt
 robots.txt

+ 1 - 2
.github/actions/spelling/expect.txt

@@ -44,7 +44,6 @@ chall
 challengemozilla
 challengemozilla
 checkpath
 checkpath
 checkresult
 checkresult
-chen
 chibi
 chibi
 cidranger
 cidranger
 ckie
 ckie
@@ -61,7 +60,6 @@ DDOS
 Debian
 Debian
 debrpm
 debrpm
 decaymap
 decaymap
-decompiling
 Diffbot
 Diffbot
 discordapp
 discordapp
 discordbot
 discordbot
@@ -300,6 +298,7 @@ xess
 xff
 xff
 XForwarded
 XForwarded
 XNG
 XNG
+XOB
 XReal
 XReal
 yae
 yae
 YAMLTo
 YAMLTo

+ 17 - 20
lib/anubis_test.go

@@ -24,12 +24,16 @@ func init() {
 	internal.InitSlog("debug")
 	internal.InitSlog("debug")
 }
 }
 
 
-func loadPolicies(t *testing.T, fname string) *policy.ParsedConfig {
+func loadPolicies(t *testing.T, fname string, difficulty int) *policy.ParsedConfig {
 	t.Helper()
 	t.Helper()
 
 
 	ctx := thothmock.WithMockThoth(t)
 	ctx := thothmock.WithMockThoth(t)
 
 
-	anubisPolicy, err := LoadPoliciesOrDefault(ctx, fname, anubis.DefaultDifficulty)
+	if fname == "" {
+		fname = "./testdata/test_config.yaml"
+	}
+
+	anubisPolicy, err := LoadPoliciesOrDefault(ctx, fname, difficulty)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
@@ -176,8 +180,7 @@ func TestLoadPolicies(t *testing.T) {
 
 
 // Regression test for CVE-2025-24369
 // Regression test for CVE-2025-24369
 func TestCVE2025_24369(t *testing.T) {
 func TestCVE2025_24369(t *testing.T) {
-	pol := loadPolicies(t, "")
-	pol.DefaultDifficulty = 4
+	pol := loadPolicies(t, "", anubis.DefaultDifficulty)
 
 
 	srv := spawnAnubis(t, Options{
 	srv := spawnAnubis(t, Options{
 		Next:   http.NewServeMux(),
 		Next:   http.NewServeMux(),
@@ -200,8 +203,7 @@ func TestCVE2025_24369(t *testing.T) {
 }
 }
 
 
 func TestCookieCustomExpiration(t *testing.T) {
 func TestCookieCustomExpiration(t *testing.T) {
-	pol := loadPolicies(t, "")
-	pol.DefaultDifficulty = 0
+	pol := loadPolicies(t, "", 0)
 	ckieExpiration := 10 * time.Minute
 	ckieExpiration := 10 * time.Minute
 
 
 	srv := spawnAnubis(t, Options{
 	srv := spawnAnubis(t, Options{
@@ -250,8 +252,7 @@ func TestCookieCustomExpiration(t *testing.T) {
 }
 }
 
 
 func TestCookieSettings(t *testing.T) {
 func TestCookieSettings(t *testing.T) {
-	pol := loadPolicies(t, "")
-	pol.DefaultDifficulty = 0
+	pol := loadPolicies(t, "", 0)
 
 
 	srv := spawnAnubis(t, Options{
 	srv := spawnAnubis(t, Options{
 		Next:   http.NewServeMux(),
 		Next:   http.NewServeMux(),
@@ -316,10 +317,7 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) {
 
 
 	for i := 1; i < 10; i++ {
 	for i := 1; i < 10; i++ {
 		t.Run(fmt.Sprint(i), func(t *testing.T) {
 		t.Run(fmt.Sprint(i), func(t *testing.T) {
-			anubisPolicy, err := LoadPoliciesOrDefault(t.Context(), "", i)
-			if err != nil {
-				t.Fatal(err)
-			}
+			anubisPolicy := loadPolicies(t, "", i)
 
 
 			s, err := New(Options{
 			s, err := New(Options{
 				Next:           h,
 				Next:           h,
@@ -337,11 +335,13 @@ func TestCheckDefaultDifficultyMatchesPolicy(t *testing.T) {
 
 
 			req.Header.Add("X-Real-Ip", "127.0.0.1")
 			req.Header.Add("X-Real-Ip", "127.0.0.1")
 
 
-			_, bot, err := s.check(req)
+			cr, bot, err := s.check(req)
 			if err != nil {
 			if err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
 
 
+			t.Log(cr.Name)
+
 			if bot.Challenge.Difficulty != i {
 			if bot.Challenge.Difficulty != i {
 				t.Errorf("Challenge.Difficulty is wrong, wanted %d, got: %d", i, bot.Challenge.Difficulty)
 				t.Errorf("Challenge.Difficulty is wrong, wanted %d, got: %d", i, bot.Challenge.Difficulty)
 			}
 			}
@@ -389,8 +389,7 @@ func TestBasePrefix(t *testing.T) {
 			// Reset the global BasePrefix before each test
 			// Reset the global BasePrefix before each test
 			anubis.BasePrefix = ""
 			anubis.BasePrefix = ""
 
 
-			pol := loadPolicies(t, "")
-			pol.DefaultDifficulty = 4
+			pol := loadPolicies(t, "", 4)
 
 
 			srv := spawnAnubis(t, Options{
 			srv := spawnAnubis(t, Options{
 				Next:       h,
 				Next:       h,
@@ -518,8 +517,7 @@ func TestCustomStatusCodes(t *testing.T) {
 		"DENY":      403,
 		"DENY":      403,
 	}
 	}
 
 
-	pol := loadPolicies(t, "./testdata/aggressive_403.yaml")
-	pol.DefaultDifficulty = 4
+	pol := loadPolicies(t, "./testdata/aggressive_403.yaml", 4)
 
 
 	srv := spawnAnubis(t, Options{
 	srv := spawnAnubis(t, Options{
 		Next:   h,
 		Next:   h,
@@ -553,7 +551,7 @@ func TestCustomStatusCodes(t *testing.T) {
 func TestCloudflareWorkersRule(t *testing.T) {
 func TestCloudflareWorkersRule(t *testing.T) {
 	for _, variant := range []string{"cel", "header"} {
 	for _, variant := range []string{"cel", "header"} {
 		t.Run(variant, func(t *testing.T) {
 		t.Run(variant, func(t *testing.T) {
-			pol := loadPolicies(t, "./testdata/cloudflare-workers-"+variant+".yaml")
+			pol := loadPolicies(t, "./testdata/cloudflare-workers-"+variant+".yaml", 0)
 
 
 			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 				fmt.Fprintln(w, "OK")
 				fmt.Fprintln(w, "OK")
@@ -609,8 +607,7 @@ func TestCloudflareWorkersRule(t *testing.T) {
 }
 }
 
 
 func TestRuleChange(t *testing.T) {
 func TestRuleChange(t *testing.T) {
-	pol := loadPolicies(t, "testdata/rule_change.yaml")
-	pol.DefaultDifficulty = 0
+	pol := loadPolicies(t, "testdata/rule_change.yaml", 0)
 	ckieExpiration := 10 * time.Minute
 	ckieExpiration := 10 * time.Minute
 
 
 	srv := spawnAnubis(t, Options{
 	srv := spawnAnubis(t, Options{

+ 1 - 1
lib/config_test.go

@@ -26,7 +26,7 @@ func TestBadConfigs(t *testing.T) {
 	for _, st := range finfos {
 	for _, st := range finfos {
 		st := st
 		st := st
 		t.Run(st.Name(), func(t *testing.T) {
 		t.Run(st.Name(), func(t *testing.T) {
-			if _, err := LoadPoliciesOrDefault(t.Context(), filepath.Join("policy", "config", "testdata", "good", st.Name()), anubis.DefaultDifficulty); err == nil {
+			if _, err := LoadPoliciesOrDefault(t.Context(), filepath.Join("policy", "config", "testdata", "bad", st.Name()), anubis.DefaultDifficulty); err == nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			} else {
 			} else {
 				t.Log(err)
 				t.Log(err)

+ 55 - 0
lib/policy/config/asn_test.go

@@ -0,0 +1,55 @@
+package config
+
+import (
+	"errors"
+	"fmt"
+	"testing"
+)
+
+func TestASNsValid(t *testing.T) {
+	for _, tt := range []struct {
+		name  string
+		input *ASNs
+		err   error
+	}{
+		{
+			name: "basic valid",
+			input: &ASNs{
+				Match: []uint32{13335}, // Cloudflare
+			},
+		},
+		{
+			name: "private ASN",
+			input: &ASNs{
+				Match: []uint32{64513, 4206942069}, // 16 and 32 bit private ASN
+			},
+			err: ErrPrivateASN,
+		},
+	} {
+		t.Run(tt.name, func(t *testing.T) {
+			if err := tt.input.Valid(); !errors.Is(err, tt.err) {
+				t.Logf("want: %v", tt.err)
+				t.Logf("got:  %v", err)
+				t.Error("got wrong validation error")
+			}
+		})
+	}
+}
+
+func TestIsPrivateASN(t *testing.T) {
+	for _, tt := range []struct {
+		input  uint32
+		output bool
+	}{
+		{13335, false},     // Cloudflare
+		{64513, true},      // 16 bit private ASN
+		{4206942069, true}, // 32 bit private ASN
+	} {
+		t.Run(fmt.Sprint(tt.input, "->", tt.output), func(t *testing.T) {
+			result := isPrivateASN(tt.input)
+			if result != tt.output {
+				t.Errorf("wanted isPrivateASN(%d) == %v, got: %v", tt.input, tt.output, result)
+			}
+		})
+	}
+}

+ 5 - 6
lib/policy/config/config.go

@@ -326,7 +326,7 @@ type fileConfig struct {
 	Bots        []BotOrImport `json:"bots"`
 	Bots        []BotOrImport `json:"bots"`
 	DNSBL       bool          `json:"dnsbl"`
 	DNSBL       bool          `json:"dnsbl"`
 	StatusCodes StatusCodes   `json:"status_codes"`
 	StatusCodes StatusCodes   `json:"status_codes"`
-	Thresholds  []Threshold   `json:"threshold"`
+	Thresholds  []Threshold   `json:"thresholds"`
 }
 }
 
 
 func (c *fileConfig) Valid() error {
 func (c *fileConfig) Valid() error {
@@ -346,10 +346,6 @@ func (c *fileConfig) Valid() error {
 		errs = append(errs, err)
 		errs = append(errs, err)
 	}
 	}
 
 
-	if len(c.Thresholds) == 0 {
-		errs = append(errs, ErrNoThresholdRulesDefined)
-	}
-
 	for i, t := range c.Thresholds {
 	for i, t := range c.Thresholds {
 		if err := t.Valid(); err != nil {
 		if err := t.Valid(); err != nil {
 			errs = append(errs, fmt.Errorf("threshold %d: %w", i, err))
 			errs = append(errs, fmt.Errorf("threshold %d: %w", i, err))
@@ -369,7 +365,6 @@ func Load(fin io.Reader, fname string) (*Config, error) {
 			Challenge: http.StatusOK,
 			Challenge: http.StatusOK,
 			Deny:      http.StatusOK,
 			Deny:      http.StatusOK,
 		},
 		},
-		Thresholds: DefaultThresholds,
 	}
 	}
 
 
 	if err := yaml.NewYAMLToJSONDecoder(fin).Decode(&c); err != nil {
 	if err := yaml.NewYAMLToJSONDecoder(fin).Decode(&c); err != nil {
@@ -407,6 +402,10 @@ func Load(fin io.Reader, fname string) (*Config, error) {
 		}
 		}
 	}
 	}
 
 
+	if len(c.Thresholds) == 0 {
+		c.Thresholds = DefaultThresholds
+	}
+
 	for _, t := range c.Thresholds {
 	for _, t := range c.Thresholds {
 		if err := t.Valid(); err != nil {
 		if err := t.Valid(); err != nil {
 			validationErrs = append(validationErrs, err)
 			validationErrs = append(validationErrs, err)

+ 1 - 1
lib/policy/config/geoip.go

@@ -8,7 +8,7 @@ import (
 )
 )
 
 
 var (
 var (
-	countryCodeRegexp = regexp.MustCompile(`^\w{2}$`)
+	countryCodeRegexp = regexp.MustCompile(`^[a-zA-Z]{2}$`)
 
 
 	ErrNotCountryCode = errors.New("config.Bot: invalid country code")
 	ErrNotCountryCode = errors.New("config.Bot: invalid country code")
 )
 )

+ 36 - 0
lib/policy/config/geoip_test.go

@@ -0,0 +1,36 @@
+package config
+
+import (
+	"errors"
+	"testing"
+)
+
+func TestGeoIPValid(t *testing.T) {
+	for _, tt := range []struct {
+		name  string
+		input *GeoIP
+		err   error
+	}{
+		{
+			name: "basic valid",
+			input: &GeoIP{
+				Countries: []string{"CA"},
+			},
+		},
+		{
+			name: "invalid country",
+			input: &GeoIP{
+				Countries: []string{"XOB"},
+			},
+			err: ErrNotCountryCode,
+		},
+	} {
+		t.Run(tt.name, func(t *testing.T) {
+			if err := tt.input.Valid(); !errors.Is(err, tt.err) {
+				t.Logf("want: %v", tt.err)
+				t.Logf("got:  %v", err)
+				t.Error("got wrong validation error")
+			}
+		})
+	}
+}

+ 11 - 0
lib/policy/config/testdata/bad/threshold-challenge-without-challenge.yaml

@@ -0,0 +1,11 @@
+bots:
+  - name: simple-weight-adjust
+    action: WEIGH
+    user_agent_regex: Mozilla
+    weight:
+      adjust: 5
+
+thresholds:
+  - name: extreme-suspicion
+    expression: "true"
+    action: WEIGH

+ 15 - 0
lib/policy/config/testdata/bad/thresholds.yaml

@@ -0,0 +1,15 @@
+bots:
+  - name: simple-weight-adjust
+    action: WEIGH
+    user_agent_regex: Mozilla
+    weight:
+      adjust: 5
+
+thresholds:
+  - name: extreme-suspicion
+    expression: "true"
+    action: WEIGH
+    challenge:
+      algorithm: fast
+      difficulty: 4
+      report_as: 4

+ 19 - 0
lib/policy/config/threshold_test.go

@@ -3,6 +3,8 @@ package config
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"os"
+	"path/filepath"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -90,3 +92,20 @@ func TestDefaultThresholdsValid(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestLoadActuallyLoadsThresholds(t *testing.T) {
+	fin, err := os.Open(filepath.Join(".", "testdata", "good", "thresholds.yaml"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer fin.Close()
+
+	c, err := Load(fin, fin.Name())
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(c.Thresholds) != 4 {
+		t.Errorf("wanted 4 thresholds, got %d thresholds", len(c.Thresholds))
+	}
+}

+ 38 - 0
lib/testdata/test_config.yaml

@@ -0,0 +1,38 @@
+bots:
+  - import: (data)/bots/_deny-pathological.yaml
+  - import: (data)/bots/aggressive-brazilian-scrapers.yaml
+  - import: (data)/meta/ai-block-aggressive.yaml
+  - import: (data)/crawlers/_allow-good.yaml
+  - import: (data)/clients/x-firefox-ai.yaml
+  - import: (data)/common/keep-internet-working.yaml
+  - name: countries-with-aggressive-scrapers
+    action: WEIGH
+    geoip:
+      countries:
+        - BR
+        - CN
+    weight:
+      adjust: 10
+  - name: aggressive-asns-without-functional-abuse-contact
+    action: WEIGH
+    asns:
+      match:
+        - 13335 # Cloudflare
+        - 136907 # Huawei Cloud
+        - 45102 # Alibaba Cloud
+    weight:
+      adjust: 10
+  - name: generic-browser
+    user_agent_regex: >-
+      Mozilla|Opera
+    action: WEIGH
+    weight:
+      adjust: 10
+
+dnsbl: false
+
+status_codes:
+  CHALLENGE: 200
+  DENY: 200
+
+thresholds: []