Selaa lähdekoodia

allow IPs in defender safe list to exceed max per-host connections

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 1 vuosi sitten
vanhempi
commit
799fdd7098

+ 3 - 3
go.mod

@@ -79,7 +79,7 @@ require (
 )
 
 require (
-	cloud.google.com/go v0.112.0 // indirect
+	cloud.google.com/go v0.112.1 // indirect
 	cloud.google.com/go/compute v1.24.0 // indirect
 	cloud.google.com/go/compute/metadata v0.2.3 // indirect
 	cloud.google.com/go/iam v1.1.6 // indirect
@@ -117,7 +117,7 @@ require (
 	github.com/golang/protobuf v1.5.3 // indirect
 	github.com/google/s2a-go v0.1.7 // indirect
 	github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
-	github.com/googleapis/gax-go/v2 v2.12.1 // indirect
+	github.com/googleapis/gax-go/v2 v2.12.2 // indirect
 	github.com/hashicorp/errwrap v1.1.0 // indirect
 	github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
 	github.com/hashicorp/go-multierror v1.1.1 // indirect
@@ -134,7 +134,7 @@ require (
 	github.com/lestrrat-go/httprc v1.0.4 // indirect
 	github.com/lestrrat-go/iter v1.0.2 // indirect
 	github.com/lestrrat-go/option v1.0.1 // indirect
-	github.com/lufia/plan9stats v0.0.0-20231016141302-07b5767bb0ed // indirect
+	github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a // indirect
 	github.com/magiconair/properties v1.8.7 // indirect
 	github.com/mattn/go-colorable v0.1.13 // indirect
 	github.com/mattn/go-isatty v0.0.20 // indirect

+ 8 - 8
go.sum

@@ -1,6 +1,6 @@
 cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-cloud.google.com/go v0.112.0 h1:tpFCD7hpHFlQ8yPwT3x+QeXqc2T6+n6T+hmABHfDUSM=
-cloud.google.com/go v0.112.0/go.mod h1:3jEEVwZ/MHU4djK5t5RHuKOA/GbLddgTdVubX1qnPD4=
+cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM=
+cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4=
 cloud.google.com/go/compute v1.24.0 h1:phWcR2eWzRJaL/kOiJwfFsPs4BaKq1j6vnpZrc1YlVg=
 cloud.google.com/go/compute v1.24.0/go.mod h1:kw1/T+h/+tK2LJK0wiPPx1intgdAM3j/g3hFDlscY40=
 cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
@@ -212,8 +212,8 @@ github.com/google/wire v0.5.0 h1:I7ELFeVBr3yfPIcc8+MWvrjk+3VjbcSzoXm3JVa+jD8=
 github.com/google/wire v0.5.0/go.mod h1:ngWDr9Qvq3yZA10YrxfyGELY/AFWGVpy9c1LTRi1EoU=
 github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
 github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
-github.com/googleapis/gax-go/v2 v2.12.1 h1:9F8GV9r9ztXyAi00gsMQHNoF51xPZm8uj1dpYt2ZETM=
-github.com/googleapis/gax-go/v2 v2.12.1/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc=
+github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA=
+github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc=
 github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
 github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -275,8 +275,8 @@ github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8=
 github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts=
 github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
-github.com/lufia/plan9stats v0.0.0-20231016141302-07b5767bb0ed h1:036IscGBfJsFIgJQzlui7nK1Ncm0tp2ktmPj8xO4N/0=
-github.com/lufia/plan9stats v0.0.0-20231016141302-07b5767bb0ed/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
+github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a h1:3Bm7EwfUQUvhNeKIkUct/gl9eod1TcXuj8stxvi/GoI=
+github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
 github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
 github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
 github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
@@ -419,8 +419,8 @@ go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
 go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
 go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI=
 go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco=
-go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8=
-go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E=
+go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw=
+go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc=
 go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI=
 go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
 go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8=

+ 11 - 7
internal/common/common.go

@@ -383,13 +383,14 @@ func GetDefenderScore(ip string) (int, error) {
 	return Config.defender.GetScore(ip)
 }
 
-// AddDefenderEvent adds the specified defender event for the given IP
-func AddDefenderEvent(ip, protocol string, event HostEvent) {
+// AddDefenderEvent adds the specified defender event for the given IP.
+// Returns true if the IP is in the defender's safe list.
+func AddDefenderEvent(ip, protocol string, event HostEvent) bool {
 	if Config.defender == nil {
-		return
+		return false
 	}
 
-	Config.defender.AddEvent(ip, protocol, event)
+	return Config.defender.AddEvent(ip, protocol, event)
 }
 
 func startPeriodicChecks(duration time.Duration, isShared int) {
@@ -1191,9 +1192,12 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string)
 
 	if Config.MaxPerHostConnections > 0 {
 		if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
-			logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections)
-			AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded)
-			return ErrConnectionDenied
+			if !AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded) {
+				logger.Warn(logSender, "", "connection denied, active connections from IP %q: %d/%d",
+					ipAddr, total, Config.MaxPerHostConnections)
+				return ErrConnectionDenied
+			}
+			logger.Info(logSender, "", "active connections from safe IP %q: %d", ipAddr, total)
 		}
 	}
 

