refact pkg/apiserver (auth helpers) (#2856)
This commit is contained in:
parent
e34af358d7
commit
4bf640c6e8
6 changed files with 50 additions and 45 deletions
|
@ -9,7 +9,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/google/uuid"
|
||||
|
@ -143,9 +142,7 @@ func normalizeScope(scope string) string {
|
|||
func (c *Controller) CreateAlert(gctx *gin.Context) {
|
||||
var input models.AddAlertsRequest
|
||||
|
||||
claims := jwt.ExtractClaims(gctx)
|
||||
// TBD: use defined rather than hardcoded key to find back owner
|
||||
machineID := claims["id"].(string)
|
||||
machineID, _ := getMachineIDFromContext(gctx)
|
||||
|
||||
if err := gctx.ShouldBindJSON(&input); err != nil {
|
||||
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||
|
|
|
@ -3,14 +3,11 @@ package v1
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (c *Controller) HeartBeat(gctx *gin.Context) {
|
||||
claims := jwt.ExtractClaims(gctx)
|
||||
// TBD: use defined rather than hardcoded key to find back owner
|
||||
machineID := claims["id"].(string)
|
||||
machineID, _ := getMachineIDFromContext(gctx)
|
||||
|
||||
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
|
||||
c.HandleDBErrors(gctx, err)
|
||||
|
|
|
@ -3,7 +3,6 @@ package v1
|
|||
import (
|
||||
"time"
|
||||
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec(
|
|||
[]string{"endpoint", "method"})
|
||||
|
||||
func PrometheusBouncersHasEmptyDecision(c *gin.Context) {
|
||||
name, ok := c.Get("BOUNCER_NAME")
|
||||
if ok {
|
||||
bouncer, _ := getBouncerFromContext(c)
|
||||
if bouncer != nil {
|
||||
LapiNilDecisions.With(prometheus.Labels{
|
||||
"bouncer": name.(string)}).Inc()
|
||||
"bouncer": bouncer.Name}).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
|
||||
name, ok := c.Get("BOUNCER_NAME")
|
||||
if ok {
|
||||
bouncer, _ := getBouncerFromContext(c)
|
||||
if bouncer != nil {
|
||||
LapiNonNilDecisions.With(prometheus.Labels{
|
||||
"bouncer": name.(string)}).Inc()
|
||||
"bouncer": bouncer.Name}).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func PrometheusMachinesMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
claims := jwt.ExtractClaims(c)
|
||||
if claims != nil {
|
||||
if rawID, ok := claims["id"]; ok {
|
||||
machineID := rawID.(string)
|
||||
LapiMachineHits.With(prometheus.Labels{
|
||||
"machine": machineID,
|
||||
"route": c.Request.URL.Path,
|
||||
"method": c.Request.Method}).Inc()
|
||||
}
|
||||
machineID, _ := getMachineIDFromContext(c)
|
||||
if machineID != "" {
|
||||
LapiMachineHits.With(prometheus.Labels{
|
||||
"machine": machineID,
|
||||
"route": c.Request.URL.Path,
|
||||
"method": c.Request.Method}).Inc()
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc {
|
|||
|
||||
func PrometheusBouncersMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
name, ok := c.Get("BOUNCER_NAME")
|
||||
if ok {
|
||||
bouncer, _ := getBouncerFromContext(c)
|
||||
if bouncer != nil {
|
||||
LapiBouncerHits.With(prometheus.Labels{
|
||||
"bouncer": name.(string),
|
||||
"bouncer": bouncer.Name,
|
||||
"route": c.Request.URL.Path,
|
||||
"method": c.Request.Method}).Inc()
|
||||
}
|
||||
|
|
|
@ -1,30 +1,50 @@
|
|||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
jwt "github.com/appleboy/gin-jwt/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||
)
|
||||
|
||||
const bouncerContextKey = "bouncer_info"
|
||||
|
||||
func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
|
||||
bouncerInterface, exist := ctx.Get(bouncerContextKey)
|
||||
bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey)
|
||||
if !exist {
|
||||
return nil, fmt.Errorf("bouncer not found")
|
||||
return nil, errors.New("bouncer not found")
|
||||
}
|
||||
|
||||
bouncerInfo, ok := bouncerInterface.(*ent.Bouncer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("bouncer not found")
|
||||
return nil, errors.New("bouncer not found")
|
||||
}
|
||||
|
||||
return bouncerInfo, nil
|
||||
}
|
||||
|
||||
func getMachineIDFromContext(ctx *gin.Context) (string, error) {
|
||||
claims := jwt.ExtractClaims(ctx)
|
||||
if claims == nil {
|
||||
return "", errors.New("failed to extract claims")
|
||||
}
|
||||
|
||||
rawID, ok := claims[middlewares.MachineIDKey]
|
||||
if !ok {
|
||||
return "", errors.New("MachineID not found in claims")
|
||||
}
|
||||
|
||||
id, ok := rawID.(string)
|
||||
if !ok {
|
||||
// should never happen
|
||||
return "", errors.New("failed to cast machineID to string")
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc {
|
||||
return func(gctx *gin.Context) {
|
||||
incomingIP := gctx.ClientIP()
|
||||
|
|
|
@ -18,9 +18,9 @@ import (
|
|||
|
||||
const (
|
||||
APIKeyHeader = "X-Api-Key"
|
||||
bouncerContextKey = "bouncer_info"
|
||||
// max allowed by bcrypt 72 = 54 bytes in base64
|
||||
BouncerContextKey = "bouncer_info"
|
||||
dummyAPIKeySize = 54
|
||||
// max allowed by bcrypt 72 = 54 bytes in base64
|
||||
)
|
||||
|
||||
type APIKey struct {
|
||||
|
@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
|
|||
"name": bouncer.Name,
|
||||
})
|
||||
|
||||
// maybe we want to store the whole bouncer object in the context instead, this would avoid another db query
|
||||
// in StreamDecision
|
||||
c.Set("BOUNCER_NAME", bouncer.Name)
|
||||
c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey)
|
||||
|
||||
if bouncer.IPAddress == "" {
|
||||
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
|
||||
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
|
||||
|
@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
c.Set(bouncerContextKey, bouncer)
|
||||
c.Set(BouncerContextKey, bouncer)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import (
|
|||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||
)
|
||||
|
||||
var identityKey = "id"
|
||||
const MachineIDKey = "id"
|
||||
|
||||
type JWT struct {
|
||||
Middleware *jwt.GinJWTMiddleware
|
||||
|
@ -33,7 +33,7 @@ type JWT struct {
|
|||
func PayloadFunc(data interface{}) jwt.MapClaims {
|
||||
if value, ok := data.(*models.WatcherAuthRequest); ok {
|
||||
return jwt.MapClaims{
|
||||
identityKey: &value.MachineID,
|
||||
MachineIDKey: &value.MachineID,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
|
|||
|
||||
func IdentityHandler(c *gin.Context) interface{} {
|
||||
claims := jwt.ExtractClaims(c)
|
||||
machineID := claims[identityKey].(string)
|
||||
machineID := claims[MachineIDKey].(string)
|
||||
|
||||
return &models.WatcherAuthRequest{
|
||||
MachineID: &machineID,
|
||||
|
@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
|
|||
Key: secret,
|
||||
Timeout: time.Hour,
|
||||
MaxRefresh: time.Hour,
|
||||
IdentityKey: identityKey,
|
||||
IdentityKey: MachineIDKey,
|
||||
PayloadFunc: PayloadFunc,
|
||||
IdentityHandler: IdentityHandler,
|
||||
Authenticator: jwtMiddleware.Authenticator,
|
||||
|
|
Loading…
Reference in a new issue