Browse Source

add missing scenarios in first login when authenticating with TLS (#2454)

* refact jwt:Authenticator
* include scenarios in first login request for machines with tlsAuth
* log.Printf -> log.Infof
* errors.Wrap -> fmt.Errorf
* don't override validation error
* fix test
mmetc 1 year ago
parent
commit
0ecb6eefee
2 changed files with 142 additions and 108 deletions
  1. 1 1
      pkg/apiserver/jwt_test.go
  2. 141 107
      pkg/apiserver/middlewares/v1/jwt.go

+ 1 - 1
pkg/apiserver/jwt_test.go

@@ -55,7 +55,7 @@ func TestLogin(t *testing.T) {
 	router.ServeHTTP(w, req)
 	router.ServeHTTP(w, req)
 
 
 	assert.Equal(t, 401, w.Code)
 	assert.Equal(t, 401, w.Code)
-	assert.Equal(t, "{\"code\":401,\"message\":\"input format error\"}", w.Body.String())
+	assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String())
 
 
 	//Validate machine
 	//Validate machine
 	err = ValidateMachine("test", config.API.Server.DbConfig)
 	err = ValidateMachine("test", config.API.Server.DbConfig)

+ 141 - 107
pkg/apiserver/middlewares/v1/jwt.go

@@ -2,6 +2,7 @@ package v1
 
 
 import (
 import (
 	"crypto/rand"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
@@ -16,7 +17,6 @@ import (
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/crowdsecurity/crowdsec/pkg/types"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"github.com/go-openapi/strfmt"
 	"github.com/go-openapi/strfmt"
-	"github.com/pkg/errors"
 	log "github.com/sirupsen/logrus"
 	log "github.com/sirupsen/logrus"
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/crypto/bcrypt"
 )
 )
@@ -46,142 +46,176 @@ func IdentityHandler(c *gin.Context) interface{} {
 	}
 	}
 }
 }
 
 
