jwt.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. package v1
  2. import (
  3. "crypto/rand"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "os"
  8. "strings"
  9. "time"
  10. jwt "github.com/appleboy/gin-jwt/v2"
  11. "github.com/crowdsecurity/crowdsec/pkg/database"
  12. "github.com/crowdsecurity/crowdsec/pkg/database/ent"
  13. "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
  14. "github.com/crowdsecurity/crowdsec/pkg/models"
  15. "github.com/crowdsecurity/crowdsec/pkg/types"
  16. "github.com/gin-gonic/gin"
  17. "github.com/go-openapi/strfmt"
  18. log "github.com/sirupsen/logrus"
  19. "golang.org/x/crypto/bcrypt"
  20. )
  21. var identityKey = "id"
  22. type JWT struct {
  23. Middleware *jwt.GinJWTMiddleware
  24. DbClient *database.Client
  25. TlsAuth *TLSAuth
  26. }
  27. func PayloadFunc(data interface{}) jwt.MapClaims {
  28. if value, ok := data.(*models.WatcherAuthRequest); ok {
  29. return jwt.MapClaims{
  30. identityKey: &value.MachineID,
  31. }
  32. }
  33. return jwt.MapClaims{}
  34. }
  35. func IdentityHandler(c *gin.Context) interface{} {
  36. claims := jwt.ExtractClaims(c)
  37. machineId := claims[identityKey].(string)
  38. return &models.WatcherAuthRequest{
  39. MachineID: &machineId,
  40. }
  41. }
  42. type authInput struct {
  43. machineID string
  44. clientMachine *ent.Machine
  45. scenariosInput []string
  46. }
  47. func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
  48. ret := authInput{}
  49. if j.TlsAuth == nil {
  50. c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
  51. c.Abort()
  52. return nil, errors.New("TLS auth is not configured")
  53. }
  54. validCert, extractedCN, err := j.TlsAuth.ValidateCert(c)
  55. if err != nil {
  56. log.Error(err)
  57. c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
  58. c.Abort()
  59. return nil, fmt.Errorf("while trying to validate client cert: %w", err)
  60. }
  61. if !validCert {
  62. c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
  63. c.Abort()
  64. return nil, fmt.Errorf("failed cert authentication")
  65. }
  66. ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
  67. ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
  68. Where(machine.MachineId(ret.machineID)).
  69. First(j.DbClient.CTX)
  70. if ent.IsNotFound(err) {
  71. //Machine was not found, let's create it
  72. log.Infof("machine %s not found, create it", ret.machineID)
  73. //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
  74. pwd, err := GenerateAPIKey(dummyAPIKeySize)
  75. if err != nil {
  76. log.WithFields(log.Fields{
  77. "ip": c.ClientIP(),
  78. "cn": extractedCN,
  79. }).Errorf("error generating password: %s", err)
  80. return nil, fmt.Errorf("error generating password")
  81. }
  82. password := strfmt.Password(pwd)
  83. ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
  84. if err != nil {
  85. return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
  86. }
  87. } else if err != nil {
  88. return nil, fmt.Errorf("while selecting machine entry for %s: %w", ret.machineID, err)
  89. } else {
  90. if ret.clientMachine.AuthType != types.TlsAuthType {
  91. return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
  92. }
  93. ret.machineID = ret.clientMachine.MachineId
  94. }
  95. loginInput := struct {
  96. Scenarios []string `json:"scenarios"`
  97. }{
  98. Scenarios: []string{},
  99. }
  100. err = c.ShouldBindJSON(&loginInput)
  101. if err != nil {
  102. return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
  103. }
  104. ret.scenariosInput = loginInput.Scenarios
  105. return &ret, nil
  106. }
  107. func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
  108. var loginInput models.WatcherAuthRequest
  109. var err error
  110. ret := authInput{}
  111. if err = c.ShouldBindJSON(&loginInput); err != nil {
  112. return nil, fmt.Errorf("missing: %w", err)
  113. }
  114. if err = loginInput.Validate(strfmt.Default); err != nil {
  115. return nil, err
  116. }
  117. ret.machineID = *loginInput.MachineID
  118. password := *loginInput.Password
  119. ret.scenariosInput = loginInput.Scenarios
  120. ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
  121. Where(machine.MachineId(ret.machineID)).
  122. First(j.DbClient.CTX)
  123. if err != nil {
  124. log.Infof("Error machine login for %s : %+v ", ret.machineID, err)
  125. return nil, err
  126. }
  127. if ret.clientMachine == nil {
  128. log.Errorf("Nothing for '%s'", ret.machineID)
  129. return nil, jwt.ErrFailedAuthentication
  130. }
  131. if ret.clientMachine.AuthType != types.PasswordAuthType {
  132. return nil, fmt.Errorf("machine %s attempted to auth with password but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
  133. }
  134. if !ret.clientMachine.IsValidated {
  135. return nil, fmt.Errorf("machine %s not validated", ret.machineID)
  136. }
  137. if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil {
  138. return nil, jwt.ErrFailedAuthentication
  139. }
  140. return &ret, nil
  141. }
  142. func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
  143. var err error
  144. var auth *authInput
  145. if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
  146. auth, err = j.authTLS(c)
  147. if err != nil {
  148. return nil, err
  149. }
  150. } else {
  151. auth, err = j.authPlain(c)
  152. if err != nil {
  153. return nil, err
  154. }
  155. }
  156. var scenarios string
  157. if len(auth.scenariosInput) > 0 {
  158. for _, scenario := range auth.scenariosInput {
  159. if scenarios == "" {
  160. scenarios = scenario
  161. } else {
  162. scenarios += "," + scenario
  163. }
  164. }
  165. err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
  166. if err != nil {
  167. log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
  168. return nil, jwt.ErrFailedAuthentication
  169. }
  170. }
  171. if auth.clientMachine.IpAddress == "" {
  172. err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
  173. if err != nil {
  174. log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
  175. return nil, jwt.ErrFailedAuthentication
  176. }
  177. }
  178. if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" {
  179. log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress)
  180. err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
  181. if err != nil {
  182. log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
  183. return nil, jwt.ErrFailedAuthentication
  184. }
  185. }
  186. useragent := strings.Split(c.Request.UserAgent(), "/")
  187. if len(useragent) != 2 {
  188. log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP())
  189. return nil, jwt.ErrFailedAuthentication
  190. }
  191. if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
  192. log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
  193. log.Errorf("bad user agent from : %s", c.ClientIP())
  194. return nil, jwt.ErrFailedAuthentication
  195. }
  196. return &models.WatcherAuthRequest{
  197. MachineID: &auth.machineID,
  198. }, nil
  199. }
  200. func Authorizator(data interface{}, c *gin.Context) bool {
  201. return true
  202. }
  203. func Unauthorized(c *gin.Context, code int, message string) {
  204. c.JSON(code, gin.H{
  205. "code": code,
  206. "message": message,
  207. })
  208. }
  209. func randomSecret() ([]byte, error) {
  210. size := 64
  211. secret := make([]byte, size)
  212. n, err := rand.Read(secret)
  213. if err != nil {
  214. return nil, errors.New("unable to generate a new random seed for JWT generation")
  215. }
  216. if n != size {
  217. return nil, errors.New("not enough entropy at random seed generation for JWT generation")
  218. }
  219. return secret, nil
  220. }
  221. func NewJWT(dbClient *database.Client) (*JWT, error) {
  222. // Get secret from environment variable "SECRET"
  223. var (
  224. secret []byte
  225. err error
  226. )
  227. // Please be aware that brute force HS256 is possible.
  228. // PLEASE choose a STRONG secret
  229. secretString := os.Getenv("CS_LAPI_SECRET")
  230. secret = []byte(secretString)
  231. switch l := len(secret); {
  232. case l == 0:
  233. secret, err = randomSecret()
  234. if err != nil {
  235. return &JWT{}, err
  236. }
  237. case l < 64:
  238. return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough")
  239. }
  240. jwtMiddleware := &JWT{
  241. DbClient: dbClient,
  242. TlsAuth: &TLSAuth{},
  243. }
  244. ret, err := jwt.New(&jwt.GinJWTMiddleware{
  245. Realm: "Crowdsec API local",
  246. Key: secret,
  247. Timeout: time.Hour,
  248. MaxRefresh: time.Hour,
  249. IdentityKey: identityKey,
  250. PayloadFunc: PayloadFunc,
  251. IdentityHandler: IdentityHandler,
  252. Authenticator: jwtMiddleware.Authenticator,
  253. Authorizator: Authorizator,
  254. Unauthorized: Unauthorized,
  255. TokenLookup: "header: Authorization, query: token, cookie: jwt",
  256. TokenHeadName: "Bearer",
  257. TimeFunc: time.Now,
  258. })
  259. if err != nil {
  260. return &JWT{}, err
  261. }
  262. errInit := ret.MiddlewareInit()
  263. if errInit != nil {
  264. return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
  265. }
  266. jwtMiddleware.Middleware = ret
  267. return jwtMiddleware, nil
  268. }