Просмотр исходного кода

Add trusted IPs which have admin API access (#1352)

* Add trusted IPs which have admin API access
Shivam Sandbhor 3 лет назад
Родитель
Сommit
023ac9e138

+ 4 - 0
cmd/crowdsec-cli/config.go

@@ -369,6 +369,10 @@ func NewConfigCmd() *cobra.Command {
 								fmt.Printf("  - Key File  : %s\n", csConfig.API.Server.TLS.KeyFilePath)
 							}
 						}
+						fmt.Printf("  - Trusted IPs: \n")
+						for _, ip := range csConfig.API.Server.TrustedIPs {
+							fmt.Printf("      - %s\n", ip)
+						}
 						if csConfig.API.Server.OnlineClient != nil && csConfig.API.Server.OnlineClient.Credentials != nil {
 							fmt.Printf("Central API:\n")
 							fmt.Printf("  - URL                     : %s\n", csConfig.API.Server.OnlineClient.Credentials.URL)

+ 3 - 0
config/config.yaml

@@ -45,6 +45,9 @@ api:
     console_path: /etc/crowdsec/console.yaml
     online_client: # Central API credentials (to push signals and receive bad IPs)
       credentials_path: /etc/crowdsec/online_api_credentials.yaml
+    trusted_ips: # IP ranges, or IPs which can have admin API access
+      - 127.0.0.1
+      - ::1
 #    tls:
 #      cert_file: /etc/crowdsec/ssl/cert.pem
 #      key_file: /etc/crowdsec/ssl/key.pem

+ 3 - 0
docker/config.yaml

@@ -41,6 +41,9 @@ api:
     log_level: info
     listen_uri: 0.0.0.0:8080
     profiles_path: /etc/crowdsec/profiles.yaml
+    trusted_ips: # IP ranges, or IPs which can have admin API access
+      - 127.0.0.1
+      - ::1
     online_client: # Central API credentials (to push signals and receive bad IPs)
       #credentials_path: /etc/crowdsec/online_api_credentials.yaml
 #    tls:

+ 72 - 0
pkg/apiserver/alerts_test.go

@@ -556,3 +556,75 @@ func TestDeleteAlert(t *testing.T) {
 	assert.Equal(t, 200, w.Code)
 	assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
 }
+
+func TestDeleteAlertTrustedIPS(t *testing.T) {
+	cfg := LoadTestConfig()
+	// IPv6 mocking doesn't seem to work.
+	// cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"}
+	cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
+	cfg.API.Server.ListenURI = "::8080"
+	server, err := NewServer(cfg.API.Server)
+	if err != nil {
+		log.Fatal(err.Error())
+	}
+	err = server.InitController()
+	if err != nil {
+		log.Fatal(err.Error())
+	}
+	router, err := server.Router()
+	if err != nil {
+		log.Fatal(err.Error())
+	}
+	loginResp, err := LoginToTestAPI(router)
+	if err != nil {
+		log.Fatal(err.Error())
+	}
+
+	insertAlert := func() {
+		alertContentBytes, err := ioutil.ReadFile("./tests/alert_sample.json")
+		if err != nil {
+			log.Fatal(err)
+		}
+		alertContent := string(alertContentBytes)
+		w := httptest.NewRecorder()
+		req, _ := http.NewRequest("POST", "/v1/alerts", strings.NewReader(alertContent))
+		AddAuthHeaders(req, loginResp)
+		router.ServeHTTP(w, req)
+	}
+
+	assertAlertDeleteFailedFromIP := func(ip string) {
+		w := httptest.NewRecorder()
+		req, _ := http.NewRequest("DELETE", "/v1/alerts", strings.NewReader(""))
+
+		AddAuthHeaders(req, loginResp)
+		req.RemoteAddr = ip + ":1234"
+		router.ServeHTTP(w, req)
+		assert.Equal(t, 403, w.Code)
+		assert.Contains(t, w.Body.String(), fmt.Sprintf(`{"message":"access forbidden from this IP (%s)"}`, ip))
+	}
+
+	assertAlertDeletedFromIP := func(ip string) {
+		w := httptest.NewRecorder()
+		req, _ := http.NewRequest("DELETE", "/v1/alerts", strings.NewReader(""))
+		AddAuthHeaders(req, loginResp)
+		req.RemoteAddr = ip + ":1234"
+		router.ServeHTTP(w, req)
+		assert.Equal(t, 200, w.Code)
+		assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
+	}
+
+	insertAlert()
+	assertAlertDeleteFailedFromIP("4.3.2.1")
+	assertAlertDeletedFromIP("1.2.3.4")
+
+	insertAlert()
+	assertAlertDeletedFromIP("1.2.4.0")
+	insertAlert()
+	assertAlertDeletedFromIP("1.2.4.1")
+	insertAlert()
+	assertAlertDeletedFromIP("1.2.4.255")
+
+	insertAlert()
+	assertAlertDeletedFromIP("127.0.0.1")
+
+}

+ 5 - 0
pkg/apiserver/apiserver.go

@@ -216,6 +216,11 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
 		apiClient = nil
 		controller.CAPIChan = nil
 	}