+ 35 - 2
internal/common/common_test.go

@@ -668,9 +668,26 @@ func TestConnectionRoles(t *testing.T) {
 }
 
 func TestMaxConnectionPerHost(t *testing.T) {
-	oldValue := Config.MaxPerHostConnections
+	defender, err := newInMemoryDefender(&DefenderConfig{
+		Enabled:            true,
+		Driver:             DefenderDriverMemory,
+		BanTime:            30,
+		BanTimeIncrement:   50,
+		Threshold:          15,
+		ScoreInvalid:       2,
+		ScoreValid:         1,
+		ScoreLimitExceeded: 3,
+		ObservationTime:    30,
+		EntriesSoftLimit:   100,
+		EntriesHardLimit:   150,
+	})
+	require.NoError(t, err)
+
+	oldMaxPerHostConn := Config.MaxPerHostConnections
+	oldDefender := Config.defender
 
 	Config.MaxPerHostConnections = 2
+	Config.defender = defender
 
 	ipAddr := "192.168.9.9"
 	Connections.AddClientConnection(ipAddr)
@@ -682,14 +699,30 @@ func TestMaxConnectionPerHost(t *testing.T) {
 	Connections.AddClientConnection(ipAddr)
 	assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP))
 	assert.Equal(t, int32(3), Connections.GetClientConnections())
+	// Add the IP to the defender safe list
+	entry := dataprovider.IPListEntry{
+		IPOrNet: ipAddr,
+		Type:    dataprovider.IPListTypeDefender,
+		Mode:    dataprovider.ListModeAllow,
+	}
+	err = dataprovider.AddIPListEntry(&entry, "", "", "")
+	assert.NoError(t, err)
 
+	Connections.AddClientConnection(ipAddr)
+	assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
+
+	err = dataprovider.DeleteIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, "", "", "")
+	assert.NoError(t, err)
+
+	Connections.RemoveClientConnection(ipAddr)
 	Connections.RemoveClientConnection(ipAddr)
 	Connections.RemoveClientConnection(ipAddr)
 	Connections.RemoveClientConnection(ipAddr)
 
 	assert.Equal(t, int32(0), Connections.GetClientConnections())
 
