diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index ebca91252..e5c3529cc 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -55,7 +55,7 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"input format error\"}", w.Body.String()) + assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String()) //Validate machine err = ValidateMachine("test", config.API.Server.DbConfig) diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index bbd33c544..22c171c63 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -2,6 +2,7 @@ package v1 import ( "crypto/rand" + "errors" "fmt" "net/http" "os" @@ -16,7 +17,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" ) @@ -46,142 +46,176 @@ func IdentityHandler(c *gin.Context) interface{} { } } -func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { - var loginInput models.WatcherAuthRequest - var scenarios string - var err error - var scenariosInput []string - var clientMachine *ent.Machine - var machineID string - if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - if j.TlsAuth == nil { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, errors.New("TLS auth is not configured") - } - validCert, extractedCN, err := j.TlsAuth.ValidateCert(c) - if err != nil { - log.Error(err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, errors.Wrap(err, "while trying to validate client cert") - } - if !validCert { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return nil, fmt.Errorf("failed cert authentication") - } - machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - clientMachine, err = j.DbClient.Ent.Machine.Query(). - Where(machine.MachineId(machineID)). - First(j.DbClient.CTX) - if ent.IsNotFound(err) { - //Machine was not found, let's create it - log.Printf("machine %s not found, create it", machineID) - //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) - pwd, err := GenerateAPIKey(dummyAPIKeySize) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("error generating password: %s", err) - return nil, fmt.Errorf("error generating password") - } - password := strfmt.Password(pwd) - clientMachine, err = j.DbClient.CreateMachine(&machineID, &password, "", true, true, types.TlsAuthType) - if err != nil { - return "", errors.Wrapf(err, "while creating machine entry for %s", machineID) - } - } else if err != nil { - return "", errors.Wrapf(err, "while selecting machine entry for %s", machineID) - } else { - if clientMachine.AuthType != types.TlsAuthType { - return "", errors.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", machineID, clientMachine.AuthType) - } - machineID = clientMachine.MachineId - loginInput := struct { - Scenarios []string `json:"scenarios"` - }{ - Scenarios: []string{}, - } - err := c.ShouldBindJSON(&loginInput) - if err != nil { - return "", errors.Wrap(err, "missing scenarios list in login request for TLS auth") - } - scenariosInput = loginInput.Scenarios - } +type authInput struct { + machineID string + clientMachine *ent.Machine + scenariosInput []string +} - } else { - //normal auth - if err := c.ShouldBindJSON(&loginInput); err != nil { - return "", errors.Wrap(err, "missing") - } - if err := loginInput.Validate(strfmt.Default); err != nil { - return "", errors.New("input format error") - } - machineID = *loginInput.MachineID - password := *loginInput.Password - scenariosInput = loginInput.Scenarios - clientMachine, err = j.DbClient.Ent.Machine.Query(). - Where(machine.MachineId(machineID)). - First(j.DbClient.CTX) - if err != nil { - log.Printf("Error machine login for %s : %+v ", machineID, err) - return nil, err - } +func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + ret := authInput{} - if clientMachine == nil { - log.Errorf("Nothing for '%s'", machineID) - return nil, jwt.ErrFailedAuthentication - } - - if clientMachine.AuthType != types.PasswordAuthType { - return nil, errors.Errorf("machine %s attempted to auth with password but it is configured to use %s", machineID, clientMachine.AuthType) - } - - if !clientMachine.IsValidated { - return nil, fmt.Errorf("machine %s not validated", machineID) - } - - if err = bcrypt.CompareHashAndPassword([]byte(clientMachine.Password), []byte(password)); err != nil { - return nil, jwt.ErrFailedAuthentication - } - - //end of normal auth + if j.TlsAuth == nil { + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, errors.New("TLS auth is not configured") } - if len(scenariosInput) > 0 { - for _, scenario := range scenariosInput { + validCert, extractedCN, err := j.TlsAuth.ValidateCert(c) + if err != nil { + log.Error(err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, fmt.Errorf("while trying to validate client cert: %w", err) + } + + if !validCert { + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, fmt.Errorf("failed cert authentication") + } + + ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(ret.machineID)). + First(j.DbClient.CTX) + if ent.IsNotFound(err) { + //Machine was not found, let's create it + log.Infof("machine %s not found, create it", ret.machineID) + //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) + pwd, err := GenerateAPIKey(dummyAPIKeySize) + if err != nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("error generating password: %s", err) + return nil, fmt.Errorf("error generating password") + } + password := strfmt.Password(pwd) + ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + if err != nil { + return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) + } + } else if err != nil { + return nil, fmt.Errorf("while selecting machine entry for %s: %w", ret.machineID, err) + } else { + if ret.clientMachine.AuthType != types.TlsAuthType { + return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) + } + ret.machineID = ret.clientMachine.MachineId + } + + loginInput := struct { + Scenarios []string `json:"scenarios"` + }{ + Scenarios: []string{}, + } + err = c.ShouldBindJSON(&loginInput) + if err != nil { + return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err) + } + ret.scenariosInput = loginInput.Scenarios + + return &ret, nil +} + + + +func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { + var loginInput models.WatcherAuthRequest + var err error + + ret := authInput{} + + if err = c.ShouldBindJSON(&loginInput); err != nil { + return nil, fmt.Errorf("missing: %w", err) + } + if err = loginInput.Validate(strfmt.Default); err != nil { + return nil, err + } + ret.machineID = *loginInput.MachineID + password := *loginInput.Password + ret.scenariosInput = loginInput.Scenarios + + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(ret.machineID)). + First(j.DbClient.CTX) + if err != nil { + log.Infof("Error machine login for %s : %+v ", ret.machineID, err) + return nil, err + } + + if ret.clientMachine == nil { + log.Errorf("Nothing for '%s'", ret.machineID) + return nil, jwt.ErrFailedAuthentication + } + + if ret.clientMachine.AuthType != types.PasswordAuthType { + return nil, fmt.Errorf("machine %s attempted to auth with password but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType) + } + + if !ret.clientMachine.IsValidated { + return nil, fmt.Errorf("machine %s not validated", ret.machineID) + } + + if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil { + return nil, jwt.ErrFailedAuthentication + } + + return &ret, nil +} + + +func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { + var err error + var auth *authInput + + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { + auth, err = j.authTLS(c) + if err != nil { + return nil, err + } + } else { + auth, err = j.authPlain(c) + if err != nil { + return nil, err + } + } + + var scenarios string + + if len(auth.scenariosInput) > 0 { + for _, scenario := range auth.scenariosInput { if scenarios == "" { scenarios = scenario } else { scenarios += "," + scenario } } - err = j.DbClient.UpdateMachineScenarios(scenarios, clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err) + log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } - if clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) + if auth.clientMachine.IpAddress == "" { + err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err) + log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication } } - if clientMachine.IpAddress != c.ClientIP() && clientMachine.IpAddress != "" { - log.Warningf("new IP address detected for machine '%s': %s (old: %s)", clientMachine.MachineId, c.ClientIP(), clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) + if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" { + log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress) + err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", clientMachine.MachineId, err) + log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication } } @@ -192,13 +226,13 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], clientMachine.ID); err != nil { - log.Errorf("unable to update machine '%s' version '%s': %s", clientMachine.MachineId, useragent[1], err) + if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", c.ClientIP()) return nil, jwt.ErrFailedAuthentication } return &models.WatcherAuthRequest{ - MachineID: &machineID, + MachineID: &auth.machineID, }, nil }