From 82949524746f5eef739e2b642d937330d4b9d4d1 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Fri, 14 Jun 2024 18:09:32 +0200 Subject: [PATCH] WebUIs: refactor CSRF Signed-off-by: Nicola Murino --- internal/httpd/api_admin.go | 2 +- internal/httpd/api_configs.go | 4 +- internal/httpd/api_http_user.go | 2 +- internal/httpd/api_mfa.go | 3 +- internal/httpd/api_shares.go | 12 +- internal/httpd/auth_utils.go | 162 +++++++++-- internal/httpd/httpd.go | 4 - internal/httpd/httpd_test.go | 491 +++++++++++++++++++++++++------- internal/httpd/internal_test.go | 403 +++++++++++++++++++++----- internal/httpd/middleware.go | 34 +-- internal/httpd/oidc.go | 2 + internal/httpd/oidc_test.go | 6 +- internal/httpd/server.go | 290 ++++++++++--------- internal/httpd/webadmin.go | 143 +++++----- internal/httpd/webclient.go | 86 +++--- 15 files changed, 1150 insertions(+), 494 deletions(-) diff --git a/internal/httpd/api_admin.go b/internal/httpd/api_admin.go index 96a16dec..646863e2 100644 --- a/internal/httpd/api_admin.go +++ b/internal/httpd/api_admin.go @@ -297,7 +297,7 @@ func changeAdminPassword(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } - invalidateToken(r) + invalidateToken(r, false) sendAPIResponse(w, r, err, "Password updated", http.StatusOK) } diff --git a/internal/httpd/api_configs.go b/internal/httpd/api_configs.go index 399c7ad2..1a208aac 100644 --- a/internal/httpd/api_configs.go +++ b/internal/httpd/api_configs.go @@ -85,7 +85,7 @@ type oauth2TokenRequest struct { BaseRedirectURL string `json:"base_redirect_url"` } -func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { +func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) var req oauth2TokenRequest @@ -115,7 +115,7 @@ func handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { clientSecret.SetAdditionalData(xid.New().String()) pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret) oauth2Mgr.addPendingAuth(pendingAuth) - stateToken := createOAuth2Token(pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr)) + stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr)) if stateToken == "" { sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError) return diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go index 26bc0063..b3292383 100644 --- a/internal/httpd/api_http_user.go +++ b/internal/httpd/api_http_user.go @@ -531,7 +531,7 @@ func changeUserPassword(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } - invalidateToken(r) + invalidateToken(r, false) sendAPIResponse(w, r, err, "Password updated", http.StatusOK) } diff --git a/internal/httpd/api_mfa.go b/internal/httpd/api_mfa.go index 0b7e282e..3c4966cc 100644 --- a/internal/httpd/api_mfa.go +++ b/internal/httpd/api_mfa.go @@ -138,8 +138,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) { if claims.MustSetTwoFactorAuth { // force logout defer func() { - c := jwtTokenClaims{} - c.removeCookie(w, r, baseURL) + removeCookie(w, r, baseURL) }() } diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go index 73f85299..f167aee2 100644 --- a/internal/httpd/api_shares.go +++ b/internal/httpd/api_shares.go @@ -441,13 +441,11 @@ func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *h doRedirect() return errInvalidToken } - if tokenValidationMode != tokenValidationNoIPMatch { - ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if !util.Contains(token.Audience(), ipAddr) { - logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr) - doRedirect() - return errInvalidToken - } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", share.ShareID, ipAddr) + doRedirect() + return err } ctx := jwtauth.NewContext(r.Context(), token, nil) claims, err := getTokenClaims(r.WithContext(ctx)) diff --git a/internal/httpd/auth_utils.go b/internal/httpd/auth_utils.go index 0ee3dabb..08d74a8a 100644 --- a/internal/httpd/auth_utils.go +++ b/internal/httpd/auth_utils.go @@ -41,6 +41,7 @@ const ( tokenAudienceAPIUser tokenAudience = "APIUser" tokenAudienceCSRF tokenAudience = "CSRF" tokenAudienceOAuth2 tokenAudience = "OAuth2" + tokenAudienceWebLogin tokenAudience = "WebLogin" ) type tokenValidation = int @@ -60,6 +61,7 @@ const ( claimMustSetSecondFactorKey = "2fa_required" claimRequiredTwoFactorProtocols = "2fa_protos" claimHideUserPageSection = "hus" + claimRef = "ref" basicRealm = "Basic realm=\"SFTPGo\"" jwtCookieKey = "jwt" ) @@ -69,7 +71,7 @@ var ( shareTokenDuration = 2 * time.Hour // csrf token duration is greater than normal token duration to reduce issues // with the login form - csrfTokenDuration = 6 * time.Hour + csrfTokenDuration = 4 * time.Hour tokenRefreshThreshold = 10 * time.Minute tokenValidationMode = tokenValidationFull ) @@ -86,6 +88,8 @@ type jwtTokenClaims struct { MustChangePassword bool RequiredTwoFactorProtocols []string HideUserPageSections int + JwtID string + Ref string } func (c *jwtTokenClaims) hasUserAudience() bool { @@ -103,6 +107,12 @@ func (c *jwtTokenClaims) asMap() map[string]any { claims[claimUsernameKey] = c.Username claims[claimPermissionsKey] = c.Permissions + if c.JwtID != "" { + claims[jwt.JwtIDKey] = c.JwtID + } + if c.Ref != "" { + claims[claimRef] = c.Ref + } if c.Role != "" { claims[claimRole] = c.Role } @@ -169,6 +179,7 @@ func (c *jwtTokenClaims) Decode(token map[string]any) { c.Permissions = nil c.Username = c.decodeString(token[claimUsernameKey]) c.Signature = c.decodeString(token[jwt.SubjectKey]) + c.JwtID = c.decodeString(token[jwt.JwtIDKey]) audience := token[jwt.AudienceKey] switch v := audience.(type) { @@ -176,6 +187,10 @@ func (c *jwtTokenClaims) Decode(token map[string]any) { c.Audience = v } + if val, ok := token[claimRef]; ok { + c.Ref = c.decodeString(val) + } + if val, ok := token[claimAPIKey]; ok { c.APIKeyID = c.decodeString(val) } @@ -236,9 +251,15 @@ func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenA claims := c.asMap() now := time.Now().UTC() - claims[jwt.JwtIDKey] = xid.New().String() + if _, ok := claims[jwt.JwtIDKey]; !ok { + claims[jwt.JwtIDKey] = xid.New().String() + } claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(tokenDuration) + if audience == tokenAudienceWebLogin { + claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration) + } else { + claims[jwt.ExpirationKey] = now.Add(tokenDuration) + } claims[jwt.AudienceKey] = []string{audience, ip} return tokenAuth.Encode(claims) @@ -274,21 +295,25 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Reque if audience == tokenAudienceWebShare { duration = shareTokenDuration } + setCookie(w, r, basePath, resp["access_token"].(string), duration) + + return nil +} + +func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) { http.SetCookie(w, &http.Cookie{ Name: jwtCookieKey, - Value: resp["access_token"].(string), - Path: basePath, + Value: cookieValue, + Path: cookiePath, Expires: time.Now().Add(duration), MaxAge: int(duration / time.Second), HttpOnly: true, Secure: isTLS(r), SameSite: http.SameSiteStrictMode, }) - - return nil } -func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) { +func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) { http.SetCookie(w, &http.Cookie{ Name: jwtCookieKey, Value: "", @@ -300,10 +325,10 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request, co SameSite: http.SameSiteStrictMode, }) w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) - invalidateToken(r) + invalidateToken(r, false) } -func tokenFromContext(r *http.Request) string { +func oidcTokenFromContext(r *http.Request) string { if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok { return token } @@ -324,7 +349,7 @@ func isTokenInvalidated(r *http.Request) bool { var findTokenFns []func(r *http.Request) string findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader) findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie) - findTokenFns = append(findTokenFns, tokenFromContext) + findTokenFns = append(findTokenFns, oidcTokenFromContext) isTokenFound := false for _, fn := range findTokenFns { @@ -340,14 +365,18 @@ func isTokenInvalidated(r *http.Request) bool { return !isTokenFound } -func invalidateToken(r *http.Request) { +func invalidateToken(r *http.Request, isLoginToken bool) { + duration := tokenDuration + if isLoginToken { + duration = csrfTokenDuration + } tokenString := jwtauth.TokenFromHeader(r) if tokenString != "" { - invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) + invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC()) } tokenString = jwtauth.TokenFromCookie(r) if tokenString != "" { - invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) + invalidatedJWTTokens.Add(tokenString, time.Now().Add(duration).UTC()) } } @@ -380,7 +409,22 @@ func getAdminFromToken(r *http.Request) *dataprovider.Admin { return admin } -func createCSRFToken(ip string) string { +func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string, +) { + c := jwtTokenClaims{ + JwtID: tokenID, + } + resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip) + if err != nil { + return + } + setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration) +} + +func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, + basePath string, +) string { + ip := util.GetIPFromRemoteAddress(r.RemoteAddr) claims := make(map[string]any) now := time.Now().UTC() @@ -388,7 +432,16 @@ func createCSRFToken(ip string) string { claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration) claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip} - + if tokenID != "" { + createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip) + claims[claimRef] = tokenID + } else { + if c, err := getTokenClaims(r); err == nil { + claims[claimRef] = c.JwtID + } else { + logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err) + } + } _, tokenString, err := csrfTokenAuth.Encode(claims) if err != nil { logger.Debug(logSender, "", "unable to create CSRF token: %v", err) @@ -397,7 +450,8 @@ func createCSRFToken(ip string) string { return tokenString } -func verifyCSRFToken(tokenString, ip string) error { +func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error { + tokenString := r.Form.Get(csrfFormToken) token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err) @@ -409,17 +463,60 @@ func verifyCSRFToken(tokenString, ip string) error { return errors.New("the form token is not valid") } - if tokenValidationMode != tokenValidationNoIPMatch { - if !util.Contains(token.Audience(), ip) { - logger.Debug(logSender, "", "error validating CSRF token IP audience") - return errors.New("the form token is not valid") - } + if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + logger.Debug(logSender, "", "error validating CSRF token IP audience") + return errors.New("the form token is not valid") + } + claims, err := getTokenClaims(r) + if err != nil { + logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err) + return err + } + ref, ok := token.Get(claimRef) + if !ok { + logger.Debug(logSender, "", "error validating CSRF token, missing reference") + return errors.New("the form token is not valid") + } + if claims.JwtID == "" || claims.JwtID != ref.(string) { + logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref) + return errors.New("unexpected form token") } return nil } -func createOAuth2Token(state, ip string) string { +func verifyLoginCookie(r *http.Request) error { + token, _, err := jwtauth.FromContext(r.Context()) + if err != nil || token == nil { + logger.Debug(logSender, "", "error getting login token: %v", err) + return errInvalidToken + } + if isTokenInvalidated(r) { + logger.Debug(logSender, "", "the login token has been invalidated") + return errInvalidToken + } + if !util.Contains(token.Audience(), tokenAudienceWebLogin) { + logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin) + return errInvalidToken + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + return err + } + return nil +} + +func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error { + if err := verifyLoginCookie(r); err != nil { + return err + } + if err := verifyCSRFToken(r, csrfTokenAuth); err != nil { + return err + } + return nil +} + +func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string { claims := make(map[string]any) now := time.Now().UTC() @@ -436,7 +533,7 @@ func createOAuth2Token(state, ip string) string { return tokenString } -func verifyOAuth2Token(tokenString, ip string) (string, error) { +func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) { token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err) @@ -451,11 +548,9 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) { return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } - if tokenValidationMode != tokenValidationNoIPMatch { - if !util.Contains(token.Audience(), ip) { - logger.Debug(logSender, "", "error validating OAuth2 token IP audience") - return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) - } + if err := validateIPForToken(token, ip); err != nil { + logger.Debug(logSender, "", "error validating OAuth2 token IP audience") + return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } if val, ok := token.Get(jwt.JwtIDKey); ok { if state, ok := val.(string); ok { @@ -465,3 +560,12 @@ func verifyOAuth2Token(tokenString, ip string) (string, error) { logger.Debug(logSender, "", "jti not found in OAuth2 token") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } + +func validateIPForToken(token jwt.Token, ip string) error { + if tokenValidationMode != tokenValidationNoIPMatch { + if !util.Contains(token.Audience(), ip) { + return errInvalidToken + } + } + return nil +} diff --git a/internal/httpd/httpd.go b/internal/httpd/httpd.go index 06410c33..64677acc 100644 --- a/internal/httpd/httpd.go +++ b/internal/httpd/httpd.go @@ -31,8 +31,6 @@ import ( "time" "github.com/go-chi/chi/v5" - "github.com/go-chi/jwtauth/v5" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" @@ -196,7 +194,6 @@ var ( cleanupTicker *time.Ticker cleanupDone chan bool invalidatedJWTTokens tokenManager - csrfTokenAuth *jwtauth.JWTAuth webRootPath string webBasePath string webBaseAdminPath string @@ -967,7 +964,6 @@ func (c *Conf) Initialize(configDir string, isShared int) error { c.SigningPassphrase = passphrase } - csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(c.SigningPassphrase), nil) hideSupportLink = c.HideSupportLink exitChannel := make(chan error, 1) diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index 17578e4e..f77fd1f9 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -3289,32 +3289,44 @@ func TestLoginRedirectNext(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", redirectURI)) // now login the user and check the redirect - csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, redirectURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = redirectURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, uri, rr.Header().Get("Location")) // unsafe URI + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) unsafeURI := webClientLoginPath + "?next=" + url.QueryEscape("http://example.net") req, err = http.NewRequest(http.MethodPost, unsafeURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = unsafeURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) unsupportedURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientProfilePath) req, err = http.NewRequest(http.MethodPost, unsupportedURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = unsupportedURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) @@ -3397,7 +3409,7 @@ func TestMustChangePasswordRequirement(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -3682,9 +3694,10 @@ func TestAdminMustChangePasswordRequirement(t *testing.T) { setJWTCookieForReq(req, webToken) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + // The change password page should be accessible, we get the CSRF from it. + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, webToken) assert.NoError(t, err) + form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("current_password", defaultTokenAuthPass) @@ -6822,10 +6835,10 @@ func TestNamingRules(t *testing.T) { return } - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(user.Username, defaultPassword) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) req, err := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) @@ -6836,6 +6849,8 @@ func TestNamingRules(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) // test user reset password. Setting the new password will fail because the username is not valid + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set("username", user.Username) form.Set(csrfFormToken, csrfToken) @@ -6843,6 +6858,7 @@ func TestNamingRules(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -6855,6 +6871,7 @@ func TestNamingRules(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -6896,7 +6913,7 @@ func TestNamingRules(t *testing.T) { token, err = getJWTWebTokenFromTestServer(admin.Username, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err = getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -6927,6 +6944,8 @@ func TestNamingRules(t *testing.T) { checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "the following characters are allowed") // test admin reset password + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set("username", admin.Username) form.Set(csrfFormToken, csrfToken) @@ -6934,10 +6953,13 @@ func TestNamingRules(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("code", lastResetCode) @@ -6946,6 +6968,7 @@ func TestNamingRules(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -7097,12 +7120,13 @@ func TestSaveErrors(t *testing.T) { assert.NoError(t, err) assert.Contains(t, string(resp), "the following characters are allowed") - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(a.Username, a.Password, csrfToken) req, err := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -7110,6 +7134,8 @@ func TestSaveErrors(t *testing.T) { cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recCode) form.Set(csrfFormToken, csrfToken) @@ -7122,12 +7148,13 @@ func TestSaveErrors(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nError500Message) - csrfToken, err = getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(u.Username, u.Password, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -7135,6 +7162,8 @@ func TestSaveErrors(t *testing.T) { cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recCode) form.Set(csrfFormToken, csrfToken) @@ -7277,7 +7306,7 @@ func TestProviderErrors(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusInternalServerError, rr) // password reset errors - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set("username", "username") @@ -7285,6 +7314,7 @@ func TestProviderErrors(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -9271,12 +9301,13 @@ func TestAdminTwoFactorLogin(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -9331,6 +9362,8 @@ func TestAdminTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "invalid_passcode") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -9369,10 +9402,13 @@ func TestAdminTwoFactorLogin(t *testing.T) { rr = executeRequest(req) assert.Equal(t, http.StatusNotFound, rr.Code) // get a new cookie and login using a recovery code + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -9393,6 +9429,8 @@ func TestAdminTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "") req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -9443,16 +9481,21 @@ func TestAdminTwoFactorLogin(t *testing.T) { } assert.True(t, found) // the same recovery code cannot be reused + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) @@ -9475,12 +9518,23 @@ func TestAdminTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) @@ -9508,6 +9562,8 @@ func TestAdminTwoFactorLogin(t *testing.T) { checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) @@ -9781,7 +9837,7 @@ func TestSMTPConfig(t *testing.T) { tokenHeader := "X-CSRF-TOKEN" webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) @@ -9865,7 +9921,7 @@ func TestOAuth2TokenRequest(t *testing.T) { tokenHeader := "X-CSRF-TOKEN" webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer([]byte("{"))) assert.NoError(t, err) @@ -10006,14 +10062,24 @@ func TestWebUserTwoFactorLogin(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) + // CSRF verification fails if there is no cookie req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) @@ -10030,6 +10096,13 @@ func TestWebUserTwoFactorLogin(t *testing.T) { assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) + // invalid IP address + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = "6.7.8.9:4567" + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) req, err = http.NewRequest(http.MethodGet, webClientTwoFactorRecoveryPath, nil) assert.NoError(t, err) @@ -10065,6 +10138,8 @@ func TestWebUserTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "invalid_user_passcode") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -10104,10 +10179,13 @@ func TestWebUserTwoFactorLogin(t *testing.T) { rr = executeRequest(req) assert.Equal(t, http.StatusNotFound, rr.Code) // get a new cookie and login using a recovery code + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -10127,6 +10205,8 @@ func TestWebUserTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "") req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -10192,16 +10272,22 @@ func TestWebUserTwoFactorLogin(t *testing.T) { } assert.True(t, found) // the same recovery code cannot be reused + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set(csrfFormToken, csrfToken) @@ -10224,10 +10310,13 @@ func TestWebUserTwoFactorLogin(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -10249,11 +10338,12 @@ func TestWebUserTwoFactorLogin(t *testing.T) { checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set("recovery_code", recoveryCode) form.Set("passcode", passcode) form.Set(csrfFormToken, csrfToken) - req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) setJWTCookieForReq(req, cookie) @@ -10325,7 +10415,7 @@ func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) uri := webClientFilesPath + "?path=%2F" @@ -10335,6 +10425,7 @@ func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = loginURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -10342,20 +10433,29 @@ func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) // test unsafe redirects + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) externalURI := webClientLoginPath + "?next=" + url.QueryEscape("https://example.com") req, err = http.NewRequest(http.MethodPost, externalURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = externalURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) internalURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientMFAPath) req, err = http.NewRequest(http.MethodPost, internalURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr req.RequestURI = internalURI + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -10369,6 +10469,8 @@ func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", expectedURI)) // login with the passcode + csrfToken, err = getCSRFTokenFromInternalPageMock(expectedURI, cookie) + assert.NoError(t, err) passcode, err := generateTOTPPasscode(key.Secret()) assert.NoError(t, err) form = make(url.Values) @@ -10800,18 +10902,22 @@ func TestMFAInvalidSecret(t *testing.T) { checkResponseCode(t, http.StatusInternalServerError, rr) assert.Contains(t, rr.Body.String(), "Unable to decrypt recovery codes") - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := getLoginForm(defaultUsername, defaultPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) cookie, err := getCookieFromResponse(rr) assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "123456") @@ -10823,6 +10929,8 @@ func TestMFAInvalidSecret(t *testing.T) { rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "RC-123456") @@ -10868,18 +10976,22 @@ func TestMFAInvalidSecret(t *testing.T) { err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) - csrfToken, err = getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) cookie, err = getCookieFromResponse(rr) assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("passcode", "123456") @@ -10891,6 +11003,8 @@ func TestMFAInvalidSecret(t *testing.T) { rr = executeRequest(req) assert.Equal(t, http.StatusInternalServerError, rr.Code) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("recovery_code", "RC-123456") @@ -12744,8 +12858,6 @@ func TestWebClientLoginMock(t *testing.T) { assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) // a web token is not valid for API or WebAdmin usage req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) setBearerForReq(req, webToken) @@ -12807,6 +12919,8 @@ func TestWebClientLoginMock(t *testing.T) { assert.NoError(t, err) apiUserToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) @@ -13125,7 +13239,7 @@ func TestMaxSessions(t *testing.T) { checkResponseCode(t, http.StatusTooManyRequests, rr) assert.Contains(t, rr.Body.String(), "too many open sessions") // web client requests - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -13177,6 +13291,8 @@ func TestMaxSessions(t *testing.T) { err = smtpCfg.Initialize(configDir, true) assert.NoError(t, err) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) @@ -13184,10 +13300,14 @@ func TestMaxSessions(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("password", defaultPassword) @@ -13196,6 +13316,7 @@ func TestMaxSessions(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -13222,8 +13343,6 @@ func TestWebConfigsMock(t *testing.T) { assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webConfigsPath, nil) assert.NoError(t, err) @@ -13239,6 +13358,8 @@ func TestWebConfigsMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) // parse form error + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webConfigsPath+"?p=p%C3%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) @@ -13483,7 +13604,7 @@ func TestSFTPLoopError(t *testing.T) { err = smtpCfg.Initialize(configDir, true) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -13492,10 +13613,14 @@ func TestSFTPLoopError(t *testing.T) { req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("password", defaultPassword) @@ -13504,6 +13629,7 @@ func TestSFTPLoopError(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -13563,8 +13689,6 @@ func TestWebClientChangePwd(t *testing.T) { assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webChangeClientPwdPath, nil) assert.NoError(t, err) @@ -13587,6 +13711,8 @@ func TestWebClientChangePwd(t *testing.T) { checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr @@ -13641,6 +13767,9 @@ func TestWebClientChangePwd(t *testing.T) { webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword+"1") assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) form.Set("current_password", defaultPassword+"1") form.Set("new_password1", defaultPassword) form.Set("new_password2", defaultPassword) @@ -14085,8 +14214,6 @@ func TestShareMaxExpiration(t *testing.T) { assert.NoError(t, err) webClientToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) s := dataprovider.Share{ Name: "test share", @@ -14140,6 +14267,9 @@ func TestShareMaxExpiration(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "share must expire before") + + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, webClientToken) + assert.NoError(t, err) form := make(url.Values) form.Set("name", s.Name) form.Set("scope", strconv.Itoa(int(s.Scope))) @@ -14252,12 +14382,13 @@ func TestWebClientShareCredentials(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) // set the CSRF token - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(loginURI, defaultRemoteAddr) assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -14301,30 +14432,42 @@ func TestWebClientShareCredentials(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) // try to login with invalid credentials + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) form.Set("share_password", "") req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) // login with the next param set + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) form.Set("share_password", defaultPassword) nextURI := path.Join(webClientPubSharesPath, shareReadID, "browse") loginURI = path.Join(webClientPubSharesPath, shareReadID, fmt.Sprintf("login?next=%s", url.QueryEscape(nextURI))) req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, nextURI, rr.Header().Get("Location")) // try to login to a missing share + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) loginURI = path.Join(webClientPubSharesPath, "missing", "login") req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -15963,7 +16106,7 @@ func TestWebClientExistenceCheck(t *testing.T) { webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webClientExistPath, nil) @@ -16322,7 +16465,7 @@ func TestWebGetFiles(t *testing.T) { assert.NoError(t, err) assert.Len(t, dirEntries, 1) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set("files", fmt.Sprintf(`["%s","%s","%s"]`, testFileName, testDir, testFileName+extensions[2])) @@ -16626,7 +16769,8 @@ func TestRenameDifferentResource(t *testing.T) { assert.NoError(t, err) webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) getStatusResponse := func(taskID string) int { req, _ := http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) @@ -17380,7 +17524,7 @@ func TestWebClientTasksAPI(t *testing.T) { webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) webToken1, err := getJWTWebClientTokenFromTestServer(user1.Username, defaultPassword) assert.NoError(t, err) @@ -18437,7 +18581,7 @@ func TestCompressionErrorMock(t *testing.T) { webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) assert.NoError(t, err) form := make(url.Values) @@ -18687,12 +18831,13 @@ func TestWebAdminSetupMock(t *testing.T) { checkResponseCode(t, http.StatusFound, rr) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) - csrfToken, err := getCSRFToken(httpBaseURL + webAdminSetupPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webAdminSetupPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) @@ -18700,6 +18845,7 @@ func TestWebAdminSetupMock(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18708,6 +18854,7 @@ func TestWebAdminSetupMock(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18716,6 +18863,7 @@ func TestWebAdminSetupMock(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18734,6 +18882,7 @@ func TestWebAdminSetupMock(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18747,6 +18896,7 @@ func TestWebAdminSetupMock(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) @@ -18898,12 +19048,13 @@ func TestWebAdminLoginMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) // now try using wrong password form := getLoginForm(defaultTokenAuthUser, "wrong pwd", csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18912,6 +19063,7 @@ func TestWebAdminLoginMock(t *testing.T) { form = getLoginForm("wrong username", defaultTokenAuthPass, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) @@ -18926,10 +19078,12 @@ func TestWebAdminLoginMock(t *testing.T) { assert.NoError(t, err) rAddr := "127.1.1.1:1234" - csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr rr = executeRequest(req) @@ -18937,20 +19091,24 @@ func TestWebAdminLoginMock(t *testing.T) { assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) rAddr = "10.9.9.9:1234" - csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr + setLoginCookie(req, loginCookie) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) rAddr = "127.0.1.1:4567" - csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.RemoteAddr = rAddr req.Header.Set("X-Forwarded-For", "10.9.9.9") @@ -18999,10 +19157,10 @@ func TestWebUserShare(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) @@ -19240,10 +19398,10 @@ func TestWebUserShareNoPasswordDisabled(t *testing.T) { user.Filters.DefaultSharesExpiration = 30 user, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, token) + assert.NoError(t, err) userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) @@ -19324,14 +19482,75 @@ func TestWebUserShareNoPasswordDisabled(t *testing.T) { assert.NoError(t, err) } +func TestInvalidCSRF(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + for _, loginURL := range []string{webClientLoginPath, webLoginPath} { + // try using an invalid CSRF token + loginCookie1, csrfToken1, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie1) + assert.NotEmpty(t, csrfToken1) + loginCookie2, csrfToken2, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie2) + assert.NotEmpty(t, csrfToken2) + rAddr := "1.2.3.4" + loginCookie3, csrfToken3, err := getCSRFTokenMock(loginURL, rAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie3) + assert.NotEmpty(t, csrfToken3) + + form := getLoginForm(defaultUsername, defaultPassword, csrfToken1) + req, err := http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + setLoginCookie(req, loginCookie2) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + // use a CSRF token as login cookie (invalid audience) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken1) + req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", csrfToken1)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + // invalid IP + form = getLoginForm(defaultUsername, defaultPassword, csrfToken3) + req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + setLoginCookie(req, loginCookie3) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + } + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + func TestWebUserProfile(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) - assert.NoError(t, err) token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) email := "user@user.com" description := "User" @@ -19407,6 +19626,8 @@ func TestWebUserProfile(t *testing.T) { assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) form.Set("allow_api_key_auth", "0") form.Set(csrfFormToken, csrfToken) @@ -19431,9 +19652,12 @@ func TestWebUserProfile(t *testing.T) { assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) form.Set("public_keys[0][public_key]", testPubKey) form.Set("public_keys[1][public_key]", testPubKey1) form.Set("tls_certs[0][tls_cert]", "") + form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -19454,8 +19678,11 @@ func TestWebUserProfile(t *testing.T) { assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) form.Set("email", "newemail@user.com") form.Set("description", "new description") + form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -19477,6 +19704,9 @@ func TestWebUserProfile(t *testing.T) { assert.NoError(t, err) token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, token) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -19507,7 +19737,7 @@ func TestWebAdminProfile(t *testing.T) { assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webAdminProfilePath, nil) assert.NoError(t, err) @@ -19584,7 +19814,7 @@ func TestWebAdminPwdChange(t *testing.T) { token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webChangeAdminPwdPath, nil) assert.NoError(t, err) @@ -20053,7 +20283,7 @@ func TestBasicWebUsersMock(t *testing.T) { setJWTCookieForReq(req, webToken) rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) form := make(url.Values) form.Set("username", user.Username) @@ -20115,7 +20345,7 @@ func TestWebAdminBasicMock(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set("username", admin.Username) @@ -20389,7 +20619,7 @@ func TestWebAdminGroupsMock(t *testing.T) { admin := getTestAdmin() admin.Username = altAdminUsername admin.Password = altAdminPassword - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -20527,7 +20757,7 @@ func TestAdminUpdateSelfMock(t *testing.T) { assert.NoError(t, err) token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) assert.NoError(t, err) form := make(url.Values) form.Set("username", admin.Username) @@ -20585,9 +20815,8 @@ func TestWebMaintenanceMock(t *testing.T) { setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webMaintenancePath, token) assert.NoError(t, err) - form := make(url.Values) form.Set("mode", "a") b, contentType, _ := getMultipartFormData(form, "", "") @@ -20706,7 +20935,7 @@ func TestWebUserAddMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) group1 := getTestGroup() group1.Name += "_1" @@ -21161,8 +21390,6 @@ func TestWebUserUpdateMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) - assert.NoError(t, err) user := getTestUser() user.Filters.BandwidthLimits = []sdk.BandwidthLimit{ { @@ -21201,6 +21428,8 @@ func TestWebUserUpdateMock(t *testing.T) { checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, userToken) + assert.NoError(t, err) req, err = http.NewRequest(http.MethodPost, webClientTOTPSavePath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) setJWTCookieForReq(req, userToken) @@ -21285,6 +21514,8 @@ func TestWebUserUpdateMock(t *testing.T) { checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + csrfToken, err = getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) @@ -21443,7 +21674,7 @@ func TestUserTemplateWithFoldersMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) user := getTestUser() form := make(url.Values) @@ -21539,7 +21770,7 @@ func TestUserTemplateWithFoldersMock(t *testing.T) { func TestUserSaveFromTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) user1 := "u1" user2 := "u2" @@ -21612,6 +21843,8 @@ func TestUserSaveFromTemplateMock(t *testing.T) { func TestUserTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) user := getTestUser() user.FsConfig.Provider = sdk.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" @@ -21622,8 +21855,6 @@ func TestUserTemplateMock(t *testing.T) { user.FsConfig.S3Config.UploadConcurrency = 4 user.FsConfig.S3Config.DownloadPartSize = 6 user.FsConfig.S3Config.DownloadConcurrency = 3 - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) - assert.NoError(t, err) form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) @@ -21768,7 +21999,7 @@ func TestUserTemplateMock(t *testing.T) { func TestUserPlaceholders(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) assert.NoError(t, err) u := getTestUser() u.HomeDir = filepath.Join(os.TempDir(), "%username%_%password%") @@ -21841,7 +22072,7 @@ func TestUserPlaceholders(t *testing.T) { func TestFolderPlaceholders(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, token) assert.NoError(t, err) folderName := "folderName" form := make(url.Values) @@ -21885,7 +22116,7 @@ func TestFolderSaveFromTemplateMock(t *testing.T) { folder2 := "f2" token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) form := make(url.Values) form.Set("name", "name") @@ -21936,7 +22167,7 @@ func TestFolderTemplateMock(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "%name%mapped%name%path") token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) assert.NoError(t, err) form := make(url.Values) form.Set("name", folderName) @@ -22082,7 +22313,7 @@ func TestWebUserS3Mock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -22320,7 +22551,7 @@ func TestWebUserGCSMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -22448,7 +22679,7 @@ func TestWebUserHTTPFsMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -22575,7 +22806,7 @@ func TestWebUserAzureBlobMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -22772,7 +23003,7 @@ func TestWebUserCryptMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -22879,7 +23110,7 @@ func TestWebUserSFTPFsMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() userAsJSON := getUserAsJSON(t, user) @@ -23035,7 +23266,7 @@ func TestWebUserRole(t *testing.T) { assert.NoError(t, err) webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) user := getTestUser() form := make(url.Values) @@ -23098,7 +23329,7 @@ func TestWebEventAction(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventActionPath, webToken) assert.NoError(t, err) action := dataprovider.BaseEventAction{ ID: 81, @@ -23643,7 +23874,7 @@ func TestWebEventAction(t *testing.T) { func TestWebEventRule(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventRulePath, webToken) assert.NoError(t, err) a := dataprovider.BaseEventAction{ Name: "web_action", @@ -23961,7 +24192,7 @@ func TestWebEventRule(t *testing.T) { func TestWebIPListEntries(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webIPListPath+"/mode", nil) @@ -24147,7 +24378,7 @@ func TestWebIPListEntries(t *testing.T) { func TestWebRole(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminRolePath, webToken) assert.NoError(t, err) role := getTestRole() form := make(url.Values) @@ -24267,7 +24498,7 @@ func TestNameParamSingleSlash(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) assert.NoError(t, err) group := getTestGroup() group.Name = "/" @@ -24330,7 +24561,7 @@ func TestAddWebGroup(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) assert.NoError(t, err) group := getTestGroup() group.UserSettings = dataprovider.GroupUserSettings{ @@ -24516,7 +24747,7 @@ func TestAddWebFoldersMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) @@ -24594,7 +24825,7 @@ func TestHTTPFsWebFolderMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) @@ -24689,7 +24920,7 @@ func TestS3WebFolderMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) mappedPath := filepath.Clean(os.TempDir()) folderName := filepath.Base(mappedPath) @@ -24834,7 +25065,7 @@ func TestUpdateWebGroupMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) assert.NoError(t, err) group, _, err := httpdtest.AddGroup(getTestGroup(), http.StatusCreated) assert.NoError(t, err) @@ -24939,7 +25170,7 @@ func TestUpdateWebFolderMock(t *testing.T) { assert.NoError(t, err) apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) assert.NoError(t, err) folderName := "vfolderupdate" folderDesc := "updated desc" @@ -25156,7 +25387,7 @@ func TestAdminForgotPassword(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) assert.NoError(t, err) form := make(url.Values) @@ -25165,6 +25396,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) @@ -25173,6 +25405,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25183,6 +25416,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -25192,6 +25426,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) @@ -25200,6 +25435,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25210,6 +25446,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25219,14 +25456,19 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) form.Set("username", altAdminUsername) req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -25248,6 +25490,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25262,6 +25505,7 @@ func TestAdminForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25316,22 +25560,25 @@ func TestUserForgotPassword(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := make(url.Values) form.Set("username", "") // no csrf token req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusForbidden, rr.Code) // empty username - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) - assert.NoError(t, err) form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25341,6 +25588,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25353,11 +25601,12 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) assert.GreaterOrEqual(t, len(lastResetCode), 20) - // no csrf token + // no login token form = make(url.Values) req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) @@ -25372,6 +25621,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25382,6 +25632,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25392,6 +25643,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25401,10 +25653,13 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) form = make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) @@ -25412,6 +25667,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusFound, rr.Code) @@ -25444,6 +25700,7 @@ func TestUserForgotPassword(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = executeRequest(req) assert.Equal(t, http.StatusOK, rr.Code) @@ -25640,7 +25897,7 @@ func TestAPIForgotPassword(t *testing.T) { func TestProviderClosedMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, token) assert.NoError(t, err) // create a role admin role, resp, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) @@ -25849,7 +26106,7 @@ func TestWebConnectionsMock(t *testing.T) { checkResponseCode(t, http.StatusForbidden, rr) assert.Contains(t, rr.Body.String(), "Invalid token") - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) setJWTCookieForReq(req, token) @@ -26124,29 +26381,53 @@ func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { return json } -func getCSRFTokenMock(loginURLPath, remoteAddr string) (string, error) { - req, err := http.NewRequest(http.MethodGet, loginURLPath, nil) +func getCSRFTokenFromInternalPageMock(urlPath, token string) (string, error) { + req, err := http.NewRequest(http.MethodGet, urlPath, nil) if err != nil { return "", err } + req.RequestURI = urlPath + setJWTCookieForReq(req, token) + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", rr.Code) + } + return getCSRFTokenFromBody(rr.Body) +} + +func getCSRFTokenMock(loginURLPath, remoteAddr string) (string, string, error) { + req, err := http.NewRequest(http.MethodGet, loginURLPath, nil) + if err != nil { + return "", "", err + } req.RemoteAddr = remoteAddr rr := executeRequest(req) - return getCSRFTokenFromBody(bytes.NewBuffer(rr.Body.Bytes())) + cookie := rr.Header().Get("Set-Cookie") + if cookie == "" { + return "", "", errors.New("unable to get login cookie") + } + token, err := getCSRFTokenFromBody(bytes.NewBuffer(rr.Body.Bytes())) + return cookie, token, err } -func getCSRFToken(url string) (string, error) { +func getCSRFToken(url string) (string, string, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return "", err + return "", "", err } resp, err := httpclient.GetHTTPClient().Do(req) if err != nil { - return "", err + return "", "", err + } + cookie := resp.Header.Get("Set-Cookie") + if cookie == "" { + return "", "", errors.New("no login cookie") } defer resp.Body.Close() - return getCSRFTokenFromBody(resp.Body) + token, err := getCSRFTokenFromBody(resp.Body) + return cookie, token, err } func getCSRFTokenFromBody(body io.Reader) (string, error) { @@ -26182,6 +26463,10 @@ func getCSRFTokenFromBody(body io.Reader) (string, error) { f(doc) + if csrfToken == "" { + return "", errors.New("CSRF token not found") + } + return csrfToken, nil } @@ -26208,6 +26493,10 @@ func setAPIKeyForReq(req *http.Request, apiKey, username string) { req.Header.Set("X-SFTPGO-API-KEY", apiKey) } +func setLoginCookie(req *http.Request, cookie string) { + req.Header.Set("Cookie", cookie) +} + func setJWTCookieForReq(req *http.Request, jwtToken string) { req.RemoteAddr = defaultRemoteAddr req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", jwtToken)) @@ -26251,13 +26540,14 @@ func getJWTAPIUserTokenFromTestServer(username, password string) (string, error) } func getJWTWebToken(username, password string) (string, error) { - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, httpBaseURL+webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") client := &http.Client{ Timeout: 10 * time.Second, @@ -26290,13 +26580,14 @@ func getCookieFromResponse(rr *httptest.ResponseRecorder) (string, error) { } func getJWTWebClientTokenFromTestServerWithAddr(username, password, remoteAddr string) (string, error) { - csrfToken, err := getCSRFTokenMock(webClientLoginPath, remoteAddr) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, remoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = remoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { @@ -26306,13 +26597,14 @@ func getJWTWebClientTokenFromTestServerWithAddr(username, password, remoteAddr s } func getJWTWebClientTokenFromTestServer(username, password string) (string, error) { - csrfToken, err := getCSRFToken(httpBaseURL + webClientLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { @@ -26322,13 +26614,14 @@ func getJWTWebClientTokenFromTestServer(username, password string) (string, erro } func getJWTWebTokenFromTestServer(username, password string) (string, error) { - csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) if err != nil { return "", err } form := getLoginForm(username, password, csrfToken) req, _ := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) if rr.Code != http.StatusFound { diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index ea48eb9c..18988730 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -412,6 +412,35 @@ func TestGCSWebInvalidFormFile(t *testing.T) { assert.EqualError(t, err, http.ErrNotMultipart.Error()) } +func TestVerifyCSRFToken(t *testing.T) { + server := httpdServer{} + server.initializeRouter() + req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, fs.ErrPermission)) + + rr := httptest.NewRecorder() + tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath) + assert.NotEmpty(t, tokenString) + + token, err := server.csrfTokenAuth.Decode(tokenString) + require.NoError(t, err) + _, ok := token.Get(claimRef) + assert.False(t, ok) + + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorIs(t, err, fs.ErrPermission) + + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorContains(t, err, "the form token is not valid") +} + func TestInvalidToken(t *testing.T) { server := httpdServer{} server.initializeRouter() @@ -923,13 +952,24 @@ func TestUpdateWebAdminInvalidClaims(t *testing.T) { token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webAdminPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + form := make(url.Values) - form.Set(csrfFormToken, createCSRFToken("")) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) form.Set("status", "1") form.Set("default_users_expiration", "30") - req, _ := http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode()))) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) rctx := chi.NewRouteContext() rctx.URLParams.Add("username", "admin") + req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) @@ -1028,7 +1068,7 @@ func TestOAuth2Redirect(t *testing.T) { assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle) ip := "127.1.1.4" - tokenString := createOAuth2Token(xid.New().String(), ip) + tokenString := createOAuth2Token(server.csrfTokenAuth, xid.New().String(), ip) rr = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst assert.NoError(t, err) @@ -1039,8 +1079,10 @@ func TestOAuth2Redirect(t *testing.T) { } func TestOAuth2Token(t *testing.T) { + server := httpdServer{} + server.initializeRouter() // invalid token - _, err := verifyOAuth2Token("token", "") + _, err := verifyOAuth2Token(server.csrfTokenAuth, "token", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify OAuth2 state") } @@ -1053,22 +1095,22 @@ func TestOAuth2Token(t *testing.T) { claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - _, tokenString, err := csrfTokenAuth.Encode(claims) + _, tokenString, err := server.csrfTokenAuth.Encode(claims) assert.NoError(t, err) - _, err = verifyOAuth2Token(tokenString, "") + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // bad IP - tokenString = createOAuth2Token("state", "127.1.1.1") - _, err = verifyOAuth2Token(tokenString, "127.1.1.2") + tokenString = createOAuth2Token(server.csrfTokenAuth, "state", "127.1.1.1") + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.2") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // ok state := xid.New().String() - tokenString = createOAuth2Token(state, "127.1.1.3") - s, err := verifyOAuth2Token(tokenString, "127.1.1.3") + tokenString = createOAuth2Token(server.csrfTokenAuth, state, "127.1.1.3") + s, err := verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.3") assert.NoError(t, err) assert.Equal(t, state, s) // no jti @@ -1077,19 +1119,17 @@ func TestOAuth2Token(t *testing.T) { claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, "127.1.1.4"} - _, tokenString, err = csrfTokenAuth.Encode(claims) + _, tokenString, err = server.csrfTokenAuth.Encode(claims) assert.NoError(t, err) - _, err = verifyOAuth2Token(tokenString, "127.1.1.4") + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // encode error - csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil) - tokenString = createOAuth2Token(xid.New().String(), "") + server.csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil) + tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "") assert.Empty(t, tokenString) - server := httpdServer{} - server.initializeRouter() rr := httptest.NewRecorder() testReq := make(map[string]any) testReq["base_redirect_url"] = "http://localhost:8082" @@ -1097,16 +1137,17 @@ func TestOAuth2Token(t *testing.T) { assert.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) assert.NoError(t, err) - handleSMTPOAuth2TokenRequestPost(rr, req) + server.handleSMTPOAuth2TokenRequestPost(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Contains(t, rr.Body.String(), "unable to create state token") - - csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) } func TestCSRFToken(t *testing.T) { + server := httpdServer{} + server.initializeRouter() // invalid token - err := verifyCSRFToken("token", "") + req := &http.Request{} + err := verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify form token") } @@ -1119,16 +1160,23 @@ func TestCSRFToken(t *testing.T) { claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - _, tokenString, err := csrfTokenAuth.Encode(claims) + _, tokenString, err := server.csrfTokenAuth.Encode(claims) assert.NoError(t, err) - err = verifyCSRFToken(tokenString, "") + values := url.Values{} + values.Set(csrfFormToken, tokenString) + req.Form = values + err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "form token is not valid") } // bad IP - tokenString = createCSRFToken("127.1.1.1") - err = verifyCSRFToken(tokenString, "127.1.1.2") + req.RemoteAddr = "127.1.1.1" + tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) + values.Set(csrfFormToken, tokenString) + req.Form = values + req.RemoteAddr = "127.1.1.2" + err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "form token is not valid") } @@ -1137,8 +1185,9 @@ func TestCSRFToken(t *testing.T) { claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) claims[jwt.ExpirationKey] = now.Add(tokenDuration) claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - _, tokenString, err = csrfTokenAuth.Encode(claims) + _, tokenString, err = server.csrfTokenAuth.Encode(claims) assert.NoError(t, err) + assert.NotEmpty(t, tokenString) r := GetHTTPRouter(Binding{ Address: "", @@ -1148,9 +1197,9 @@ func TestCSRFToken(t *testing.T) { EnableRESTAPI: true, RenderOpenAPI: true, }) - fn := verifyCSRFHeader(r) + fn := server.verifyCSRFHeader(r) rr := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "Invalid token") @@ -1163,18 +1212,20 @@ func TestCSRFToken(t *testing.T) { assert.Contains(t, rr.Body.String(), "the token is not valid") // invalid IP - tokenString = createCSRFToken("172.16.1.2") + tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) req.Header.Set(csrfHeaderToken, tokenString) + req.RemoteAddr = "172.16.1.2" rr = httptest.NewRecorder() fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "the token is not valid") - csrfTokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) - tokenString = createCSRFToken("") + csrfTokenAuth := jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) + tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath) assert.Empty(t, tokenString) - - csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) + rr = httptest.NewRecorder() + createLoginCookie(rr, req, csrfTokenAuth, "", webBaseAdminPath, req.RemoteAddr) + assert.Empty(t, rr.Header().Get("Set-Cookie")) } func TestCreateShareCookieError(t *testing.T) { @@ -1205,19 +1256,38 @@ func TestCreateShareCookieError(t *testing.T) { assert.NoError(t, err) server := httpdServer{ - tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil), + tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil), + csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), } + + c := jwtTokenClaims{ + JwtID: xid.New().String(), + } + resp, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "127.0.0.1") + assert.NoError(t, err) + parsedToken, err := jwtauth.VerifyToken(server.csrfTokenAuth, resp["access_token"].(string)) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil) + assert.NoError(t, err) + req.RemoteAddr = "127.0.0.1:4567" + ctx := req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + form := make(url.Values) form.Set("share_password", pwd) - form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseClientPath)) rctx := chi.NewRouteContext() rctx.URLParams.Add("id", share.ShareID) rr := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"), + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"), bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = "127.0.0.1:2345" + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", resp["access_token"])) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) server.handleClientShareLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) @@ -1229,7 +1299,8 @@ func TestCreateShareCookieError(t *testing.T) { func TestCreateTokenError(t *testing.T) { server := httpdServer{ - tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil), + tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil), + csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), } rr := httptest.NewRecorder() admin := dataprovider.Admin{ @@ -1253,14 +1324,36 @@ func TestCreateTokenError(t *testing.T) { server.generateAndSendUserToken(rr, req, "", user) assert.Equal(t, http.StatusInternalServerError, rr.Code) + c := jwtTokenClaims{ + JwtID: xid.New().String(), + } + token, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + rr = httptest.NewRecorder() form := make(url.Values) form.Set("username", admin.Username) form.Set("password", admin.Password) - form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, xid.New().String(), webBaseAdminPath)) + cookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, cookie) req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) - req.RemoteAddr = "127.0.0.1:1234" + req.Header.Set("Cookie", cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) server.handleWebAdminLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) // req with no content type @@ -1287,7 +1380,7 @@ func TestCreateTokenError(t *testing.T) { req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - _, err := getAdminFromPostFields(req) + _, err = getAdminFromPostFields(req) assert.Error(t, err) req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil) @@ -1421,13 +1514,21 @@ func TestCreateTokenError(t *testing.T) { err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + rr = httptest.NewRecorder() form = make(url.Values) form.Set("username", user.Username) form.Set("password", "clientpwd") - form.Set(csrfFormToken, createCSRFToken("127.0.0.1")) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) - req.RemoteAddr = "127.0.0.1:4567" req.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.handleWebClientLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) @@ -1616,6 +1717,7 @@ func TestCookieExpiration(t *testing.T) { claims = make(map[string]any) claims[claimUsernameKey] = admin.Username claims[claimPermissionsKey] = admin.Permissions + claims[jwt.JwtIDKey] = xid.New().String() claims[jwt.SubjectKey] = admin.GetSignature() claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.AudienceKey] = []string{tokenAudienceAPI} @@ -1648,9 +1750,11 @@ func TestCookieExpiration(t *testing.T) { admin, err = dataprovider.AdminExists(admin.Username) assert.NoError(t, err) + tokenID := xid.New().String() claims = make(map[string]any) claims[claimUsernameKey] = admin.Username claims[claimPermissionsKey] = admin.Permissions + claims[jwt.JwtIDKey] = tokenID claims[jwt.SubjectKey] = admin.GetSignature() claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.AudienceKey] = []string{tokenAudienceAPI} @@ -1669,6 +1773,11 @@ func TestCookieExpiration(t *testing.T) { server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.True(t, strings.HasPrefix(cookie, "jwt=")) + req.Header.Set("Cookie", cookie) + token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + if assert.NoError(t, err) { + assert.Equal(t, tokenID, token.JwtID()) + } err = dataprovider.DeleteAdmin(admin.Username, "", "", "") assert.NoError(t, err) @@ -1689,6 +1798,7 @@ func TestCookieExpiration(t *testing.T) { claims = make(map[string]any) claims[claimUsernameKey] = user.Username claims[claimPermissionsKey] = user.Filters.WebClient + claims[jwt.JwtIDKey] = tokenID claims[jwt.SubjectKey] = user.GetSignature() claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} @@ -1721,6 +1831,7 @@ func TestCookieExpiration(t *testing.T) { claims = make(map[string]any) claims[claimUsernameKey] = user.Username claims[claimPermissionsKey] = user.Filters.WebClient + claims[jwt.JwtIDKey] = tokenID claims[jwt.SubjectKey] = user.GetSignature() claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} @@ -1740,6 +1851,35 @@ func TestCookieExpiration(t *testing.T) { server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) + req.Header.Set("Cookie", cookie) + token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + if assert.NoError(t, err) { + assert.Equal(t, tokenID, token.JwtID()) + } + + // test a disabled user + user.Status = 0 + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + + claims = make(map[string]any) + claims[claimUsernameKey] = user.Username + claims[claimPermissionsKey] = user.Filters.WebClient + claims[jwt.JwtIDKey] = tokenID + claims[jwt.SubjectKey] = user.GetSignature() + claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) + claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} + token, _, err = server.tokenAuth.Encode(claims) + assert.NoError(t, err) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + ctx = jwtauth.NewContext(req.Context(), token, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) err = dataprovider.DeleteUser(user.Username, "", "", "") assert.NoError(t, err) @@ -2104,34 +2244,95 @@ func TestProxyHeaders(t *testing.T) { testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = testIP + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + cookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, cookie) + req.Header.Set("Cookie", cookie) + parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + form := make(url.Values) form.Set("username", username) form.Set("password", password) - form.Set(csrfFormToken, createCSRFToken(testIP)) + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP + req.Header.Set("Cookie", cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) - form.Set(csrfFormToken, createCSRFToken(validForwardedFor)) + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) rr = httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) - cookie := rr.Header().Get("Set-Cookie") + cookie = rr.Header().Get("Set-Cookie") assert.NotContains(t, cookie, "Secure") + // The login cookie is invalidated after a successful login, the same request will fail req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie = rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set(xForwardedProto, "https") @@ -2141,9 +2342,26 @@ func TestProxyHeaders(t *testing.T) { cookie = rr.Header().Get("Set-Cookie") assert.Contains(t, cookie, "Secure") + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie = rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("X-Forwarded-For", validForwardedFor) req.Header.Set(xForwardedProto, "http") @@ -2715,10 +2933,22 @@ func TestInvalidClaims(t *testing.T) { } token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "") assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + form := make(url.Values) - form.Set(csrfFormToken, createCSRFToken("")) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) form.Set("public_keys", "") - req, _ := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req, err = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) server.handleWebClientProfilePost(rr, req) @@ -2735,14 +2965,27 @@ func TestInvalidClaims(t *testing.T) { } token, err = c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + parsedToken, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + form = make(url.Values) - form.Set(csrfFormToken, createCSRFToken("")) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) form.Set("allow_api_key_auth", "") - req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) server.handleWebAdminProfilePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) } func TestTLSReq(t *testing.T) { @@ -3041,24 +3284,31 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { } server.initializeRouter() - rr := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) - assert.NoError(t, err) - server.router.ServeHTTP(rr, r) - assert.Equal(t, http.StatusOK, rr.Code) - for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { - rr = httptest.NewRecorder() - r, err = http.NewRequest(http.MethodGet, webURL, nil) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webURL, nil) assert.NoError(t, err) server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) } + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + cookie := rr.Header().Get("Set-Cookie") + r.Header.Set("Cookie", cookie) + parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx := r.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + r = r.WithContext(ctx) + form := make(url.Values) - csrfToken := createCSRFToken("") - form.Set("_form_token", csrfToken) + csrfToken := createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) + form.Set(csrfFormToken, csrfToken) form.Set("install_code", installationCode+"5") form.Set("username", defaultAdminUsername) form.Set("password", "password") @@ -3066,6 +3316,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) @@ -3077,6 +3329,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) @@ -3098,12 +3352,6 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { return "5678" }) - rr = httptest.NewRecorder() - r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) - assert.NoError(t, err) - server.router.ServeHTTP(rr, r) - assert.Equal(t, http.StatusOK, rr.Code) - for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodGet, webURL, nil) @@ -3113,9 +3361,22 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + cookie = rr.Header().Get("Set-Cookie") + r.Header.Set("Cookie", cookie) + parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie) + assert.NoError(t, err) + ctx = r.Context() + ctx = jwtauth.NewContext(ctx, parsedToken, err) + r = r.WithContext(ctx) + form = make(url.Values) - csrfToken = createCSRFToken("") - form.Set("_form_token", csrfToken) + csrfToken = createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) + form.Set(csrfFormToken, csrfToken) form.Set("install_code", installationCode) form.Set("username", defaultAdminUsername) form.Set("password", "password") @@ -3123,6 +3384,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusOK, rr.Code) @@ -3134,6 +3397,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { rr = httptest.NewRecorder() r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") server.router.ServeHTTP(rr, r) assert.Equal(t, http.StatusFound, rr.Code) @@ -3199,6 +3464,7 @@ func TestDecodeToken(t *testing.T) { claimNodeID: nodeID, claimMustChangePasswordKey: false, claimMustSetSecondFactorKey: true, + claimRef: "ref", } c := jwtTokenClaims{} c.Decode(token) @@ -3206,6 +3472,11 @@ func TestDecodeToken(t *testing.T) { assert.Equal(t, nodeID, c.NodeID) assert.False(t, c.MustChangePassword) assert.True(t, c.MustSetTwoFactorAuth) + assert.Equal(t, "ref", c.Ref) + + asMap := c.asMap() + asMap[claimMustChangePasswordKey] = false + assert.Equal(t, token, asMap) token[claimMustChangePasswordKey] = 10 c = jwtTokenClaims{} diff --git a/internal/httpd/middleware.go b/internal/httpd/middleware.go index 02eb438b..38fa0508 100644 --- a/internal/httpd/middleware.go +++ b/internal/httpd/middleware.go @@ -95,13 +95,11 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi doRedirect("Your token audience is not valid", nil) return errInvalidToken } - if tokenValidationMode != tokenValidationNoIPMatch { - ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if !util.Contains(token.Audience(), ipAddr) { - logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) - doRedirect("Your token is not valid", nil) - return errInvalidToken - } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) + doRedirect("Your token is not valid", nil) + return err } return nil } @@ -123,10 +121,16 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req return errInvalidToken } if !util.Contains(token.Audience(), audience) { - logger.Debug(logSender, "", "the token is not valid for audience %q", audience) + logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience) notFoundFunc(w, r, nil) return errInvalidToken } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) + notFoundFunc(w, r, nil) + return err + } return nil } @@ -324,10 +328,10 @@ func (s *httpdServer) checkPerm(perm string) func(next http.Handler) http.Handle } } -func verifyCSRFHeader(next http.Handler) http.Handler { +func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get(csrfHeaderToken) - token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) + token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF header: %v", err) sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden) @@ -340,12 +344,10 @@ func verifyCSRFHeader(next http.Handler) http.Handler { return } - if tokenValidationMode != tokenValidationNoIPMatch { - if !util.Contains(token.Audience(), util.GetIPFromRemoteAddress(r.RemoteAddr)) { - logger.Debug(logSender, "", "error validating CSRF header IP audience") - sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) - return - } + if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + logger.Debug(logSender, "", "error validating CSRF header IP audience") + sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) + return } next.ServeHTTP(w, r) diff --git a/internal/httpd/oidc.go b/internal/httpd/oidc.go index 22ba3c7c..3b127f6e 100644 --- a/internal/httpd/oidc.go +++ b/internal/httpd/oidc.go @@ -541,6 +541,7 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h return } jwtTokenClaims := jwtTokenClaims{ + JwtID: token.Cookie, Username: token.Username, Permissions: token.Permissions, Role: token.TokenRole, @@ -594,6 +595,7 @@ func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request) authReq, err := oidcMgr.getPendingAuth(state) if err != nil { logger.Debug(logSender, "", "oidc authentication state did not match") + oidcMgr.removePendingAuth(state) s.renderClientMessagePage(w, r, util.I18nInvalidAuthReqTitle, http.StatusBadRequest, util.NewI18nError(err, util.I18nInvalidAuth), "") return diff --git a/internal/httpd/oidc_test.go b/internal/httpd/oidc_test.go index be04d2ab..efd9a6b2 100644 --- a/internal/httpd/oidc_test.go +++ b/internal/httpd/oidc_test.go @@ -33,7 +33,6 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/go-chi/jwtauth/v5" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" @@ -1584,12 +1583,9 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) { tokenCookie = k } // we should be able to create admins without setting a password - if csrfTokenAuth == nil { - csrfTokenAuth = jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) - } adminUsername := "testAdmin" form := make(url.Values) - form.Set(csrfFormToken, createCSRFToken("")) + form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath)) form.Set("username", adminUsername) form.Set("password", "") form.Set("status", "1") diff --git a/internal/httpd/server.go b/internal/httpd/server.go index 52f4fe96..e8fe32e2 100644 --- a/internal/httpd/server.go +++ b/internal/httpd/server.go @@ -68,6 +68,7 @@ type httpdServer struct { isShared int router *chi.Mux tokenAuth *jwtauth.JWTAuth + csrfTokenAuth *jwtauth.JWTAuth signingPassphrase string cors CorsConfig } @@ -164,13 +165,13 @@ func (s *httpdServer) refreshCookie(next http.Handler) http.Handler { }) } -func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := loginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nLoginTitle, CurrentURL: webClientLoginPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), Branding: s.binding.Branding.WebClient, FormDisabled: s.binding.isWebClientLoginFormDisabled(), CheckRedirect: true, @@ -193,8 +194,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - c := jwtTokenClaims{} - c.removeCookie(w, r, webBaseClientPath) + removeCookie(w, r, webBaseClientPath) s.logoutOIDCUser(w, r) http.Redirect(w, r, webClientLoginPath, http.StatusFound) @@ -206,7 +206,7 @@ func (s *httpdServer) handleWebClientChangePwdPost(w http.ResponseWriter, r *htt s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -226,7 +226,7 @@ func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Reques return } msg := getFlashMessage(w, r) - s.renderClientLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderClientLoginPage(w, r, msg.getI18nError()) } func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { @@ -234,7 +234,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } protocol := common.ProtocolHTTP @@ -244,20 +244,19 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials) s.renderClientLoginPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err) - s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) - return + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) } if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err) - s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) return } @@ -265,13 +264,13 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) s.renderClientLoginPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err) - s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message), ipAddr) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) return } @@ -280,7 +279,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) - s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric), ipAddr) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric)) return } s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage) @@ -292,10 +291,10 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { - s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -304,12 +303,12 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r _, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), newPassword, confirmPassword, false) if err != nil { - s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr) + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) if err := checkHTTPClientUser(user, r, connectionID, true); err != nil { - s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorDirList403), ipAddr) + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorDirList403)) return } @@ -317,7 +316,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r err = user.CheckFsRoot(connectionID) if err != nil { logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) - s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset), ipAddr) + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset)) return } s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage) @@ -332,18 +331,18 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) if username == "" || recoveryCode == "" { s.renderClientTwoFactorRecoveryPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { - s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user, userMerged, err := dataprovider.GetUserVariants(username, "") @@ -352,12 +351,12 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } s.renderClientTwoFactorRecoveryPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { s.renderClientTwoFactorPage(w, r, util.NewI18nError( - util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr) + util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) return } for idx, code := range user.Filters.RecoveryCodes { @@ -368,7 +367,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter if code.Secret.GetPayload() == recoveryCode { if code.Used { s.renderClientTwoFactorRecoveryPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } user.Filters.RecoveryCodes[idx].Used = true @@ -386,7 +385,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter } handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck s.renderClientTwoFactorRecoveryPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) } func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) { @@ -398,7 +397,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username @@ -407,25 +406,25 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials) s.renderClientTwoFactorPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err) - s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } user, err := dataprovider.GetUserWithGroupSettings(username, "") if err != nil { updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, dataprovider.LoginMethodPassword, ipAddr, err) - s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) - s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) return } err = user.Filters.TOTPConfig.Secret.Decrypt() @@ -439,7 +438,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt if !match || err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials) s.renderClientTwoFactorPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String()) @@ -456,18 +455,17 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) if username == "" || recoveryCode == "" { - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.AdminExists(username) @@ -475,12 +473,11 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } if !admin.Filters.TOTPConfig.Enabled { - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled), ipAddr) + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) return } for idx, code := range admin.Filters.RecoveryCodes { @@ -491,7 +488,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, if code.Secret.GetPayload() == recoveryCode { if code.Used { s.renderTwoFactorRecoveryPage(w, r, - util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), ipAddr) + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } admin.Filters.RecoveryCodes[idx].Used = true @@ -506,8 +503,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, } } handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck - s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) } func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) { @@ -519,19 +515,18 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := claims.Username passcode := strings.TrimSpace(r.Form.Get("passcode")) if username == "" || passcode == "" { - s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { err = handleDefenderEventLoginFailed(ipAddr, err) - s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.AdminExists(username) @@ -539,11 +534,11 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http if errors.Is(err, util.ErrNotFound) { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck } - s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } if !admin.Filters.TOTPConfig.Enabled { - s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled), ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) return } err = admin.Filters.TOTPConfig.Secret.Decrypt() @@ -555,8 +550,7 @@ func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http admin.Filters.TOTPConfig.Secret.GetPayload()) if !match || err != nil { handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck - s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr) @@ -567,37 +561,35 @@ func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Req ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } username := strings.TrimSpace(r.Form.Get("username")) password := strings.TrimSpace(r.Form.Get("password")) if username == "" || password == "" { - s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { - s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) if err != nil { handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck - s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr) } -func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := loginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nLoginTitle, CurrentURL: webAdminLoginPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), Branding: s.binding.Branding.WebAdmin, FormDisabled: s.binding.isWebAdminLoginFormDisabled(), CheckRedirect: false, @@ -622,13 +614,12 @@ func (s *httpdServer) handleWebAdminLogin(w http.ResponseWriter, r *http.Request return } msg := getFlashMessage(w, r) - s.renderAdminLoginPage(w, r, msg.getI18nError(), util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderAdminLoginPage(w, r, msg.getI18nError()) } func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - c := jwtTokenClaims{} - c.removeCookie(w, r, webBaseAdminPath) + removeCookie(w, r, webBaseAdminPath) s.logoutOIDCUser(w, r) http.Redirect(w, r, webAdminLoginPath, http.StatusFound) @@ -641,7 +632,7 @@ func (s *httpdServer) handleWebAdminChangePwdPost(w http.ResponseWriter, r *http s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -660,10 +651,10 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r * ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { - s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -672,7 +663,7 @@ func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r * admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), newPassword, confirmPassword, true) if err != nil { - s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric), ipAddr) + s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) return } @@ -688,10 +679,10 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { - s.renderAdminSetupPage(w, r, "", ipAddr, util.NewI18nError(err, util.I18nErrorInvalidForm)) + s.renderAdminSetupPage(w, r, "", util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -700,7 +691,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) installCode := strings.TrimSpace(r.Form.Get("install_code")) if installationCode != "" && installCode != resolveInstallationCode() { - s.renderAdminSetupPage(w, r, username, ipAddr, + s.renderAdminSetupPage(w, r, username, util.NewI18nError( util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)), util.I18nErrorSetupInstallCode), @@ -708,17 +699,17 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req return } if username == "" { - s.renderAdminSetupPage(w, r, username, ipAddr, + s.renderAdminSetupPage(w, r, username, util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message)) return } if password == "" { - s.renderAdminSetupPage(w, r, username, ipAddr, + s.renderAdminSetupPage(w, r, username, util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message)) return } if password != confirmPassword { - s.renderAdminSetupPage(w, r, username, ipAddr, + s.renderAdminSetupPage(w, r, username, util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch)) return } @@ -730,7 +721,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req } err = dataprovider.AddAdmin(&admin, username, ipAddr, "") if err != nil { - s.renderAdminSetupPage(w, r, username, ipAddr, util.NewI18nError(err, util.I18nError500Message)) + s.renderAdminSetupPage(w, r, username, util.NewI18nError(err, util.I18nError500Message)) return } s.loginAdmin(w, r, &admin, false, nil, ipAddr) @@ -738,7 +729,7 @@ func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Req func (s *httpdServer) loginUser( w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string, - isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string), + isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ) { c := jwtTokenClaims{ Username: user.Username, @@ -760,12 +751,10 @@ func (s *httpdServer) loginUser( if err != nil { logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err) updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure) - errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) + errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) return } - if isSecondFactorAuth { - invalidateToken(r) - } + invalidateToken(r, !isSecondFactorAuth) if audience == tokenAudienceWebClientPartial { redirectPath := webClientTwoFactorPath if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { @@ -785,7 +774,7 @@ func (s *httpdServer) loginUser( func (s *httpdServer) loginAdmin( w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, - isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string), + isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ipAddr string, ) { c := jwtTokenClaims{ @@ -807,15 +796,13 @@ func (s *httpdServer) loginAdmin( if err != nil { logger.Warn(logSender, "", "unable to set admin login cookie %v", err) if errorFunc == nil { - s.renderAdminSetupPage(w, r, admin.Username, ipAddr, util.NewI18nError(err, util.I18nError500Message)) + s.renderAdminSetupPage(w, r, admin.Username, util.NewI18nError(err, util.I18nError500Message)) return } - errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) + errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) return } - if isSecondFactorAuth { - invalidateToken(r) - } + invalidateToken(r, !isSecondFactorAuth) if audience == tokenAudienceWebAdminPartial { http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) return @@ -831,7 +818,7 @@ func (s *httpdServer) loginAdmin( func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - invalidateToken(r) + invalidateToken(r, false) sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK) } @@ -1022,13 +1009,13 @@ func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Reque return } if util.Contains(token.Audience(), tokenAudienceWebClient) { - s.refreshClientToken(w, r, tokenClaims) + s.refreshClientToken(w, r, &tokenClaims) } else { - s.refreshAdminToken(w, r, tokenClaims) + s.refreshAdminToken(w, r, &tokenClaims) } } -func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) { +func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) { user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "") if err != nil { return @@ -1037,6 +1024,10 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username) return } + if err := user.CheckLoginConditions(); err != nil { + logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) + return + } if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil { logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) return @@ -1048,22 +1039,18 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck } -func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims jwtTokenClaims) { +func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) { admin, err := dataprovider.AdminExists(tokenClaims.Username) if err != nil { return } - if admin.Status != 1 { - logger.Debug(logSender, "", "admin %q is disabled, unable to refresh cookie", admin.Username) - return - } if admin.GetSignature() != tokenClaims.Signature { logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username) return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if !admin.CanLoginFromIP(ipAddr) { - logger.Debug(logSender, "", "admin %q cannot login from %v, unable to refresh cookie", admin.Username, r.RemoteAddr) + if err := admin.CanLogin(ipAddr); err != nil { + logger.Debug(logSender, "", "unable to refresh cookie for admin %q, err: %v", admin.Username, err) return } tokenClaims.Permissions = admin.Permissions @@ -1236,6 +1223,7 @@ func (s *httpdServer) mustCheckPath(r *http.Request) bool { func (s *httpdServer) initializeRouter() { var hasHTTPSRedirect bool s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil) + s.csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil) s.router = chi.NewRouter() s.router.Use(middleware.RequestID) @@ -1537,11 +1525,14 @@ func (s *httpdServer) setupWebClientRoutes() { s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin) } if !s.binding.isWebClientLoginFormDisabled() { - s.router.Post(webClientLoginPath, s.handleWebClientLoginPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webClientLoginPath, s.handleWebClientLoginPost) s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd) - s.router.Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) s.router.Get(webClientResetPwdPath, s.handleWebClientPasswordReset) - s.router.Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Get(webClientTwoFactorPath, s.handleWebClientTwoFactor) @@ -1557,7 +1548,8 @@ func (s *httpdServer) setupWebClientRoutes() { } // share routes available to external users s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet) - s.router.Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare) s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload) s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles) @@ -1574,32 +1566,32 @@ func (s *httpdServer) setupWebClientRoutes() { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient)) } - router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie)) + router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie)) router.Use(jwtAuthenticatorWebClient) router.Get(webClientLogoutPath, s.handleWebClientLogout) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF) - router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientFilePath, getUserFile) - router.With(s.checkAuthRequirements, s.refreshCookie, verifyCSRFHeader).Get(webClientTasksPath+"/{id}", + router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientFilePath, getUserFile) + router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientTasksPath+"/{id}", getWebTask) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFilePath, uploadUserFile) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientExistPath, s.handleClientCheckExist) router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Delete(webClientFilesPath, deleteUserFile) router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie). Get(webClientDirsPath, s.handleClientGetDirContents) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientDirsPath, createUserDir) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Delete(webClientDirsPath, taskDeleteDir) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFileActionsPath+"/move", taskRenameFsEntry) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). Post(webClientFileActionsPath+"/copy", taskCopyFsEntry) router.With(s.checkAuthRequirements, s.refreshCookie). Post(webClientDownloadZipPath, s.handleWebClientDownloadZip) @@ -1615,15 +1607,15 @@ func (s *httpdServer) setupWebClientRoutes() { Get(webClientMFAPath, s.handleWebClientMFA) router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). Get(webClientMFAPath+"/qrcode", getQRCode) - router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPGeneratePath, generateTOTPSecret) - router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPValidatePath, validateTOTPPasscode) - router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientTOTPSavePath, saveTOTPConfig) - router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader, s.refreshCookie). + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader, s.refreshCookie). Get(webClientRecoveryCodesPath, getRecoveryCodes) - router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), verifyCSRFHeader). + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). Post(webClientRecoveryCodesPath, generateRecoveryCodes) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie). Get(webClientSharesPath+jsonAPISuffix, getAllShares) @@ -1637,7 +1629,7 @@ func (s *httpdServer) setupWebClientRoutes() { Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet) router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost) - router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), verifyCSRFHeader). + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.verifyCSRFHeader). Delete(webClientSharePath+"/{id}", deleteShare) }) } @@ -1655,9 +1647,11 @@ func (s *httpdServer) setupWebAdminRoutes() { } s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect) s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet) - s.router.Post(webAdminSetupPath, s.handleWebAdminSetupPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webAdminSetupPath, s.handleWebAdminSetupPost) if !s.binding.isWebAdminLoginFormDisabled() { - s.router.Post(webAdminLoginPath, s.handleWebAdminLoginPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webAdminLoginPath, s.handleWebAdminLoginPost) s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor) @@ -1671,16 +1665,18 @@ func (s *httpdServer) setupWebAdminRoutes() { s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost) s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd) - s.router.Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) s.router.Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset) - s.router.Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) + s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) } s.router.Group(func(router chi.Router) { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin)) } - router.Use(jwtauth.Verify(s.tokenAuth, tokenFromContext, jwtauth.TokenFromCookie)) + router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie)) router.Use(jwtAuthenticatorWebAdmin) router.Get(webLogoutPath, s.handleWebAdminLogout) @@ -1692,12 +1688,12 @@ func (s *httpdServer) setupWebAdminRoutes() { router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA) router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode) - router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret) - router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode) - router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig) - router.With(verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath, + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath, getRecoveryCodes) - router.With(verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes) router.Group(func(router chi.Router) { router.Use(s.checkAuthRequirements) @@ -1724,7 +1720,7 @@ func (s *httpdServer) setupWebAdminRoutes() { Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet) router.With(s.checkPerm(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}", s.handleWebUpdateGroupPost) - router.With(s.checkPerm(dataprovider.PermAdminManageGroups), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageGroups), s.verifyCSRFHeader). Delete(webGroupPath+"/{name}", deleteGroup) router.With(s.checkPerm(dataprovider.PermAdminViewConnections), s.refreshCookie). Get(webConnectionsPath, s.handleWebGetConnections) @@ -1750,25 +1746,25 @@ func (s *httpdServer) setupWebAdminRoutes() { router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath, s.handleWebAddAdminPost) router.With(s.checkPerm(dataprovider.PermAdminManageAdmins)).Post(webAdminPath+"/{username}", s.handleWebUpdateAdminPost) - router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageAdmins), s.verifyCSRFHeader). Delete(webAdminPath+"/{username}", deleteAdmin) - router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA) - router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminCloseConnections), s.verifyCSRFHeader). Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.refreshCookie). Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet) router.With(s.checkPerm(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}", s.handleWebUpdateFolderPost) - router.With(s.checkPerm(dataprovider.PermAdminManageFolders), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageFolders), s.verifyCSRFHeader). Delete(webFolderPath+"/{name}", deleteFolder) - router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). Post(webScanVFolderPath+"/{name}", startFolderQuotaScan) - router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminDeleteUsers), s.verifyCSRFHeader). Delete(webUserPath+"/{username}", deleteUser) - router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). Put(webUserPath+"/{username}/2fa/disable", disableUser2FA) - router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). Post(webQuotaScanPath+"/{username}", startUserQuotaScan) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webMaintenancePath, s.handleWebMaintenance) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Get(webBackupPath, dumpData) @@ -1795,7 +1791,7 @@ func (s *httpdServer) setupWebAdminRoutes() { Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet) router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionPost) - router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader). Delete(webAdminEventActionPath+"/{name}", deleteEventAction) router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), compressor.Handler, s.refreshCookie). Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules) @@ -1809,9 +1805,9 @@ func (s *httpdServer) setupWebAdminRoutes() { Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet) router.With(s.checkPerm(dataprovider.PermAdminManageEventRules)).Post(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRulePost) - router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader). Delete(webAdminEventRulePath+"/{name}", deleteEventRule) - router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageEventRules), s.verifyCSRFHeader). Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule) router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.refreshCookie). Get(webAdminRolesPath, s.handleWebGetRoles) @@ -1824,7 +1820,7 @@ func (s *httpdServer) setupWebAdminRoutes() { Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet) router.With(s.checkPerm(dataprovider.PermAdminManageRoles)).Post(webAdminRolePath+"/{name}", s.handleWebUpdateRolePost) - router.With(s.checkPerm(dataprovider.PermAdminManageRoles), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageRoles), s.verifyCSRFHeader). Delete(webAdminRolePath+"/{name}", deleteRole) router.With(s.checkPerm(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath, s.handleWebGetEvents) @@ -1845,14 +1841,14 @@ func (s *httpdServer) setupWebAdminRoutes() { s.handleWebUpdateIPListEntryGet) router.With(s.checkPerm(dataprovider.PermAdminManageIPLists)).Post(webIPListPath+"/{type}/{ipornet}", s.handleWebUpdateIPListEntryPost) - router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), verifyCSRFHeader). + router.With(s.checkPerm(dataprovider.PermAdminManageIPLists), s.verifyCSRFHeader). Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry) router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs) router.With(s.checkPerm(dataprovider.PermAdminManageSystem)).Post(webConfigsPath, s.handleWebConfigsPost) - router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie). + router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie). Post(webConfigsPath+"/smtp/test", testSMTPConfig) - router.With(s.checkPerm(dataprovider.PermAdminManageSystem), verifyCSRFHeader, s.refreshCookie). - Post(webOAuth2TokenPath, handleSMTPOAuth2TokenRequestPost) + router.With(s.checkPerm(dataprovider.PermAdminManageSystem), s.verifyCSRFHeader, s.refreshCookie). + Post(webOAuth2TokenPath, s.handleSMTPOAuth2TokenRequestPost) }) }) } diff --git a/internal/httpd/webadmin.go b/internal/httpd/webadmin.go index f8ffa63f..a3ee7269 100644 --- a/internal/httpd/webadmin.go +++ b/internal/httpd/webadmin.go @@ -31,6 +31,7 @@ import ( "time" "github.com/go-chi/render" + "github.com/rs/xid" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" @@ -612,10 +613,10 @@ func isServerManagerResource(currentURL string) bool { currentURL == webConfigsPath } -func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request) basePage { +func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage { var csrfToken string if currentURL != "" { - csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr)) + csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath) } return basePage{ commonBasePage: getCommonBasePage(r), @@ -675,7 +676,7 @@ func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http err error, message, text string, ) { data := messagePage{ - basePage: s.getBasePageData(title, "", r), + basePage: s.getBasePageData(title, "", w, r), Error: getI18nError(err), Success: message, Text: text, @@ -710,12 +711,12 @@ func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request, util.NewI18nError(err, util.I18nError404Message), "") } -func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := forgotPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webAdminForgotPwdPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), LoginURL: webAdminLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.Branding.WebAdmin, @@ -723,12 +724,12 @@ func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request renderAdminTemplate(w, templateForgotPassword, data) } -func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := resetPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webAdminResetPwdPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), LoginURL: webAdminLoginPath, Title: util.I18nResetPwdTitle, Branding: s.binding.Branding.WebAdmin, @@ -736,26 +737,26 @@ func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, renderAdminTemplate(w, templateResetPassword, data) } -func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: pageTwoFactorTitle, CurrentURL: webAdminTwoFactorPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), RecoveryURL: webAdminTwoFactorRecoveryPath, Branding: s.binding.Branding.WebAdmin, } renderAdminTemplate(w, templateTwoFactor, data) } -func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: pageTwoFactorRecoveryTitle, CurrentURL: webAdminTwoFactorRecoveryPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), Branding: s.binding.Branding.WebAdmin, } renderAdminTemplate(w, templateTwoFactorRecovery, data) @@ -763,7 +764,7 @@ func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) { data := mfaPage{ - basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, r), + basePage: s.getBasePageData(pageMFATitle, webAdminMFAPath, w, r), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), GenerateTOTPURL: webAdminTOTPGeneratePath, ValidateTOTPURL: webAdminTOTPValidatePath, @@ -782,7 +783,7 @@ func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) { data := profilePage{ - basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, r), + basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r), Error: getI18nError(err), } admin, err := dataprovider.AdminExists(data.LoggedUser.Username) @@ -799,7 +800,7 @@ func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := changePasswordPage{ - basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, r), + basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r), Error: err, } @@ -808,7 +809,7 @@ func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Re func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) { data := maintenancePage{ - basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, r), + basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r), BackupPath: webBackupPath, RestorePath: webRestorePath, Error: getI18nError(err), @@ -830,7 +831,7 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request, configs.ACME.HTTP01Challenge.Port = 80 } data := configsPage{ - basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, r), + basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r), Configs: configs, ConfigSection: section, RedactedSecret: redactedSecret, @@ -842,12 +843,12 @@ func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request, renderAdminTemplate(w, templateConfigs, data) } -func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username, ip string, err *util.I18nError) { +func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) { data := setupPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nSetupTitle, CurrentURL: webAdminSetupPath, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), Username: username, HasInstallationCode: installationCode != "", InstallationCodeHint: installationCodeHint, @@ -876,7 +877,7 @@ func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Re title = util.I18nUpdateAdminTitle } data := adminPage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Admin: admin, Groups: groups, Roles: roles, @@ -917,7 +918,7 @@ func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, use } } user.FsConfig.RedactedSecret = redactedSecret - basePage := s.getBasePageData(title, currentURL, r) + basePage := s.getBasePageData(title, currentURL, w, r) if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil { for _, group := range admin.Groups { user.Groups = append(user.Groups, sdk.GroupMapping{ @@ -982,7 +983,7 @@ func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, e currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet)) } data := ipListPage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Entry: &entry, Mode: mode, @@ -1003,7 +1004,7 @@ func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, rol currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name)) } data := rolePage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Role: &role, Mode: mode, @@ -1033,7 +1034,7 @@ func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, gr group.UserSettings.FsConfig.SetEmptySecretsIfNil() data := groupPage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Group: &group, Mode: mode, @@ -1078,7 +1079,7 @@ func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Reque } data := eventActionPage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Action: action, ActionTypes: dataprovider.EventActionTypes, FsActions: dataprovider.FsActionTypes, @@ -1108,7 +1109,7 @@ func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request } data := eventRulePage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Rule: rule, TriggerTypes: dataprovider.EventTriggerTypes, Actions: actions, @@ -1142,7 +1143,7 @@ func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, f folder.FsConfig.SetEmptySecretsIfNil() data := folderPage{ - basePage: s.getBasePageData(title, currentURL, r), + basePage: s.getBasePageData(title, currentURL, w, r), Error: getI18nError(err), Folder: folder, Mode: mode, @@ -2764,25 +2765,24 @@ func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Req s.renderNotFoundPage(w, r, errors.New("this page does not exist")) return } - s.renderForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderForgotPwdPage(w, r, nil) } func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { - s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } err = handleForgotPassword(r, r.Form.Get("username"), true) if err != nil { - s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr) + s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) return } http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound) @@ -2794,17 +2794,17 @@ func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http s.renderNotFoundPage(w, r, errors.New("this page does not exist")) return } - s.renderResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderResetPwdPage(w, r, nil) } func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - s.renderTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderTwoFactorPage(w, r, nil) } func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - s.renderTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderTwoFactorRecoveryPage(w, r, nil) } func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) { @@ -2830,7 +2830,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -2875,7 +2875,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) { defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -2936,7 +2936,7 @@ func getAllAdmins(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, r) + data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r) renderAdminTemplate(w, templateAdmins, data) } @@ -2946,7 +2946,7 @@ func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Requ http.Redirect(w, r, webAdminLoginPath, http.StatusFound) return } - s.renderAdminSetupPage(w, r, "", util.GetIPFromRemoteAddress(r.RemoteAddr), nil) + s.renderAdminSetupPage(w, r, "", nil) } func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) { @@ -2987,7 +2987,7 @@ func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Reque admin.Password = util.GenerateUniqueID() } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3018,7 +3018,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3071,7 +3071,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := defenderHostsPage{ - basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, r), + basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r), DefenderHostsURL: webDefenderHostsPath, } @@ -3105,7 +3105,7 @@ func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request) s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return } - data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, r) + data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r) renderAdminTemplate(w, templateUsers, data) } @@ -3144,7 +3144,7 @@ func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3244,7 +3244,7 @@ func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.R return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3341,7 +3341,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3387,7 +3387,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3425,7 +3425,7 @@ func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Req func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := statusPage{ - basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, r), + basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r), Status: getServicesStatus(), } renderAdminTemplate(w, templateStatus, data) @@ -3439,7 +3439,7 @@ func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Req return } - data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, r) + data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r) renderAdminTemplate(w, templateConnections, data) } @@ -3464,7 +3464,7 @@ func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Requ defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3525,7 +3525,7 @@ func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.R defer r.MultipartForm.RemoveAll() //nolint:errcheck ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3588,7 +3588,7 @@ func getAllFolders(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, r) + data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r) renderAdminTemplate(w, templateFolders, data) } @@ -3626,7 +3626,7 @@ func getAllGroups(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, r) + data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r) renderAdminTemplate(w, templateGroups, data) } @@ -3648,7 +3648,7 @@ func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Reque return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3695,7 +3695,7 @@ func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Re return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3748,7 +3748,7 @@ func getAllActions(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, r) + data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r) renderAdminTemplate(w, templateEventActions, data) } @@ -3773,7 +3773,7 @@ func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3819,7 +3819,7 @@ func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *h return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3858,7 +3858,7 @@ func getAllRules(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, r) + data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r) renderAdminTemplate(w, templateEventRules, data) } @@ -3884,7 +3884,7 @@ func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.R return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - err = verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr) + err = verifyCSRFToken(r, s.csrfTokenAuth) if err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return @@ -3931,7 +3931,7 @@ func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *htt return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -3978,7 +3978,7 @@ func getAllRoles(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, r) + data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r) renderAdminTemplate(w, templateRoles, data) } @@ -4001,7 +4001,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -4047,7 +4047,7 @@ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Req return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -4065,7 +4065,7 @@ func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request) r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := eventsPage{ - basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, r), + basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r), FsEventsSearchURL: webEventsFsSearchPath, ProviderEventsSearchURL: webEventsProviderSearchPath, LogEventsSearchURL: webEventsLogSearchPath, @@ -4077,7 +4077,7 @@ func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Reques r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus() data := ipListsPage{ - basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, r), + basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r), RateLimitersStatus: rtlStatus, RateLimitersProtocols: strings.Join(rtlProtocols, ", "), IsAllowListEnabled: common.Config.IsAllowListEnabled(), @@ -4115,7 +4115,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -4170,7 +4170,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *h return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -4212,7 +4212,7 @@ func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Reques return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -4262,20 +4262,21 @@ func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.R stateToken := r.URL.Query().Get("state") - state, err := verifyOAuth2Token(stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr)) + state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr)) if err != nil { s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "") return } - defer oauth2Mgr.removePendingAuth(state) - pendingAuth, err := oauth2Mgr.getPendingAuth(state) if err != nil { + oauth2Mgr.removePendingAuth(state) s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, util.NewI18nError(err, util.I18nOAuth2ErrorValidateState), "") return } + oauth2Mgr.removePendingAuth(state) + oauth2Config := smtp.OAuth2Config{ Provider: pendingAuth.Provider, ClientID: pendingAuth.ClientID, diff --git a/internal/httpd/webclient.go b/internal/httpd/webclient.go index 10fd2c49..88596dc3 100644 --- a/internal/httpd/webclient.go +++ b/internal/httpd/webclient.go @@ -523,10 +523,10 @@ func loadClientTemplates(templatesPath string) { clientTemplates[templateShareDownload] = shareDownloadTmpl } -func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Request) baseClientPage { +func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage { var csrfToken string if currentURL != "" { - csrfToken = createCSRFToken(util.GetIPFromRemoteAddress(r.RemoteAddr)) + csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath) } data := baseClientPage{ @@ -552,12 +552,12 @@ func (s *httpdServer) getBaseClientPageData(title, currentURL string, r *http.Re return data } -func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := forgotPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webClientForgotPwdPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), LoginURL: webClientLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.Branding.WebClient, @@ -565,12 +565,12 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.R renderClientTemplate(w, templateForgotPassword, data) } -func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := resetPwdPage{ commonBasePage: getCommonBasePage(r), CurrentURL: webClientResetPwdPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), LoginURL: webClientLoginPath, Title: util.I18nResetPwdTitle, Branding: s.binding.Branding.WebClient, @@ -578,13 +578,13 @@ func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Re renderClientTemplate(w, templateResetPassword, data) } -func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := shareLoginPage{ commonBasePage: getCommonBasePage(r), Title: util.I18nShareLoginTitle, CurrentURL: r.RequestURI, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), Branding: s.binding.Branding.WebClient, } renderClientTemplate(w, templateShareLogin, data) @@ -599,7 +599,7 @@ func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) { func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) { data := clientMessagePage{ - baseClientPage: s.getBaseClientPageData(title, "", r), + baseClientPage: s.getBaseClientPageData(title, "", w, r), Error: getI18nError(err), Success: message, } @@ -627,13 +627,13 @@ func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Re util.NewI18nError(err, util.I18nError404Message), "") } -func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: pageTwoFactorTitle, CurrentURL: webClientTwoFactorPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), RecoveryURL: webClientTwoFactorRecoveryPath, Branding: s.binding.Branding.WebClient, } @@ -643,13 +643,13 @@ func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.R renderClientTemplate(w, templateTwoFactor, data) } -func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError, ip string) { +func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := twoFactorPage{ commonBasePage: getCommonBasePage(r), Title: pageTwoFactorRecoveryTitle, CurrentURL: webClientTwoFactorRecoveryPath, Error: err, - CSRFToken: createCSRFToken(ip), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), Branding: s.binding.Branding.WebClient, } renderClientTemplate(w, templateTwoFactorRecovery, data) @@ -657,7 +657,7 @@ func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) { data := clientMFAPage{ - baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, r), + baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r), TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), GenerateTOTPURL: webClientTOTPGeneratePath, ValidateTOTPURL: webClientTOTPValidatePath, @@ -681,7 +681,7 @@ func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request, title = util.I18nEditFileTitle } data := editFilePage{ - baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, r), + baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r), Path: fileName, Name: path.Base(fileName), CurrentDir: path.Dir(fileName), @@ -702,7 +702,7 @@ func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Re title = util.I18nShareUpdateTitle } data := clientSharePage{ - baseClientPage: s.getBaseClientPageData(title, currentURL, r), + baseClientPage: s.getBaseClientPageData(title, currentURL, w, r), Share: share, Error: err, IsAdd: isAdd, @@ -736,7 +736,7 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque err *util.I18nError, share dataprovider.Share, ) { currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse") - baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, r) + baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r) baseData.FilesURL = currentURL baseSharePath := path.Join(webClientPubSharesPath, share.ShareID) @@ -768,7 +768,7 @@ func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Reque func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, downloadLink string) { data := shareDownloadPage{ - baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", r), + baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r), DownloadLink: downloadLink, } renderClientTemplate(w, templateShareDownload, data) @@ -777,7 +777,7 @@ func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Req func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share dataprovider.Share) { currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload") data := shareUploadPage{ - baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, r), + baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r), Share: &share, UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID), } @@ -787,7 +787,7 @@ func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Req func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string, err *util.I18nError, user *dataprovider.User) { data := filesPage{ - baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, r), + baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r), Error: err, CurrentDir: url.QueryEscape(dirName), DownloadURL: webClientDownloadZipPath, @@ -813,7 +813,7 @@ func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, di func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := clientProfilePage{ - baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, r), + baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r), Error: err, } user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "") @@ -832,7 +832,7 @@ func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Req func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { data := changeClientPasswordPage{ - baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, r), + baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r), Error: err, } @@ -850,8 +850,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http. s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -1440,7 +1439,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -1508,7 +1507,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -1579,7 +1578,7 @@ func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Reque r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) data := clientSharesPage{ - baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, r), + baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r), BasePublicSharesURL: webClientPubSharesPath, } renderClientTemplate(w, templateClientShares, data) @@ -1603,7 +1602,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http. return } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } @@ -1662,12 +1661,12 @@ func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request) func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - s.renderClientTwoFactorPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderClientTwoFactorPage(w, r, nil) } func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - s.renderClientTwoFactorRecoveryPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderClientTwoFactorRecoveryPage(w, r, nil) } func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) { @@ -1719,26 +1718,25 @@ func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Re s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) return } - s.renderClientForgotPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderClientForgotPwdPage(w, r, nil) } func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) err := r.ParseForm() if err != nil { - s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } username := strings.TrimSpace(r.Form.Get("username")) err = handleForgotPassword(r, username, false) if err != nil { - s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric), ipAddr) + s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) return } http.Redirect(w, r, webClientResetPwdPath, http.StatusFound) @@ -1750,7 +1748,7 @@ func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *htt s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) return } - s.renderClientResetPwdPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderClientResetPwdPage(w, r, nil) } func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) { @@ -1853,30 +1851,30 @@ func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name str func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - s.renderShareLoginPage(w, r, nil, util.GetIPFromRemoteAddress(r.RemoteAddr)) + s.renderShareLoginPage(w, r, nil) } func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := r.ParseForm(); err != nil { - s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm), ipAddr) + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) return } - if err := verifyCSRFToken(r.Form.Get(csrfFormToken), ipAddr); err != nil { - s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF), ipAddr) + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } + invalidateToken(r, true) shareID := getURLParam(r, "id") share, err := dataprovider.ShareExists(shareID, "") if err != nil { - s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials), ipAddr) + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) return } match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password"))) if !match || err != nil { - s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials), - ipAddr) + s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) return } c := jwtTokenClaims{ @@ -1884,7 +1882,7 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http. } err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr) if err != nil { - s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message), ipAddr) + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message)) return } next := path.Clean(r.URL.Query().Get("next"))