diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2e3b6ae2..048c5e30 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,7 +5,7 @@ on: tags: 'v*' env: - GO_VERSION: 1.19.2 + GO_VERSION: 1.19.3 jobs: prepare-sources-with-deps: diff --git a/go.mod b/go.mod index dba0e44f..1f45d5e7 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( github.com/rs/xid v1.4.0 github.com/rs/zerolog v1.28.0 github.com/sftpgo/sdk v0.1.2 - github.com/shirou/gopsutil/v3 v3.22.9 + github.com/shirou/gopsutil/v3 v3.22.10 github.com/spf13/afero v1.9.2 github.com/spf13/cobra v1.6.1 github.com/spf13/viper v1.13.0 diff --git a/go.sum b/go.sum index c1bfbc82..71ada88e 100644 --- a/go.sum +++ b/go.sum @@ -1459,8 +1459,8 @@ github.com/seccomp/libseccomp-golang v0.9.2-0.20210429002308-3879420cc921/go.mod github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= github.com/sftpgo/sdk v0.1.2 h1:j4V63RuVcYfJAOWV0zRUofa1PlQvKU2ujly0lB7quVA= github.com/sftpgo/sdk v0.1.2/go.mod h1:PTp1TfXa+95wHw9yuZu7BA3vmzLqbRkz3gBmMNnwFQg= -github.com/shirou/gopsutil/v3 v3.22.9 h1:yibtJhIVEMcdw+tCTbOPiF1VcsuDeTE4utJ8Dm4c5eA= -github.com/shirou/gopsutil/v3 v3.22.9/go.mod h1:bBYl1kjgEJpWpxeHmLI+dVHWtyAwfcmSBLDsp2TNT8A= +github.com/shirou/gopsutil/v3 v3.22.10 h1:4KMHdfBRYXGF9skjDWiL4RA2N+E8dRdodU/bOZpPoVg= +github.com/shirou/gopsutil/v3 v3.22.10/go.mod h1:QNza6r4YQoydyCfo6rH0blGfKahgibh4dQmV5xdFkQk= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= diff --git a/internal/common/connection.go b/internal/common/connection.go index 52824b54..24a3db07 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -309,7 +309,7 @@ func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) { // CheckParentDirs tries to create the specified directory and any missing parent dirs func (c *BaseConnection) CheckParentDirs(virtualPath string) error { - fs, err := c.User.GetFilesystemForPath(virtualPath, "") + fs, err := c.User.GetFilesystemForPath(virtualPath, c.GetID()) if err != nil { return err } @@ -321,7 +321,7 @@ func (c *BaseConnection) CheckParentDirs(virtualPath string) error { } dirs := util.GetDirsForVirtualPath(virtualPath) for idx := len(dirs) - 1; idx >= 0; idx-- { - fs, err = c.User.GetFilesystemForPath(dirs[idx], "") + fs, err = c.User.GetFilesystemForPath(dirs[idx], c.GetID()) if err != nil { return err } @@ -1509,6 +1509,7 @@ func (c *BaseConnection) GetGenericError(err error) error { err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable || err == ErrShuttingDown { return err } + c.Log(logger.LevelError, "generic error: %+v", err) return ErrGenericFailure } } @@ -1536,7 +1537,7 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin // will not be listed return nil, "", c.GetPermissionDeniedError() } - return nil, "", err + return nil, "", c.GetGenericError(err) } if isShuttingDown.Load() { diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index a0534b84..00a46141 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -6319,14 +6319,10 @@ func TestSFTPLoopError(t *testing.T) { conn = common.NewBaseConnection("", common.ProtocolSFTP, "", "", user1) _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "SFTP loop") - } + assert.Error(t, err) conn = common.NewBaseConnection("", common.ProtocolFTP, "", "", user1) _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "SFTP loop") - } + assert.Error(t, err) _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) assert.NoError(t, err) diff --git a/internal/dataprovider/admin.go b/internal/dataprovider/admin.go index a9660d32..5ce58fcb 100644 --- a/internal/dataprovider/admin.go +++ b/internal/dataprovider/admin.go @@ -15,14 +15,13 @@ package dataprovider import ( - "crypto/sha256" - "encoding/base64" "encoding/json" "errors" "fmt" "net" "os" "sort" + "strconv" "strings" "github.com/alexedwards/argon2id" @@ -548,12 +547,9 @@ func (a *Admin) CanManageMFA() bool { } // GetSignature returns a signature for this admin. -// It could change after an update +// It will change after an update func (a *Admin) GetSignature() string { - data := []byte(a.Username) - data = append(data, []byte(a.Password)...) - signature := sha256.Sum256(data) - return base64.StdEncoding.EncodeToString(signature[:]) + return strconv.FormatInt(a.UpdatedAt, 10) } func (a *Admin) getACopy() Admin { diff --git a/internal/dataprovider/user.go b/internal/dataprovider/user.go index a95a4ddc..1a12e0e0 100644 --- a/internal/dataprovider/user.go +++ b/internal/dataprovider/user.go @@ -15,8 +15,6 @@ package dataprovider import ( - "crypto/sha256" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -24,9 +22,11 @@ import ( "net" "os" "path" + "strconv" "strings" "time" + "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/kms" @@ -600,7 +600,7 @@ func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, e // CheckMetadataConsistency checks the consistency between the metadata stored // in the configured metadata plugin and the filesystem func (u *User) CheckMetadataConsistency() error { - fs, err := u.getRootFs("") + fs, err := u.getRootFs(xid.New().String()) if err != nil { return err } @@ -621,7 +621,7 @@ func (u *User) CheckMetadataConsistency() error { // ScanQuota scans the user home dir and virtual folders, included in its quota, // and returns the number of files and their size func (u *User) ScanQuota() (int, int64, error) { - fs, err := u.getRootFs("") + fs, err := u.getRootFs(xid.New().String()) if err != nil { return 0, 0, err } @@ -1131,12 +1131,9 @@ func (u *User) MustSetSecondFactorForProtocol(protocol string) bool { } // GetSignature returns a signature for this admin. -// It could change after an update +// It will change after an update func (u *User) GetSignature() string { - data := []byte(fmt.Sprintf("%v_%v_%v", u.Username, u.Status, u.ExpirationDate)) - data = append(data, []byte(u.Password)...) - signature := sha256.Sum256(data) - return base64.StdEncoding.EncodeToString(signature[:]) + return strconv.FormatInt(u.UpdatedAt, 10) } // GetBandwidthForIP returns the upload and download bandwidth for the specified IP diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index 11bbef7c..10201f55 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -4516,7 +4516,7 @@ func TestUserSFTPFs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) user.FsConfig.Provider = sdk.SFTPFilesystemProvider - user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1" // missing port + user.FsConfig.SFTPConfig.Endpoint = "[::1]:22:22" // invalid endpoint user.FsConfig.SFTPConfig.Username = "sftp_user" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp_pwd") user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) @@ -4527,6 +4527,13 @@ func TestUserSFTPFs(t *testing.T) { assert.NoError(t, err) assert.Contains(t, string(resp), "invalid endpoint") + user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1" + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1:22", user.FsConfig.SFTPConfig.Endpoint) + user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" user.FsConfig.SFTPConfig.DisableCouncurrentReads = true user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") diff --git a/internal/sftpd/handler.go b/internal/sftpd/handler.go index 73be378c..60e076a7 100644 --- a/internal/sftpd/handler.go +++ b/internal/sftpd/handler.go @@ -182,9 +182,6 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader func (c *Connection) Filecmd(request *sftp.Request) error { c.UpdateLastActivity() - c.Log(logger.LevelDebug, "new cmd, method: %v, sourcePath: %#v, targetPath: %#v", request.Method, - request.Filepath, request.Target) - switch request.Method { case "Setstat": return c.handleSFTPSetstat(request) diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index 28b7ad54..e4a65c10 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -527,14 +527,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve loginType := sconn.Permissions.Extensions["sftpgo_login_method"] connectionID := hex.EncodeToString(sconn.SessionID()) + defer user.CloseFs() //nolint:errcheck if err = user.CheckFsRoot(connectionID); err != nil { - errClose := user.CloseFs() - logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) + logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) return } - defer user.CloseFs() //nolint:errcheck - logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, "User %#v logged in with %#v, from ip %#v, client version %#v", user.Username, loginType, ipAddr, string(sconn.ClientVersion())) diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index a8531ddb..7d41ae97 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -5279,9 +5279,7 @@ func TestSFTPLoopVirtualFolders(t *testing.T) { defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir("/vdir") - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "SFTP loop") - } + assert.Error(t, err) } // now make user2 a local account with an SFTP virtual folder to user1. // So we have: @@ -5316,9 +5314,7 @@ func TestSFTPLoopVirtualFolders(t *testing.T) { defer client.Close() assert.NoError(t, checkBasicSFTP(client)) _, err = client.ReadDir("/vdir") - if assert.Error(t, err) { - assert.Contains(t, err.Error(), "SFTP loop") - } + assert.Error(t, err) } _, err = httpdtest.RemoveUser(user1, http.StatusOK) @@ -10080,7 +10076,7 @@ func TestSCPNestedFolders(t *testing.T) { // now change the password for the base user, so SFTP folder will not work baseUser.Password = defaultPassword + "_mod" - _, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "") + _, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "1") assert.NoError(t, err) err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 2e994e74..e17d2e8e 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -566,6 +566,9 @@ func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) { } numFiles++ size += blobSize + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "root dir scan in progress, files: %d, size: %d", numFiles, size) + } } } } diff --git a/internal/vfs/folder.go b/internal/vfs/folder.go index a6d055d9..c1a70978 100644 --- a/internal/vfs/folder.go +++ b/internal/vfs/folder.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" + "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/util" @@ -213,7 +214,7 @@ func (v *VirtualFolder) GetFilesystem(connectionID string, forbiddenSelfUsers [] // CheckMetadataConsistency checks the consistency between the metadata stored // in the configured metadata plugin and the filesystem func (v *VirtualFolder) CheckMetadataConsistency() error { - fs, err := v.GetFilesystem("", nil) + fs, err := v.GetFilesystem(xid.New().String(), nil) if err != nil { return err } @@ -227,7 +228,7 @@ func (v *VirtualFolder) ScanQuota() (int, int64, error) { if v.hasPathPlaceholder() { return 0, 0, errors.New("cannot scan quota: this folder has a path placeholder") } - fs, err := v.GetFilesystem("", nil) + fs, err := v.GetFilesystem(xid.New().String(), nil) if err != nil { return 0, 0, err } diff --git a/internal/vfs/gcsfs.go b/internal/vfs/gcsfs.go index 8ba82c89..fe4f6c4c 100644 --- a/internal/vfs/gcsfs.go +++ b/internal/vfs/gcsfs.go @@ -534,6 +534,9 @@ func (fs *GCSFs) ScanRootDirContents() (int, int64, error) { } numFiles++ size += attrs.Size + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "root dir scan in progress, files: %d, size: %d", numFiles, size) + } } objects = nil diff --git a/internal/vfs/osfs.go b/internal/vfs/osfs.go index a1ea3700..a5d1705e 100644 --- a/internal/vfs/osfs.go +++ b/internal/vfs/osfs.go @@ -382,6 +382,9 @@ func (fs *OsFs) GetDirSize(dirname string) (int, int64, error) { if info != nil && info.Mode().IsRegular() { size += info.Size() numFiles++ + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) + } } return err }) diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index 240f952b..7f172a6f 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -600,6 +600,9 @@ func (fs *S3Fs) ScanRootDirContents() (int, int64, error) { } numFiles++ size += fileObject.Size + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "root dir scan in progress, files: %d, size: %d", numFiles, size) + } } } diff --git a/internal/vfs/sftpfs.go b/internal/vfs/sftpfs.go index 38555583..617dfe5b 100644 --- a/internal/vfs/sftpfs.go +++ b/internal/vfs/sftpfs.go @@ -16,8 +16,10 @@ package vfs import ( "bufio" + "bytes" "errors" "fmt" + "hash/fnv" "io" "io/fs" "net" @@ -25,12 +27,15 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/eikenb/pipeat" "github.com/pkg/sftp" + "github.com/robfig/cron/v3" "github.com/rs/xid" "github.com/sftpgo/sdk" "golang.org/x/crypto/ssh" @@ -43,11 +48,16 @@ import ( const ( // sftpFsName is the name for the SFTP Fs implementation - sftpFsName = "sftpfs" + sftpFsName = "sftpfs" + logSenderSFTPCache = "sftpCache" + maxSessionsPerConnection = 5 ) -// ErrSFTPLoop defines the error to return if an SFTP loop is detected -var ErrSFTPLoop = errors.New("SFTP loop or nested local SFTP folders detected") +var ( + // ErrSFTPLoop defines the error to return if an SFTP loop is detected + ErrSFTPLoop = errors.New("SFTP loop or nested local SFTP folders detected") + sftpConnsCache = newSFTPConnectionCache() +) // SFTPFsConfig defines the configuration for SFTP based filesystem type SFTPFsConfig struct { @@ -145,6 +155,9 @@ func (c *SFTPFsConfig) validate() error { if c.Endpoint == "" { return errors.New("endpoint cannot be empty") } + if !strings.Contains(c.Endpoint, ":") { + c.Endpoint += ":22" + } _, _, err := net.SplitHostPort(c.Endpoint) if err != nil { return fmt.Errorf("invalid endpoint: %v", err) @@ -220,17 +233,36 @@ func (c *SFTPFsConfig) ValidateAndEncryptCredentials(additionalData string) erro return nil } +// getUniqueID returns an hash of the settings used to connect to the SFTP server +func (c *SFTPFsConfig) getUniqueID(partition int) uint64 { + h := fnv.New64a() + var b bytes.Buffer + + b.WriteString(c.Endpoint) + b.WriteString(c.Username) + b.WriteString(strings.Join(c.Fingerprints, "")) + b.WriteString(strconv.FormatBool(c.DisableCouncurrentReads)) + b.WriteString(strconv.FormatInt(c.BufferSize, 10)) + b.WriteString(c.Password.GetPayload()) + b.WriteString(c.PrivateKey.GetPayload()) + b.WriteString(c.KeyPassphrase.GetPayload()) + if allowSelfConnections != 0 { + b.WriteString(strings.Join(c.forbiddenSelfUsernames, "")) + } + b.WriteString(strconv.Itoa(partition)) + + h.Write(b.Bytes()) + return h.Sum64() +} + // SFTPFs is a Fs implementation for SFTP backends type SFTPFs struct { - sync.Mutex connectionID string // if not empty this fs is mouted as virtual folder in the specified path mountPath string localTempDir string config *SFTPFsConfig - sshClient *ssh.Client - sftpClient *sftp.Client - err chan error + conn *sftpConnection } // NewSFTPFs returns an SFTPFs object that allows to interact with an SFTP server @@ -266,15 +298,18 @@ func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUserna mountPath: getMountPath(mountPath), localTempDir: localTempDir, config: &config, - err: make(chan error, 1), + conn: sftpConnsCache.Get(&config, connectionID), } err := sftpFs.createConnection() + if err != nil { + sftpFs.Close() //nolint:errcheck + } return sftpFs, err } // Name returns the name for the Fs implementation func (fs *SFTPFs) Name() string { - return fmt.Sprintf("%v %#v", sftpFsName, fs.config.Endpoint) + return fmt.Sprintf(`%s %q@%q`, sftpFsName, fs.config.Username, fs.config.Endpoint) } // ConnectionID returns the connection ID associated to this Fs implementation @@ -284,26 +319,29 @@ func (fs *SFTPFs) ConnectionID() string { // Stat returns a FileInfo describing the named file func (fs *SFTPFs) Stat(name string) (os.FileInfo, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return nil, err } - return fs.sftpClient.Stat(name) + return client.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs *SFTPFs) Lstat(name string) (os.FileInfo, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return nil, err } - return fs.sftpClient.Lstat(name) + return client.Lstat(name) } // Open opens the named file for reading func (fs *SFTPFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, func(), error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return nil, nil, nil, err } - f, err := fs.sftpClient.Open(name) + f, err := client.Open(name) if err != nil { return nil, nil, nil, err } @@ -337,21 +375,21 @@ func (fs *SFTPFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, f // Create creates or opens the named file for writing func (fs *SFTPFs) Create(name string, flag int) (File, *PipeWriter, func(), error) { - err := fs.checkConnection() + client, err := fs.conn.getClient() if err != nil { return nil, nil, nil, err } if fs.config.BufferSize == 0 { var f File if flag == 0 { - f, err = fs.sftpClient.Create(name) + f, err = client.Create(name) } else { - f, err = fs.sftpClient.OpenFile(name, flag) + f, err = client.OpenFile(name, flag) } return f, nil, nil, err } // buffering is enabled - f, err := fs.sftpClient.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC) + f, err := client.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC) if err != nil { return nil, nil, nil, err } @@ -393,48 +431,53 @@ func (fs *SFTPFs) Rename(source, target string) error { if source == target { return nil } - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - if _, ok := fs.sftpClient.HasExtension("posix-rename@openssh.com"); ok { - return fs.sftpClient.PosixRename(source, target) + if _, ok := client.HasExtension("posix-rename@openssh.com"); ok { + return client.PosixRename(source, target) } - return fs.sftpClient.Rename(source, target) + return client.Rename(source, target) } // Remove removes the named file or (empty) directory. func (fs *SFTPFs) Remove(name string, isDir bool) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } if isDir { - return fs.sftpClient.RemoveDirectory(name) + return client.RemoveDirectory(name) } - return fs.sftpClient.Remove(name) + return client.Remove(name) } // Mkdir creates a new directory with the specified name and default permissions func (fs *SFTPFs) Mkdir(name string) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Mkdir(name) + return client.Mkdir(name) } // Symlink creates source as a symbolic link to target. func (fs *SFTPFs) Symlink(source, target string) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Symlink(source, target) + return client.Symlink(source, target) } // Readlink returns the destination of the named symbolic link func (fs *SFTPFs) Readlink(name string) (string, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return "", err } - resolved, err := fs.sftpClient.ReadLink(name) + resolved, err := client.ReadLink(name) if err != nil { return resolved, err } @@ -448,43 +491,48 @@ func (fs *SFTPFs) Readlink(name string) (string, error) { // Chown changes the numeric uid and gid of the named file. func (fs *SFTPFs) Chown(name string, uid int, gid int) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Chown(name, uid, gid) + return client.Chown(name, uid, gid) } // Chmod changes the mode of the named file to mode. func (fs *SFTPFs) Chmod(name string, mode os.FileMode) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Chmod(name, mode) + return client.Chmod(name, mode) } // Chtimes changes the access and modification times of the named file. func (fs *SFTPFs) Chtimes(name string, atime, mtime time.Time, isUploading bool) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Chtimes(name, atime, mtime) + return client.Chtimes(name, atime, mtime) } // Truncate changes the size of the named file. func (fs *SFTPFs) Truncate(name string, size int64) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - return fs.sftpClient.Truncate(name, size) + return client.Truncate(name, size) } // ReadDir reads the directory named by dirname and returns // a list of directory entries. func (fs *SFTPFs) ReadDir(dirname string) ([]os.FileInfo, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return nil, err } - return fs.sftpClient.ReadDir(dirname) + return client.ReadDir(dirname) } // IsUploadResumeSupported returns true if resuming uploads is supported. @@ -528,11 +576,12 @@ func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool { if fs.config.Prefix == "/" { return true } - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return false } - if err := fs.sftpClient.MkdirAll(fs.config.Prefix); err != nil { - fsLog(fs, logger.LevelDebug, "error creating root directory %#v for user %#v: %v", fs.config.Prefix, username, err) + if err := client.MkdirAll(fs.config.Prefix); err != nil { + fsLog(fs, logger.LevelDebug, "error creating root directory %q for user %q: %v", fs.config.Prefix, username, err) return false } return true @@ -581,10 +630,11 @@ func (fs *SFTPFs) GetRelativePath(name string) string { // Walk walks the file tree rooted at root, calling walkFn for each file or // directory in the tree, including root func (fs *SFTPFs) Walk(root string, walkFn filepath.WalkFunc) error { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return err } - walker := fs.sftpClient.Walk(root) + walker := client.Walk(root) for walker.Step() { err := walker.Err() if err != nil { @@ -620,9 +670,6 @@ func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) { if fs.config.Prefix != "/" && fsPath != "/" { // we need to check if this path is a symlink outside the given prefix // or a file/dir inside a dir symlinked outside the prefix - if err := fs.checkConnection(); err != nil { - return "", err - } var validatedPath string var err error validatedPath, err = fs.getRealPath(fsPath) @@ -657,10 +704,11 @@ func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) { // RealPath implements the FsRealPather interface func (fs *SFTPFs) RealPath(p string) (string, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return "", err } - resolved, err := fs.sftpClient.RealPath(p) + resolved, err := client.RealPath(p) if err != nil { return "", err } @@ -676,16 +724,20 @@ func (fs *SFTPFs) RealPath(p string) (string, error) { // getRealPath returns the real remote path trying to resolve symbolic links if any func (fs *SFTPFs) getRealPath(name string) (string, error) { + client, err := fs.conn.getClient() + if err != nil { + return "", err + } linksWalked := 0 for { - info, err := fs.sftpClient.Lstat(name) + info, err := client.Lstat(name) if err != nil { return name, err } if info.Mode()&os.ModeSymlink == 0 { return name, nil } - resolvedLink, err := fs.sftpClient.ReadLink(name) + resolvedLink, err := client.ReadLink(name) if err != nil { return name, fmt.Errorf("unable to resolve link to %q: %w", name, err) } @@ -723,12 +775,13 @@ func (fs *SFTPFs) isSubDir(name string) error { func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) { numFiles := 0 size := int64(0) - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return numFiles, size, err } isDir, err := isDirectory(fs, dirname) if err == nil && isDir { - walker := fs.sftpClient.Walk(dirname) + walker := client.Walk(dirname) for walker.Step() { err := walker.Err() if err != nil { @@ -737,6 +790,9 @@ func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) { if walker.Stat().Mode().IsRegular() { size += walker.Stat().Size() numFiles++ + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) + } } } } @@ -745,10 +801,11 @@ func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) { // GetMimeType returns the content type func (fs *SFTPFs) GetMimeType(name string) (string, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return "", err } - f, err := fs.sftpClient.OpenFile(name, os.O_RDONLY) + f, err := client.OpenFile(name, os.O_RDONLY) if err != nil { return "", err } @@ -766,31 +823,20 @@ func (fs *SFTPFs) GetMimeType(name string) (string, error) { // GetAvailableDiskSize returns the available size for the specified path func (fs *SFTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { - if err := fs.checkConnection(); err != nil { + client, err := fs.conn.getClient() + if err != nil { return nil, err } - if _, ok := fs.sftpClient.HasExtension("statvfs@openssh.com"); !ok { + if _, ok := client.HasExtension("statvfs@openssh.com"); !ok { return nil, ErrStorageSizeUnavailable } - return fs.sftpClient.StatVFS(dirName) + return client.StatVFS(dirName) } // Close the connection func (fs *SFTPFs) Close() error { - fs.Lock() - defer fs.Unlock() - - var sftpErr, sshErr error - if fs.sftpClient != nil { - sftpErr = fs.sftpClient.Close() - } - if fs.sshClient != nil { - sshErr = fs.sshClient.Close() - } - if sftpErr != nil { - return sftpErr - } - return sshErr + fs.conn.RemoveSession(fs.connectionID) + return nil } func (fs *SFTPFs) copy(dst io.Writer, src io.Reader) (written int64, err error) { @@ -825,64 +871,98 @@ func (fs *SFTPFs) copy(dst io.Writer, src io.Reader) (written int64, err error) return written, err } -func (fs *SFTPFs) checkConnection() error { - err := fs.closed() - if err == nil { - return nil +func (fs *SFTPFs) createConnection() error { + err := fs.conn.OpenConnection() + if err != nil { + fsLog(fs, logger.LevelError, "error opening connection: %v", err) + return err } - return fs.createConnection() + return nil } -func (fs *SFTPFs) createConnection() error { - fs.Lock() - defer fs.Unlock() +type sftpConnection struct { + config *SFTPFsConfig + logSender string + sshClient *ssh.Client + sftpClient *sftp.Client + mu sync.RWMutex + isConnected bool + sessions map[string]bool + lastActivity time.Time +} - var err error +func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection { + c := &sftpConnection{ + config: config, + logSender: fmt.Sprintf(`%s "%s@%s"`, sftpFsName, config.Username, config.Endpoint), + isConnected: false, + sessions: map[string]bool{}, + lastActivity: time.Now().UTC(), + } + c.sessions[sessionID] = true + return c +} + +func (c *sftpConnection) OpenConnection() error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.openConnNoLock() +} + +func (c *sftpConnection) openConnNoLock() error { + if c.isConnected { + logger.Debug(c.logSender, "", "reusing connection") + return nil + } + + logger.Debug(c.logSender, "", "try to open a new connection") clientConfig := &ssh.ClientConfig{ - User: fs.config.Username, + User: c.config.Username, HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error { fp := ssh.FingerprintSHA256(key) if util.Contains(sftpFingerprints, fp) { if allowSelfConnections == 0 { - fsLog(fs, logger.LevelError, "SFTP self connections not allowed") + logger.Log(logger.LevelError, c.logSender, "", "SFTP self connections not allowed") return ErrSFTPLoop } - if util.Contains(fs.config.forbiddenSelfUsernames, fs.config.Username) { - fsLog(fs, logger.LevelError, "SFTP loop or nested local SFTP folders detected, mount path %q, username %q, forbidden usernames: %+v", - fs.mountPath, fs.config.Username, fs.config.forbiddenSelfUsernames) + if util.Contains(c.config.forbiddenSelfUsernames, c.config.Username) { + logger.Log(logger.LevelError, c.logSender, "", + "SFTP loop or nested local SFTP folders detected, username %q, forbidden usernames: %+v", + c.config.Username, c.config.forbiddenSelfUsernames) return ErrSFTPLoop } } - if len(fs.config.Fingerprints) > 0 { - for _, provided := range fs.config.Fingerprints { + if len(c.config.Fingerprints) > 0 { + for _, provided := range c.config.Fingerprints { if provided == fp { return nil } } - return fmt.Errorf("invalid fingerprint %#v", fp) + return fmt.Errorf("invalid fingerprint %q", fp) } - fsLog(fs, logger.LevelWarn, "login without host key validation, please provide at least a fingerprint!") + logger.Log(logger.LevelWarn, c.logSender, "", "login without host key validation, please provide at least a fingerprint!") return nil }, Timeout: 10 * time.Second, ClientVersion: fmt.Sprintf("SSH-2.0-SFTPGo_%v", version.Get().Version), } - if fs.config.PrivateKey.GetPayload() != "" { + if c.config.PrivateKey.GetPayload() != "" { var signer ssh.Signer - if fs.config.KeyPassphrase.GetPayload() != "" { - signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(fs.config.PrivateKey.GetPayload()), - []byte(fs.config.KeyPassphrase.GetPayload())) + var err error + if c.config.KeyPassphrase.GetPayload() != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(c.config.PrivateKey.GetPayload()), + []byte(c.config.KeyPassphrase.GetPayload())) } else { - signer, err = ssh.ParsePrivateKey([]byte(fs.config.PrivateKey.GetPayload())) + signer, err = ssh.ParsePrivateKey([]byte(c.config.PrivateKey.GetPayload())) } if err != nil { - fs.err <- err return fmt.Errorf("sftpfs: unable to parse the private key: %w", err) } clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer)) } - if fs.config.Password.GetPayload() != "" { - clientConfig.Auth = append(clientConfig.Auth, ssh.Password(fs.config.Password.GetPayload())) + if c.config.Password.GetPayload() != "" { + clientConfig.Auth = append(clientConfig.Auth, ssh.Password(c.config.Password.GetPayload())) } // add more ciphers, KEXs and MACs, they are negotiated according to the order clientConfig.Ciphers = []string{"aes128-gcm@openssh.com", "aes256-gcm@openssh.com", "chacha20-poly1305@openssh.com", @@ -895,52 +975,225 @@ func (fs *SFTPFs) createConnection() error { clientConfig.MACs = []string{"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96"} - fs.sshClient, err = ssh.Dial("tcp", fs.config.Endpoint, clientConfig) + sshClient, err := ssh.Dial("tcp", c.config.Endpoint, clientConfig) if err != nil { - fsLog(fs, logger.LevelError, "unable to connect: %v", err) - fs.err <- err - return err + return fmt.Errorf("sftpfs: unable to connect: %w", err) } - fs.sftpClient, err = sftp.NewClient(fs.sshClient) + sftpClient, err := sftp.NewClient(sshClient, c.getClientOptions()...) if err != nil { - fsLog(fs, logger.LevelError, "unable to create SFTP client: %v", err) - fs.sshClient.Close() - fs.err <- err - return err + sshClient.Close() + return fmt.Errorf("sftpfs: unable to create SFTP client: %w", err) } - if fs.config.DisableCouncurrentReads { - fsLog(fs, logger.LevelDebug, "disabling concurrent reads") - opt := sftp.UseConcurrentReads(false) - opt(fs.sftpClient) //nolint:errcheck - } - if fs.config.BufferSize > 0 { - fsLog(fs, logger.LevelDebug, "enabling concurrent writes") - opt := sftp.UseConcurrentWrites(true) - opt(fs.sftpClient) //nolint:errcheck - } - go fs.wait() + c.sshClient = sshClient + c.sftpClient = sftpClient + c.isConnected = true + go c.Wait() return nil } -func (fs *SFTPFs) wait() { +func (c *sftpConnection) getClientOptions() []sftp.ClientOption { + var options []sftp.ClientOption + if c.config.DisableCouncurrentReads { + options = append(options, sftp.UseConcurrentReads(false)) + logger.Debug(c.logSender, "", "disabling concurrent reads") + } + if c.config.BufferSize > 0 { + options = append(options, sftp.UseConcurrentWrites(true)) + logger.Debug(c.logSender, "", "enabling concurrent writes") + } + return options +} + +func (c *sftpConnection) getClient() (*sftp.Client, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.isConnected { + return c.sftpClient, nil + } + err := c.openConnNoLock() + return c.sftpClient, err +} + +func (c *sftpConnection) Wait() { + waitEnd := make(chan struct{}) + ticker := time.NewTicker(30 * time.Second) + + go func() { + var watchdogInProgress atomic.Bool + for { + select { + case <-ticker.C: + if watchdogInProgress.Load() { + logger.Error(c.logSender, "", "watchdog still in progress, closing hanging connection") + ticker.Stop() + c.sshClient.Close() + return + } + go func() { + watchdogInProgress.Store(true) + defer watchdogInProgress.Store(false) + + _, err := c.sftpClient.Getwd() + if err != nil { + logger.Error(c.logSender, "", "watchdog error: %v", err) + } + }() + case <-waitEnd: + logger.Debug(c.logSender, "", "quitting watchdog") + ticker.Stop() + return + } + } + }() + // we wait on the sftp client otherwise if the channel is closed but not the connection // we don't detect the event. - fs.err <- fs.sftpClient.Wait() - fsLog(fs, logger.LevelDebug, "sftp channel closed") + err := c.sftpClient.Wait() + logger.Log(logger.LevelDebug, c.logSender, "", "sftp channel closed: %v", err) + close(waitEnd) - fs.Lock() - defer fs.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - if fs.sshClient != nil { - fs.sshClient.Close() + c.isConnected = false + if c.sshClient != nil { + c.sshClient.Close() } } -func (fs *SFTPFs) closed() error { - select { - case err := <-fs.err: - return err - default: - return nil +func (c *sftpConnection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + logger.Debug(c.logSender, "", "closing connection") + var sftpErr, sshErr error + if c.sftpClient != nil { + sftpErr = c.sftpClient.Close() + } + if c.sshClient != nil { + sshErr = c.sshClient.Close() + } + if sftpErr != nil { + return sftpErr + } + c.isConnected = false + return sshErr +} + +func (c *sftpConnection) AddSession(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + + c.sessions[sessionID] = true + logger.Debug(c.logSender, "", "added session %s, active sessions: %d", sessionID, len(c.sessions)) +} + +func (c *sftpConnection) RemoveSession(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.sessions, sessionID) + logger.Debug(c.logSender, "", "removed session %s, active sessions: %d", sessionID, len(c.sessions)) + if len(c.sessions) == 0 { + c.lastActivity = time.Now().UTC() } } + +func (c *sftpConnection) ActiveSessions() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.sessions) +} + +func (c *sftpConnection) GetLastActivity() time.Time { + c.mu.RLock() + defer c.mu.RUnlock() + + if len(c.sessions) > 0 { + return time.Now().UTC() + } + logger.Debug(c.logSender, "", "last activity %s", c.lastActivity) + return c.lastActivity +} + +type sftpConnectionsCache struct { + scheduler *cron.Cron + sync.RWMutex + items map[uint64]*sftpConnection +} + +func newSFTPConnectionCache() *sftpConnectionsCache { + c := &sftpConnectionsCache{ + scheduler: cron.New(), + items: make(map[uint64]*sftpConnection), + } + _, err := c.scheduler.AddFunc("@every 1m", c.Cleanup) + util.PanicOnError(err) + c.scheduler.Start() + return c +} + +func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftpConnection { + partition := 0 + key := config.getUniqueID(partition) + + c.Lock() + defer c.Unlock() + + var oldKey uint64 + for { + if val, ok := c.items[key]; ok { + activeSessions := val.ActiveSessions() + if activeSessions < maxSessionsPerConnection || key == oldKey { + logger.Debug(logSenderSFTPCache, "", + "reusing connection for session ID %q, key: %d, active sessions %d, active connections: %d", + sessionID, key, activeSessions+1, len(c.items)) + val.AddSession(sessionID) + return val + } + partition++ + oldKey = key + key = config.getUniqueID(partition) + logger.Debug(logSenderSFTPCache, "", + "connection full, generated new key for partition: %d, active sessions: %d, key: %d, old key: %d", + partition, activeSessions, oldKey, key) + } else { + conn := newSFTPConnection(config, sessionID) + c.items[key] = conn + logger.Debug(logSenderSFTPCache, "", + "adding new connection for session ID %q, partition: %d, key: %d, active connections: %d", + sessionID, partition, key, len(c.items)) + return conn + } + } +} + +func (c *sftpConnectionsCache) Remove(key uint64) { + c.Lock() + defer c.Unlock() + + if conn, ok := c.items[key]; ok { + delete(c.items, key) + logger.Debug(logSenderSFTPCache, "", "removed connection with key %d, active connections: %d", key, len(c.items)) + + defer conn.Close() + } +} + +func (c *sftpConnectionsCache) Cleanup() { + c.RLock() + + for k, conn := range c.items { + if val := conn.GetLastActivity(); val.Before(time.Now().Add(-30 * time.Second)) { + logger.Debug(conn.logSender, "", "removing inactive connection, last activity %s", val) + + defer func(key uint64) { + c.Remove(key) + }(k) + } + } + + c.RUnlock() +} diff --git a/internal/webdavd/webdavd_test.go b/internal/webdavd/webdavd_test.go index 785080c1..c7a42473 100644 --- a/internal/webdavd/webdavd_test.go +++ b/internal/webdavd/webdavd_test.go @@ -922,12 +922,13 @@ func TestPropPatch(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - assert.Len(t, common.Connections.GetStats(), 0) } _, err = httpdtest.RemoveUser(localUser, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, + 1*time.Second, 100*time.Millisecond) } func TestLoginInvalidPwd(t *testing.T) { @@ -1898,6 +1899,11 @@ func TestClientClose(t *testing.T) { common.Connections.Close(stat.ConnectionID) } wg.Wait() + // for the sftp user a stat is done after the failed upload and + // this triggers a new connection + for _, stat := range common.Connections.GetStats() { + common.Connections.Close(stat.ConnectionID) + } assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 100*time.Millisecond)