-	Config.MaxPerHostConnections = oldValue
+	Config.MaxPerHostConnections = oldMaxPerHostConn
+	Config.defender = oldDefender
 }
 
 func TestIdleConnections(t *testing.T) {

+ 1 - 1
internal/common/defender.go

@@ -47,7 +47,7 @@ var (
 type Defender interface {
 	GetHosts() ([]dataprovider.DefenderEntry, error)
 	GetHost(ip string) (dataprovider.DefenderEntry, error)
-	AddEvent(ip, protocol string, event HostEvent)
+	AddEvent(ip, protocol string, event HostEvent) bool
 	IsBanned(ip, protocol string) bool
 	IsSafe(ip, protocol string) bool
 	GetBanTime(ip string) (*time.Time, error)

+ 6 - 4
internal/common/defenderdb.go

@@ -88,17 +88,18 @@ func (d *dbDefender) DeleteHost(ip string) bool {
 }
 
 // AddEvent adds an event for the given IP.
-// This method must be called for clients not yet banned
-func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) {
+// This method must be called for clients not yet banned.
+// Returns true if the IP is in the defender's safe list.
+func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) bool {
 	if d.IsSafe(ip, protocol) {
-		return
+		return true
 	}
 
 	score := d.baseDefender.getScore(event)
 
 	host, err := dataprovider.AddDefenderEvent(ip, score, d.getStartObservationTime())
 	if err != nil {
-		return
+		return false
 	}
 	d.baseDefender.logEvent(ip, protocol, event, host.Score)
 	if host.Score > d.config.Threshold {
@@ -118,6 +119,7 @@ func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) {
 	if err == nil {
 		d.cleanup()
 	}
+	return false
 }
 
 // GetBanTime returns the ban time for the given IP or nil if the IP is not banned

+ 6 - 4
internal/common/defendermem.go

@@ -170,10 +170,11 @@ func (d *memoryDefender) DeleteHost(ip string) bool {
 }
 
 // AddEvent adds an event for the given IP.
-// This method must be called for clients not yet banned
-func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
+// This method must be called for clients not yet banned.
+// Returns true if the IP is in the defender's safe list.
+func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) bool {
 	if d.IsSafe(ip, protocol) {
-		return
+		return true
 	}
 
 	d.Lock()
@@ -182,7 +183,7 @@ func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
 	// ignore events for already banned hosts
 	if v, ok := d.banned[ip]; ok {
 		if v.After(time.Now()) {
-			return
+			return false
 		}
 		delete(d.banned, ip)
 	}
@@ -231,6 +232,7 @@ func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
 		}
 		d.cleanupHosts()
 	}
+	return false
 }
 
 func (d *memoryDefender) countBanned() int {

+ 8 - 0
internal/httpd/api_utils.go

@@ -146,6 +146,14 @@ func getURLParam(r *http.Request, key string) string {
 	return unescaped
 }
 
+func getURLPath(r *http.Request) string {
+	rctx := chi.RouteContext(r.Context())
+	if rctx != nil && rctx.RoutePath != "" {
+		return rctx.RoutePath
+	}
+	return r.URL.Path
+}
+
 func getCommaSeparatedQueryParam(r *http.Request, key string) []string {
 	var result []string
 

+ 2 - 14
internal/httpd/server.go

@@ -1220,25 +1220,13 @@ func (s *httpdServer) redirectToWebPath(w http.ResponseWriter, r *http.Request,
 // The StripSlashes causes infinite redirects at the root path if used with http.FileServer.
 // We also don't strip paths with more than one trailing slash, see #1434
 func (s *httpdServer) mustStripSlash(r *http.Request) bool {
-	var urlPath string
-	rctx := chi.RouteContext(r.Context())
-	if rctx != nil && rctx.RoutePath != "" {
-		urlPath = rctx.RoutePath
-	} else {
-		urlPath = r.URL.Path
-	}
+	urlPath := getURLPath(r)
 	return !strings.HasSuffix(urlPath, "//") && !strings.HasPrefix(urlPath, webOpenAPIPath) &&
 		!strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI)
 }
 
 func (s *httpdServer) mustCheckPath(r *http.Request) bool {
-	var urlPath string
-	rctx := chi.RouteContext(r.Context())
-	if rctx != nil && rctx.RoutePath != "" {
-		urlPath = rctx.RoutePath
-	} else {
-		urlPath = r.URL.Path
-	}
+	urlPath := getURLPath(r)
 	return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI)
 }