+	if trustedIPs, err := config.GetTrustedIPs(); err == nil {
+		controller.TrustedIPs = trustedIPs
+	} else {
+		return &APIServer{}, err
+	}
 
 	return &APIServer{
 		URL:            config.ListenURI,

+ 4 - 1
pkg/apiserver/controllers/controller.go

@@ -2,6 +2,7 @@ package controllers
 
 import (
 	"context"
+	"net"
 	"net/http"
 
 	"github.com/alexliesenfeld/health"
@@ -23,6 +24,7 @@ type Controller struct {
 	PluginChannel chan csplugin.ProfileAlert
 	Log           *log.Logger
 	ConsoleConfig *csconfig.ConsoleConfig
+	TrustedIPs    []net.IPNet
 }
 
 func (c *Controller) Init() error {
@@ -53,7 +55,8 @@ func serveHealth() http.HandlerFunc {
 }
 
 func (c *Controller) NewV1() error {
-	handlerV1, err := v1.New(c.DBClient, c.Ectx, c.Profiles, c.CAPIChan, c.PluginChannel, *c.ConsoleConfig)
+
+	handlerV1, err := v1.New(c.DBClient, c.Ectx, c.Profiles, c.CAPIChan, c.PluginChannel, *c.ConsoleConfig, c.TrustedIPs)
 	if err != nil {
 		return err
 	}

+ 14 - 3
pkg/apiserver/controllers/v1/alerts.go

@@ -3,6 +3,7 @@ package v1
 import (
 	"encoding/json"
 	"fmt"
+	"net"
 	"net/http"
 	"strconv"
 	"time"
@@ -236,9 +237,9 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) {
 
 // DeleteAlerts : delete alerts from database based on the specified filter
 func (c *Controller) DeleteAlerts(gctx *gin.Context) {
-
-	if gctx.ClientIP() != "127.0.0.1" && gctx.ClientIP() != "::1" {
-		gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", gctx.ClientIP())})
+	incomingIP := gctx.ClientIP()
+	if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) {
+		gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
 		return
 	}
 	var err error
@@ -252,3 +253,13 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) {
 	}
 	gctx.JSON(http.StatusOK, deleteAlertsResp)
 }
+
+func networksContainIP(networks []net.IPNet, ip string) bool {
+	parsedIP := net.ParseIP(ip)
+	for _, network := range networks {
+		if network.Contains(parsedIP) {
+			return true
+		}
+	}
+	return false
+}

+ 4 - 2
pkg/apiserver/controllers/v1/controller.go

@@ -2,6 +2,7 @@ package v1
 
 import (
 	"context"
+	"net"
 
 	middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
 	"github.com/crowdsecurity/crowdsec/pkg/csconfig"
@@ -19,9 +20,10 @@ type Controller struct {
 	CAPIChan      chan []*models.Alert
 	PluginChannel chan csplugin.ProfileAlert
 	ConsoleConfig csconfig.ConsoleConfig
+	TrustedIPs    []net.IPNet
 }
 
-func New(dbClient *database.Client, ctx context.Context, profiles []*csconfig.ProfileCfg, capiChan chan []*models.Alert, pluginChannel chan csplugin.ProfileAlert, consoleConfig csconfig.ConsoleConfig) (*Controller, error) {
+func New(dbClient *database.Client, ctx context.Context, profiles []*csconfig.ProfileCfg, capiChan chan []*models.Alert, pluginChannel chan csplugin.ProfileAlert, consoleConfig csconfig.ConsoleConfig, trustedIPs []net.IPNet) (*Controller, error) {
 	var err error
 	v1 := &Controller{
 		Ectx:          ctx,
@@ -31,11 +33,11 @@ func New(dbClient *database.Client, ctx context.Context, profiles []*csconfig.Pr
 		CAPIChan:      capiChan,
 		PluginChannel: pluginChannel,
 		ConsoleConfig: consoleConfig,
+		TrustedIPs:    trustedIPs,
 	}
 	v1.Middlewares, err = middlewares.NewMiddlewares(dbClient)
 	if err != nil {
 		return v1, err
 	}
-
 	return v1, nil
 }

+ 26 - 0
pkg/csconfig/api.go

@@ -3,6 +3,7 @@ package csconfig
 import (
 	"fmt"
 	"io/ioutil"
+	"net"
 	"strings"
 
 	"github.com/crowdsecurity/crowdsec/pkg/apiclient"
@@ -76,6 +77,30 @@ func (l *LocalApiClientCfg) Load() error {
 	return nil
 }
 
+func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
+	trustedIPs := make([]net.IPNet, 0)
+	for _, ip := range lapiCfg.TrustedIPs {
+		cidr := toValidCIDR(ip)
+		_, ipNet, err := net.ParseCIDR(cidr)
+		if err != nil {
+			return nil, err
+		}
+		trustedIPs = append(trustedIPs, *ipNet)
+	}
+	return trustedIPs, nil
+}
+
+func toValidCIDR(ip string) string {
+	if strings.Contains(ip, "/") {
+		return ip
+	}
+
+	if strings.Contains(ip, ":") {
+		return ip + "/128"
+	}
+	return ip + "/32"
+}
+
 /*local api service configuration*/
 type LocalApiServerCfg struct {
 	ListenURI              string              `yaml:"listen_uri,omitempty"` //127.0.0.1:8080
@@ -95,6 +120,7 @@ type LocalApiServerCfg struct {
 	LogMaxSize             int                 `yaml:"-"`
 	LogMaxAge              int                 `yaml:"-"`
 	LogMaxFiles            int                 `yaml:"-"`
+	TrustedIPs             []string            `yaml:"trusted_ips,omitempty"`
 }
 
 type TLSCfg struct {