WebUIs: refactor CSRF

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino 2024-06-14 18:09:32 +02:00
parent 7fb5b1b996
commit 8294952474
No known key found for this signature in database
GPG key ID: 935D2952DEC4EECF
15 changed files with 1150 additions and 494 deletions

View file

@ -297,7 +297,7 @@ func changeAdminPassword(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err)) sendAPIResponse(w, r, err, "", getRespStatus(err))
return return
} }
invalidateToken(r) invalidateToken(r, false)
sendAPIResponse(w, r, err, "Password updated", http.StatusOK) sendAPIResponse(w, r, err, "Password updated", http.StatusOK)
} }

View file

@ -85,7 +85,7 @@ type oauth2TokenRequest struct {
BaseRedirectURL string `json:"base_redirect_url"` BaseRedirectURL string `json:"base_redirect_url"`
} }
func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
var req oauth2TokenRequest var req oauth2TokenRequest
@ -115,7 +115,7 @@ func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) {
clientSecret.SetAdditionalData(xid.New().String()) clientSecret.SetAdditionalData(xid.New().String())
pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret) pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret)
oauth2Mgr.addPendingAuth(pendingAuth) oauth2Mgr.addPendingAuth(pendingAuth)
stateToken := createOAuth2Token(pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr)) stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr))
if stateToken == "" { if stateToken == "" {
sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError) sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError)
return return

View file

@ -531,7 +531,7 @@ func changeUserPassword(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err)) sendAPIResponse(w, r, err, "", getRespStatus(err))
return return
} }
invalidateToken(r) invalidateToken(r, false)
sendAPIResponse(w, r, err, "Password updated", http.StatusOK) sendAPIResponse(w, r, err, "Password updated", http.StatusOK)
} }

View file

@ -138,8 +138,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) {
if claims.MustSetTwoFactorAuth { if claims.MustSetTwoFactorAuth {
// force logout // force logout
defer func() { defer func() {
c := jwtTokenClaims{} removeCookie(w, r, baseURL)
c.removeCookie(w, r, baseURL)
}() }()
} }

View file

@ -441,13 +441,11 @@ func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *h
doRedirect() doRedirect()
return errInvalidToken return errInvalidToken
} }
if tokenValidationMode != tokenValidationNoIPMatch { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil {
if !util.Contains(token.Audience(), ipAddr) { logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr)
logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr) doRedirect()
doRedirect() return err
return errInvalidToken
}
} }
ctx := jwtauth.NewContext(r.Context(), token, nil) ctx := jwtauth.NewContext(r.Context(), token, nil)
claims, err := getTokenClaims(r.WithContext(ctx)) claims, err := getTokenClaims(r.WithContext(ctx))

View file

