jwt.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package v1
  2. import (
  3. "crypto/rand"
  4. "fmt"
  5. "os"
  6. "strings"
  7. "time"
  8. jwt "github.com/appleboy/gin-jwt/v2"
  9. "github.com/crowdsecurity/crowdsec/pkg/database"
  10. "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
  11. "github.com/crowdsecurity/crowdsec/pkg/models"
  12. "github.com/gin-gonic/gin"
  13. "github.com/go-openapi/strfmt"
  14. "github.com/pkg/errors"
  15. log "github.com/sirupsen/logrus"
  16. "golang.org/x/crypto/bcrypt"
  17. )
  18. var identityKey = "id"
  19. type JWT struct {
  20. Middleware *jwt.GinJWTMiddleware
  21. DbClient *database.Client
  22. }
  23. func PayloadFunc(data interface{}) jwt.MapClaims {
  24. if value, ok := data.(*models.WatcherAuthRequest); ok {
  25. return jwt.MapClaims{
  26. identityKey: &value.MachineID,
  27. }
  28. }
  29. return jwt.MapClaims{}
  30. }
  31. func IdentityHandler(c *gin.Context) interface{} {
  32. claims := jwt.ExtractClaims(c)
  33. machineId := claims[identityKey].(string)
  34. return &models.WatcherAuthRequest{
  35. MachineID: &machineId,
  36. }
  37. }
  38. func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
  39. var loginInput models.WatcherAuthRequest
  40. var scenarios string
  41. var err error
  42. if err := c.ShouldBindJSON(&loginInput); err != nil {
  43. return "", errors.Wrap(err, "missing")
  44. }
  45. if err := loginInput.Validate(strfmt.Default); err != nil {
  46. return "", errors.New("input format error")
  47. }
  48. machineID := *loginInput.MachineID
  49. password := *loginInput.Password
  50. scenariosInput := loginInput.Scenarios
  51. machine, err := j.DbClient.Ent.Machine.Query().
  52. Where(machine.MachineId(machineID)).
  53. First(j.DbClient.CTX)
  54. if err != nil {
  55. log.Printf("Error machine login for %s : %+v ", machineID, err)
  56. return nil, err
  57. }
  58. if machine == nil {
  59. log.Errorf("Nothing for '%s'", machineID)
  60. return nil, jwt.ErrFailedAuthentication
  61. }
  62. if !machine.IsValidated {
  63. return nil, fmt.Errorf("machine %s not validated", machineID)
  64. }
  65. if err = bcrypt.CompareHashAndPassword([]byte(machine.Password), []byte(password)); err != nil {
  66. return nil, jwt.ErrFailedAuthentication
  67. }
  68. if len(scenariosInput) > 0 {
  69. for _, scenario := range scenariosInput {
  70. if scenarios == "" {
  71. scenarios = scenario
  72. } else {
  73. scenarios += "," + scenario
  74. }
  75. }
  76. err = j.DbClient.UpdateMachineScenarios(scenarios, machine.ID)
  77. if err != nil {
  78. log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err)
  79. return nil, jwt.ErrFailedAuthentication
  80. }
  81. }
  82. if machine.IpAddress == "" {
  83. err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID)
  84. if err != nil {
  85. log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err)
  86. return nil, jwt.ErrFailedAuthentication
  87. }
  88. }
  89. if machine.IpAddress != c.ClientIP() && machine.IpAddress != "" {
  90. log.Warningf("new IP address detected for machine '%s': %s (old: %s)", machine.MachineId, c.ClientIP(), machine.IpAddress)
  91. err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID)
  92. if err != nil {
  93. log.Errorf("Failed to update ip address for '%s': %s\n", machine.MachineId, err)
  94. return nil, jwt.ErrFailedAuthentication
  95. }
  96. }
  97. useragent := strings.Split(c.Request.UserAgent(), "/")
  98. if len(useragent) != 2 {
  99. log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP())
  100. return nil, jwt.ErrFailedAuthentication
  101. }
  102. if err := j.DbClient.UpdateMachineVersion(useragent[1], machine.ID); err != nil {
  103. log.Errorf("unable to update machine '%s' version '%s': %s", machine.MachineId, useragent[1], err)
  104. log.Errorf("bad user agent from : %s", c.ClientIP())
  105. return nil, jwt.ErrFailedAuthentication
  106. }
  107. return &models.WatcherAuthRequest{
  108. MachineID: &machineID,
  109. }, nil
  110. }
  111. func Authorizator(data interface{}, c *gin.Context) bool {
  112. return true
  113. }
  114. func Unauthorized(c *gin.Context, code int, message string) {
  115. c.JSON(code, gin.H{
  116. "code": code,
  117. "message": message,
  118. })
  119. }
  120. func randomSecret() ([]byte, error) {
  121. size := 64
  122. secret := make([]byte, size)
  123. n, err := rand.Read(secret)
  124. if err != nil {
  125. return nil, errors.New("unable to generate a new random seed for JWT generation")
  126. }
  127. if n != size {
  128. return nil, errors.New("not enough entropy at random seed generation for JWT generation")
  129. }
  130. return secret, nil
  131. }
  132. func NewJWT(dbClient *database.Client) (*JWT, error) {
  133. // Get secret from environment variable "SECRET"
  134. var (
  135. secret []byte
  136. err error
  137. )
  138. // Please be aware that brute force HS256 is possible.
  139. // PLEASE choose a STRONG secret
  140. secretString := os.Getenv("CS_LAPI_SECRET")
  141. secret = []byte(secretString)
  142. switch l := len(secret); {
  143. case l == 0:
  144. secret, err = randomSecret()
  145. if err != nil {
  146. return &JWT{}, err
  147. }
  148. case l < 64:
  149. return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough")
  150. }
  151. jwtMiddleware := &JWT{
  152. DbClient: dbClient,
  153. }
  154. ret, err := jwt.New(&jwt.GinJWTMiddleware{
  155. Realm: "Crowdsec API local",
  156. Key: secret,
  157. Timeout: time.Hour,
  158. MaxRefresh: time.Hour,
  159. IdentityKey: identityKey,
  160. PayloadFunc: PayloadFunc,
  161. IdentityHandler: IdentityHandler,
  162. Authenticator: jwtMiddleware.Authenticator,
  163. Authorizator: Authorizator,
  164. Unauthorized: Unauthorized,
  165. TokenLookup: "header: Authorization, query: token, cookie: jwt",
  166. TokenHeadName: "Bearer",
  167. TimeFunc: time.Now,
  168. })
  169. errInit := ret.MiddlewareInit()
  170. if errInit != nil {
  171. return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
  172. }
  173. if err != nil {
  174. return &JWT{}, err
  175. }
  176. return &JWT{Middleware: ret}, nil
  177. }