From 796ea1dde9e2bf12bdffa4ba3ffdeaa4eef83de5 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 19 May 2022 19:49:51 +0200 Subject: [PATCH] allow to store temporary sessions within the data provider so we can persist password reset codes, OIDC auth sessions and tokens. These features will also work in multi-node setups without sicky sessions now Signed-off-by: Nicola Murino --- README.md | 1 + common/actions.go | 8 +- common/common.go | 4 +- common/common_test.go | 6 +- common/connection.go | 6 +- common/connection_test.go | 2 +- common/dataretention.go | 4 +- common/protocol_test.go | 2 +- common/ratelimiter.go | 2 +- config/config_test.go | 22 ++-- dataprovider/actions.go | 4 +- dataprovider/admin.go | 10 +- dataprovider/bolt.go | 40 ++++-- dataprovider/dataprovider.go | 83 +++++++++---- dataprovider/memory.go | 26 +++- dataprovider/mysql.go | 75 ++++++++++-- dataprovider/pgsql.go | 75 ++++++++++-- dataprovider/session.go | 34 ++++++ dataprovider/sqlcommon.go | 163 +++++++++++++++++++++---- dataprovider/sqlite.go | 76 ++++++++++-- dataprovider/sqlqueries.go | 32 +++++ dataprovider/user.go | 48 ++++---- docs/full-configuration.md | 2 +- ftpd/ftpd_test.go | 2 +- ftpd/server.go | 2 +- go.mod | 60 ++++----- go.sum | 155 +++++++++++------------ httpd/api_mfa.go | 2 +- httpd/api_shares.go | 4 +- httpd/api_utils.go | 22 ++-- httpd/auth_utils.go | 34 +++--- httpd/httpd.go | 13 +- httpd/httpd_test.go | 143 +++++++++++----------- httpd/internal_test.go | 80 ++++++++++-- httpd/middleware.go | 14 +-- httpd/oidc.go | 171 ++++---------------------- httpd/oidc_test.go | 204 +++++++++++++++++++++++++------ httpd/oidcmanager.go | 230 +++++++++++++++++++++++++++++++++++ httpd/resetcode.go | 105 ++++++++++++++-- httpd/server.go | 10 +- httpd/webadmin.go | 10 +- httpd/webclient.go | 10 +- httpdtest/httpdtest.go | 48 ++++---- init/sftpgo.service | 1 + logger/hclog_adapter.go | 14 +-- logger/logger.go | 34 +++--- logger/request_logger.go | 8 +- mfa/mfa.go | 2 +- mfa/totp.go | 2 +- plugin/kms.go | 4 +- plugin/notifier.go | 6 +- service/service.go | 3 +- service/service_portable.go | 2 +- sftpd/internal_test.go | 6 +- sftpd/mocks/middleware.go | 14 +-- sftpd/scp.go | 4 +- sftpd/server.go | 22 ++-- sftpd/sftpd_test.go | 16 +-- sftpd/ssh_cmd.go | 10 +- smtp/smtp.go | 4 +- util/util.go | 11 -- vfs/fileinfo.go | 2 +- vfs/s3fs.go | 2 +- vfs/sftpfs.go | 6 +- vfs/vfs.go | 4 +- webdavd/handler.go | 4 +- webdavd/server.go | 4 +- webdavd/webdavd_test.go | 2 +- 68 files changed, 1501 insertions(+), 730 deletions(-) create mode 100644 dataprovider/session.go create mode 100644 httpd/oidcmanager.go diff --git a/README.md b/README.md index 310f372d..9d6f3f94 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ SFTPGo is also available on [AWS Marketplace](https://aws.amazon.com/marketplace +On macOS you can install from the Homebrew [Formula](https://formulae.brew.sh/formula/sftpgo). On FreeBSD you can install from the [SFTPGo port](https://www.freshports.org/ftp/sftpgo). On DragonFlyBSD you can install SFTPGo from [DPorts](https://github.com/DragonFlyBSD/DPorts/tree/master/ftp/sftpgo). diff --git a/common/actions.go b/common/actions.go index a275eaa5..c1aa8976 100644 --- a/common/actions.go +++ b/common/actions.go @@ -66,7 +66,7 @@ func handleUnconfiguredPreAction(operation string) error { func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath string, fileSize int64, openFlags int) error { var event *notifier.FsEvent hasNotifiersPlugin := plugin.Handler.HasNotifiers() - hasHook := util.IsStringInSlice(operation, Config.Actions.ExecuteOn) + hasHook := util.Contains(Config.Actions.ExecuteOn, operation) if !hasHook && !hasNotifiersPlugin { return handleUnconfiguredPreAction(operation) } @@ -86,7 +86,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua fileSize int64, err error, ) { hasNotifiersPlugin := plugin.Handler.HasNotifiers() - hasHook := util.IsStringInSlice(operation, Config.Actions.ExecuteOn) + hasHook := util.Contains(Config.Actions.ExecuteOn, operation) if !hasHook && !hasNotifiersPlugin { return } @@ -97,7 +97,7 @@ func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtua } if hasHook { - if util.IsStringInSlice(operation, Config.Actions.ExecuteSync) { + if util.Contains(Config.Actions.ExecuteSync, operation) { actionHandler.Handle(notification) //nolint:errcheck return } @@ -168,7 +168,7 @@ func newActionNotification( type defaultActionHandler struct{} func (h *defaultActionHandler) Handle(event *notifier.FsEvent) error { - if !util.IsStringInSlice(event.Action, Config.Actions.ExecuteOn) { + if !util.Contains(Config.Actions.ExecuteOn, event.Action) { return errUnconfiguredAction } diff --git a/common/common.go b/common/common.go index 5556907b..ff35057d 100644 --- a/common/common.go +++ b/common/common.go @@ -163,7 +163,7 @@ func Initialize(c Configuration, isShared int) error { } } if c.DefenderConfig.Enabled { - if !util.IsStringInSlice(c.DefenderConfig.Driver, supportedDefenderDrivers) { + if !util.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) { return fmt.Errorf("unsupported defender driver %#v", c.DefenderConfig.Driver) } var defender Defender @@ -635,7 +635,7 @@ func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username, if c.PostDisconnectHook == "" { return } - if !util.IsStringInSlice(protocol, disconnHookProtocols) { + if !util.Contains(disconnHookProtocols, protocol) { return } go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime) diff --git a/common/common_test.go b/common/common_test.go index 8d975fc9..efc3f8f9 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -885,8 +885,8 @@ func TestFolderCopy(t *testing.T) { folder.ID = 2 folder.Users = []string{"user3"} require.Len(t, folderCopy.Users, 2) - require.True(t, util.IsStringInSlice("user1", folderCopy.Users)) - require.True(t, util.IsStringInSlice("user2", folderCopy.Users)) + require.True(t, util.Contains(folderCopy.Users, "user1")) + require.True(t, util.Contains(folderCopy.Users, "user2")) require.Equal(t, int64(1), folderCopy.ID) require.Equal(t, folder.Name, folderCopy.Name) require.Equal(t, folder.MappedPath, folderCopy.MappedPath) @@ -902,7 +902,7 @@ func TestFolderCopy(t *testing.T) { folderCopy = folder.GetACopy() folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() require.Len(t, folderCopy.Users, 1) - require.True(t, util.IsStringInSlice("user3", folderCopy.Users)) + require.True(t, util.Contains(folderCopy.Users, "user3")) require.Equal(t, int64(2), folderCopy.ID) require.Equal(t, folder.Name, folderCopy.Name) require.Equal(t, folder.MappedPath, folderCopy.MappedPath) diff --git a/common/connection.go b/common/connection.go index 4d9d23a4..562ed06e 100644 --- a/common/connection.go +++ b/common/connection.go @@ -44,7 +44,7 @@ type BaseConnection struct { // NewBaseConnection returns a new BaseConnection func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprovider.User) *BaseConnection { connID := id - if util.IsStringInSlice(protocol, supportedProtocols) { + if util.Contains(supportedProtocols, protocol) { connID = fmt.Sprintf("%s_%s", protocol, id) } user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID) @@ -61,7 +61,7 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov } // Log outputs a log entry to the configured logger -func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...interface{}) { +func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) { logger.Log(level, c.protocol, c.ID, format, v...) } @@ -98,7 +98,7 @@ func (c *BaseConnection) GetRemoteIP() string { // SetProtocol sets the protocol for this connection func (c *BaseConnection) SetProtocol(protocol string) { c.protocol = protocol - if util.IsStringInSlice(c.protocol, supportedProtocols) { + if util.Contains(supportedProtocols, c.protocol) { c.ID = fmt.Sprintf("%v_%v", c.protocol, c.ID) } } diff --git a/common/connection_test.go b/common/connection_test.go index f14a8c23..74fa3659 100644 --- a/common/connection_test.go +++ b/common/connection_test.go @@ -293,7 +293,7 @@ func TestErrorsMapping(t *testing.T) { err := conn.GetFsError(fs, os.ErrNotExist) if protocol == ProtocolSFTP { assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) - } else if util.IsStringInSlice(protocol, osErrorsProtocols) { + } else if util.Contains(osErrorsProtocols, protocol) { assert.EqualError(t, err, os.ErrNotExist.Error()) } else { assert.EqualError(t, err, ErrNotExist.Error()) diff --git a/common/dataretention.go b/common/dataretention.go index 9c9f1aa8..ddfa1c5c 100644 --- a/common/dataretention.go +++ b/common/dataretention.go @@ -364,7 +364,7 @@ func (c *RetentionCheck) sendNotifications(elapsed time.Duration, err error) { func (c *RetentionCheck) sendEmailNotification(elapsed time.Duration, errCheck error) error { body := new(bytes.Buffer) - data := make(map[string]interface{}) + data := make(map[string]any) data["Results"] = c.results totalDeletedFiles := 0 totalDeletedSize := int64(0) @@ -399,7 +399,7 @@ func (c *RetentionCheck) sendEmailNotification(elapsed time.Duration, errCheck e } func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck error) error { - data := make(map[string]interface{}) + data := make(map[string]any) totalDeletedFiles := 0 totalDeletedSize := int64(0) for _, result := range c.results { diff --git a/common/protocol_test.go b/common/protocol_test.go index 5b6930bd..4873f7ea 100644 --- a/common/protocol_test.go +++ b/common/protocol_test.go @@ -125,7 +125,7 @@ func TestMain(m *testing.M) { }() go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } diff --git a/common/ratelimiter.go b/common/ratelimiter.go index d72266e6..1f68b4f0 100644 --- a/common/ratelimiter.go +++ b/common/ratelimiter.go @@ -83,7 +83,7 @@ func (r *RateLimiterConfig) validate() error { } r.Protocols = util.RemoveDuplicates(r.Protocols) for _, protocol := range r.Protocols { - if !util.IsStringInSlice(protocol, rateLimiterProtocolValues) { + if !util.Contains(rateLimiterProtocolValues, protocol) { return fmt.Errorf("invalid protocol %#v", protocol) } } diff --git a/config/config_test.go b/config/config_test.go index 3b962f80..b2253ff7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -294,7 +294,7 @@ func TestDefenderProviderDriver(t *testing.T) { commonConfig := config.GetCommonConfig() commonConfig.DefenderConfig.Enabled = true commonConfig.DefenderConfig.Driver = common.DefenderDriverProvider - c := make(map[string]interface{}) + c := make(map[string]any) c["common"] = commonConfig c["data_provider"] = providerConf jsonConf, err := json.Marshal(c) @@ -524,8 +524,8 @@ func TestPluginsFromEnv(t *testing.T) { pluginConf := pluginsConf[0] require.Equal(t, "notifier", pluginConf.Type) require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) - require.True(t, util.IsStringInSlice("upload", pluginConf.NotifierOptions.FsEvents)) - require.True(t, util.IsStringInSlice("download", pluginConf.NotifierOptions.FsEvents)) + require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) + require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download")) require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) @@ -563,8 +563,8 @@ func TestPluginsFromEnv(t *testing.T) { pluginConf = pluginsConf[0] require.Equal(t, "notifier", pluginConf.Type) require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) - require.True(t, util.IsStringInSlice("upload", pluginConf.NotifierOptions.FsEvents)) - require.True(t, util.IsStringInSlice("download", pluginConf.NotifierOptions.FsEvents)) + require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) + require.True(t, util.Contains(pluginConf.NotifierOptions.FsEvents, "download")) require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) @@ -624,8 +624,8 @@ func TestRateLimitersFromEnv(t *testing.T) { require.Equal(t, 2, limiters[0].Type) protocols := limiters[0].Protocols require.Len(t, protocols, 2) - require.True(t, util.IsStringInSlice(common.ProtocolFTP, protocols)) - require.True(t, util.IsStringInSlice(common.ProtocolSSH, protocols)) + require.True(t, util.Contains(protocols, common.ProtocolFTP)) + require.True(t, util.Contains(protocols, common.ProtocolSSH)) require.True(t, limiters[0].GenerateDefenderEvents) require.Equal(t, 50, limiters[0].EntriesSoftLimit) require.Equal(t, 100, limiters[0].EntriesHardLimit) @@ -641,10 +641,10 @@ func TestRateLimitersFromEnv(t *testing.T) { require.Equal(t, 2, limiters[1].Type) protocols = limiters[1].Protocols require.Len(t, protocols, 4) - require.True(t, util.IsStringInSlice(common.ProtocolFTP, protocols)) - require.True(t, util.IsStringInSlice(common.ProtocolSSH, protocols)) - require.True(t, util.IsStringInSlice(common.ProtocolWebDAV, protocols)) - require.True(t, util.IsStringInSlice(common.ProtocolHTTP, protocols)) + require.True(t, util.Contains(protocols, common.ProtocolFTP)) + require.True(t, util.Contains(protocols, common.ProtocolSSH)) + require.True(t, util.Contains(protocols, common.ProtocolWebDAV)) + require.True(t, util.Contains(protocols, common.ProtocolHTTP)) require.False(t, limiters[1].GenerateDefenderEvents) require.Equal(t, 100, limiters[1].EntriesSoftLimit) require.Equal(t, 150, limiters[1].EntriesHardLimit) diff --git a/dataprovider/actions.go b/dataprovider/actions.go index a14ea5dd..04c08d37 100644 --- a/dataprovider/actions.go +++ b/dataprovider/actions.go @@ -49,8 +49,8 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec if config.Actions.Hook == "" { return } - if !util.IsStringInSlice(operation, config.Actions.ExecuteOn) || - !util.IsStringInSlice(objectType, config.Actions.ExecuteFor) { + if !util.Contains(config.Actions.ExecuteOn, operation) || + !util.Contains(config.Actions.ExecuteFor, objectType) { return } diff --git a/dataprovider/admin.go b/dataprovider/admin.go index 4bbc030b..b82aebb9 100644 --- a/dataprovider/admin.go +++ b/dataprovider/admin.go @@ -68,7 +68,7 @@ func (c *AdminTOTPConfig) validate(username string) error { if c.ConfigName == "" { return util.NewValidationError("totp: config name is mandatory") } - if !util.IsStringInSlice(c.ConfigName, mfa.GetAvailableTOTPConfigNames()) { + if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { return util.NewValidationError(fmt.Sprintf("totp: config name %#v not found", c.ConfigName)) } if c.Secret.IsEmpty() { @@ -182,11 +182,11 @@ func (a *Admin) validatePermissions() error { if len(a.Permissions) == 0 { return util.NewValidationError("please grant some permissions to this admin") } - if util.IsStringInSlice(PermAdminAny, a.Permissions) { + if util.Contains(a.Permissions, PermAdminAny) { a.Permissions = []string{PermAdminAny} } for _, perm := range a.Permissions { - if !util.IsStringInSlice(perm, validAdminPerms) { + if !util.Contains(validAdminPerms, perm) { return util.NewValidationError(fmt.Sprintf("invalid permission: %#v", perm)) } } @@ -345,10 +345,10 @@ func (a *Admin) SetNilSecretsIfEmpty() { // HasPermission returns true if the admin has the specified permission func (a *Admin) HasPermission(perm string) bool { - if util.IsStringInSlice(PermAdminAny, a.Permissions) { + if util.Contains(a.Permissions, PermAdminAny) { return true } - return util.IsStringInSlice(perm, a.Permissions) + return util.Contains(a.Permissions, perm) } // GetPermissionsAsString returns permission as string diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index b756a927..9cc48e12 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -20,7 +20,7 @@ import ( ) const ( - boltDatabaseVersion = 18 + boltDatabaseVersion = 19 ) var ( @@ -1916,6 +1916,22 @@ func (p *BoltProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, err return nil, ErrNotImplemented } +func (p *BoltProvider) addSharedSession(session Session) error { + return ErrNotImplemented +} + +func (p *BoltProvider) deleteSharedSession(key string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) getSharedSession(key string) (Session, error) { + return Session{}, ErrNotImplemented +} + +func (p *BoltProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return ErrNotImplemented +} + func (p *BoltProvider) close() error { return p.dbHandle.Close() } @@ -1943,13 +1959,13 @@ func (p *BoltProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err - case version == 15, version == 16, version == 17: - logger.InfoToConsole(fmt.Sprintf("updating database version: %v -> 18", version)) - providerLog(logger.LevelInfo, "updating database version: %v -> 18", version) + case version == 15, version == 16, version == 17, version == 18: + logger.InfoToConsole(fmt.Sprintf("updating database version: %v -> 19", version)) + providerLog(logger.LevelInfo, "updating database version: %v -> 19", version) if err = importGCSCredentials(); err != nil { return err } - return updateBoltDatabaseVersion(p.dbHandle, 18) + return updateBoltDatabaseVersion(p.dbHandle, 19) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -1971,7 +1987,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { - case 16, 17, 18: + case 16, 17, 18, 19: return updateBoltDatabaseVersion(p.dbHandle, 15) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) @@ -2081,7 +2097,7 @@ func (p *BoltProvider) addUserToGroupMapping(username, groupname string, bucket if err != nil { return err } - if !util.IsStringInSlice(username, group.Users) { + if !util.Contains(group.Users, username) { group.Users = append(group.Users, username) buf, err := json.Marshal(group) if err != nil { @@ -2102,7 +2118,7 @@ func (p *BoltProvider) removeUserFromGroupMapping(username, groupname string, bu if err != nil { return err } - if util.IsStringInSlice(username, group.Users) { + if util.Contains(group.Users, username) { var users []string for _, u := range group.Users { if u != username { @@ -2145,10 +2161,10 @@ func (p *BoltProvider) addRelationToFolderMapping(baseFolder *vfs.BaseVirtualFol baseFolder.UsedQuotaSize = oldFolder.UsedQuotaSize baseFolder.Users = oldFolder.Users baseFolder.Groups = oldFolder.Groups - if user != nil && !util.IsStringInSlice(user.Username, baseFolder.Users) { + if user != nil && !util.Contains(baseFolder.Users, user.Username) { baseFolder.Users = append(baseFolder.Users, user.Username) } - if group != nil && !util.IsStringInSlice(group.Name, baseFolder.Groups) { + if group != nil && !util.Contains(baseFolder.Groups, group.Name) { baseFolder.Groups = append(baseFolder.Groups, group.Name) } buf, err := json.Marshal(baseFolder) @@ -2172,7 +2188,7 @@ func (p *BoltProvider) removeRelationFromFolderMapping(folder vfs.VirtualFolder, return err } found := false - if username != "" && util.IsStringInSlice(username, baseFolder.Users) { + if username != "" && util.Contains(baseFolder.Users, username) { found = true var newUserMapping []string for _, u := range baseFolder.Users { @@ -2182,7 +2198,7 @@ func (p *BoltProvider) removeRelationFromFolderMapping(folder vfs.VirtualFolder, } baseFolder.Users = newUserMapping } - if groupname != "" && util.IsStringInSlice(groupname, baseFolder.Groups) { + if groupname != "" && util.Contains(baseFolder.Groups, groupname) { found = true var newGroupMapping []string for _, g := range baseFolder.Groups { diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index e5233a62..649d93fb 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -167,6 +167,7 @@ var ( sqlTableGroups = "groups" sqlTableUsersGroupsMapping = "users_groups_mapping" sqlTableGroupsFoldersMapping = "groups_folders_mapping" + sqlTableSharedSessions = "shared_sessions" sqlTableSchemaVersion = "schema_version" argon2Params *argon2id.Params lastLoginMinDelay = 10 * time.Minute @@ -409,7 +410,7 @@ type Config struct { // GetShared returns the provider share mode func (c *Config) GetShared() int { - if !util.IsStringInSlice(c.Driver, sharedProviders) { + if !util.Contains(sharedProviders, c.Driver) { return 0 } return c.IsShared @@ -686,6 +687,10 @@ type Provider interface { removeActiveTransfer(transferID int64, connectionID string) error cleanupActiveTransfers(before time.Time) error getActiveTransfers(from time.Time) ([]ActiveTransfer, error) + addSharedSession(session Session) error + deleteSharedSession(key string) error + getSharedSession(key string) (Session, error) + cleanupSharedSessions(sessionType SessionType, before int64) error checkAvailability() error close() error reloadConfig() error @@ -1138,7 +1143,7 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard // after a successful authentication with an external identity provider. // If a pre-login hook is defined it will be executed so the SFTPGo user // can be created if it does not exist -func GetUserAfterIDPAuth(username, ip, protocol string, oidcTokenFields *map[string]interface{}) (User, error) { +func GetUserAfterIDPAuth(username, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { var user User var err error if config.PreLoginHook != "" { @@ -1643,6 +1648,42 @@ func GetActiveTransfers(from time.Time) ([]ActiveTransfer, error) { return provider.getActiveTransfers(from) } +// AddSharedSession stores a new session within the data provider +func AddSharedSession(session Session) error { + err := provider.addSharedSession(session) + if err != nil { + providerLog(logger.LevelError, "unable to add shared session, key %#v, type: %v, err: %v", + session.Key, session.Type, err) + } + return err +} + +// DeleteSharedSession deletes the session with the specified key +func DeleteSharedSession(key string) error { + err := provider.deleteSharedSession(key) + if err != nil { + providerLog(logger.LevelError, "unable to add shared session, key %#v, err: %v", key, err) + } + return err +} + +// GetSharedSession retrieves the session with the specified key +func GetSharedSession(key string) (Session, error) { + return provider.getSharedSession(key) +} + +// CleanupSharedSessions removes the shared session with the specified type and +// before the specified time +func CleanupSharedSessions(sessionType SessionType, before time.Time) error { + err := provider.cleanupSharedSessions(sessionType, util.GetTimeAsMsSinceEpoch(before)) + if err == nil { + providerLog(logger.LevelDebug, "deleted shared sessions before: %v, type: %v", before, sessionType) + } else { + providerLog(logger.LevelError, "error deleting shared session before %v, type %v: %v", before, sessionType, err) + } + return err +} + // ReloadConfig reloads provider configuration. // Currently only implemented for memory provider, allows to reload the users // from the configured file, if defined @@ -2047,7 +2088,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error { if c.ConfigName == "" { return util.NewValidationError("totp: config name is mandatory") } - if !util.IsStringInSlice(c.ConfigName, mfa.GetAvailableTOTPConfigNames()) { + if !util.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { return util.NewValidationError(fmt.Sprintf("totp: config name %#v not found", c.ConfigName)) } if c.Secret.IsEmpty() { @@ -2063,7 +2104,7 @@ func validateUserTOTPConfig(c *UserTOTPConfig, username string) error { return util.NewValidationError("totp: specify at least one protocol") } for _, protocol := range c.Protocols { - if !util.IsStringInSlice(protocol, MFAProtocols) { + if !util.Contains(MFAProtocols, protocol) { return util.NewValidationError(fmt.Sprintf("totp: invalid protocol %#v", protocol)) } } @@ -2096,7 +2137,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str return permissions, util.NewValidationError("invalid permissions") } for _, p := range perms { - if !util.IsStringInSlice(p, ValidPerms) { + if !util.Contains(ValidPerms, p) { return permissions, util.NewValidationError(fmt.Sprintf("invalid permission: %#v", p)) } } @@ -2110,7 +2151,7 @@ func validateUserPermissions(permsToCheck map[string][]string) (map[string][]str if dir != cleanedDir && cleanedDir == "/" { return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for invalid subdirectory: %#v is an alias for \"/\"", dir)) } - if util.IsStringInSlice(PermAny, perms) { + if util.Contains(perms, PermAny) { permissions[cleanedDir] = []string{PermAny} } else { permissions[cleanedDir] = util.RemoveDuplicates(perms) @@ -2166,7 +2207,7 @@ func validateFiltersPatternExtensions(baseFilters *sdk.BaseUserFilters) error { if !path.IsAbs(cleanedPath) { return util.NewValidationError(fmt.Sprintf("invalid path %#v for file patterns filter", f.Path)) } - if util.IsStringInSlice(cleanedPath, filteredPaths) { + if util.Contains(filteredPaths, cleanedPath) { return util.NewValidationError(fmt.Sprintf("duplicate file patterns filter for path %#v", f.Path)) } if len(f.AllowedPatterns) == 0 && len(f.DeniedPatterns) == 0 { @@ -2296,13 +2337,13 @@ func validateFilterProtocols(filters *sdk.BaseUserFilters) error { return util.NewValidationError("invalid denied_protocols") } for _, p := range filters.DeniedProtocols { - if !util.IsStringInSlice(p, ValidProtocols) { + if !util.Contains(ValidProtocols, p) { return util.NewValidationError(fmt.Sprintf("invalid denied protocol %#v", p)) } } for _, p := range filters.TwoFactorAuthProtocols { - if !util.IsStringInSlice(p, MFAProtocols) { + if !util.Contains(MFAProtocols, p) { return util.NewValidationError(fmt.Sprintf("invalid two factor protocol %#v", p)) } } @@ -2324,7 +2365,7 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error { return util.NewValidationError("invalid denied_login_methods") } for _, loginMethod := range filters.DeniedLoginMethods { - if !util.IsStringInSlice(loginMethod, ValidLoginMethods) { + if !util.Contains(ValidLoginMethods, loginMethod) { return util.NewValidationError(fmt.Sprintf("invalid login method: %#v", loginMethod)) } } @@ -2332,12 +2373,12 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error { return err } if filters.TLSUsername != "" { - if !util.IsStringInSlice(string(filters.TLSUsername), validTLSUsernames) { + if !util.Contains(validTLSUsernames, string(filters.TLSUsername)) { return util.NewValidationError(fmt.Sprintf("invalid TLS username: %#v", filters.TLSUsername)) } } for _, opts := range filters.WebClient { - if !util.IsStringInSlice(opts, sdk.WebClientOptions) { + if !util.Contains(sdk.WebClientOptions, opts) { return util.NewValidationError(fmt.Sprintf("invalid web client options %#v", opts)) } } @@ -2481,7 +2522,7 @@ func ValidateUser(user *User) error { if !user.HasExternalAuth() { user.Filters.ExternalAuthCacheTime = 0 } - if user.Filters.TOTPConfig.Enabled && util.IsStringInSlice(sdk.WebClientMFADisabled, user.Filters.WebClient) { + if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) { return util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration") } return nil @@ -2622,7 +2663,7 @@ func checkUserPasscode(user *User, password, protocol string) (string, error) { if user.Filters.TOTPConfig.Enabled { switch protocol { case protocolFTP: - if util.IsStringInSlice(protocol, user.Filters.TOTPConfig.Protocols) { + if util.Contains(user.Filters.TOTPConfig.Protocols, protocol) { // the TOTP passcode has six digits pwdLen := len(password) if pwdLen < 7 { @@ -2810,7 +2851,7 @@ func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractive if err != nil { return 0, err } - if !user.Filters.TOTPConfig.Enabled || !util.IsStringInSlice(protocolSSH, user.Filters.TOTPConfig.Protocols) { + if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { return 1, nil } err = user.Filters.TOTPConfig.Secret.TryDecrypt() @@ -2934,7 +2975,7 @@ func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, resp } if len(answers) == 1 && response.CheckPwd > 0 { if response.CheckPwd == 2 { - if !user.Filters.TOTPConfig.Enabled || !util.IsStringInSlice(protocolSSH, user.Filters.TOTPConfig.Protocols) { + if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to check TOTP passcode, TOTP is not enabled for user %#v", user.Username) return answers, errors.New("TOTP not enabled for SSH protocol") @@ -3190,7 +3231,7 @@ func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte return cmd.Output() } -func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFields *map[string]interface{}) (User, error) { +func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, oidcTokenFields) if err != nil { return u, err @@ -3340,7 +3381,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, } if strings.HasPrefix(config.ExternalAuthHook, "http") { var result []byte - authRequest := make(map[string]interface{}) + authRequest := make(map[string]any) authRequest["username"] = username authRequest["ip"] = ip authRequest["password"] = password @@ -3552,7 +3593,7 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string, return provider.userExists(user.Username) } -func getUserForHook(username string, oidcTokenFields *map[string]interface{}) (User, User, error) { +func getUserForHook(username string, oidcTokenFields *map[string]any) (User, User, error) { u, err := provider.userExists(username) if err != nil { if _, ok := err.(*util.RecordNotFoundError); !ok { @@ -3575,7 +3616,7 @@ func getUserForHook(username string, oidcTokenFields *map[string]interface{}) (U return u, mergedUser, err } -func getUserAndJSONForHook(username string, oidcTokenFields *map[string]interface{}) (User, User, []byte, error) { +func getUserAndJSONForHook(username string, oidcTokenFields *map[string]any) (User, User, []byte, error) { u, mergedUser, err := getUserForHook(username, oidcTokenFields) if err != nil { return u, mergedUser, nil, err @@ -3689,6 +3730,6 @@ func getConfigPath(name, configDir string) string { return name } -func providerLog(level logger.LogLevel, format string, v ...interface{}) { +func providerLog(level logger.LogLevel, format string, v ...any) { logger.Log(level, logSender, "", format, v...) } diff --git a/dataprovider/memory.go b/dataprovider/memory.go index 80c8373d..3410e296 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -988,7 +988,7 @@ func (p *MemoryProvider) addUserFromGroupMapping(username, groupname string) err if err != nil { return err } - if !util.IsStringInSlice(username, g.Users) { + if !util.Contains(g.Users, username) { g.Users = append(g.Users, username) p.dbHandle.groups[groupname] = g } @@ -1000,7 +1000,7 @@ func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) if err != nil { return err } - if util.IsStringInSlice(username, g.Users) { + if util.Contains(g.Users, username) { var users []string for _, u := range g.Users { if u != username { @@ -1069,7 +1069,7 @@ func (p *MemoryProvider) removeRelationFromFolderMapping(folderName, username, g func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFolder) { p.dbHandle.vfolders[folder.Name] = folder - if !util.IsStringInSlice(folder.Name, p.dbHandle.vfoldersNames) { + if !util.Contains(p.dbHandle.vfoldersNames, folder.Name) { p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name) sort.Strings(p.dbHandle.vfoldersNames) } @@ -1084,10 +1084,10 @@ func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFo folder.MappedPath = baseFolder.MappedPath folder.Description = baseFolder.Description folder.FsConfig = baseFolder.FsConfig.GetACopy() - if username != "" && !util.IsStringInSlice(username, folder.Users) { + if username != "" && !util.Contains(folder.Users, username) { folder.Users = append(folder.Users, username) } - if groupname != "" && !util.IsStringInSlice(groupname, folder.Groups) { + if groupname != "" && !util.Contains(folder.Groups, groupname) { folder.Groups = append(folder.Groups, groupname) } p.updateFoldersMappingInternal(folder) @@ -1752,6 +1752,22 @@ func (p *MemoryProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, e return nil, ErrNotImplemented } +func (p *MemoryProvider) addSharedSession(session Session) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) deleteSharedSession(key string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) getSharedSession(key string) (Session, error) { + return Session{}, ErrNotImplemented +} + +func (p *MemoryProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return ErrNotImplemented +} + func (p *MemoryProvider) getNextID() int64 { nextID := int64(1) for _, v := range p.dbHandle.users { diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index f9abcadc..5521ddb1 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -35,6 +35,7 @@ const ( "DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" + "DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + @@ -152,6 +153,11 @@ const ( "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_folder_id_fk_folders_id` " + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" + mysqlV19SQL = "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " + + "`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" + + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + mysqlV19DownSQL = "DROP TABLE `{{shared_sessions}}` CASCADE;" ) // MySQLProvider defines the auth provider for MySQL/MariaDB database @@ -520,6 +526,22 @@ func (p *MySQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, er return sqlCommonGetActiveTransfers(from, p.dbHandle) } +func (p *MySQLProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *MySQLProvider) deleteSharedSession(key string) error { + return sqlCommonDeleteSession(key, p.dbHandle) +} + +func (p *MySQLProvider) getSharedSession(key string) (Session, error) { + return sqlCommonGetSession(key, p.dbHandle) +} + +func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + func (p *MySQLProvider) close() error { return p.dbHandle.Close() } @@ -550,10 +572,10 @@ func (p *MySQLProvider) initializeDatabase() error { initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts) initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 15) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 15, true) } -func (p *MySQLProvider) migrateDatabase() error { +func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err @@ -574,6 +596,8 @@ func (p *MySQLProvider) migrateDatabase() error { return updateMySQLDatabaseFromV16(p.dbHandle) case version == 17: return updateMySQLDatabaseFromV17(p.dbHandle) + case version == 18: + return updateMySQLDatabaseFromV18(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -602,6 +626,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { return downgradeMySQLDatabaseFromV17(p.dbHandle) case 18: return downgradeMySQLDatabaseFromV18(p.dbHandle) + case 19: + return downgradeMySQLDatabaseFromV19(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -609,7 +635,7 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { func (p *MySQLProvider) resetDatabase() error { sql := sqlReplaceAll(mysqlResetSQL) - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false) } func updateMySQLDatabaseFromV15(dbHandle *sql.DB) error { @@ -627,7 +653,14 @@ func updateMySQLDatabaseFromV16(dbHandle *sql.DB) error { } func updateMySQLDatabaseFromV17(dbHandle *sql.DB) error { - return updateMySQLDatabaseFrom17To18(dbHandle) + if err := updateMySQLDatabaseFrom17To18(dbHandle); err != nil { + return err + } + return updateMySQLDatabaseFromV18(dbHandle) +} + +func updateMySQLDatabaseFromV18(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom18To19(dbHandle) } func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error { @@ -648,13 +681,20 @@ func downgradeMySQLDatabaseFromV18(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFromV17(dbHandle) } +func downgradeMySQLDatabaseFromV19(dbHandle *sql.DB) error { + if err := downgradeMySQLDatabaseFrom19To18(dbHandle); err != nil { + return err + } + return downgradeMySQLDatabaseFromV18(dbHandle) +} + func updateMySQLDatabaseFrom15To16(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 15 -> 16") providerLog(logger.LevelInfo, "updating database version: 15 -> 16") sql := strings.ReplaceAll(mysqlV16SQL, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, true) } func updateMySQLDatabaseFrom16To17(dbHandle *sql.DB) error { @@ -668,7 +708,7 @@ func updateMySQLDatabaseFrom16To17(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 17) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 17, true) } func updateMySQLDatabaseFrom17To18(dbHandle *sql.DB) error { @@ -677,7 +717,15 @@ func updateMySQLDatabaseFrom17To18(dbHandle *sql.DB) error { if err := importGCSCredentials(); err != nil { return err } - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true) +} + +func updateMySQLDatabaseFrom18To19(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 18 -> 19") + providerLog(logger.LevelInfo, "updating database version: 18 -> 19") + sql := strings.ReplaceAll(mysqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 19, true) } func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error { @@ -685,7 +733,7 @@ func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error { providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") sql := strings.ReplaceAll(mysqlV16DownSQL, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15, false) } func downgradeMySQLDatabaseFrom17To16(dbHandle *sql.DB) error { @@ -699,11 +747,18 @@ func downgradeMySQLDatabaseFrom17To16(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, false) } func downgradeMySQLDatabaseFrom18To17(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 18 -> 17") providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false) +} + +func downgradeMySQLDatabaseFrom19To18(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 19 -> 18") + providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18") + sql := strings.ReplaceAll(mysqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false) } diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 026e5c0c..997933c6 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -34,6 +34,7 @@ DROP TABLE IF EXISTS "{{groups}}" CASCADE; DROP TABLE IF EXISTS "{{defender_events}}" CASCADE; DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE; DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE; +DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE; DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; ` pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL); @@ -158,6 +159,11 @@ ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_mapping" UNIQ CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id"); ` + pgsqlV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, +"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");` + pgsqlV19DownSQL = `DROP TABLE "{{shared_sessions}}" CASCADE;` ) // PGSQLProvider defines the auth provider for PostgreSQL database @@ -489,6 +495,22 @@ func (p *PGSQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, er return sqlCommonGetActiveTransfers(from, p.dbHandle) } +func (p *PGSQLProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *PGSQLProvider) deleteSharedSession(key string) error { + return sqlCommonDeleteSession(key, p.dbHandle) +} + +func (p *PGSQLProvider) getSharedSession(key string) (Session, error) { + return sqlCommonGetSession(key, p.dbHandle) +} + +func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + func (p *PGSQLProvider) close() error { return p.dbHandle.Close() } @@ -525,10 +547,10 @@ func (p *PGSQLProvider) initializeDatabase() error { initialSQL = strings.ReplaceAll(initialSQL, "DEFERRABLE INITIALLY DEFERRED", "") } - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true) } -func (p *PGSQLProvider) migrateDatabase() error { +func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err @@ -549,6 +571,8 @@ func (p *PGSQLProvider) migrateDatabase() error { return updatePGSQLDatabaseFromV16(p.dbHandle) case version == 17: return updatePGSQLDatabaseFromV17(p.dbHandle) + case version == 18: + return updatePGSQLDatabaseFromV18(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -577,6 +601,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { return downgradePGSQLDatabaseFromV17(p.dbHandle) case 18: return downgradePGSQLDatabaseFromV18(p.dbHandle) + case 19: + return downgradePGSQLDatabaseFromV19(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -584,7 +610,7 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { func (p *PGSQLProvider) resetDatabase() error { sql := sqlReplaceAll(pgsqlResetSQL) - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) } func updatePGSQLDatabaseFromV15(dbHandle *sql.DB) error { @@ -602,7 +628,14 @@ func updatePGSQLDatabaseFromV16(dbHandle *sql.DB) error { } func updatePGSQLDatabaseFromV17(dbHandle *sql.DB) error { - return updatePGSQLDatabaseFrom17To18(dbHandle) + if err := updatePGSQLDatabaseFrom17To18(dbHandle); err != nil { + return err + } + return updatePGSQLDatabaseFromV18(dbHandle) +} + +func updatePGSQLDatabaseFromV18(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom18To19(dbHandle) } func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error { @@ -623,6 +656,13 @@ func downgradePGSQLDatabaseFromV18(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFromV17(dbHandle) } +func downgradePGSQLDatabaseFromV19(dbHandle *sql.DB) error { + if err := downgradePGSQLDatabaseFrom19To18(dbHandle); err != nil { + return err + } + return downgradePGSQLDatabaseFromV18(dbHandle) +} + func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 15 -> 16") providerLog(logger.LevelInfo, "updating database version: 15 -> 16") @@ -645,7 +685,7 @@ func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error { } return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 16) } - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true) } func updatePGSQLDatabaseFrom16To17(dbHandle *sql.DB) error { @@ -664,7 +704,7 @@ func updatePGSQLDatabaseFrom16To17(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true) } func updatePGSQLDatabaseFrom17To18(dbHandle *sql.DB) error { @@ -673,7 +713,15 @@ func updatePGSQLDatabaseFrom17To18(dbHandle *sql.DB) error { if err := importGCSCredentials(); err != nil { return err } - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true) +} + +func updatePGSQLDatabaseFrom18To19(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 18 -> 19") + providerLog(logger.LevelInfo, "updating database version: 18 -> 19") + sql := strings.ReplaceAll(pgsqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true) } func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error { @@ -681,7 +729,7 @@ func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error { providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") sql := strings.ReplaceAll(pgsqlV16DownSQL, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false) } func downgradePGSQLDatabaseFrom17To16(dbHandle *sql.DB) error { @@ -700,11 +748,18 @@ func downgradePGSQLDatabaseFrom17To16(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false) } func downgradePGSQLDatabaseFrom18To17(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 18 -> 17") providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false) +} + +func downgradePGSQLDatabaseFrom19To18(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 19 -> 18") + providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18") + sql := strings.ReplaceAll(pgsqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false) } diff --git a/dataprovider/session.go b/dataprovider/session.go new file mode 100644 index 00000000..05da0984 --- /dev/null +++ b/dataprovider/session.go @@ -0,0 +1,34 @@ +package dataprovider + +import ( + "errors" + "fmt" +) + +// SessionType defines the supported session types +type SessionType int + +// Supported session types +const ( + SessionTypeOIDCAuth SessionType = iota + 1 + SessionTypeOIDCToken + SessionTypeResetCode +) + +// Session defines a shared session persisted in the data provider +type Session struct { + Key string + Data any + Type SessionType + Timestamp int64 +} + +func (s *Session) validate() error { + if s.Key == "" { + return errors.New("unable to save a session with an empty key") + } + if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeResetCode { + return fmt.Errorf("invalid session type: %v", s.Type) + } + return nil +} diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 9c2b270a..865e34fe 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -20,7 +20,7 @@ import ( ) const ( - sqlDatabaseVersion = 18 + sqlDatabaseVersion = 19 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -37,7 +37,7 @@ type sqlQuerier interface { } type sqlScanner interface { - Scan(dest ...interface{}) error + Scan(dest ...any) error } func sqlReplaceAll(sql string) string { @@ -203,8 +203,11 @@ func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, share.ShareID) - return err + res, err := stmt.ExecContext(ctx, share.ShareID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) { @@ -352,8 +355,11 @@ func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, apiKey.KeyID) - return err + res, err := stmt.ExecContext(ctx, apiKey.KeyID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) { @@ -532,8 +538,11 @@ func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, admin.Username) - return err + res, err := stmt.ExecContext(ctx, admin.Username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) { @@ -667,7 +676,7 @@ func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, e } defer stmt.Close() - args := make([]interface{}, 0, len(names)) + args := make([]any, 0, len(names)) for _, name := range names { args = append(args, name) } @@ -705,7 +714,7 @@ func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, } defer stmt.Close() - args := make([]interface{}, 0, len(names)) + args := make([]any, 0, len(names)) for _, name := range names { args = append(args, name) } @@ -849,8 +858,11 @@ func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, group.Name) - return err + res, err := stmt.ExecContext(ctx, group.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) { @@ -1206,8 +1218,11 @@ func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, user.ID) - return err + res, err := stmt.ExecContext(ctx, user.ID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) { @@ -1389,7 +1404,7 @@ func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier } defer stmt.Close() - queryArgs := make([]interface{}, 0, len(usernames)) + queryArgs := make([]any, 0, len(usernames)) for idx := range usernames { queryArgs = append(queryArgs, usernames[idx]) } @@ -1730,11 +1745,12 @@ func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error { return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, ip) + res, err := stmt.ExecContext(ctx, ip) if err != nil { providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err) + return err } - return err + return sqlCommonRequireRowAffected(res) } func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error { @@ -2160,8 +2176,11 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, folder.ID) - return err + res, err := stmt.ExecContext(ctx, folder.ID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) } func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { @@ -2911,6 +2930,86 @@ func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, return userID, adminID, nil } +func sqlCommonAddSession(session Session, dbHandle *sql.DB) error { + if err := session.validate(); err != nil { + return err + } + data, err := json.Marshal(session.Data) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddSessionQuery() + + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + + _, err = stmt.ExecContext(ctx, session.Key, data, session.Type, session.Timestamp) + return err +} + +func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) { + var session Session + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSessionQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return session, err + } + defer stmt.Close() + + var data []byte // type hint, some driver will use string instead of []byte if the type is any + err = stmt.QueryRowContext(ctx, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp) + if err != nil { + return session, err + } + session.Data = data + return session, nil +} + +func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteSessionQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + res, err := stmt.ExecContext(ctx, key) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getCleanupSessionsQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + _, err = stmt.ExecContext(ctx, sessionType, before) + return err +} + func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) { var result schemaVersion ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) @@ -2931,6 +3030,16 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema return result, err } +func sqlCommonRequireRowAffected(res sql.Result) error { + // MariaDB/MySQL returns 0 rows affected for updates that don't change anything + // so we don't check rows affected for updates + affected, err := res.RowsAffected() + if err == nil && affected == 0 { + return util.NewRecordNotFoundError(sql.ErrNoRows.Error()) + } + return nil +} + func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error { q := getUpdateDBVersionQuery() stmt, err := dbHandle.PrepareContext(ctx, q) @@ -2943,8 +3052,8 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve return err } -func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error { - if err := sqlAquireLock(dbHandle); err != nil { +func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error { + if err := sqlAcquireLock(dbHandle); err != nil { return err } defer sqlReleaseLock(dbHandle) @@ -2954,10 +3063,12 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n if newVersion > 0 { currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false) - if err == nil && currentVersion.Version >= newVersion { - providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?", - currentVersion.Version, newVersion) - return nil + if err == nil { + if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) { + providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?", + currentVersion.Version, newVersion) + return nil + } } } @@ -2978,7 +3089,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n }) } -func sqlAquireLock(dbHandle *sql.DB) error { +func sqlAcquireLock(dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index d944cbeb..9c811e38 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -36,6 +36,7 @@ DROP TABLE IF EXISTS "{{groups}}"; DROP TABLE IF EXISTS "{{defender_events}}"; DROP TABLE IF EXISTS "{{defender_hosts}}"; DROP TABLE IF EXISTS "{{active_transfers}}"; +DROP TABLE IF EXISTS "{{shared_sessions}}"; DROP TABLE IF EXISTS "{{schema_version}}"; ` sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL); @@ -149,6 +150,12 @@ ALTER TABLE "new__folders_mapping" RENAME TO "{{folders_mapping}}"; CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id"); ` + sqliteV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL, +"type" integer NOT NULL, "timestamp" bigint NOT NULL); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); + ` + sqliteV19DownSQL = `DROP TABLE "{{shared_sessions}}";` ) // SQLiteProvider defines the auth provider for SQLite database @@ -466,6 +473,22 @@ func (p *SQLiteProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, e return sqlCommonGetActiveTransfers(from, p.dbHandle) } +func (p *SQLiteProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *SQLiteProvider) deleteSharedSession(key string) error { + return sqlCommonDeleteSession(key, p.dbHandle) +} + +func (p *SQLiteProvider) getSharedSession(key string) (Session, error) { + return sqlCommonGetSession(key, p.dbHandle) +} + +func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + func (p *SQLiteProvider) close() error { return p.dbHandle.Close() } @@ -496,10 +519,10 @@ func (p *SQLiteProvider) initializeDatabase() error { initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts) initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true) } -func (p *SQLiteProvider) migrateDatabase() error { +func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) if err != nil { return err @@ -520,6 +543,8 @@ func (p *SQLiteProvider) migrateDatabase() error { return updateSQLiteDatabaseFromV16(p.dbHandle) case version == 17: return updateSQLiteDatabaseFromV17(p.dbHandle) + case version == 18: + return updateSQLiteDatabaseFromV18(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -548,6 +573,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { return downgradeSQLiteDatabaseFromV17(p.dbHandle) case 18: return downgradeSQLiteDatabaseFromV18(p.dbHandle) + case 19: + return downgradeSQLiteDatabaseFromV19(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -555,7 +582,7 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { func (p *SQLiteProvider) resetDatabase() error { sql := sqlReplaceAll(sqliteResetSQL) - return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) } func updateSQLiteDatabaseFromV15(dbHandle *sql.DB) error { @@ -573,7 +600,14 @@ func updateSQLiteDatabaseFromV16(dbHandle *sql.DB) error { } func updateSQLiteDatabaseFromV17(dbHandle *sql.DB) error { - return updateSQLiteDatabaseFrom17To18(dbHandle) + if err := updateSQLiteDatabaseFrom17To18(dbHandle); err != nil { + return err + } + return updateSQLiteDatabaseFromV18(dbHandle) +} + +func updateSQLiteDatabaseFromV18(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom18To19(dbHandle) } func downgradeSQLiteDatabaseFromV16(dbHandle *sql.DB) error { @@ -594,13 +628,20 @@ func downgradeSQLiteDatabaseFromV18(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFromV17(dbHandle) } +func downgradeSQLiteDatabaseFromV19(dbHandle *sql.DB) error { + if err := downgradeSQLiteDatabaseFrom19To18(dbHandle); err != nil { + return err + } + return downgradeSQLiteDatabaseFromV18(dbHandle) +} + func updateSQLiteDatabaseFrom15To16(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 15 -> 16") providerLog(logger.LevelInfo, "updating database version: 15 -> 16") sql := strings.ReplaceAll(sqliteV16SQL, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true) } func updateSQLiteDatabaseFrom16To17(dbHandle *sql.DB) error { @@ -617,7 +658,7 @@ func updateSQLiteDatabaseFrom16To17(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17); err != nil { + if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true); err != nil { return err } return setPragmaFK(dbHandle, "ON") @@ -629,7 +670,15 @@ func updateSQLiteDatabaseFrom17To18(dbHandle *sql.DB) error { if err := importGCSCredentials(); err != nil { return err } - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true) +} + +func updateSQLiteDatabaseFrom18To19(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 18 -> 19") + providerLog(logger.LevelInfo, "updating database version: 18 -> 19") + sql := strings.ReplaceAll(sqliteV19SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true) } func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error { @@ -637,7 +686,7 @@ func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error { providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") sql := strings.ReplaceAll(sqliteV16DownSQL, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false) } func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error { @@ -654,7 +703,7 @@ func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) - if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16); err != nil { + if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false); err != nil { return err } return setPragmaFK(dbHandle, "ON") @@ -663,7 +712,14 @@ func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFrom18To17(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 18 -> 17") providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") - return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false) +} + +func downgradeSQLiteDatabaseFrom19To18(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 19 -> 18") + providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18") + sql := strings.ReplaceAll(sqliteV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false) } func setPragmaFK(dbHandle *sql.DB, value string) error { diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 007a3509..cb2b3f4d 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -33,6 +33,38 @@ func getSQLPlaceholders() []string { return placeholders } +func getAddSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("INSERT INTO %s (`key`,`data`,`type`,`timestamp`) VALUES (%s,%s,%s,%s) "+ + "ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key) DO UPDATE SET data= + EXCLUDED.data, timestamp=EXCLUDED.timestamp`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getDeleteSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s", sqlTableSharedSessions, sqlPlaceholders[0]) + } + return fmt.Sprintf(`DELETE FROM %s WHERE key = %s`, sqlTableSharedSessions, sqlPlaceholders[0]) +} + +func getSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s", sqlTableSharedSessions, + sqlPlaceholders[0]) + } + return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s`, sqlTableSharedSessions, + sqlPlaceholders[0]) +} + +func getCleanupSessionsQuery() string { + return fmt.Sprintf(`DELETE from %s WHERE type = %s AND timestamp < %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) +} + func getAddDefenderHostQuery() string { if config.Driver == MySQLDataProviderName { return fmt.Sprintf("INSERT INTO %v (`ip`,`updated_at`,`ban_time`) VALUES (%v,%v,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)", diff --git a/dataprovider/user.go b/dataprovider/user.go index b3bbb6c7..88b55c4f 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -756,20 +756,20 @@ func (u *User) HasPermissionsInside(virtualPath string) bool { // HasPerm returns true if the user has the given permission or any permission func (u *User) HasPerm(permission, path string) bool { perms := u.GetPermissionsForPath(path) - if util.IsStringInSlice(PermAny, perms) { + if util.Contains(perms, PermAny) { return true } - return util.IsStringInSlice(permission, perms) + return util.Contains(perms, permission) } // HasAnyPerm returns true if the user has at least one of the given permissions func (u *User) HasAnyPerm(permissions []string, path string) bool { perms := u.GetPermissionsForPath(path) - if util.IsStringInSlice(PermAny, perms) { + if util.Contains(perms, PermAny) { return true } for _, permission := range permissions { - if util.IsStringInSlice(permission, perms) { + if util.Contains(perms, permission) { return true } } @@ -779,11 +779,11 @@ func (u *User) HasAnyPerm(permissions []string, path string) bool { // HasPerms returns true if the user has all the given permissions func (u *User) HasPerms(permissions []string, path string) bool { perms := u.GetPermissionsForPath(path) - if util.IsStringInSlice(PermAny, perms) { + if util.Contains(perms, PermAny) { return true } for _, permission := range permissions { - if !util.IsStringInSlice(permission, perms) { + if !util.Contains(perms, permission) { return false } } @@ -850,11 +850,11 @@ func (u *User) IsLoginMethodAllowed(loginMethod, protocol string, partialSuccess } } } - if util.IsStringInSlice(loginMethod, u.Filters.DeniedLoginMethods) { + if util.Contains(u.Filters.DeniedLoginMethods, loginMethod) { return false } if protocol == protocolSSH && loginMethod == LoginMethodPassword { - if util.IsStringInSlice(SSHLoginMethodPassword, u.Filters.DeniedLoginMethods) { + if util.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) { return false } } @@ -896,7 +896,7 @@ func (u *User) IsPartialAuth(loginMethod string) bool { method == SSHLoginMethodPassword { continue } - if !util.IsStringInSlice(method, SSHMultiStepsLoginMethods) { + if !util.Contains(SSHMultiStepsLoginMethods, method) { return false } } @@ -910,7 +910,7 @@ func (u *User) GetAllowedLoginMethods() []string { if method == SSHLoginMethodPassword { continue } - if !util.IsStringInSlice(method, u.Filters.DeniedLoginMethods) { + if !util.Contains(u.Filters.DeniedLoginMethods, method) { allowedMethods = append(allowedMethods, method) } } @@ -968,7 +968,7 @@ func (u *User) IsFileAllowed(virtualPath string) (bool, int) { // CanManageMFA returns true if the user can add a multi-factor authentication configuration func (u *User) CanManageMFA() bool { - if util.IsStringInSlice(sdk.WebClientMFADisabled, u.Filters.WebClient) { + if util.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) { return false } return len(mfa.GetAvailableTOTPConfigs()) > 0 @@ -987,39 +987,39 @@ func (u *User) isExternalAuthCached() bool { // CanManageShares returns true if the user can add, update and list shares func (u *User) CanManageShares() bool { - return !util.IsStringInSlice(sdk.WebClientSharesDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled) } // CanResetPassword returns true if this user is allowed to reset its password func (u *User) CanResetPassword() bool { - return !util.IsStringInSlice(sdk.WebClientPasswordResetDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled) } // CanChangePassword returns true if this user is allowed to change its password func (u *User) CanChangePassword() bool { - return !util.IsStringInSlice(sdk.WebClientPasswordChangeDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) } // CanChangeAPIKeyAuth returns true if this user is allowed to enable/disable API key authentication func (u *User) CanChangeAPIKeyAuth() bool { - return !util.IsStringInSlice(sdk.WebClientAPIKeyAuthChangeDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled) } // CanChangeInfo returns true if this user is allowed to change its info such as email and description func (u *User) CanChangeInfo() bool { - return !util.IsStringInSlice(sdk.WebClientInfoChangeDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled) } // CanManagePublicKeys returns true if this user is allowed to manage public keys // from the web client. Used in web client UI func (u *User) CanManagePublicKeys() bool { - return !util.IsStringInSlice(sdk.WebClientPubKeyChangeDisabled, u.Filters.WebClient) + return !util.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled) } // CanAddFilesFromWeb returns true if the client can add files from the web UI. // The specified target is the directory where the files must be uploaded func (u *User) CanAddFilesFromWeb(target string) bool { - if util.IsStringInSlice(sdk.WebClientWriteDisabled, u.Filters.WebClient) { + if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasPerm(PermUpload, target) || u.HasPerm(PermOverwrite, target) @@ -1028,7 +1028,7 @@ func (u *User) CanAddFilesFromWeb(target string) bool { // CanAddDirsFromWeb returns true if the client can add directories from the web UI. // The specified target is the directory where the new directory must be created func (u *User) CanAddDirsFromWeb(target string) bool { - if util.IsStringInSlice(sdk.WebClientWriteDisabled, u.Filters.WebClient) { + if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasPerm(PermCreateDirs, target) @@ -1037,7 +1037,7 @@ func (u *User) CanAddDirsFromWeb(target string) bool { // CanRenameFromWeb returns true if the client can rename objects from the web UI. // The specified src and dest are the source and target directories for the rename. func (u *User) CanRenameFromWeb(src, dest string) bool { - if util.IsStringInSlice(sdk.WebClientWriteDisabled, u.Filters.WebClient) { + if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasAnyPerm(permsRenameAny, src) && u.HasAnyPerm(permsRenameAny, dest) @@ -1046,7 +1046,7 @@ func (u *User) CanRenameFromWeb(src, dest string) bool { // CanDeleteFromWeb returns true if the client can delete objects from the web UI. // The specified target is the parent directory for the object to delete func (u *User) CanDeleteFromWeb(target string) bool { - if util.IsStringInSlice(sdk.WebClientWriteDisabled, u.Filters.WebClient) { + if util.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { return false } return u.HasAnyPerm(permsDeleteAny, target) @@ -1059,7 +1059,7 @@ func (u *User) MustSetSecondFactor() bool { return true } for _, p := range u.Filters.TwoFactorAuthProtocols { - if !util.IsStringInSlice(p, u.Filters.TOTPConfig.Protocols) { + if !util.Contains(u.Filters.TOTPConfig.Protocols, p) { return true } } @@ -1070,11 +1070,11 @@ func (u *User) MustSetSecondFactor() bool { // MustSetSecondFactorForProtocol returns true if the user must set a second factor authentication // for the specified protocol func (u *User) MustSetSecondFactorForProtocol(protocol string) bool { - if util.IsStringInSlice(protocol, u.Filters.TwoFactorAuthProtocols) { + if util.Contains(u.Filters.TwoFactorAuthProtocols, protocol) { if !u.Filters.TOTPConfig.Enabled { return true } - if !util.IsStringInSlice(protocol, u.Filters.TOTPConfig.Protocols) { + if !util.Contains(u.Filters.TOTPConfig.Protocols, protocol) { return true } } diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 991eff15..5e14b5a5 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -226,7 +226,7 @@ The configuration file contains the following sections: - `update_mode`, integer. Defines how the database will be initialized/updated. 0 means automatically. 1 means manually using the initprovider sub-command. - `create_default_admin`, boolean. Before you can use SFTPGo you need to create an admin account. If you open the admin web UI, a setup screen will guide you in creating the first admin account. You can automatically create the first admin account by enabling this setting and setting the environment variables `SFTPGO_DEFAULT_ADMIN_USERNAME` and `SFTPGO_DEFAULT_ADMIN_PASSWORD`. You can also create the first admin by loading initial data. This setting has no effect if an admin account is already found within the data provider. Default `false`. - `naming_rules`, integer. Naming rules for usernames and folder names. `0` means no rules. `1` means you can use any UTF-8 character. The names are used in URIs for REST API and Web admin. If not set only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". `2` means names are converted to lowercase before saving/matching and so case insensitive matching is possible. `3` means trimming trailing and leading white spaces before saving/matching. Rules can be combined, for example `3` means both converting to lowercase and allowing any UTF-8 character. Enabling these options for existing installations could be backward incompatible, some users could be unable to login, for example existing users with mixed cases in their usernames. You have to ensure that all existing users respect the defined rules. Default: `0`. - - `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, active transfers are persisted in the database and thus quota checks between ongoing transfers will work cross multiple instances. Default: `0`. + - `is_shared`, integer. If the data provider is shared across multiple SFTPGo instances, set this parameter to `1`. `MySQL`, `PostgreSQL` and `CockroachDB` can be shared, this setting is ignored for other data providers. For shared data providers, active transfers are persisted in the database and thus quota checks between ongoing transfers will work cross multiple instances. Password reset requests and OIDC tokens/states are also persisted in the database if the provider is shared. The database table `shared_sessions` is used only to store temporary sessions. In performance critical installations, you might consider using a database-specific optimization, for example you might use an `UNLOGGED` table for PostgreSQL. This optimization in only required in very limited use cases. Default: `0`. - `backups_path`, string. Path to the backup directory. This can be an absolute path or a path relative to the config dir. We don't allow backups in arbitrary paths for security reasons. - `auto_backup`, struct. Defines the configuration for automatic data provider backups. Example: hour `0` and day_of_week `*` means a backup every day at midnight. The backup file name is in the format `backup__.json`, files with the same name will be overwritten. Note, this process will only backup provider data (users, folders, shares, admins, api keys) and will not backup the configuration file and users files. - `enabled`, boolean. Set to `true` to enable automatic backups. Default: `true`. diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 35b4542c..554fa8e4 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -377,7 +377,7 @@ func TestMain(m *testing.M) { }() go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } diff --git a/ftpd/server.go b/ftpd/server.go index ca4aa4b7..2cff4bad 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -326,7 +326,7 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext user.Username, user.HomeDir) return nil, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir) } - if util.IsStringInSlice(common.ProtocolFTP, user.Filters.DeniedProtocols) { + if util.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) { logger.Info(logSender, connectionID, "cannot login user %#v, protocol FTP is not allowed", user.Username) return nil, fmt.Errorf("protocol FTP is not allowed for user %#v", user.Username) } diff --git a/go.mod b/go.mod index 3d6874f2..1696879a 100644 --- a/go.mod +++ b/go.mod @@ -3,26 +3,26 @@ module github.com/drakkan/sftpgo/v2 go 1.18 require ( - cloud.google.com/go/storage v1.22.0 + cloud.google.com/go/storage v1.22.1 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.4.1 github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 github.com/alexedwards/argon2id v0.0.0-20211130144151-3585854a6387 - github.com/aws/aws-sdk-go-v2 v1.16.3 - github.com/aws/aws-sdk-go-v2/config v1.15.5 - github.com/aws/aws-sdk-go-v2/credentials v1.12.0 - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.10 - github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4 - github.com/aws/aws-sdk-go-v2/service/s3 v1.26.9 - github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.7 - github.com/aws/aws-sdk-go-v2/service/sts v1.16.4 - github.com/cockroachdb/cockroach-go/v2 v2.2.8 + github.com/aws/aws-sdk-go-v2 v1.16.4 + github.com/aws/aws-sdk-go-v2/config v1.15.7 + github.com/aws/aws-sdk-go-v2/credentials v1.12.2 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.12 + github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.5 + github.com/aws/aws-sdk-go-v2/service/s3 v1.26.10 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.8 + github.com/aws/aws-sdk-go-v2/service/sts v1.16.6 + github.com/cockroachdb/cockroach-go/v2 v2.2.10 github.com/coreos/go-oidc/v3 v3.2.0 github.com/eikenb/pipeat v0.0.0-20210730190139-06b3e6902001 - github.com/fclairamb/ftpserverlib v0.18.0 + github.com/fclairamb/ftpserverlib v0.18.1-0.20220515214847-f96d31ec626e github.com/fclairamb/go-log v0.3.0 - github.com/go-chi/chi/v5 v5.0.8-0.20220103230436-7dbe9a0bd10f + github.com/go-chi/chi/v5 v5.0.8-0.20220512131524-9e71a0d4b3d6 github.com/go-chi/jwtauth/v5 v5.0.2 github.com/go-chi/render v1.0.1 github.com/go-sql-driver/mysql v1.6.0 @@ -36,7 +36,7 @@ require ( github.com/jlaffaye/ftp v0.0.0-20201112195030-9aae4d151126 github.com/klauspost/compress v1.15.4 github.com/lestrrat-go/jwx v1.2.24 - github.com/lib/pq v1.10.5 + github.com/lib/pq v1.10.6 github.com/lithammer/shortuuid/v3 v3.0.7 github.com/mattn/go-sqlite3 v1.14.13 github.com/mhale/smtpd v0.8.0 @@ -49,7 +49,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/rs/cors v1.8.2 github.com/rs/xid v1.4.0 - github.com/rs/zerolog v1.26.2-0.20220312163309-e9344a8c507b + github.com/rs/zerolog v1.26.2-0.20220505171737-a4ec5e4cdd4b github.com/sftpgo/sdk v0.1.1-0.20220425123921-2f843a49e012 github.com/shirou/gopsutil/v3 v3.22.4 github.com/spf13/afero v1.8.2 @@ -64,12 +64,12 @@ require ( go.etcd.io/bbolt v1.3.6 go.uber.org/automaxprocs v1.5.1 gocloud.dev v0.25.0 - golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 + golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898 golang.org/x/net v0.0.0-20220513224357-95641704303c golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 - golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a + golang.org/x/sys v0.0.0-20220519141025-dcacdad47464 golang.org/x/time v0.0.0-20220411224347-583f2d630306 - google.golang.org/api v0.79.0 + google.golang.org/api v0.80.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -79,15 +79,15 @@ require ( cloud.google.com/go/iam v0.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.11.5 // indirect github.com/aws/smithy-go v1.11.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.0.1 // indirect @@ -119,7 +119,7 @@ require ( github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.0 // indirect - github.com/lufia/plan9stats v0.0.0-20220326011226-f1430873d8db // indirect + github.com/lufia/plan9stats v0.0.0-20220517141722-cf486979b281 // indirect github.com/magiconair/properties v1.8.6 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect @@ -150,9 +150,9 @@ require ( golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.1.10 // indirect - golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f // indirect + golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20220505152158-f39f71e6c8f3 // indirect + google.golang.org/genproto v0.0.0-20220519153652-3a47de7e79bd // indirect google.golang.org/grpc v1.46.2 // indirect google.golang.org/protobuf v1.28.0 // indirect gopkg.in/ini.v1 v1.66.4 // indirect @@ -163,6 +163,6 @@ require ( replace ( github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 - golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220514091251-ad79d832b8dc + golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20220519062025-309756691f42 golang.org/x/net => github.com/drakkan/net v0.0.0-20220514085754-d827943a3fff ) diff --git a/go.sum b/go.sum index 9fb9746c..22e60a58 100644 --- a/go.sum +++ b/go.sum @@ -71,8 +71,9 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= cloud.google.com/go/storage v1.21.0/go.mod h1:XmRlxkgPjlBONznT2dDUU/5XlpU2OjMnKuqnZI01LAA= -cloud.google.com/go/storage v1.22.0 h1:NUV0NNp9nkBuW66BFRLuMgldN60C57ET3dhbwLIYio8= cloud.google.com/go/storage v1.22.0/go.mod h1:GbaLEoMqbVm6sx3Z0R++gSiBlgMv6yUi2q1DeGFKQgE= +cloud.google.com/go/storage v1.22.1 h1:F6IlQJZrZM++apn9V5/VfS3gbTUYg98PS3EMQAzqtfg= +cloud.google.com/go/storage v1.22.1/go.mod h1:S8N1cAStu7BOeFfE8KAQzmyyLkK8p/vmRq6kuBTW58Y= cloud.google.com/go/trace v1.0.0/go.mod h1:4iErSByzxkyHWzzlAj63/Gmjz0NH1ASqhJguHpGcr6A= cloud.google.com/go/trace v1.2.0/go.mod h1:Wc8y/uYyOhPy12KEnXG9XGrvfMz5F5SrYecQlbW1rwM= contrib.go.opencensus.io/exporter/aws v0.0.0-20200617204711-c478e41e60e9/go.mod h1:uu1P0UCM/6RbsMrgPa98ll8ZcHM858i/AD06a9aLRCA= @@ -136,62 +137,62 @@ github.com/aws/aws-sdk-go v1.15.27/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZo github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go v1.43.31/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/aws/aws-sdk-go-v2 v1.16.2/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= -github.com/aws/aws-sdk-go-v2 v1.16.3 h1:0W1TSJ7O6OzwuEvIXAtJGvOeQ0SGAhcpxPN2/NK5EhM= -github.com/aws/aws-sdk-go-v2 v1.16.3/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= +github.com/aws/aws-sdk-go-v2 v1.16.4 h1:swQTEQUyJF/UkEA94/Ga55miiKFoXmm/Zd67XHgmjSg= +github.com/aws/aws-sdk-go-v2 v1.16.4/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.1 h1:SdK4Ppk5IzLs64ZMvr6MrSficMtjY2oS0WOORXTlxwU= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.1/go.mod h1:n8Bs1ElDD2wJ9kCRTczA83gYbBmjSwZp3umc6zF4EeM= github.com/aws/aws-sdk-go-v2/config v1.15.3/go.mod h1:9YL3v07Xc/ohTsxFXzan9ZpFpdTOFl4X65BAKYaz8jg= -github.com/aws/aws-sdk-go-v2/config v1.15.5 h1:P+xwhr6kabhxDTXTVH9YoHkqjLJ0wVVpIUHtFNr2hjU= -github.com/aws/aws-sdk-go-v2/config v1.15.5/go.mod h1:ZijHHh0xd/A+ZY53az0qzC5tT46kt4JVCePf2NX9Lk4= +github.com/aws/aws-sdk-go-v2/config v1.15.7 h1:PrzhYjDpWnGSpjedmEapldQKPW4x8cCNzUI8XOho1CM= +github.com/aws/aws-sdk-go-v2/config v1.15.7/go.mod h1:exERlvqx1OoUHrxQpMgrmfSW0H6B1+r3xziZD3bBXRg= github.com/aws/aws-sdk-go-v2/credentials v1.11.2/go.mod h1:j8YsY9TXTm31k4eFhspiQicfXPLZ0gYXA50i4gxPE8g= -github.com/aws/aws-sdk-go-v2/credentials v1.12.0 h1:4R/NqlcRFSkR0wxOhgHi+agGpbEr5qMCjn7VqUIJY+E= -github.com/aws/aws-sdk-go-v2/credentials v1.12.0/go.mod h1:9YWk7VW+eyKsoIL6/CljkTrNVWBSK9pkqOPUuijid4A= +github.com/aws/aws-sdk-go-v2/credentials v1.12.2 h1:tX4EHQFU4+O9at5QjnwIKb/Qgv7MbgbUNtqTRF0Vu2M= +github.com/aws/aws-sdk-go-v2/credentials v1.12.2/go.mod h1:/XWqDVuzclEKvzileqtD7/t+wIhOogv//6JFlKEe0Wc= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnqSqz8O48LBYDSC+k6brng09jcMOk= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 h1:FP8gquGeGHHdfY6G5llaMQDF+HAf20VKc8opRwmjf04= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4/go.mod h1:u/s5/Z+ohUQOPXl00m2yJVyioWDECsbpXTQlaqSlufc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5 h1:YPxclBeE07HsLQE8vtjC8T2emcTjM9nzqsnDi2fv5UM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5/go.mod h1:WAPnuhG5IQ/i6DETFl5NmX3kKqCzw7aau9NHAGcm4QE= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.10 h1:JL7cY85hyjlgfA29MMyAlItX+JYIH9XsxgMBS7jtlqA= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.10/go.mod h1:p+ul5bLZSDRRXCZ/vePvfmZBH9akozXBJA5oMshWa5U= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.12 h1:Gd+McyLAdshV3ZaQXt7Vd8dtLMZgcAmn5Y/mXDEO9L8= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.12/go.mod h1:8pCb6S1pHhY5PulX37wdb2dqXHkM4B3ij6Z1gAOdDtE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10 h1:uFWgo6mGJI1n17nbcvSc6fxVuR3xLNqvXt12JCnEcT8= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10/go.mod h1:F+EZtuIwjlv35kRJPyBGcsA4f7bnSoz15zOQ2lJq1Z4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11 h1:gsqHplNh1DaQunEKZISK56wlpbCg0yKxNVvGWCFuF1k= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11/go.mod h1:tmUB6jakq5DFNcXsXOA/ZQ7/C8VnSKYkx58OI7Fh79g= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3/go.mod h1:ssOhaLpRlh88H3UmEcsBoVKq309quMvm3Ds8e9d4eJM= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4 h1:cnsvEKSoHN4oAN7spMMr0zhEW2MHnhAVpmqQg8E6UcM= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4/go.mod h1:8glyUqVIM4AmeenIsPo0oVh3+NUwnsQml2OFupfQW+0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5 h1:PLFj+M2PgIDHG//hw3T0O0KLI4itVtAjtxrZx4AHPLg= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5/go.mod h1:fV1AaS2gFc1tM0RCb015FJ0pvWVUfJZANzjwoO4YakM= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.10/go.mod h1:8DcYQcz0+ZJaSxANlHIsbbi6S+zMwjwdDqwW3r9AzaE= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11 h1:6cZRymlLEIlDTEB0+5+An6Zj1CKt6rSE69tOmFeu1nk= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11/go.mod h1:0MR+sS1b/yxsfAPvAESrw8NfwUoxMinDyw6EYR9BS2U= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.1 h1:C21IDZCm9Yu5xqjb3fKmxDoYvJXtw1DNlOmLZEIlY1M= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.1/go.mod h1:l/BbcfqDCT3hePawhy4ZRtewjtdkl6GWtd9/U+1penQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12 h1:j0VqrjtgsY1Bx27tD0ysay36/K4kFMWRp9K3ieO9nLU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12/go.mod h1:00c7+ALdPh4YeEUPXJzyU0Yy01nPGOq2+9rUaz05z9g= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.2 h1:1fs9WkbFcMawQjxEI0B5L0SqvBhJZebxWM6Z3x/qHWY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.2/go.mod h1:0jDVeWUFPbI3sOfsXXAsIdiawXcn7VBLx/IlFVTRP64= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1 h1:T4pFel53bkHjL2mMo+4DKE6r6AuoZnM0fg7k1/ratr4= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1/go.mod h1:GeUru+8VzrTXV/83XyMJ80KpH8xO89VPoUileyNQ+tc= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.3/go.mod h1:Seb8KNmD6kVTjwRjVEgOT5hPin6sq+v4C2ycJQDwuH8= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5 h1:9LSZqt4v1JiehyZTrQnRFf2mY/awmyYNNY/b7zqtduU= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5/go.mod h1:S8TVP66AAkMMdYYCNZGvrdEq9YRm+qLXjio4FqRnrEE= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.6 h1:9mvDAsMiN+07wcfGM+hJ1J3dOKZ2YOpDiPZ6ufRJcgw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.6/go.mod h1:Eus+Z2iBIEfhOvhSdMTcscNOMy6n3X9/BJV0Zgax98w= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3/go.mod h1:wlY6SVjuwvh3TVRpTqdy4I1JpBFLX4UGeKZdWntaocw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 h1:b16QW0XWl0jWjLABFc1A+uh145Oqv+xDcObNk0iQgUk= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4/go.mod h1:uKkN7qmSIsNJVyMtxNQoCEYMvFEXbOg9fwCJPdfp2u8= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5 h1:gRW1ZisKc93EWEORNJRvy/ZydF3o6xLSveJHdi1Oa0U= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5/go.mod h1:ZbkttHXaVn3bBo/wpJbQGiiIWR90eTBUVBrEHUEQlho= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.3/go.mod h1:Bm/v2IaN6rZ+Op7zX+bOUMdL4fsrYZiD0dsjLhNKwZc= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.4 h1:RE/DlZLYrz1OOmq8F28IXHLksuuvlpzUbvJ+SESCZBI= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.4/go.mod h1:oudbsSdDtazNj47z1ut1n37re9hDsKpk2ZI3v7KSxq0= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.5 h1:DyPYkrH4R2zn+Pdu6hM3VTuPsQYAE6x2WB24X85Sgw0= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.5/go.mod h1:XtL92YWo0Yq80iN3AgYRERJqohg4TozrqRlxYhHGJ7g= github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYYAWpQlx3PKmohy+rE2F+o5g= -github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4 h1:qmHavnjRtgdH54nyG4iEk6ZCde9m2S++32INurhaNTk= -github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4/go.mod h1:CloMDruFIVZJ8qv2OsY5ENIqzg5c0eeTciVVW3KHdvE= +github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.5 h1:CvRAsgxd1BN5l961+xXfS0mEhhyJTMxqdoWpZQIJZt4= +github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.5/go.mod h1:tnwCNkQvihXdRZ8Fyita7EJ0IeY46DcJWhgcWaquT+o= github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.9 h1:LCQKnopq2t4oQS3VKivlYTzAHCTJZZoQICM9fny7KHY= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.9/go.mod h1:iMYipLPXlWpBJ0KFX7QJHZ84rBydHBY8as2aQICTPWk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.10 h1:GWdLZK0r1AK5sKb8rhB9bEXqXCK8WNuyv4TBAD6ZviQ= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.10/go.mod h1:+O7qJxF8nLorAhuIVhYTHse6okjHJJm4EwhhzvpnkT0= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o= -github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.7 h1:9PpUYZ6D8J9p3kP7IW4iob1x1kbD5tMhKuRWzT/aQ6o= -github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.7/go.mod h1:Z+i6uqZgCOBXhNoEGoRm/ZaLsaJA9rGUAmkVKM/3+g4= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.8 h1:diuSSZTFZpbsLA5CA0wo7Nw04c5kX5Kae37v3f6CsIA= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.8/go.mod h1:Jt1lSw1fYlQ60lqrZ9ViN2LMGizbWTWbkStm4rbuYuE= github.com/aws/aws-sdk-go-v2/service/sns v1.17.4/go.mod h1:kElt+uCcXxcqFyc+bQqZPFD9DME/eC6oHBXvFzQ9Bcw= github.com/aws/aws-sdk-go-v2/service/sqs v1.18.3/go.mod h1:skmQo0UPvsjsuYYSYMVmrPc1HWCbHUJyrCEp+ZaLzqM= github.com/aws/aws-sdk-go-v2/service/ssm v1.24.1/go.mod h1:NR/xoKjdbRJ+qx0pMR4mI+N/H1I1ynHwXnO6FowXJc0= github.com/aws/aws-sdk-go-v2/service/sso v1.11.3/go.mod h1:7UQ/e69kU7LDPtY40OyoHYgRmgfGM4mgsLYtcObdveU= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.4 h1:Uw5wBybFQ1UeA9ts0Y07gbv0ncZnIAyw858tDW0NP2o= -github.com/aws/aws-sdk-go-v2/service/sso v1.11.4/go.mod h1:cPDwJwsP4Kff9mldCXAmddjJL6JGQqtA3Mzer2zyr88= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.5 h1:TfJ/zuOYvHnxkvohSwAF3Ppn9KT/SrGZuOZHTPy8Guw= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.5/go.mod h1:TFVe6Rr2joVLsYQ1ABACXgOC6lXip/qpX2x5jWg/A9w= github.com/aws/aws-sdk-go-v2/service/sts v1.16.3/go.mod h1:bfBj0iVmsUyUg4weDB4NxktD9rDGeKSVWnjTnwbx9b8= -github.com/aws/aws-sdk-go-v2/service/sts v1.16.4 h1:+xtV90n3abQmgzk1pS++FdxZTrPEDgQng6e4/56WR2A= -github.com/aws/aws-sdk-go-v2/service/sts v1.16.4/go.mod h1:lfSYenAXtavyX2A1LsViglqlG9eEFYxNryTZS5rn3QE= +github.com/aws/aws-sdk-go-v2/service/sts v1.16.6 h1:aYToU0/iazkMY67/BYLt3r6/LT/mUtarLAF5mGof1Kg= +github.com/aws/aws-sdk-go-v2/service/sts v1.16.6/go.mod h1:rP1rEOKAGZoXp4iGDxSXFvODAtXpm34Egf0lL0eshaQ= github.com/aws/smithy-go v1.11.2 h1:eG/N+CcUMAvsdffgMvjMKwfyDzIkjM6pfxMJ8Mzc6mE= github.com/aws/smithy-go v1.11.2/go.mod h1:3xHYmszWVx2c0kIwQeEVf9uSm4fYZt67FBJnwub1bgM= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= @@ -224,8 +225,8 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/cockroachdb/cockroach-go/v2 v2.2.8 h1:IrQpwOXQza67nSSezygYjl4GQtQnE+rDrU2yK6MmNFA= -github.com/cockroachdb/cockroach-go/v2 v2.2.8/go.mod h1:q4ZRgO6CQpwNyEvEwSxwNrOSVchsmzrBnAv3HuZ3Abc= +github.com/cockroachdb/cockroach-go/v2 v2.2.10 h1:O7Hl8m0rs/oJNBmRr14ED3Q3+AmugMK9DtJwRDHZ2DA= +github.com/cockroachdb/cockroach-go/v2 v2.2.10/go.mod h1:xZ2VHjUEb/cySv0scXBx7YsBnHtLHkR1+w/w73b5i3M= github.com/coreos/go-oidc/v3 v3.2.0 h1:2eR2MGR7thBXSQ2YbODlF0fcmgtliLCfr9iX6RW11fc= github.com/coreos/go-oidc/v3 v3.2.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -250,8 +251,8 @@ github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQ github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/drakkan/crypto v0.0.0-20220514091251-ad79d832b8dc h1:RZGYMm4+aQbbGUecIKYQk4d+tYBePbB23v0QDrNXf4w= -github.com/drakkan/crypto v0.0.0-20220514091251-ad79d832b8dc/go.mod h1:SiM6ypd8Xu1xldObYtbDztuUU7xUzMnUULfphXFZmro= +github.com/drakkan/crypto v0.0.0-20220519062025-309756691f42 h1:AS9tPudMbdwJhnFgwE8BnRfjIyOe8buOUUiJE+qi8MY= +github.com/drakkan/crypto v0.0.0-20220519062025-309756691f42/go.mod h1:SiM6ypd8Xu1xldObYtbDztuUU7xUzMnUULfphXFZmro= github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9 h1:LPH1dEblAOO/LoG7yHPMtBLXhQmjaga91/DDjWk9jWA= github.com/drakkan/ftp v0.0.0-20201114075148-9b9adce499a9/go.mod h1:2lmrmq866uF2tnje75wQHzmPXhmSWUt7Gyx2vgK1RCU= github.com/drakkan/net v0.0.0-20220514085754-d827943a3fff h1:en4qoYF7ceYxP1OkTSdbAO87JA9ruqdvKLV2KHviiNM= @@ -271,8 +272,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/fclairamb/ftpserverlib v0.18.0 h1:q/uz7jVFMoGEMswnA+nbaKEC5mzxXJOmhPE/Q3r7VZI= -github.com/fclairamb/ftpserverlib v0.18.0/go.mod h1:QhLRiCajhPG/2WwGgcsAqmlaYXX8KziNXtSe1BlRH+k= +github.com/fclairamb/ftpserverlib v0.18.1-0.20220515214847-f96d31ec626e h1:D7/to1KmKRTTRQyExulywEVYKhB+/WOW3gqiKimrbXg= +github.com/fclairamb/ftpserverlib v0.18.1-0.20220515214847-f96d31ec626e/go.mod h1:Ff6D1Ofy7/ezi7C30NPEgazzp/AQqyp0T8D7k+Tv2ls= github.com/fclairamb/go-log v0.3.0 h1:oSC7Zjt0FZIYC5xXahUUycKGkypSdr2srFPLsp7CLd0= github.com/fclairamb/go-log v0.3.0/go.mod h1:XG61EiPlAXnPDN8SA4N3zeA+GyBJmVOCCo12WORx/gA= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= @@ -286,8 +287,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gin-gonic/gin v1.7.3/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= github.com/go-chi/chi/v5 v5.0.4/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/go-chi/chi/v5 v5.0.8-0.20220103230436-7dbe9a0bd10f h1:6kLofhLkWj7lgCc+mvcVLnwhTzQYgL/yW/Y0e/JYwjg= -github.com/go-chi/chi/v5 v5.0.8-0.20220103230436-7dbe9a0bd10f/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/chi/v5 v5.0.8-0.20220512131524-9e71a0d4b3d6 h1:+fT7oFUOersdx+u7uIxOjabDVGxg+qqNV6kRdAXIvaQ= +github.com/go-chi/chi/v5 v5.0.8-0.20220512131524-9e71a0d4b3d6/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/jwtauth/v5 v5.0.2 h1:CSKtr+b6Jnfy5T27sMaiBPxaVE/bjnjS3ramFQ0526w= github.com/go-chi/jwtauth/v5 v5.0.2/go.mod h1:TeA7vmPe3uYThvHw8O8W13HOOpOd4MTgToxL41gZyjs= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= @@ -299,8 +300,8 @@ github.com/go-ini/ini v1.25.4/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3I github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-kit/log v0.2.0 h1:7i2K3eKTos3Vc0enKCfnVcgHh2olr/MyfboYq7cAcFw= github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= +github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= @@ -313,7 +314,6 @@ github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTM github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -328,7 +328,6 @@ github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGF github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= @@ -468,13 +467,12 @@ github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgO github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= -github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= -github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= -github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.11.0/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.12.0/go.mod h1:ZkhRC59Llhrq3oSfrikvwQ5NaxYExr6twkdkMLaKono= +github.com/jackc/pgconn v1.12.1/go.mod h1:ZkhRC59Llhrq3oSfrikvwQ5NaxYExr6twkdkMLaKono= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= @@ -485,45 +483,37 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= -github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= -github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= -github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= -github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= -github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= -github.com/jackc/pgtype v1.6.2/go.mod h1:JCULISAZBFGrHaOXIIFiyfzW5VY0GRitRr8NeJsrdig= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= github.com/jackc/pgtype v1.10.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.11.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= -github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= -github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= -github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= -github.com/jackc/pgx/v4 v4.10.1/go.mod h1:QlrWebbs3kqEZPHCTGyxecvzG6tvIsYu+A5b1raylkA= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= +github.com/jackc/pgx/v4 v4.16.0/go.mod h1:N0A9sFdWzkw/Jy1lwoiB64F2+ugFZi987zRxcPez/wI= +github.com/jackc/pgx/v4 v4.16.1/go.mod h1:SIhx0D5hoADaiXZVyv+3gSm3LCIIINTVO0PficsvWGQ= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jhump/protoreflect v1.6.0 h1:h5jfMVslIg6l29nsMs0D8Wj17RDVdNYti0vDN/PZZoE= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -579,21 +569,18 @@ github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmt github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.5 h1:J+gdV2cUmX7ZqL2B0lFcW0m+egaHC2V3lpO8nWxyYiQ= -github.com/lib/pq v1.10.5/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= +github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8= github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= -github.com/lufia/plan9stats v0.0.0-20220326011226-f1430873d8db h1:m2s7Fwo4OwmcheIWUc/Nw9/MZ0eFtP3to0ovTpqOiCQ= -github.com/lufia/plan9stats v0.0.0-20220326011226-f1430873d8db/go.mod h1:VgrrWVwBO2+6XKn8ypT3WUqvoxCa8R2M5to2tRzGovI= +github.com/lufia/plan9stats v0.0.0-20220517141722-cf486979b281 h1:aczX6NMOtt6L4YT0fQvKkDK6LZEtdOso9sUH89V1+P0= +github.com/lufia/plan9stats v0.0.0-20220517141722-cf486979b281/go.mod h1:lc+czkgO/8F7puNki5jk8QyujbfK1LOT7Wl0ON2hxyk= github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamhfG/Qzo= github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -604,7 +591,6 @@ github.com/mattn/go-ieproxy v0.0.3/go.mod h1:6ZpRmhBaYuBX1U2za+9rC9iCGLsSp2tftel github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= @@ -707,8 +693,8 @@ github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/rs/zerolog v1.26.2-0.20220312163309-e9344a8c507b h1:72Plc168SB6g5i9cOEPaCuMK01bKNyniHnCpqPnX0Cg= -github.com/rs/zerolog v1.26.2-0.20220312163309-e9344a8c507b/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U= +github.com/rs/zerolog v1.26.2-0.20220505171737-a4ec5e4cdd4b h1:wKjeedusHurN46dp/9kF0JLBh3YO54lu5juBX1oqJWE= +github.com/rs/zerolog v1.26.2-0.20220505171737-a4ec5e4cdd4b/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= @@ -718,8 +704,8 @@ github.com/sftpgo/sdk v0.1.1-0.20220425123921-2f843a49e012/go.mod h1:m5J7DH8unhD github.com/shirou/gopsutil/v3 v3.22.4 h1:srAQaiX6jX/cYL6q29aE0m8lOskT9CurZ9N61YR3yoI= github.com/shirou/gopsutil/v3 v3.22.4/go.mod h1:D01hZJ4pVHPpCTZ3m3T2+wDF2YAGfd+H4ifUguaQzHM= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= -github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -896,7 +882,6 @@ golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -965,8 +950,9 @@ golang.org/x/sys v0.0.0-20220330033206-e17cdc41300f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220502124256-b6088ccd6cba/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a h1:N2T1jUrTQE9Re6TFF5PhvEHXHCguynGhKjWVsIUt5cY= golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220519141025-dcacdad47464 h1:MpIuURY70f0iKp/oooEFtB2oENcHITo/z1b6u41pKCw= +golang.org/x/sys v0.0.0-20220519141025-dcacdad47464/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1057,8 +1043,9 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f h1:GGU+dLjvlC3qDwqYgL6UgRmHXhOOgns0bZu2Ty5mm6U= golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df h1:5Pf6pFKu98ODmgnpvkJ3kFUOQGGLIzLIkbzUHp47618= +golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= @@ -1104,8 +1091,8 @@ google.golang.org/api v0.74.0/go.mod h1:ZpfMZOVRMywNyvJFeqL9HRWBgAuRfSjJFpe9QtRR google.golang.org/api v0.75.0/go.mod h1:pU9QmyHLnzlpar1Mjt4IbapUCy8J+6HD6GeELN69ljA= google.golang.org/api v0.77.0/go.mod h1:pU9QmyHLnzlpar1Mjt4IbapUCy8J+6HD6GeELN69ljA= google.golang.org/api v0.78.0/go.mod h1:1Sg78yoMLOhlQTeF+ARBoytAcH1NNyyl390YMy6rKmw= -google.golang.org/api v0.79.0 h1:vaOcm0WdXvhGkci9a0+CcQVZqSRjN8ksSBlWv99f8Pg= -google.golang.org/api v0.79.0/go.mod h1:xY3nI94gbvBrE0J6NHXhxOmW97HG7Khjkku6AFB3Hyg= +google.golang.org/api v0.80.0 h1:IQWaGVCYnsm4MO3hh+WtSXMzMzuyFx/fuR8qkN3A0Qo= +google.golang.org/api v0.80.0/go.mod h1:xY3nI94gbvBrE0J6NHXhxOmW97HG7Khjkku6AFB3Hyg= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1206,8 +1193,10 @@ google.golang.org/genproto v0.0.0-20220414192740-2d67ff6cf2b4/go.mod h1:8w6bsBMX google.golang.org/genproto v0.0.0-20220421151946-72621c1f0bd3/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= google.golang.org/genproto v0.0.0-20220429170224-98d788798c3e/go.mod h1:8w6bsBMX6yCPbAVTeqQHvzxW0EIFigd5lZyahWgyfDo= google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= -google.golang.org/genproto v0.0.0-20220505152158-f39f71e6c8f3 h1:q1kiSVscqoDeqTF27eQ2NnLLDmqF0I373qQNXYMy0fo= google.golang.org/genproto v0.0.0-20220505152158-f39f71e6c8f3/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= +google.golang.org/genproto v0.0.0-20220518221133-4f43b3371335/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= +google.golang.org/genproto v0.0.0-20220519153652-3a47de7e79bd h1:e0TwkXOdbnH/1x5rc5MZ/VYyiZ4v+RdVfrGMqEwT68I= +google.golang.org/genproto v0.0.0-20220519153652-3a47de7e79bd/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1283,9 +1272,9 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99 h1:dbuHpmKjkDzSOMKAWl10QNlgaZUd3V1q99xc81tt2Kc= gopkg.in/yaml.v3 v3.0.0-20220512140231-539c8e751b99/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.0.8/go.mod h1:4eOzrI1MUfm6ObJU/UcmbXyiHSs8jSwH95G5P5dxcAg= -gorm.io/gorm v1.20.12/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= -gorm.io/gorm v1.21.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= +gorm.io/driver/postgres v1.3.5/go.mod h1:EGCWefLFQSVFrHGy4J8EtiHCWX5Q8t0yz2Jt9aKkGzU= +gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.23.5/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/httpd/api_mfa.go b/httpd/api_mfa.go index d46edff0..ac749991 100644 --- a/httpd/api_mfa.go +++ b/httpd/api_mfa.go @@ -221,7 +221,7 @@ func saveUserTOTPConfig(username string, r *http.Request, recoveryCodes []datapr return util.NewValidationError("two-factor authentication must be enabled") } for _, p := range user.Filters.TwoFactorAuthProtocols { - if !util.IsStringInSlice(p, user.Filters.TOTPConfig.Protocols) { + if !util.Contains(user.Filters.TOTPConfig.Protocols, p) { return util.NewValidationError(fmt.Sprintf("totp: the following protocols are required: %#v", strings.Join(user.Filters.TwoFactorAuthProtocols, ", "))) } diff --git a/httpd/api_shares.go b/httpd/api_shares.go index 8f648354..3b8ff27b 100644 --- a/httpd/api_shares.go +++ b/httpd/api_shares.go @@ -78,7 +78,7 @@ func addShare(w http.ResponseWriter, r *http.Request) { share.Name = share.ShareID } if share.Password == "" { - if util.IsStringInSlice(sdk.WebClientShareNoPasswordDisabled, claims.Permissions) { + if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", http.StatusForbidden) return @@ -121,7 +121,7 @@ func updateShare(w http.ResponseWriter, r *http.Request) { share.Password = oldPassword } if share.Password == "" { - if util.IsStringInSlice(sdk.WebClientShareNoPasswordDisabled, claims.Permissions) { + if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", http.StatusForbidden) return diff --git a/httpd/api_utils.go b/httpd/api_utils.go index 4e716056..b86155df 100644 --- a/httpd/api_utils.go +++ b/httpd/api_utils.go @@ -189,12 +189,12 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, } func renderAPIDirContents(w http.ResponseWriter, r *http.Request, contents []os.FileInfo, omitNonRegularFiles bool) { - results := make([]map[string]interface{}, 0, len(contents)) + results := make([]map[string]any, 0, len(contents)) for _, info := range contents { if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { continue } - res := make(map[string]interface{}) + res := make(map[string]any) res["name"] = info.Name() if info.Mode().IsRegular() { res["size"] = info.Size() @@ -508,7 +508,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err } func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error { - if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) { + if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { logger.Info(logSender, connectionID, "cannot login user %#v, protocol HTTP is not allowed", user.Username) return fmt.Errorf("protocol HTTP is not allowed for user %#v", user.Username) } @@ -581,8 +581,7 @@ func handleForgotPassword(r *http.Request, username string, isAdmin bool) error } logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %#v, email: %#v, is admin? %v, elapsed: %v", username, email, isAdmin, time.Since(startTime)) - resetCodes.Store(c.Code, c) - return nil + return resetCodesMgr.Add(c) } func handleResetPassword(r *http.Request, code, newPassword string, isAdmin bool) ( @@ -598,11 +597,10 @@ func handleResetPassword(r *http.Request, code, newPassword string, isAdmin bool if code == "" { return &admin, &user, util.NewValidationError("please set a confirmation code") } - c, ok := resetCodes.Load(code) - if !ok { + resetCode, err := resetCodesMgr.Get(code) + if err != nil { return &admin, &user, util.NewValidationError("confirmation code not found") } - resetCode := c.(*resetCode) if resetCode.IsAdmin != isAdmin { return &admin, &user, util.NewValidationError("invalid confirmation code") } @@ -616,8 +614,8 @@ func handleResetPassword(r *http.Request, code, newPassword string, isAdmin bool if err != nil { return &admin, &user, util.NewGenericError(fmt.Sprintf("unable to set the new password: %v", err)) } - resetCodes.Delete(code) - return &admin, &user, nil + err = resetCodesMgr.Delete(code) + return &admin, &user, err } user, err = dataprovider.GetUserWithGroupSettings(resetCode.Username) if err != nil { @@ -631,7 +629,7 @@ func handleResetPassword(r *http.Request, code, newPassword string, isAdmin bool err = dataprovider.UpdateUserPassword(user.Username, newPassword, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr)) if err == nil { - resetCodes.Delete(code) + err = resetCodesMgr.Delete(code) } return &admin, &user, err } @@ -640,7 +638,7 @@ func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool if !user.CanResetPassword() { return false } - if util.IsStringInSlice(common.ProtocolHTTP, user.Filters.DeniedProtocols) { + if util.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { return false } if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP, nil) { diff --git a/httpd/auth_utils.go b/httpd/auth_utils.go index 0b0f6fb0..c70a9c69 100644 --- a/httpd/auth_utils.go +++ b/httpd/auth_utils.go @@ -65,8 +65,8 @@ func (c *jwtTokenClaims) hasUserAudience() bool { return false } -func (c *jwtTokenClaims) asMap() map[string]interface{} { - claims := make(map[string]interface{}) +func (c *jwtTokenClaims) asMap() map[string]any { + claims := make(map[string]any) claims[claimUsernameKey] = c.Username claims[claimPermissionsKey] = c.Permissions @@ -80,7 +80,7 @@ func (c *jwtTokenClaims) asMap() map[string]interface{} { return claims } -func (c *jwtTokenClaims) Decode(token map[string]interface{}) { +func (c *jwtTokenClaims) Decode(token map[string]any) { c.Permissions = nil username := token[claimUsernameKey] @@ -112,7 +112,7 @@ func (c *jwtTokenClaims) Decode(token map[string]interface{}) { permissions := token[claimPermissionsKey] switch v := permissions.(type) { - case []interface{}: + case []any: for _, elem := range v { switch elemValue := elem.(type) { case string: @@ -129,7 +129,7 @@ func (c *jwtTokenClaims) Decode(token map[string]interface{}) { secondFactorProtocols := token[claimRequiredTwoFactorProtocols] switch v := secondFactorProtocols.(type) { - case []interface{}: + case []any: for _, elem := range v { switch elemValue := elem.(type) { case string: @@ -140,24 +140,24 @@ func (c *jwtTokenClaims) Decode(token map[string]interface{}) { } func (c *jwtTokenClaims) isCriticalPermRemoved(permissions []string) bool { - if util.IsStringInSlice(dataprovider.PermAdminAny, permissions) { + if util.Contains(permissions, dataprovider.PermAdminAny) { return false } - if (util.IsStringInSlice(dataprovider.PermAdminManageAdmins, c.Permissions) || - util.IsStringInSlice(dataprovider.PermAdminAny, c.Permissions)) && - !util.IsStringInSlice(dataprovider.PermAdminManageAdmins, permissions) && - !util.IsStringInSlice(dataprovider.PermAdminAny, permissions) { + if (util.Contains(c.Permissions, dataprovider.PermAdminManageAdmins) || + util.Contains(c.Permissions, dataprovider.PermAdminAny)) && + !util.Contains(permissions, dataprovider.PermAdminManageAdmins) && + !util.Contains(permissions, dataprovider.PermAdminAny) { return true } return false } func (c *jwtTokenClaims) hasPerm(perm string) bool { - if util.IsStringInSlice(dataprovider.PermAdminAny, c.Permissions) { + if util.Contains(c.Permissions, dataprovider.PermAdminAny) { return true } - return util.IsStringInSlice(perm, c.Permissions) + return util.Contains(c.Permissions, perm) } func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) { @@ -172,13 +172,13 @@ func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenA return tokenAuth.Encode(claims) } -func (c *jwtTokenClaims) createTokenResponse(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (map[string]interface{}, error) { +func (c *jwtTokenClaims) createTokenResponse(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (map[string]any, error) { token, tokenString, err := c.createToken(tokenAuth, audience, ip) if err != nil { return nil, err } - response := make(map[string]interface{}) + response := make(map[string]any) response["access_token"] = tokenString response["expires_at"] = token.Expiration().Format(time.RFC3339) @@ -301,7 +301,7 @@ func getAdminFromToken(r *http.Request) *dataprovider.Admin { } func createCSRFToken(ip string) string { - claims := make(map[string]interface{}) + claims := make(map[string]any) now := time.Now().UTC() claims[jwt.JwtIDKey] = xid.New().String() @@ -324,12 +324,12 @@ func verifyCSRFToken(tokenString, ip string) error { return fmt.Errorf("unable to verify form token: %v", err) } - if !util.IsStringInSlice(tokenAudienceCSRF, token.Audience()) { + if !util.Contains(token.Audience(), tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF token audience") return errors.New("the form token is not valid") } - if !util.IsStringInSlice(ip, token.Audience()) { + if !util.Contains(token.Audience(), ip) { logger.Debug(logSender, "", "error validating CSRF token IP audience") return errors.New("the form token is not valid") } diff --git a/httpd/httpd.go b/httpd/httpd.go index 6235e2f5..f578d0ed 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -643,8 +643,10 @@ func (c *Conf) getRedacted() Conf { } // Initialize configures and starts the HTTP server -func (c *Conf) Initialize(configDir string) error { +func (c *Conf) Initialize(configDir string, isShared int) error { logger.Info(logSender, "", "initializing HTTP server with config %+v", c.getRedacted()) + resetCodesMgr = newResetCodeManager(isShared) + oidcMgr = newOIDCManager(isShared) staticFilesPath := util.FindSharedDataPath(c.StaticFilesPath, configDir) templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) openAPIPath := util.FindSharedDataPath(c.OpenAPIPath, configDir) @@ -862,13 +864,18 @@ func startCleanupTicker(duration time.Duration) { cleanupDone = make(chan bool) go func() { + counter := int64(0) for { select { case <-cleanupDone: return case <-cleanupTicker.C: + counter++ cleanupExpiredJWTTokens() - cleanupExpiredResetCodes() + resetCodesMgr.Cleanup() + if counter%2 == 0 { + oidcMgr.cleanup() + } } } }() @@ -883,7 +890,7 @@ func stopCleanupTicker() { } func cleanupExpiredJWTTokens() { - invalidatedJWTTokens.Range(func(key, value interface{}) bool { + invalidatedJWTTokens.Range(func(key, value any) bool { exp, ok := value.(time.Time) if !ok || exp.Before(time.Now().UTC()) { invalidatedJWTTokens.Delete(key) diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 6051d414..4f22e9ed 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -373,7 +373,7 @@ func TestMain(m *testing.M) { sftpdConf.HostKeys = []string{hostKeyPath} go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } @@ -412,7 +412,7 @@ func TestMain(m *testing.M) { httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{}) go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTPS server: %v", err) os.Exit(1) } @@ -437,6 +437,7 @@ func TestMain(m *testing.M) { } func TestInitialization(t *testing.T) { + isShared := 0 err := config.LoadConfig(configDir, "") assert.NoError(t, err) invalidFile := "invalid file" @@ -445,12 +446,12 @@ func TestInitialization(t *testing.T) { defaultStaticPath := httpdConf.StaticFilesPath httpdConf.CertificateFile = invalidFile httpdConf.CertificateKeyFile = invalidFile - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CertificateFile = "" httpdConf.CertificateKeyFile = "" httpdConf.TemplatesPath = "." - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf = config.GetHTTPDConfig() httpdConf.TemplatesPath = defaultTemplatesPath @@ -458,22 +459,22 @@ func TestInitialization(t *testing.T) { httpdConf.CertificateKeyFile = invalidFile httpdConf.StaticFilesPath = "" httpdConf.TemplatesPath = "" - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.StaticFilesPath = defaultStaticPath httpdConf.TemplatesPath = defaultTemplatesPath httpdConf.CertificateFile = filepath.Join(os.TempDir(), "test.crt") httpdConf.CertificateKeyFile = filepath.Join(os.TempDir(), "test.key") httpdConf.CACertificates = append(httpdConf.CACertificates, invalidFile) - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CACertificates = nil httpdConf.CARevocationLists = append(httpdConf.CARevocationLists, invalidFile) - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) assert.Error(t, err) httpdConf.CARevocationLists = nil httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"} - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "is not a valid IP range") } @@ -483,7 +484,7 @@ func TestInitialization(t *testing.T) { httpdConf.Bindings[0].Port = 8081 httpdConf.Bindings[0].EnableHTTPS = true httpdConf.Bindings[0].ClientAuthType = 1 - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, 0) assert.Error(t, err) httpdConf.Bindings[0].OIDC = httpd.OIDC{ @@ -491,12 +492,12 @@ func TestInitialization(t *testing.T) { ClientSecret: "secret", ConfigURL: "http://127.0.0.1:11111", } - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, 0) if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc") } httpdConf.Bindings[0].OIDC.UsernameField = "preferred_username" - err = httpdConf.Initialize(configDir) + err = httpdConf.Initialize(configDir, isShared) if assert.Error(t, err) { assert.Contains(t, err.Error(), "oidc") } @@ -516,7 +517,7 @@ func TestBasicUserHandling(t *testing.T) { user.AdditionalInfo = "some free text" user.Filters.TLSUsername = sdk.TLSUsernameCN user.Email = "user@example.net" - user.OIDCCustomFields = &map[string]interface{}{ + user.OIDCCustomFields = &map[string]any{ "field1": "value1", } user.Filters.WebClient = append(user.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled, @@ -1301,7 +1302,7 @@ func TestHTTPUserAuthentication(t *testing.T) { c.CloseIdleConnections() assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) @@ -1356,7 +1357,7 @@ func TestHTTPUserAuthentication(t *testing.T) { resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder = make(map[string]interface{}) + responseHolder = make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) adminToken := responseHolder["access_token"].(string) @@ -1571,7 +1572,7 @@ func TestTwoFactorRequirements(t *testing.T) { resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) @@ -1620,7 +1621,7 @@ func TestLoginUserAPITOTP(t *testing.T) { user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) // two factor auth cannot be disabled - config := make(map[string]interface{}) + config := make(map[string]any) config["enabled"] = false asJSON, err = json.Marshal(config) assert.NoError(t, err) @@ -1667,7 +1668,7 @@ func TestLoginUserAPITOTP(t *testing.T) { resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) @@ -1743,7 +1744,7 @@ func TestLoginAdminAPITOTP(t *testing.T) { resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) adminToken := responseHolder["access_token"].(string) @@ -1774,7 +1775,7 @@ func TestHTTPStreamZipError(t *testing.T) { resp, err := httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) userToken := responseHolder["access_token"].(string) @@ -1968,7 +1969,7 @@ func TestAdminInvalidCredentials(t *testing.T) { resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) err = resp.Body.Close() @@ -1979,7 +1980,7 @@ func TestAdminInvalidCredentials(t *testing.T) { resp, err = httpclient.GetHTTPClient().Do(req) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - responseHolder = make(map[string]interface{}) + responseHolder = make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) assert.NoError(t, err) err = resp.Body.Close() @@ -2467,7 +2468,7 @@ func TestMetadataAPI(t *testing.T) { setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var resp []interface{} + var resp []any err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) assert.Len(t, resp, 0) @@ -2484,7 +2485,7 @@ func TestMetadataAPI(t *testing.T) { setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var resp []interface{} + var resp []any err = json.Unmarshal(rr.Body.Bytes(), &resp) assert.NoError(t, err) return len(resp) == 0 @@ -2744,7 +2745,7 @@ func TestUpdateUserEmptyPassword(t *testing.T) { assert.NotEmpty(t, dbUser.Password) assert.True(t, dbUser.IsPasswordHashed()) // now update the user and set an empty password - customUser := make(map[string]interface{}) + customUser := make(map[string]any) customUser["password"] = "" asJSON, err := json.Marshal(customUser) assert.NoError(t, err) @@ -5775,7 +5776,7 @@ func TestBasicUserHandlingMock(t *testing.T) { assert.Equal(t, user.MaxSessions, updatedUser.MaxSessions) assert.Equal(t, user.UploadBandwidth, updatedUser.UploadBandwidth) assert.Equal(t, 1, len(updatedUser.Permissions["/"])) - assert.True(t, util.IsStringInSlice(dataprovider.PermAny, updatedUser.Permissions["/"])) + assert.True(t, util.Contains(updatedUser.Permissions["/"], dataprovider.PermAny)) req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+user.Username, nil) setBearerForReq(req, token) rr = executeRequest(req) @@ -6859,7 +6860,7 @@ func TestSearchEvents(t *testing.T) { setBearerForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - events := make([]map[string]interface{}, 0) + events := make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { @@ -6889,7 +6890,7 @@ func TestSearchEvents(t *testing.T) { setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - events = make([]map[string]interface{}, 0) + events = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &events) assert.NoError(t, err) if assert.Len(t, events, 1) { @@ -7467,7 +7468,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { email := "userapi@example.com" description := "user API description" - profileReq := make(map[string]interface{}) + profileReq := make(map[string]any) profileReq["allow_api_key_auth"] = true profileReq["email"] = email profileReq["description"] = description @@ -7480,7 +7481,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) @@ -7491,9 +7492,9 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description, profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) - assert.Len(t, profileReq["public_keys"].([]interface{}), 2) + assert.Len(t, profileReq["public_keys"].([]any), 2) // set an invalid email - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) profileReq["email"] = "notavalidemail" asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) @@ -7504,7 +7505,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Validation error: email") // set an invalid public key - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) profileReq["public_keys"] = []string{"not a public key"} asJSON, err = json.Marshal(profileReq) assert.NoError(t, err) @@ -7524,7 +7525,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) profileReq["allow_api_key_auth"] = false profileReq["email"] = email profileReq["description"] = description + "_mod" @@ -7538,7 +7539,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), "Profile updated") // check that api key auth and public keys were not changed - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) @@ -7549,7 +7550,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description+"_mod", profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) - assert.Len(t, profileReq["public_keys"].([]interface{}), 2) + assert.Len(t, profileReq["public_keys"].([]any), 2) user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled} user.Description = description + "_mod" @@ -7558,7 +7559,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) assert.NoError(t, err) - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) profileReq["allow_api_key_auth"] = false profileReq["email"] = "newemail@apiuser.com" profileReq["description"] = description @@ -7570,7 +7571,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { setBearerForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) @@ -7581,7 +7582,7 @@ func TestWebAPIChangeUserProfileMock(t *testing.T) { assert.Equal(t, email, profileReq["email"].(string)) assert.Equal(t, description+"_mod", profileReq["description"].(string)) assert.True(t, profileReq["allow_api_key_auth"].(bool)) - assert.Len(t, profileReq["public_keys"].([]interface{}), 1) + assert.Len(t, profileReq["public_keys"].([]any), 1) // finally disable all profile permissions user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled, sdk.WebClientPubKeyChangeDisabled} @@ -7765,7 +7766,7 @@ func TestWebAPIChangeAdminProfileMock(t *testing.T) { email := "adminapi@example.com" description := "admin API description" - profileReq := make(map[string]interface{}) + profileReq := make(map[string]any) profileReq["allow_api_key_auth"] = true profileReq["email"] = email profileReq["description"] = description @@ -7778,7 +7779,7 @@ func TestWebAPIChangeAdminProfileMock(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) assert.Contains(t, rr.Body.String(), "Profile updated") - profileReq = make(map[string]interface{}) + profileReq = make(map[string]any) req, err = http.NewRequest(http.MethodGet, adminProfilePath, nil) assert.NoError(t, err) setBearerForReq(req, token) @@ -8187,7 +8188,7 @@ func TestUpdateUserMock(t *testing.T) { for dir, perms := range permissions { if actualPerms, ok := updatedUser.Permissions[dir]; ok { for _, v := range actualPerms { - assert.True(t, util.IsStringInSlice(v, perms)) + assert.True(t, util.Contains(perms, v)) } } else { assert.Fail(t, "Permissions directories mismatch") @@ -8346,7 +8347,7 @@ func TestUserPermissionsMock(t *testing.T) { err = render.DecodeJSON(rr.Body, &updatedUser) assert.NoError(t, err) if val, ok := updatedUser.Permissions["/otherdir"]; ok { - assert.True(t, util.IsStringInSlice(dataprovider.PermListItems, val)) + assert.True(t, util.Contains(val, dataprovider.PermListItems)) assert.Equal(t, 1, len(val)) } else { assert.Fail(t, "expected dir not found in permissions") @@ -10015,7 +10016,7 @@ func TestShareUsage(t *testing.T) { checkResponseCode(t, http.StatusNotFound, rr) share.ExpiresAt = 0 - jsonReq := make(map[string]interface{}) + jsonReq := make(map[string]any) jsonReq["name"] = share.Name jsonReq["scope"] = share.Scope jsonReq["paths"] = share.Paths @@ -10722,7 +10723,7 @@ func TestBrowseShares(t *testing.T) { assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - contents := make([]map[string]interface{}, 0) + contents := make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 2) @@ -10731,7 +10732,7 @@ func TestBrowseShares(t *testing.T) { assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - contents = make([]map[string]interface{}, 0) + contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 2) @@ -10740,7 +10741,7 @@ func TestBrowseShares(t *testing.T) { assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - contents = make([]map[string]interface{}, 0) + contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 1) @@ -10955,7 +10956,7 @@ func TestBrowseShares(t *testing.T) { assert.NoError(t, err) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - contents = make([]map[string]interface{}, 0) + contents = make([]map[string]any, 0) err = json.Unmarshal(rr.Body.Bytes(), &contents) assert.NoError(t, err) assert.Len(t, contents, 1) @@ -11437,7 +11438,7 @@ func TestUserAPIKey(t *testing.T) { setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var dirEntries []map[string]interface{} + var dirEntries []map[string]any err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) assert.NoError(t, err) assert.Len(t, dirEntries, 1) @@ -11668,7 +11669,7 @@ func TestWebGetFiles(t *testing.T) { setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var dirEntries []map[string]interface{} + var dirEntries []map[string]any err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) assert.NoError(t, err) assert.Len(t, dirEntries, 1) @@ -11918,7 +11919,7 @@ func TestWebDirsAPI(t *testing.T) { setBearerForReq(req, webAPIToken) rr := executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var contents []map[string]interface{} + var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 0) @@ -12197,7 +12198,7 @@ func TestWebFilesAPI(t *testing.T) { setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var contents []map[string]interface{} + var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 2) @@ -12398,7 +12399,7 @@ func TestStartDirectory(t *testing.T) { setBearerForReq(req, webAPIToken) rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var contents []map[string]interface{} + var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) if assert.Len(t, contents, 1) { @@ -14628,7 +14629,7 @@ func TestAPIKeyOnDeleteCascade(t *testing.T) { setAPIKeyForReq(req, apiKey.Key, "") rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - var contents []map[string]interface{} + var contents []map[string]any err = json.NewDecoder(rr.Body).Decode(&contents) assert.NoError(t, err) assert.Len(t, contents, 0) @@ -15570,10 +15571,10 @@ func TestWebUserAddMock(t *testing.T) { assert.False(t, newUser.Filters.AllowAPIKeyAuth) assert.Equal(t, user.Email, newUser.Email) assert.Equal(t, "/start/dir", newUser.Filters.StartDirectory) - assert.True(t, util.IsStringInSlice(testPubKey, newUser.PublicKeys)) + assert.True(t, util.Contains(newUser.PublicKeys, testPubKey)) if val, ok := newUser.Permissions["/subdir"]; ok { - assert.True(t, util.IsStringInSlice(dataprovider.PermListItems, val)) - assert.True(t, util.IsStringInSlice(dataprovider.PermDownload, val)) + assert.True(t, util.Contains(val, dataprovider.PermListItems)) + assert.True(t, util.Contains(val, dataprovider.PermDownload)) } else { assert.Fail(t, "user permissions must contain /somedir", "actual: %v", newUser.Permissions) } @@ -15592,20 +15593,20 @@ func TestWebUserAddMock(t *testing.T) { case "/dir1": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 1) - assert.True(t, util.IsStringInSlice("*.png", filter.AllowedPatterns)) - assert.True(t, util.IsStringInSlice("*.zip", filter.DeniedPatterns)) + assert.True(t, util.Contains(filter.AllowedPatterns, "*.png")) + assert.True(t, util.Contains(filter.DeniedPatterns, "*.zip")) assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) case "/dir2": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 2) - assert.True(t, util.IsStringInSlice("*.jpg", filter.AllowedPatterns)) - assert.True(t, util.IsStringInSlice("*.png", filter.AllowedPatterns)) - assert.True(t, util.IsStringInSlice("*.mkv", filter.DeniedPatterns)) + assert.True(t, util.Contains(filter.AllowedPatterns, "*.jpg")) + assert.True(t, util.Contains(filter.AllowedPatterns, "*.png")) + assert.True(t, util.Contains(filter.DeniedPatterns, "*.mkv")) assert.Equal(t, sdk.DenyPolicyHide, filter.DenyPolicy) case "/dir3": assert.Len(t, filter.DeniedPatterns, 1) assert.Len(t, filter.AllowedPatterns, 0) - assert.True(t, util.IsStringInSlice("*.rar", filter.DeniedPatterns)) + assert.True(t, util.Contains(filter.DeniedPatterns, "*.rar")) assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) } } @@ -15845,16 +15846,16 @@ func TestWebUserUpdateMock(t *testing.T) { assert.Equal(t, int64(0), updateUser.UploadDataTransfer) assert.Equal(t, int64(0), updateUser.Filters.ExternalAuthCacheTime) if val, ok := updateUser.Permissions["/otherdir"]; ok { - assert.True(t, util.IsStringInSlice(dataprovider.PermListItems, val)) - assert.True(t, util.IsStringInSlice(dataprovider.PermUpload, val)) + assert.True(t, util.Contains(val, dataprovider.PermListItems)) + assert.True(t, util.Contains(val, dataprovider.PermUpload)) } else { assert.Fail(t, "user permissions must contains /otherdir", "actual: %v", updateUser.Permissions) } - assert.True(t, util.IsStringInSlice("192.168.1.3/32", updateUser.Filters.AllowedIP)) - assert.True(t, util.IsStringInSlice("10.0.0.2/32", updateUser.Filters.DeniedIP)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, updateUser.Filters.DeniedLoginMethods)) - assert.True(t, util.IsStringInSlice(common.ProtocolFTP, updateUser.Filters.DeniedProtocols)) - assert.True(t, util.IsStringInSlice("*.zip", updateUser.Filters.FilePatterns[0].DeniedPatterns)) + assert.True(t, util.Contains(updateUser.Filters.AllowedIP, "192.168.1.3/32")) + assert.True(t, util.Contains(updateUser.Filters.DeniedIP, "10.0.0.2/32")) + assert.True(t, util.Contains(updateUser.Filters.DeniedLoginMethods, dataprovider.SSHLoginMethodKeyboardInteractive)) + assert.True(t, util.Contains(updateUser.Filters.DeniedProtocols, common.ProtocolFTP)) + assert.True(t, util.Contains(updateUser.Filters.FilePatterns[0].DeniedPatterns, "*.zip")) assert.Len(t, updateUser.Filters.BandwidthLimits, 0) req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) assert.NoError(t, err) @@ -18864,7 +18865,7 @@ func getJWTAPITokenFromTestServer(username, password string) (string, error) { if rr.Code != http.StatusOK { return "", fmt.Errorf("unexpected status code %v", rr.Code) } - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err := render.DecodeJSON(rr.Body, &responseHolder) if err != nil { return "", err @@ -18879,7 +18880,7 @@ func getJWTAPIUserTokenFromTestServer(username, password string) (string, error) if rr.Code != http.StatusOK { return "", fmt.Errorf("unexpected status code %v", rr.Code) } - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err := render.DecodeJSON(rr.Body, &responseHolder) if err != nil { return "", err diff --git a/httpd/internal_test.go b/httpd/internal_test.go index db0a7519..b8c97a94 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "database/sql" "encoding/json" "errors" "fmt" @@ -728,7 +729,7 @@ func TestCSRFToken(t *testing.T) { assert.Contains(t, err.Error(), "unable to verify form token") } // bad audience - claims := make(map[string]interface{}) + claims := make(map[string]any) now := time.Now().UTC() claims[jwt.JwtIDKey] = xid.New().String() @@ -1008,7 +1009,7 @@ func TestAPIKeyAuthForbidden(t *testing.T) { func TestJWTTokenValidation(t *testing.T) { tokenAuth := jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) - claims := make(map[string]interface{}) + claims := make(map[string]any) claims["username"] = defaultAdminUsername claims[jwt.ExpirationKey] = time.Now().UTC().Add(-1 * time.Hour) token, _, err := tokenAuth.Encode(claims) @@ -1103,7 +1104,7 @@ func TestUpdateContextFromCookie(t *testing.T) { tokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) - claims := make(map[string]interface{}) + claims := make(map[string]any) claims["a"] = "b" token, _, err := server.tokenAuth.Encode(claims) assert.NoError(t, err) @@ -1125,7 +1126,7 @@ func TestCookieExpiration(t *testing.T) { assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - claims := make(map[string]interface{}) + claims := make(map[string]any) claims["a"] = "b" token, _, err := server.tokenAuth.Encode(claims) assert.NoError(t, err) @@ -1139,7 +1140,7 @@ func TestCookieExpiration(t *testing.T) { Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } - claims = make(map[string]interface{}) + claims = make(map[string]any) claims[claimUsernameKey] = admin.Username claims[claimPermissionsKey] = admin.Permissions claims[jwt.SubjectKey] = admin.GetSignature() @@ -1174,7 +1175,7 @@ func TestCookieExpiration(t *testing.T) { admin, err = dataprovider.AdminExists(admin.Username) assert.NoError(t, err) - claims = make(map[string]interface{}) + claims = make(map[string]any) claims[claimUsernameKey] = admin.Username claims[claimPermissionsKey] = admin.Permissions claims[jwt.SubjectKey] = admin.GetSignature() @@ -1212,7 +1213,7 @@ func TestCookieExpiration(t *testing.T) { user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{"*"} - claims = make(map[string]interface{}) + claims = make(map[string]any) claims[claimUsernameKey] = user.Username claims[claimPermissionsKey] = user.Filters.WebClient claims[jwt.SubjectKey] = user.GetSignature() @@ -1244,7 +1245,7 @@ func TestCookieExpiration(t *testing.T) { user, err = dataprovider.UserExists(user.Username) assert.NoError(t, err) - claims = make(map[string]interface{}) + claims = make(map[string]any) claims[claimUsernameKey] = user.Username claims[claimPermissionsKey] = user.Filters.WebClient claims[jwt.SubjectKey] = user.GetSignature() @@ -1502,7 +1503,7 @@ func TestJWTTokenCleanup(t *testing.T) { Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } - claims := make(map[string]interface{}) + claims := make(map[string]any) claims[claimUsernameKey] = admin.Username claims[claimPermissionsKey] = admin.Permissions claims[jwt.SubjectKey] = admin.GetSignature() @@ -2234,10 +2235,11 @@ func TestLoginLinks(t *testing.T) { func TestResetCodesCleanup(t *testing.T) { resetCode := newResetCode(util.GenerateUniqueID(), false) resetCode.ExpiresAt = time.Now().Add(-1 * time.Minute).UTC() - resetCodes.Store(resetCode.Code, resetCode) - cleanupExpiredResetCodes() - _, ok := resetCodes.Load(resetCode.Code) - assert.False(t, ok) + err := resetCodesMgr.Add(resetCode) + assert.NoError(t, err) + resetCodesMgr.Cleanup() + _, err = resetCodesMgr.Get(resetCode.Code) + assert.Error(t, err) } func TestUserCanResetPassword(t *testing.T) { @@ -2556,3 +2558,55 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { installationCode = "" SetInstallationCodeResolver(nil) } + +func TestDbResetCodeManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newResetCodeManager(1) + resetCode := newResetCode("admin", true) + err := mgr.Add(resetCode) + assert.NoError(t, err) + codeGet, err := mgr.Get(resetCode.Code) + assert.NoError(t, err) + assert.Equal(t, resetCode, codeGet) + err = mgr.Delete(resetCode.Code) + assert.NoError(t, err) + err = mgr.Delete(resetCode.Code) + if assert.Error(t, err) { + _, ok := err.(*util.RecordNotFoundError) + assert.True(t, ok) + } + _, err = mgr.Get(resetCode.Code) + assert.ErrorIs(t, err, sql.ErrNoRows) + // add an expired reset code + resetCode = newResetCode("user", false) + resetCode.ExpiresAt = time.Now().Add(-24 * time.Hour) + err = mgr.Add(resetCode) + assert.NoError(t, err) + _, err = mgr.Get(resetCode.Code) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "reset code expired") + } + mgr.Cleanup() + _, err = mgr.Get(resetCode.Code) + assert.ErrorIs(t, err, sql.ErrNoRows) + + dbMgr, ok := mgr.(*dbResetCodeManager) + if assert.True(t, ok) { + _, err = dbMgr.decodeData("astring") + assert.Error(t, err) + } +} + +func isSharedProviderSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} diff --git a/httpd/middleware.go b/httpd/middleware.go index bef53977..a9882e05 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -71,13 +71,13 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil { return err } - if !util.IsStringInSlice(audience, token.Audience()) { + if !util.Contains(token.Audience(), audience) { logger.Debug(logSender, "", "the token is not valid for audience %#v", audience) doRedirect("Your token audience is not valid", nil) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) - if !util.IsStringInSlice(ipAddr, token.Audience()) { + if !util.Contains(token.Audience(), ipAddr) { logger.Debug(logSender, "", "the token with id %#v is not valid for the ip address %#v", token.JwtID(), ipAddr) doRedirect("Your token is not valid", nil) return errInvalidToken @@ -101,7 +101,7 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req notFoundFunc(w, r, nil) return errInvalidToken } - if !util.IsStringInSlice(audience, token.Audience()) { + if !util.Contains(token.Audience(), audience) { logger.Debug(logSender, "", "the token is not valid for audience %#v", audience) notFoundFunc(w, r, nil) return errInvalidToken @@ -277,13 +277,13 @@ func verifyCSRFHeader(next http.Handler) http.Handler { return } - if !util.IsStringInSlice(tokenAudienceCSRF, token.Audience()) { + if !util.Contains(token.Audience(), tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF header token audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return } - if !util.IsStringInSlice(util.GetIPFromRemoteAddress(r.RemoteAddr), token.Audience()) { + if !util.Contains(token.Audience(), util.GetIPFromRemoteAddress(r.RemoteAddr)) { logger.Debug(logSender, "", "error validating CSRF header IP audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return @@ -464,11 +464,11 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu } func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, tokenAudience []string) error { - if audience == tokenAudienceWebAdmin && util.IsStringInSlice(tokenAudienceWebAdminPartial, tokenAudience) { + if audience == tokenAudienceWebAdmin && util.Contains(tokenAudience, tokenAudienceWebAdminPartial) { http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) return errInvalidToken } - if audience == tokenAudienceWebClient && util.IsStringInSlice(tokenAudienceWebClientPartial, tokenAudience) { + if audience == tokenAudienceWebClient && util.Contains(tokenAudience, tokenAudienceWebClientPartial) { http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound) return errInvalidToken } diff --git a/httpd/oidc.go b/httpd/oidc.go index 4b6dbf20..eec7bd35 100644 --- a/httpd/oidc.go +++ b/httpd/oidc.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -31,17 +30,8 @@ const ( var ( oidcTokenKey = &contextKey{"OIDC token key"} oidcGeneratedToken = &contextKey{"OIDC generated token"} - oidcMgr *oidcManager ) -func init() { - oidcMgr = &oidcManager{ - pendingAuths: make(map[string]oidcPendingAuth), - tokens: make(map[string]oidcToken), - lastCleanup: time.Now(), - } -} - // OAuth2Config defines an interface for OAuth2 methods, so we can mock them type OAuth2Config interface { AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string @@ -119,7 +109,7 @@ func (o *OIDC) initialize() error { if err != nil { return fmt.Errorf("oidc: unable to initialize provider for URL %#v: %w", o.ConfigURL, err) } - claims := make(map[string]interface{}) + claims := make(map[string]any) // we cannot get an error here because the response body was already parsed as JSON // on provider creation provider.Claims(&claims) //nolint:errcheck @@ -146,10 +136,10 @@ func (o *OIDC) initialize() error { } type oidcPendingAuth struct { - State string - Nonce string - Audience tokenAudience - IssueAt int64 + State string `json:"state"` + Nonce string `json:"nonce"` + Audience tokenAudience `json:"audience"` + IssuedAt int64 `json:"issued_at"` } func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth { @@ -157,27 +147,27 @@ func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth { State: xid.New().String(), Nonce: xid.New().String(), Audience: audience, - IssueAt: util.GetTimeAsMsSinceEpoch(time.Now()), + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), } } type oidcToken struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - ExpiresAt int64 `json:"expires_at,omitempty"` - SessionID string `json:"session_id"` - IDToken string `json:"id_token"` - Nonce string `json:"nonce"` - Username string `json:"username"` - Permissions []string `json:"permissions"` - Role interface{} `json:"role"` - CustomFields *map[string]interface{} `json:"custom_fields,omitempty"` - Cookie string `json:"cookie"` - UsedAt int64 `json:"used_at"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + SessionID string `json:"session_id"` + IDToken string `json:"id_token"` + Nonce string `json:"nonce"` + Username string `json:"username"` + Permissions []string `json:"permissions"` + Role any `json:"role"` + CustomFields *map[string]any `json:"custom_fields,omitempty"` + Cookie string `json:"cookie"` + UsedAt int64 `json:"used_at"` } -func (t *oidcToken) parseClaims(claims map[string]interface{}, usernameField, roleField string, customFields []string) error { +func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField string, customFields []string) error { getClaimsFields := func() []string { keys := make([]string, 0, len(claims)) for k := range claims { @@ -203,7 +193,7 @@ func (t *oidcToken) parseClaims(claims map[string]interface{}, usernameField, ro for _, field := range customFields { if val, ok := claims[field]; ok { if t.CustomFields == nil { - customFields := make(map[string]interface{}) + customFields := make(map[string]any) t.CustomFields = &customFields } logger.Debug(logSender, "", "custom field %#v found in token claims", field) @@ -224,7 +214,7 @@ func (t *oidcToken) isAdmin() bool { switch v := t.Role.(type) { case string: return v == "admin" - case []interface{}: + case []any: for _, s := range v { if val, ok := s.(string); ok && val == "admin" { return true @@ -288,7 +278,7 @@ func (t *oidcToken) refresh(config OAuth2Config, verifier OIDCTokenVerifier) err logger.Debug(logSender, "", "unable to verify refreshed id token for cookie %#v: nonce mismatch", t.Cookie) return errors.New("the refreshed token nonce mismatch") } - claims := make(map[string]interface{}) + claims := make(map[string]any) err = idToken.Claims(&claims) if err != nil { logger.Debug(logSender, "", "unable to get refreshed id token claims for cookie %#v: %v", t.Cookie, err) @@ -348,119 +338,6 @@ func (t *oidcToken) getUser(r *http.Request) error { return nil } -type oidcManager struct { - authMutex sync.RWMutex - pendingAuths map[string]oidcPendingAuth - tokenMutex sync.RWMutex - tokens map[string]oidcToken - lastCleanup time.Time -} - -func (o *oidcManager) addPendingAuth(pendingAuth oidcPendingAuth) { - o.authMutex.Lock() - o.pendingAuths[pendingAuth.State] = pendingAuth - o.authMutex.Unlock() - - o.checkCleanup() -} - -func (o *oidcManager) removePendingAuth(key string) { - o.authMutex.Lock() - defer o.authMutex.Unlock() - - delete(o.pendingAuths, key) -} - -func (o *oidcManager) getPendingAuth(state string) (oidcPendingAuth, error) { - o.authMutex.RLock() - defer o.authMutex.RUnlock() - - authReq, ok := o.pendingAuths[state] - if !ok { - return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state") - } - diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssueAt - if diff > authStateValidity { - return oidcPendingAuth{}, errors.New("oidc: auth request is too old") - } - return authReq, nil -} - -func (o *oidcManager) addToken(token oidcToken) { - o.tokenMutex.Lock() - token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - o.tokens[token.Cookie] = token - o.tokenMutex.Unlock() - - o.checkCleanup() -} - -func (o *oidcManager) getToken(cookie string) (oidcToken, error) { - o.tokenMutex.RLock() - defer o.tokenMutex.RUnlock() - - token, ok := o.tokens[cookie] - if !ok { - return oidcToken{}, errors.New("oidc: no token found for the specified session") - } - return token, nil -} - -func (o *oidcManager) removeToken(cookie string) { - o.tokenMutex.Lock() - defer o.tokenMutex.Unlock() - - delete(o.tokens, cookie) -} - -func (o *oidcManager) updateTokenUsage(token oidcToken) { - diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt - if diff > tokenUpdateInterval { - o.addToken(token) - } -} - -func (o *oidcManager) checkCleanup() { - o.authMutex.RLock() - needCleanup := o.lastCleanup.Add(20 * time.Minute).Before(time.Now()) - o.authMutex.RUnlock() - - if needCleanup { - o.authMutex.Lock() - o.lastCleanup = time.Now() - o.authMutex.Unlock() - - o.cleanupAuthRequests() - o.cleanupTokens() - } -} - -func (o *oidcManager) cleanupAuthRequests() { - o.authMutex.Lock() - defer o.authMutex.Unlock() - - for k, auth := range o.pendingAuths { - diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssueAt - // remove old pending auth requests - if diff < 0 || diff > authStateValidity { - delete(o.pendingAuths, k) - } - } -} - -func (o *oidcManager) cleanupTokens() { - o.tokenMutex.Lock() - defer o.tokenMutex.Unlock() - - for k, token := range o.tokens { - diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt - // remove tokens unused from more than tokenDeleteInterval - if diff > tokenDeleteInterval { - delete(o.tokens, k) - } - } -} - func (s *httpdServer) validateOIDCToken(w http.ResponseWriter, r *http.Request, isAdmin bool) (oidcToken, error) { doRedirect := func() { removeOIDCCookie(w, r) @@ -614,7 +491,7 @@ func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request) return } - claims := make(map[string]interface{}) + claims := make(map[string]any) err = idToken.Claims(&claims) if err != nil { logger.Debug(logSender, "", "unable to get oidc token claims: %v", err) diff --git a/httpd/oidc_test.go b/httpd/oidc_test.go index 5cabaccf..49832887 100644 --- a/httpd/oidc_test.go +++ b/httpd/oidc_test.go @@ -103,6 +103,8 @@ func TestOIDCInitialization(t *testing.T) { } func TestOIDCLoginLogout(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) @@ -119,7 +121,7 @@ func TestOIDCLoginLogout(t *testing.T) { State: xid.New().String(), Nonce: xid.New().String(), Audience: tokenAudienceWebClient, - IssueAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), } oidcMgr.addPendingAuth(expiredAuthReq) rr = httptest.NewRecorder() @@ -209,7 +211,7 @@ func TestOIDCLoginLogout(t *testing.T) { AccessToken: "123", Expiry: time.Now().Add(5 * time.Minute), } - token = token.WithExtra(map[string]interface{}{ + token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ @@ -502,6 +504,8 @@ func TestOIDCLoginLogout(t *testing.T) { } func TestOIDCRefreshToken(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) token := oidcToken{ Cookie: xid.New().String(), AccessToken: xid.New().String(), @@ -542,7 +546,7 @@ func TestOIDCRefreshToken(t *testing.T) { if assert.Error(t, err) { assert.Contains(t, err.Error(), "the refreshed token has no id token") } - newToken = newToken.WithExtra(map[string]interface{}{ + newToken = newToken.WithExtra(map[string]any{ "id_token": "id_token_val", }) newToken.Expiry = time.Time{} @@ -557,7 +561,7 @@ func TestOIDCRefreshToken(t *testing.T) { err = token.refresh(&config, &verifier) assert.ErrorIs(t, err, common.ErrGenericFailure) - newToken = newToken.WithExtra(map[string]interface{}{ + newToken = newToken.WithExtra(map[string]any{ "id_token": "id_token_val", }) newToken.Expiry = time.Now().Add(5 * time.Minute) @@ -597,6 +601,8 @@ func TestOIDCRefreshToken(t *testing.T) { } func TestValidateOIDCToken(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) @@ -787,7 +793,9 @@ func TestOIDCToken(t *testing.T) { assert.NoError(t, err) } -func TestOIDCManager(t *testing.T) { +func TestMemoryOIDCManager(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) require.Len(t, oidcMgr.pendingAuths, 0) authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) @@ -796,19 +804,15 @@ func TestOIDCManager(t *testing.T) { assert.NoError(t, err) oidcMgr.removePendingAuth(authReq.State) require.Len(t, oidcMgr.pendingAuths, 0) - authReq.IssueAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second)) + authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second)) oidcMgr.addPendingAuth(authReq) require.Len(t, oidcMgr.pendingAuths, 1) _, err = oidcMgr.getPendingAuth(authReq.State) if assert.Error(t, err) { assert.Contains(t, err.Error(), "too old") } - oidcMgr.checkCleanup() - require.Len(t, oidcMgr.pendingAuths, 1) - oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour) - oidcMgr.checkCleanup() + oidcMgr.cleanup() require.Len(t, oidcMgr.pendingAuths, 0) - assert.True(t, oidcMgr.lastCleanup.After(time.Now().Add(-10*time.Second))) token := oidcToken{ AccessToken: xid.New().String(), @@ -826,6 +830,7 @@ func TestOIDCManager(t *testing.T) { assert.Error(t, err) storedToken, err := oidcMgr.getToken(token.Cookie) assert.NoError(t, err) + token.UsedAt = 0 // ensure we don't modify the stored token assert.Greater(t, storedToken.UsedAt, int64(0)) token.UsedAt = storedToken.UsedAt assert.Equal(t, token, storedToken) @@ -848,6 +853,12 @@ func TestOIDCManager(t *testing.T) { assert.Greater(t, storedToken.UsedAt, usedAt) token.UsedAt = storedToken.UsedAt assert.Equal(t, token, storedToken) + storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1 + oidcMgr.tokens[token.Cookie] = storedToken + storedToken, err = oidcMgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "token is too old") + } oidcMgr.removeToken(xid.New().String()) require.Len(t, oidcMgr.tokens, 1) oidcMgr.removeToken(token.Cookie) @@ -859,8 +870,8 @@ func TestOIDCManager(t *testing.T) { newToken := oidcToken{ Cookie: xid.New().String(), } - oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour) oidcMgr.addToken(newToken) + oidcMgr.cleanup() require.Len(t, oidcMgr.tokens, 1) _, err = oidcMgr.getToken(token.Cookie) assert.Error(t, err) @@ -874,6 +885,8 @@ func TestOIDCPreLoginHook(t *testing.T) { if runtime.GOOS == osWindows { t.Skip("this test is not available on Windows") } + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) username := "test_oidc_user_prelogin" u := dataprovider.User{ BaseUser: sdk.BaseUser{ @@ -902,7 +915,7 @@ func TestOIDCPreLoginHook(t *testing.T) { server.initializeRouter() _, err = dataprovider.UserExists(username) - _, ok := err.(*util.RecordNotFoundError) + _, ok = err.(*util.RecordNotFoundError) assert.True(t, ok) // now login with OIDC authReq := newOIDCPendingAuth(tokenAudienceWebClient) @@ -911,7 +924,7 @@ func TestOIDCPreLoginHook(t *testing.T) { AccessToken: "1234", Expiry: time.Now().Add(5 * time.Minute), } - token = token.WithExtra(map[string]interface{}{ + token = token.WithExtra(map[string]any{ "id_token": "id_token_val", }) server.binding.OIDC.oauth2Config = &mockOAuth2Config{ @@ -979,6 +992,143 @@ func TestOIDCPreLoginHook(t *testing.T) { assert.NoError(t, err) } +func TestOIDCIsAdmin(t *testing.T) { + type test struct { + input any + want bool + } + + emptySlice := make([]any, 0) + + tests := []test{ + {input: "admin", want: true}, + {input: append(emptySlice, "admin"), want: true}, + {input: append(emptySlice, "user", "admin"), want: true}, + {input: "user", want: false}, + {input: emptySlice, want: false}, + {input: append(emptySlice, 1), want: false}, + {input: 1, want: false}, + {input: nil, want: false}, + } + for _, tc := range tests { + token := oidcToken{ + Role: tc.input, + } + assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want) + } +} + +func TestDbOIDCManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newOIDCManager(1) + pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin) + mgr.addPendingAuth(pendingAuth) + authReq, err := mgr.getPendingAuth(pendingAuth.State) + assert.NoError(t, err) + assert.Equal(t, pendingAuth, authReq) + pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + mgr.addPendingAuth(pendingAuth) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "auth request is too old") + } + mgr.removePendingAuth(pendingAuth.State) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") + } + mgr.addPendingAuth(pendingAuth) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "auth request is too old") + } + mgr.cleanup() + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") + } + + token := oidcToken{ + Cookie: xid.New().String(), + AccessToken: xid.New().String(), + TokenType: "Bearer", + RefreshToken: xid.New().String(), + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), + SessionID: xid.New().String(), + IDToken: xid.New().String(), + Nonce: xid.New().String(), + Username: xid.New().String(), + Permissions: []string{dataprovider.PermAdminAny}, + Role: "admin", + } + mgr.addToken(token) + tokenGet, err := mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Greater(t, tokenGet.UsedAt, int64(0)) + token.UsedAt = tokenGet.UsedAt + assert.Equal(t, token, tokenGet) + time.Sleep(100 * time.Millisecond) + mgr.updateTokenUsage(token) + // no change + tokenGet, err = mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Equal(t, token.UsedAt, tokenGet.UsedAt) + tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + tokenGet.RefreshToken = xid.New().String() + mgr.updateTokenUsage(tokenGet) + tokenGet, err = mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.NotEmpty(t, tokenGet.RefreshToken) + assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken) + assert.Greater(t, tokenGet.UsedAt, token.UsedAt) + mgr.removeToken(token.Cookie) + tokenGet, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the token for the specified session") + } + // add an expired token + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + session := dataprovider.Session{ + Key: token.Cookie, + Data: token, + Type: dataprovider.SessionTypeOIDCToken, + Timestamp: token.UsedAt + tokenDeleteInterval, + } + err = dataprovider.AddSharedSession(session) + assert.NoError(t, err) + _, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "token is too old") + } + mgr.cleanup() + _, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the token for the specified session") + } + // adding a session without a key should fail + session.Key = "" + err = dataprovider.AddSharedSession(session) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to save a session with an empty key") + } + session.Key = xid.New().String() + session.Type = 1000 + err = dataprovider.AddSharedSession(session) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid session type") + } + + dbMgr, ok := mgr.(*dbOIDCManager) + if assert.True(t, ok) { + _, err = dbMgr.decodePendingAuthData(2) + assert.Error(t, err) + _, err = dbMgr.decodeTokenData(true) + assert.Error(t, err) + } +} + func getTestOIDCServer() *httpdServer { return &httpdServer{ binding: Binding{ @@ -1009,29 +1159,3 @@ func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []by } return content } - -func TestOIDCIsAdmin(t *testing.T) { - type test struct { - input interface{} - want bool - } - - emptySlice := make([]interface{}, 0) - - tests := []test{ - {input: "admin", want: true}, - {input: append(emptySlice, "admin"), want: true}, - {input: append(emptySlice, "user", "admin"), want: true}, - {input: "user", want: false}, - {input: emptySlice, want: false}, - {input: append(emptySlice, 1), want: false}, - {input: 1, want: false}, - {input: nil, want: false}, - } - for _, tc := range tests { - token := oidcToken{ - Role: tc.input, - } - assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want) - } -} diff --git a/httpd/oidcmanager.go b/httpd/oidcmanager.go new file mode 100644 index 00000000..7148d6f3 --- /dev/null +++ b/httpd/oidcmanager.go @@ -0,0 +1,230 @@ +package httpd + +import ( + "encoding/json" + "errors" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/logger" + "github.com/drakkan/sftpgo/v2/util" +) + +var ( + oidcMgr oidcManager +) + +func newOIDCManager(isShared int) oidcManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider OIDC manager") + return &dbOIDCManager{} + } + logger.Info(logSender, "", "using memory OIDC manager") + return &memoryOIDCManager{ + pendingAuths: make(map[string]oidcPendingAuth), + tokens: make(map[string]oidcToken), + } +} + +type oidcManager interface { + addPendingAuth(pendingAuth oidcPendingAuth) + removePendingAuth(state string) + getPendingAuth(state string) (oidcPendingAuth, error) + addToken(token oidcToken) + getToken(cookie string) (oidcToken, error) + removeToken(cookie string) + updateTokenUsage(token oidcToken) + cleanup() +} + +type memoryOIDCManager struct { + authMutex sync.RWMutex + pendingAuths map[string]oidcPendingAuth + tokenMutex sync.RWMutex + tokens map[string]oidcToken +} + +func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { + o.authMutex.Lock() + o.pendingAuths[pendingAuth.State] = pendingAuth + o.authMutex.Unlock() +} + +func (o *memoryOIDCManager) removePendingAuth(state string) { + o.authMutex.Lock() + defer o.authMutex.Unlock() + + delete(o.pendingAuths, state) +} + +func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { + o.authMutex.RLock() + defer o.authMutex.RUnlock() + + authReq, ok := o.pendingAuths[state] + if !ok { + return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state") + } + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt + if diff > authStateValidity { + return oidcPendingAuth{}, errors.New("oidc: auth request is too old") + } + return authReq, nil +} + +func (o *memoryOIDCManager) addToken(token oidcToken) { + o.tokenMutex.Lock() + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + o.tokens[token.Cookie] = token + o.tokenMutex.Unlock() +} + +func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) { + o.tokenMutex.RLock() + defer o.tokenMutex.RUnlock() + + token, ok := o.tokens[cookie] + if !ok { + return oidcToken{}, errors.New("oidc: no token found for the specified session") + } + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenDeleteInterval { + return oidcToken{}, errors.New("oidc: token is too old") + } + return token, nil +} + +func (o *memoryOIDCManager) removeToken(cookie string) { + o.tokenMutex.Lock() + defer o.tokenMutex.Unlock() + + delete(o.tokens, cookie) +} + +func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenUpdateInterval { + o.addToken(token) + } +} + +func (o *memoryOIDCManager) cleanup() { + logger.Debug(logSender, "", "oidc manager cleanup") + o.cleanupAuthRequests() + o.cleanupTokens() +} + +func (o *memoryOIDCManager) cleanupAuthRequests() { + o.authMutex.Lock() + defer o.authMutex.Unlock() + + for k, auth := range o.pendingAuths { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt + // remove old pending auth requests + if diff < 0 || diff > authStateValidity { + delete(o.pendingAuths, k) + } + } +} + +func (o *memoryOIDCManager) cleanupTokens() { + o.tokenMutex.Lock() + defer o.tokenMutex.Unlock() + + for k, token := range o.tokens { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + // remove tokens unused from more than tokenDeleteInterval + if diff > tokenDeleteInterval { + delete(o.tokens, k) + } + } +} + +type dbOIDCManager struct{} + +func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { + session := dataprovider.Session{ + Key: pendingAuth.State, + Data: pendingAuth, + Type: dataprovider.SessionTypeOIDCAuth, + Timestamp: pendingAuth.IssuedAt + authStateValidity, + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (o *dbOIDCManager) removePendingAuth(state string) { + dataprovider.DeleteSharedSession(state) //nolint:errcheck +} + +func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { + session, err := dataprovider.GetSharedSession(state) + if err != nil { + return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state") + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return oidcPendingAuth{}, errors.New("oidc: auth request is too old") + } + return o.decodePendingAuthData(session.Data) +} + +func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) { + if val, ok := data.([]byte); ok { + authReq := oidcPendingAuth{} + err := json.Unmarshal(val, &authReq) + return authReq, err + } + logger.Error(logSender, "", "invalid oidc auth request data type %T", data) + return oidcPendingAuth{}, errors.New("oidc: invalid auth request data") +} + +func (o *dbOIDCManager) addToken(token oidcToken) { + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + session := dataprovider.Session{ + Key: token.Cookie, + Data: token, + Type: dataprovider.SessionTypeOIDCToken, + Timestamp: token.UsedAt + tokenDeleteInterval, + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (o *dbOIDCManager) removeToken(cookie string) { + dataprovider.DeleteSharedSession(cookie) //nolint:errcheck +} + +func (o *dbOIDCManager) updateTokenUsage(token oidcToken) { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenUpdateInterval { + o.addToken(token) + } +} + +func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) { + session, err := dataprovider.GetSharedSession(cookie) + if err != nil { + return oidcToken{}, errors.New("oidc: unable to get the token for the specified session") + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return oidcToken{}, errors.New("oidc: token is too old") + } + return o.decodeTokenData(session.Data) +} + +func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) { + if val, ok := data.([]byte); ok { + token := oidcToken{} + err := json.Unmarshal(val, &token) + return token, err + } + logger.Error(logSender, "", "invalid oidc token data type %T", data) + return oidcToken{}, errors.New("oidc: invalid token data") +} + +func (o *dbOIDCManager) cleanup() { + logger.Debug(logSender, "", "oidc manager cleanup") + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck +} diff --git a/httpd/resetcode.go b/httpd/resetcode.go index a94846e9..d91edaa1 100644 --- a/httpd/resetcode.go +++ b/httpd/resetcode.go @@ -1,26 +1,41 @@ package httpd import ( + "encoding/json" "sync" "time" + "github.com/drakkan/sftpgo/v2/dataprovider" + "github.com/drakkan/sftpgo/v2/logger" "github.com/drakkan/sftpgo/v2/util" ) var ( resetCodeLifespan = 10 * time.Minute - resetCodes sync.Map + resetCodesMgr resetCodeManager ) -type resetCode struct { - Code string - Username string - IsAdmin bool - ExpiresAt time.Time +type resetCodeManager interface { + Add(code *resetCode) error + Get(code string) (*resetCode, error) + Delete(code string) error + Cleanup() } -func (c *resetCode) isExpired() bool { - return c.ExpiresAt.Before(time.Now().UTC()) +func newResetCodeManager(isShared int) resetCodeManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider reset code manager") + return &dbResetCodeManager{} + } + logger.Info(logSender, "", "using memory reset code manager") + return &memoryResetCodeManager{} +} + +type resetCode struct { + Code string `json:"code"` + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` + ExpiresAt time.Time `json:"expires_at"` } func newResetCode(username string, isAdmin bool) *resetCode { @@ -32,12 +47,80 @@ func newResetCode(username string, isAdmin bool) *resetCode { } } -func cleanupExpiredResetCodes() { - resetCodes.Range(func(key, value interface{}) bool { +func (c *resetCode) isExpired() bool { + return c.ExpiresAt.Before(time.Now().UTC()) +} + +type memoryResetCodeManager struct { + resetCodes sync.Map +} + +func (m *memoryResetCodeManager) Add(code *resetCode) error { + m.resetCodes.Store(code.Code, code) + return nil +} + +func (m *memoryResetCodeManager) Get(code string) (*resetCode, error) { + c, ok := m.resetCodes.Load(code) + if !ok { + return nil, util.NewRecordNotFoundError("reset code not found") + } + return c.(*resetCode), nil +} + +func (m *memoryResetCodeManager) Delete(code string) error { + m.resetCodes.Delete(code) + return nil +} + +func (m *memoryResetCodeManager) Cleanup() { + m.resetCodes.Range(func(key, value any) bool { c, ok := value.(*resetCode) if !ok || c.isExpired() { - resetCodes.Delete(key) + m.resetCodes.Delete(key) } return true }) } + +type dbResetCodeManager struct{} + +func (m *dbResetCodeManager) Add(code *resetCode) error { + session := dataprovider.Session{ + Key: code.Code, + Data: code, + Type: dataprovider.SessionTypeResetCode, + Timestamp: util.GetTimeAsMsSinceEpoch(code.ExpiresAt), + } + return dataprovider.AddSharedSession(session) +} + +func (m *dbResetCodeManager) Get(code string) (*resetCode, error) { + session, err := dataprovider.GetSharedSession(code) + if err != nil { + return nil, err + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return nil, util.NewRecordNotFoundError("reset code expired") + } + return m.decodeData(session.Data) +} + +func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) { + if val, ok := data.([]byte); ok { + c := &resetCode{} + err := json.Unmarshal(val, c) + return c, err + } + logger.Error(logSender, "", "invalid reset code data type %T", data) + return nil, util.NewRecordNotFoundError("invalid reset code") +} + +func (m *dbResetCodeManager) Delete(code string) error { + return dataprovider.DeleteSharedSession(code) +} + +func (m *dbResetCodeManager) Cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeResetCode, time.Now()) //nolint:errcheck +} diff --git a/httpd/server.go b/httpd/server.go index 035a9287..3749a7c9 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -307,7 +307,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter s.renderClientTwoFactorRecoveryPage(w, "Invalid credentials", ipAddr) return } - if !userMerged.Filters.TOTPConfig.Enabled || !util.IsStringInSlice(common.ProtocolHTTP, userMerged.Filters.TOTPConfig.Protocols) { + if !userMerged.Filters.TOTPConfig.Enabled || !util.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr) return } @@ -364,7 +364,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt s.renderClientTwoFactorPage(w, "Invalid credentials", ipAddr) return } - if !user.Filters.TOTPConfig.Enabled || !util.IsStringInSlice(common.ProtocolHTTP, user.Filters.TOTPConfig.Protocols) { + if !user.Filters.TOTPConfig.Enabled || !util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { s.renderClientTwoFactorPage(w, "Two factory authentication is not enabled", ipAddr) return } @@ -659,7 +659,7 @@ func (s *httpdServer) loginUser( } audience := tokenAudienceWebClient - if user.Filters.TOTPConfig.Enabled && util.IsStringInSlice(common.ProtocolHTTP, user.Filters.TOTPConfig.Protocols) && + if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) && user.CanManageMFA() && !isSecondFactorAuth { audience = tokenAudienceWebClientPartial } @@ -764,7 +764,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) { return } - if user.Filters.TOTPConfig.Enabled && util.IsStringInSlice(common.ProtocolHTTP, user.Filters.TOTPConfig.Protocols) { + if user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { passcode := r.Header.Get(otpHeaderCode) if passcode == "" { logger.Debug(logSender, "", "TOTP enabled for user %#v and not passcode provided, authentication refused", user.Username) @@ -902,7 +902,7 @@ func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Reque if time.Until(token.Expiration()) > tokenRefreshThreshold { return } - if util.IsStringInSlice(tokenAudienceWebClient, token.Audience()) { + if util.Contains(token.Audience(), tokenAudienceWebClient) { s.refreshClientToken(w, r, tokenClaims) } else { s.refreshAdminToken(w, r, tokenClaims) diff --git a/httpd/webadmin.go b/httpd/webadmin.go index c51f59cc..b10de826 100644 --- a/httpd/webadmin.go +++ b/httpd/webadmin.go @@ -486,14 +486,14 @@ func (s *httpdServer) getBasePageData(title, currentURL string, r *http.Request) Version: version.GetAsString(), LoggedAdmin: getAdminFromToken(r), HasDefender: common.Config.DefenderConfig.Enabled, - HasMFA: len(mfa.GetAvailableTOTPConfigNames()) > 0, + HasMFA: len(mfa.GetAvailableTOTPConfigs()) > 0, HasExternalLogin: isLoggedInWithOIDC(r), CSRFToken: csrfToken, Branding: s.binding.Branding.WebAdmin, } } -func renderAdminTemplate(w http.ResponseWriter, tmplName string, data interface{}) { +func renderAdminTemplate(w http.ResponseWriter, tmplName string, data any) { err := adminTemplates[tmplName].ExecuteTemplate(w, tmplName, data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -1109,13 +1109,13 @@ func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error) filters.TLSUsername = sdk.TLSUsername(r.Form.Get("tls_username")) filters.WebClient = r.Form["web_client_options"] hooks := r.Form["hooks"] - if util.IsStringInSlice("external_auth_disabled", hooks) { + if util.Contains(hooks, "external_auth_disabled") { filters.Hooks.ExternalAuthDisabled = true } - if util.IsStringInSlice("pre_login_disabled", hooks) { + if util.Contains(hooks, "pre_login_disabled") { filters.Hooks.PreLoginDisabled = true } - if util.IsStringInSlice("check_password_disabled", hooks) { + if util.Contains(hooks, "check_password_disabled") { filters.Hooks.CheckPasswordDisabled = true } filters.DisableFsChecks = len(r.Form.Get("disable_fs_checks")) > 0 diff --git a/httpd/webclient.go b/httpd/webclient.go index 77aa86c3..ec1f97f2 100644 --- a/httpd/webclient.go +++ b/httpd/webclient.go @@ -383,7 +383,7 @@ func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, error, ip renderClientTemplate(w, templateResetPassword, data) } -func renderClientTemplate(w http.ResponseWriter, tmplName string, data interface{}) { +func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) { err := clientTemplates[tmplName].ExecuteTemplate(w, tmplName, data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -812,7 +812,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. if len(s.binding.WebClientIntegrations) > 0 { extension := path.Ext(info.Name()) for idx := range s.binding.WebClientIntegrations { - if util.IsStringInSlice(extension, s.binding.WebClientIntegrations[idx].FileExtensions) { + if util.Contains(s.binding.WebClientIntegrations[idx].FileExtensions, extension) { res["ext_url"] = s.binding.WebClientIntegrations[idx].URL res["ext_link"] = fmt.Sprintf("%v?path=%v&_=%v", webClientFilePath, url.QueryEscape(path.Join(name, info.Name())), time.Now().UTC().Unix()) @@ -957,7 +957,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques return } - s.renderEditFilePage(w, r, name, b.String(), util.IsStringInSlice(sdk.WebClientWriteDisabled, user.Filters.WebClient)) + s.renderEditFilePage(w, r, name, b.String(), util.Contains(user.Filters.WebClient, sdk.WebClientWriteDisabled)) } func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Request) { @@ -1027,7 +1027,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re share.LastUseAt = 0 share.Username = claims.Username if share.Password == "" { - if util.IsStringInSlice(sdk.WebClientShareNoPasswordDisabled, claims.Permissions) { + if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { s.renderClientForbiddenPage(w, r, "You are not authorized to share files/folders without a password") return } @@ -1072,7 +1072,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http updatedShare.Password = share.Password } if updatedShare.Password == "" { - if util.IsStringInSlice(sdk.WebClientShareNoPasswordDisabled, claims.Permissions) { + if util.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { s.renderClientForbiddenPage(w, r, "You are not authorized to share files/folders without a password") return } diff --git a/httpdtest/httpdtest.go b/httpdtest/httpdtest.go index 3d808481..deb5e540 100644 --- a/httpdtest/httpdtest.go +++ b/httpdtest/httpdtest.go @@ -94,7 +94,7 @@ func buildURLRelativeToBase(paths ...string) string { } // GetToken tries to return a JWT token -func GetToken(username, password string) (string, map[string]interface{}, error) { +func GetToken(username, password string) (string, map[string]any, error) { req, err := http.NewRequest(http.MethodGet, buildURLRelativeToBase(tokenPath), nil) if err != nil { return "", nil, err @@ -110,7 +110,7 @@ func GetToken(username, password string) (string, map[string]interface{}, error) if err != nil { return "", nil, err } - responseHolder := make(map[string]interface{}) + responseHolder := make(map[string]any) err = render.DecodeJSON(resp.Body, &responseHolder) if err != nil { return "", nil, err @@ -985,8 +985,8 @@ func RemoveDefenderHostByIP(ip string, expectedStatusCode int) ([]byte, error) { } // GetBanTime returns the ban time for the given IP address -func GetBanTime(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) { - var response map[string]interface{} +func GetBanTime(ip string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(defenderBanTime)) if err != nil { @@ -1010,8 +1010,8 @@ func GetBanTime(ip string, expectedStatusCode int) (map[string]interface{}, []by } // GetScore returns the score for the given IP address -func GetScore(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) { - var response map[string]interface{} +func GetScore(ip string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(defenderScore)) if err != nil { @@ -1050,8 +1050,8 @@ func UnbanIP(ip string, expectedStatusCode int) error { // Dumpdata requests a backup to outputFile. // outputFile is relative to the configured backups_path -func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (map[string]interface{}, []byte, error) { - var response map[string]interface{} +func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(dumpDataPath)) if err != nil { @@ -1083,8 +1083,8 @@ func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (ma } // Loaddata restores a backup. -func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) { - var response map[string]interface{} +func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) if err != nil { @@ -1114,8 +1114,8 @@ func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[st } // LoaddataFromPostBody restores a backup -func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) { - var response map[string]interface{} +func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any var body []byte url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) if err != nil { @@ -1265,7 +1265,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error { return errors.New("permissions mismatch") } for _, p := range expected.Permissions { - if !util.IsStringInSlice(p, actual.Permissions) { + if !util.Contains(actual.Permissions, p) { return errors.New("permissions content mismatch") } } @@ -1276,7 +1276,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error { return errors.New("allow_api_key_auth mismatch") } for _, v := range expected.Filters.AllowList { - if !util.IsStringInSlice(v, actual.Filters.AllowList) { + if !util.Contains(actual.Filters.AllowList, v) { return errors.New("allow list content mismatch") } } @@ -1350,7 +1350,7 @@ func compareUserPermissions(expected map[string][]string, actual map[string][]st for dir, perms := range expected { if actualPerms, ok := actual[dir]; ok { for _, v := range actualPerms { - if !util.IsStringInSlice(v, perms) { + if !util.Contains(perms, v) { return errors.New("permissions contents mismatch") } } @@ -1530,7 +1530,7 @@ func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error return errors.New("SFTPFs fingerprints mismatch") } for _, value := range actual.SFTPConfig.Fingerprints { - if !util.IsStringInSlice(value, expected.SFTPConfig.Fingerprints) { + if !util.Contains(expected.SFTPConfig.Fingerprints, value) { return errors.New("SFTPFs fingerprints mismatch") } } @@ -1621,27 +1621,27 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error { func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { for _, IPMask := range expected.AllowedIP { - if !util.IsStringInSlice(IPMask, actual.AllowedIP) { + if !util.Contains(actual.AllowedIP, IPMask) { return errors.New("allowed IP contents mismatch") } } for _, IPMask := range expected.DeniedIP { - if !util.IsStringInSlice(IPMask, actual.DeniedIP) { + if !util.Contains(actual.DeniedIP, IPMask) { return errors.New("denied IP contents mismatch") } } for _, method := range expected.DeniedLoginMethods { - if !util.IsStringInSlice(method, actual.DeniedLoginMethods) { + if !util.Contains(actual.DeniedLoginMethods, method) { return errors.New("denied login methods contents mismatch") } } for _, protocol := range expected.DeniedProtocols { - if !util.IsStringInSlice(protocol, actual.DeniedProtocols) { + if !util.Contains(actual.DeniedProtocols, protocol) { return errors.New("denied protocols contents mismatch") } } for _, options := range expected.WebClient { - if !util.IsStringInSlice(options, actual.WebClient) { + if !util.Contains(actual.WebClient, options) { return errors.New("web client options contents mismatch") } } @@ -1712,7 +1712,7 @@ func checkFilterMatch(expected []string, actual []string) bool { return false } for _, e := range expected { - if !util.IsStringInSlice(strings.ToLower(e), actual) { + if !util.Contains(actual, strings.ToLower(e)) { return false } } @@ -1734,7 +1734,7 @@ func compareUserDataTransferLimitFilters(expected sdk.BaseUserFilters, actual sd return errors.New("data transfer limit total_data_transfer mismatch") } for _, source := range actual.DataTransferLimits[idx].Sources { - if !util.IsStringInSlice(source, l.Sources) { + if !util.Contains(l.Sources, source) { return errors.New("data transfer limit source mismatch") } } @@ -1759,7 +1759,7 @@ func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.B return errors.New("bandwidth filters sources mismatch") } for _, source := range actual.BandwidthLimits[idx].Sources { - if !util.IsStringInSlice(source, l.Sources) { + if !util.Contains(l.Sources, source) { return errors.New("bandwidth filters source mismatch") } } diff --git a/init/sftpgo.service b/init/sftpgo.service index 085c2405..fe6c3845 100644 --- a/init/sftpgo.service +++ b/init/sftpgo.service @@ -22,6 +22,7 @@ PrivateDevices=yes DevicePolicy=closed ProtectSystem=true RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX +AmbientCapabilities=CAP_NET_BIND_SERVICE [Install] WantedBy=multi-user.target diff --git a/logger/hclog_adapter.go b/logger/hclog_adapter.go index 4dc15208..ef615edb 100644 --- a/logger/hclog_adapter.go +++ b/logger/hclog_adapter.go @@ -14,7 +14,7 @@ type HCLogAdapter struct { } // Log emits a message and key/value pairs at a provided log level -func (l *HCLogAdapter) Log(level hclog.Level, msg string, args ...interface{}) { +func (l *HCLogAdapter) Log(level hclog.Level, msg string, args ...any) { var ev *zerolog.Event switch level { case hclog.Info: @@ -32,32 +32,32 @@ func (l *HCLogAdapter) Log(level hclog.Level, msg string, args ...interface{}) { } // Trace emits a message and key/value pairs at the TRACE level -func (l *HCLogAdapter) Trace(msg string, args ...interface{}) { +func (l *HCLogAdapter) Trace(msg string, args ...any) { l.Log(hclog.Debug, msg, args...) } // Debug emits a message and key/value pairs at the DEBUG level -func (l *HCLogAdapter) Debug(msg string, args ...interface{}) { +func (l *HCLogAdapter) Debug(msg string, args ...any) { l.Log(hclog.Debug, msg, args...) } // Info emits a message and key/value pairs at the INFO level -func (l *HCLogAdapter) Info(msg string, args ...interface{}) { +func (l *HCLogAdapter) Info(msg string, args ...any) { l.Log(hclog.Info, msg, args...) } // Warn emits a message and key/value pairs at the WARN level -func (l *HCLogAdapter) Warn(msg string, args ...interface{}) { +func (l *HCLogAdapter) Warn(msg string, args ...any) { l.Log(hclog.Warn, msg, args...) } // Error emits a message and key/value pairs at the ERROR level -func (l *HCLogAdapter) Error(msg string, args ...interface{}) { +func (l *HCLogAdapter) Error(msg string, args ...any) { l.Log(hclog.Error, msg, args...) } // With creates a sub-logger -func (l *HCLogAdapter) With(args ...interface{}) hclog.Logger { +func (l *HCLogAdapter) With(args ...any) hclog.Logger { return &HCLogAdapter{Logger: l.Logger.With(args...)} } diff --git a/logger/logger.go b/logger/logger.go index ac72d0cd..7f19e0c0 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -127,7 +127,7 @@ func SetLogTime(utc bool) { } // Log logs at the specified level for the specified sender -func Log(level LogLevel, sender string, connectionID string, format string, v ...interface{}) { +func Log(level LogLevel, sender string, connectionID string, format string, v ...any) { var ev *zerolog.Event switch level { case LevelDebug: @@ -147,42 +147,42 @@ func Log(level LogLevel, sender string, connectionID string, format string, v .. } // Debug logs at debug level for the specified sender -func Debug(sender, connectionID, format string, v ...interface{}) { +func Debug(sender, connectionID, format string, v ...any) { Log(LevelDebug, sender, connectionID, format, v...) } // Info logs at info level for the specified sender -func Info(sender, connectionID, format string, v ...interface{}) { +func Info(sender, connectionID, format string, v ...any) { Log(LevelInfo, sender, connectionID, format, v...) } // Warn logs at warn level for the specified sender -func Warn(sender, connectionID, format string, v ...interface{}) { +func Warn(sender, connectionID, format string, v ...any) { Log(LevelWarn, sender, connectionID, format, v...) } // Error logs at error level for the specified sender -func Error(sender, connectionID, format string, v ...interface{}) { +func Error(sender, connectionID, format string, v ...any) { Log(LevelError, sender, connectionID, format, v...) } // DebugToConsole logs at debug level to stdout -func DebugToConsole(format string, v ...interface{}) { +func DebugToConsole(format string, v ...any) { consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) } // InfoToConsole logs at info level to stdout -func InfoToConsole(format string, v ...interface{}) { +func InfoToConsole(format string, v ...any) { consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) } // WarnToConsole logs at info level to stdout -func WarnToConsole(format string, v ...interface{}) { +func WarnToConsole(format string, v ...any) { consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) } // ErrorToConsole logs at error level to stdout -func ErrorToConsole(format string, v ...interface{}) { +func ErrorToConsole(format string, v ...any) { consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) } @@ -274,10 +274,10 @@ func (l *StdLoggerWrapper) Write(p []byte) (n int, err error) { // LeveledLogger is a logger that accepts a message string and a variadic number of key-value pairs type LeveledLogger struct { Sender string - additionalKeyVals []interface{} + additionalKeyVals []any } -func addKeysAndValues(ev *zerolog.Event, keysAndValues ...interface{}) { +func addKeysAndValues(ev *zerolog.Event, keysAndValues ...any) { kvLen := len(keysAndValues) if kvLen%2 != 0 { extra := keysAndValues[kvLen-1] @@ -292,7 +292,7 @@ func addKeysAndValues(ev *zerolog.Event, keysAndValues ...interface{}) { } // Error logs at error level for the specified sender -func (l *LeveledLogger) Error(msg string, keysAndValues ...interface{}) { +func (l *LeveledLogger) Error(msg string, keysAndValues ...any) { ev := logger.Error() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { @@ -303,7 +303,7 @@ func (l *LeveledLogger) Error(msg string, keysAndValues ...interface{}) { } // Info logs at info level for the specified sender -func (l *LeveledLogger) Info(msg string, keysAndValues ...interface{}) { +func (l *LeveledLogger) Info(msg string, keysAndValues ...any) { ev := logger.Info() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { @@ -314,7 +314,7 @@ func (l *LeveledLogger) Info(msg string, keysAndValues ...interface{}) { } // Debug logs at debug level for the specified sender -func (l *LeveledLogger) Debug(msg string, keysAndValues ...interface{}) { +func (l *LeveledLogger) Debug(msg string, keysAndValues ...any) { ev := logger.Debug() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { @@ -325,7 +325,7 @@ func (l *LeveledLogger) Debug(msg string, keysAndValues ...interface{}) { } // Warn logs at warn level for the specified sender -func (l *LeveledLogger) Warn(msg string, keysAndValues ...interface{}) { +func (l *LeveledLogger) Warn(msg string, keysAndValues ...any) { ev := logger.Warn() ev.Timestamp().Str("sender", l.Sender) if len(l.additionalKeyVals) > 0 { @@ -336,12 +336,12 @@ func (l *LeveledLogger) Warn(msg string, keysAndValues ...interface{}) { } // Panic logs the panic at error level for the specified sender -func (l *LeveledLogger) Panic(msg string, keysAndValues ...interface{}) { +func (l *LeveledLogger) Panic(msg string, keysAndValues ...any) { l.Error(msg, keysAndValues...) } // With returns a LeveledLogger with additional context specific keyvals -func (l *LeveledLogger) With(keysAndValues ...interface{}) ftpserverlog.Logger { +func (l *LeveledLogger) With(keysAndValues ...any) ftpserverlog.Logger { return &LeveledLogger{ Sender: l.Sender, additionalKeyVals: append(l.additionalKeyVals, keysAndValues...), diff --git a/logger/request_logger.go b/logger/request_logger.go index bcfcdf7b..c78788f1 100644 --- a/logger/request_logger.go +++ b/logger/request_logger.go @@ -24,7 +24,7 @@ type StructuredLoggerEntry struct { // The zerolog logger Logger *zerolog.Logger // fields to write in the log - fields map[string]interface{} + fields map[string]any } // NewStructuredLogger returns a chi.middleware.RequestLogger using our StructuredLogger. @@ -40,7 +40,7 @@ func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { scheme = "https" } - fields := map[string]interface{}{ + fields := map[string]any{ "local_addr": getLocalAddress(r), "remote_addr": r.RemoteAddr, "proto": r.Proto, @@ -57,7 +57,7 @@ func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { } // Write logs a new entry at the end of the HTTP request -func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { +func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra any) { metric.HTTPRequestServed(status) l.Logger.Info(). Timestamp(). @@ -70,7 +70,7 @@ func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, ela } // Panic logs panics -func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) { +func (l *StructuredLoggerEntry) Panic(v any, stack []byte) { l.Logger.Error(). Timestamp(). Str("sender", "httpd"). diff --git a/mfa/mfa.go b/mfa/mfa.go index ca47bfaf..0bd89506 100644 --- a/mfa/mfa.go +++ b/mfa/mfa.go @@ -53,7 +53,7 @@ func (c *Config) Initialize() error { return nil } -// GetAvailableTOTPConfigs returns the available TOTP config names +// GetAvailableTOTPConfigs returns the available TOTP configs func GetAvailableTOTPConfigs() []*TOTPConfig { return totpConfigs } diff --git a/mfa/totp.go b/mfa/totp.go index f040ef5d..ff33a47a 100644 --- a/mfa/totp.go +++ b/mfa/totp.go @@ -96,7 +96,7 @@ func (c *TOTPConfig) generate(username string, qrCodeWidth, qrCodeHeight int) (s } func cleanupUsedPasscodes() { - usedPasscodes.Range(func(key, value interface{}) bool { + usedPasscodes.Range(func(key, value any) bool { exp, ok := value.(time.Time) if !ok || exp.Before(time.Now().UTC()) { usedPasscodes.Delete(key) diff --git a/plugin/kms.go b/plugin/kms.go index d1f9d2ff..5fe96a6b 100644 --- a/plugin/kms.go +++ b/plugin/kms.go @@ -29,10 +29,10 @@ type KMSConfig struct { } func (c *KMSConfig) validate() error { - if !util.IsStringInSlice(c.Scheme, validKMSSchemes) { + if !util.Contains(validKMSSchemes, c.Scheme) { return fmt.Errorf("invalid kms scheme: %v", c.Scheme) } - if !util.IsStringInSlice(c.EncryptedStatus, validKMSEncryptedStatuses) { + if !util.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) { return fmt.Errorf("invalid kms encrypted status: %v", c.EncryptedStatus) } return nil diff --git a/plugin/notifier.go b/plugin/notifier.go index c9922de8..cd23d2fa 100644 --- a/plugin/notifier.go +++ b/plugin/notifier.go @@ -181,7 +181,7 @@ func (p *notifierPlugin) canQueueEvent(timestamp int64) bool { } func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) { - if !util.IsStringInSlice(event.Action, p.config.NotifierOptions.FsEvents) { + if !util.Contains(p.config.NotifierOptions.FsEvents, event.Action) { return } @@ -191,8 +191,8 @@ func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) { } func (p *notifierPlugin) notifyProviderAction(event *notifier.ProviderEvent, object Renderer) { - if !util.IsStringInSlice(event.Action, p.config.NotifierOptions.ProviderEvents) || - !util.IsStringInSlice(event.ObjectType, p.config.NotifierOptions.ProviderObjects) { + if !util.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) || + !util.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) { return } diff --git a/service/service.go b/service/service.go index 22d648b1..cbe06d6d 100644 --- a/service/service.go +++ b/service/service.go @@ -186,7 +186,8 @@ func (s *Service) startServices() { if httpdConf.ShouldBind() { go func() { - if err := httpdConf.Initialize(s.ConfigDir); err != nil { + providerConf := config.GetProviderConf() + if err := httpdConf.Initialize(s.ConfigDir, providerConf.GetShared()); err != nil { logger.Error(logSender, "", "could not start HTTP server: %v", err) logger.ErrorToConsole("could not start HTTP server: %v", err) s.Error = err diff --git a/service/service_portable.go b/service/service_portable.go index d87c0fee..80120aa6 100644 --- a/service/service_portable.go +++ b/service/service_portable.go @@ -67,7 +67,7 @@ func (s *Service) StartPortableMode(sftpdPort, ftpPort, webdavPort int, enabledS // dynamic ports starts from 49152 sftpdConf.Bindings[0].Port = 49152 + rand.Intn(15000) } - if util.IsStringInSlice("*", enabledSSHCommands) { + if util.Contains(enabledSSHCommands, "*") { sftpdConf.EnabledSSHCommands = sftpd.GetSupportedSSHCommands() } else { sftpdConf.EnabledSSHCommands = enabledSSHCommands diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 6278f0fb..60fe472a 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -392,7 +392,7 @@ func TestSupportedSSHCommands(t *testing.T) { assert.Equal(t, len(supportedSSHCommands), len(cmds)) for _, c := range cmds { - assert.True(t, util.IsStringInSlice(c, supportedSSHCommands)) + assert.True(t, util.Contains(supportedSSHCommands, c)) } } @@ -845,7 +845,7 @@ func TestRsyncOptions(t *testing.T) { } cmd, err := sshCmd.getSystemCommand() assert.NoError(t, err) - assert.True(t, util.IsStringInSlice("--safe-links", cmd.cmd.Args), + assert.True(t, util.Contains(cmd.cmd.Args, "--safe-links"), "--safe-links must be added if the user has the create symlinks permission") permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, @@ -862,7 +862,7 @@ func TestRsyncOptions(t *testing.T) { } cmd, err = sshCmd.getSystemCommand() assert.NoError(t, err) - assert.True(t, util.IsStringInSlice("--munge-links", cmd.cmd.Args), + assert.True(t, util.Contains(cmd.cmd.Args, "--munge-links"), "--munge-links must be added if the user has the create symlinks permission") sshCmd.connection.User.VirtualFolders = append(sshCmd.connection.User.VirtualFolders, vfs.VirtualFolder{ diff --git a/sftpd/mocks/middleware.go b/sftpd/mocks/middleware.go index 4a5a5428..e0251436 100644 --- a/sftpd/mocks/middleware.go +++ b/sftpd/mocks/middleware.go @@ -44,7 +44,7 @@ func (m *MockMiddleware) Filecmd(arg0 *sftp.Request) error { } // Filecmd indicates an expected call of Filecmd. -func (mr *MockMiddlewareMockRecorder) Filecmd(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) Filecmd(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filecmd", reflect.TypeOf((*MockMiddleware)(nil).Filecmd), arg0) } @@ -59,7 +59,7 @@ func (m *MockMiddleware) Filelist(arg0 *sftp.Request) (sftp.ListerAt, error) { } // Filelist indicates an expected call of Filelist. -func (mr *MockMiddlewareMockRecorder) Filelist(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) Filelist(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filelist", reflect.TypeOf((*MockMiddleware)(nil).Filelist), arg0) } @@ -74,7 +74,7 @@ func (m *MockMiddleware) Fileread(arg0 *sftp.Request) (io.ReaderAt, error) { } // Fileread indicates an expected call of Fileread. -func (mr *MockMiddlewareMockRecorder) Fileread(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) Fileread(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fileread", reflect.TypeOf((*MockMiddleware)(nil).Fileread), arg0) } @@ -89,7 +89,7 @@ func (m *MockMiddleware) Filewrite(arg0 *sftp.Request) (io.WriterAt, error) { } // Filewrite indicates an expected call of Filewrite. -func (mr *MockMiddlewareMockRecorder) Filewrite(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) Filewrite(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filewrite", reflect.TypeOf((*MockMiddleware)(nil).Filewrite), arg0) } @@ -104,7 +104,7 @@ func (m *MockMiddleware) Lstat(arg0 *sftp.Request) (sftp.ListerAt, error) { } // Lstat indicates an expected call of Lstat. -func (mr *MockMiddlewareMockRecorder) Lstat(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) Lstat(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lstat", reflect.TypeOf((*MockMiddleware)(nil).Lstat), arg0) } @@ -119,7 +119,7 @@ func (m *MockMiddleware) OpenFile(arg0 *sftp.Request) (sftp.WriterAtReaderAt, er } // OpenFile indicates an expected call of OpenFile. -func (mr *MockMiddlewareMockRecorder) OpenFile(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) OpenFile(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenFile", reflect.TypeOf((*MockMiddleware)(nil).OpenFile), arg0) } @@ -134,7 +134,7 @@ func (m *MockMiddleware) StatVFS(arg0 *sftp.Request) (*sftp.StatVFS, error) { } // StatVFS indicates an expected call of StatVFS. -func (mr *MockMiddlewareMockRecorder) StatVFS(arg0 interface{}) *gomock.Call { +func (mr *MockMiddlewareMockRecorder) StatVFS(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StatVFS", reflect.TypeOf((*MockMiddleware)(nil).StatVFS), arg0) } diff --git a/sftpd/scp.go b/sftpd/scp.go index 63fde638..d1a2acce 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -560,11 +560,11 @@ func (c *scpCommand) getCommandType() string { } func (c *scpCommand) sendFileTime() bool { - return util.IsStringInSlice("-p", c.args) + return util.Contains(c.args, "-p") } func (c *scpCommand) isRecursive() bool { - return util.IsStringInSlice("-r", c.args) + return util.Contains(c.args, "-r") } // read the SCP confirmation message and the optional text message diff --git a/sftpd/server.go b/sftpd/server.go index a93e8fe3..5060ed0d 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -251,13 +251,13 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig { func (c *Configuration) updateSupportedAuthentications() { serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications) - if util.IsStringInSlice(dataprovider.LoginMethodPassword, serviceStatus.Authentications) && - util.IsStringInSlice(dataprovider.SSHLoginMethodPublicKey, serviceStatus.Authentications) { + if util.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) && + util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword) } - if util.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, serviceStatus.Authentications) && - util.IsStringInSlice(dataprovider.SSHLoginMethodPublicKey, serviceStatus.Authentications) { + if util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) && + util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt) } } @@ -367,7 +367,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms) } for _, hostKeyAlgo := range c.HostKeyAlgorithms { - if !util.IsStringInSlice(hostKeyAlgo, supportedHostKeyAlgos) { + if !util.Contains(supportedHostKeyAlgos, hostKeyAlgo) { return fmt.Errorf("unsupported host key algorithm %#v", hostKeyAlgo) } } @@ -376,7 +376,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) if len(c.KexAlgorithms) > 0 { c.KexAlgorithms = util.RemoveDuplicates(c.KexAlgorithms) for _, kex := range c.KexAlgorithms { - if !util.IsStringInSlice(kex, supportedKexAlgos) { + if !util.Contains(supportedKexAlgos, kex) { return fmt.Errorf("unsupported key-exchange algorithm %#v", kex) } } @@ -385,7 +385,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) if len(c.Ciphers) > 0 { c.Ciphers = util.RemoveDuplicates(c.Ciphers) for _, cipher := range c.Ciphers { - if !util.IsStringInSlice(cipher, supportedCiphers) { + if !util.Contains(supportedCiphers, cipher) { return fmt.Errorf("unsupported cipher %#v", cipher) } } @@ -394,7 +394,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) if len(c.MACs) > 0 { c.MACs = util.RemoveDuplicates(c.MACs) for _, mac := range c.MACs { - if !util.IsStringInSlice(mac, supportedMACs) { + if !util.Contains(supportedMACs, mac) { return fmt.Errorf("unsupported MAC algorithm %#v", mac) } } @@ -676,7 +676,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh. user.Username, user.HomeDir) return nil, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir) } - if util.IsStringInSlice(common.ProtocolSSH, user.Filters.DeniedProtocols) { + if util.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) { logger.Info(logSender, connectionID, "cannot login user %#v, protocol SSH is not allowed", user.Username) return nil, fmt.Errorf("protocol SSH is not allowed for user %#v", user.Username) } @@ -721,13 +721,13 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh. } func (c *Configuration) checkSSHCommands() { - if util.IsStringInSlice("*", c.EnabledSSHCommands) { + if util.Contains(c.EnabledSSHCommands, "*") { c.EnabledSSHCommands = GetSupportedSSHCommands() return } sshCommands := []string{} for _, command := range c.EnabledSSHCommands { - if util.IsStringInSlice(command, supportedSSHCommands) { + if util.Contains(supportedSSHCommands, command) { sshCommands = append(sshCommands, command) } else { logger.Warn(logSender, "", "unsupported ssh command: %#v ignored", command) diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 91803e14..f276e504 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -252,7 +252,7 @@ func TestMain(m *testing.M) { }() go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) } @@ -7988,8 +7988,8 @@ func TestUserAllowedLoginMethods(t *testing.T) { allowedMethods = user.GetAllowedLoginMethods() assert.Equal(t, 4, len(allowedMethods)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndKeyboardInt, allowedMethods)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyAndPassword, allowedMethods)) + assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt)) + assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword)) } func TestUserPartialAuth(t *testing.T) { @@ -8040,11 +8040,11 @@ func TestUserGetNextAuthMethods(t *testing.T) { methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey}, true) assert.Equal(t, 2, len(methods)) - assert.True(t, util.IsStringInSlice(dataprovider.LoginMethodPassword, methods)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods)) + assert.True(t, util.Contains(methods, dataprovider.LoginMethodPassword)) + assert.True(t, util.Contains(methods, dataprovider.SSHLoginMethodKeyboardInteractive)) methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey}, false) assert.Equal(t, 1, len(methods)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods)) + assert.True(t, util.Contains(methods, dataprovider.SSHLoginMethodKeyboardInteractive)) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, @@ -8054,7 +8054,7 @@ func TestUserGetNextAuthMethods(t *testing.T) { } methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey}, true) assert.Equal(t, 1, len(methods)) - assert.True(t, util.IsStringInSlice(dataprovider.LoginMethodPassword, methods)) + assert.True(t, util.Contains(methods, dataprovider.LoginMethodPassword)) user.Filters.DeniedLoginMethods = []string{ dataprovider.LoginMethodPassword, @@ -8064,7 +8064,7 @@ func TestUserGetNextAuthMethods(t *testing.T) { } methods = user.GetNextAuthMethods([]string{dataprovider.SSHLoginMethodPublicKey}, true) assert.Equal(t, 1, len(methods)) - assert.True(t, util.IsStringInSlice(dataprovider.SSHLoginMethodKeyboardInteractive, methods)) + assert.True(t, util.Contains(methods, dataprovider.SSHLoginMethodKeyboardInteractive)) } func TestUserIsLoginMethodAllowed(t *testing.T) { diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 9d2146a6..0ebe6fab 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -76,7 +76,7 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand name, args, err := parseCommandPayload(msg.Command) connection.Log(logger.LevelDebug, "new ssh command: %#v args: %v num args: %v user: %v, error: %v", name, args, len(args), connection.User.Username, err) - if err == nil && util.IsStringInSlice(name, enabledSSHCommands) { + if err == nil && util.Contains(enabledSSHCommands, name) { connection.command = msg.Command if name == scpCmdName && len(args) >= 2 { connection.SetProtocol(common.ProtocolSCP) @@ -122,9 +122,9 @@ func (c *sshCommand) handle() (err error) { defer common.Connections.Remove(c.connection.GetID()) c.connection.UpdateLastActivity() - if util.IsStringInSlice(c.command, sshHashCommands) { + if util.Contains(sshHashCommands, c.command) { return c.handleHashCommands() - } else if util.IsStringInSlice(c.command, systemCommands) { + } else if util.Contains(systemCommands, c.command) { command, err := c.getSystemCommand() if err != nil { return c.sendErrorResponse(err) @@ -507,11 +507,11 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) { // If the user cannot create symlinks we add the option --munge-links, if it is not // already set. This should make symlinks unusable (but manually recoverable) if c.connection.User.HasPerm(dataprovider.PermCreateSymlinks, c.getDestPath()) { - if !util.IsStringInSlice("--safe-links", args) { + if !util.Contains(args, "--safe-links") { args = append([]string{"--safe-links"}, args...) } } else { - if !util.IsStringInSlice("--munge-links", args) { + if !util.Contains(args, "--munge-links") { args = append([]string{"--munge-links"}, args...) } } diff --git a/smtp/smtp.go b/smtp/smtp.go index 0f26325d..c7b6dc86 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -153,7 +153,7 @@ func loadTemplates(templatesPath string) { } // RenderRetentionReportTemplate executes the retention report template -func RenderRetentionReportTemplate(buf *bytes.Buffer, data interface{}) error { +func RenderRetentionReportTemplate(buf *bytes.Buffer, data any) error { if smtpServer == nil { return errors.New("smtp: not configured") } @@ -161,7 +161,7 @@ func RenderRetentionReportTemplate(buf *bytes.Buffer, data interface{}) error { } // RenderPasswordResetTemplate executes the password reset template -func RenderPasswordResetTemplate(buf *bytes.Buffer, data interface{}) error { +func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error { if smtpServer == nil { return errors.New("smtp: not configured") } diff --git a/util/util.go b/util/util.go index a5e7831f..4ed28c8f 100644 --- a/util/util.go +++ b/util/util.go @@ -56,17 +56,6 @@ func Contains[T comparable](elems []T, v T) bool { return false } -// IsStringInSlice searches a string in a slice and returns true if the string is found -// TODO: replace with Contains above -func IsStringInSlice(obj string, list []string) bool { - for i := 0; i < len(list); i++ { - if list[i] == obj { - return true - } - } - return false -} - // IsStringPrefixInSlice searches a string prefix in a slice and returns true // if a matching prefix is found func IsStringPrefixInSlice(obj string, list []string) bool { diff --git a/vfs/fileinfo.go b/vfs/fileinfo.go index bc90fce5..11c3c37b 100644 --- a/vfs/fileinfo.go +++ b/vfs/fileinfo.go @@ -64,6 +64,6 @@ func (fi *FileInfo) SetMode(mode os.FileMode) { } // Sys provides the underlying data source (can return nil) -func (fi *FileInfo) Sys() interface{} { +func (fi *FileInfo) Sys() any { return nil } diff --git a/vfs/s3fs.go b/vfs/s3fs.go index 578e4a82..e2049464 100644 --- a/vfs/s3fs.go +++ b/vfs/s3fs.go @@ -95,7 +95,7 @@ func NewS3Fs(connectionID, localTempDir, mountPath string, s3Config S3FsConfig) credentials.NewStaticCredentialsProvider(fs.config.AccessKey, fs.config.AccessSecret.GetPayload(), "")) } if fs.config.Endpoint != "" { - endpointResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + endpointResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...any) (aws.Endpoint, error) { return aws.Endpoint{ URL: fs.config.Endpoint, HostnameImmutable: fs.config.ForcePathStyle, diff --git a/vfs/sftpfs.go b/vfs/sftpfs.go index 99368ab8..c50788eb 100644 --- a/vfs/sftpfs.go +++ b/vfs/sftpfs.go @@ -73,7 +73,7 @@ func (c *SFTPFsConfig) isEqual(other *SFTPFsConfig) bool { return false } for _, fp := range c.Fingerprints { - if !util.IsStringInSlice(fp, other.Fingerprints) { + if !util.Contains(other.Fingerprints, fp) { return false } } @@ -756,8 +756,8 @@ func (fs *SFTPFs) createConnection() error { User: fs.config.Username, HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { fp := ssh.FingerprintSHA256(key) - if util.IsStringInSlice(fp, sftpFingerprints) { - if util.IsStringInSlice(fs.config.Username, fs.config.forbiddenSelfUsernames) { + if util.Contains(sftpFingerprints, fp) { + if util.Contains(fs.config.forbiddenSelfUsernames, fs.config.Username) { fsLog(fs, logger.LevelError, "SFTP loop or nested local SFTP folders detected, mount path %#v, username %#v, forbidden usernames: %+v", fs.mountPath, fs.config.Username, fs.config.forbiddenSelfUsernames) return ErrSFTPLoop diff --git a/vfs/vfs.go b/vfs/vfs.go index c60a8bcf..dd16bb21 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -549,7 +549,7 @@ func (c *AzBlobFsConfig) validate() error { if err := c.checkPartSizeAndConcurrency(); err != nil { return err } - if !util.IsStringInSlice(c.AccessTier, validAzAccessTier) { + if !util.Contains(validAzAccessTier, c.AccessTier) { return fmt.Errorf("invalid access tier %#v, valid values: \"''%v\"", c.AccessTier, strings.Join(validAzAccessTier, ", ")) } return nil @@ -829,6 +829,6 @@ func getMountPath(mountPath string) string { return mountPath } -func fsLog(fs Fs, level logger.LogLevel, format string, v ...interface{}) { +func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) } diff --git a/webdavd/handler.go b/webdavd/handler.go index 25d54d71..8a73ad56 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -359,7 +359,7 @@ func (c *Connection) orderDirsToRemove(fs vfs.Fs, dirsToRemove []objectMapping) for len(orderedDirs) < len(dirsToRemove) { for idx, d := range dirsToRemove { - if util.IsStringInSlice(d.fsPath, removedDirs) { + if util.Contains(removedDirs, d.fsPath) { continue } isEmpty := true @@ -367,7 +367,7 @@ func (c *Connection) orderDirsToRemove(fs vfs.Fs, dirsToRemove []objectMapping) if idx == idx1 { continue } - if util.IsStringInSlice(d1.fsPath, removedDirs) { + if util.Contains(removedDirs, d1.fsPath) { continue } if strings.HasPrefix(d1.fsPath, d.fsPath+pathSeparator) { diff --git a/webdavd/server.go b/webdavd/server.go index 9beccc61..a727250b 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -308,7 +308,7 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo user.Username, user.HomeDir) return connID, fmt.Errorf("cannot login user with invalid home dir: %#v", user.HomeDir) } - if util.IsStringInSlice(common.ProtocolWebDAV, user.Filters.DeniedProtocols) { + if util.Contains(user.Filters.DeniedProtocols, common.ProtocolWebDAV) { logger.Info(logSender, connectionID, "cannot login user %#v, protocol DAV is not allowed", user.Username) return connID, fmt.Errorf("protocol DAV is not allowed for user %#v", user.Username) } @@ -348,7 +348,7 @@ func writeLog(r *http.Request, status int, err error) { if r.TLS != nil { scheme = "https" } - fields := map[string]interface{}{ + fields := map[string]any{ "remote_addr": r.RemoteAddr, "proto": r.Proto, "method": r.Method, diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index 9c273fe3..13288f7d 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -386,7 +386,7 @@ func TestMain(m *testing.M) { }() go func() { - if err := httpdConf.Initialize(configDir); err != nil { + if err := httpdConf.Initialize(configDir, 0); err != nil { logger.ErrorToConsole("could not start HTTP server: %v", err) os.Exit(1) }