-func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
-	var loginInput models.WatcherAuthRequest
-	var scenarios string
-	var err error
-	var scenariosInput []string
-	var clientMachine *ent.Machine
-	var machineID string
 
 
-	if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
-		if j.TlsAuth == nil {
-			c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
-			c.Abort()
-			return nil, errors.New("TLS auth is not configured")
+
+type authInput struct {
+	machineID string
+	clientMachine *ent.Machine
+	scenariosInput []string
+}
+
+
+
+func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
+	ret := authInput{}
+
+	if j.TlsAuth == nil {
+		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
+		c.Abort()
+		return nil, errors.New("TLS auth is not configured")
+	}
+
+	validCert, extractedCN, err := j.TlsAuth.ValidateCert(c)
+	if err != nil {
+		log.Error(err)
+		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
+		c.Abort()
+		return nil, fmt.Errorf("while trying to validate client cert: %w", err)
+	}
+
+	if !validCert {
+		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
+		c.Abort()
+		return nil, fmt.Errorf("failed cert authentication")
+	}
+
+	ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
+	ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
+		Where(machine.MachineId(ret.machineID)).
+		First(j.DbClient.CTX)
+	if ent.IsNotFound(err) {
+		//Machine was not found, let's create it
+		log.Infof("machine %s not found, create it", ret.machineID)
+		//let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
+		pwd, err := GenerateAPIKey(dummyAPIKeySize)
+		if err != nil {
+			log.WithFields(log.Fields{
+				"ip": c.ClientIP(),
+				"cn": extractedCN,
+			}).Errorf("error generating password: %s", err)
+			return nil, fmt.Errorf("error generating password")
 		}
 		}
-		validCert, extractedCN, err := j.TlsAuth.ValidateCert(c)
+		password := strfmt.Password(pwd)
+		ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
 		if err != nil {
 		if err != nil {
-			log.Error(err)
-			c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
-			c.Abort()
-			return nil, errors.Wrap(err, "while trying to validate client cert")
+			return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
 		}
 		}
-		if !validCert {
-			c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
-			c.Abort()
-			return nil, fmt.Errorf("failed cert authentication")
+	} else if err != nil {
+		return nil, fmt.Errorf("while selecting machine entry for %s: %w", ret.machineID, err)
+	} else {
+		if ret.clientMachine.AuthType != types.TlsAuthType {
+			return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
 		}
 		}
+		ret.machineID = ret.clientMachine.MachineId
+	}
 
 
-		machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
-		clientMachine, err = j.DbClient.Ent.Machine.Query().
-			Where(machine.MachineId(machineID)).
-			First(j.DbClient.CTX)
-		if ent.IsNotFound(err) {
-			//Machine was not found, let's create it
-			log.Printf("machine %s not found, create it", machineID)
-			//let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
-			pwd, err := GenerateAPIKey(dummyAPIKeySize)
-			if err != nil {
-				log.WithFields(log.Fields{
-					"ip": c.ClientIP(),
-					"cn": extractedCN,
-				}).Errorf("error generating password: %s", err)
-				return nil, fmt.Errorf("error generating password")
-			}
-			password := strfmt.Password(pwd)
-			clientMachine, err = j.DbClient.CreateMachine(&machineID, &password, "", true, true, types.TlsAuthType)
-			if err != nil {
-				return "", errors.Wrapf(err, "while creating machine entry for %s", machineID)
-			}
-		} else if err != nil {
-			return "", errors.Wrapf(err, "while selecting machine entry for %s", machineID)
-		} else {
-			if clientMachine.AuthType != types.TlsAuthType {
-				return "", errors.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", machineID, clientMachine.AuthType)
-			}
-			machineID = clientMachine.MachineId
-			loginInput := struct {
-				Scenarios []string `json:"scenarios"`
-			}{
-				Scenarios: []string{},
-			}
-			err := c.ShouldBindJSON(&loginInput)
-			if err != nil {
-				return "", errors.Wrap(err, "missing scenarios list in login request for TLS auth")
-			}
-			scenariosInput = loginInput.Scenarios
-		}
+	loginInput := struct {
+		Scenarios []string `json:"scenarios"`
+	}{
+		Scenarios: []string{},
+	}
+	err = c.ShouldBindJSON(&loginInput)
+	if err != nil {
+		return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
+	}
+	ret.scenariosInput = loginInput.Scenarios
 
 
-	} else {
-		//normal auth
+	return &ret, nil
+}
 
 
-		if err := c.ShouldBindJSON(&loginInput); err != nil {
-			return "", errors.Wrap(err, "missing")
-		}
-		if err := loginInput.Validate(strfmt.Default); err != nil {
-			return "", errors.New("input format error")
-		}
-		machineID = *loginInput.MachineID
-		password := *loginInput.Password
-		scenariosInput = loginInput.Scenarios
 
 
-		clientMachine, err = j.DbClient.Ent.Machine.Query().
-			Where(machine.MachineId(machineID)).
-			First(j.DbClient.CTX)
-		if err != nil {
-			log.Printf("Error machine login for %s : %+v ", machineID, err)
-			return nil, err
-		}
 
 
-		if clientMachine == nil {
-			log.Errorf("Nothing for '%s'", machineID)
-			return nil, jwt.ErrFailedAuthentication
-		}
+func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
+	var loginInput models.WatcherAuthRequest
+	var err error
 
 
-		if clientMachine.AuthType != types.PasswordAuthType {
-			return nil, errors.Errorf("machine %s attempted to auth with password but it is configured to use %s", machineID, clientMachine.AuthType)
-		}
+	ret := authInput{}
 
 
-		if !clientMachine.IsValidated {
-			return nil, fmt.Errorf("machine %s not validated", machineID)
-		}
+	if err = c.ShouldBindJSON(&loginInput); err != nil {
+		return nil, fmt.Errorf("missing: %w", err)
+	}
+	if err = loginInput.Validate(strfmt.Default); err != nil {
+		return nil, err
+	}
+	ret.machineID = *loginInput.MachineID
+	password := *loginInput.Password
+	ret.scenariosInput = loginInput.Scenarios
 
 
-		if err = bcrypt.CompareHashAndPassword([]byte(clientMachine.Password), []byte(password)); err != nil {
-			return nil, jwt.ErrFailedAuthentication
-		}
+	ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
+		Where(machine.MachineId(ret.machineID)).
+		First(j.DbClient.CTX)
+	if err != nil {
+		log.Infof("Error machine login for %s : %+v ", ret.machineID, err)
+		return nil, err
+	}
 
 
-		//end of normal auth
+	if ret.clientMachine == nil {
+		log.Errorf("Nothing for '%s'", ret.machineID)
+		return nil, jwt.ErrFailedAuthentication
+	}
+
+	if ret.clientMachine.AuthType != types.PasswordAuthType {
+		return nil, fmt.Errorf("machine %s attempted to auth with password but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
+	}
+
+	if !ret.clientMachine.IsValidated {
+		return nil, fmt.Errorf("machine %s not validated", ret.machineID)
+	}
+
+	if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil {
+		return nil, jwt.ErrFailedAuthentication
 	}
 	}
 
 
