diff --git a/internal/dataprovider/session.go b/internal/dataprovider/session.go index b9139ad1..e2c35860 100644 --- a/internal/dataprovider/session.go +++ b/internal/dataprovider/session.go @@ -28,6 +28,7 @@ const ( SessionTypeOIDCToken SessionTypeResetCode SessionTypeOAuth2Auth + SessionTypeInvalidToken ) // Session defines a shared session persisted in the data provider @@ -42,7 +43,7 @@ func (s *Session) validate() error { if s.Key == "" { return errors.New("unable to save a session with an empty key") } - if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeOAuth2Auth { + if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeInvalidToken { return fmt.Errorf("invalid session type: %v", s.Type) } return nil diff --git a/internal/httpd/auth_utils.go b/internal/httpd/auth_utils.go index 0cf1dcec..53f86c8f 100644 --- a/internal/httpd/auth_utils.go +++ b/internal/httpd/auth_utils.go @@ -331,7 +331,7 @@ func isTokenInvalidated(r *http.Request) bool { token := fn(r) if token != "" { isTokenFound = true - if _, ok := invalidatedJWTTokens.Load(token); ok { + if invalidatedJWTTokens.Get(token) { return true } } @@ -343,11 +343,11 @@ func isTokenInvalidated(r *http.Request) bool { func invalidateToken(r *http.Request) { tokenString := jwtauth.TokenFromHeader(r) if tokenString != "" { - invalidatedJWTTokens.Store(tokenString, time.Now().Add(tokenDuration).UTC()) + invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) } tokenString = jwtauth.TokenFromCookie(r) if tokenString != "" { - invalidatedJWTTokens.Store(tokenString, time.Now().Add(tokenDuration).UTC()) + invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC()) } } diff --git a/internal/httpd/httpd.go b/internal/httpd/httpd.go index 95608bf7..0fb75b5f 100644 --- a/internal/httpd/httpd.go +++ b/internal/httpd/httpd.go @@ -28,7 +28,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "time" "github.com/go-chi/chi/v5" @@ -196,7 +195,7 @@ var ( certMgr *common.CertManager cleanupTicker *time.Ticker cleanupDone chan bool - invalidatedJWTTokens sync.Map + invalidatedJWTTokens tokenManager csrfTokenAuth *jwtauth.JWTAuth webRootPath string webBasePath string @@ -923,6 +922,7 @@ func (c *Conf) Initialize(configDir string, isShared int) error { } logger.Info(logSender, "", "initializing HTTP server with config %+v", c.getRedacted()) configurationDir = configDir + invalidatedJWTTokens = newTokenManager(isShared) resetCodesMgr = newResetCodeManager(isShared) oidcMgr = newOIDCManager(isShared) oauth2Mgr = newOAuth2Manager(isShared) @@ -1185,7 +1185,7 @@ func startCleanupTicker(duration time.Duration) { return case <-cleanupTicker.C: counter++ - cleanupExpiredJWTTokens() + invalidatedJWTTokens.Cleanup() resetCodesMgr.Cleanup() if counter%2 == 0 { oidcMgr.cleanup() @@ -1204,16 +1204,6 @@ func stopCleanupTicker() { } } -func cleanupExpiredJWTTokens() { - invalidatedJWTTokens.Range(func(key, value any) bool { - exp, ok := value.(time.Time) - if !ok || exp.Before(time.Now().UTC()) { - invalidatedJWTTokens.Delete(key) - } - return true - }) -} - func getSigningKey(signingPassphrase string) []byte { if signingPassphrase != "" { sk := sha256.Sum256([]byte(signingPassphrase)) diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 8835fe8a..ee568079 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -1991,13 +1991,38 @@ func TestJWTTokenCleanup(t *testing.T) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) - invalidatedJWTTokens.Store(token, time.Now().Add(-tokenDuration).UTC()) + invalidatedJWTTokens.Add(token, time.Now().Add(-tokenDuration).UTC()) require.True(t, isTokenInvalidated(req)) startCleanupTicker(100 * time.Millisecond) assert.Eventually(t, func() bool { return !isTokenInvalidated(req) }, 1*time.Second, 200*time.Millisecond) stopCleanupTicker() } +func TestDbTokenManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newTokenManager(1) + dbTokenManager := mgr.(*dbTokenManager) + testToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsiV2ViQWRtaW4iLCI6OjEiXSwiZXhwIjoxNjk4NjYwMDM4LCJqdGkiOiJja3ZuazVrYjF1aHUzZXRmZmhyZyIsIm5iZiI6MTY5ODY1ODgwOCwicGVybWlzc2lvbnMiOlsiKiJdLCJzdWIiOiIxNjk3ODIwNDM3NTMyIiwidXNlcm5hbWUiOiJhZG1pbiJ9.LXuFFksvnSuzHqHat6r70yR0jEulNRju7m7SaWrOfy8; csrftoken=mP0C7DqjwpAXsptO2gGCaYBkYw3oNMWB" + key := dbTokenManager.getKey(testToken) + require.Len(t, key, 64) + dbTokenManager.Add(testToken, time.Now().Add(-tokenDuration).UTC()) + isInvalidated := dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + dbTokenManager.Cleanup() + isInvalidated = dbTokenManager.Get(testToken) + assert.False(t, isInvalidated) + dbTokenManager.Add(testToken, time.Now().Add(tokenDuration).UTC()) + isInvalidated = dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + dbTokenManager.Cleanup() + isInvalidated = dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + err := dataprovider.DeleteSharedSession(key) + assert.NoError(t, err) +} + func TestAllowedProxyUnixDomainSocket(t *testing.T) { b := Binding{ Address: filepath.Join(os.TempDir(), "sock"), diff --git a/internal/httpd/oauth2.go b/internal/httpd/oauth2.go index d5d0e85c..f1c54e08 100644 --- a/internal/httpd/oauth2.go +++ b/internal/httpd/oauth2.go @@ -105,7 +105,6 @@ func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, e } func (o *memoryOAuth2Manager) cleanup() { - logger.Debug(logSender, "", "oauth2 manager cleanup") o.mu.Lock() defer o.mu.Unlock() @@ -165,6 +164,5 @@ func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, er } func (o *dbOAuth2Manager) cleanup() { - logger.Debug(logSender, "", "oauth2 manager cleanup") dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck } diff --git a/internal/httpd/oidcmanager.go b/internal/httpd/oidcmanager.go index 710e5f13..6749e8f2 100644 --- a/internal/httpd/oidcmanager.go +++ b/internal/httpd/oidcmanager.go @@ -124,7 +124,6 @@ func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) { } func (o *memoryOIDCManager) cleanup() { - logger.Debug(logSender, "", "oidc manager cleanup") o.cleanupAuthRequests() o.cleanupTokens() } @@ -238,7 +237,6 @@ func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) { } func (o *dbOIDCManager) cleanup() { - logger.Debug(logSender, "", "oidc manager cleanup") dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck } diff --git a/internal/httpd/token.go b/internal/httpd/token.go new file mode 100644 index 00000000..45f7233d --- /dev/null +++ b/internal/httpd/token.go @@ -0,0 +1,95 @@ +// Copyright (C) 2019-2023 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func newTokenManager(isShared int) tokenManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider token manager") + return &dbTokenManager{} + } + logger.Info(logSender, "", "using memory token manager") + return &memoryTokenManager{} +} + +type tokenManager interface { + Add(token string, expiresAt time.Time) + Get(token string) bool + Cleanup() +} + +type memoryTokenManager struct { + invalidatedJWTTokens sync.Map +} + +func (m *memoryTokenManager) Add(token string, expiresAt time.Time) { + m.invalidatedJWTTokens.Store(token, expiresAt) +} + +func (m *memoryTokenManager) Get(token string) bool { + _, ok := m.invalidatedJWTTokens.Load(token) + return ok +} + +func (m *memoryTokenManager) Cleanup() { + m.invalidatedJWTTokens.Range(func(key, value any) bool { + exp, ok := value.(time.Time) + if !ok || exp.Before(time.Now().UTC()) { + m.invalidatedJWTTokens.Delete(key) + } + return true + }) +} + +type dbTokenManager struct{} + +func (m *dbTokenManager) getKey(token string) string { + digest := sha256.Sum256([]byte(token)) + return hex.EncodeToString(digest[:]) +} + +func (m *dbTokenManager) Add(token string, expiresAt time.Time) { + key := m.getKey(token) + data := map[string]string{ + "jwt": token, + } + session := dataprovider.Session{ + Key: key, + Data: data, + Type: dataprovider.SessionTypeInvalidToken, + Timestamp: util.GetTimeAsMsSinceEpoch(expiresAt), + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (m *dbTokenManager) Get(token string) bool { + key := m.getKey(token) + _, err := dataprovider.GetSharedSession(key) + return err == nil +} + +func (m *dbTokenManager) Cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeInvalidToken, time.Now()) //nolint:errcheck +}