From 9ad750da54b34519d8a2b58e83e447578ef2c8f8 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 27 Mar 2021 19:10:27 +0100 Subject: [PATCH] WebDAV: try to preserve the lock fs as much as possible --- dataprovider/cacheduser.go | 142 +++++++++++++++++++++++++++++++++++ dataprovider/dataprovider.go | 81 +++++--------------- dataprovider/user.go | 51 ++++++++----- docs/external-auth.md | 6 +- httpd/httpd_test.go | 68 ----------------- kms/kms.go | 32 ++++++++ vfs/azblobfs.go | 8 +- vfs/cryptfs.go | 6 +- vfs/filesystem.go | 21 ++++++ vfs/gcsfs.go | 8 +- vfs/s3fs.go | 7 +- vfs/sftpfs.go | 44 ++++++++++- vfs/vfs.go | 103 +++++++++++++++++++++++++ webdavd/internal_test.go | 71 ++++++++++-------- webdavd/server.go | 9 ++- webdavd/webdavd.go | 2 + webdavd/webdavd_test.go | 6 ++ 17 files changed, 456 insertions(+), 209 deletions(-) create mode 100644 dataprovider/cacheduser.go diff --git a/dataprovider/cacheduser.go b/dataprovider/cacheduser.go new file mode 100644 index 00000000..cddba4de --- /dev/null +++ b/dataprovider/cacheduser.go @@ -0,0 +1,142 @@ +package dataprovider + +import ( + "sync" + "time" + + "golang.org/x/net/webdav" + + "github.com/drakkan/sftpgo/utils" +) + +var ( + webDAVUsersCache *usersCache +) + +func init() { + webDAVUsersCache = &usersCache{ + users: map[string]CachedUser{}, + } +} + +// InitializeWebDAVUserCache initializes the cache for webdav users +func InitializeWebDAVUserCache(maxSize int) { + webDAVUsersCache = &usersCache{ + users: map[string]CachedUser{}, + maxSize: maxSize, + } +} + +// CachedUser adds fields useful for caching to a SFTPGo user +type CachedUser struct { + User User + Expiration time.Time + Password string + LockSystem webdav.LockSystem +} + +// IsExpired returns true if the cached user is expired +func (c *CachedUser) IsExpired() bool { + if c.Expiration.IsZero() { + return false + } + return c.Expiration.Before(time.Now()) +} + +type usersCache struct { + sync.RWMutex + users map[string]CachedUser + maxSize int +} + +func (cache *usersCache) updateLastLogin(username string) { + cache.Lock() + defer cache.Unlock() + + if cachedUser, ok := cache.users[username]; ok { + cachedUser.User.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now()) + cache.users[username] = cachedUser + } +} + +// swapWebDAVUser updates an existing cached user with the specified one +// preserving the lock fs if possible +func (cache *usersCache) swap(user *User) { + cache.Lock() + defer cache.Unlock() + + if cachedUser, ok := cache.users[user.Username]; ok { + if cachedUser.User.Password != user.Password { + // the password changed, the cached user is no longer valid + delete(cache.users, user.Username) + return + } + if cachedUser.User.isFsEqual(user) { + // the updated user has the same fs as the cached one, we can preserve the lock filesystem + cachedUser.User = *user + cache.users[user.Username] = cachedUser + } else { + // filesystem changed, the cached user is no longer valid + delete(cache.users, user.Username) + } + } +} + +func (cache *usersCache) add(cachedUser *CachedUser) { + cache.Lock() + defer cache.Unlock() + + if cache.maxSize > 0 && len(cache.users) >= cache.maxSize { + var userToRemove string + var expirationTime time.Time + + for k, v := range cache.users { + if userToRemove == "" { + userToRemove = k + expirationTime = v.Expiration + continue + } + expireTime := v.Expiration + if !expireTime.IsZero() && expireTime.Before(expirationTime) { + userToRemove = k + expirationTime = expireTime + } + } + + delete(cache.users, userToRemove) + } + + if cachedUser.User.Username != "" { + cache.users[cachedUser.User.Username] = *cachedUser + } +} + +func (cache *usersCache) remove(username string) { + cache.Lock() + defer cache.Unlock() + + delete(cache.users, username) +} + +func (cache *usersCache) get(username string) (*CachedUser, bool) { + cache.RLock() + defer cache.RUnlock() + + cachedUser, ok := cache.users[username] + return &cachedUser, ok +} + +// CacheWebDAVUser add a user to the WebDAV cache +func CacheWebDAVUser(cachedUser *CachedUser) { + webDAVUsersCache.add(cachedUser) +} + +// GetCachedWebDAVUser returns a previously cached WebDAV user +func GetCachedWebDAVUser(username string) (*CachedUser, bool) { + return webDAVUsersCache.get(username) +} + +// RemoveCachedWebDAVUser removes a cached WebDAV user +func RemoveCachedWebDAVUser(username string) { + webDAVUsersCache.remove(username) +} diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 500622b7..09b66204 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -111,7 +111,6 @@ var ( // ErrInvalidCredentials defines the error to return if the supplied credentials are invalid ErrInvalidCredentials = errors.New("invalid credentials") validTLSUsernames = []string{string(TLSUsernameNone), string(TLSUsernameCN)} - webDAVUsersCache sync.Map config Config provider Provider sqlPlaceholders []string @@ -750,7 +749,7 @@ func UpdateLastLogin(user *User) error { if diff < 0 || diff > lastLoginMinDelay { err := provider.updateLastLogin(user.Username) if err == nil { - updateWebDavCachedUserLastLogin(user.Username) + webDAVUsersCache.updateLastLogin(user.Username) } return err } @@ -841,7 +840,7 @@ func AddUser(user *User) error { func UpdateUser(user *User) error { err := provider.updateUser(user) if err == nil { - RemoveCachedWebDAVUser(user.Username) + webDAVUsersCache.swap(user) executeAction(operationUpdate, user) } return err @@ -2190,6 +2189,9 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro err = provider.addUser(&u) } else { err = provider.updateUser(&u) + if err == nil { + webDAVUsersCache.swap(&u) + } } if err != nil { return u, err @@ -2328,6 +2330,15 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, return cmd.Output() } +func updateUserFromExtAuthResponse(user *User, password, pkey string) { + if password != "" { + user.Password = password + } + if pkey != "" && !utils.IsStringPrefixInSlice(pkey, user.PublicKeys) { + user.PublicKeys = append(user.PublicKeys, pkey) + } +} + func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive, ip, protocol string, tlsCert *x509.Certificate) (User, error) { var user User @@ -2358,15 +2369,11 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv if err != nil { return user, fmt.Errorf("invalid external auth response: %v", err) } + // an empty username means authentication failure if user.Username == "" { return user, ErrInvalidCredentials } - if password != "" { - user.Password = password - } - if pkey != "" && !utils.IsStringPrefixInSlice(pkey, user.PublicKeys) { - user.PublicKeys = append(user.PublicKeys, pkey) - } + updateUserFromExtAuthResponse(&user, password, pkey) // some users want to map multiple login usernames with a single SFTPGo account // for example an SFTP user logins using "user1" or "user2" and the external auth // returns "user" in both cases, so we use the username returned from @@ -2381,6 +2388,9 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv user.LastQuotaUpdate = u.LastQuotaUpdate user.LastLogin = u.LastLogin err = provider.updateUser(&user) + if err == nil { + webDAVUsersCache.swap(&user) + } return user, err } err = provider.addUser(&user) @@ -2485,56 +2495,3 @@ func executeAction(operation string, user *User) { } }() } - -func updateWebDavCachedUserLastLogin(username string) { - result, ok := webDAVUsersCache.Load(username) - if ok { - cachedUser := result.(*CachedUser) - cachedUser.User.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now()) - webDAVUsersCache.Store(cachedUser.User.Username, cachedUser) - } -} - -// CacheWebDAVUser add a user to the WebDAV cache -func CacheWebDAVUser(cachedUser *CachedUser, maxSize int) { - if maxSize > 0 { - var cacheSize int - var userToRemove string - var expirationTime time.Time - - webDAVUsersCache.Range(func(k, v interface{}) bool { - cacheSize++ - if userToRemove == "" { - userToRemove = k.(string) - expirationTime = v.(*CachedUser).Expiration - return true - } - expireTime := v.(*CachedUser).Expiration - if !expireTime.IsZero() && expireTime.Before(expirationTime) { - userToRemove = k.(string) - expirationTime = expireTime - } - return true - }) - - if cacheSize >= maxSize { - RemoveCachedWebDAVUser(userToRemove) - } - } - - if cachedUser.User.Username != "" { - webDAVUsersCache.Store(cachedUser.User.Username, cachedUser) - } -} - -// GetCachedWebDAVUser returns a previously cached WebDAV user -func GetCachedWebDAVUser(username string) (interface{}, bool) { - return webDAVUsersCache.Load(username) -} - -// RemoveCachedWebDAVUser removes a cached WebDAV user -func RemoveCachedWebDAVUser(username string) { - if username != "" { - webDAVUsersCache.Delete(username) - } -} diff --git a/dataprovider/user.go b/dataprovider/user.go index 122149a8..6c7f3b4c 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -13,8 +13,6 @@ import ( "strings" "time" - "golang.org/x/net/webdav" - "github.com/drakkan/sftpgo/kms" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" @@ -75,22 +73,6 @@ var ( errNoMatchingVirtualFolder = errors.New("no matching virtual folder found") ) -// CachedUser adds fields useful for caching to a SFTPGo user -type CachedUser struct { - User User - Expiration time.Time - Password string - LockSystem webdav.LockSystem -} - -// IsExpired returns true if the cached user is expired -func (c *CachedUser) IsExpired() bool { - if c.Expiration.IsZero() { - return false - } - return c.Expiration.Before(time.Now()) -} - // ExtensionsFilter defines filters based on file extensions. // These restrictions do not apply to files listing for performance reasons, so // a denied file cannot be downloaded/overwritten/renamed but will still be @@ -279,6 +261,39 @@ func (u *User) CheckFsRoot(connectionID string) error { return nil } +// isFsEqual returns true if the fs has the same configuration +func (u *User) isFsEqual(other *User) bool { + if u.FsConfig.Provider == vfs.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() { + return false + } + if !u.FsConfig.IsEqual(&other.FsConfig) { + return false + } + if len(u.VirtualFolders) != len(other.VirtualFolders) { + return false + } + for idx := range u.VirtualFolders { + f := &u.VirtualFolders[idx] + found := false + for idx1 := range other.VirtualFolders { + f1 := &other.VirtualFolders[idx1] + if f.VirtualPath == f1.VirtualPath { + found = true + if f.FsConfig.Provider == vfs.LocalFilesystemProvider && f.MappedPath != f1.MappedPath { + return false + } + if !f.FsConfig.IsEqual(&f1.FsConfig) { + return false + } + } + } + if !found { + return false + } + } + return true +} + // hideConfidentialData hides user confidential data func (u *User) hideConfidentialData() { u.Password = "" diff --git a/docs/external-auth.md b/docs/external-auth.md index 99be473b..f8774dc8 100644 --- a/docs/external-auth.md +++ b/docs/external-auth.md @@ -18,7 +18,7 @@ The program can inspect the SFTPGo user, if it exists, using the `SFTPGO_AUTHD_U The program must write, on its standard output: - a valid SFTPGo user serialized as JSON if the authentication succeeds. The user will be added/updated within the defined data provider -- an empty string, or no response at all, if authentication succeeds and the existing SFTPGo user does not need to be updated +- an empty string, or no response at all, if authentication succeeds and the existing SFTPGo user does not need to be updated. Please note that in versions 2.0.x and earlier an empty response was interpreted as an authentication error - a user with an empty username if the authentication fails If the hook is an HTTP URL then it will be invoked as HTTP POST. The request body will contain a JSON serialized struct with the following fields: @@ -35,9 +35,9 @@ If the hook is an HTTP URL then it will be invoked as HTTP POST. The request bod If authentication succeeds the HTTP response code must be 200 and the response body can be: - a valid SFTPGo user serialized as JSON. The user will be added/updated within the defined data provider -- empty, the existing SFTPGo user does not need to be updated +- empty, the existing SFTPGo user does not need to be updated. Please note that in versions 2.0.x and earlier an empty response was interpreted as an authentication error -If the authentication fails the HTTP response code must be != 200. +If the authentication fails the HTTP response code must be != 200 or the returned SFTPGo user must have an empty username. Actions defined for users added/updated will not be executed in this case and an already logged in user with the same username will not be disconnected. diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 6e85abbe..11a88f32 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -824,74 +824,6 @@ func TestAddUserInvalidVirtualFolders(t *testing.T) { }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) - /*u.VirtualFolders = nil - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir", "subdir"), - Name: folderName + "2", - }, - VirtualPath: "/vdir1", - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), // invalid, contains mapped_dir/subdir - Name: folderName, - }, - VirtualPath: "/vdir2", - }) - _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) - assert.NoError(t, err) - u.VirtualFolders = nil - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), - Name: folderName, - }, - VirtualPath: "/vdir1", - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir", "subdir"), // invalid, contained in mapped_dir - Name: folderName + "3", - }, - VirtualPath: "/vdir2", - }) - _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) - assert.NoError(t, err) - u.VirtualFolders = nil - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), - Name: folderName + "1", - }, - VirtualPath: "/vdir1/subdir", - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir2"), - Name: folderName + "2", - }, - VirtualPath: "/vdir1/../vdir1", // invalid, overlaps with /vdir1/subdir - }) - _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) - assert.NoError(t, err) - u.VirtualFolders = nil - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), - Name: folderName + "1", - }, - VirtualPath: "/vdir1/", - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: filepath.Join(os.TempDir(), "mapped_dir2"), - Name: folderName + "2", - }, - VirtualPath: "/vdir1/subdir", // invalid, contained inside /vdir1 - }) - _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) - assert.NoError(t, err)*/ u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ diff --git a/kms/kms.go b/kms/kms.go index 399aee5a..93b0405e 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -197,6 +197,26 @@ func (s *Secret) UnmarshalJSON(data []byte) error { return nil } +// IsEqual returns true if all the secrets fields are equal +func (s *Secret) IsEqual(other *Secret) bool { + if s.GetStatus() != other.GetStatus() { + return false + } + if s.GetPayload() != other.GetPayload() { + return false + } + if s.GetKey() != other.GetKey() { + return false + } + if s.GetAdditionalData() != other.GetAdditionalData() { + return false + } + if s.GetMode() != other.GetMode() { + return false + } + return true +} + // Clone returns a copy of the secret object func (s *Secret) Clone() *Secret { s.RLock() @@ -414,3 +434,15 @@ func (s *Secret) Decrypt() error { return s.provider.Decrypt() } + +// TryDecrypt decrypts a Secret object if encrypted. +// It returns a nil error if the object is not encrypted +func (s *Secret) TryDecrypt() error { + s.Lock() + defer s.Unlock() + + if s.provider.IsEncrypted() { + return s.provider.Decrypt() + } + return nil +} diff --git a/vfs/azblobfs.go b/vfs/azblobfs.go index a1c455c6..f8be3fa7 100644 --- a/vfs/azblobfs.go +++ b/vfs/azblobfs.go @@ -67,11 +67,9 @@ func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsCo if err := fs.config.Validate(); err != nil { return fs, err } - if fs.config.AccountKey.IsEncrypted() { - err := fs.config.AccountKey.Decrypt() - if err != nil { - return fs, err - } + + if err := fs.config.AccountKey.TryDecrypt(); err != nil { + return fs, err } fs.setConfigDefaults() diff --git a/vfs/cryptfs.go b/vfs/cryptfs.go index 79e95c01..fb5bde1b 100644 --- a/vfs/cryptfs.go +++ b/vfs/cryptfs.go @@ -35,10 +35,8 @@ func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) ( if err := config.Validate(); err != nil { return nil, err } - if config.Passphrase.IsEncrypted() { - if err := config.Passphrase.Decrypt(); err != nil { - return nil, err - } + if err := config.Passphrase.TryDecrypt(); err != nil { + return nil, err } fs := &CryptFs{ OsFs: &OsFs{ diff --git a/vfs/filesystem.go b/vfs/filesystem.go index 9bbdc670..cda9bcbb 100644 --- a/vfs/filesystem.go +++ b/vfs/filesystem.go @@ -72,6 +72,27 @@ func (f *Filesystem) SetNilSecretsIfEmpty() { } } +// IsEqual returns true if the fs is equal to other +func (f *Filesystem) IsEqual(other *Filesystem) bool { + if f.Provider != other.Provider { + return false + } + switch f.Provider { + case S3FilesystemProvider: + return f.S3Config.isEqual(&other.S3Config) + case GCSFilesystemProvider: + return f.GCSConfig.isEqual(&other.GCSConfig) + case AzureBlobFilesystemProvider: + return f.AzBlobConfig.isEqual(&other.AzBlobConfig) + case CryptedFilesystemProvider: + return f.CryptConfig.isEqual(&other.CryptConfig) + case SFTPFilesystemProvider: + return f.SFTPConfig.isEqual(&other.SFTPConfig) + default: + return true + } +} + // GetACopy returns a copy func (f *Filesystem) GetACopy() Filesystem { f.SetEmptySecretsIfNil() diff --git a/vfs/gcsfs.go b/vfs/gcsfs.go index b1a5781c..c4288546 100644 --- a/vfs/gcsfs.go +++ b/vfs/gcsfs.go @@ -71,11 +71,9 @@ func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) if fs.config.AutomaticCredentials > 0 { fs.svc, err = storage.NewClient(ctx) } else if !fs.config.Credentials.IsEmpty() { - if fs.config.Credentials.IsEncrypted() { - err = fs.config.Credentials.Decrypt() - if err != nil { - return fs, err - } + err = fs.config.Credentials.TryDecrypt() + if err != nil { + return fs, err } fs.svc, err = storage.NewClient(ctx, option.WithCredentialsJSON([]byte(fs.config.Credentials.GetPayload()))) } else { diff --git a/vfs/s3fs.go b/vfs/s3fs.go index 93912f74..53701dad 100644 --- a/vfs/s3fs.go +++ b/vfs/s3fs.go @@ -68,11 +68,8 @@ func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (F } if !fs.config.AccessSecret.IsEmpty() { - if fs.config.AccessSecret.IsEncrypted() { - err := fs.config.AccessSecret.Decrypt() - if err != nil { - return fs, err - } + if err := fs.config.AccessSecret.TryDecrypt(); err != nil { + return fs, err } awsConfig.Credentials = credentials.NewStaticCredentials(fs.config.AccessKey, fs.config.AccessSecret.GetPayload(), "") } diff --git a/vfs/sftpfs.go b/vfs/sftpfs.go index 97d8bd7f..c75a8331 100644 --- a/vfs/sftpfs.go +++ b/vfs/sftpfs.go @@ -44,6 +44,35 @@ type SFTPFsConfig struct { DisableCouncurrentReads bool `json:"disable_concurrent_reads,omitempty"` } +func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool { + if c.Endpoint != other.Endpoint { + return false + } + if c.Username != other.Username { + return false + } + if c.Prefix != other.Prefix { + return false + } + if c.DisableCouncurrentReads != other.DisableCouncurrentReads { + return false + } + if len(c.Fingerprints) != len(other.Fingerprints) { + return false + } + for _, fp := range c.Fingerprints { + if !utils.IsStringInSlice(fp, other.Fingerprints) { + return false + } + } + c.setEmptyCredentialsIfNil() + other.setEmptyCredentialsIfNil() + if !c.Password.IsEqual(other.Password) { + return false + } + return c.PrivateKey.IsEqual(other.PrivateKey) +} + func (c *SFTPFsConfig) setEmptyCredentialsIfNil() { if c.Password == nil { c.Password = kms.NewEmptySecret() @@ -123,13 +152,13 @@ func NewSFTPFs(connectionID, mountPath string, config SFTPFsConfig) (Fs, error) if err := config.Validate(); err != nil { return nil, err } - if !config.Password.IsEmpty() && config.Password.IsEncrypted() { - if err := config.Password.Decrypt(); err != nil { + if !config.Password.IsEmpty() { + if err := config.Password.TryDecrypt(); err != nil { return nil, err } } - if !config.PrivateKey.IsEmpty() && config.PrivateKey.IsEncrypted() { - if err := config.PrivateKey.Decrypt(); err != nil { + if !config.PrivateKey.IsEmpty() { + if err := config.PrivateKey.TryDecrypt(); err != nil { return nil, err } } @@ -339,6 +368,13 @@ func (*SFTPFs) IsNotSupported(err error) bool { // CheckRootPath creates the specified local root directory if it does not exists func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool { + if fs.config.Prefix == "/" { + return true + } + if err := fs.MkdirAll(fs.config.Prefix, uid, gid); err != nil { + fsLog(fs, logger.LevelDebug, "error creating root directory %#v for user %#v: %v", fs.config.Prefix, username, err) + return false + } return true } diff --git a/vfs/vfs.go b/vfs/vfs.go index 2660682f..39b56f6e 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -147,6 +147,40 @@ type S3FsConfig struct { UploadConcurrency int `json:"upload_concurrency,omitempty"` } +func (c *S3FsConfig) isEqual(other *S3FsConfig) bool { + if c.Bucket != other.Bucket { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.Region != other.Region { + return false + } + if c.AccessKey != other.AccessKey { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + if c.StorageClass != other.StorageClass { + return false + } + if c.UploadPartSize != other.UploadPartSize { + return false + } + if c.UploadConcurrency != other.UploadConcurrency { + return false + } + if c.AccessSecret == nil { + c.AccessSecret = kms.NewEmptySecret() + } + if other.AccessSecret == nil { + other.AccessSecret = kms.NewEmptySecret() + } + return c.AccessSecret.IsEqual(other.AccessSecret) +} + func (c *S3FsConfig) checkCredentials() error { if c.AccessKey == "" && !c.AccessSecret.IsEmpty() { return errors.New("access_key cannot be empty with access_secret not empty") @@ -224,6 +258,28 @@ type GCSFsConfig struct { StorageClass string `json:"storage_class,omitempty"` } +func (c *GCSFsConfig) isEqual(other *GCSFsConfig) bool { + if c.Bucket != other.Bucket { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.AutomaticCredentials != other.AutomaticCredentials { + return false + } + if c.StorageClass != other.StorageClass { + return false + } + if c.Credentials == nil { + c.Credentials = kms.NewEmptySecret() + } + if other.Credentials == nil { + other.Credentials = kms.NewEmptySecret() + } + return c.Credentials.IsEqual(other.Credentials) +} + // Validate returns an error if the configuration is not valid func (c *GCSFsConfig) Validate(credentialsFilePath string) error { if c.Credentials == nil { @@ -293,6 +349,43 @@ type AzBlobFsConfig struct { AccessTier string `json:"access_tier,omitempty"` } +func (c *AzBlobFsConfig) isEqual(other *AzBlobFsConfig) bool { + if c.Container != other.Container { + return false + } + if c.AccountName != other.AccountName { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + if c.SASURL != other.SASURL { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.UploadPartSize != other.UploadPartSize { + return false + } + if c.UploadConcurrency != other.UploadConcurrency { + return false + } + if c.UseEmulator != other.UseEmulator { + return false + } + if c.AccessTier != other.AccessTier { + return false + } + if c.AccountKey == nil { + c.AccountKey = kms.NewEmptySecret() + } + if other.AccountKey == nil { + other.AccountKey = kms.NewEmptySecret() + } + return c.AccountKey.IsEqual(other.AccountKey) +} + // EncryptCredentials encrypts access secret if it is in plain text func (c *AzBlobFsConfig) EncryptCredentials(additionalData string) error { if c.AccountKey.IsPlain() { @@ -355,6 +448,16 @@ type CryptFsConfig struct { Passphrase *kms.Secret `json:"passphrase,omitempty"` } +func (c *CryptFsConfig) isEqual(other *CryptFsConfig) bool { + if c.Passphrase == nil { + c.Passphrase = kms.NewEmptySecret() + } + if other.Passphrase == nil { + other.Passphrase = kms.NewEmptySecret() + } + return c.Passphrase.IsEqual(other.Passphrase) +} + // EncryptCredentials encrypts access secret if it is in plain text func (c *CryptFsConfig) EncryptCredentials(additionalData string) error { if c.Passphrase.IsPlain() { diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index 434b467d..569625c2 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -895,6 +895,7 @@ func TestBasicUsersCache(t *testing.T) { }, }, } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], @@ -915,10 +916,8 @@ func TestBasicUsersCache(t *testing.T) { assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) // now the user should be cached - var cachedUser *dataprovider.CachedUser - result, ok := dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.False(t, cachedUser.IsExpired()) assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) // authenticate must return the cached user now @@ -935,10 +934,9 @@ func TestBasicUsersCache(t *testing.T) { // force cached user expiration cachedUser.Expiration = now - dataprovider.CacheWebDAVUser(cachedUser, c.Cache.Users.MaxSize) - result, ok = dataprovider.GetCachedWebDAVUser(username) + dataprovider.CacheWebDAVUser(cachedUser) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.True(t, cachedUser.IsExpired()) } // now authenticate should get the user from the data provider and update the cache @@ -946,12 +944,24 @@ func TestBasicUsersCache(t *testing.T) { assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) - result, ok = dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.False(t, cachedUser.IsExpired()) } - // cache is invalidated after a user modification + // cache is not invalidated after a user modification if the fs does not change + err = dataprovider.UpdateUser(&user) + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.True(t, ok) + folderName := "testFolder" + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "mapped"), + }, + VirtualPath: "/vdir", + }) + err = dataprovider.UpdateUser(&user) assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(username) @@ -969,6 +979,9 @@ func TestBasicUsersCache(t *testing.T) { _, ok = dataprovider.GetCachedWebDAVUser(username) assert.False(t, ok) + err = dataprovider.DeleteFolder(folderName) + assert.NoError(t, err) + err = os.RemoveAll(u.GetHomeDir()) assert.NoError(t, err) } @@ -1011,6 +1024,7 @@ func TestCachedUserWithFolders(t *testing.T) { }, }, } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], @@ -1031,10 +1045,8 @@ func TestCachedUserWithFolders(t *testing.T) { assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) // now the user should be cached - var cachedUser *dataprovider.CachedUser - result, ok := dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.False(t, cachedUser.IsExpired()) assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) // authenticate must return the cached user now @@ -1054,9 +1066,8 @@ func TestCachedUserWithFolders(t *testing.T) { assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) - result, ok = dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.False(t, cachedUser.IsExpired()) } @@ -1067,9 +1078,8 @@ func TestCachedUserWithFolders(t *testing.T) { assert.NoError(t, err) assert.False(t, isCached) assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) - result, ok = dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser = result.(*dataprovider.CachedUser) assert.False(t, cachedUser.IsExpired()) } @@ -1133,6 +1143,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) { }, }, } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) server := webDavServer{ config: c, binding: c.Bindings[0], @@ -1240,6 +1251,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) { assert.True(t, ok) // now remove user1 after an update + user1.HomeDir += "_mod" err = dataprovider.UpdateUser(&user1) assert.NoError(t, err) _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) @@ -1283,6 +1295,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) { } func TestUserCacheIsolation(t *testing.T) { + dataprovider.InitializeWebDAVUserCache(10) username := "webdav_internal_cache_test" password := "dav_pwd" u := dataprovider.User{ @@ -1307,31 +1320,27 @@ func TestUserCacheIsolation(t *testing.T) { cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret") err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt() assert.NoError(t, err) - - dataprovider.CacheWebDAVUser(cachedUser, 10) - result, ok := dataprovider.GetCachedWebDAVUser(username) + dataprovider.CacheWebDAVUser(cachedUser) + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser := result.(*dataprovider.CachedUser).User - _, err = cachedUser.GetFilesystem("") + _, err = cachedUser.User.GetFilesystem("") assert.NoError(t, err) // the filesystem is now cached } - result, ok = dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser := result.(*dataprovider.CachedUser).User - assert.True(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted()) - err = cachedUser.FsConfig.S3Config.AccessSecret.Decrypt() + assert.True(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) + err = cachedUser.User.FsConfig.S3Config.AccessSecret.Decrypt() assert.NoError(t, err) - cachedUser.FsConfig.Provider = vfs.S3FilesystemProvider - _, err = cachedUser.GetFilesystem("") + cachedUser.User.FsConfig.Provider = vfs.S3FilesystemProvider + _, err = cachedUser.User.GetFilesystem("") assert.Error(t, err, "we don't have to get the previously cached filesystem!") } - result, ok = dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) if assert.True(t, ok) { - cachedUser := result.(*dataprovider.CachedUser).User - assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.FsConfig.Provider) - assert.False(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted()) + assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.User.FsConfig.Provider) + assert.False(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) } err = dataprovider.DeleteUser(username) diff --git a/webdavd/server.go b/webdavd/server.go index 4e03b513..b28e91b2 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -171,6 +171,8 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { connectionID, err := s.validateUser(&user, r, loginMethod) if err != nil { + // remove the cached user, we have not yet validated its filesystem + dataprovider.RemoveCachedWebDAVUser(user.Username) updateLoginMetrics(&user, ipAddr, loginMethod, err) http.Error(w, err.Error(), http.StatusForbidden) return @@ -246,9 +248,8 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us if !ok { return user, false, nil, loginMethod, err401 } - result, ok := dataprovider.GetCachedWebDAVUser(username) + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) if ok { - cachedUser := result.(*dataprovider.CachedUser) if cachedUser.IsExpired() { dataprovider.RemoveCachedWebDAVUser(username) } else { @@ -272,7 +273,7 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us return user, false, nil, loginMethod, err } lockSystem := webdav.NewMemLS() - cachedUser := &dataprovider.CachedUser{ + cachedUser = &dataprovider.CachedUser{ User: user, Password: password, LockSystem: lockSystem, @@ -280,7 +281,7 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us if s.config.Cache.Users.ExpirationTime > 0 { cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute) } - dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize) + dataprovider.CacheWebDAVUser(cachedUser) return user, false, lockSystem, loginMethod, nil } diff --git a/webdavd/webdavd.go b/webdavd/webdavd.go index e7d5ce1f..53858429 100644 --- a/webdavd/webdavd.go +++ b/webdavd/webdavd.go @@ -8,6 +8,7 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/drakkan/sftpgo/common" + "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" ) @@ -178,6 +179,7 @@ func (c *Configuration) Initialize(configDir string) error { certMgr = mgr } compressor := middleware.NewCompressor(5, "text/*") + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) serviceStatus = ServiceStatus{ Bindings: nil, diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index 0dc8953a..668eb708 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -808,11 +808,15 @@ func TestPreLoginHook(t *testing.T) { err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) // update the user to remove it from the cache + user.FsConfig.Provider = vfs.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) // update the user to remove it from the cache + user.FsConfig.Provider = vfs.LocalFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = nil user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.Status = 0 @@ -2037,11 +2041,13 @@ func TestPreLoginHookWithClientCert(t *testing.T) { err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) assert.NoError(t, err) // update the user to remove it from the cache + user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) client = getWebDavClient(user, true, tlsConfig) assert.Error(t, checkBasicFunc(client)) // update the user to remove it from the cache + user.Password = defaultPassword user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.Status = 0