-	if len(scenariosInput) > 0 {
-		for _, scenario := range scenariosInput {
+	return &ret, nil
+}
+
+
+func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
+	var err error
+	var auth *authInput
+
+	if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
+		auth, err = j.authTLS(c)
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		auth, err = j.authPlain(c)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	var scenarios string
+
+	if len(auth.scenariosInput) > 0 {
+		for _, scenario := range auth.scenariosInput {
 			if scenarios == "" {
 			if scenarios == "" {
 				scenarios = scenario
 				scenarios = scenario
 			} else {
 			} else {
 				scenarios += "," + scenario
 				scenarios += "," + scenario
 			}
 			}
 		}
 		}
-		err = j.DbClient.UpdateMachineScenarios(scenarios, clientMachine.ID)
+		err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
 		if err != nil {
 		if err != nil {
-			log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err)
+			log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
 			return nil, jwt.ErrFailedAuthentication
 			return nil, jwt.ErrFailedAuthentication
 		}
 		}
 	}
 	}
 
 
-	if clientMachine.IpAddress == "" {
-		err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID)
+	if auth.clientMachine.IpAddress == "" {
+		err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
 		if err != nil {
 		if err != nil {
-			log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err)
+			log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
 			return nil, jwt.ErrFailedAuthentication
 			return nil, jwt.ErrFailedAuthentication
 		}
 		}
 	}
 	}
 
 
-	if clientMachine.IpAddress != c.ClientIP() && clientMachine.IpAddress != "" {
-		log.Warningf("new IP address detected for machine '%s': %s (old: %s)", clientMachine.MachineId, c.ClientIP(), clientMachine.IpAddress)
-		err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID)
+	if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" {
+		log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress)
+		err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
 		if err != nil {
 		if err != nil {
-			log.Errorf("Failed to update ip address for '%s': %s\n", clientMachine.MachineId, err)
+			log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
 			return nil, jwt.ErrFailedAuthentication
 			return nil, jwt.ErrFailedAuthentication
 		}
 		}
 	}
 	}
@@ -192,13 +226,13 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
 		return nil, jwt.ErrFailedAuthentication
 		return nil, jwt.ErrFailedAuthentication
 	}
 	}
 
 
-	if err := j.DbClient.UpdateMachineVersion(useragent[1], clientMachine.ID); err != nil {
-		log.Errorf("unable to update machine '%s' version '%s': %s", clientMachine.MachineId, useragent[1], err)
+	if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
+		log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
 		log.Errorf("bad user agent from : %s", c.ClientIP())
 		log.Errorf("bad user agent from : %s", c.ClientIP())
 		return nil, jwt.ErrFailedAuthentication
 		return nil, jwt.ErrFailedAuthentication
 	}
 	}
 	return &models.WatcherAuthRequest{
 	return &models.WatcherAuthRequest{
-		MachineID: &machineID,
+		MachineID: &auth.machineID,
 	}, nil
 	}, nil
 
 
 }
 }