Browse Source

httpd: add database based token manager

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 1 year ago
parent
commit
50cae4ee7d

+ 2 - 1
internal/dataprovider/session.go

@@ -28,6 +28,7 @@ const (
 	SessionTypeOIDCToken
 	SessionTypeResetCode
 	SessionTypeOAuth2Auth
+	SessionTypeInvalidToken
 )
 
 // Session defines a shared session persisted in the data provider
@@ -42,7 +43,7 @@ func (s *Session) validate() error {
 	if s.Key == "" {
 		return errors.New("unable to save a session with an empty key")
 	}
-	if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeOAuth2Auth {
+	if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeInvalidToken {
 		return fmt.Errorf("invalid session type: %v", s.Type)
 	}
 	return nil

+ 3 - 3
internal/httpd/auth_utils.go

@@ -331,7 +331,7 @@ func isTokenInvalidated(r *http.Request) bool {
 		token := fn(r)
 		if token != "" {
 			isTokenFound = true
-			if _, ok := invalidatedJWTTokens.Load(token); ok {
+			if invalidatedJWTTokens.Get(token) {
 				return true
 			}
 		}
@@ -343,11 +343,11 @@ func isTokenInvalidated(r *http.Request) bool {
 func invalidateToken(r *http.Request) {
 	tokenString := jwtauth.TokenFromHeader(r)
 	if tokenString != "" {
-		invalidatedJWTTokens.Store(tokenString, time.Now().Add(tokenDuration).UTC())
+		invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
 	}
 	tokenString = jwtauth.TokenFromCookie(r)
 	if tokenString != "" {
-		invalidatedJWTTokens.Store(tokenString, time.Now().Add(tokenDuration).UTC())
+		invalidatedJWTTokens.Add(tokenString, time.Now().Add(tokenDuration).UTC())
 	}
 }
 

+ 3 - 13
internal/httpd/httpd.go

@@ -28,7 +28,6 @@ import (
 	"path/filepath"
 	"runtime"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/go-chi/chi/v5"
@@ -196,7 +195,7 @@ var (
 	certMgr                        *common.CertManager
 	cleanupTicker                  *time.Ticker
 	cleanupDone                    chan bool
-	invalidatedJWTTokens           sync.Map
+	invalidatedJWTTokens           tokenManager
 	csrfTokenAuth                  *jwtauth.JWTAuth
 	webRootPath                    string
 	webBasePath                    string
@@ -921,6 +920,7 @@ func (c *Conf) Initialize(configDir string, isShared int) error {
 	}
 	logger.Info(logSender, "", "initializing HTTP server with config %+v", c.getRedacted())
 	configurationDir = configDir
+	invalidatedJWTTokens = newTokenManager(isShared)
 	resetCodesMgr = newResetCodeManager(isShared)
 	oidcMgr = newOIDCManager(isShared)
 	oauth2Mgr = newOAuth2Manager(isShared)
@@ -1183,7 +1183,7 @@ func startCleanupTicker(duration time.Duration) {
 				return
 			case <-cleanupTicker.C:
 				counter++
-				cleanupExpiredJWTTokens()
+				invalidatedJWTTokens.Cleanup()
 				resetCodesMgr.Cleanup()
 				if counter%2 == 0 {
 					oidcMgr.cleanup()
@@ -1202,16 +1202,6 @@ func stopCleanupTicker() {
 	}
 }
 
-func cleanupExpiredJWTTokens() {
-	invalidatedJWTTokens.Range(func(key, value any) bool {
-		exp, ok := value.(time.Time)
-		if !ok || exp.Before(time.Now().UTC()) {
-			invalidatedJWTTokens.Delete(key)
-		}
-		return true
-	})
-}
-
 func getSigningKey(signingPassphrase string) []byte {
 	if signingPassphrase != "" {
 		sk := sha256.Sum256([]byte(signingPassphrase))

+ 26 - 1
internal/httpd/internal_test.go

@@ -1991,13 +1991,38 @@ func TestJWTTokenCleanup(t *testing.T) {
 
 	req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token))
 
-	invalidatedJWTTokens.Store(token, time.Now().Add(-tokenDuration).UTC())
+	invalidatedJWTTokens.Add(token, time.Now().Add(-tokenDuration).UTC())
 	require.True(t, isTokenInvalidated(req))
 	startCleanupTicker(100 * time.Millisecond)
 	assert.Eventually(t, func() bool { return !isTokenInvalidated(req) }, 1*time.Second, 200*time.Millisecond)
 	stopCleanupTicker()
 }
 
+func TestDbTokenManager(t *testing.T) {
+	if !isSharedProviderSupported() {
+		t.Skip("this test it is not available with this provider")
+	}
+	mgr := newTokenManager(1)
+	dbTokenManager := mgr.(*dbTokenManager)
+	testToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsiV2ViQWRtaW4iLCI6OjEiXSwiZXhwIjoxNjk4NjYwMDM4LCJqdGkiOiJja3ZuazVrYjF1aHUzZXRmZmhyZyIsIm5iZiI6MTY5ODY1ODgwOCwicGVybWlzc2lvbnMiOlsiKiJdLCJzdWIiOiIxNjk3ODIwNDM3NTMyIiwidXNlcm5hbWUiOiJhZG1pbiJ9.LXuFFksvnSuzHqHat6r70yR0jEulNRju7m7SaWrOfy8; csrftoken=mP0C7DqjwpAXsptO2gGCaYBkYw3oNMWB"
+	key := dbTokenManager.getKey(testToken)
+	require.Len(t, key, 64)
+	dbTokenManager.Add(testToken, time.Now().Add(-tokenDuration).UTC())
+	isInvalidated := dbTokenManager.Get(testToken)
+	assert.True(t, isInvalidated)
+	dbTokenManager.Cleanup()
+	isInvalidated = dbTokenManager.Get(testToken)
+	assert.False(t, isInvalidated)
+	dbTokenManager.Add(testToken, time.Now().Add(tokenDuration).UTC())
+	isInvalidated = dbTokenManager.Get(testToken)
+	assert.True(t, isInvalidated)
+	dbTokenManager.Cleanup()
+	isInvalidated = dbTokenManager.Get(testToken)
+	assert.True(t, isInvalidated)
+	err := dataprovider.DeleteSharedSession(key)
+	assert.NoError(t, err)
+}
+
 func TestAllowedProxyUnixDomainSocket(t *testing.T) {
 	b := Binding{
 		Address:      filepath.Join(os.TempDir(), "sock"),

+ 0 - 2
internal/httpd/oauth2.go

@@ -105,7 +105,6 @@ func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, e
 }
 
 func (o *memoryOAuth2Manager) cleanup() {
-	logger.Debug(logSender, "", "oauth2 manager cleanup")
 	o.mu.Lock()
 	defer o.mu.Unlock()
 
@@ -165,6 +164,5 @@ func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, er
 }
 
 func (o *dbOAuth2Manager) cleanup() {
-	logger.Debug(logSender, "", "oauth2 manager cleanup")
 	dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck
 }

+ 0 - 2
internal/httpd/oidcmanager.go

@@ -124,7 +124,6 @@ func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) {
 }
 
 func (o *memoryOIDCManager) cleanup() {
-	logger.Debug(logSender, "", "oidc manager cleanup")
 	o.cleanupAuthRequests()
 	o.cleanupTokens()
 }
@@ -238,7 +237,6 @@ func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) {
 }
 
 func (o *dbOIDCManager) cleanup() {
-	logger.Debug(logSender, "", "oidc manager cleanup")
 	dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now())  //nolint:errcheck
 	dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck
 }

+ 95 - 0
internal/httpd/token.go

@@ -0,0 +1,95 @@
+// Copyright (C) 2019-2023 Nicola Murino
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, version 3.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <https://www.gnu.org/licenses/>.
+
+package httpd
+
+import (
+	"crypto/sha256"
+	"encoding/hex"
+	"sync"
+	"time"
+
+	"github.com/drakkan/sftpgo/v2/internal/dataprovider"
+	"github.com/drakkan/sftpgo/v2/internal/logger"
+	"github.com/drakkan/sftpgo/v2/internal/util"
+)
+
+func newTokenManager(isShared int) tokenManager {
+	if isShared == 1 {
+		logger.Info(logSender, "", "using provider token manager")
+		return &dbTokenManager{}
+	}
+	logger.Info(logSender, "", "using memory token manager")
+	return &memoryTokenManager{}
+}
+
+type tokenManager interface {
+	Add(token string, expiresAt time.Time)
+	Get(token string) bool
+	Cleanup()
+}
+
+type memoryTokenManager struct {
+	invalidatedJWTTokens sync.Map
+}
+
+func (m *memoryTokenManager) Add(token string, expiresAt time.Time) {
+	m.invalidatedJWTTokens.Store(token, expiresAt)
+}
+
+func (m *memoryTokenManager) Get(token string) bool {
+	_, ok := m.invalidatedJWTTokens.Load(token)
+	return ok
+}
+
+func (m *memoryTokenManager) Cleanup() {
+	m.invalidatedJWTTokens.Range(func(key, value any) bool {
+		exp, ok := value.(time.Time)
+		if !ok || exp.Before(time.Now().UTC()) {
+			m.invalidatedJWTTokens.Delete(key)
+		}
+		return true
+	})
+}
+
+type dbTokenManager struct{}
+
+func (m *dbTokenManager) getKey(token string) string {
+	digest := sha256.Sum256([]byte(token))
+	return hex.EncodeToString(digest[:])
+}
+
+func (m *dbTokenManager) Add(token string, expiresAt time.Time) {
+	key := m.getKey(token)
+	data := map[string]string{
+		"jwt": token,
+	}
+	session := dataprovider.Session{
+		Key:       key,
+		Data:      data,
+		Type:      dataprovider.SessionTypeInvalidToken,
+		Timestamp: util.GetTimeAsMsSinceEpoch(expiresAt),
+	}
+	dataprovider.AddSharedSession(session) //nolint:errcheck
+}
+
+func (m *dbTokenManager) Get(token string) bool {
+	key := m.getKey(token)
+	_, err := dataprovider.GetSharedSession(key)
+	return err == nil
+}
+
+func (m *dbTokenManager) Cleanup() {
+	dataprovider.CleanupSharedSessions(dataprovider.SessionTypeInvalidToken, time.Now()) //nolint:errcheck
+}