sftpgo/internal/httpd/auth_utils.go
Nicola Murino d94f80c8da
replace utils.Contains with slices.Contains
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-07-24 18:27:13 +02:00

572 lines
16 KiB
Go

// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package httpd
import (
"errors"
"fmt"
"net/http"
"slices"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/rs/xid"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
type tokenAudience = string
const (
tokenAudienceWebAdmin tokenAudience = "WebAdmin"
tokenAudienceWebClient tokenAudience = "WebClient"
tokenAudienceWebShare tokenAudience = "WebShare"
tokenAudienceWebAdminPartial tokenAudience = "WebAdminPartial"
tokenAudienceWebClientPartial tokenAudience = "WebClientPartial"
tokenAudienceAPI tokenAudience = "API"
tokenAudienceAPIUser tokenAudience = "APIUser"
tokenAudienceCSRF tokenAudience = "CSRF"
tokenAudienceOAuth2 tokenAudience = "OAuth2"
tokenAudienceWebLogin tokenAudience = "WebLogin"
)
type tokenValidation = int
const (
tokenValidationFull = iota
tokenValidationNoIPMatch tokenValidation = iota
)
const (
claimUsernameKey = "username"
claimPermissionsKey = "permissions"
claimRole = "role"
claimAPIKey = "api_key"
claimNodeID = "node_id"
claimMustChangePasswordKey = "chpwd"
claimMustSetSecondFactorKey = "2fa_required"
claimRequiredTwoFactorProtocols = "2fa_protos"
claimHideUserPageSection = "hus"
claimRef = "ref"
basicRealm = "Basic realm=\"SFTPGo\""
jwtCookieKey = "jwt"
)
var (
tokenDuration = 20 * time.Minute
shareTokenDuration = 2 * time.Hour
// csrf token duration is greater than normal token duration to reduce issues
// with the login form
csrfTokenDuration = 4 * time.Hour
tokenRefreshThreshold = 10 * time.Minute
tokenValidationMode = tokenValidationFull
)
type jwtTokenClaims struct {
Username string
Permissions []string
Role string
Signature string
Audience []string
APIKeyID string
NodeID string
MustSetTwoFactorAuth bool
MustChangePassword bool
RequiredTwoFactorProtocols []string
HideUserPageSections int
JwtID string
Ref string
}
func (c *jwtTokenClaims) hasUserAudience() bool {
for _, audience := range c.Audience {
if audience == tokenAudienceWebClient || audience == tokenAudienceAPIUser {
return true
}
}
return false
}
func (c *jwtTokenClaims) asMap() map[string]any {
claims := make(map[string]any)
claims[claimUsernameKey] = c.Username
claims[claimPermissionsKey] = c.Permissions
if c.JwtID != "" {
claims[jwt.JwtIDKey] = c.JwtID
}
if c.Ref != "" {
claims[claimRef] = c.Ref
}
if c.Role != "" {
claims[claimRole] = c.Role
}
if c.APIKeyID != "" {
claims[claimAPIKey] = c.APIKeyID
}
if c.NodeID != "" {
claims[claimNodeID] = c.NodeID
}
claims[jwt.SubjectKey] = c.Signature
if c.MustChangePassword {
claims[claimMustChangePasswordKey] = c.MustChangePassword
}
if c.MustSetTwoFactorAuth {
claims[claimMustSetSecondFactorKey] = c.MustSetTwoFactorAuth
}
if len(c.RequiredTwoFactorProtocols) > 0 {
claims[claimRequiredTwoFactorProtocols] = c.RequiredTwoFactorProtocols
}
if c.HideUserPageSections > 0 {
claims[claimHideUserPageSection] = c.HideUserPageSections
}
return claims
}
func (c *jwtTokenClaims) decodeSliceString(val any) []string {
switch v := val.(type) {
case []any:
result := make([]string, 0, len(v))
for _, elem := range v {
switch elemValue := elem.(type) {
case string:
result = append(result, elemValue)
}
}
return result
case []string:
return v
default:
return nil
}
}
func (c *jwtTokenClaims) decodeBoolean(val any) bool {
switch v := val.(type) {
case bool:
return v
default:
return false
}
}
func (c *jwtTokenClaims) decodeString(val any) string {
switch v := val.(type) {
case string:
return v
default:
return ""
}
}
func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Permissions = nil
c.Username = c.decodeString(token[claimUsernameKey])
c.Signature = c.decodeString(token[jwt.SubjectKey])
c.JwtID = c.decodeString(token[jwt.JwtIDKey])
audience := token[jwt.AudienceKey]
switch v := audience.(type) {
case []string:
c.Audience = v
}
if val, ok := token[claimRef]; ok {
c.Ref = c.decodeString(val)
}
if val, ok := token[claimAPIKey]; ok {
c.APIKeyID = c.decodeString(val)
}
if val, ok := token[claimNodeID]; ok {
c.NodeID = c.decodeString(val)
}
if val, ok := token[claimRole]; ok {
c.Role = c.decodeString(val)
}
permissions := token[claimPermissionsKey]
c.Permissions = c.decodeSliceString(permissions)
if val, ok := token[claimMustChangePasswordKey]; ok {
c.MustChangePassword = c.decodeBoolean(val)
}
if val, ok := token[claimMustSetSecondFactorKey]; ok {
c.MustSetTwoFactorAuth = c.decodeBoolean(val)
}
if val, ok := token[claimRequiredTwoFactorProtocols]; ok {
c.RequiredTwoFactorProtocols = c.decodeSliceString(val)
}
if val, ok := token[claimHideUserPageSection]; ok {
switch v := val.(type) {
case float64:
c.HideUserPageSections = int(v)
}
}
}
func (c *jwtTokenClaims) isCriticalPermRemoved(permissions []string) bool {
if slices.Contains(permissions, dataprovider.PermAdminAny) {
return false
}
if (slices.Contains(c.Permissions, dataprovider.PermAdminManageAdmins) ||
slices.Contains(c.Permissions, dataprovider.PermAdminAny)) &&
!slices.Contains(permissions, dataprovider.PermAdminManageAdmins) &&
!slices.Contains(permissions, dataprovider.PermAdminAny) {
return true
}
return false
}
func (c *jwtTokenClaims) hasPerm(perm string) bool {
if slices.Contains(c.Permissions, dataprovider.PermAdminAny) {
return true
}
return slices.Contains(c.Permissions, perm)
}
func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) {
claims := c.asMap()
now := time.Now().UTC()
if _, ok := claims[jwt.JwtIDKey]; !ok {
claims[jwt.JwtIDKey] = xid.New().String()
}
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
if audience == tokenAudienceWebLogin {
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
} else {
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
}
claims[jwt.AudienceKey] = []string{audience, ip}
return tokenAuth.Encode(claims)
}
func (c *jwtTokenClaims) createTokenResponse(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (map[string]any, error) {
token, tokenString, err := c.createToken(tokenAuth, audience, ip)
if err != nil {
return nil, err
}
response := make(map[string]any)
response["access_token"] = tokenString
response["expires_at"] = token.Expiration().Format(time.RFC3339)
return response, nil
}
func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Request, tokenAuth *jwtauth.JWTAuth,
audience tokenAudience, ip string,
) error {
resp, err := c.createTokenResponse(tokenAuth, audience, ip)
if err != nil {
return err
}
var basePath string
if audience == tokenAudienceWebAdmin || audience == tokenAudienceWebAdminPartial {
basePath = webBaseAdminPath
} else {
basePath = webBaseClientPath
}
duration := tokenDuration
if audience == tokenAudienceWebShare {
duration = shareTokenDuration
}
setCookie(w, r, basePath, resp["access_token"].(string), duration)
return nil
}
func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey,
Value: cookieValue,
Path: cookiePath,
Expires: time.Now().Add(duration),
MaxAge: int(duration / time.Second),
HttpOnly: true,
Secure: isTLS(r),
SameSite: http.SameSiteStrictMode,
})
}
func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey,
Value: "",
Path: cookiePath,
Expires: time.Unix(0, 0),
MaxAge: -1,
HttpOnly: true,
Secure: isTLS(r),
SameSite: http.SameSiteStrictMode,
})
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
invalidateToken(r, false)
}
func oidcTokenFromContext(r *http.Request) string {
if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok {
return token
}
return ""
}
func isTLS(r *http.Request) bool {
if r.TLS != nil {
return true
}
if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok {
return proto == "https"
}
return false
}
func isTokenInvalidated(r *http.Request) bool {
var findTokenFns []func(r *http.Request) string
findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader)
findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie)
findTokenFns = append(findTokenFns, oidcTokenFromContext)
isTokenFound := false
for _, fn := range findTokenFns {
token := fn(r)
if token != "" {
isTokenFound = true
if invalidatedJWTTokens.Get(token) {
return true
}
}
}
return !isTokenFound
}
func invalidateToken(r *http.Request, isLoginToken bool) {
duration := tokenDuration
if isLoginToken {
duration = csrfTokenDuration
}
tokenString := jwtauth.TokenFromHeader(r)
if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
}
tokenString = jwtauth.TokenFromCookie(r)
if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
}
}
func getUserFromToken(r *http.Request) *dataprovider.User {
user := &dataprovider.User{}
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
return user
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
user.Username = tokenClaims.Username
user.Filters.WebClient = tokenClaims.Permissions
user.Role = tokenClaims.Role
return user
}
func getAdminFromToken(r *http.Request) *dataprovider.Admin {
admin := &dataprovider.Admin{}
_, claims, err := jwtauth.FromContext(r.Context())
if err != nil {
return admin
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
admin.Username = tokenClaims.Username
admin.Permissions = tokenClaims.Permissions
admin.Filters.Preferences.HideUserPageSections = tokenClaims.HideUserPageSections
admin.Role = tokenClaims.Role
return admin
}
func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string,
) {
c := jwtTokenClaims{
JwtID: tokenID,
}
resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip)
if err != nil {
return
}
setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration)
}
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID,
basePath string,
) string {
ip := util.GetIPFromRemoteAddress(r.RemoteAddr)
claims := make(map[string]any)
now := time.Now().UTC()
claims[jwt.JwtIDKey] = xid.New().String()
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip}
if tokenID != "" {
createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip)
claims[claimRef] = tokenID
} else {
if c, err := getTokenClaims(r); err == nil {
claims[claimRef] = c.JwtID
} else {
logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err)
}
}
_, tokenString, err := csrfTokenAuth.Encode(claims)
if err != nil {
logger.Debug(logSender, "", "unable to create CSRF token: %v", err)
return ""
}
return tokenString
}
func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
tokenString := r.Form.Get(csrfFormToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err)
return fmt.Errorf("unable to verify form token: %v", err)
}
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF token audience")
return errors.New("the form token is not valid")
}
if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
logger.Debug(logSender, "", "error validating CSRF token IP audience")
return errors.New("the form token is not valid")
}
claims, err := getTokenClaims(r)
if err != nil {
logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err)
return err
}
ref, ok := token.Get(claimRef)
if !ok {
logger.Debug(logSender, "", "error validating CSRF token, missing reference")
return errors.New("the form token is not valid")
}
if claims.JwtID == "" || claims.JwtID != ref.(string) {
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref)
return errors.New("unexpected form token")
}
return nil
}
func verifyLoginCookie(r *http.Request) error {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
logger.Debug(logSender, "", "error getting login token: %v", err)
return errInvalidToken
}
if isTokenInvalidated(r) {
logger.Debug(logSender, "", "the login token has been invalidated")
return errInvalidToken
}
if !slices.Contains(token.Audience(), tokenAudienceWebLogin) {
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
return err
}
return nil
}
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
if err := verifyLoginCookie(r); err != nil {
return err
}
if err := verifyCSRFToken(r, csrfTokenAuth); err != nil {
return err
}
return nil
}
func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string {
claims := make(map[string]any)
now := time.Now().UTC()
claims[jwt.JwtIDKey] = state
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(3 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, ip}
_, tokenString, err := csrfTokenAuth.Encode(claims)
if err != nil {
logger.Debug(logSender, "", "unable to create OAuth2 token: %v", err)
return ""
}
return tokenString
}
func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) {
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err)
return "", util.NewI18nError(
fmt.Errorf("unable to verify OAuth2 state: %v", err),
util.I18nOAuth2ErrorVerifyState,
)
}
if !slices.Contains(token.Audience(), tokenAudienceOAuth2) {
logger.Debug(logSender, "", "error validating OAuth2 token audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
if err := validateIPForToken(token, ip); err != nil {
logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
if val, ok := token.Get(jwt.JwtIDKey); ok {
if state, ok := val.(string); ok {
return state, nil
}
}
logger.Debug(logSender, "", "jti not found in OAuth2 token")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
func validateIPForToken(token jwt.Token, ip string) error {
if tokenValidationMode != tokenValidationNoIPMatch {
if !slices.Contains(token.Audience(), ip) {
return errInvalidToken
}
}
return nil
}