@ -41,6 +41,7 @@ const (
tokenAudienceAPIUser tokenAudience = "APIUser" tokenAudienceAPIUser tokenAudience = "APIUser"
tokenAudienceCSRF tokenAudience = "CSRF" tokenAudienceCSRF tokenAudience = "CSRF"
tokenAudienceOAuth2 tokenAudience = "OAuth2" tokenAudienceOAuth2 tokenAudience = "OAuth2"
tokenAudienceWebLogin tokenAudience = "WebLogin"
) )
type tokenValidation = int type tokenValidation = int
@ -60,6 +61,7 @@ const (
claimMustSetSecondFactorKey = "2fa_required" claimMustSetSecondFactorKey = "2fa_required"
claimRequiredTwoFactorProtocols = "2fa_protos" claimRequiredTwoFactorProtocols = "2fa_protos"
claimHideUserPageSection = "hus" claimHideUserPageSection = "hus"
claimRef = "ref"
basicRealm = "Basic realm=\"SFTPGo\"" basicRealm = "Basic realm=\"SFTPGo\""
jwtCookieKey = "jwt" jwtCookieKey = "jwt"
) )
@ -69,7 +71,7 @@ var (
shareTokenDuration = 2 * time.Hour shareTokenDuration = 2 * time.Hour
// csrf token duration is greater than normal token duration to reduce issues // csrf token duration is greater than normal token duration to reduce issues
// with the login form // with the login form
csrfTokenDuration = 6 * time.Hour csrfTokenDuration = 4 * time.Hour
tokenRefreshThreshold = 10 * time.Minute tokenRefreshThreshold = 10 * time.Minute
tokenValidationMode = tokenValidationFull tokenValidationMode = tokenValidationFull
) )
@ -86,6 +88,8 @@ type jwtTokenClaims struct {
MustChangePassword bool MustChangePassword bool
RequiredTwoFactorProtocols []string RequiredTwoFactorProtocols []string
HideUserPageSections int HideUserPageSections int
JwtID string
Ref string
} }
func (c *jwtTokenClaims) hasUserAudience() bool { func (c *jwtTokenClaims) hasUserAudience() bool {
@ -103,6 +107,12 @@ func (c *jwtTokenClaims) asMap() map[string]any {
claims[claimUsernameKey] = c.Username claims[claimUsernameKey] = c.Username
claims[claimPermissionsKey] = c.Permissions claims[claimPermissionsKey] = c.Permissions
if c.JwtID != "" {
claims[jwt.JwtIDKey] = c.JwtID
}
if c.Ref != "" {
claims[claimRef] = c.Ref
}
if c.Role != "" { if c.Role != "" {
claims[claimRole] = c.Role claims[claimRole] = c.Role
} }
@ -169,6 +179,7 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Permissions = nil c.Permissions = nil
c.Username = c.decodeString(token[claimUsernameKey]) c.Username = c.decodeString(token[claimUsernameKey])
c.Signature = c.decodeString(token[jwt.SubjectKey]) c.Signature = c.decodeString(token[jwt.SubjectKey])
c.JwtID = c.decodeString(token[jwt.JwtIDKey])
audience := token[jwt.AudienceKey] audience := token[jwt.AudienceKey]
switch v := audience.(type) { switch v := audience.(type) {
@ -176,6 +187,10 @@ func (c *jwtTokenClaims) Decode(token map[string]any) {
c.Audience = v c.Audience = v
} }
if val, ok := token[claimRef]; ok {
c.Ref = c.decodeString(val)
}
if val, ok := token[claimAPIKey]; ok { if val, ok := token[claimAPIKey]; ok {
c.APIKeyID = c.decodeString(val) c.APIKeyID = c.decodeString(val)
} }
@ -236,9 +251,15 @@ func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenA
claims := c.asMap() claims := c.asMap()
now := time.Now().UTC() now := time.Now().UTC()
claims[jwt.JwtIDKey] = xid.New().String() if _, ok := claims[jwt.JwtIDKey]; !ok {
claims[jwt.JwtIDKey] = xid.New().String()
}
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(tokenDuration) if audience == tokenAudienceWebLogin {
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
} else {
claims[jwt.ExpirationKey] = now.Add(tokenDuration)
}
claims[jwt.AudienceKey] = []string{audience, ip} claims[jwt.AudienceKey] = []string{audience, ip}
return tokenAuth.Encode(claims) return tokenAuth.Encode(claims)
@ -274,21 +295,25 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque
if audience == tokenAudienceWebShare { if audience == tokenAudienceWebShare {
duration = shareTokenDuration duration = shareTokenDuration
} }
setCookie(w, r, basePath, resp["access_token"].(string), duration)
return nil
}
func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey, Name: jwtCookieKey,
Value: resp["access_token"].(string), Value: cookieValue,
Path: basePath, Path: cookiePath,
Expires: time.Now().Add(duration), Expires: time.Now().Add(duration),
MaxAge: int(duration / time.Second), MaxAge: int(duration / time.Second),
HttpOnly: true, HttpOnly: true,
Secure: isTLS(r), Secure: isTLS(r),
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
}) })
return nil
} }
func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) { func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: jwtCookieKey, Name: jwtCookieKey,
Value: "", Value: "",
@ -300,10 +325,10 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, co
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
}) })
w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`)
invalidateToken(r) invalidateToken(r, false)
} }
func tokenFromContext(r *http.Request) string { func oidcTokenFromContext(r *http.Request) string {
if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok { if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok {
return token return token
} }
@ -324,7 +349,7 @@ func isTokenInvalidated(r *http.Request) bool {
var findTokenFns []func(r *http.Request) string var findTokenFns []func(r *http.Request) string
findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader) findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader)
findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie) findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie)
findTokenFns = append(findTokenFns, tokenFromContext) findTokenFns = append(findTokenFns, oidcTokenFromContext)
isTokenFound := false isTokenFound := false
for _, fn := range findTokenFns { for _, fn := range findTokenFns {
@ -340,14 +365,18 @@ func isTokenInvalidated(r *http.Request) bool {
return !isTokenFound return !isTokenFound
} }
func invalidateToken(r *http.Request) { func invalidateToken(r *http.Request, isLoginToken bool) {
duration := tokenDuration
if isLoginToken {
duration = csrfTokenDuration
}
tokenString := jwtauth.TokenFromHeader(r) tokenString := jwtauth.TokenFromHeader(r)
if tokenString != "" { if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
} }
tokenString = jwtauth.TokenFromCookie(r) tokenString = jwtauth.TokenFromCookie(r)
if tokenString != "" { if tokenString != "" {
invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC())
} }
} }
@ -380,7 +409,22 @@ func getAdminFromToken(r *http.Request) *dataprovider.Admin {
return admin return admin
} }
func createCSRFToken(ip string) string { func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string,
) {
c := jwtTokenClaims{
JwtID: tokenID,
}
resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip)
if err != nil {
return
}
setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration)
}
func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID,
basePath string,
) string {
ip := util.GetIPFromRemoteAddress(r.RemoteAddr)
claims := make(map[string]any) claims := make(map[string]any)
now := time.Now().UTC() now := time.Now().UTC()
@ -388,7 +432,16 @@ func createCSRFToken(ip string) string {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration) claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip} claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip}
if tokenID != "" {
createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip)
claims[claimRef] = tokenID
} else {
if c, err := getTokenClaims(r); err == nil {
claims[claimRef] = c.JwtID
} else {
logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err)
}
}
_, tokenString, err := csrfTokenAuth.Encode(claims) _, tokenString, err := csrfTokenAuth.Encode(claims)
if err != nil { if err != nil {
logger.Debug(logSender, "", "unable to create CSRF token: %v", err) logger.Debug(logSender, "", "unable to create CSRF token: %v", err)
@ -397,7 +450,8 @@ func createCSRFToken(ip string) string {
return tokenString return tokenString
} }
func verifyCSRFToken(tokenString, ip string) error { func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
tokenString := r.Form.Get(csrfFormToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil { if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err) logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err)
@ -409,17 +463,60 @@ func verifyCSRFToken(tokenString, ip string) error {
return errors.New("the form token is not valid") return errors.New("the form token is not valid")
} }
if tokenValidationMode != tokenValidationNoIPMatch { if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
if !util.Contains(token.Audience(), ip) { logger.Debug(logSender, "", "error validating CSRF token IP audience")
logger.Debug(logSender, "", "error validating CSRF token IP audience") return errors.New("the form token is not valid")
return errors.New("the form token is not valid") }
} claims, err := getTokenClaims(r)
if err != nil {
logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err)
return err
}
ref, ok := token.Get(claimRef)
if !ok {
logger.Debug(logSender, "", "error validating CSRF token, missing reference")
return errors.New("the form token is not valid")
}
if claims.JwtID == "" || claims.JwtID != ref.(string) {
logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref)
return errors.New("unexpected form token")
} }
return nil return nil
} }
func createOAuth2Token(state, ip string) string { func verifyLoginCookie(r *http.Request) error {
token, _, err := jwtauth.FromContext(r.Context())
if err != nil || token == nil {
logger.Debug(logSender, "", "error getting login token: %v", err)
return errInvalidToken
}
if isTokenInvalidated(r) {
logger.Debug(logSender, "", "the login token has been invalidated")
return errInvalidToken
}
if !util.Contains(token.Audience(), tokenAudienceWebLogin) {
logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
return err
}
return nil
}
func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error {
if err := verifyLoginCookie(r); err != nil {
return err
}
if err := verifyCSRFToken(r, csrfTokenAuth); err != nil {
return err
}
return nil
}
func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string {
claims := make(map[string]any) claims := make(map[string]any)
now := time.Now().UTC() now := time.Now().UTC()
@ -436,7 +533,7 @@ func createOAuth2Token(state, ip string) string {
return tokenString return tokenString
} }
func verifyOAuth2Token(tokenString, ip string) (string, error) { func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) {
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil { if err != nil || token == nil {
logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err) logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err)
@ -451,11 +548,9 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
} }
if tokenValidationMode != tokenValidationNoIPMatch { if err := validateIPForToken(token, ip); err != nil {
if !util.Contains(token.Audience(), ip) { logger.Debug(logSender, "", "error validating OAuth2 token IP audience")
logger.Debug(logSender, "", "error validating OAuth2 token IP audience") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
}
} }
if val, ok := token.Get(jwt.JwtIDKey); ok { if val, ok := token.Get(jwt.JwtIDKey); ok {
if state, ok := val.(string); ok { if state, ok := val.(string); ok {
@ -465,3 +560,12 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) {
logger.Debug(logSender, "", "jti not found in OAuth2 token") logger.Debug(logSender, "", "jti not found in OAuth2 token")
return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState)
} }
func validateIPForToken(token jwt.Token, ip string) error {
if tokenValidationMode != tokenValidationNoIPMatch {
if !util.Contains(token.Audience(), ip) {
return errInvalidToken
}
}
return nil
}

View file

@ -31,8 +31,6 @@ import (
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/acme"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/common"
@ -196,7 +194,6 @@ var (
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
cleanupDone chan bool cleanupDone chan bool
invalidatedJWTTokens tokenManager invalidatedJWTTokens tokenManager
csrfTokenAuth *jwtauth.JWTAuth
webRootPath string webRootPath string
webBasePath string webBasePath string
webBaseAdminPath string webBaseAdminPath string
@ -967,7 +964,6 @@ func (c *Conf) Initialize(configDir string, isShared int) error {
c.SigningPassphrase = passphrase c.SigningPassphrase = passphrase
} }
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(c.SigningPassphrase), nil)
hideSupportLink = c.HideSupportLink hideSupportLink = c.HideSupportLink
exitChannel := make(chan error, 1) exitChannel := make(chan error, 1)

File diff suppressed because it is too large Load diff

View file

@ -412,6 +412,35 @@ func TestGCSWebInvalidFormFile(t *testing.T) {
assert.EqualError(t, err, http.ErrNotMultipart.Error()) assert.EqualError(t, err, http.ErrNotMultipart.Error())
} }
func TestVerifyCSRFToken(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil)
require.NoError(t, err)
req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, fs.ErrPermission))
rr := httptest.NewRecorder()
tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)
assert.NotEmpty(t, tokenString)
token, err := server.csrfTokenAuth.Decode(tokenString)
require.NoError(t, err)
_, ok := token.Get(claimRef)
assert.False(t, ok)
req.Form = url.Values{}
req.Form.Set(csrfFormToken, tokenString)
err = verifyCSRFToken(req, server.csrfTokenAuth)
assert.ErrorIs(t, err, fs.ErrPermission)
req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil)
require.NoError(t, err)
req.Form = url.Values{}
req.Form.Set(csrfFormToken, tokenString)
err = verifyCSRFToken(req, server.csrfTokenAuth)
assert.ErrorContains(t, err, "the form token is not valid")
}
func TestInvalidToken(t *testing.T) { func TestInvalidToken(t *testing.T) {
server := httpdServer{} server := httpdServer{}
server.initializeRouter() server.initializeRouter()
@ -923,13 +952,24 @@ func TestUpdateWebAdminInvalidClaims(t *testing.T) {
token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "")
assert.NoError(t, err) assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, webAdminPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values) form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken("")) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath))
form.Set("status", "1") form.Set("status", "1")
form.Set("default_users_expiration", "30") form.Set("default_users_expiration", "30")
req, _ := http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
rctx := chi.NewRouteContext() rctx := chi.NewRouteContext()
rctx.URLParams.Add("username", "admin") rctx.URLParams.Add("username", "admin")
req = req.WithContext(ctx)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
@ -1028,7 +1068,7 @@ func TestOAuth2Redirect(t *testing.T) {
assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle) assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle)
ip := "127.1.1.4" ip := "127.1.1.4"
tokenString := createOAuth2Token(xid.New().String(), ip) tokenString := createOAuth2Token(server.csrfTokenAuth, xid.New().String(), ip)
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst
assert.NoError(t, err) assert.NoError(t, err)
@ -1039,8 +1079,10 @@ func TestOAuth2Redirect(t *testing.T) {
} }
func TestOAuth2Token(t *testing.T) { func TestOAuth2Token(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
// invalid token // invalid token
_, err := verifyOAuth2Token("token", "") _, err := verifyOAuth2Token(server.csrfTokenAuth, "token", "")
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unable to verify OAuth2 state") assert.Contains(t, err.Error(), "unable to verify OAuth2 state")
} }
@ -1053,22 +1095,22 @@ func TestOAuth2Token(t *testing.T) {
claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI} claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err := csrfTokenAuth.Encode(claims) _, tokenString, err := server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err) assert.NoError(t, err)
_, err = verifyOAuth2Token(tokenString, "") _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "")
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state") assert.Contains(t, err.Error(), "invalid OAuth2 state")
} }
// bad IP // bad IP
tokenString = createOAuth2Token("state", "127.1.1.1") tokenString = createOAuth2Token(server.csrfTokenAuth, "state", "127.1.1.1")
_, err = verifyOAuth2Token(tokenString, "127.1.1.2") _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.2")
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state") assert.Contains(t, err.Error(), "invalid OAuth2 state")
} }
// ok // ok
state := xid.New().String() state := xid.New().String()
tokenString = createOAuth2Token(state, "127.1.1.3") tokenString = createOAuth2Token(server.csrfTokenAuth, state, "127.1.1.3")
s, err := verifyOAuth2Token(tokenString, "127.1.1.3") s, err := verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.3")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, state, s) assert.Equal(t, state, s)
// no jti // no jti
@ -1077,19 +1119,17 @@ func TestOAuth2Token(t *testing.T) {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, "127.1.1.4"} claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, "127.1.1.4"}
_, tokenString, err = csrfTokenAuth.Encode(claims) _, tokenString, err = server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err) assert.NoError(t, err)
_, err = verifyOAuth2Token(tokenString, "127.1.1.4") _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4")
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid OAuth2 state") assert.Contains(t, err.Error(), "invalid OAuth2 state")
} }
// encode error // encode error
csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil) server.csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil)
tokenString = createOAuth2Token(xid.New().String(), "") tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "")
assert.Empty(t, tokenString) assert.Empty(t, tokenString)
server := httpdServer{}
server.initializeRouter()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
testReq := make(map[string]any) testReq := make(map[string]any)
testReq["base_redirect_url"] = "http://localhost:8082" testReq["base_redirect_url"] = "http://localhost:8082"
@ -1097,16 +1137,17 @@ func TestOAuth2Token(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON))
assert.NoError(t, err) assert.NoError(t, err)
handleSMTPOAuth2TokenRequestPost(rr, req) server.handleSMTPOAuth2TokenRequestPost(rr, req)
assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Equal(t, http.StatusInternalServerError, rr.Code)
assert.Contains(t, rr.Body.String(), "unable to create state token") assert.Contains(t, rr.Body.String(), "unable to create state token")
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
} }
func TestCSRFToken(t *testing.T) { func TestCSRFToken(t *testing.T) {
server := httpdServer{}
server.initializeRouter()
// invalid token // invalid token
err := verifyCSRFToken("token", "") req := &http.Request{}
err := verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unable to verify form token") assert.Contains(t, err.Error(), "unable to verify form token")
} }
@ -1119,16 +1160,23 @@ func TestCSRFToken(t *testing.T) {
claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI} claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err := csrfTokenAuth.Encode(claims) _, tokenString, err := server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err) assert.NoError(t, err)
err = verifyCSRFToken(tokenString, "") values := url.Values{}
values.Set(csrfFormToken, tokenString)
req.Form = values
err = verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "form token is not valid") assert.Contains(t, err.Error(), "form token is not valid")
} }
// bad IP // bad IP
tokenString = createCSRFToken("127.1.1.1") req.RemoteAddr = "127.1.1.1"
err = verifyCSRFToken(tokenString, "127.1.1.2") tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)
values.Set(csrfFormToken, tokenString)
req.Form = values
req.RemoteAddr = "127.1.1.2"
err = verifyCSRFToken(req, server.csrfTokenAuth)
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "form token is not valid") assert.Contains(t, err.Error(), "form token is not valid")
} }
@ -1137,8 +1185,9 @@ func TestCSRFToken(t *testing.T) {
claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second)
claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.ExpirationKey] = now.Add(tokenDuration)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI} claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
_, tokenString, err = csrfTokenAuth.Encode(claims) _, tokenString, err = server.csrfTokenAuth.Encode(claims)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, tokenString)
r := GetHTTPRouter(Binding{ r := GetHTTPRouter(Binding{
Address: "", Address: "",
@ -1148,9 +1197,9 @@ func TestCSRFToken(t *testing.T) {
EnableRESTAPI: true, EnableRESTAPI: true,
RenderOpenAPI: true, RenderOpenAPI: true,
}) })
fn := verifyCSRFHeader(r) fn := server.verifyCSRFHeader(r)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil)
fn.ServeHTTP(rr, req) fn.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code) assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "Invalid token") assert.Contains(t, rr.Body.String(), "Invalid token")
@ -1163,18 +1212,20 @@ func TestCSRFToken(t *testing.T) {
assert.Contains(t, rr.Body.String(), "the token is not valid") assert.Contains(t, rr.Body.String(), "the token is not valid")
// invalid IP // invalid IP
tokenString = createCSRFToken("172.16.1.2") tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)
req.Header.Set(csrfHeaderToken, tokenString) req.Header.Set(csrfHeaderToken, tokenString)
req.RemoteAddr = "172.16.1.2"
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
fn.ServeHTTP(rr, req) fn.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code) assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "the token is not valid") assert.Contains(t, rr.Body.String(), "the token is not valid")
csrfTokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) csrfTokenAuth := jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
tokenString = createCSRFToken("") tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath)
assert.Empty(t, tokenString) assert.Empty(t, tokenString)
rr = httptest.NewRecorder()
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) createLoginCookie(rr, req, csrfTokenAuth, "", webBaseAdminPath, req.RemoteAddr)
assert.Empty(t, rr.Header().Get("Set-Cookie"))
} }
func TestCreateShareCookieError(t *testing.T) { func TestCreateShareCookieError(t *testing.T) {
@ -1205,19 +1256,38 @@ func TestCreateShareCookieError(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
server := httpdServer{ server := httpdServer{
tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil), tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil),
csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil),
} }
c := jwtTokenClaims{
JwtID: xid.New().String(),
}
resp, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "127.0.0.1")
assert.NoError(t, err)
parsedToken, err := jwtauth.VerifyToken(server.csrfTokenAuth, resp["access_token"].(string))
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil)
assert.NoError(t, err)
req.RemoteAddr = "127.0.0.1:4567"
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values) form := make(url.Values)
form.Set("share_password", pwd) form.Set("share_password", pwd)
form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseClientPath))
rctx := chi.NewRouteContext() rctx := chi.NewRouteContext()
rctx.URLParams.Add("id", share.ShareID) rctx.URLParams.Add("id", share.ShareID)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"), req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"),
bytes.NewBuffer([]byte(form.Encode()))) bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
req.RemoteAddr = "127.0.0.1:2345" req.RemoteAddr = "127.0.0.1:2345"
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", resp["access_token"]))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(ctx)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
server.handleClientShareLoginPost(rr, req) server.handleClientShareLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
@ -1229,7 +1299,8 @@ func TestCreateShareCookieError(t *testing.T) {
func TestCreateTokenError(t *testing.T) { func TestCreateTokenError(t *testing.T) {
server := httpdServer{ server := httpdServer{
tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil), tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil),
csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil),
} }
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
admin := dataprovider.Admin{ admin := dataprovider.Admin{
@ -1253,14 +1324,36 @@ func TestCreateTokenError(t *testing.T) {
server.generateAndSendUserToken(rr, req, "", user) server.generateAndSendUserToken(rr, req, "", user)
assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Equal(t, http.StatusInternalServerError, rr.Code)
c := jwtTokenClaims{
JwtID: xid.New().String(),
}
token, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "")
assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
form := make(url.Values) form := make(url.Values)
form.Set("username", admin.Username) form.Set("username", admin.Username)
form.Set("password", admin.Password) form.Set("password", admin.Password)
form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, xid.New().String(), webBaseAdminPath))
cookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
req.RemoteAddr = "127.0.0.1:1234" req.Header.Set("Cookie", cookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
server.handleWebAdminLoginPost(rr, req) server.handleWebAdminLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
// req with no content type // req with no content type
@ -1287,7 +1380,7 @@ func TestCreateTokenError(t *testing.T) {
req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil) req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
_, err := getAdminFromPostFields(req) _, err = getAdminFromPostFields(req)
assert.Error(t, err) assert.Error(t, err)
req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil) req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil)
@ -1421,13 +1514,21 @@ func TestCreateTokenError(t *testing.T) {
err = dataprovider.AddUser(&user, "", "", "") err = dataprovider.AddUser(&user, "", "", "")
assert.NoError(t, err) assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
form = make(url.Values) form = make(url.Values)
form.Set("username", user.Username) form.Set("username", user.Username)
form.Set("password", "clientpwd") form.Set("password", "clientpwd")
form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath))
req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode())))
req.RemoteAddr = "127.0.0.1:4567"
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.handleWebClientLoginPost(rr, req) server.handleWebClientLoginPost(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
@ -1616,6 +1717,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any) claims = make(map[string]any)
claims[claimUsernameKey] = admin.Username claims[claimUsernameKey] = admin.Username
claims[claimPermissionsKey] = admin.Permissions claims[claimPermissionsKey] = admin.Permissions
claims[jwt.JwtIDKey] = xid.New().String()
claims[jwt.SubjectKey] = admin.GetSignature() claims[jwt.SubjectKey] = admin.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI} claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
@ -1648,9 +1750,11 @@ func TestCookieExpiration(t *testing.T) {
admin, err = dataprovider.AdminExists(admin.Username) admin, err = dataprovider.AdminExists(admin.Username)
assert.NoError(t, err) assert.NoError(t, err)
tokenID := xid.New().String()
claims = make(map[string]any) claims = make(map[string]any)
claims[claimUsernameKey] = admin.Username claims[claimUsernameKey] = admin.Username
claims[claimPermissionsKey] = admin.Permissions claims[claimPermissionsKey] = admin.Permissions
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = admin.GetSignature() claims[jwt.SubjectKey] = admin.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceAPI} claims[jwt.AudienceKey] = []string{tokenAudienceAPI}
@ -1669,6 +1773,11 @@ func TestCookieExpiration(t *testing.T) {
server.checkCookieExpiration(rr, req.WithContext(ctx)) server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie") cookie = rr.Header().Get("Set-Cookie")
assert.True(t, strings.HasPrefix(cookie, "jwt=")) assert.True(t, strings.HasPrefix(cookie, "jwt="))
req.Header.Set("Cookie", cookie)
token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
if assert.NoError(t, err) {
assert.Equal(t, tokenID, token.JwtID())
}
err = dataprovider.DeleteAdmin(admin.Username, "", "", "") err = dataprovider.DeleteAdmin(admin.Username, "", "", "")
assert.NoError(t, err) assert.NoError(t, err)
@ -1689,6 +1798,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any) claims = make(map[string]any)
claims[claimUsernameKey] = user.Username claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature() claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
@ -1721,6 +1831,7 @@ func TestCookieExpiration(t *testing.T) {
claims = make(map[string]any) claims = make(map[string]any)
claims[claimUsernameKey] = user.Username claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature() claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
@ -1740,6 +1851,35 @@ func TestCookieExpiration(t *testing.T) {
server.checkCookieExpiration(rr, req.WithContext(ctx)) server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie") cookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie) assert.NotEmpty(t, cookie)
req.Header.Set("Cookie", cookie)
token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
if assert.NoError(t, err) {
assert.Equal(t, tokenID, token.JwtID())
}
// test a disabled user
user.Status = 0
err = dataprovider.UpdateUser(&user, "", "", "")
assert.NoError(t, err)
user, err = dataprovider.UserExists(user.Username, "")
assert.NoError(t, err)
claims = make(map[string]any)
claims[claimUsernameKey] = user.Username
claims[claimPermissionsKey] = user.Filters.WebClient
claims[jwt.JwtIDKey] = tokenID
claims[jwt.SubjectKey] = user.GetSignature()
claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute)
claims[jwt.AudienceKey] = []string{tokenAudienceWebClient}
token, _, err = server.tokenAuth.Encode(claims)
assert.NoError(t, err)
rr = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil)
ctx = jwtauth.NewContext(req.Context(), token, nil)
server.checkCookieExpiration(rr, req.WithContext(ctx))
cookie = rr.Header().Get("Set-Cookie")
assert.Empty(t, cookie)
err = dataprovider.DeleteUser(user.Username, "", "", "") err = dataprovider.DeleteUser(user.Username, "", "", "")
assert.NoError(t, err) assert.NoError(t, err)
@ -2104,34 +2244,95 @@ func TestProxyHeaders(t *testing.T) {
testServer.Config.Handler.ServeHTTP(rr, req) testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, http.StatusOK, rr.Code)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = testIP
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
cookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, cookie)
req.Header.Set("Cookie", cookie)
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values) form := make(url.Values)
form.Set("username", username) form.Set("username", username)
form.Set("password", password) form.Set("password", password)
form.Set(csrfFormToken, createCSRFToken(testIP)) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
req.RemoteAddr = testIP req.RemoteAddr = testIP
req.Header.Set("Cookie", cookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req) testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials)
form.Set(csrfFormToken, createCSRFToken(validForwardedFor)) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie := rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
req.RemoteAddr = testIP req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set("X-Forwarded-For", validForwardedFor)
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req) testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String())
cookie := rr.Header().Get("Set-Cookie") cookie = rr.Header().Get("Set-Cookie")
assert.NotContains(t, cookie, "Secure") assert.NotContains(t, cookie, "Secure")
// The login cookie is invalidated after a successful login, the same request will fail
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
req.RemoteAddr = testIP req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor)
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF)
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "https") req.Header.Set(xForwardedProto, "https")
@ -2141,9 +2342,26 @@ func TestProxyHeaders(t *testing.T) {
cookie = rr.Header().Get("Set-Cookie") cookie = rr.Header().Get("Set-Cookie")
assert.Contains(t, cookie, "Secure") assert.Contains(t, cookie, "Secure")
req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil)
assert.NoError(t, err)
req.RemoteAddr = validForwardedFor
rr = httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String())
loginCookie = rr.Header().Get("Set-Cookie")
assert.NotEmpty(t, loginCookie)
req.Header.Set("Cookie", loginCookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath))
req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
req.RemoteAddr = testIP req.RemoteAddr = testIP
req.Header.Set("Cookie", loginCookie)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set("X-Forwarded-For", validForwardedFor)
req.Header.Set(xForwardedProto, "http") req.Header.Set(xForwardedProto, "http")
@ -2715,10 +2933,22 @@ func TestInvalidClaims(t *testing.T) {
} }
token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "") token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "")
assert.NoError(t, err) assert.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form := make(url.Values) form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken("")) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath))
form.Set("public_keys", "") form.Set("public_keys", "")
req, _ := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
server.handleWebClientProfilePost(rr, req) server.handleWebClientProfilePost(rr, req)
@ -2735,14 +2965,27 @@ func TestInvalidClaims(t *testing.T) {
} }
token, err = c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") token, err = c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "")
assert.NoError(t, err) assert.NoError(t, err)
req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil)
assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
parsedToken, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = req.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
req = req.WithContext(ctx)
form = make(url.Values) form = make(url.Values)
form.Set(csrfFormToken, createCSRFToken("")) form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath))
form.Set("allow_api_key_auth", "") form.Set("allow_api_key_auth", "")
req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err)
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"]))
server.handleWebAdminProfilePost(rr, req) server.handleWebAdminProfilePost(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code) assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken)
} }
func TestTLSReq(t *testing.T) { func TestTLSReq(t *testing.T) {
@ -3041,24 +3284,31 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
} }
server.initializeRouter() server.initializeRouter()
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} {
rr = httptest.NewRecorder() rr := httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webURL, nil) r, err := http.NewRequest(http.MethodGet, webURL, nil)
assert.NoError(t, err) assert.NoError(t, err)
server.router.ServeHTTP(rr, r) server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, http.StatusFound, rr.Code)
assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location"))
} }
rr := httptest.NewRecorder()
r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
cookie := rr.Header().Get("Set-Cookie")
r.Header.Set("Cookie", cookie)
parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx := r.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
r = r.WithContext(ctx)
form := make(url.Values) form := make(url.Values)
csrfToken := createCSRFToken("") csrfToken := createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath)
form.Set("_form_token", csrfToken) form.Set(csrfFormToken, csrfToken)
form.Set("install_code", installationCode+"5") form.Set("install_code", installationCode+"5")
form.Set("username", defaultAdminUsername) form.Set("username", defaultAdminUsername)
form.Set("password", "password") form.Set("password", "password")
@ -3066,6 +3316,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r) server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, http.StatusOK, rr.Code)
@ -3077,6 +3329,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r) server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, http.StatusFound, rr.Code)
@ -3098,12 +3352,6 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
return "5678" return "5678"
}) })
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} {
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webURL, nil) r, err = http.NewRequest(http.MethodGet, webURL, nil)
@ -3113,9 +3361,22 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location"))
} }
rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil)
assert.NoError(t, err)
server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code)
cookie = rr.Header().Get("Set-Cookie")
r.Header.Set("Cookie", cookie)
parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie)
assert.NoError(t, err)
ctx = r.Context()
ctx = jwtauth.NewContext(ctx, parsedToken, err)
r = r.WithContext(ctx)
form = make(url.Values) form = make(url.Values)
csrfToken = createCSRFToken("") csrfToken = createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath)
form.Set("_form_token", csrfToken) form.Set(csrfFormToken, csrfToken)
form.Set("install_code", installationCode) form.Set("install_code", installationCode)
form.Set("username", defaultAdminUsername) form.Set("username", defaultAdminUsername)
form.Set("password", "password") form.Set("password", "password")
@ -3123,6 +3384,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r) server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, http.StatusOK, rr.Code)
@ -3134,6 +3397,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) {
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode())))
assert.NoError(t, err) assert.NoError(t, err)
r = r.WithContext(ctx)
r.Header.Set("Cookie", cookie)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
server.router.ServeHTTP(rr, r) server.router.ServeHTTP(rr, r)
assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, http.StatusFound, rr.Code)
@ -3199,6 +3464,7 @@ func TestDecodeToken(t *testing.T) {
claimNodeID: nodeID, claimNodeID: nodeID,
claimMustChangePasswordKey: false, claimMustChangePasswordKey: false,
claimMustSetSecondFactorKey: true, claimMustSetSecondFactorKey: true,
claimRef: "ref",
} }
c := jwtTokenClaims{} c := jwtTokenClaims{}
c.Decode(token) c.Decode(token)
@ -3206,6 +3472,11 @@ func TestDecodeToken(t *testing.T) {
assert.Equal(t, nodeID, c.NodeID) assert.Equal(t, nodeID, c.NodeID)
assert.False(t, c.MustChangePassword) assert.False(t, c.MustChangePassword)
assert.True(t, c.MustSetTwoFactorAuth) assert.True(t, c.MustSetTwoFactorAuth)
assert.Equal(t, "ref", c.Ref)
asMap := c.asMap()
asMap[claimMustChangePasswordKey] = false
assert.Equal(t, token, asMap)
token[claimMustChangePasswordKey] = 10 token[claimMustChangePasswordKey] = 10
c = jwtTokenClaims{} c = jwtTokenClaims{}

View file

@ -95,13 +95,11 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
doRedirect("Your token audience is not valid", nil) doRedirect("Your token audience is not valid", nil)
return errInvalidToken return errInvalidToken
} }
if tokenValidationMode != tokenValidationNoIPMatch { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil {
if !util.Contains(token.Audience(), ipAddr) { logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) doRedirect("Your token is not valid", nil)
doRedirect("Your token is not valid", nil) return err
return errInvalidToken
}
} }
return nil return nil
} }
@ -123,10 +121,16 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
return errInvalidToken return errInvalidToken
} }
if !util.Contains(token.Audience(), audience) { if !util.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the token is not valid for audience %q", audience) logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience)
notFoundFunc(w, r, nil) notFoundFunc(w, r, nil)
return errInvalidToken return errInvalidToken
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
notFoundFunc(w, r, nil)
return err
}
return nil return nil
} }
@ -324,10 +328,10 @@ func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handle
} }
} }
func verifyCSRFHeader(next http.Handler) http.Handler { func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get(csrfHeaderToken) tokenString := r.Header.Get(csrfHeaderToken)
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString)
if err != nil || token == nil { if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF header: %v", err) logger.Debug(logSender, "", "error validating CSRF header: %v", err)
sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden) sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
@ -340,12 +344,10 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
return return
} }
if tokenValidationMode != tokenValidationNoIPMatch { if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil {
if !util.Contains(token.Audience(), util.GetIPFromRemoteAddress(r.RemoteAddr)) { logger.Debug(logSender, "", "error validating CSRF header IP audience")
logger.Debug(logSender, "", "error validating CSRF header IP audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return
return
}
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)

View file

@ -541,6 +541,7 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h
return return
} }
jwtTokenClaims := jwtTokenClaims{ jwtTokenClaims := jwtTokenClaims{
JwtID: token.Cookie,
Username: token.Username, Username: token.Username,
Permissions: token.Permissions, Permissions: token.Permissions,
Role: token.TokenRole, Role: token.TokenRole,
@ -594,6 +595,7 @@ func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request)
authReq, err := oidcMgr.getPendingAuth(state) authReq, err := oidcMgr.getPendingAuth(state)
if err != nil { if err != nil {
logger.Debug(logSender, "", "oidc authentication state did not match") logger.Debug(logSender, "", "oidc authentication state did not match")
oidcMgr.removePendingAuth(state)
s.renderClientMessagePage(w, r, util.I18nInvalidAuthReqTitle, http.StatusBadRequest, s.renderClientMessagePage(w, r, util.I18nInvalidAuthReqTitle, http.StatusBadRequest,
util.NewI18nError(err, util.I18nInvalidAuth), "") util.NewI18nError(err, util.I18nInvalidAuth), "")
return return

View file

@ -33,7 +33,6 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/go-chi/jwtauth/v5" "github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -1584,12 +1583,9 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) {
tokenCookie = k tokenCookie = k
} }
// we should be able to create admins without setting a password // we should be able to create admins without setting a password
if csrfTokenAuth == nil {
csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil)
}
adminUsername := "testAdmin" adminUsername := "testAdmin"
form := make(url.Values) form := make(url.Values)
form.Set(csrfFormToken, createCSRFToken("")) form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath))
form.Set("username", adminUsername) form.Set("username", adminUsername)
form.Set("password", "") form.Set("password", "")
form.Set("status", "1") form.Set("status", "1")

View file

@ -68,6 +68,7 @@ type httpdServer struct {
isShared int isShared int
router *chi.Mux router *chi.Mux
tokenAuth *jwtauth.JWTAuth tokenAuth *jwtauth.JWTAuth
csrfTokenAuth *jwtauth.JWTAuth
signingPassphrase string signingPassphrase string
cors CorsConfig cors CorsConfig
} }
@ -164,13 +165,13 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler {
}) })
} }
func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := loginPage{ data := loginPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: util.I18nLoginTitle, Title: util.I18nLoginTitle,
CurrentURL: webClientLoginPath, CurrentURL: webClientLoginPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
FormDisabled: s.binding.isWebClientLoginFormDisabled(), FormDisabled: s.binding.isWebClientLoginFormDisabled(),
CheckRedirect: true, CheckRedirect: true,
@ -193,8 +194,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque
func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
c := jwtTokenClaims{} removeCookie(w, r, webBaseClientPath)
c.removeCookie(w, r, webBaseClientPath)
s.logoutOIDCUser(w, r) s.logoutOIDCUser(w, r)
http.Redirect(w, r, webClientLoginPath, http.StatusFound) http.Redirect(w, r, webClientLoginPath, http.StatusFound)
@ -206,7 +206,7 @@ func (s *httpdServer) handleWebClientChangePwdPost(w http.ResponseWriter, r *htt
s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -226,7 +226,7 @@ func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Reques
return return
} }
msg := getFlashMessage(w, r) msg := getFlashMessage(w, r)
s.renderClientLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderClientLoginPage(w, r, msg.getI18nError())
} }
func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) {
@ -234,7 +234,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
protocol := common.ProtocolHTTP protocol := common.ProtocolHTTP
@ -244,20 +244,19 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials) dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
s.renderClientLoginPage(w, r, s.renderClientLoginPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err) dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return
} }
if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err) dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message))
return return
} }
@ -265,13 +264,13 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
if err != nil { if err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, s.renderClientLoginPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message))
return return
} }
@ -280,7 +279,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
if err != nil { if err != nil {
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric), ipAddr) s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric))
return return
} }
s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage) s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage)
@ -292,10 +291,10 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -304,12 +303,12 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
_, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), _, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
newPassword, confirmPassword, false) newPassword, confirmPassword, false)
if err != nil { if err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr) s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric))
return return
} }
connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil { if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorDirList403), ipAddr) s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorDirList403))
return return
} }
@ -317,7 +316,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
err = user.CheckFsRoot(connectionID) err = user.CheckFsRoot(connectionID)
if err != nil { if err != nil {
logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) logger.Warn(logSender, connectionID, "unable to check fs root: %v", err)
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset), ipAddr) s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset))
return return
} }
s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage) s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage)
@ -332,18 +331,18 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
username := claims.Username username := claims.Username
recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
if username == "" || recoveryCode == "" { if username == "" || recoveryCode == "" {
s.renderClientTwoFactorRecoveryPage(w, r, s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
user, userMerged, err := dataprovider.GetUserVariants(username, "") user, userMerged, err := dataprovider.GetUserVariants(username, "")
@ -352,12 +351,12 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
} }
s.renderClientTwoFactorRecoveryPage(w, r, s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
s.renderClientTwoFactorPage(w, r, util.NewI18nError( s.renderClientTwoFactorPage(w, r, util.NewI18nError(
util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr) util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled))
return return
} }
for idx, code := range user.Filters.RecoveryCodes { for idx, code := range user.Filters.RecoveryCodes {
@ -368,7 +367,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
if code.Secret.GetPayload() == recoveryCode { if code.Secret.GetPayload() == recoveryCode {
if code.Used { if code.Used {
s.renderClientTwoFactorRecoveryPage(w, r, s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
user.Filters.RecoveryCodes[idx].Used = true user.Filters.RecoveryCodes[idx].Used = true
@ -386,7 +385,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter
} }
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderClientTwoFactorRecoveryPage(w, r, s.renderClientTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
} }
func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@ -398,7 +397,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
username := claims.Username username := claims.Username
@ -407,25 +406,25 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials) dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials)
s.renderClientTwoFactorPage(w, r, s.renderClientTwoFactorPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err) dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
user, err := dataprovider.GetUserWithGroupSettings(username, "") user, err := dataprovider.GetUserWithGroupSettings(username, "")
if err != nil { if err != nil {
updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}},
dataprovider.LoginMethodPassword, ipAddr, err) dataprovider.LoginMethodPassword, ipAddr, err)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return return
} }
if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr) s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled))
return return
} }
err = user.Filters.TOTPConfig.Secret.Decrypt() err = user.Filters.TOTPConfig.Secret.Decrypt()
@ -439,7 +438,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt
if !match || err != nil { if !match || err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials)
s.renderClientTwoFactorPage(w, r, s.renderClientTwoFactorPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String()) connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String())
@ -456,18 +455,17 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
username := claims.Username username := claims.Username
recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code"))
if username == "" || recoveryCode == "" { if username == "" || recoveryCode == "" {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
admin, err := dataprovider.AdminExists(username) admin, err := dataprovider.AdminExists(username)
@ -475,12 +473,11 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
if errors.Is(err, util.ErrNotFound) { if errors.Is(err, util.ErrNotFound) {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
} }
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
if !admin.Filters.TOTPConfig.Enabled { if !admin.Filters.TOTPConfig.Enabled {
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr) s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled))
return return
} }
for idx, code := range admin.Filters.RecoveryCodes { for idx, code := range admin.Filters.RecoveryCodes {
@ -491,7 +488,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
if code.Secret.GetPayload() == recoveryCode { if code.Secret.GetPayload() == recoveryCode {
if code.Used { if code.Used {
s.renderTwoFactorRecoveryPage(w, r, s.renderTwoFactorRecoveryPage(w, r,
util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
return return
} }
admin.Filters.RecoveryCodes[idx].Used = true admin.Filters.RecoveryCodes[idx].Used = true
@ -506,8 +503,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter,
} }
} }
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
} }
func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) {
@ -519,19 +515,18 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
username := claims.Username username := claims.Username
passcode := strings.TrimSpace(r.Form.Get("passcode")) passcode := strings.TrimSpace(r.Form.Get("passcode"))
if username == "" || passcode == "" { if username == "" || passcode == "" {
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
err = handleDefenderEventLoginFailed(ipAddr, err) err = handleDefenderEventLoginFailed(ipAddr, err)
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
admin, err := dataprovider.AdminExists(username) admin, err := dataprovider.AdminExists(username)
@ -539,11 +534,11 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
if errors.Is(err, util.ErrNotFound) { if errors.Is(err, util.ErrNotFound) {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
} }
s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return return
} }
if !admin.Filters.TOTPConfig.Enabled { if !admin.Filters.TOTPConfig.Enabled {
s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr) s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled))
return return
} }
err = admin.Filters.TOTPConfig.Secret.Decrypt() err = admin.Filters.TOTPConfig.Secret.Decrypt()
@ -555,8 +550,7 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http
admin.Filters.TOTPConfig.Secret.GetPayload()) admin.Filters.TOTPConfig.Secret.GetPayload())
if !match || err != nil { if !match || err != nil {
handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck
s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr) s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr)
@ -567,37 +561,35 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
username := strings.TrimSpace(r.Form.Get("username")) username := strings.TrimSpace(r.Form.Get("username"))
password := strings.TrimSpace(r.Form.Get("password")) password := strings.TrimSpace(r.Form.Get("password"))
if username == "" || password == "" { if username == "" || password == "" {
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr)
if err != nil { if err != nil {
handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck
s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr) s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr)
} }
func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := loginPage{ data := loginPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: util.I18nLoginTitle, Title: util.I18nLoginTitle,
CurrentURL: webAdminLoginPath, CurrentURL: webAdminLoginPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
Branding: s.binding.Branding.WebAdmin, Branding: s.binding.Branding.WebAdmin,
FormDisabled: s.binding.isWebAdminLoginFormDisabled(), FormDisabled: s.binding.isWebAdminLoginFormDisabled(),
CheckRedirect: false, CheckRedirect: false,
@ -622,13 +614,12 @@ func (s *httpdServer) handleWebAdminLogin(w http.ResponseWriter, r *http.Request
return return
} }
msg := getFlashMessage(w, r) msg := getFlashMessage(w, r)
s.renderAdminLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderAdminLoginPage(w, r, msg.getI18nError())
} }
func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
c := jwtTokenClaims{} removeCookie(w, r, webBaseAdminPath)
c.removeCookie(w, r, webBaseAdminPath)
s.logoutOIDCUser(w, r) s.logoutOIDCUser(w, r)
http.Redirect(w, r, webAdminLoginPath, http.StatusFound) http.Redirect(w, r, webAdminLoginPath, http.StatusFound)
@ -641,7 +632,7 @@ func (s *httpdServer) handleWebAdminChangePwdPost(w http.ResponseWriter, r *http
s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -660,10 +651,10 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -672,7 +663,7 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *
admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")),
newPassword, confirmPassword, true) newPassword, confirmPassword, true)
if err != nil { if err != nil {
s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr) s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric))
return return
} }
@ -688,10 +679,10 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.renderAdminSetupPage(w, r, "", ipAddr, util.NewI18nError(err, util.I18nErrorInvalidForm)) s.renderAdminSetupPage(w, r, "", util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -700,7 +691,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password"))
installCode := strings.TrimSpace(r.Form.Get("install_code")) installCode := strings.TrimSpace(r.Form.Get("install_code"))
if installationCode != "" && installCode != resolveInstallationCode() { if installationCode != "" && installCode != resolveInstallationCode() {
s.renderAdminSetupPage(w, r, username, ipAddr, s.renderAdminSetupPage(w, r, username,
util.NewI18nError( util.NewI18nError(
util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)), util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)),
util.I18nErrorSetupInstallCode), util.I18nErrorSetupInstallCode),
@ -708,17 +699,17 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
return return
} }
if username == "" { if username == "" {
s.renderAdminSetupPage(w, r, username, ipAddr, s.renderAdminSetupPage(w, r, username,
util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message)) util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message))
return return
} }
if password == "" { if password == "" {
s.renderAdminSetupPage(w, r, username, ipAddr, s.renderAdminSetupPage(w, r, username,
util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message)) util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message))
return return
} }
if password != confirmPassword { if password != confirmPassword {
s.renderAdminSetupPage(w, r, username, ipAddr, s.renderAdminSetupPage(w, r, username,
util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch)) util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch))
return return
} }
@ -730,7 +721,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
} }
err = dataprovider.AddAdmin(&admin, username, ipAddr, "") err = dataprovider.AddAdmin(&admin, username, ipAddr, "")
if err != nil { if err != nil {
s.renderAdminSetupPage(w, r, username, ipAddr, util.NewI18nError(err, util.I18nError500Message)) s.renderAdminSetupPage(w, r, username, util.NewI18nError(err, util.I18nError500Message))
return return
} }
s.loginAdmin(w, r, &admin, false, nil, ipAddr) s.loginAdmin(w, r, &admin, false, nil, ipAddr)
@ -738,7 +729,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req
func (s *httpdServer) loginUser( func (s *httpdServer) loginUser(
w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string, w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string,
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string), isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
) { ) {
c := jwtTokenClaims{ c := jwtTokenClaims{
Username: user.Username, Username: user.Username,
@ -760,12 +751,10 @@ func (s *httpdServer) loginUser(
if err != nil { if err != nil {
logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err) logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err)
updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure)
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message))
return return
} }
if isSecondFactorAuth { invalidateToken(r, !isSecondFactorAuth)
invalidateToken(r)
}
if audience == tokenAudienceWebClientPartial { if audience == tokenAudienceWebClientPartial {
redirectPath := webClientTwoFactorPath redirectPath := webClientTwoFactorPath
if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) {
@ -785,7 +774,7 @@ func (s *httpdServer) loginUser(
func (s *httpdServer) loginAdmin( func (s *httpdServer) loginAdmin(
w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin,
isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string), isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError),
ipAddr string, ipAddr string,
) { ) {
c := jwtTokenClaims{ c := jwtTokenClaims{
@ -807,15 +796,13 @@ func (s *httpdServer) loginAdmin(
if err != nil { if err != nil {
logger.Warn(logSender, "", "unable to set admin login cookie %v", err) logger.Warn(logSender, "", "unable to set admin login cookie %v", err)
if errorFunc == nil { if errorFunc == nil {
s.renderAdminSetupPage(w, r, admin.Username, ipAddr, util.NewI18nError(err, util.I18nError500Message)) s.renderAdminSetupPage(w, r, admin.Username, util.NewI18nError(err, util.I18nError500Message))
return return
} }
errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message))
return return
} }
if isSecondFactorAuth { invalidateToken(r, !isSecondFactorAuth)
invalidateToken(r)
}
if audience == tokenAudienceWebAdminPartial { if audience == tokenAudienceWebAdminPartial {
http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound)
return return
@ -831,7 +818,7 @@ func (s *httpdServer) loginAdmin(
func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
invalidateToken(r) invalidateToken(r, false)
sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK) sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK)
} }
@ -1022,13 +1009,13 @@ func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Reque
return return
} }
if util.Contains(token.Audience(), tokenAudienceWebClient) { if util.Contains(token.Audience(), tokenAudienceWebClient) {
s.refreshClientToken(w, r, tokenClaims) s.refreshClientToken(w, r, &tokenClaims)
} else { } else {
s.refreshAdminToken(w, r, tokenClaims) s.refreshAdminToken(w, r, &tokenClaims)
} }
} }
func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) { func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "") user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "")
if err != nil { if err != nil {
return return
@ -1037,6 +1024,10 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username) logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username)
return return
} }
if err := user.CheckLoginConditions(); err != nil {
logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err)
return
}
if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil { if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err)
return return
@ -1048,22 +1039,18 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck
} }
func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) { func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) {
admin, err := dataprovider.AdminExists(tokenClaims.Username) admin, err := dataprovider.AdminExists(tokenClaims.Username)
if err != nil { if err != nil {
return return
} }
if admin.Status != 1 {
logger.Debug(logSender, "", "admin %q is disabled, unable to refresh cookie", admin.Username)
return
}
if admin.GetSignature() != tokenClaims.Signature { if admin.GetSignature() != tokenClaims.Signature {
logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username) logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username)
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if !admin.CanLoginFromIP(ipAddr) { if err := admin.CanLogin(ipAddr); err != nil {
logger.Debug(logSender, "", "admin %q cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr) logger.Debug(logSender, "", "unable to refresh cookie for admin %q, err: %v", admin.Username, err)
return return
} }
tokenClaims.Permissions = admin.Permissions tokenClaims.Permissions = admin.Permissions
@ -1236,6 +1223,7 @@ func (s *httpdServer) mustCheckPath(r *http.Request) bool {
func (s *httpdServer) initializeRouter() { func (s *httpdServer) initializeRouter() {
var hasHTTPSRedirect bool var hasHTTPSRedirect bool
s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil) s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil)
s.router = chi.NewRouter() s.router = chi.NewRouter()
s.router.Use(middleware.RequestID) s.router.Use(middleware.RequestID)
@ -1537,11 +1525,14 @@ func (s *httpdServer) setupWebClientRoutes() {
s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin) s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin)
} }
if !s.binding.isWebClientLoginFormDisabled() { if !s.binding.isWebClientLoginFormDisabled() {
s.router.Post(webClientLoginPath, s.handleWebClientLoginPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientLoginPath, s.handleWebClientLoginPost)
s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd) s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd)
s.router.Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost)
s.router.Get(webClientResetPwdPath, s.handleWebClientPasswordReset) s.router.Get(webClientResetPwdPath, s.handleWebClientPasswordReset)
s.router.Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)).
Get(webClientTwoFactorPath, s.handleWebClientTwoFactor) Get(webClientTwoFactorPath, s.handleWebClientTwoFactor)
@ -1557,7 +1548,8 @@ func (s *httpdServer) setupWebClientRoutes() {
} }
// share routes available to external users // share routes available to external users
s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet) s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet)
s.router.Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost)
s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare) s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare)
s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload) s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload)
s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles) s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles)
@ -1574,32 +1566,32 @@ func (s *httpdServer) setupWebClientRoutes() {
if s.binding.OIDC.isEnabled() { if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient)) router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient))
} }
router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie)) router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebClient) router.Use(jwtAuthenticatorWebClient)
router.Get(webClientLogoutPath, s.handleWebClientLogout) router.Get(webClientLogoutPath, s.handleWebClientLogout)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF)
router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientFilePath, getUserFile) router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientFilePath, getUserFile)
router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientTasksPath+"/{id}", router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientTasksPath+"/{id}",
getWebTask) getWebTask)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFilePath, uploadUserFile) Post(webClientFilePath, uploadUserFile)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientExistPath, s.handleClientCheckExist) Post(webClientExistPath, s.handleClientCheckExist)
router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Delete(webClientFilesPath, deleteUserFile) Delete(webClientFilesPath, deleteUserFile)
router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie). router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie).
Get(webClientDirsPath, s.handleClientGetDirContents) Get(webClientDirsPath, s.handleClientGetDirContents)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientDirsPath, createUserDir) Post(webClientDirsPath, createUserDir)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Delete(webClientDirsPath, taskDeleteDir) Delete(webClientDirsPath, taskDeleteDir)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFileActionsPath+"/move", taskRenameFsEntry) Post(webClientFileActionsPath+"/move", taskRenameFsEntry)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader).
Post(webClientFileActionsPath+"/copy", taskCopyFsEntry) Post(webClientFileActionsPath+"/copy", taskCopyFsEntry)
router.With(s.checkAuthRequirements, s.refreshCookie). router.With(s.checkAuthRequirements, s.refreshCookie).
Post(webClientDownloadZipPath, s.handleWebClientDownloadZip) Post(webClientDownloadZipPath, s.handleWebClientDownloadZip)
@ -1615,15 +1607,15 @@ func (s *httpdServer) setupWebClientRoutes() {
Get(webClientMFAPath, s.handleWebClientMFA) Get(webClientMFAPath, s.handleWebClientMFA)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie).
Get(webClientMFAPath+"/qrcode", getQRCode) Get(webClientMFAPath+"/qrcode", getQRCode)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPGeneratePath, generateTOTPSecret) Post(webClientTOTPGeneratePath, generateTOTPSecret)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPValidatePath, validateTOTPPasscode) Post(webClientTOTPValidatePath, validateTOTPPasscode)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientTOTPSavePath, saveTOTPConfig) Post(webClientTOTPSavePath, saveTOTPConfig)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader, s.refreshCookie). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader, s.refreshCookie).
Get(webClientRecoveryCodesPath, getRecoveryCodes) Get(webClientRecoveryCodesPath, getRecoveryCodes)
router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader).
Post(webClientRecoveryCodesPath, generateRecoveryCodes) Post(webClientRecoveryCodesPath, generateRecoveryCodes)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie).
Get(webClientSharesPath+jsonAPISuffix, getAllShares) Get(webClientSharesPath+jsonAPISuffix, getAllShares)
@ -1637,7 +1629,7 @@ func (s *httpdServer) setupWebClientRoutes() {
Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet) Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)).
Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost) Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost)
router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), verifyCSRFHeader). router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.verifyCSRFHeader).
Delete(webClientSharePath+"/{id}", deleteShare) Delete(webClientSharePath+"/{id}", deleteShare)
}) })
} }
@ -1655,9 +1647,11 @@ func (s *httpdServer) setupWebAdminRoutes() {
} }
s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect) s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect)
s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet) s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet)
s.router.Post(webAdminSetupPath, s.handleWebAdminSetupPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminSetupPath, s.handleWebAdminSetupPost)
if !s.binding.isWebAdminLoginFormDisabled() { if !s.binding.isWebAdminLoginFormDisabled() {
s.router.Post(webAdminLoginPath, s.handleWebAdminLoginPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminLoginPath, s.handleWebAdminLoginPost)
s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie),
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor) Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor)
@ -1671,16 +1665,18 @@ func (s *httpdServer) setupWebAdminRoutes() {
s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)).
Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost) Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost)
s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd) s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd)
s.router.Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost)
s.router.Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset) s.router.Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset)
s.router.Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)).
Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost)
} }
s.router.Group(func(router chi.Router) { s.router.Group(func(router chi.Router) {
if s.binding.OIDC.isEnabled() { if s.binding.OIDC.isEnabled() {
router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin)) router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin))
} }
router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie)) router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie))
router.Use(jwtAuthenticatorWebAdmin) router.Use(jwtAuthenticatorWebAdmin)
router.Get(webLogoutPath, s.handleWebAdminLogout) router.Get(webLogoutPath, s.handleWebAdminLogout)
@ -1692,12 +1688,12 @@ func (s *httpdServer) setupWebAdminRoutes() {
router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA)
router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig)
router.With(verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath, router.With(s.verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath,
getRecoveryCodes) getRecoveryCodes)
router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes) router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes)
router.Group(func(router chi.Router) { router.Group(func(router chi.Router) {
router.Use(s.checkAuthRequirements) router.Use(s.checkAuthRequirements)
@ -1724,7 +1720,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet) Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet)
router.With(s.checkPerm(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}", router.With(s.checkPerm(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}",
s.handleWebUpdateGroupPost) s.handleWebUpdateGroupPost)
router.With(s.checkPerm(dataprovider.PermAdminManageGroups), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageGroups), s.verifyCSRFHeader).
Delete(webGroupPath+"/{name}", deleteGroup) Delete(webGroupPath+"/{name}", deleteGroup)
router.With(s.checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie).
Get(webConnectionsPath, s.handleWebGetConnections) Get(webConnectionsPath, s.handleWebGetConnections)
@ -1750,25 +1746,25 @@ func (s *httpdServer) setupWebAdminRoutes() {
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, s.handleWebAddAdminPost) router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, s.handleWebAddAdminPost)
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}",
s.handleWebUpdateAdminPost) s.handleWebUpdateAdminPost)
router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), s.verifyCSRFHeader).
Delete(webAdminPath+"/{username}", deleteAdmin) Delete(webAdminPath+"/{username}", deleteAdmin)
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader).
Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA) Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA)
router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), s.verifyCSRFHeader).
Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.refreshCookie).
Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet) Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}", router.With(s.checkPerm(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}",
s.handleWebUpdateFolderPost) s.handleWebUpdateFolderPost)
router.With(s.checkPerm(dataprovider.PermAdminManageFolders), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.verifyCSRFHeader).
Delete(webFolderPath+"/{name}", deleteFolder) Delete(webFolderPath+"/{name}", deleteFolder)
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader).
Post(webScanVFolderPath+"/{name}", startFolderQuotaScan) Post(webScanVFolderPath+"/{name}", startFolderQuotaScan)
router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), s.verifyCSRFHeader).
Delete(webUserPath+"/{username}", deleteUser) Delete(webUserPath+"/{username}", deleteUser)
router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader).
Put(webUserPath+"/{username}/2fa/disable", disableUser2FA) Put(webUserPath+"/{username}/2fa/disable", disableUser2FA)
router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader).
Post(webQuotaScanPath+"/{username}", startUserQuotaScan) Post(webQuotaScanPath+"/{username}", startUserQuotaScan)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, s.handleWebMaintenance) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, s.handleWebMaintenance)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData)
@ -1795,7 +1791,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet) Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventActionPath+"/{name}", router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventActionPath+"/{name}",
s.handleWebUpdateEventActionPost) s.handleWebUpdateEventActionPost)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Delete(webAdminEventActionPath+"/{name}", deleteEventAction) Delete(webAdminEventActionPath+"/{name}", deleteEventAction)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), compressor.Handler, s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), compressor.Handler, s.refreshCookie).
Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules) Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules)
@ -1809,9 +1805,9 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet) Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventRulePath+"/{name}", router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventRulePath+"/{name}",
s.handleWebUpdateEventRulePost) s.handleWebUpdateEventRulePost)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Delete(webAdminEventRulePath+"/{name}", deleteEventRule) Delete(webAdminEventRulePath+"/{name}", deleteEventRule)
router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader).
Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule) Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.refreshCookie).
Get(webAdminRolesPath, s.handleWebGetRoles) Get(webAdminRolesPath, s.handleWebGetRoles)
@ -1824,7 +1820,7 @@ func (s *httpdServer) setupWebAdminRoutes() {
Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet) Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles)).Post(webAdminRolePath+"/{name}", router.With(s.checkPerm(dataprovider.PermAdminManageRoles)).Post(webAdminRolePath+"/{name}",
s.handleWebUpdateRolePost) s.handleWebUpdateRolePost)
router.With(s.checkPerm(dataprovider.PermAdminManageRoles), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.verifyCSRFHeader).
Delete(webAdminRolePath+"/{name}", deleteRole) Delete(webAdminRolePath+"/{name}", deleteRole)
router.With(s.checkPerm(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath, router.With(s.checkPerm(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath,
s.handleWebGetEvents) s.handleWebGetEvents)
@ -1845,14 +1841,14 @@ func (s *httpdServer) setupWebAdminRoutes() {
s.handleWebUpdateIPListEntryGet) s.handleWebUpdateIPListEntryGet)
router.With(s.checkPerm(dataprovider.PermAdminManageIPLists)).Post(webIPListPath+"/{type}/{ipornet}", router.With(s.checkPerm(dataprovider.PermAdminManageIPLists)).Post(webIPListPath+"/{type}/{ipornet}",
s.handleWebUpdateIPListEntryPost) s.handleWebUpdateIPListEntryPost)
router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), verifyCSRFHeader). router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), s.verifyCSRFHeader).
Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry) Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs) router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Post(webConfigsPath, s.handleWebConfigsPost) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Post(webConfigsPath, s.handleWebConfigsPost)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie).
Post(webConfigsPath+"/smtp/test", testSMTPConfig) Post(webConfigsPath+"/smtp/test", testSMTPConfig)
router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie). router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie).
Post(webOAuth2TokenPath, handleSMTPOAuth2TokenRequestPost) Post(webOAuth2TokenPath, s.handleSMTPOAuth2TokenRequestPost)
}) })
}) })
} }

View file

@ -31,6 +31,7 @@ import (
"time" "time"
"github.com/go-chi/render" "github.com/go-chi/render"
"github.com/rs/xid"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms" sdkkms "github.com/sftpgo/sdk/kms"
@ -612,10 +613,10 @@ func isServerManagerResource(currentURL string) bool {
currentURL == webConfigsPath currentURL == webConfigsPath
} }
func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request) basePage { func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage {
var csrfToken string var csrfToken string
if currentURL != "" { if currentURL != "" {
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr)) csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath)
} }
return basePage{ return basePage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
@ -675,7 +676,7 @@ func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http
err error, message, text string, err error, message, text string,
) { ) {
data := messagePage{ data := messagePage{
basePage: s.getBasePageData(title, "", r), basePage: s.getBasePageData(title, "", w, r),
Error: getI18nError(err), Error: getI18nError(err),
Success: message, Success: message,
Text: text, Text: text,
@ -710,12 +711,12 @@ func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request,
util.NewI18nError(err, util.I18nError404Message), "") util.NewI18nError(err, util.I18nError404Message), "")
} }
func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := forgotPwdPage{ data := forgotPwdPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
CurrentURL: webAdminForgotPwdPath, CurrentURL: webAdminForgotPwdPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
LoginURL: webAdminLoginPath, LoginURL: webAdminLoginPath,
Title: util.I18nForgotPwdTitle, Title: util.I18nForgotPwdTitle,
Branding: s.binding.Branding.WebAdmin, Branding: s.binding.Branding.WebAdmin,
@ -723,12 +724,12 @@ func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request
renderAdminTemplate(w, templateForgotPassword, data) renderAdminTemplate(w, templateForgotPassword, data)
} }
func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := resetPwdPage{ data := resetPwdPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
CurrentURL: webAdminResetPwdPath, CurrentURL: webAdminResetPwdPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
LoginURL: webAdminLoginPath, LoginURL: webAdminLoginPath,
Title: util.I18nResetPwdTitle, Title: util.I18nResetPwdTitle,
Branding: s.binding.Branding.WebAdmin, Branding: s.binding.Branding.WebAdmin,
@ -736,26 +737,26 @@ func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request,
renderAdminTemplate(w, templateResetPassword, data) renderAdminTemplate(w, templateResetPassword, data)
} }
func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{ data := twoFactorPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorTitle, Title: pageTwoFactorTitle,
CurrentURL: webAdminTwoFactorPath, CurrentURL: webAdminTwoFactorPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
RecoveryURL: webAdminTwoFactorRecoveryPath, RecoveryURL: webAdminTwoFactorRecoveryPath,
Branding: s.binding.Branding.WebAdmin, Branding: s.binding.Branding.WebAdmin,
} }
renderAdminTemplate(w, templateTwoFactor, data) renderAdminTemplate(w, templateTwoFactor, data)
} }
func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{ data := twoFactorPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorRecoveryTitle, Title: pageTwoFactorRecoveryTitle,
CurrentURL: webAdminTwoFactorRecoveryPath, CurrentURL: webAdminTwoFactorRecoveryPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath),
Branding: s.binding.Branding.WebAdmin, Branding: s.binding.Branding.WebAdmin,
} }
renderAdminTemplate(w, templateTwoFactorRecovery, data) renderAdminTemplate(w, templateTwoFactorRecovery, data)
@ -763,7 +764,7 @@ func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http
func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
data := mfaPage{ data := mfaPage{
basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, r), basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, w, r),
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
GenerateTOTPURL: webAdminTOTPGeneratePath, GenerateTOTPURL: webAdminTOTPGeneratePath,
ValidateTOTPURL: webAdminTOTPValidatePath, ValidateTOTPURL: webAdminTOTPValidatePath,
@ -782,7 +783,7 @@ func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) { func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) {
data := profilePage{ data := profilePage{
basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, r), basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r),
Error: getI18nError(err), Error: getI18nError(err),
} }
admin, err := dataprovider.AdminExists(data.LoggedUser.Username) admin, err := dataprovider.AdminExists(data.LoggedUser.Username)
@ -799,7 +800,7 @@ func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request,
func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := changePasswordPage{ data := changePasswordPage{
basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, r), basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r),
Error: err, Error: err,
} }
@ -808,7 +809,7 @@ func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Re
func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) { func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) {
data := maintenancePage{ data := maintenancePage{
basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, r), basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r),
BackupPath: webBackupPath, BackupPath: webBackupPath,
RestorePath: webRestorePath, RestorePath: webRestorePath,
Error: getI18nError(err), Error: getI18nError(err),
@ -830,7 +831,7 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request,
configs.ACME.HTTP01Challenge.Port = 80 configs.ACME.HTTP01Challenge.Port = 80
} }
data := configsPage{ data := configsPage{
basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, r), basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r),
Configs: configs, Configs: configs,
ConfigSection: section, ConfigSection: section,
RedactedSecret: redactedSecret, RedactedSecret: redactedSecret,
@ -842,12 +843,12 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request,
renderAdminTemplate(w, templateConfigs, data) renderAdminTemplate(w, templateConfigs, data)
} }
func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username, ip string, err *util.I18nError) { func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) {
data := setupPage{ data := setupPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: util.I18nSetupTitle, Title: util.I18nSetupTitle,
CurrentURL: webAdminSetupPath, CurrentURL: webAdminSetupPath,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath),
Username: username, Username: username,
HasInstallationCode: installationCode != "", HasInstallationCode: installationCode != "",
InstallationCodeHint: installationCodeHint, InstallationCodeHint: installationCodeHint,
@ -876,7 +877,7 @@ func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Re
title = util.I18nUpdateAdminTitle title = util.I18nUpdateAdminTitle
} }
data := adminPage{ data := adminPage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Admin: admin, Admin: admin,
Groups: groups, Groups: groups,
Roles: roles, Roles: roles,
@ -917,7 +918,7 @@ func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, use
} }
} }
user.FsConfig.RedactedSecret = redactedSecret user.FsConfig.RedactedSecret = redactedSecret
basePage := s.getBasePageData(title, currentURL, r) basePage := s.getBasePageData(title, currentURL, w, r)
if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil { if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil {
for _, group := range admin.Groups { for _, group := range admin.Groups {
user.Groups = append(user.Groups, sdk.GroupMapping{ user.Groups = append(user.Groups, sdk.GroupMapping{
@ -982,7 +983,7 @@ func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, e
currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet)) currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet))
} }
data := ipListPage{ data := ipListPage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err), Error: getI18nError(err),
Entry: &entry, Entry: &entry,
Mode: mode, Mode: mode,
@ -1003,7 +1004,7 @@ func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, rol
currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name)) currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name))
} }
data := rolePage{ data := rolePage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err), Error: getI18nError(err),
Role: &role, Role: &role,
Mode: mode, Mode: mode,
@ -1033,7 +1034,7 @@ func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, gr
group.UserSettings.FsConfig.SetEmptySecretsIfNil() group.UserSettings.FsConfig.SetEmptySecretsIfNil()
data := groupPage{ data := groupPage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err), Error: getI18nError(err),
Group: &group, Group: &group,
Mode: mode, Mode: mode,
@ -1078,7 +1079,7 @@ func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Reque
} }
data := eventActionPage{ data := eventActionPage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Action: action, Action: action,
ActionTypes: dataprovider.EventActionTypes, ActionTypes: dataprovider.EventActionTypes,
FsActions: dataprovider.FsActionTypes, FsActions: dataprovider.FsActionTypes,
@ -1108,7 +1109,7 @@ func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request
} }
data := eventRulePage{ data := eventRulePage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Rule: rule, Rule: rule,
TriggerTypes: dataprovider.EventTriggerTypes, TriggerTypes: dataprovider.EventTriggerTypes,
Actions: actions, Actions: actions,
@ -1142,7 +1143,7 @@ func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, f
folder.FsConfig.SetEmptySecretsIfNil() folder.FsConfig.SetEmptySecretsIfNil()
data := folderPage{ data := folderPage{
basePage: s.getBasePageData(title, currentURL, r), basePage: s.getBasePageData(title, currentURL, w, r),
Error: getI18nError(err), Error: getI18nError(err),
Folder: folder, Folder: folder,
Mode: mode, Mode: mode,
@ -2764,25 +2765,24 @@ func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Req
s.renderNotFoundPage(w, r, errors.New("this page does not exist")) s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
return return
} }
s.renderForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderForgotPwdPage(w, r, nil)
} }
func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
err = handleForgotPassword(r, r.Form.Get("username"), true) err = handleForgotPassword(r, r.Form.Get("username"), true)
if err != nil { if err != nil {
s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr) s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
return return
} }
http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound) http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound)
@ -2794,17 +2794,17 @@ func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http
s.renderNotFoundPage(w, r, errors.New("this page does not exist")) s.renderNotFoundPage(w, r, errors.New("this page does not exist"))
return return
} }
s.renderResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderResetPwdPage(w, r, nil)
} }
func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderTwoFactorPage(w, r, nil)
} }
func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderTwoFactorRecoveryPage(w, r, nil)
} }
func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) {
@ -2830,7 +2830,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -2875,7 +2875,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) {
defer r.MultipartForm.RemoveAll() //nolint:errcheck defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -2936,7 +2936,7 @@ func getAllAdmins(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, r) data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r)
renderAdminTemplate(w, templateAdmins, data) renderAdminTemplate(w, templateAdmins, data)
} }
@ -2946,7 +2946,7 @@ func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Requ
http.Redirect(w, r, webAdminLoginPath, http.StatusFound) http.Redirect(w, r, webAdminLoginPath, http.StatusFound)
return return
} }
s.renderAdminSetupPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr), nil) s.renderAdminSetupPage(w, r, "", nil)
} }
func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) {
@ -2987,7 +2987,7 @@ func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Reque
admin.Password = util.GenerateUniqueID() admin.Password = util.GenerateUniqueID()
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3018,7 +3018,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3071,7 +3071,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re
func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := defenderHostsPage{ data := defenderHostsPage{
basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, r), basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r),
DefenderHostsURL: webDefenderHostsPath, DefenderHostsURL: webDefenderHostsPath,
} }
@ -3105,7 +3105,7 @@ func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request)
s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken))
return return
} }
data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, r) data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r)
renderAdminTemplate(w, templateUsers, data) renderAdminTemplate(w, templateUsers, data)
} }
@ -3144,7 +3144,7 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http
defer r.MultipartForm.RemoveAll() //nolint:errcheck defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3244,7 +3244,7 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3341,7 +3341,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3387,7 +3387,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3425,7 +3425,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req
func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := statusPage{ data := statusPage{
basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, r), basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r),
Status: getServicesStatus(), Status: getServicesStatus(),
} }
renderAdminTemplate(w, templateStatus, data) renderAdminTemplate(w, templateStatus, data)
@ -3439,7 +3439,7 @@ func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Req
return return
} }
data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, r) data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r)
renderAdminTemplate(w, templateConnections, data) renderAdminTemplate(w, templateConnections, data)
} }
@ -3464,7 +3464,7 @@ func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Requ
defer r.MultipartForm.RemoveAll() //nolint:errcheck defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3525,7 +3525,7 @@ func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.R
defer r.MultipartForm.RemoveAll() //nolint:errcheck defer r.MultipartForm.RemoveAll() //nolint:errcheck
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3588,7 +3588,7 @@ func getAllFolders(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, r) data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r)
renderAdminTemplate(w, templateFolders, data) renderAdminTemplate(w, templateFolders, data)
} }
@ -3626,7 +3626,7 @@ func getAllGroups(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, r) data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r)
renderAdminTemplate(w, templateGroups, data) renderAdminTemplate(w, templateGroups, data)
} }
@ -3648,7 +3648,7 @@ func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Reque
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3695,7 +3695,7 @@ func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Re
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3748,7 +3748,7 @@ func getAllActions(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, r) data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r)
renderAdminTemplate(w, templateEventActions, data) renderAdminTemplate(w, templateEventActions, data)
} }
@ -3773,7 +3773,7 @@ func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3819,7 +3819,7 @@ func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *h
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3858,7 +3858,7 @@ func getAllRules(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, r) data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r)
renderAdminTemplate(w, templateEventRules, data) renderAdminTemplate(w, templateEventRules, data)
} }
@ -3884,7 +3884,7 @@ func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.R
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err = verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr) err = verifyCSRFToken(r, s.csrfTokenAuth)
if err != nil { if err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
@ -3931,7 +3931,7 @@ func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *htt
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -3978,7 +3978,7 @@ func getAllRoles(w http.ResponseWriter, r *http.Request) {
func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, r) data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r)
renderAdminTemplate(w, templateRoles, data) renderAdminTemplate(w, templateRoles, data)
} }
@ -4001,7 +4001,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -4047,7 +4047,7 @@ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Req
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -4065,7 +4065,7 @@ func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request)
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := eventsPage{ data := eventsPage{
basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, r), basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r),
FsEventsSearchURL: webEventsFsSearchPath, FsEventsSearchURL: webEventsFsSearchPath,
ProviderEventsSearchURL: webEventsProviderSearchPath, ProviderEventsSearchURL: webEventsProviderSearchPath,
LogEventsSearchURL: webEventsLogSearchPath, LogEventsSearchURL: webEventsLogSearchPath,
@ -4077,7 +4077,7 @@ func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Reques
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus() rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus()
data := ipListsPage{ data := ipListsPage{
basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, r), basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r),
RateLimitersStatus: rtlStatus, RateLimitersStatus: rtlStatus,
RateLimitersProtocols: strings.Join(rtlProtocols, ", "), RateLimitersProtocols: strings.Join(rtlProtocols, ", "),
IsAllowListEnabled: common.Config.IsAllowListEnabled(), IsAllowListEnabled: common.Config.IsAllowListEnabled(),
@ -4115,7 +4115,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -4170,7 +4170,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *h
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -4212,7 +4212,7 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -4262,20 +4262,21 @@ func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.R
stateToken := r.URL.Query().Get("state") stateToken := r.URL.Query().Get("state")
state, err := verifyOAuth2Token(stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr)) state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr))
if err != nil { if err != nil {
s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "") s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "")
return return
} }
defer oauth2Mgr.removePendingAuth(state)
pendingAuth, err := oauth2Mgr.getPendingAuth(state) pendingAuth, err := oauth2Mgr.getPendingAuth(state)
if err != nil { if err != nil {
oauth2Mgr.removePendingAuth(state)
s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError,
util.NewI18nError(err, util.I18nOAuth2ErrorValidateState), "") util.NewI18nError(err, util.I18nOAuth2ErrorValidateState), "")
return return
} }
oauth2Mgr.removePendingAuth(state)
oauth2Config := smtp.OAuth2Config{ oauth2Config := smtp.OAuth2Config{
Provider: pendingAuth.Provider, Provider: pendingAuth.Provider,
ClientID: pendingAuth.ClientID, ClientID: pendingAuth.ClientID,

View file

@ -523,10 +523,10 @@ func loadClientTemplates(templatesPath string) {
clientTemplates[templateShareDownload] = shareDownloadTmpl clientTemplates[templateShareDownload] = shareDownloadTmpl
} }
func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Request) baseClientPage { func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage {
var csrfToken string var csrfToken string
if currentURL != "" { if currentURL != "" {
csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr)) csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath)
} }
data := baseClientPage{ data := baseClientPage{
@ -552,12 +552,12 @@ func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Re
return data return data
} }
func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := forgotPwdPage{ data := forgotPwdPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
CurrentURL: webClientForgotPwdPath, CurrentURL: webClientForgotPwdPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
LoginURL: webClientLoginPath, LoginURL: webClientLoginPath,
Title: util.I18nForgotPwdTitle, Title: util.I18nForgotPwdTitle,
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
@ -565,12 +565,12 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.R
renderClientTemplate(w, templateForgotPassword, data) renderClientTemplate(w, templateForgotPassword, data)
} }
func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := resetPwdPage{ data := resetPwdPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
CurrentURL: webClientResetPwdPath, CurrentURL: webClientResetPwdPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
LoginURL: webClientLoginPath, LoginURL: webClientLoginPath,
Title: util.I18nResetPwdTitle, Title: util.I18nResetPwdTitle,
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
@ -578,13 +578,13 @@ func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Re
renderClientTemplate(w, templateResetPassword, data) renderClientTemplate(w, templateResetPassword, data)
} }
func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := shareLoginPage{ data := shareLoginPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: util.I18nShareLoginTitle, Title: util.I18nShareLoginTitle,
CurrentURL: r.RequestURI, CurrentURL: r.RequestURI,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath),
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
} }
renderClientTemplate(w, templateShareLogin, data) renderClientTemplate(w, templateShareLogin, data)
@ -599,7 +599,7 @@ func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) {
func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) { func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) {
data := clientMessagePage{ data := clientMessagePage{
baseClientPage: s.getBaseClientPageData(title, "", r), baseClientPage: s.getBaseClientPageData(title, "", w, r),
Error: getI18nError(err), Error: getI18nError(err),
Success: message, Success: message,
} }
@ -627,13 +627,13 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re
util.NewI18nError(err, util.I18nError404Message), "") util.NewI18nError(err, util.I18nError404Message), "")
} }
func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{ data := twoFactorPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorTitle, Title: pageTwoFactorTitle,
CurrentURL: webClientTwoFactorPath, CurrentURL: webClientTwoFactorPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
RecoveryURL: webClientTwoFactorRecoveryPath, RecoveryURL: webClientTwoFactorRecoveryPath,
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
} }
@ -643,13 +643,13 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.R
renderClientTemplate(w, templateTwoFactor, data) renderClientTemplate(w, templateTwoFactor, data)
} }
func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := twoFactorPage{ data := twoFactorPage{
commonBasePage: getCommonBasePage(r), commonBasePage: getCommonBasePage(r),
Title: pageTwoFactorRecoveryTitle, Title: pageTwoFactorRecoveryTitle,
CurrentURL: webClientTwoFactorRecoveryPath, CurrentURL: webClientTwoFactorRecoveryPath,
Error: err, Error: err,
CSRFToken: createCSRFToken(ip), CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath),
Branding: s.binding.Branding.WebClient, Branding: s.binding.Branding.WebClient,
} }
renderClientTemplate(w, templateTwoFactorRecovery, data) renderClientTemplate(w, templateTwoFactorRecovery, data)
@ -657,7 +657,7 @@ func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r
func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) {
data := clientMFAPage{ data := clientMFAPage{
baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, r), baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r),
TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(),
GenerateTOTPURL: webClientTOTPGeneratePath, GenerateTOTPURL: webClientTOTPGeneratePath,
ValidateTOTPURL: webClientTOTPValidatePath, ValidateTOTPURL: webClientTOTPValidatePath,
@ -681,7 +681,7 @@ func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request,
title = util.I18nEditFileTitle title = util.I18nEditFileTitle
} }
data := editFilePage{ data := editFilePage{
baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, r), baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r),
Path: fileName, Path: fileName,
Name: path.Base(fileName), Name: path.Base(fileName),
CurrentDir: path.Dir(fileName), CurrentDir: path.Dir(fileName),
@ -702,7 +702,7 @@ func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Re
title = util.I18nShareUpdateTitle title = util.I18nShareUpdateTitle
} }
data := clientSharePage{ data := clientSharePage{
baseClientPage: s.getBaseClientPageData(title, currentURL, r), baseClientPage: s.getBaseClientPageData(title, currentURL, w, r),
Share: share, Share: share,
Error: err, Error: err,
IsAdd: isAdd, IsAdd: isAdd,
@ -736,7 +736,7 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
err *util.I18nError, share dataprovider.Share, err *util.I18nError, share dataprovider.Share,
) { ) {
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse") currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse")
baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, r) baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r)
baseData.FilesURL = currentURL baseData.FilesURL = currentURL
baseSharePath := path.Join(webClientPubSharesPath, share.ShareID) baseSharePath := path.Join(webClientPubSharesPath, share.ShareID)
@ -768,7 +768,7 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque
func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, downloadLink string) { func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, downloadLink string) {
data := shareDownloadPage{ data := shareDownloadPage{
baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", r), baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r),
DownloadLink: downloadLink, DownloadLink: downloadLink,
} }
renderClientTemplate(w, templateShareDownload, data) renderClientTemplate(w, templateShareDownload, data)
@ -777,7 +777,7 @@ func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Req
func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share dataprovider.Share) { func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share dataprovider.Share) {
currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload") currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload")
data := shareUploadPage{ data := shareUploadPage{
baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, r), baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r),
Share: &share, Share: &share,
UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID), UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID),
} }
@ -787,7 +787,7 @@ func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Req
func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string, func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string,
err *util.I18nError, user *dataprovider.User) { err *util.I18nError, user *dataprovider.User) {
data := filesPage{ data := filesPage{
baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, r), baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r),
Error: err, Error: err,
CurrentDir: url.QueryEscape(dirName), CurrentDir: url.QueryEscape(dirName),
DownloadURL: webClientDownloadZipPath, DownloadURL: webClientDownloadZipPath,
@ -813,7 +813,7 @@ func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, di
func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := clientProfilePage{ data := clientProfilePage{
baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, r), baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r),
Error: err, Error: err,
} }
user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "") user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "")
@ -832,7 +832,7 @@ func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Req
func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) {
data := changeClientPasswordPage{ data := changeClientPasswordPage{
baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, r), baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r),
Error: err, Error: err,
} }
@ -850,8 +850,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -1440,7 +1439,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -1508,7 +1507,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -1579,7 +1578,7 @@ func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Reque
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
data := clientSharesPage{ data := clientSharesPage{
baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, r), baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r),
BasePublicSharesURL: webClientPubSharesPath, BasePublicSharesURL: webClientPubSharesPath,
} }
renderClientTemplate(w, templateClientShares, data) renderClientTemplate(w, templateClientShares, data)
@ -1603,7 +1602,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.
return return
} }
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
@ -1662,12 +1661,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request)
func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderClientTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderClientTwoFactorPage(w, r, nil)
} }
func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
s.renderClientTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderClientTwoFactorRecoveryPage(w, r, nil)
} }
func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) { func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) {
@ -1719,26 +1718,25 @@ func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Re
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
return return
} }
s.renderClientForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderClientForgotPwdPage(w, r, nil)
} }
func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
username := strings.TrimSpace(r.Form.Get("username")) username := strings.TrimSpace(r.Form.Get("username"))
err = handleForgotPassword(r, username, false) err = handleForgotPassword(r, username, false)
if err != nil { if err != nil {
s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr) s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric))
return return
} }
http.Redirect(w, r, webClientResetPwdPath, http.StatusFound) http.Redirect(w, r, webClientResetPwdPath, http.StatusFound)
@ -1750,7 +1748,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt
s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) s.renderClientNotFoundPage(w, r, errors.New("this page does not exist"))
return return
} }
s.renderClientResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderClientResetPwdPage(w, r, nil)
} }
func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) {
@ -1853,30 +1851,30 @@ func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name str
func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
s.renderShareLoginPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) s.renderShareLoginPage(w, r, nil)
} }
func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize)
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm))
return return
} }
if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF))
return return
} }
invalidateToken(r, true)
shareID := getURLParam(r, "id") shareID := getURLParam(r, "id")
share, err := dataprovider.ShareExists(shareID, "") share, err := dataprovider.ShareExists(shareID, "")
if err != nil { if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials))
return return
} }
match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password"))) match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password")))
if !match || err != nil { if !match || err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials))
ipAddr)
return return
} }
c := jwtTokenClaims{ c := jwtTokenClaims{
@ -1884,7 +1882,7 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.
} }
err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr) err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr)
if err != nil { if err != nil {
s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message))
return return
} }
next := path.Clean(r.URL.Query().Get("next")) next := path.Clean(r.URL.Query().Get("next"))