From d6dc3a507ed56a794bb9f9c8a4fb3321eccb1c7b Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sun, 21 Mar 2021 19:15:47 +0100 Subject: [PATCH] extend virtual folders support to all storage backends Fixes #241 --- cmd/portable.go | 14 +- common/actions.go | 7 +- common/actions_test.go | 14 +- common/common_test.go | 197 +-- common/connection.go | 394 +++--- common/connection_test.go | 1131 ++-------------- common/protocol_test.go | 2415 ++++++++++++++++++++++++++++++++++ common/transfer.go | 8 +- common/transfer_test.go | 28 +- dataprovider/bolt.go | 90 +- dataprovider/dataprovider.go | 318 +++-- dataprovider/memory.go | 83 +- dataprovider/mysql.go | 4 +- dataprovider/pgsql.go | 4 +- dataprovider/sqlcommon.go | 126 +- dataprovider/sqlite.go | 6 +- dataprovider/sqlqueries.go | 22 +- dataprovider/user.go | 488 ++++--- docs/ssh-commands.md | 2 +- ftpd/cryptfs_test.go | 3 +- ftpd/ftpd_test.go | 127 +- ftpd/handler.go | 152 +-- ftpd/internal_test.go | 32 +- ftpd/server.go | 25 +- ftpd/transfer.go | 2 +- httpd/api_admin.go | 6 +- httpd/api_folder.go | 19 +- httpd/api_maintenance.go | 10 +- httpd/api_quota.go | 27 +- httpd/api_user.go | 51 +- httpd/api_utils.go | 6 +- httpd/auth_utils.go | 4 +- httpd/httpd.go | 4 +- httpd/httpd_test.go | 594 +++++++-- httpd/internal_test.go | 56 +- httpd/middleware.go | 2 +- httpd/schema/openapi.yaml | 45 +- httpd/web.go | 85 +- httpdtest/httpdtest.go | 221 ++-- service/service_portable.go | 15 +- sftpd/cryptfs_test.go | 2 +- sftpd/handler.go | 170 ++- sftpd/internal_test.go | 401 +++--- sftpd/scp.go | 226 ++-- sftpd/server.go | 67 +- sftpd/sftpd_test.go | 801 +++++++---- sftpd/ssh_cmd.go | 187 +-- sftpd/subsystem.go | 8 +- sftpd/transfer.go | 4 +- templates/folder.html | 17 +- templates/fsconfig.html | 359 +++++ templates/user.html | 356 +---- utils/utils.go | 20 +- vfs/azblobfs.go | 33 +- vfs/azblobfs_disabled.go | 2 +- vfs/cryptfs.go | 10 +- vfs/filesystem.go | 104 ++ vfs/folder.go | 106 ++ vfs/gcsfs.go | 32 +- vfs/gcsfs_disabled.go | 2 +- vfs/osfs.go | 173 +-- vfs/s3fs.go | 29 +- vfs/s3fs_disabled.go | 2 +- vfs/sftpfs.go | 31 +- vfs/vfs.go | 22 +- webdavd/file.go | 4 +- webdavd/handler.go | 128 +- webdavd/internal_test.go | 285 +++- webdavd/server.go | 31 +- webdavd/webdavd_test.go | 116 +- 70 files changed, 6825 insertions(+), 3740 deletions(-) create mode 100644 common/protocol_test.go create mode 100644 templates/fsconfig.html create mode 100644 vfs/filesystem.go diff --git a/cmd/portable.go b/cmd/portable.go index 73443ad1..a20bf8fc 100644 --- a/cmd/portable.go +++ b/cmd/portable.go @@ -84,9 +84,9 @@ $ sftpgo portable Please take a look at the usage below to customize the serving parameters`, Run: func(cmd *cobra.Command, args []string) { portableDir := directoryToServe - fsProvider := dataprovider.FilesystemProvider(portableFsProvider) + fsProvider := vfs.FilesystemProvider(portableFsProvider) if !filepath.IsAbs(portableDir) { - if fsProvider == dataprovider.LocalFilesystemProvider { + if fsProvider == vfs.LocalFilesystemProvider { portableDir, _ = filepath.Abs(portableDir) } else { portableDir = os.TempDir() @@ -95,7 +95,7 @@ Please take a look at the usage below to customize the serving parameters`, permissions := make(map[string][]string) permissions["/"] = portablePermissions portableGCSCredentials := "" - if fsProvider == dataprovider.GCSFilesystemProvider && portableGCSCredentialsFile != "" { + if fsProvider == vfs.GCSFilesystemProvider && portableGCSCredentialsFile != "" { contents, err := getFileContents(portableGCSCredentialsFile) if err != nil { fmt.Printf("Unable to get GCS credentials: %v\n", err) @@ -105,7 +105,7 @@ Please take a look at the usage below to customize the serving parameters`, portableGCSAutoCredentials = 0 } portableSFTPPrivateKey := "" - if fsProvider == dataprovider.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" { + if fsProvider == vfs.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" { contents, err := getFileContents(portableSFTPPrivateKeyPath) if err != nil { fmt.Printf("Unable to get SFTP private key: %v\n", err) @@ -149,8 +149,8 @@ Please take a look at the usage below to customize the serving parameters`, Permissions: permissions, HomeDir: portableDir, Status: 1, - FsConfig: dataprovider.Filesystem{ - Provider: dataprovider.FilesystemProvider(portableFsProvider), + FsConfig: vfs.Filesystem{ + Provider: vfs.FilesystemProvider(portableFsProvider), S3Config: vfs.S3FsConfig{ Bucket: portableS3Bucket, Region: portableS3Region, @@ -257,7 +257,7 @@ multicast DNS`) advertised via multicast DNS, this flag allows to put username/password inside the advertised TXT record`) - portableCmd.Flags().IntVarP(&portableFsProvider, "fs-provider", "f", int(dataprovider.LocalFilesystemProvider), `0 => local filesystem + portableCmd.Flags().IntVarP(&portableFsProvider, "fs-provider", "f", int(vfs.LocalFilesystemProvider), `0 => local filesystem 1 => AWS S3 compatible 2 => Google Cloud Storage 3 => Azure Blob Storage diff --git a/common/actions.go b/common/actions.go index c9e5bd1e..d1090625 100644 --- a/common/actions.go +++ b/common/actions.go @@ -18,6 +18,7 @@ import ( "github.com/drakkan/sftpgo/httpclient" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" + "github.com/drakkan/sftpgo/vfs" ) var ( @@ -79,12 +80,12 @@ func newActionNotification( var bucket, endpoint string status := 1 - if user.FsConfig.Provider == dataprovider.S3FilesystemProvider { + if user.FsConfig.Provider == vfs.S3FilesystemProvider { bucket = user.FsConfig.S3Config.Bucket endpoint = user.FsConfig.S3Config.Endpoint - } else if user.FsConfig.Provider == dataprovider.GCSFilesystemProvider { + } else if user.FsConfig.Provider == vfs.GCSFilesystemProvider { bucket = user.FsConfig.GCSConfig.Bucket - } else if user.FsConfig.Provider == dataprovider.AzureBlobFilesystemProvider { + } else if user.FsConfig.Provider == vfs.AzureBlobFilesystemProvider { bucket = user.FsConfig.AzBlobConfig.Container if user.FsConfig.AzBlobConfig.SASURL != "" { endpoint = user.FsConfig.AzBlobConfig.SASURL diff --git a/common/actions_test.go b/common/actions_test.go index 231f6dd8..8a46c94f 100644 --- a/common/actions_test.go +++ b/common/actions_test.go @@ -19,7 +19,7 @@ func TestNewActionNotification(t *testing.T) { user := &dataprovider.User{ Username: "username", } - user.FsConfig.Provider = dataprovider.LocalFilesystemProvider + user.FsConfig.Provider = vfs.LocalFilesystemProvider user.FsConfig.S3Config = vfs.S3FsConfig{ Bucket: "s3bucket", Endpoint: "endpoint", @@ -38,19 +38,19 @@ func TestNewActionNotification(t *testing.T) { assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 0, a.Status) - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSSH, 123, nil) assert.Equal(t, "s3bucket", a.Bucket) assert.Equal(t, "endpoint", a.Endpoint) assert.Equal(t, 1, a.Status) - user.FsConfig.Provider = dataprovider.GCSFilesystemProvider + user.FsConfig.Provider = vfs.GCSFilesystemProvider a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, ErrQuotaExceeded) assert.Equal(t, "gcsbucket", a.Bucket) assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 2, a.Status) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, nil) assert.Equal(t, "azcontainer", a.Bucket) assert.Equal(t, "azsasurl", a.Endpoint) @@ -179,15 +179,15 @@ func TestPreDeleteAction(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("id", homeDir, nil) - c := NewBaseConnection("id", ProtocolSFTP, user, fs) + fs := vfs.NewOsFs("id", homeDir, "") + c := NewBaseConnection("id", ProtocolSFTP, user) testfile := filepath.Join(user.HomeDir, "testfile") err = os.WriteFile(testfile, []byte("test"), os.ModePerm) assert.NoError(t, err) info, err := os.Stat(testfile) assert.NoError(t, err) - err = c.RemoveFile(testfile, "testfile", info) + err = c.RemoveFile(fs, testfile, "testfile", info) assert.NoError(t, err) assert.FileExists(t, testfile) diff --git a/common/common_test.go b/common/common_test.go index f5dfdefb..f0cf4f65 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -3,7 +3,6 @@ package common import ( "fmt" "net" - "net/http" "os" "os/exec" "path/filepath" @@ -13,15 +12,11 @@ import ( "testing" "time" - "github.com/rs/zerolog" - "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/httpclient" "github.com/drakkan/sftpgo/kms" - "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/vfs" ) @@ -29,29 +24,22 @@ import ( const ( logSenderTest = "common_test" httpAddr = "127.0.0.1:9999" - httpProxyAddr = "127.0.0.1:7777" configDir = ".." osWindows = "windows" userTestUsername = "common_test_username" - userTestPwd = "common_test_pwd" ) -type providerConf struct { - Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"` -} - type fakeConnection struct { *BaseConnection command string } func (c *fakeConnection) AddUser(user dataprovider.User) error { - fs, err := user.GetFilesystem(c.GetID()) + _, err := user.GetFilesystem(c.GetID()) if err != nil { return err } c.BaseConnection.User = user - c.BaseConnection.Fs = fs return nil } @@ -84,110 +72,6 @@ func (c *customNetConn) Close() error { return c.Conn.Close() } -func TestMain(m *testing.M) { - logfilePath := "common_test.log" - logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel) - - viper.SetEnvPrefix("sftpgo") - replacer := strings.NewReplacer(".", "__") - viper.SetEnvKeyReplacer(replacer) - viper.SetConfigName("sftpgo") - viper.AutomaticEnv() - viper.AllowEmptyEnv(true) - - driver, err := initializeDataprovider(-1) - if err != nil { - logger.WarnToConsole("error initializing data provider: %v", err) - os.Exit(1) - } - logger.InfoToConsole("Starting COMMON tests, provider: %v", driver) - err = Initialize(Configuration{}) - if err != nil { - logger.WarnToConsole("error initializing common: %v", err) - os.Exit(1) - } - httpConfig := httpclient.Config{ - Timeout: 5, - } - httpConfig.Initialize(configDir) //nolint:errcheck - - go func() { - // start a test HTTP server to receive action notifications - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "OK\n") - }) - http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - fmt.Fprintf(w, "Not found\n") - }) - if err := http.ListenAndServe(httpAddr, nil); err != nil { - logger.ErrorToConsole("could not start HTTP notification server: %v", err) - os.Exit(1) - } - }() - - go func() { - Config.ProxyProtocol = 2 - listener, err := net.Listen("tcp", httpProxyAddr) - if err != nil { - logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err) - os.Exit(1) - } - proxyListener, err := Config.GetProxyListener(listener) - if err != nil { - logger.ErrorToConsole("error creating proxy protocol listener: %v", err) - os.Exit(1) - } - Config.ProxyProtocol = 0 - - s := &http.Server{} - if err := s.Serve(proxyListener); err != nil { - logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err) - os.Exit(1) - } - }() - - waitTCPListening(httpAddr) - waitTCPListening(httpProxyAddr) - exitCode := m.Run() - os.Remove(logfilePath) //nolint:errcheck - os.Exit(exitCode) -} - -func waitTCPListening(address string) { - for { - conn, err := net.Dial("tcp", address) - if err != nil { - logger.WarnToConsole("tcp server %v not listening: %v", address, err) - time.Sleep(100 * time.Millisecond) - continue - } - logger.InfoToConsole("tcp server %v now listening", address) - conn.Close() - break - } -} - -func initializeDataprovider(trackQuota int) (string, error) { - configDir := ".." - viper.AddConfigPath(configDir) - if err := viper.ReadInConfig(); err != nil { - return "", err - } - var cfg providerConf - if err := viper.Unmarshal(&cfg); err != nil { - return "", err - } - if trackQuota >= 0 && trackQuota <= 2 { - cfg.Config.TrackQuota = trackQuota - } - return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir, true) -} - -func closeDataprovider() error { - return dataprovider.Close() -} - func TestSSHConnections(t *testing.T) { conn1, conn2 := net.Pipe() now := time.Now() @@ -286,7 +170,7 @@ func TestMaxConnections(t *testing.T) { Config.MaxTotalConnections = 1 assert.True(t, Connections.IsNewConnectionAllowed()) - c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil) + c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } @@ -324,7 +208,7 @@ func TestIdleConnections(t *testing.T) { user := dataprovider.User{ Username: username, } - c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil) + c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() fakeConn := &fakeConnection{ BaseConnection: c, @@ -336,7 +220,7 @@ func TestIdleConnections(t *testing.T) { Connections.AddSSHConnection(sshConn1) Connections.Add(fakeConn) assert.Equal(t, Connections.GetActiveSessions(username), 1) - c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil) + c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user) fakeConn = &fakeConnection{ BaseConnection: c, } @@ -344,7 +228,7 @@ func TestIdleConnections(t *testing.T) { Connections.Add(fakeConn) assert.Equal(t, Connections.GetActiveSessions(username), 2) - cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil) + cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}) cFTP.lastActivity = time.Now().UnixNano() fakeConn = &fakeConnection{ BaseConnection: cFTP, @@ -383,7 +267,7 @@ func TestIdleConnections(t *testing.T) { } func TestCloseConnection(t *testing.T) { - c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil) + c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } @@ -399,7 +283,7 @@ func TestCloseConnection(t *testing.T) { } func TestSwapConnection(t *testing.T) { - c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}, nil) + c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}) fakeConn := &fakeConnection{ BaseConnection: c, } @@ -409,7 +293,7 @@ func TestSwapConnection(t *testing.T) { } c = NewBaseConnection("id", ProtocolFTP, dataprovider.User{ Username: userTestUsername, - }, nil) + }) fakeConn = &fakeConnection{ BaseConnection: c, } @@ -443,8 +327,8 @@ func TestConnectionStatus(t *testing.T) { user := dataprovider.User{ Username: username, } - fs := vfs.NewOsFs("", os.TempDir(), nil) - c1 := NewBaseConnection("id1", ProtocolSFTP, user, fs) + fs := vfs.NewOsFs("", os.TempDir(), "") + c1 := NewBaseConnection("id1", ProtocolSFTP, user) fakeConn1 := &fakeConnection{ BaseConnection: c1, } @@ -452,12 +336,12 @@ func TestConnectionStatus(t *testing.T) { t1.BytesReceived = 123 t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) t2.BytesSent = 456 - c2 := NewBaseConnection("id2", ProtocolSSH, user, nil) + c2 := NewBaseConnection("id2", ProtocolSSH, user) fakeConn2 := &fakeConnection{ BaseConnection: c2, command: "md5sum", } - c3 := NewBaseConnection("id3", ProtocolWebDAV, user, nil) + c3 := NewBaseConnection("id3", ProtocolWebDAV, user) fakeConn3 := &fakeConnection{ BaseConnection: c3, command: "PROPFIND", @@ -565,15 +449,6 @@ func TestProxyProtocolVersion(t *testing.T) { assert.Error(t, err) } -func TestProxyProtocol(t *testing.T) { - httpClient := httpclient.GetHTTPClient() - resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) - if assert.NoError(t, err) { - defer resp.Body.Close() - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - } -} - func TestPostConnectHook(t *testing.T) { Config.PostConnectHook = "" @@ -614,7 +489,7 @@ func TestPostConnectHook(t *testing.T) { func TestCryptoConvertFileInfo(t *testing.T) { name := "name" - fs, err := vfs.NewCryptFs("connID1", os.TempDir(), vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) + fs, err := vfs.NewCryptFs("connID1", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) require.NoError(t, err) cryptFs := fs.(*vfs.CryptFs) info := vfs.NewFileInfo(name, true, 48, time.Now(), false) @@ -649,4 +524,50 @@ func TestFolderCopy(t *testing.T) { require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) + + folder.FsConfig = vfs.Filesystem{ + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("crypto secret"), + }, + } + folderCopy = folder.GetACopy() + folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() + require.Len(t, folderCopy.Users, 1) + require.True(t, utils.IsStringInSlice("user3", folderCopy.Users)) + require.Equal(t, int64(2), folderCopy.ID) + require.Equal(t, folder.Name, folderCopy.Name) + require.Equal(t, folder.MappedPath, folderCopy.MappedPath) + require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) + require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) + require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) + require.Equal(t, "crypto secret", folderCopy.FsConfig.CryptConfig.Passphrase.GetPayload()) +} + +func TestCachedFs(t *testing.T) { + user := dataprovider.User{ + HomeDir: filepath.Clean(os.TempDir()), + } + conn := NewBaseConnection("id", ProtocolSFTP, user) + // changing the user should not affect the connection + user.HomeDir = filepath.Join(os.TempDir(), "temp") + err := os.Mkdir(user.HomeDir, os.ModePerm) + assert.NoError(t, err) + fs, err := user.GetFilesystem("") + assert.NoError(t, err) + p, err := fs.ResolvePath("/") + assert.NoError(t, err) + assert.Equal(t, user.GetHomeDir(), p) + + _, p, err = conn.GetFsAndResolvedPath("/") + assert.NoError(t, err) + assert.Equal(t, filepath.Clean(os.TempDir()), p) + user.FsConfig.Provider = vfs.S3FilesystemProvider + _, err = user.GetFilesystem("") + assert.Error(t, err) + conn.User.FsConfig.Provider = vfs.S3FilesystemProvider + _, p, err = conn.GetFsAndResolvedPath("/") + assert.NoError(t, err) + assert.Equal(t, filepath.Clean(os.TempDir()), p) + err = os.Remove(user.HomeDir) + assert.NoError(t, err) } diff --git a/common/connection.go b/common/connection.go index 87063589..add73d4a 100644 --- a/common/connection.go +++ b/common/connection.go @@ -30,14 +30,13 @@ type BaseConnection struct { // start time for this connection startTime time.Time protocol string - Fs vfs.Fs sync.RWMutex transferID uint64 activeTransfers []ActiveTransfer } // NewBaseConnection returns a new BaseConnection -func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection { +func NewBaseConnection(id, protocol string, user dataprovider.User) *BaseConnection { connID := id if utils.IsStringInSlice(protocol, supportedProtocols) { connID = fmt.Sprintf("%v_%v", protocol, id) @@ -47,7 +46,6 @@ func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) * User: user, startTime: time.Now(), protocol: protocol, - Fs: fs, lastActivity: time.Now().UnixNano(), transferID: 0, } @@ -103,10 +101,7 @@ func (c *BaseConnection) GetLastActivity() time.Time { // CloseFS closes the underlying fs func (c *BaseConnection) CloseFS() error { - if c.Fs != nil { - return c.Fs.Close() - } - return nil + return c.User.CloseFs() } // AddTransfer associates a new transfer to this connection @@ -207,21 +202,25 @@ func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, e return 0, errNoTransfer } -// ListDir reads the directory named by fsPath and returns a list of directory entries -func (c *BaseConnection) ListDir(fsPath, virtualPath string) ([]os.FileInfo, error) { +// ListDir reads the directory matching virtualPath and returns a list of directory entries +func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) { if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) { return nil, c.GetPermissionDeniedError() } - files, err := c.Fs.ReadDir(fsPath) + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return nil, err + } + files, err := fs.ReadDir(fsPath) if err != nil { c.Log(logger.LevelWarn, "error listing directory: %+v", err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } return c.User.AddVirtualDirs(files, virtualPath), nil } // CreateDir creates a new directory at the specified fsPath -func (c *BaseConnection) CreateDir(fsPath, virtualPath string) error { +func (c *BaseConnection) CreateDir(virtualPath string) error { if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) { return c.GetPermissionDeniedError() } @@ -229,42 +228,47 @@ func (c *BaseConnection) CreateDir(fsPath, virtualPath string) error { c.Log(logger.LevelWarn, "mkdir not allowed %#v is a virtual folder", virtualPath) return c.GetPermissionDeniedError() } - if err := c.Fs.Mkdir(fsPath); err != nil { - c.Log(logger.LevelWarn, "error creating dir: %#v error: %+v", fsPath, err) - return c.GetFsError(err) + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err } - vfs.SetPathPermissions(c.Fs, fsPath, c.User.GetUID(), c.User.GetGID()) + if err := fs.Mkdir(fsPath); err != nil { + c.Log(logger.LevelWarn, "error creating dir: %#v error: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + vfs.SetPathPermissions(fs, fsPath, c.User.GetUID(), c.User.GetGID()) logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1) return nil } // IsRemoveFileAllowed returns an error if removing this file is not allowed -func (c *BaseConnection) IsRemoveFileAllowed(fsPath, virtualPath string) error { +func (c *BaseConnection) IsRemoveFileAllowed(virtualPath string) error { if !c.User.HasPerm(dataprovider.PermDelete, path.Dir(virtualPath)) { return c.GetPermissionDeniedError() } if !c.User.IsFileAllowed(virtualPath) { - c.Log(logger.LevelDebug, "removing file %#v is not allowed", fsPath) + c.Log(logger.LevelDebug, "removing file %#v is not allowed", virtualPath) return c.GetPermissionDeniedError() } return nil } // RemoveFile removes a file at the specified fsPath -func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo) error { - if err := c.IsRemoveFileAllowed(fsPath, virtualPath); err != nil { +func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error { + if err := c.IsRemoveFileAllowed(virtualPath); err != nil { return err } + size := info.Size() action := newActionNotification(&c.User, operationPreDelete, fsPath, "", "", c.protocol, size, nil) actionErr := actionHandler.Handle(action) if actionErr == nil { c.Log(logger.LevelDebug, "remove for path %#v handled by pre-delete action", fsPath) } else { - if err := c.Fs.Remove(fsPath, false); err != nil { + if err := fs.Remove(fsPath, false); err != nil { c.Log(logger.LevelWarn, "failed to remove a file/symlink %#v: %+v", fsPath, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } } @@ -288,8 +292,8 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo } // IsRemoveDirAllowed returns an error if removing this directory is not allowed -func (c *BaseConnection) IsRemoveDirAllowed(fsPath, virtualPath string) error { - if c.Fs.GetRelativePath(fsPath) == "/" { +func (c *BaseConnection) IsRemoveDirAllowed(fs vfs.Fs, fsPath, virtualPath string) error { + if fs.GetRelativePath(fsPath) == "/" { c.Log(logger.LevelWarn, "removing root dir is not allowed") return c.GetPermissionDeniedError() } @@ -312,54 +316,57 @@ func (c *BaseConnection) IsRemoveDirAllowed(fsPath, virtualPath string) error { } // RemoveDir removes a directory at the specified fsPath -func (c *BaseConnection) RemoveDir(fsPath, virtualPath string) error { - if err := c.IsRemoveDirAllowed(fsPath, virtualPath); err != nil { +func (c *BaseConnection) RemoveDir(virtualPath string) error { + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { return err } var fi os.FileInfo - var err error - if fi, err = c.Fs.Lstat(fsPath); err != nil { + if fi, err = fs.Lstat(fsPath); err != nil { // see #149 - if c.Fs.IsNotExist(err) && c.Fs.HasVirtualFolders() { + if fs.IsNotExist(err) && fs.HasVirtualFolders() { return nil } c.Log(logger.LevelWarn, "failed to remove a dir %#v: stat error: %+v", fsPath, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } if !fi.IsDir() || fi.Mode()&os.ModeSymlink != 0 { c.Log(logger.LevelDebug, "cannot remove %#v is not a directory", fsPath) return c.GetGenericError(nil) } - if err := c.Fs.Remove(fsPath, true); err != nil { + if err := fs.Remove(fsPath, true); err != nil { c.Log(logger.LevelWarn, "failed to remove directory %#v: %+v", fsPath, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1) return nil } -// Rename renames (moves) fsSourcePath to fsTargetPath -func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string) error { - if c.User.IsMappedPath(fsSourcePath) { - c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %#v", fsSourcePath) - return c.GetPermissionDeniedError() - } - if c.User.IsMappedPath(fsTargetPath) { - c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath) - return c.GetPermissionDeniedError() - } - srcInfo, err := c.Fs.Lstat(fsSourcePath) +// Rename renames (moves) virtualSourcePath to virtualTargetPath +func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) error { + fsSrc, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) if err != nil { - return c.GetFsError(err) + return err } - if !c.isRenamePermitted(fsSourcePath, virtualSourcePath, virtualTargetPath, srcInfo) { + fsDst, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) + if err != nil { + return err + } + srcInfo, err := fsSrc.Lstat(fsSourcePath) + if err != nil { + return c.GetFsError(fsSrc, err) + } + if !c.isRenamePermitted(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo) { return c.GetPermissionDeniedError() } initialSize := int64(-1) - if dstInfo, err := c.Fs.Lstat(fsTargetPath); err == nil { + if dstInfo, err := fsDst.Lstat(fsTargetPath); err == nil { if dstInfo.IsDir() { c.Log(logger.LevelWarn, "attempted to rename %#v overwriting an existing directory %#v", fsSourcePath, fsTargetPath) @@ -370,8 +377,8 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v initialSize = dstInfo.Size() } if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualTargetPath)) { - c.Log(logger.LevelDebug, "renaming is not allowed, %#v -> %#v. Target exists but the user "+ - "has no overwrite permission", virtualSourcePath, virtualTargetPath) + c.Log(logger.LevelDebug, "renaming %#v -> %#v is not allowed. Target exists but the user %#v"+ + "has no overwrite permission", virtualSourcePath, virtualTargetPath, c.User.Username) return c.GetPermissionDeniedError() } } @@ -381,22 +388,21 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v virtualSourcePath) return c.GetOpUnsupportedError() } - if err = c.checkRecursiveRenameDirPermissions(fsSourcePath, fsTargetPath); err != nil { + if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath); err != nil { c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %#v: %+v", fsSourcePath, err) - return c.GetFsError(err) + return err } } - if !c.hasSpaceForRename(virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath) { + if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath) { c.Log(logger.LevelInfo, "denying cross rename due to space limit") return c.GetGenericError(ErrQuotaExceeded) } - if err := c.Fs.Rename(fsSourcePath, fsTargetPath); err != nil { + if err := fsSrc.Rename(fsSourcePath, fsTargetPath); err != nil { c.Log(logger.LevelWarn, "failed to rename %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err) - return c.GetFsError(err) - } - if dataprovider.GetQuotaTracking() > 0 { - c.updateQuotaAfterRename(virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck + return c.GetFsError(fsSrc, err) } + vfs.SetPathPermissions(fsDst, fsTargetPath, c.User.GetUID(), c.User.GetGID()) + c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1) action := newActionNotification(&c.User, operationRename, fsSourcePath, fsTargetPath, "", c.protocol, 0, nil) @@ -407,41 +413,42 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v } // CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath -func (c *BaseConnection) CreateSymlink(fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string) error { - if c.Fs.GetRelativePath(fsSourcePath) == "/" { +func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error { + if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) { + c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath) + return c.GetOpUnsupportedError() + } + // we cannot have a cross folder request here so only one fs is enough + fs, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) + if err != nil { + return err + } + fsTargetPath, err := fs.ResolvePath(virtualTargetPath) + if err != nil { + return c.GetFsError(fs, err) + } + if fs.GetRelativePath(fsSourcePath) == "/" { c.Log(logger.LevelWarn, "symlinking root dir is not allowed") return c.GetPermissionDeniedError() } - if c.User.IsVirtualFolder(virtualTargetPath) { - c.Log(logger.LevelWarn, "symlinking a virtual folder is not allowed") + if fs.GetRelativePath(fsTargetPath) == "/" { + c.Log(logger.LevelWarn, "symlinking to root dir is not allowed") return c.GetPermissionDeniedError() } if !c.User.HasPerm(dataprovider.PermCreateSymlinks, path.Dir(virtualTargetPath)) { return c.GetPermissionDeniedError() } - if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) { - c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath) - return c.GetOpUnsupportedError() - } - if c.User.IsMappedPath(fsSourcePath) { - c.Log(logger.LevelWarn, "symlinking a directory mapped as virtual folder is not allowed: %#v", fsSourcePath) - return c.GetPermissionDeniedError() - } - if c.User.IsMappedPath(fsTargetPath) { - c.Log(logger.LevelWarn, "symlinking to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath) - return c.GetPermissionDeniedError() - } - if err := c.Fs.Symlink(fsSourcePath, fsTargetPath); err != nil { + if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil { c.Log(logger.LevelWarn, "failed to create symlink %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } logger.CommandLog(symlinkLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1) return nil } -func (c *BaseConnection) getPathForSetStatPerms(fsPath, virtualPath string) string { +func (c *BaseConnection) getPathForSetStatPerms(fs vfs.Fs, fsPath, virtualPath string) string { pathForPerms := virtualPath - if fi, err := c.Fs.Lstat(fsPath); err == nil { + if fi, err := fs.Lstat(fsPath); err == nil { if fi.IsDir() { pathForPerms = path.Dir(virtualPath) } @@ -450,74 +457,89 @@ func (c *BaseConnection) getPathForSetStatPerms(fsPath, virtualPath string) stri } // DoStat execute a Stat if mode = 0, Lstat if mode = 1 -func (c *BaseConnection) DoStat(fsPath string, mode int) (os.FileInfo, error) { +func (c *BaseConnection) DoStat(virtualPath string, mode int) (os.FileInfo, error) { + // for some vfs we don't create intermediary folders so we cannot simply check + // if virtualPath is a virtual folder + vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) + if _, ok := vfolders[virtualPath]; ok { + return vfs.NewFileInfo(virtualPath, true, 0, time.Now(), false), nil + } + var info os.FileInfo - var err error + + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return info, err + } + if mode == 1 { - info, err = c.Fs.Lstat(c.getRealFsPath(fsPath)) + info, err = fs.Lstat(c.getRealFsPath(fsPath)) } else { - info, err = c.Fs.Stat(c.getRealFsPath(fsPath)) + info, err = fs.Stat(c.getRealFsPath(fsPath)) } - if err == nil && vfs.IsCryptOsFs(c.Fs) { - info = c.Fs.(*vfs.CryptFs).ConvertFileInfo(info) + if err != nil { + return info, c.GetFsError(fs, err) } - return info, err + if vfs.IsCryptOsFs(fs) { + info = fs.(*vfs.CryptFs).ConvertFileInfo(info) + } + return info, nil } -func (c *BaseConnection) ignoreSetStat() bool { +func (c *BaseConnection) ignoreSetStat(fs vfs.Fs) bool { if Config.SetstatMode == 1 { return true } - if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(c.Fs) { + if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(fs) && !vfs.IsCryptOsFs(fs) { return true } return false } -func (c *BaseConnection) handleChmod(fsPath, pathForPerms string, attributes *StatAttributes) error { +func (c *BaseConnection) handleChmod(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChmod, pathForPerms) { return c.GetPermissionDeniedError() } - if c.ignoreSetStat() { + if c.ignoreSetStat(fs) { return nil } - if err := c.Fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil { + if err := fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil { c.Log(logger.LevelWarn, "failed to chmod path %#v, mode: %v, err: %+v", fsPath, attributes.Mode.String(), err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } logger.CommandLog(chmodLogSender, fsPath, "", c.User.Username, attributes.Mode.String(), c.ID, c.protocol, -1, -1, "", "", "", -1) return nil } -func (c *BaseConnection) handleChown(fsPath, pathForPerms string, attributes *StatAttributes) error { +func (c *BaseConnection) handleChown(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChown, pathForPerms) { return c.GetPermissionDeniedError() } - if c.ignoreSetStat() { + if c.ignoreSetStat(fs) { return nil } - if err := c.Fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil { + if err := fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil { c.Log(logger.LevelWarn, "failed to chown path %#v, uid: %v, gid: %v, err: %+v", fsPath, attributes.UID, attributes.GID, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } logger.CommandLog(chownLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, attributes.UID, attributes.GID, "", "", "", -1) return nil } -func (c *BaseConnection) handleChtimes(fsPath, pathForPerms string, attributes *StatAttributes) error { +func (c *BaseConnection) handleChtimes(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { if !c.User.HasPerm(dataprovider.PermChtimes, pathForPerms) { return c.GetPermissionDeniedError() } - if c.ignoreSetStat() { + if c.ignoreSetStat(fs) { return nil } - if err := c.Fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime); err != nil { + if err := fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime); err != nil { c.Log(logger.LevelWarn, "failed to chtimes for path %#v, access time: %v, modification time: %v, err: %+v", fsPath, attributes.Atime, attributes.Mtime, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } accessTimeString := attributes.Atime.Format(chtimesFormat) modificationTimeString := attributes.Mtime.Format(chtimesFormat) @@ -527,19 +549,23 @@ func (c *BaseConnection) handleChtimes(fsPath, pathForPerms string, attributes * } // SetStat set StatAttributes for the specified fsPath -func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAttributes) error { - pathForPerms := c.getPathForSetStatPerms(fsPath, virtualPath) +func (c *BaseConnection) SetStat(virtualPath string, attributes *StatAttributes) error { + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + pathForPerms := c.getPathForSetStatPerms(fs, fsPath, virtualPath) if attributes.Flags&StatAttrPerms != 0 { - return c.handleChmod(fsPath, pathForPerms, attributes) + return c.handleChmod(fs, fsPath, pathForPerms, attributes) } if attributes.Flags&StatAttrUIDGID != 0 { - return c.handleChown(fsPath, pathForPerms, attributes) + return c.handleChown(fs, fsPath, pathForPerms, attributes) } if attributes.Flags&StatAttrTimes != 0 { - return c.handleChtimes(fsPath, pathForPerms, attributes) + return c.handleChtimes(fs, fsPath, pathForPerms, attributes) } if attributes.Flags&StatAttrSize != 0 { @@ -547,9 +573,9 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt return c.GetPermissionDeniedError() } - if err := c.truncateFile(fsPath, virtualPath, attributes.Size); err != nil { + if err := c.truncateFile(fs, fsPath, virtualPath, attributes.Size); err != nil { c.Log(logger.LevelWarn, "failed to truncate path %#v, size: %v, err: %+v", fsPath, attributes.Size, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } logger.CommandLog(truncateLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", attributes.Size) } @@ -557,7 +583,7 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt return nil } -func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) error { +func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, size int64) error { // check first if we have an open transfer for the given path and try to truncate the file already opened // if we found no transfer we truncate by path. var initialSize int64 @@ -566,14 +592,14 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er if err == errNoTransfer { c.Log(logger.LevelDebug, "file path %#v not found in active transfers, execute trucate by path", fsPath) var info os.FileInfo - info, err = c.Fs.Stat(fsPath) + info, err = fs.Stat(fsPath) if err != nil { return err } initialSize = info.Size() - err = c.Fs.Truncate(fsPath, size) + err = fs.Truncate(fsPath, size) } - if err == nil && vfs.IsLocalOrSFTPFs(c.Fs) { + if err == nil && vfs.IsLocalOrSFTPFs(fs) { sizeDiff := initialSize - size vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) if err == nil { @@ -588,20 +614,20 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er return err } -func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPath string) error { +func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath string) error { dstPerms := []string{ dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermCreateSymlinks, } - err := c.Fs.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error { + err := fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error { if err != nil { - return err + return c.GetFsError(fsSrc, err) } dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1) - virtualSrcPath := c.Fs.GetRelativePath(walkedPath) - virtualDstPath := c.Fs.GetRelativePath(dstPath) + virtualSrcPath := fsSrc.GetRelativePath(walkedPath) + virtualDstPath := fsDst.GetRelativePath(dstPath) // walk scans the directory tree in order, checking the parent directory permissions we are sure that all contents // inside the parent path was checked. If the current dir has no subdirs with defined permissions inside it // and it has all the possible permissions we can stop scanning @@ -615,10 +641,10 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPa return ErrSkipPermissionsCheck } } - if !c.isRenamePermitted(walkedPath, virtualSrcPath, virtualDstPath, info) { + if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) { c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v", walkedPath, dstPath, virtualDstPath) - return os.ErrPermission + return c.GetPermissionDeniedError() } return nil }) @@ -628,22 +654,7 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPa return err } -func (c *BaseConnection) isRenamePermitted(fsSourcePath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { - if c.Fs.GetRelativePath(fsSourcePath) == "/" { - c.Log(logger.LevelWarn, "renaming root dir is not allowed") - return false - } - if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) { - c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed") - return false - } - if !c.User.IsFileAllowed(virtualSourcePath) || !c.User.IsFileAllowed(virtualTargetPath) { - if fi != nil && fi.Mode().IsRegular() { - c.Log(logger.LevelDebug, "renaming file is not allowed, source: %#v target: %#v", - virtualSourcePath, virtualTargetPath) - return false - } - } +func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { if c.User.HasPerm(dataprovider.PermRename, path.Dir(virtualSourcePath)) && c.User.HasPerm(dataprovider.PermRename, path.Dir(virtualTargetPath)) { return true @@ -661,7 +672,39 @@ func (c *BaseConnection) isRenamePermitted(fsSourcePath, virtualSourcePath, virt return c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualTargetPath)) } -func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath string, initialSize int64, +func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { + if !c.isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath) { + c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be local or on the same virtual folder", + virtualSourcePath, virtualTargetPath) + return false + } + if c.User.IsMappedPath(fsSourcePath) && vfs.IsLocalOrCryptoFs(fsSrc) { + c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %#v", fsSourcePath) + return false + } + if c.User.IsMappedPath(fsTargetPath) && vfs.IsLocalOrCryptoFs(fsDst) { + c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath) + return false + } + if fsSrc.GetRelativePath(fsSourcePath) == "/" { + c.Log(logger.LevelWarn, "renaming root dir is not allowed") + return false + } + if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) { + c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed") + return false + } + if !c.User.IsFileAllowed(virtualSourcePath) || !c.User.IsFileAllowed(virtualTargetPath) { + if fi != nil && fi.Mode().IsRegular() { + c.Log(logger.LevelDebug, "renaming file is not allowed, source: %#v target: %#v", + virtualSourcePath, virtualTargetPath) + return false + } + } + return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, fi) +} + +func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath string, initialSize int64, fsSourcePath string) bool { if dataprovider.GetQuotaTracking() == 0 { return true @@ -674,7 +717,7 @@ func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath } if errSrc == nil && errDst == nil { // rename between virtual folders - if sourceFolder.MappedPath == dstFolder.MappedPath { + if sourceFolder.Name == dstFolder.Name { // rename inside the same virtual folder return true } @@ -684,16 +727,16 @@ func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath return true } quotaResult := c.HasSpace(true, false, virtualTargetPath) - return c.hasSpaceForCrossRename(quotaResult, initialSize, fsSourcePath) + return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, fsSourcePath) } // hasSpaceForCrossRename checks the quota after a rename between different folders -func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string) bool { +func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string) bool { if !quotaResult.HasSpace && initialSize == -1 { // we are over quota and this is not a file replace return false } - fi, err := c.Fs.Lstat(sourcePath) + fi, err := fs.Lstat(sourcePath) if err != nil { c.Log(logger.LevelWarn, "cross rename denied, stat error for path %#v: %v", sourcePath, err) return false @@ -708,7 +751,7 @@ func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult filesDiff = 0 } } else if fi.IsDir() { - filesDiff, sizeDiff, err = c.Fs.GetDirSize(sourcePath) + filesDiff, sizeDiff, err = fs.GetDirSize(sourcePath) if err != nil { c.Log(logger.LevelWarn, "cross rename denied, error getting size for directory %#v: %v", sourcePath, err) return false @@ -745,11 +788,11 @@ func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult // GetMaxWriteSize returns the allowed size for an upload or an error // if no enough size is available for a resume/append -func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64) (int64, error) { +func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, isUploadResumeSupported bool) (int64, error) { maxWriteSize := quotaResult.GetRemainingSize() if isResume { - if !c.Fs.IsUploadResumeSupported() { + if !isUploadResumeSupported { return 0, c.GetOpUnsupportedError() } if c.User.Filters.MaxUploadFileSize > 0 && c.User.Filters.MaxUploadFileSize <= fileSize { @@ -823,6 +866,40 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) return result } +// returns true if this is a rename on the same fs or local virtual folders +func (c *BaseConnection) isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath string) bool { + sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) + dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) + if errSrc != nil && errDst != nil { + return true + } + if errSrc == nil && errDst == nil { + if sourceFolder.Name == dstFolder.Name { + return true + } + // we have different folders, only local fs is supported + if sourceFolder.FsConfig.Provider == vfs.LocalFilesystemProvider && + dstFolder.FsConfig.Provider == vfs.LocalFilesystemProvider { + return true + } + return false + } + if c.User.FsConfig.Provider != vfs.LocalFilesystemProvider { + return false + } + if errSrc == nil { + if sourceFolder.FsConfig.Provider == vfs.LocalFilesystemProvider { + return true + } + } + if errDst == nil { + if dstFolder.FsConfig.Provider == vfs.LocalFilesystemProvider { + return true + } + } + return false +} + func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool { sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) @@ -830,14 +907,14 @@ func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetP return false } if errSrc == nil && errDst == nil { - return sourceFolder.MappedPath != dstFolder.MappedPath + return sourceFolder.Name != dstFolder.Name } return true } func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { - if sourceFolder.MappedPath == dstFolder.MappedPath { + if sourceFolder.Name == dstFolder.Name { // both files are inside the same virtual folder if initialSize != -1 { dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck @@ -897,7 +974,10 @@ func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, } } -func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTargetPath, targetPath string, initialSize int64) error { +func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string, initialSize int64) error { + if dataprovider.GetQuotaTracking() == 0 { + return nil + } // we don't allow to overwrite an existing directory so targetPath can be: // - a new file, a symlink is as a new file here // - a file overwriting an existing one @@ -908,7 +988,8 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget if errSrc != nil && errDst != nil { // both files are contained inside the user home dir if initialSize != -1 { - // we cannot have a directory here + // we cannot have a directory here, we are overwriting an existing file + // we need to subtract the size of the overwritten file from the user quota dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck } return nil @@ -916,9 +997,9 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget filesSize := int64(0) numFiles := 1 - if fi, err := c.Fs.Stat(targetPath); err == nil { + if fi, err := fs.Stat(targetPath); err == nil { if fi.Mode().IsDir() { - numFiles, filesSize, err = c.Fs.GetDirSize(targetPath) + numFiles, filesSize, err = fs.GetDirSize(targetPath) if err != nil { c.Log(logger.LevelWarn, "failed to update quota after rename, error scanning moved folder %#v: %v", targetPath, err) @@ -995,15 +1076,30 @@ func (c *BaseConnection) GetGenericError(err error) error { } // GetFsError converts a filesystem error to a protocol error -func (c *BaseConnection) GetFsError(err error) error { - if c.Fs.IsNotExist(err) { +func (c *BaseConnection) GetFsError(fs vfs.Fs, err error) error { + if fs.IsNotExist(err) { return c.GetNotExistError() - } else if c.Fs.IsPermission(err) { + } else if fs.IsPermission(err) { return c.GetPermissionDeniedError() - } else if c.Fs.IsNotSupported(err) { + } else if fs.IsNotSupported(err) { return c.GetOpUnsupportedError() } else if err != nil { return c.GetGenericError(err) } return nil } + +// GetFsAndResolvedPath returns the fs and the fs path matching virtualPath +func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, string, error) { + fs, err := c.User.GetFilesystemForPath(virtualPath, c.ID) + if err != nil { + return nil, "", err + } + + fsPath, err := fs.ResolvePath(virtualPath) + if err != nil { + return nil, "", c.GetFsError(fs, err) + } + + return fs, fsPath, nil +} diff --git a/common/connection_test.go b/common/connection_test.go index 9a06e48d..e46a02e4 100644 --- a/common/connection_test.go +++ b/common/connection_test.go @@ -8,13 +8,10 @@ import ( "testing" "time" - "github.com/minio/sio" "github.com/pkg/sftp" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/kms" "github.com/drakkan/sftpgo/vfs" ) @@ -40,896 +37,110 @@ func (fs MockOsFs) IsUploadResumeSupported() bool { func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir string) vfs.Fs { return &MockOsFs{ - Fs: vfs.NewOsFs(connectionID, rootDir, nil), + Fs: vfs.NewOsFs(connectionID, rootDir, ""), hasVirtualFolders: hasVirtualFolders, } } -func TestListDir(t *testing.T) { +func TestRemoveErrors(t *testing.T) { + mappedPath := filepath.Join(os.TempDir(), "map") + homePath := filepath.Join(os.TempDir(), "home") + user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath := filepath.Join(os.TempDir(), "vdir") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermUpload} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, + Username: "remove_errors_user", + HomeDir: homePath, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath), + MappedPath: mappedPath, + }, + VirtualPath: "/virtualpath", + }, }, - VirtualPath: "/vdir", - }) - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - _, err = c.ListDir(user.GetHomeDir(), "/") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) } - c.User.Permissions["/"] = []string{dataprovider.PermAny} - files, err := c.ListDir(user.GetHomeDir(), "/") - if assert.NoError(t, err) { - vdirFound := false - for _, f := range files { - if f.Name() == "vdir" { - vdirFound = true - break - } - } - assert.True(t, vdirFound) - } - _, err = c.ListDir(mappedPath, "/vdir") - assert.Error(t, err) - - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) -} - -func TestCreateDir(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath := filepath.Join(os.TempDir(), "vdir") user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/sub"] = []string{dataprovider.PermListItems} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: "/vdir", - }) - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.CreateDir("", "/sub/dir") + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := NewBaseConnection("", ProtocolFTP, user) + err := conn.IsRemoveDirAllowed(fs, mappedPath, "/virtualpath1") if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) + assert.Contains(t, err.Error(), "permission denied") } - err = c.CreateDir("", "/vdir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.CreateDir(filepath.Join(mappedPath, "adir"), "/vdir/adir") + err = conn.RemoveFile(fs, filepath.Join(homePath, "missing_file"), "/missing_file", + vfs.NewFileInfo("info", false, 100, time.Now(), false)) assert.Error(t, err) - err = c.CreateDir(filepath.Join(user.GetHomeDir(), "dir"), "/dir") - assert.NoError(t, err) - - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) } -func TestRemoveFile(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath := filepath.Join(os.TempDir(), "vdir") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/sub"] = []string{dataprovider.PermListItems} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: "/vdir", - QuotaFiles: -1, - QuotaSize: -1, - }) - user.Filters.FileExtensions = []dataprovider.ExtensionsFilter{ - { - Path: "/p", - AllowedExtensions: []string{}, - DeniedExtensions: []string{".zip"}, - }, - } - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath, os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.RemoveFile("", "/sub/file", nil) - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.RemoveFile("", "/p/file.zip", nil) - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - testFile := filepath.Join(mappedPath, "afile") - err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) - assert.NoError(t, err) - info, err := os.Stat(testFile) - assert.NoError(t, err) - err = c.RemoveFile(filepath.Join(user.GetHomeDir(), "missing"), "/missing", info) - assert.Error(t, err) - err = c.RemoveFile(testFile, "/vdir/afile", info) - assert.NoError(t, err) - - err = os.RemoveAll(mappedPath) - assert.NoError(t, err) - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) -} - -func TestRemoveDir(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath := filepath.Join(os.TempDir(), "vdir") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/sub"] = []string{dataprovider.PermListItems} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: "/adir/vdir", - }) - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath, os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.RemoveDir(user.GetHomeDir(), "/") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.RemoveDir(mappedPath, "/adir/vdir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.RemoveDir(mappedPath, "/adir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetOpUnsupportedError().Error()) - } - err = c.RemoveDir(mappedPath, "/adir/dir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.RemoveDir(filepath.Join(user.GetHomeDir(), "/sub/dir"), "/sub/dir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - testDir := filepath.Join(user.GetHomeDir(), "testDir") - err = c.RemoveDir(testDir, "testDir") - assert.Error(t, err) - err = os.WriteFile(testDir, []byte("data"), os.ModePerm) - assert.NoError(t, err) - err = c.RemoveDir(testDir, "testDir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetGenericError(err).Error()) - } - err = os.Remove(testDir) - assert.NoError(t, err) - testDirSub := filepath.Join(testDir, "sub") - err = os.MkdirAll(testDirSub, os.ModePerm) - assert.NoError(t, err) - err = c.RemoveDir(testDir, "/testDir") - assert.Error(t, err) - err = os.RemoveAll(testDirSub) - assert.NoError(t, err) - err = c.RemoveDir(testDir, "/testDir") - assert.NoError(t, err) - - err = c.RemoveDir(testDir, "/testDir") - assert.Error(t, err) - - fs = newMockOsFs(true, "", user.GetHomeDir()) - c.Fs = fs - err = c.RemoveDir(testDir, "/testDir") - assert.NoError(t, err) - - err = os.RemoveAll(mappedPath) - assert.NoError(t, err) - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) -} - -func TestRename(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - QuotaSize: 10485760, - } - mappedPath1 := filepath.Join(os.TempDir(), "vdir1") - mappedPath2 := filepath.Join(os.TempDir(), "vdir2") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/sub"] = []string{dataprovider.PermListItems} - user.Permissions["/sub1"] = []string{dataprovider.PermRename} - user.Permissions["/dir"] = []string{dataprovider.PermListItems} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath1, - }, - VirtualPath: "/vdir1/sub", - QuotaFiles: -1, - QuotaSize: -1, - }) - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath2, - }, - VirtualPath: "/vdir2", - QuotaFiles: -1, - QuotaSize: -1, - }) - err := os.MkdirAll(filepath.Join(user.GetHomeDir(), "sub"), os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "dir", "sub"), os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath1, os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath2, os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.Rename(mappedPath1, "", "", "") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.Rename("", mappedPath2, "", "") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.Rename("missing", "", "", "") - assert.Error(t, err) - testFile := filepath.Join(user.GetHomeDir(), "file") - err = os.WriteFile(testFile, []byte("data"), os.ModePerm) - assert.NoError(t, err) - testSubFile := filepath.Join(user.GetHomeDir(), "sub", "file") - err = os.WriteFile(testSubFile, []byte("data"), os.ModePerm) - assert.NoError(t, err) - err = c.Rename(testSubFile, filepath.Join(user.GetHomeDir(), "file"), "/sub/file", "/file") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.Rename(testFile, filepath.Join(user.GetHomeDir(), "sub"), "/file", "/sub") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetOpUnsupportedError().Error()) - } - err = c.Rename(testSubFile, testFile, "/file", "/sub1/file") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.Rename(filepath.Join(user.GetHomeDir(), "sub"), filepath.Join(user.GetHomeDir(), "adir"), "/vdir1", "/adir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetOpUnsupportedError().Error()) - } - err = c.Rename(filepath.Join(user.GetHomeDir(), "dir"), filepath.Join(user.GetHomeDir(), "adir"), "/dir", "/adir") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "testdir"), os.ModePerm) - assert.NoError(t, err) - err = c.Rename(filepath.Join(user.GetHomeDir(), "testdir"), filepath.Join(user.GetHomeDir(), "tdir", "sub"), "/testdir", "/tdir/sub") - assert.Error(t, err) - err = os.Remove(testSubFile) - assert.NoError(t, err) - err = c.Rename(filepath.Join(user.GetHomeDir(), "sub"), filepath.Join(user.GetHomeDir(), "adir"), "/sub", "/adir") - assert.NoError(t, err) - err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "adir"), os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(user.GetHomeDir(), "adir", "file"), []byte("data"), os.ModePerm) - assert.NoError(t, err) - err = c.Rename(filepath.Join(user.GetHomeDir(), "adir", "file"), filepath.Join(user.GetHomeDir(), "file"), "/adir/file", "/file") - assert.NoError(t, err) - // rename between virtual folder this should fail since the virtual folder is not found inside the data provider - // and so the remaining space cannot be computed - err = c.Rename(filepath.Join(user.GetHomeDir(), "adir"), filepath.Join(user.GetHomeDir(), "another"), "/vdir1/sub/a", "/vdir2/b") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetGenericError(err).Error()) - } - - err = os.RemoveAll(mappedPath1) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath2) - assert.NoError(t, err) - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) -} - -func TestCreateSymlink(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath1 := filepath.Join(os.TempDir(), "vdir1") - mappedPath2 := filepath.Join(os.TempDir(), "vdir2") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/sub"] = []string{dataprovider.PermListItems} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath1, - }, - VirtualPath: "/vdir1", - QuotaFiles: -1, - QuotaSize: -1, - }) - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath2, - }, - VirtualPath: "/vdir2", - QuotaFiles: -1, - QuotaSize: -1, - }) - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath1, os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(mappedPath2, os.ModePerm) - assert.NoError(t, err) - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.CreateSymlink(user.GetHomeDir(), mappedPath1, "/", "/vdir1") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.CreateSymlink(filepath.Join(user.GetHomeDir(), "a"), mappedPath1, "/a", "/vdir1") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.CreateSymlink(filepath.Join(user.GetHomeDir(), "b"), mappedPath1, "/b", "/sub/b") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.CreateSymlink(filepath.Join(user.GetHomeDir(), "b"), mappedPath1, "/vdir1/b", "/vdir2/b") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetOpUnsupportedError().Error()) - } - err = c.CreateSymlink(mappedPath1, filepath.Join(mappedPath1, "b"), "/vdir1/a", "/vdir1/b") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.CreateSymlink(filepath.Join(mappedPath1, "b"), mappedPath1, "/vdir1/a", "/vdir1/b") - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - - err = os.Mkdir(filepath.Join(user.GetHomeDir(), "b"), os.ModePerm) - assert.NoError(t, err) - err = c.CreateSymlink(filepath.Join(user.GetHomeDir(), "b"), filepath.Join(user.GetHomeDir(), "c"), "/b", "/c") - assert.NoError(t, err) - err = c.CreateSymlink(filepath.Join(user.GetHomeDir(), "b"), filepath.Join(user.GetHomeDir(), "c"), "/b", "/c") - assert.Error(t, err) - - err = os.RemoveAll(mappedPath1) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath2) - assert.NoError(t, err) - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) -} - -func TestDoStat(t *testing.T) { - testFile := filepath.Join(os.TempDir(), "afile.txt") - fs := vfs.NewOsFs("123", os.TempDir(), nil) - u := dataprovider.User{ - Username: "user", - HomeDir: os.TempDir(), - } - u.Permissions = make(map[string][]string) - u.Permissions["/"] = []string{dataprovider.PermAny} - err := os.WriteFile(testFile, []byte("data"), os.ModePerm) - require.NoError(t, err) - err = os.Symlink(testFile, testFile+".sym") - require.NoError(t, err) - conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs) - infoStat, err := conn.DoStat(testFile+".sym", 0) - if assert.NoError(t, err) { - assert.Equal(t, int64(4), infoStat.Size()) - } - infoLstat, err := conn.DoStat(testFile+".sym", 1) - if assert.NoError(t, err) { - assert.NotEqual(t, int64(4), infoLstat.Size()) - } - assert.False(t, os.SameFile(infoStat, infoLstat)) - - fs, err = vfs.NewCryptFs(fs.ConnectionID(), os.TempDir(), vfs.CryptFsConfig{ - Passphrase: kms.NewPlainSecret("payload"), - }) - assert.NoError(t, err) - conn = NewBaseConnection(fs.ConnectionID(), ProtocolFTP, u, fs) - dataSize := int64(32768) - data := make([]byte, dataSize) - err = os.WriteFile(testFile, data, os.ModePerm) - assert.NoError(t, err) - infoStat, err = conn.DoStat(testFile, 0) - assert.NoError(t, err) - assert.Less(t, infoStat.Size(), dataSize) - encSize, err := sio.EncryptedSize(uint64(infoStat.Size())) - assert.NoError(t, err) - assert.Equal(t, int64(encSize)+33, dataSize) - - err = os.Remove(testFile) - assert.NoError(t, err) - err = os.Remove(testFile + ".sym") - assert.NoError(t, err) - assert.Len(t, conn.GetTransfers(), 0) -} - -func TestSetStat(t *testing.T) { +func TestSetStatMode(t *testing.T) { oldSetStatMode := Config.SetstatMode Config.SetstatMode = 1 + + fakePath := "fake path" user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), + HomeDir: os.TempDir(), } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - user.Permissions["/dir1"] = []string{dataprovider.PermChmod} - user.Permissions["/dir2"] = []string{dataprovider.PermChown} - user.Permissions["/dir3"] = []string{dataprovider.PermChtimes} - dir1 := filepath.Join(user.GetHomeDir(), "dir1") - dir2 := filepath.Join(user.GetHomeDir(), "dir2") - dir3 := filepath.Join(user.GetHomeDir(), "dir3") - err := os.Mkdir(user.GetHomeDir(), os.ModePerm) + fs := newMockOsFs(true, "", user.GetHomeDir()) + conn := NewBaseConnection("", ProtocolWebDAV, user) + err := conn.handleChmod(fs, fakePath, fakePath, nil) assert.NoError(t, err) - err = os.Mkdir(dir1, os.ModePerm) + err = conn.handleChown(fs, fakePath, fakePath, nil) assert.NoError(t, err) - err = os.Mkdir(dir2, os.ModePerm) - assert.NoError(t, err) - err = os.Mkdir(dir3, os.ModePerm) - assert.NoError(t, err) - - fs, err := user.GetFilesystem("") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - err = c.SetStat(user.GetHomeDir(), "/", &StatAttributes{}) - assert.NoError(t, err) - - err = c.SetStat(dir2, "/dir1/file", &StatAttributes{ - Mode: os.ModePerm, - Flags: StatAttrPerms, - }) - assert.NoError(t, err) - err = c.SetStat(dir1, "/dir2/file", &StatAttributes{ - UID: os.Getuid(), - GID: os.Getgid(), - Flags: StatAttrUIDGID, - }) - assert.NoError(t, err) - err = c.SetStat(dir1, "/dir3/file", &StatAttributes{ - Atime: time.Now(), - Mtime: time.Now(), - Flags: StatAttrTimes, - }) + err = conn.handleChtimes(fs, fakePath, fakePath, nil) assert.NoError(t, err) Config.SetstatMode = 2 - assert.False(t, c.ignoreSetStat()) - c1 := NewBaseConnection("", ProtocolSFTP, user, newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir())) - assert.True(t, c1.ignoreSetStat()) + err = conn.handleChmod(fs, fakePath, fakePath, nil) + assert.NoError(t, err) Config.SetstatMode = oldSetStatMode - // chmod - err = c.SetStat(dir1, "/dir1/file", &StatAttributes{ - Mode: os.ModePerm, - Flags: StatAttrPerms, - }) - assert.NoError(t, err) - err = c.SetStat(dir2, "/dir2/file", &StatAttributes{ - Mode: os.ModePerm, - Flags: StatAttrPerms, - }) - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.SetStat(filepath.Join(user.GetHomeDir(), "missing"), "/missing", &StatAttributes{ - Mode: os.ModePerm, - Flags: StatAttrPerms, - }) - assert.Error(t, err) - // chown +} + +func TestRecursiveRenameWalkError(t *testing.T) { + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := NewBaseConnection("", ProtocolWebDAV, dataprovider.User{}) + err := conn.checkRecursiveRenameDirPermissions(fs, fs, "/source", "/target") + assert.ErrorIs(t, err, os.ErrNotExist) +} + +func TestCrossRenameFsErrors(t *testing.T) { + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := NewBaseConnection("", ProtocolWebDAV, dataprovider.User{}) + res := conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, "missingsource") + assert.False(t, res) if runtime.GOOS != osWindows { - err = c.SetStat(dir1, "/dir2/file", &StatAttributes{ - UID: os.Getuid(), - GID: os.Getgid(), - Flags: StatAttrUIDGID, - }) + dirPath := filepath.Join(os.TempDir(), "d") + err := os.Mkdir(dirPath, os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(dirPath, 0001) + assert.NoError(t, err) + + res = conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, dirPath) + assert.False(t, res) + + err = os.Chmod(dirPath, os.ModePerm) + assert.NoError(t, err) + err = os.Remove(dirPath) assert.NoError(t, err) } - - err = c.SetStat(dir1, "/dir3/file", &StatAttributes{ - UID: os.Getuid(), - GID: os.Getgid(), - Flags: StatAttrUIDGID, - }) - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - - err = c.SetStat(filepath.Join(user.GetHomeDir(), "missing"), "/missing", &StatAttributes{ - UID: os.Getuid(), - GID: os.Getgid(), - Flags: StatAttrUIDGID, - }) - assert.Error(t, err) - // chtimes - err = c.SetStat(dir1, "/dir3/file", &StatAttributes{ - Atime: time.Now(), - Mtime: time.Now(), - Flags: StatAttrTimes, - }) - assert.NoError(t, err) - err = c.SetStat(dir1, "/dir1/file", &StatAttributes{ - Atime: time.Now(), - Mtime: time.Now(), - Flags: StatAttrTimes, - }) - if assert.Error(t, err) { - assert.EqualError(t, err, c.GetPermissionDeniedError().Error()) - } - err = c.SetStat(filepath.Join(user.GetHomeDir(), "missing"), "/missing", &StatAttributes{ - Atime: time.Now(), - Mtime: time.Now(), - Flags: StatAttrTimes, - }) - assert.Error(t, err) - // truncate - err = c.SetStat(filepath.Join(user.GetHomeDir(), "/missing/missing"), "/missing/missing", &StatAttributes{ - Size: 1, - Flags: StatAttrSize, - }) - assert.Error(t, err) - err = c.SetStat(filepath.Join(dir3, "afile.txt"), "/dir3/afile.txt", &StatAttributes{ - Size: 1, - Flags: StatAttrSize, - }) - assert.Error(t, err) - - filePath := filepath.Join(user.GetHomeDir(), "afile.txt") - err = os.WriteFile(filePath, []byte("hello"), os.ModePerm) - assert.NoError(t, err) - err = c.SetStat(filePath, "/afile.txt", &StatAttributes{ - Flags: StatAttrSize, - Size: 1, - }) - assert.NoError(t, err) - fi, err := os.Stat(filePath) - if assert.NoError(t, err) { - assert.Equal(t, int64(1), fi.Size()) - } - - vDir := filepath.Join(os.TempDir(), "vdir") - err = os.MkdirAll(vDir, os.ModePerm) - assert.NoError(t, err) - c.User.VirtualFolders = nil - c.User.VirtualFolders = append(c.User.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: vDir, - }, - VirtualPath: "/vpath", - QuotaSize: -1, - QuotaFiles: -1, - }) - - filePath = filepath.Join(vDir, "afile.txt") - err = os.WriteFile(filePath, []byte("hello"), os.ModePerm) - assert.NoError(t, err) - err = c.SetStat(filePath, "/vpath/afile.txt", &StatAttributes{ - Flags: StatAttrSize, - Size: 1, - }) - assert.NoError(t, err) - fi, err = os.Stat(filePath) - if assert.NoError(t, err) { - assert.Equal(t, int64(1), fi.Size()) - } - - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) - err = os.RemoveAll(vDir) - assert.NoError(t, err) } -func TestSpaceForCrossRename(t *testing.T) { - permissions := make(map[string][]string) - permissions["/"] = []string{dataprovider.PermAny} - user := dataprovider.User{ - Username: userTestUsername, - Permissions: permissions, - HomeDir: filepath.Clean(os.TempDir()), - } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) - conn := NewBaseConnection("", ProtocolSFTP, user, fs) - quotaResult := vfs.QuotaCheckResult{ - HasSpace: true, - } - assert.False(t, conn.hasSpaceForCrossRename(quotaResult, -1, filepath.Join(os.TempDir(), "a missing file"))) - if runtime.GOOS != osWindows { - testDir := filepath.Join(os.TempDir(), "dir") - err = os.MkdirAll(testDir, os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(testDir, "afile.txt"), []byte("content"), os.ModePerm) - assert.NoError(t, err) - err = os.Chmod(testDir, 0001) - assert.NoError(t, err) - assert.False(t, conn.hasSpaceForCrossRename(quotaResult, -1, testDir)) - err = os.Chmod(testDir, os.ModePerm) - assert.NoError(t, err) - err = os.RemoveAll(testDir) - assert.NoError(t, err) - } - - testFile := filepath.Join(os.TempDir(), "afile.txt") - err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) - assert.NoError(t, err) - quotaResult = vfs.QuotaCheckResult{ - HasSpace: false, - QuotaSize: 0, - } - assert.True(t, conn.hasSpaceForCrossRename(quotaResult, 123, testFile)) - - quotaResult = vfs.QuotaCheckResult{ - HasSpace: false, - QuotaSize: 124, - UsedSize: 125, - } - assert.False(t, conn.hasSpaceForCrossRename(quotaResult, 8, testFile)) - - quotaResult = vfs.QuotaCheckResult{ - HasSpace: false, - QuotaSize: 124, - UsedSize: 124, - } - assert.True(t, conn.hasSpaceForCrossRename(quotaResult, 123, testFile)) - - quotaResult = vfs.QuotaCheckResult{ - HasSpace: true, - QuotaSize: 10, - UsedSize: 1, - } - assert.True(t, conn.hasSpaceForCrossRename(quotaResult, -1, testFile)) - - quotaResult = vfs.QuotaCheckResult{ - HasSpace: true, - QuotaSize: 7, - UsedSize: 0, - } - assert.False(t, conn.hasSpaceForCrossRename(quotaResult, -1, testFile)) - - err = os.Remove(testFile) - assert.NoError(t, err) - - testDir := filepath.Join(os.TempDir(), "testDir") - err = os.MkdirAll(testDir, os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(testDir, "1"), []byte("1"), os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(testDir, "2"), []byte("2"), os.ModePerm) - assert.NoError(t, err) - quotaResult = vfs.QuotaCheckResult{ - HasSpace: true, - QuotaFiles: 2, - UsedFiles: 1, - } - assert.False(t, conn.hasSpaceForCrossRename(quotaResult, -1, testDir)) - - quotaResult = vfs.QuotaCheckResult{ - HasSpace: true, - QuotaFiles: 2, - UsedFiles: 0, - } - assert.True(t, conn.hasSpaceForCrossRename(quotaResult, -1, testDir)) - - err = os.RemoveAll(testDir) - assert.NoError(t, err) -} - -func TestRenamePermission(t *testing.T) { - permissions := make(map[string][]string) - permissions["/"] = []string{dataprovider.PermAny} - permissions["/dir1"] = []string{dataprovider.PermRename} - permissions["/dir2"] = []string{dataprovider.PermUpload} - permissions["/dir3"] = []string{dataprovider.PermDelete} - permissions["/dir4"] = []string{dataprovider.PermListItems} - permissions["/dir5"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload} - permissions["/dir6"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, - dataprovider.PermListItems, dataprovider.PermCreateSymlinks} - permissions["/dir7"] = []string{dataprovider.PermAny} - permissions["/dir8"] = []string{dataprovider.PermAny} - - user := dataprovider.User{ - Username: userTestUsername, - Permissions: permissions, - HomeDir: os.TempDir(), - } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) - conn := NewBaseConnection("", ProtocolSFTP, user, fs) - request := sftp.NewRequest("Rename", "/testfile") - request.Target = "/dir1/testfile" - // rename is granted on Source and Target - assert.True(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request.Target = "/dir4/testfile" - // rename is not granted on Target - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request = sftp.NewRequest("Rename", "/dir1/testfile") - request.Target = "/dir2/testfile" //nolint:goconst - // rename is granted on Source but not on Target - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request = sftp.NewRequest("Rename", "/dir4/testfile") - request.Target = "/dir1/testfile" - // rename is granted on Target but not on Source - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request = sftp.NewRequest("Rename", "/dir4/testfile") - request.Target = "/testfile" - // rename is granted on Target but not on Source - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request = sftp.NewRequest("Rename", "/dir3/testfile") - request.Target = "/dir2/testfile" - // delete is granted on Source and Upload on Target, the target is a file this is enough - assert.True(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - request = sftp.NewRequest("Rename", "/dir2/testfile") - request.Target = "/dir3/testfile" - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - tmpDir := filepath.Join(os.TempDir(), "dir") - tmpDirLink := filepath.Join(os.TempDir(), "link") - err = os.Mkdir(tmpDir, os.ModePerm) - assert.NoError(t, err) - err = os.Symlink(tmpDir, tmpDirLink) - assert.NoError(t, err) - request.Filepath = "/dir" - request.Target = "/dir2/dir" - // the source is a dir and the target has no createDirs perm - info, err := os.Lstat(tmpDir) - if assert.NoError(t, err) { - assert.False(t, conn.isRenamePermitted(tmpDir, request.Filepath, request.Target, info)) - conn.User.Permissions["/dir2"] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} - // the source is a dir and the target has createDirs perm - assert.True(t, conn.isRenamePermitted(tmpDir, request.Filepath, request.Target, info)) - - request = sftp.NewRequest("Rename", "/testfile") - request.Target = "/dir5/testfile" - // the source is a dir and the target has createDirs and upload perm - assert.True(t, conn.isRenamePermitted(tmpDir, request.Filepath, request.Target, info)) - } - info, err = os.Lstat(tmpDirLink) - if assert.NoError(t, err) { - assert.True(t, info.Mode()&os.ModeSymlink != 0) - // the source is a symlink and the target has createDirs and upload perm - assert.False(t, conn.isRenamePermitted(tmpDir, request.Filepath, request.Target, info)) - } - err = os.RemoveAll(tmpDir) - assert.NoError(t, err) - err = os.Remove(tmpDirLink) - assert.NoError(t, err) - conn.User.VirtualFolders = append(conn.User.VirtualFolders, vfs.VirtualFolder{ +func TestRenameVirtualFolders(t *testing.T) { + vdir := "/avdir" + u := dataprovider.User{} + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: os.TempDir(), + Name: "name", + MappedPath: "mappedPath", }, - VirtualPath: "/dir1", + VirtualPath: vdir, }) - request = sftp.NewRequest("Rename", "/dir1") - request.Target = "/dir2/testfile" - // renaming a virtual folder is not allowed - assert.False(t, conn.isRenamePermitted("", request.Filepath, request.Target, nil)) - err = conn.checkRecursiveRenameDirPermissions("invalid", "invalid") - assert.Error(t, err) - dir3 := filepath.Join(conn.User.HomeDir, "dir3") - dir6 := filepath.Join(conn.User.HomeDir, "dir6") - err = os.MkdirAll(filepath.Join(dir3, "subdir"), os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(dir3, "subdir", "testfile"), []byte("test"), os.ModePerm) - assert.NoError(t, err) - err = conn.checkRecursiveRenameDirPermissions(dir3, dir6) - assert.NoError(t, err) - err = os.RemoveAll(dir3) - assert.NoError(t, err) - - dir7 := filepath.Join(conn.User.HomeDir, "dir7") - dir8 := filepath.Join(conn.User.HomeDir, "dir8") - err = os.MkdirAll(filepath.Join(dir8, "subdir"), os.ModePerm) - assert.NoError(t, err) - err = os.WriteFile(filepath.Join(dir8, "subdir", "testfile"), []byte("test"), os.ModePerm) - assert.NoError(t, err) - err = conn.checkRecursiveRenameDirPermissions(dir8, dir7) - assert.NoError(t, err) - err = os.RemoveAll(dir8) - assert.NoError(t, err) - - assert.False(t, conn.isRenamePermitted(user.GetHomeDir(), "", "", nil)) - - conn.User.Filters.FileExtensions = []dataprovider.ExtensionsFilter{ - { - Path: "/p", - AllowedExtensions: []string{}, - DeniedExtensions: []string{".zip"}, - }, - } - testFile := filepath.Join(user.HomeDir, "testfile") - err = os.WriteFile(testFile, []byte("data"), os.ModePerm) - assert.NoError(t, err) - info, err = os.Stat(testFile) - assert.NoError(t, err) - assert.False(t, conn.isRenamePermitted(dir7, "/file", "/p/file.zip", info)) - err = os.Remove(testFile) - assert.NoError(t, err) -} - -func TestHasSpaceForRename(t *testing.T) { - err := closeDataprovider() - assert.NoError(t, err) - _, err = initializeDataprovider(0) - assert.NoError(t, err) - - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - } - mappedPath := filepath.Join(os.TempDir(), "vdir") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: "/vdir1", - }) - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: "/vdir2", - QuotaSize: -1, - QuotaFiles: -1, - }) - fs, err := user.GetFilesystem("id") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - // with quota tracking disabled hasSpaceForRename will always return true - assert.True(t, c.hasSpaceForRename("", "", 0, "")) - quotaResult := c.HasSpace(true, false, "") - assert.True(t, quotaResult.HasSpace) - - err = closeDataprovider() - assert.NoError(t, err) - _, err = initializeDataprovider(-1) - assert.NoError(t, err) - - // rename inside the same mapped path - assert.True(t, c.hasSpaceForRename("/vdir1/file", "/vdir2/file", 0, filepath.Join(mappedPath, "file"))) - // rename between user root dir and a virtual folder included in user quota - assert.True(t, c.hasSpaceForRename("/file", "/vdir2/file", 0, filepath.Join(mappedPath, "file"))) - - assert.True(t, c.isCrossFoldersRequest("/file", "/vdir2/file")) + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := NewBaseConnection("", ProtocolFTP, u) + res := conn.isRenamePermitted(fs, fs, "source", "target", vdir, "vdirtarget", nil) + assert.False(t, res) } func TestUpdateQuotaAfterRename(t *testing.T) { @@ -962,7 +173,7 @@ func TestUpdateQuotaAfterRename(t *testing.T) { assert.NoError(t, err) fs, err := user.GetFilesystem("id") assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) + c := NewBaseConnection("", ProtocolSFTP, user) request := sftp.NewRequest("Rename", "/testfile") if runtime.GOOS != osWindows { request.Filepath = "/dir" @@ -972,7 +183,7 @@ func TestUpdateQuotaAfterRename(t *testing.T) { assert.NoError(t, err) err = os.Chmod(testDirPath, 0001) assert.NoError(t, err) - err = c.updateQuotaAfterRename(request.Filepath, request.Target, testDirPath, 0) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, testDirPath, 0) assert.Error(t, err) err = os.Chmod(testDirPath, os.ModePerm) assert.NoError(t, err) @@ -980,23 +191,23 @@ func TestUpdateQuotaAfterRename(t *testing.T) { testFile1 := "/testfile1" request.Target = testFile1 request.Filepath = path.Join("/vdir", "file") - err = c.updateQuotaAfterRename(request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 0) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 0) assert.Error(t, err) err = os.WriteFile(filepath.Join(mappedPath, "file"), []byte("test content"), os.ModePerm) assert.NoError(t, err) request.Filepath = testFile1 request.Target = path.Join("/vdir", "file") - err = c.updateQuotaAfterRename(request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) assert.NoError(t, err) err = os.WriteFile(filepath.Join(user.GetHomeDir(), "testfile1"), []byte("test content"), os.ModePerm) assert.NoError(t, err) request.Target = testFile1 request.Filepath = path.Join("/vdir", "file") - err = c.updateQuotaAfterRename(request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) assert.NoError(t, err) request.Target = path.Join("/vdir1", "file") request.Filepath = path.Join("/vdir", "file") - err = c.updateQuotaAfterRename(request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12) assert.NoError(t, err) err = os.RemoveAll(mappedPath) @@ -1005,168 +216,12 @@ func TestUpdateQuotaAfterRename(t *testing.T) { assert.NoError(t, err) } -func TestHasSpace(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - Password: userTestPwd, - } - mappedPath := filepath.Join(os.TempDir(), "vdir") - folderName := "testFolder" - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - Name: folderName, - }, - VirtualPath: "/vdir", - QuotaFiles: -1, - QuotaSize: -1, - }) - fs, err := user.GetFilesystem("id") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - quotaResult := c.HasSpace(true, false, "/") - assert.True(t, quotaResult.HasSpace) - - user.VirtualFolders[0].QuotaFiles = 0 - user.VirtualFolders[0].QuotaSize = 0 - err = dataprovider.AddUser(&user) - assert.NoError(t, err) - user, err = dataprovider.UserExists(user.Username) - assert.NoError(t, err) - c.User = user - quotaResult = c.HasSpace(true, false, "/vdir/file") - assert.True(t, quotaResult.HasSpace) - - user.VirtualFolders[0].QuotaFiles = 10 - user.VirtualFolders[0].QuotaSize = 1048576 - err = dataprovider.UpdateUser(&user) - assert.NoError(t, err) - c.User = user - quotaResult = c.HasSpace(true, false, "/vdir/file1") - assert.True(t, quotaResult.HasSpace) - - quotaResult = c.HasSpace(true, false, "/file") - assert.True(t, quotaResult.HasSpace) - - folder, err := dataprovider.GetFolderByName(folderName) - assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 1048576, true) - assert.NoError(t, err) - quotaResult = c.HasSpace(true, false, "/vdir/file1") - assert.False(t, quotaResult.HasSpace) - - err = dataprovider.DeleteUser(user.Username) - assert.NoError(t, err) - - err = dataprovider.DeleteFolder(folder.Name) - assert.NoError(t, err) -} - -func TestUpdateQuotaMoveVFolders(t *testing.T) { - user := dataprovider.User{ - Username: userTestUsername, - HomeDir: filepath.Join(os.TempDir(), "home"), - Password: userTestPwd, - QuotaFiles: 100, - } - folderName1 := "testFolder1" - folderName2 := "testFolder2" - mappedPath1 := filepath.Join(os.TempDir(), "vdir1") - mappedPath2 := filepath.Join(os.TempDir(), "vdir2") - user.Permissions = make(map[string][]string) - user.Permissions["/"] = []string{dataprovider.PermAny} - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath1, - Name: folderName1, - }, - VirtualPath: "/vdir1", - QuotaFiles: -1, - QuotaSize: -1, - }) - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath2, - Name: folderName2, - }, - VirtualPath: "/vdir2", - QuotaFiles: -1, - QuotaSize: -1, - }) - err := dataprovider.AddUser(&user) - assert.NoError(t, err) - user, err = dataprovider.UserExists(user.Username) - assert.NoError(t, err) - folder1, err := dataprovider.GetFolderByName(folderName1) - assert.NoError(t, err) - folder2, err := dataprovider.GetFolderByName(folderName2) - assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(&folder1, 1, 100, true) - assert.NoError(t, err) - err = dataprovider.UpdateVirtualFolderQuota(&folder2, 2, 150, true) - assert.NoError(t, err) - fs, err := user.GetFilesystem("id") - assert.NoError(t, err) - c := NewBaseConnection("", ProtocolSFTP, user, fs) - c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[0], &user.VirtualFolders[1], -1, 100, 1) - folder1, err = dataprovider.GetFolderByName(folderName1) - assert.NoError(t, err) - assert.Equal(t, 0, folder1.UsedQuotaFiles) - assert.Equal(t, int64(0), folder1.UsedQuotaSize) - folder2, err = dataprovider.GetFolderByName(folderName2) - assert.NoError(t, err) - assert.Equal(t, 3, folder2.UsedQuotaFiles) - assert.Equal(t, int64(250), folder2.UsedQuotaSize) - - c.updateQuotaMoveBetweenVFolders(&user.VirtualFolders[1], &user.VirtualFolders[0], 10, 100, 1) - folder1, err = dataprovider.GetFolderByName(folderName1) - assert.NoError(t, err) - assert.Equal(t, 0, folder1.UsedQuotaFiles) - assert.Equal(t, int64(90), folder1.UsedQuotaSize) - folder2, err = dataprovider.GetFolderByName(folderName2) - assert.NoError(t, err) - assert.Equal(t, 2, folder2.UsedQuotaFiles) - assert.Equal(t, int64(150), folder2.UsedQuotaSize) - - err = dataprovider.UpdateUserQuota(&user, 1, 100, true) - assert.NoError(t, err) - c.updateQuotaMoveFromVFolder(&user.VirtualFolders[1], -1, 50, 1) - folder2, err = dataprovider.GetFolderByName(folderName2) - assert.NoError(t, err) - assert.Equal(t, 1, folder2.UsedQuotaFiles) - assert.Equal(t, int64(100), folder2.UsedQuotaSize) - user, err = dataprovider.UserExists(user.Username) - assert.NoError(t, err) - assert.Equal(t, 1, user.UsedQuotaFiles) - assert.Equal(t, int64(100), user.UsedQuotaSize) - - c.updateQuotaMoveToVFolder(&user.VirtualFolders[1], -1, 100, 1) - folder2, err = dataprovider.GetFolderByName(folderName2) - assert.NoError(t, err) - assert.Equal(t, 2, folder2.UsedQuotaFiles) - assert.Equal(t, int64(200), folder2.UsedQuotaSize) - user, err = dataprovider.UserExists(user.Username) - assert.NoError(t, err) - assert.Equal(t, 1, user.UsedQuotaFiles) - assert.Equal(t, int64(100), user.UsedQuotaSize) - - err = dataprovider.DeleteUser(user.Username) - assert.NoError(t, err) - err = dataprovider.DeleteFolder(folder1.Name) - assert.NoError(t, err) - err = dataprovider.DeleteFolder(folder2.Name) - assert.NoError(t, err) -} - func TestErrorsMapping(t *testing.T) { - fs := vfs.NewOsFs("", os.TempDir(), nil) - conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{}, fs) + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{HomeDir: os.TempDir()}) for _, protocol := range supportedProtocols { conn.SetProtocol(protocol) - err := conn.GetFsError(os.ErrNotExist) + err := conn.GetFsError(fs, os.ErrNotExist) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxNoSuchFile.Error()) } else if protocol == ProtocolWebDAV || protocol == ProtocolFTP { @@ -1174,37 +229,37 @@ func TestErrorsMapping(t *testing.T) { } else { assert.EqualError(t, err, ErrNotExist.Error()) } - err = conn.GetFsError(os.ErrPermission) + err = conn.GetFsError(fs, os.ErrPermission) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxPermissionDenied.Error()) } else { assert.EqualError(t, err, ErrPermissionDenied.Error()) } - err = conn.GetFsError(os.ErrClosed) + err = conn.GetFsError(fs, os.ErrClosed) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error()) } else { assert.EqualError(t, err, ErrGenericFailure.Error()) } - err = conn.GetFsError(ErrPermissionDenied) + err = conn.GetFsError(fs, ErrPermissionDenied) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error()) } else { assert.EqualError(t, err, ErrPermissionDenied.Error()) } - err = conn.GetFsError(vfs.ErrVfsUnsupported) + err = conn.GetFsError(fs, vfs.ErrVfsUnsupported) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) } else { assert.EqualError(t, err, ErrOpUnsupported.Error()) } - err = conn.GetFsError(vfs.ErrStorageSizeUnavailable) + err = conn.GetFsError(fs, vfs.ErrStorageSizeUnavailable) if protocol == ProtocolSFTP { assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) } else { assert.EqualError(t, err, vfs.ErrStorageSizeUnavailable.Error()) } - err = conn.GetFsError(nil) + err = conn.GetFsError(fs, nil) assert.NoError(t, err) err = conn.GetOpUnsupportedError() if protocol == ProtocolSFTP { @@ -1225,42 +280,42 @@ func TestMaxWriteSize(t *testing.T) { } fs, err := user.GetFilesystem("123") assert.NoError(t, err) - conn := NewBaseConnection("", ProtocolFTP, user, fs) + conn := NewBaseConnection("", ProtocolFTP, user) quotaResult := vfs.QuotaCheckResult{ HasSpace: true, } - size, err := conn.GetMaxWriteSize(quotaResult, false, 0) + size, err := conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(0), size) conn.User.Filters.MaxUploadFileSize = 100 - size, err = conn.GetMaxWriteSize(quotaResult, false, 0) + size, err = conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(100), size) quotaResult.QuotaSize = 1000 - size, err = conn.GetMaxWriteSize(quotaResult, false, 50) + size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(100), size) quotaResult.QuotaSize = 1000 quotaResult.UsedSize = 990 - size, err = conn.GetMaxWriteSize(quotaResult, false, 50) + size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(60), size) quotaResult.QuotaSize = 0 quotaResult.UsedSize = 0 - size, err = conn.GetMaxWriteSize(quotaResult, true, 100) + size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) assert.EqualError(t, err, ErrQuotaExceeded.Error()) assert.Equal(t, int64(0), size) - size, err = conn.GetMaxWriteSize(quotaResult, true, 10) + size, err = conn.GetMaxWriteSize(quotaResult, true, 10, fs.IsUploadResumeSupported()) assert.NoError(t, err) assert.Equal(t, int64(90), size) - conn.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir()) - size, err = conn.GetMaxWriteSize(quotaResult, true, 100) + fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir()) + size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) assert.EqualError(t, err, ErrOpUnsupported.Error()) assert.Equal(t, int64(0), size) } diff --git a/common/protocol_test.go b/common/protocol_test.go new file mode 100644 index 00000000..61abf73b --- /dev/null +++ b/common/protocol_test.go @@ -0,0 +1,2415 @@ +package common_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "math" + "net" + "net/http" + "os" + "path" + "path/filepath" + "runtime" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/pkg/sftp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/rs/zerolog" + + "github.com/drakkan/sftpgo/common" + "github.com/drakkan/sftpgo/config" + "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/httpclient" + "github.com/drakkan/sftpgo/httpdtest" + "github.com/drakkan/sftpgo/kms" + "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/vfs" +) + +const ( + configDir = ".." + httpAddr = "127.0.0.1:9999" + httpProxyAddr = "127.0.0.1:7777" + sftpServerAddr = "127.0.0.1:4022" + defaultUsername = "test_common_sftp" + defaultPassword = "test_password" + defaultSFTPUsername = "test_common_sftpfs_user" + osWindows = "windows" + testFileName = "test_file_common_sftp.dat" + testDir = "test_dir_common" +) + +var ( + allPerms = []string{dataprovider.PermAny} + homeBasePath string + testFileContent = []byte("test data") +) + +func TestMain(m *testing.M) { + homeBasePath = os.TempDir() + logFilePath := filepath.Join(configDir, "common_test.log") + logger.InitLogger(logFilePath, 5, 1, 28, false, zerolog.DebugLevel) + + err := config.LoadConfig(configDir, "") + if err != nil { + logger.ErrorToConsole("error loading configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver) + + err = common.Initialize(config.GetCommonConfig()) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + + httpConfig := config.GetHTTPConfig() + httpConfig.Timeout = 5 + httpConfig.RetryMax = 0 + httpConfig.Initialize(configDir) //nolint:errcheck + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings[0].Port = 4022 + + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings[0].Port = 4080 + httpdtest.SetBaseURL("http://127.0.0.1:4080") + + go func() { + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := httpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + // start a test HTTP server to receive action notifications + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "OK\n") + }) + http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "Not found\n") + }) + if err := http.ListenAndServe(httpAddr, nil); err != nil { + logger.ErrorToConsole("could not start HTTP notification server: %v", err) + os.Exit(1) + } + }() + + go func() { + common.Config.ProxyProtocol = 2 + listener, err := net.Listen("tcp", httpProxyAddr) + if err != nil { + logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err) + os.Exit(1) + } + proxyListener, err := common.Config.GetProxyListener(listener) + if err != nil { + logger.ErrorToConsole("error creating proxy protocol listener: %v", err) + os.Exit(1) + } + common.Config.ProxyProtocol = 0 + + s := &http.Server{} + if err := s.Serve(proxyListener); err != nil { + logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(httpAddr) + waitTCPListening(httpProxyAddr) + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + + exitCode := m.Run() + os.Remove(logFilePath) + os.Exit(exitCode) +} + +func TestBaseConnection(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.RemoveDirectory(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.Error(t, err) + info, err := client.Stat(testDir) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + err = client.RemoveDirectory(testDir) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.ErrorIs(t, err, os.ErrNotExist) + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + linkName := testFileName + ".link" + err = client.Symlink(testFileName, linkName) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName) + assert.Error(t, err) + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(len(testFileContent)), info.Size()) + assert.False(t, info.IsDir()) + } + info, err = client.Lstat(linkName) + if assert.NoError(t, err) { + assert.NotEqual(t, int64(7), info.Size()) + assert.True(t, info.Mode()&os.ModeSymlink != 0) + assert.False(t, info.IsDir()) + } + err = client.RemoveDirectory(linkName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Failure") + } + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Remove(linkName) + assert.NoError(t, err) + err = client.Rename(testFileName, "test") + assert.ErrorIs(t, err, os.ErrNotExist) + f, err = client.Create(testFileName) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + err = client.Remove(testFileName + "1") + assert.NoError(t, err) + err = client.RemoveDirectory("missing") + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSetStat(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + acmodTime := time.Now().Add(36 * time.Hour) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.NoError(t, err) + newFi, err := client.Lstat(testFileName) + assert.NoError(t, err) + diff := math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1)) + if runtime.GOOS != osWindows { + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.NoError(t, err) + } + newPerm := os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + newFi, err = client.Lstat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, newPerm, newFi.Mode().Perm()) + } + err = client.Truncate(testFileName, 2) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(2), info.Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + + err = client.Truncate(testFileName, 0) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.ErrorIs(t, err, os.ErrNotExist) + if runtime.GOOS != osWindows { + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.ErrorIs(t, err, os.ErrNotExist) + } + err = client.Chmod(testFileName, newPerm) + assert.ErrorIs(t, err, os.ErrNotExist) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermissionErrors(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser() + subDir := "/sub" + u.Permissions[subDir] = nil + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = client.MkdirAll(path.Join(subDir, subDir)) + assert.NoError(t, err) + f, err := client.Create(path.Join(subDir, subDir, testFileName)) + if assert.NoError(t, err) { + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + } + client, err = getSftpClient(sftpUser) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir(subDir) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink("test", path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chmod(path.Join(subDir, subDir), os.ModePerm) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chown(path.Join(subDir, subDir), os.Getuid(), os.Getgid()) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chtimes(path.Join(subDir, subDir), time.Now(), time.Now()) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Truncate(path.Join(subDir, subDir), 0) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Remove(path.Join(subDir, subDir, testFileName)) + assert.ErrorIs(t, err, os.ErrPermission) + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestFileNotAllowedErrors(t *testing.T) { + deniedDir := "/denied" + u := getTestUser() + u.Filters.FilePatterns = []dataprovider.PatternsFilter{ + { + Path: deniedDir, + DeniedPatterns: []string{"*.txt"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt") + err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFile, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = client.Remove(path.Join(deniedDir, "file.txt")) + // the sftp client will try to remove the path as directory after receiving + // a permission denied error, so we get a generic failure here + assert.Error(t, err) + err = client.Rename(path.Join(deniedDir, "file.txt"), path.Join(deniedDir, "file1.txt")) + assert.ErrorIs(t, err, os.ErrPermission) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTruncateQuotaLimits(t *testing.T) { + u := getTestUser() + u.QuotaSize = 20 + mappedPath1 := filepath.Join(os.TempDir(), "mapped1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vmapped1" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + QuotaFiles: 10, + }) + mappedPath2 := filepath.Join(os.TempDir(), "mapped2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vmapped2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + QuotaFiles: -1, + QuotaSize: -1, + }) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaSize = 20 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + f, err := client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(5) + assert.NoError(t, err) + expectedQuotaSize = int64(5) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + expectedQuotaSize = int64(5) + int64(len(testFileContent)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now truncate by path + err = client.Truncate(testFileName, 5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + // now open an existing file without truncate it, quota should not change + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + } + // open the file truncating it + f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + // now test max write size + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(11) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(11), user.UsedQuotaSize) + _, err = f.Seek(int64(11), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + _, err = f.Seek(int64(5), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(12) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(12), user.UsedQuotaSize) + _, err = f.Seek(int64(12), io.SeekStart) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + // the file is deleted + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + + if user.Username == defaultUsername { + // basic test inside a virtual folder + vfileName1 := path.Join(vdirPath1, testFileName) + f, err = client.OpenFile(vfileName1, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + fold, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + fold, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + } + err = client.Truncate(vfileName1, 1) + assert.NoError(t, err) + fold, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(1), fold.UsedQuotaSize) + assert.Equal(t, 1, fold.UsedQuotaFiles) + // now test on vdirPath2, the folder quota is included in the user's quota + vfileName2 := path.Join(vdirPath2, testFileName) + f, err = client.OpenFile(vfileName2, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(3) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(3) + fold, _, err := httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + fold, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + + // cleanup + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + user.UsedQuotaFiles = 0 + user.UsedQuotaSize = 0 + _, err = httpdtest.UpdateQuotaUsage(user, "reset", http.StatusOK) + assert.NoError(t, err) + user.QuotaSize = 0 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { + testFileSize := int64(131072) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" //nolint:goconst + u := getTestUser() + u.QuotaFiles = 0 + u.QuotaSize = 0 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" //nolint:goconst + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" //nolint:goconst + mappedPath3 := filepath.Join(os.TempDir(), "vdir3") + folderName3 := filepath.Base(mappedPath3) + vdirPath3 := "/vdir3" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + QuotaFiles: 2, + QuotaSize: 0, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath2, + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 0, + QuotaSize: testFileSize + testFileSize1 + 1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName3, + MappedPath: mappedPath3, + }, + VirtualPath: vdirPath3, + QuotaFiles: 2, + QuotaSize: testFileSize * 2, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath3, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath3, testFileName+"1"), testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName+".rename")) + assert.Error(t, err) + // we overwrite an existing file and we have unlimited size + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + // we have no space and we try to overwrite a bigger file with a smaller one, this should succeed + err = client.Rename(testFileName1, path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + // we have no space and we try to overwrite a smaller file with a bigger one, this should fail + err = client.Rename(testFileName, path.Join(vdirPath2, testFileName1)) + assert.Error(t, err) + fi, err := client.Stat(path.Join(vdirPath1, testFileName1)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize1, fi.Size()) + } + // we are overquota inside vdir3 size 2/2 and size 262144/262144 + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName1+".rename")) + assert.Error(t, err) + // we overwrite an existing file and we have enough size + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName)) + assert.NoError(t, err) + testFileName2 := "test_file2.dat" + err = writeSFTPFile(testFileName2, testFileSize+testFileSize1, client) + assert.NoError(t, err) + // we overwrite an existing file and we haven't enough size + err = client.Rename(testFileName2, path.Join(vdirPath3, testFileName)) + assert.Error(t, err) + // now remove a file from vdir3, create a dir with 2 files and try to rename it in vdir3 + // this will fail since the rename will result in 3 files inside vdir3 and quota limits only + // allow 2 total files there + err = client.Remove(path.Join(vdirPath3, testFileName+"1")) + assert.NoError(t, err) + aDir := "a dir" + err = client.Mkdir(aDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(aDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(aDir, testFileName1+"1"), testFileSize1, client) + assert.NoError(t, err) + err = client.Rename(aDir, path.Join(vdirPath3, aDir)) + assert.Error(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath3) + assert.NoError(t, err) +} + +func TestQuotaRenameOverwrite(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileSize := int64(131072) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + err = client.Rename(testFileName, testFileName1) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + err = client.Remove(testFileName1) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = client.Rename(testFileName1, testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaValues(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + folderName2 := filepath.Base(mappedPath2) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileSize := int64(131072) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + // we copy the same file two times to test quota update on file overwrite + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + expectedQuotaFiles := 2 + expectedQuotaSize := testFileSize * 2 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Remove(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + err = client.Remove(path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + folderName2 := filepath.Base(mappedPath2) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" //nolint:goconst + dir2 := "dir2" //nolint:goconst + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file inside vdir1 it is included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file inside vdir2, it isn't included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file inside vdir2 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file inside vdir1 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename (initial testFileName1) + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a directory inside the same virtual folder, quota should not change + err = client.RemoveDirectory(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath1, dir1), path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, dir1), path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameBetweenVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to vdir2, vdir1 is included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName1+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1, vdir2 is not included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize*2, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir1 to vdir2 overwriting an existing file, vdir1 is included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 (is the initial testFileName) + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath2, dir2, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1 overwriting an existing file, vdir2 is not included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename (is the initial testFileName1) + // - vdir2/dir2/testFileName1 (is the initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName1+".rename"), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName+"1.dupl"), testFileSize1, client) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + // - vdir1/dir2/testFileName.rename (initial testFileName1) + // - vdir1/dir2/testFileName + // - vdir2/dir2/testFileName1 (initial testFileName) + // - vdir2/dir2/testFileName (initial testFileName1) + // - vdir2/dir2/testFileName1.dupl + // rename directories between the two virtual folders + err = client.Rename(path.Join(vdirPath2, dir2), path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 5, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*3+testFileSize*2, f.UsedQuotaSize) + assert.Equal(t, 5, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // now move on vpath2 + err = client.Rename(path.Join(vdirPath1, dir2), path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameFromVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to the user home dir, vdir1 is included in user quota so we have: + // - testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir, vdir2 is not included in user quota so we have: + // - testFileName + // - testFileName1 + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir1 to the user home dir overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir overwriting an existing file, vdir2 is not included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // dir rename + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - vdir1/dir1/testFileName + // - vdir1/dir1/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1), dir1) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - dir2/testFileName + // - dir2/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1), dir2) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameToVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + u.Permissions[vdirPath1] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, + dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermCreateSymlinks, dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - testFileName + // - testFileName1 + // + // rename a file from user home dir to vdir1, vdir1 is included in user quota so we have: + // - testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName1, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have: + // - /vdir2/dir1/testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // upload two new files to the user home dir so we have: + // - testFileName + // - testFileName1 + // - /vdir1/dir1/testFileName1 + // - /vdir2/dir1/testFileName + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + // rename a file from user home dir to vdir1 overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName + err = client.Rename(testFileName, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2 overwriting an existing file, vdir2 is not included in user quota so we have: + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(testFileName1, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /dir1/testFileName + // - /dir1/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(dir1, path.Join(vdirPath1, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + // - /vdir2/adir/testFileName + // - /vdir2/adir/testFileName1 + err = client.Rename(dir1, path.Join(vdirPath2, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize*2+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestVirtualFoldersLink(t *testing.T) { + u := getTestUser() + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileSize := int64(131072) + testDir := "adir" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(testFileName, path.Join(vdirPath1, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(testFileName, path.Join(vdirPath1, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(testFileName, path.Join(vdirPath2, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(testFileName, path.Join(vdirPath2, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(path.Join(vdirPath1, testFileName), testFileName+".link1") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(path.Join(vdirPath2, testFileName), testFileName+".link1") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath1, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink("/", "/roolink") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink(testFileName, "/") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink(testFileName, vdirPath1) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Symlink(vdirPath1, testFileName+".link2") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestDirs(t *testing.T) { + u := getTestUser() + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/path/vdir" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + }) + u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, + dataprovider.PermDelete, dataprovider.PermCreateDirs, dataprovider.PermRename, dataprovider.PermListItems} + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + info, err := client.ReadDir("/") + if assert.NoError(t, err) { + assert.Len(t, info, 1) + assert.Equal(t, "path", info[0].Name()) + } + fi, err := client.Stat(path.Dir(vdirPath)) + if assert.NoError(t, err) { + assert.True(t, fi.IsDir()) + } + err = client.RemoveDirectory("/") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(vdirPath) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(path.Dir(vdirPath)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.Mkdir(vdirPath) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir("adir") + assert.NoError(t, err) + err = client.Rename("/adir", path.Dir(vdirPath)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = client.MkdirAll("/subdir/adir") + assert.NoError(t, err) + err = client.Rename("adir", "subdir/adir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + err = writeSFTPFile("/subdir/afile.bin", 64, client) + assert.NoError(t, err) + err = writeSFTPFile("/afile.bin", 32, client) + assert.NoError(t, err) + err = client.Rename("afile.bin", "subdir/afile.bin") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename("afile.bin", "subdir/afile1.bin") + assert.NoError(t, err) + err = client.Rename(path.Dir(vdirPath), "renamed_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestCryptFsStat(t *testing.T) { + user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testFileSize := int64(4096) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + info, err = os.Stat(filepath.Join(user.HomeDir, testFileName)) + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), testFileSize) + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestFsPermissionErrors(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + testDir := "tDir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = os.Chmod(user.GetHomeDir(), 0111) + assert.NoError(t, err) + + err = client.RemoveDirectory(testDir) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testDir, testDir+"1") + assert.ErrorIs(t, err, os.ErrPermission) + + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestResolvePathError(t *testing.T) { + u := getTestUser() + u.HomeDir = "relative_path" + conn := common.NewBaseConnection("", common.ProtocolFTP, u) + testPath := "apath" + _, err := conn.ListDir(testPath) + assert.Error(t, err) + err = conn.CreateDir(testPath) + assert.Error(t, err) + err = conn.RemoveDir(testPath) + assert.Error(t, err) + err = conn.Rename(testPath, testPath+"1") + assert.Error(t, err) + err = conn.CreateSymlink(testPath, testPath+".sym") + assert.Error(t, err) + _, err = conn.DoStat(testPath, 0) + assert.Error(t, err) + err = conn.SetStat(testPath, &common.StatAttributes{ + Atime: time.Now(), + Mtime: time.Now(), + }) + assert.Error(t, err) + + u = getTestUser() + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: "relative_mapped_path", + }, + VirtualPath: "/vpath", + }) + err = os.MkdirAll(u.HomeDir, os.ModePerm) + assert.NoError(t, err) + conn.User = u + err = conn.Rename(testPath, "/vpath/subpath") + assert.Error(t, err) + + outHomePath := filepath.Join(os.TempDir(), testFileName) + err = os.WriteFile(outHomePath, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = os.Symlink(outHomePath, filepath.Join(u.HomeDir, testFileName+".link")) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(u.HomeDir, testFileName), testFileContent, os.ModePerm) + assert.NoError(t, err) + err = conn.CreateSymlink(testFileName, testFileName+".link") + assert.Error(t, err) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(outHomePath) + assert.NoError(t, err) +} + +func TestQuotaTrackDisabled(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.TrackQuota = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestGetQuotaError(t *testing.T) { + u := getTestUser() + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/vpath" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + QuotaSize: 0, + QuotaFiles: 10, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + + err = client.Rename(testFileName, path.Join(vdirPath, testFileName)) + assert.Error(t, err) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestRenameDir(t *testing.T) { + u := getTestUser() + testDir := "/dir-to-rename" + u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), 32, client) + assert.NoError(t, err) + err = client.Rename(testDir, testDir+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameSymlink(t *testing.T) { + u := getTestUser() + testDir := "/dir-no-create-links" + otherDir := "otherdir" + u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + err = client.Mkdir(otherDir) + assert.NoError(t, err) + err = client.Symlink(otherDir, otherDir+".link") + assert.NoError(t, err) + err = client.Rename(otherDir+".link", path.Join(testDir, "symlink")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(otherDir+".link", "allowed_link") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestNonLocalCrossRename(t *testing.T) { + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestUser() + u.HomeDir += "_folders" + u.Username += "_folders" + mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") + folderNameSFTP := filepath.Base(mappedPathSFTP) + vdirSFTPPath := "/vdir/sftp" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + FsConfig: vfs.Filesystem{ + Provider: vfs.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + }, + VirtualPath: vdirSFTPPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + }) + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirSFTPPath, testFileName), 8192, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirSFTPPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirSFTPPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + // rename on local fs or on the same folder must work + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirSFTPPath, testFileName+"_rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) + assert.NoError(t, err) + // renaming a virtual folder is not allowed + err = client.Rename(vdirSFTPPath, vdirSFTPPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, path.Join(vdirCryptPath, "rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(vdirCryptPath, "subcryptdir")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, "subcryptdir"), vdirCryptPath) + assert.ErrorIs(t, err, os.ErrPermission) + // renaming root folder is not allowed + err = client.Rename("/", "new_name") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming a path to a virtual folder is not allowed + err = client.Rename("/vdir", "new_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathSFTP) + assert.NoError(t, err) +} + +func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) { + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestSFTPUser() + mappedPathLocal := filepath.Join(os.TempDir(), "local") + folderNameLocal := filepath.Base(mappedPathLocal) + vdirLocalPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameLocal, + MappedPath: mappedPathLocal, + }, + VirtualPath: vdirLocalPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + }) + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirLocalPath, testFileName), 8192, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirLocalPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirLocalPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirLocalPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + // rename on local fs or on the same folder must work + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirLocalPath, testFileName+"_rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) + assert.NoError(t, err) + // renaming a virtual folder is not allowed + err = client.Rename(vdirLocalPath, vdirLocalPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming root folder is not allowed + err = client.Rename("/", "new_name") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming a path to a virtual folder is not allowed + err = client.Rename("/vdir", "new_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "Operation Unsupported") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameLocal}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathLocal) + assert.NoError(t, err) +} + +func TestProxyProtocol(t *testing.T) { + httpClient := httpclient.GetHTTPClient() + resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + } +} + +func TestSetProtocol(t *testing.T) { + conn := common.NewBaseConnection("id", "sshd_exec", dataprovider.User{HomeDir: os.TempDir()}) + conn.SetProtocol(common.ProtocolSCP) + require.Equal(t, "SCP_id", conn.GetID()) +} + +func TestGetFsError(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = vfs.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + conn := common.NewBaseConnection("", common.ProtocolFTP, u) + _, _, err := conn.GetFsAndResolvedPath("/vpath") + assert.Error(t, err) +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func checkBasicSFTP(client *sftp.Client) error { + _, err := client.Getwd() + if err != nil { + return err + } + _, err = client.ReadDir(".") + return err +} + +func getSftpClient(user dataprovider.User) (*sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + return sftpClient, err +} + +func getTestUser() dataprovider.User { + user := dataprovider.User{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + ExpirationDate: 0, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + return user +} + +func getTestSFTPUser() dataprovider.User { + u := getTestUser() + u.Username = defaultSFTPUsername + u.FsConfig.Provider = vfs.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + return u +} + +func getCryptFsUser() dataprovider.User { + u := getTestUser() + u.FsConfig.Provider = vfs.CryptedFilesystemProvider + u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) + return u +} + +func writeSFTPFile(name string, size int64, client *sftp.Client) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + f, err := client.Create(name) + if err != nil { + return err + } + _, err = io.Copy(f, bytes.NewBuffer(content)) + if err != nil { + f.Close() + return err + } + err = f.Close() + if err != nil { + return err + } + info, err := client.Stat(name) + if err != nil { + return err + } + if info.Size() != size { + return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) + } + return nil +} diff --git a/common/transfer.go b/common/transfer.go index 19f50906..4dc31b9b 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -180,7 +180,7 @@ func (t *BaseTransfer) getUploadFileSize() (int64, error) { fileSize = info.Size() } if vfs.IsCryptOsFs(t.Fs) && t.ErrTransfer != nil { - errDelete := t.Connection.Fs.Remove(t.fsPath, false) + errDelete := t.Fs.Remove(t.fsPath, false) if errDelete != nil { t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %#v: %v", t.fsPath, errDelete) } @@ -204,7 +204,7 @@ func (t *BaseTransfer) Close() error { metrics.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) if t.ErrTransfer == ErrQuotaExceeded && t.File != nil { // if quota is exceeded we try to remove the partial file for uploads to local filesystem - err = t.Connection.Fs.Remove(t.File.Name(), false) + err = t.Fs.Remove(t.File.Name(), false) if err == nil { numFiles-- atomic.StoreInt64(&t.BytesReceived, 0) @@ -214,11 +214,11 @@ func (t *BaseTransfer) Close() error { t.File.Name(), err) } else if t.transferType == TransferUpload && t.File != nil && t.File.Name() != t.fsPath { if t.ErrTransfer == nil || Config.UploadMode == UploadModeAtomicWithResume { - err = t.Connection.Fs.Rename(t.File.Name(), t.fsPath) + err = t.Fs.Rename(t.File.Name(), t.fsPath) t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %#v -> %#v, error: %v", t.File.Name(), t.fsPath, err) } else { - err = t.Connection.Fs.Remove(t.File.Name(), false) + err = t.Fs.Remove(t.File.Name(), false) t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+ "deletion error: %v", t.ErrTransfer, t.File.Name(), err) if err == nil { diff --git a/common/transfer_test.go b/common/transfer_test.go index 13dcdc39..d4ecde51 100644 --- a/common/transfer_test.go +++ b/common/transfer_test.go @@ -16,12 +16,12 @@ import ( ) func TestTransferUpdateQuota(t *testing.T) { - conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{}, nil) + conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{}) transfer := BaseTransfer{ Connection: conn, transferType: TransferUpload, BytesReceived: 123, - Fs: vfs.NewOsFs("", os.TempDir(), nil), + Fs: vfs.NewOsFs("", os.TempDir(), ""), } errFake := errors.New("fake error") transfer.TransferError(errFake) @@ -54,14 +54,14 @@ func TestTransferThrottling(t *testing.T) { UploadBandwidth: 50, DownloadBandwidth: 40, } - fs := vfs.NewOsFs("", os.TempDir(), nil) + fs := vfs.NewOsFs("", os.TempDir(), "") testFileSize := int64(131072) wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth // some tolerance wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10 - conn := NewBaseConnection("id", ProtocolSCP, u, nil) + conn := NewBaseConnection("id", ProtocolSCP, u) transfer := NewBaseTransfer(nil, conn, nil, "", "", TransferUpload, 0, 0, 0, true, fs) transfer.BytesReceived = testFileSize transfer.Connection.UpdateLastActivity() @@ -86,7 +86,7 @@ func TestTransferThrottling(t *testing.T) { func TestRealPath(t *testing.T) { testFile := filepath.Join(os.TempDir(), "afile.txt") - fs := vfs.NewOsFs("123", os.TempDir(), nil) + fs := vfs.NewOsFs("123", os.TempDir(), "") u := dataprovider.User{ Username: "user", HomeDir: os.TempDir(), @@ -95,7 +95,7 @@ func TestRealPath(t *testing.T) { u.Permissions["/"] = []string{dataprovider.PermAny} file, err := os.Create(testFile) require.NoError(t, err) - conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs) + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u) transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) rPath := transfer.GetRealFsPath(testFile) assert.Equal(t, testFile, rPath) @@ -117,7 +117,7 @@ func TestRealPath(t *testing.T) { func TestTruncate(t *testing.T) { testFile := filepath.Join(os.TempDir(), "transfer_test_file") - fs := vfs.NewOsFs("123", os.TempDir(), nil) + fs := vfs.NewOsFs("123", os.TempDir(), "") u := dataprovider.User{ Username: "user", HomeDir: os.TempDir(), @@ -130,10 +130,10 @@ func TestTruncate(t *testing.T) { } _, err = file.Write([]byte("hello")) assert.NoError(t, err) - conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs) + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u) transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs) - err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{ + err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, Flags: StatAttrSize, }) @@ -150,7 +150,7 @@ func TestTruncate(t *testing.T) { transfer = NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs) // file.Stat will fail on a closed file - err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{ + err = conn.SetStat("/transfer_test_file", &StatAttributes{ Size: 2, Flags: StatAttrSize, }) @@ -181,7 +181,7 @@ func TestTransferErrors(t *testing.T) { isCancelled = true } testFile := filepath.Join(os.TempDir(), "transfer_test_file") - fs := vfs.NewOsFs("id", os.TempDir(), nil) + fs := vfs.NewOsFs("id", os.TempDir(), "") u := dataprovider.User{ Username: "test", HomeDir: os.TempDir(), @@ -192,7 +192,7 @@ func TestTransferErrors(t *testing.T) { if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } - conn := NewBaseConnection("id", ProtocolSFTP, u, fs) + conn := NewBaseConnection("id", ProtocolSFTP, u) transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) assert.Nil(t, transfer.cancelFn) assert.Equal(t, testFile, transfer.GetFsPath()) @@ -255,13 +255,13 @@ func TestTransferErrors(t *testing.T) { func TestRemovePartialCryptoFile(t *testing.T) { testFile := filepath.Join(os.TempDir(), "transfer_test_file") - fs, err := vfs.NewCryptFs("id", os.TempDir(), vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) + fs, err := vfs.NewCryptFs("id", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) require.NoError(t, err) u := dataprovider.User{ Username: "test", HomeDir: os.TempDir(), } - conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs) + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u) transfer := NewBaseTransfer(nil, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) transfer.ErrTransfer = errors.New("test error") _, err = transfer.getUploadFileSize() diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index 7b20cea3..f73f861a 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -44,7 +44,7 @@ func initializeBoltProvider(basePath string) error { logSender = fmt.Sprintf("dataprovider_%v", BoltDataProviderName) dbPath := config.Name if !utils.IsFileInputValid(dbPath) { - return fmt.Errorf("Invalid database path: %#v", dbPath) + return fmt.Errorf("invalid database path: %#v", dbPath) } if !filepath.IsAbs(dbPath) { dbPath = filepath.Join(basePath, dbPath) @@ -119,7 +119,7 @@ func (p *BoltProvider) validateUserAndTLSCert(username, protocol string, tlsCert func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { var user User if password == "" { - return user, errors.New("Credentials cannot be null or empty") + return user, errors.New("credentials cannot be null or empty") } user, err := p.userExists(username) if err != nil { @@ -142,7 +142,7 @@ func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admi func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) { var user User if len(pubKey) == 0 { - return user, "", errors.New("Credentials cannot be null or empty") + return user, "", errors.New("credentials cannot be null or empty") } user, err := p.userExists(username) if err != nil { @@ -435,8 +435,8 @@ func (p *BoltProvider) addUser(user *User) error { user.UsedQuotaSize = 0 user.UsedQuotaFiles = 0 user.LastLogin = 0 - for _, folder := range user.VirtualFolders { - err = addUserToFolderMapping(folder, user, folderBucket) + for idx := range user.VirtualFolders { + err = addUserToFolderMapping(&user.VirtualFolders[idx].BaseVirtualFolder, user, folderBucket) if err != nil { return err } @@ -472,14 +472,14 @@ func (p *BoltProvider) updateUser(user *User) error { if err != nil { return err } - for _, folder := range oldUser.VirtualFolders { - err = removeUserFromFolderMapping(folder, &oldUser, folderBucket) + for idx := range oldUser.VirtualFolders { + err = removeUserFromFolderMapping(&oldUser.VirtualFolders[idx], &oldUser, folderBucket) if err != nil { return err } } - for _, folder := range user.VirtualFolders { - err = addUserToFolderMapping(folder, user, folderBucket) + for idx := range user.VirtualFolders { + err = addUserToFolderMapping(&user.VirtualFolders[idx].BaseVirtualFolder, user, folderBucket) if err != nil { return err } @@ -508,8 +508,8 @@ func (p *BoltProvider) deleteUser(user *User) error { if err != nil { return err } - for _, folder := range user.VirtualFolders { - err = removeUserFromFolderMapping(folder, user, folderBucket) + for idx := range user.VirtualFolders { + err = removeUserFromFolderMapping(&user.VirtualFolders[idx], user, folderBucket) if err != nil { return err } @@ -649,6 +649,7 @@ func (p *BoltProvider) getFolders(limit, offset int, order string) ([]vfs.BaseVi if err != nil { return err } + folder.HideConfidentialData() folders = append(folders, folder) if len(folders) >= limit { break @@ -665,6 +666,7 @@ func (p *BoltProvider) getFolders(limit, offset int, order string) ([]vfs.BaseVi if err != nil { return err } + folder.HideConfidentialData() folders = append(folders, folder) if len(folders) >= limit { break @@ -703,8 +705,7 @@ func (p *BoltProvider) addFolder(folder *vfs.BaseVirtualFolder) error { return fmt.Errorf("folder %v already exists", folder.Name) } folder.Users = nil - _, err = addFolderInternal(*folder, bucket) - return err + return addFolderInternal(*folder, bucket) }) } @@ -867,7 +868,7 @@ func (p *BoltProvider) migrateDatabase() error { boltDatabaseVersion) return nil } - return fmt.Errorf("Database version not handled: %v", version) + return fmt.Errorf("database version not handled: %v", version) } } @@ -893,17 +894,14 @@ func joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) { } if len(user.VirtualFolders) > 0 { var folders []vfs.VirtualFolder - for _, folder := range user.VirtualFolders { + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] baseFolder, err := folderExistsInternal(folder.Name, foldersBucket) if err != nil { continue } - folder.MappedPath = baseFolder.MappedPath - folder.UsedQuotaFiles = baseFolder.UsedQuotaFiles - folder.UsedQuotaSize = baseFolder.UsedQuotaSize - folder.LastQuotaUpdate = baseFolder.LastQuotaUpdate - folder.ID = baseFolder.ID - folders = append(folders, folder) + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) } user.VirtualFolders = folders } @@ -922,50 +920,50 @@ func folderExistsInternal(name string, bucket *bolt.Bucket) (vfs.BaseVirtualFold return folder, err } -func addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) (vfs.BaseVirtualFolder, error) { +func addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) error { id, err := bucket.NextSequence() if err != nil { - return folder, err + return err } folder.ID = int64(id) buf, err := json.Marshal(folder) if err != nil { - return folder, err + return err } - err = bucket.Put([]byte(folder.Name), buf) - return folder, err + return bucket.Put([]byte(folder.Name), buf) } -func addUserToFolderMapping(folder vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error { - var baseFolder vfs.BaseVirtualFolder - var err error - if f := bucket.Get([]byte(folder.Name)); f == nil { +func addUserToFolderMapping(baseFolder *vfs.BaseVirtualFolder, user *User, bucket *bolt.Bucket) error { + f := bucket.Get([]byte(baseFolder.Name)) + if f == nil { // folder does not exists, try to create - folder.LastQuotaUpdate = 0 - folder.UsedQuotaFiles = 0 - folder.UsedQuotaSize = 0 - baseFolder, err = addFolderInternal(folder.BaseVirtualFolder, bucket) - } else { - err = json.Unmarshal(f, &baseFolder) + baseFolder.LastQuotaUpdate = 0 + baseFolder.UsedQuotaFiles = 0 + baseFolder.UsedQuotaSize = 0 + baseFolder.Users = []string{user.Username} + return addFolderInternal(*baseFolder, bucket) } + var oldFolder vfs.BaseVirtualFolder + err := json.Unmarshal(f, &oldFolder) if err != nil { return err } + baseFolder.ID = oldFolder.ID + baseFolder.LastQuotaUpdate = oldFolder.LastQuotaUpdate + baseFolder.UsedQuotaFiles = oldFolder.UsedQuotaFiles + baseFolder.UsedQuotaSize = oldFolder.UsedQuotaSize + baseFolder.Users = oldFolder.Users if !utils.IsStringInSlice(user.Username, baseFolder.Users) { baseFolder.Users = append(baseFolder.Users, user.Username) - buf, err := json.Marshal(baseFolder) - if err != nil { - return err - } - err = bucket.Put([]byte(folder.Name), buf) - if err != nil { - return err - } } - return err + buf, err := json.Marshal(baseFolder) + if err != nil { + return err + } + return bucket.Put([]byte(baseFolder.Name), buf) } -func removeUserFromFolderMapping(folder vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error { +func removeUserFromFolderMapping(folder *vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error { var f []byte if f = bucket.Get([]byte(folder.Name)); f == nil { // the folder does not exists so there is no associated user diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 6f9d86ec..31fe525d 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -105,7 +105,7 @@ var ( // ValidProtocols defines all the valid protcols ValidProtocols = []string{"SSH", "FTP", "DAV"} // ErrNoInitRequired defines the error returned by InitProvider if no inizialization/update is required - ErrNoInitRequired = errors.New("The data provider is up to date") + ErrNoInitRequired = errors.New("the data provider is up to date") // ErrInvalidCredentials defines the error to return if the supplied credentials are invalid ErrInvalidCredentials = errors.New("invalid credentials") validTLSUsernames = []string{string(TLSUsernameNone), string(TLSUsernameCN)} @@ -410,6 +410,11 @@ type Provider interface { revertDatabase(targetVersion int) error } +type fsValidatorHelper interface { + GetGCSCredentialsFilePath() string + GetEncrytionAdditionalData() string +} + // Initialize the data provider. // An error is returned if the configured driver is invalid or if the data provider cannot be initialized func Initialize(cnf Config, basePath string, checkAdmins bool) error { @@ -421,6 +426,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error { } else { credentialsDirPath = filepath.Join(basePath, config.CredentialsPath) } + vfs.SetCredentialsDirPath(credentialsDirPath) if err = validateHooks(); err != nil { return err @@ -498,7 +504,7 @@ func validateSQLTablesPrefix() error { if len(config.SQLTablesPrefix) > 0 { for _, char := range config.SQLTablesPrefix { if !strings.Contains(sqlPrefixValidChars, strings.ToLower(string(char))) { - return errors.New("Invalid sql_tables_prefix only chars in range 'a..z', 'A..Z' and '_' are allowed") + return errors.New("invalid sql_tables_prefix only chars in range 'a..z', 'A..Z' and '_' are allowed") } } sqlTableUsers = config.SQLTablesPrefix + sqlTableUsers @@ -583,7 +589,7 @@ func CheckCachedUserCredentials(user *CachedUser, password, loginMethod, protoco } if loginMethod == LoginMethodTLSCertificate { if !user.User.IsLoginMethodAllowed(LoginMethodTLSCertificate, nil) { - return fmt.Errorf("Certificate login method is not allowed for user %#v", user.User.Username) + return fmt.Errorf("certificate login method is not allowed for user %#v", user.User.Username) } return nil } @@ -628,7 +634,7 @@ func CheckCompositeCredentials(username, password, ip, loginMethod, protocol str return user, loginMethod, err } if loginMethod == LoginMethodTLSCertificate && !user.IsLoginMethodAllowed(LoginMethodTLSCertificate, nil) { - return user, loginMethod, fmt.Errorf("Certificate login method is not allowed for user %#v", user.Username) + return user, loginMethod, fmt.Errorf("certificate login method is not allowed for user %#v", user.Username) } if loginMethod == LoginMethodTLSCertificateAndPwd { if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) { @@ -876,8 +882,14 @@ func AddFolder(folder *vfs.BaseVirtualFolder) error { } // UpdateFolder updates the specified virtual folder -func UpdateFolder(folder *vfs.BaseVirtualFolder) error { - return provider.updateFolder(folder) +func UpdateFolder(folder *vfs.BaseVirtualFolder, users []string) error { + err := provider.updateFolder(folder) + if err == nil { + for _, user := range users { + RemoveCachedWebDAVUser(user) + } + } + return err } // DeleteFolder deletes an existing folder. @@ -886,7 +898,13 @@ func DeleteFolder(folderName string) error { if err != nil { return err } - return provider.deleteFolder(&folder) + err = provider.deleteFolder(&folder) + if err == nil { + for _, user := range folder.Users { + RemoveCachedWebDAVUser(user) + } + } + return err } // GetFolderByName returns the folder with the specified name if any @@ -981,41 +999,45 @@ func buildUserHomeDir(user *User) { if user.HomeDir == "" { if config.UsersBaseDir != "" { user.HomeDir = filepath.Join(config.UsersBaseDir, user.Username) - } else if user.FsConfig.Provider == SFTPFilesystemProvider { + } else if user.FsConfig.Provider == vfs.SFTPFilesystemProvider { user.HomeDir = filepath.Join(os.TempDir(), user.Username) } } } -func isVirtualDirOverlapped(dir1, dir2 string) bool { +func isVirtualDirOverlapped(dir1, dir2 string, fullCheck bool) bool { if dir1 == dir2 { return true } - if len(dir1) > len(dir2) { - if strings.HasPrefix(dir1, dir2+"/") { - return true + if fullCheck { + if len(dir1) > len(dir2) { + if strings.HasPrefix(dir1, dir2+"/") { + return true + } } - } - if len(dir2) > len(dir1) { - if strings.HasPrefix(dir2, dir1+"/") { - return true + if len(dir2) > len(dir1) { + if strings.HasPrefix(dir2, dir1+"/") { + return true + } } } return false } -func isMappedDirOverlapped(dir1, dir2 string) bool { +func isMappedDirOverlapped(dir1, dir2 string, fullCheck bool) bool { if dir1 == dir2 { return true } - if len(dir1) > len(dir2) { - if strings.HasPrefix(dir1, dir2+string(os.PathSeparator)) { - return true + if fullCheck { + if len(dir1) > len(dir2) { + if strings.HasPrefix(dir1, dir2+string(os.PathSeparator)) { + return true + } } - } - if len(dir2) > len(dir1) { - if strings.HasPrefix(dir2, dir1+string(os.PathSeparator)) { - return true + if len(dir2) > len(dir1) { + if strings.HasPrefix(dir2, dir1+string(os.PathSeparator)) { + return true + } } } return false @@ -1046,19 +1068,33 @@ func getVirtualFolderIfInvalid(folder *vfs.BaseVirtualFolder) *vfs.BaseVirtualFo if folder.Name == "" { return folder } + if folder.FsConfig.Provider != vfs.LocalFilesystemProvider { + return folder + } if f, err := GetFolderByName(folder.Name); err == nil { return &f } return folder } +func hasSFTPLoopForFolder(user *User, folder *vfs.BaseVirtualFolder) bool { + if folder.FsConfig.Provider == vfs.SFTPFilesystemProvider { + // FIXME: this could be inaccurate, it is not easy to check the endpoint too + if folder.FsConfig.SFTPConfig.Username == user.Username { + return true + } + } + return false +} + func validateUserVirtualFolders(user *User) error { - if len(user.VirtualFolders) == 0 || user.FsConfig.Provider != LocalFilesystemProvider { + if len(user.VirtualFolders) == 0 { user.VirtualFolders = []vfs.VirtualFolder{} return nil } var virtualFolders []vfs.VirtualFolder - mappedPaths := make(map[string]string) + mappedPaths := make(map[string]bool) + virtualPaths := make(map[string]bool) for _, v := range user.VirtualFolders { cleanedVPath := filepath.ToSlash(path.Clean(v.VirtualPath)) if !path.IsAbs(cleanedVPath) || cleanedVPath == "/" { @@ -1071,34 +1107,37 @@ func validateUserVirtualFolders(user *User) error { if err := ValidateFolder(folder); err != nil { return err } - cleanedMPath := folder.MappedPath - if isMappedDirOverlapped(cleanedMPath, user.GetHomeDir()) { - return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v cannot be inside or contain the user home dir %#v", - folder.MappedPath, user.GetHomeDir())} + if hasSFTPLoopForFolder(user, folder) { + return &ValidationError{err: fmt.Sprintf("SFTP folder %#v could point to the same SFTPGo account, this is not allowed", + folder.Name)} } + cleanedMPath := folder.MappedPath + if folder.IsLocalOrLocalCrypted() { + if isMappedDirOverlapped(cleanedMPath, user.GetHomeDir(), true) { + return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v cannot be inside or contain the user home dir %#v", + folder.MappedPath, user.GetHomeDir())} + } + for mPath := range mappedPaths { + if folder.IsLocalOrLocalCrypted() && isMappedDirOverlapped(mPath, cleanedMPath, false) { + return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v overlaps with mapped folder %#v", + v.MappedPath, mPath)} + } + } + mappedPaths[cleanedMPath] = true + } + for vPath := range virtualPaths { + if isVirtualDirOverlapped(vPath, cleanedVPath, false) { + return &ValidationError{err: fmt.Sprintf("invalid virtual folder %#v overlaps with virtual folder %#v", + v.VirtualPath, vPath)} + } + } + virtualPaths[cleanedVPath] = true virtualFolders = append(virtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: *folder, VirtualPath: cleanedVPath, QuotaSize: v.QuotaSize, QuotaFiles: v.QuotaFiles, }) - for k, virtual := range mappedPaths { - if GetQuotaTracking() > 0 { - if isMappedDirOverlapped(k, cleanedMPath) { - return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v overlaps with mapped folder %#v", - v.MappedPath, k)} - } - } else { - if k == cleanedMPath { - return &ValidationError{err: fmt.Sprintf("duplicated mapped folder %#v", v.MappedPath)} - } - } - if isVirtualDirOverlapped(virtual, cleanedVPath) { - return &ValidationError{err: fmt.Sprintf("invalid virtual folder %#v overlaps with virtual folder %#v", - v.VirtualPath, virtual)} - } - } - mappedPaths[cleanedMPath] = cleanedVPath } user.VirtualFolders = virtualFolders return nil @@ -1297,35 +1336,35 @@ func validateFilters(user *User) error { return validateFileFilters(user) } -func saveGCSCredentials(user *User) error { - if user.FsConfig.Provider != GCSFilesystemProvider { +func saveGCSCredentials(fsConfig *vfs.Filesystem, helper fsValidatorHelper) error { + if fsConfig.Provider != vfs.GCSFilesystemProvider { return nil } - if user.FsConfig.GCSConfig.Credentials.GetPayload() == "" { + if fsConfig.GCSConfig.Credentials.GetPayload() == "" { return nil } if config.PreferDatabaseCredentials { - if user.FsConfig.GCSConfig.Credentials.IsPlain() { - user.FsConfig.GCSConfig.Credentials.SetAdditionalData(user.Username) - err := user.FsConfig.GCSConfig.Credentials.Encrypt() + if fsConfig.GCSConfig.Credentials.IsPlain() { + fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncrytionAdditionalData()) + err := fsConfig.GCSConfig.Credentials.Encrypt() if err != nil { return err } } return nil } - if user.FsConfig.GCSConfig.Credentials.IsPlain() { - user.FsConfig.GCSConfig.Credentials.SetAdditionalData(user.Username) - err := user.FsConfig.GCSConfig.Credentials.Encrypt() + if fsConfig.GCSConfig.Credentials.IsPlain() { + fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncrytionAdditionalData()) + err := fsConfig.GCSConfig.Credentials.Encrypt() if err != nil { return &ValidationError{err: fmt.Sprintf("could not encrypt GCS credentials: %v", err)} } } - creds, err := json.Marshal(user.FsConfig.GCSConfig.Credentials) + creds, err := json.Marshal(fsConfig.GCSConfig.Credentials) if err != nil { return &ValidationError{err: fmt.Sprintf("could not marshal GCS credentials: %v", err)} } - credentialsFilePath := user.getGCSCredentialsFilePath() + credentialsFilePath := helper.GetGCSCredentialsFilePath() err = os.MkdirAll(filepath.Dir(credentialsFilePath), 0700) if err != nil { return &ValidationError{err: fmt.Sprintf("could not create GCS credentials dir: %v", err)} @@ -1334,75 +1373,75 @@ func saveGCSCredentials(user *User) error { if err != nil { return &ValidationError{err: fmt.Sprintf("could not save GCS credentials: %v", err)} } - user.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() + fsConfig.GCSConfig.Credentials = kms.NewEmptySecret() return nil } -func validateFilesystemConfig(user *User) error { - if user.FsConfig.Provider == S3FilesystemProvider { - if err := user.FsConfig.S3Config.Validate(); err != nil { +func validateFilesystemConfig(fsConfig *vfs.Filesystem, helper fsValidatorHelper) error { + if fsConfig.Provider == vfs.S3FilesystemProvider { + if err := fsConfig.S3Config.Validate(); err != nil { return &ValidationError{err: fmt.Sprintf("could not validate s3config: %v", err)} } - if err := user.FsConfig.S3Config.EncryptCredentials(user.Username); err != nil { + if err := fsConfig.S3Config.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil { return &ValidationError{err: fmt.Sprintf("could not encrypt s3 access secret: %v", err)} } - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} - user.FsConfig.CryptConfig = vfs.CryptFsConfig{} - user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + fsConfig.GCSConfig = vfs.GCSFsConfig{} + fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + fsConfig.CryptConfig = vfs.CryptFsConfig{} + fsConfig.SFTPConfig = vfs.SFTPFsConfig{} return nil - } else if user.FsConfig.Provider == GCSFilesystemProvider { - if err := user.FsConfig.GCSConfig.Validate(user.getGCSCredentialsFilePath()); err != nil { + } else if fsConfig.Provider == vfs.GCSFilesystemProvider { + if err := fsConfig.GCSConfig.Validate(helper.GetGCSCredentialsFilePath()); err != nil { return &ValidationError{err: fmt.Sprintf("could not validate GCS config: %v", err)} } - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} - user.FsConfig.CryptConfig = vfs.CryptFsConfig{} - user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + fsConfig.S3Config = vfs.S3FsConfig{} + fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + fsConfig.CryptConfig = vfs.CryptFsConfig{} + fsConfig.SFTPConfig = vfs.SFTPFsConfig{} return nil - } else if user.FsConfig.Provider == AzureBlobFilesystemProvider { - if err := user.FsConfig.AzBlobConfig.Validate(); err != nil { + } else if fsConfig.Provider == vfs.AzureBlobFilesystemProvider { + if err := fsConfig.AzBlobConfig.Validate(); err != nil { return &ValidationError{err: fmt.Sprintf("could not validate Azure Blob config: %v", err)} } - if err := user.FsConfig.AzBlobConfig.EncryptCredentials(user.Username); err != nil { + if err := fsConfig.AzBlobConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil { return &ValidationError{err: fmt.Sprintf("could not encrypt Azure blob account key: %v", err)} } - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.CryptConfig = vfs.CryptFsConfig{} - user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + fsConfig.S3Config = vfs.S3FsConfig{} + fsConfig.GCSConfig = vfs.GCSFsConfig{} + fsConfig.CryptConfig = vfs.CryptFsConfig{} + fsConfig.SFTPConfig = vfs.SFTPFsConfig{} return nil - } else if user.FsConfig.Provider == CryptedFilesystemProvider { - if err := user.FsConfig.CryptConfig.Validate(); err != nil { + } else if fsConfig.Provider == vfs.CryptedFilesystemProvider { + if err := fsConfig.CryptConfig.Validate(); err != nil { return &ValidationError{err: fmt.Sprintf("could not validate Crypt fs config: %v", err)} } - if err := user.FsConfig.CryptConfig.EncryptCredentials(user.Username); err != nil { + if err := fsConfig.CryptConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil { return &ValidationError{err: fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)} } - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} - user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + fsConfig.S3Config = vfs.S3FsConfig{} + fsConfig.GCSConfig = vfs.GCSFsConfig{} + fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + fsConfig.SFTPConfig = vfs.SFTPFsConfig{} return nil - } else if user.FsConfig.Provider == SFTPFilesystemProvider { - if err := user.FsConfig.SFTPConfig.Validate(); err != nil { + } else if fsConfig.Provider == vfs.SFTPFilesystemProvider { + if err := fsConfig.SFTPConfig.Validate(); err != nil { return &ValidationError{err: fmt.Sprintf("could not validate SFTP fs config: %v", err)} } - if err := user.FsConfig.SFTPConfig.EncryptCredentials(user.Username); err != nil { + if err := fsConfig.SFTPConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil { return &ValidationError{err: fmt.Sprintf("could not encrypt SFTP fs credentials: %v", err)} } - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} - user.FsConfig.CryptConfig = vfs.CryptFsConfig{} + fsConfig.S3Config = vfs.S3FsConfig{} + fsConfig.GCSConfig = vfs.GCSFsConfig{} + fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + fsConfig.CryptConfig = vfs.CryptFsConfig{} return nil } - user.FsConfig.Provider = LocalFilesystemProvider - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} - user.FsConfig.CryptConfig = vfs.CryptFsConfig{} - user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + fsConfig.Provider = vfs.LocalFilesystemProvider + fsConfig.S3Config = vfs.S3FsConfig{} + fsConfig.GCSConfig = vfs.GCSFsConfig{} + fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + fsConfig.CryptConfig = vfs.CryptFsConfig{} + fsConfig.SFTPConfig = vfs.SFTPFsConfig{} return nil } @@ -1447,11 +1486,23 @@ func ValidateFolder(folder *vfs.BaseVirtualFolder) error { return &ValidationError{err: fmt.Sprintf("folder name %#v is not valid, the following characters are allowed: a-zA-Z0-9-_.~", folder.Name)} } - cleanedMPath := filepath.Clean(folder.MappedPath) - if !filepath.IsAbs(cleanedMPath) { - return &ValidationError{err: fmt.Sprintf("invalid folder mapped path %#v", folder.MappedPath)} + if folder.FsConfig.Provider == vfs.LocalFilesystemProvider || folder.FsConfig.Provider == vfs.CryptedFilesystemProvider || + folder.MappedPath != "" { + cleanedMPath := filepath.Clean(folder.MappedPath) + if !filepath.IsAbs(cleanedMPath) { + return &ValidationError{err: fmt.Sprintf("invalid folder mapped path %#v", folder.MappedPath)} + } + folder.MappedPath = cleanedMPath + } + if folder.HasRedactedSecret() { + return errors.New("cannot save a folder with a redacted secret") + } + if err := validateFilesystemConfig(&folder.FsConfig, folder); err != nil { + return err + } + if err := saveGCSCredentials(&folder.FsConfig, folder); err != nil { + return err } - folder.MappedPath = cleanedMPath return nil } @@ -1466,7 +1517,10 @@ func ValidateUser(user *User) error { if err := validatePermissions(user); err != nil { return err } - if err := validateFilesystemConfig(user); err != nil { + if user.hasRedactedSecret() { + return errors.New("cannot save a user with a redacted secret") + } + if err := validateFilesystemConfig(&user.FsConfig, user); err != nil { return err } if err := validateUserVirtualFolders(user); err != nil { @@ -1484,7 +1538,7 @@ func ValidateUser(user *User) error { if err := validateFilters(user); err != nil { return err } - if err := saveGCSCredentials(user); err != nil { + if err := saveGCSCredentials(&user.FsConfig, user); err != nil { return err } return nil @@ -1555,12 +1609,12 @@ func checkUserAndPass(user *User, password, ip, protocol string) (User, error) { return *user, err } if user.Password == "" { - return *user, errors.New("Credentials cannot be null or empty") + return *user, errors.New("credentials cannot be null or empty") } hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol) if err != nil { providerLog(logger.LevelDebug, "error executing check password hook: %v", err) - return *user, errors.New("Unable to check credentials") + return *user, errors.New("unable to check credentials") } switch hookResponse.Status { case -1: @@ -1664,7 +1718,10 @@ func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) } func addCredentialsToUser(user *User) error { - if user.FsConfig.Provider != GCSFilesystemProvider { + if err := addFolderCredentialsToUser(user); err != nil { + return err + } + if user.FsConfig.Provider != vfs.GCSFilesystemProvider { return nil } if user.FsConfig.GCSConfig.AutomaticCredentials > 0 { @@ -1676,13 +1733,38 @@ func addCredentialsToUser(user *User) error { return nil } - cred, err := os.ReadFile(user.getGCSCredentialsFilePath()) + cred, err := os.ReadFile(user.GetGCSCredentialsFilePath()) if err != nil { return err } return json.Unmarshal(cred, &user.FsConfig.GCSConfig.Credentials) } +func addFolderCredentialsToUser(user *User) error { + for idx := range user.VirtualFolders { + f := &user.VirtualFolders[idx] + if f.FsConfig.Provider != vfs.GCSFilesystemProvider { + continue + } + if f.FsConfig.GCSConfig.AutomaticCredentials > 0 { + continue + } + // Don't read from file if credentials have already been set + if f.FsConfig.GCSConfig.Credentials.IsValid() { + continue + } + cred, err := os.ReadFile(f.GetGCSCredentialsFilePath()) + if err != nil { + return err + } + err = json.Unmarshal(cred, f.FsConfig.GCSConfig.Credentials) + if err != nil { + return err + } + } + return nil +} + func getSSLMode() string { if config.Driver == PGSQLDataProviderName { if config.SSLMode == 0 { @@ -2014,7 +2096,9 @@ func executeCheckPasswordHook(username, password, ip, protocol string) (checkPas return response, nil } + startTime := time.Now() out, err := getPasswordHookResponse(username, password, ip, protocol) + providerLog(logger.LevelDebug, "check password hook executed, error: %v, elapsed: %v", err, time.Since(startTime)) if err != nil { return response, err } @@ -2078,10 +2162,12 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro if err != nil { return u, err } + startTime := time.Now() out, err := getPreLoginHookResponse(loginMethod, ip, protocol, userAsJSON) if err != nil { - return u, fmt.Errorf("Pre-login hook error: %v", err) + return u, fmt.Errorf("pre-login hook error: %v, elapsed %v", err, time.Since(startTime)) } + providerLog(logger.LevelDebug, "pre-login hook completed, elapsed: %v", time.Since(startTime)) if strings.TrimSpace(string(out)) == "" { providerLog(logger.LevelDebug, "empty response from pre-login hook, no modification requested for user %#v id: %v", username, u.ID) @@ -2098,7 +2184,7 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro userLastLogin := u.LastLogin err = json.Unmarshal(out, &u) if err != nil { - return u, fmt.Errorf("Invalid pre-login hook response %#v, error: %v", string(out), err) + return u, fmt.Errorf("invalid pre-login hook response %#v, error: %v", string(out), err) } u.ID = userID u.UsedQuotaSize = userUsedQuotaSize @@ -2214,9 +2300,11 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, return result, err } defer resp.Body.Close() + providerLog(logger.LevelDebug, "external auth hook executed, response code: %v", resp.StatusCode) if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("wrong external auth http status code: %v, expected 200", resp.StatusCode) } + return io.ReadAll(resp.Body) } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -2250,13 +2338,15 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv return user, err } } + startTime := time.Now() out, err := getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol, cert) if err != nil { - return user, fmt.Errorf("External auth error: %v", err) + return user, fmt.Errorf("external auth error: %v, elapsed: %v", err, time.Since(startTime)) } + providerLog(logger.LevelDebug, "external auth completed, elapsed: %v", time.Since(startTime)) err = json.Unmarshal(out, &user) if err != nil { - return user, fmt.Errorf("Invalid external auth response: %v", err) + return user, fmt.Errorf("invalid external auth response: %v", err) } if user.Username == "" { return user, ErrInvalidCredentials diff --git a/dataprovider/memory.go b/dataprovider/memory.go index 42d722b8..c587f4a2 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -105,7 +105,7 @@ func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCe func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { var user User if password == "" { - return user, errors.New("Credentials cannot be null or empty") + return user, errors.New("credentials cannot be null or empty") } user, err := p.userExists(username) if err != nil { @@ -118,7 +118,7 @@ func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol st func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) { var user User if len(pubKey) == 0 { - return user, "", errors.New("Credentials cannot be null or empty") + return user, "", errors.New("credentials cannot be null or empty") } user, err := p.userExists(username) if err != nil { @@ -548,15 +548,12 @@ func (p *MemoryProvider) getUsedFolderQuota(name string) (int, int64, error) { func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolder { var folders []vfs.VirtualFolder - for _, folder := range user.VirtualFolders { - f, err := p.addOrGetFolderInternal(folder.Name, folder.MappedPath, user.Username) + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + f, err := p.addOrUpdateFolderInternal(&folder.BaseVirtualFolder, user.Username, 0, 0, 0) if err == nil { - folder.UsedQuotaFiles = f.UsedQuotaFiles - folder.UsedQuotaSize = f.UsedQuotaSize - folder.LastQuotaUpdate = f.LastQuotaUpdate - folder.ID = f.ID - folder.MappedPath = f.MappedPath - folders = append(folders, folder) + folder.BaseVirtualFolder = f + folders = append(folders, *folder) } } return folders @@ -584,24 +581,29 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold } } -func (p *MemoryProvider) addOrGetFolderInternal(folderName, folderMappedPath, username string) (vfs.BaseVirtualFolder, error) { - folder, err := p.folderExistsInternal(folderName) - if _, ok := err.(*RecordNotFoundError); ok { - folder := vfs.BaseVirtualFolder{ - ID: p.getNextFolderID(), - Name: folderName, - MappedPath: folderMappedPath, - UsedQuotaSize: 0, - UsedQuotaFiles: 0, - LastQuotaUpdate: 0, - Users: []string{username}, +func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64, + usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) { + folder, err := p.folderExistsInternal(baseFolder.Name) + if err == nil { + // exists + folder.MappedPath = baseFolder.MappedPath + folder.Description = baseFolder.Description + folder.FsConfig = baseFolder.FsConfig.GetACopy() + if !utils.IsStringInSlice(username, folder.Users) { + folder.Users = append(folder.Users, username) } p.updateFoldersMappingInternal(folder) return folder, nil } - if err == nil && !utils.IsStringInSlice(username, folder.Users) { - folder.Users = append(folder.Users, username) + if _, ok := err.(*RecordNotFoundError); ok { + folder = baseFolder.GetACopy() + folder.ID = p.getNextFolderID() + folder.UsedQuotaSize = usedQuotaSize + folder.UsedQuotaFiles = usedQuotaFiles + folder.LastQuotaUpdate = lastQuotaUpdate + folder.Users = []string{username} p.updateFoldersMappingInternal(folder) + return folder, nil } return folder, err } @@ -631,7 +633,9 @@ func (p *MemoryProvider) getFolders(limit, offset int, order string) ([]vfs.Base if itNum <= offset { continue } - folder := p.dbHandle.vfolders[name] + f := p.dbHandle.vfolders[name] + folder := f.GetACopy() + folder.HideConfidentialData() folders = append(folders, folder) if len(folders) >= limit { break @@ -644,7 +648,9 @@ func (p *MemoryProvider) getFolders(limit, offset int, order string) ([]vfs.Base continue } name := p.dbHandle.vfoldersNames[i] - folder := p.dbHandle.vfolders[name] + f := p.dbHandle.vfolders[name] + folder := f.GetACopy() + folder.HideConfidentialData() folders = append(folders, folder) if len(folders) >= limit { break @@ -660,7 +666,11 @@ func (p *MemoryProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, er if p.dbHandle.isClosed { return vfs.BaseVirtualFolder{}, errMemoryProviderClosed } - return p.folderExistsInternal(name) + folder, err := p.folderExistsInternal(name) + if err != nil { + return vfs.BaseVirtualFolder{}, err + } + return folder.GetACopy(), nil } func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error { @@ -708,6 +718,22 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { folder.UsedQuotaSize = f.UsedQuotaSize folder.Users = f.Users p.dbHandle.vfolders[folder.Name] = folder.GetACopy() + // now update the related users + for _, username := range folder.Users { + user, err := p.userExistsInternal(username) + if err == nil { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + userFolder := &user.VirtualFolders[idx] + if folder.Name == userFolder.Name { + userFolder.BaseVirtualFolder = folder.GetACopy() + } + folders = append(folders, *userFolder) + } + user.VirtualFolders = folders + p.dbHandle.users[user.Username] = user + } + } return nil } @@ -726,9 +752,10 @@ func (p *MemoryProvider) deleteFolder(folder *vfs.BaseVirtualFolder) error { user, err := p.userExistsInternal(username) if err == nil { var folders []vfs.VirtualFolder - for _, userFolder := range user.VirtualFolders { + for idx := range user.VirtualFolders { + userFolder := &user.VirtualFolders[idx] if folder.Name != userFolder.Name { - folders = append(folders, userFolder) + folders = append(folders, *userFolder) } } user.VirtualFolders = folders diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 9d4cee10..3e0f58c9 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -257,7 +257,7 @@ func (p *MySQLProvider) migrateDatabase() error { sqlDatabaseVersion) return nil } - return fmt.Errorf("Database version not handled: %v", version) + return fmt.Errorf("database version not handled: %v", version) } } @@ -274,7 +274,7 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { case 9: return downgradeMySQLDatabaseFromV9(p.dbHandle) default: - return fmt.Errorf("Database version not handled: %v", dbVersion.Version) + return fmt.Errorf("database version not handled: %v", dbVersion.Version) } } diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index b79c5d78..7cbd36eb 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -265,7 +265,7 @@ func (p *PGSQLProvider) migrateDatabase() error { sqlDatabaseVersion) return nil } - return fmt.Errorf("Database version not handled: %v", version) + return fmt.Errorf("database version not handled: %v", version) } } @@ -282,7 +282,7 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { case 9: return downgradePGSQLDatabaseFromV9(p.dbHandle) default: - return fmt.Errorf("Database version not handled: %v", dbVersion.Version) + return fmt.Errorf("database version not handled: %v", dbVersion.Version) } } diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index c63516c9..468614b7 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -217,7 +217,7 @@ func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, err func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) { var user User if password == "" { - return user, errors.New("Credentials cannot be null or empty") + return user, errors.New("credentials cannot be null or empty") } user, err := sqlCommonGetUserByUsername(username, dbHandle) if err != nil { @@ -243,7 +243,7 @@ func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert * func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) { var user User if len(pubKey) == 0 { - return user, "", errors.New("Credentials cannot be null or empty") + return user, "", errors.New("credentials cannot be null or empty") } user, err := sqlCommonGetUserByUsername(username, dbHandle) if err != nil { @@ -587,7 +587,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) { } } if fsConfig.Valid { - var fs Filesystem + var fs vfs.Filesystem err = json.Unmarshal([]byte(fsConfig.String), &fs) if err == nil { user.FsConfig = fs @@ -603,7 +603,20 @@ func getUserFromDbRow(row sqlScanner) (User, error) { return user, err } -func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { +func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error { + var folderName string + q := checkFolderNameQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err) + return err + } + defer stmt.Close() + row := stmt.QueryRowContext(ctx, name) + return row.Scan(&folderName) +} + +func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { var folder vfs.BaseVirtualFolder q := getFolderByNameQuery() stmt, err := dbHandle.PrepareContext(ctx, q) @@ -613,9 +626,9 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu } defer stmt.Close() row := stmt.QueryRowContext(ctx, name) - var mappedPath, description sql.NullString + var mappedPath, description, fsConfig sql.NullString err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, - &folder.Name, &description) + &folder.Name, &description, &fsConfig) if err == sql.ErrNoRows { return folder, &RecordNotFoundError{err: err.Error()} } @@ -625,11 +638,18 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu if description.Valid { folder.Description = description.String } + if fsConfig.Valid { + var fs vfs.Filesystem + err = json.Unmarshal([]byte(fsConfig.String), &fs) + if err == nil { + folder.FsConfig = fs + } + } return folder, err } func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { - folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle) + folder, err := sqlCommonGetFolder(ctx, name, dbHandle) if err != nil { return folder, err } @@ -643,23 +663,30 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer return folders[0], nil } -func sqlCommonAddOrGetFolder(ctx context.Context, baseFolder vfs.BaseVirtualFolder, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { - folder, err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle) - if _, ok := err.(*RecordNotFoundError); ok { - f := &vfs.BaseVirtualFolder{ - Name: baseFolder.Name, - MappedPath: baseFolder.MappedPath, - UsedQuotaSize: usedQuotaSize, - UsedQuotaFiles: usedQuotaFiles, - LastQuotaUpdate: lastQuotaUpdate, - } - err = sqlCommonAddFolder(f, dbHandle) +func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64, + usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { + var folder vfs.BaseVirtualFolder + // FIXME: we could use an UPSERT here, this SELECT could be racy + err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle) + switch err { + case nil: + err = sqlCommonUpdateFolder(baseFolder, dbHandle) if err != nil { return folder, err } - return sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle) + case sql.ErrNoRows: + baseFolder.UsedQuotaFiles = usedQuotaFiles + baseFolder.UsedQuotaSize = usedQuotaSize + baseFolder.LastQuotaUpdate = lastQuotaUpdate + err = sqlCommonAddFolder(baseFolder, dbHandle) + if err != nil { + return folder, err + } + default: + return folder, err } - return folder, err + + return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle) } func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { @@ -667,6 +694,10 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro if err != nil { return err } + fsConfig, err := json.Marshal(folder.FsConfig) + if err != nil { + return err + } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getAddFolderQuery() @@ -677,15 +708,19 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro } defer stmt.Close() _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles, - folder.LastQuotaUpdate, folder.Name, folder.Description) + folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig)) return err } -func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) error { +func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { err := ValidateFolder(folder) if err != nil { return err } + fsConfig, err := json.Marshal(folder.FsConfig) + if err != nil { + return err + } ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getUpdateFolderQuery() @@ -695,7 +730,7 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) erro return err } defer stmt.Close() - _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, folder.Name) + _, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, string(fsConfig), folder.Name) return err } @@ -731,9 +766,9 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) defer rows.Close() for rows.Next() { var folder vfs.BaseVirtualFolder - var mappedPath, description sql.NullString + var mappedPath, description, fsConfig sql.NullString err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, - &folder.LastQuotaUpdate, &folder.Name, &description) + &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) if err != nil { return folders, err } @@ -743,6 +778,13 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) if description.Valid { folder.Description = description.String } + if fsConfig.Valid { + var fs vfs.Filesystem + err = json.Unmarshal([]byte(fsConfig.String), &fs) + if err == nil { + folder.FsConfig = fs + } + } folders = append(folders, folder) } err = rows.Err() @@ -771,9 +813,9 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) ( defer rows.Close() for rows.Next() { var folder vfs.BaseVirtualFolder - var mappedPath, description sql.NullString + var mappedPath, description, fsConfig sql.NullString err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, - &folder.LastQuotaUpdate, &folder.Name, &description) + &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) if err != nil { return folders, err } @@ -783,6 +825,14 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) ( if description.Valid { folder.Description = description.String } + if fsConfig.Valid { + var fs vfs.Filesystem + err = json.Unmarshal([]byte(fsConfig.String), &fs) + if err == nil { + folder.FsConfig = fs + } + } + folder.HideConfidentialData() folders = append(folders, folder) } @@ -805,7 +855,7 @@ func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQu return err } -func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error { +func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error { q := getAddFolderMappingQuery() stmt, err := dbHandle.PrepareContext(ctx, q) if err != nil { @@ -822,8 +872,9 @@ func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sql if err != nil { return err } - for _, vfolder := range user.VirtualFolders { - f, err := sqlCommonAddOrGetFolder(ctx, vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) + for idx := range user.VirtualFolders { + vfolder := &user.VirtualFolders[idx] + f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle) if err != nil { return err } @@ -870,15 +921,26 @@ func getUsersWithVirtualFolders(users []User, dbHandle sqlQuerier) ([]User, erro for rows.Next() { var folder vfs.VirtualFolder var userID int64 - var mappedPath sql.NullString + var mappedPath, fsConfig, description sql.NullString err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, - &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID) + &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig, + &description) if err != nil { return users, err } if mappedPath.Valid { folder.MappedPath = mappedPath.String } + if description.Valid { + folder.Description = description.String + } + if fsConfig.Valid { + var fs vfs.Filesystem + err = json.Unmarshal([]byte(fsConfig.String), &fs) + if err == nil { + folder.FsConfig = fs + } + } usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder) } err = rows.Err() diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 066a3d71..2eed8ca4 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -95,7 +95,7 @@ func initializeSQLiteProvider(basePath string) error { if config.ConnectionString == "" { dbPath := config.Name if !utils.IsFileInputValid(dbPath) { - return fmt.Errorf("Invalid database path: %#v", dbPath) + return fmt.Errorf("invalid database path: %#v", dbPath) } if !filepath.IsAbs(dbPath) { dbPath = filepath.Join(basePath, dbPath) @@ -278,7 +278,7 @@ func (p *SQLiteProvider) migrateDatabase() error { sqlDatabaseVersion) return nil } - return fmt.Errorf("Database version not handled: %v", version) + return fmt.Errorf("database version not handled: %v", version) } } @@ -295,7 +295,7 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { case 9: return downgradeSQLiteDatabaseFromV9(p.dbHandle) default: - return fmt.Errorf("Database version not handled: %v", dbVersion.Version) + return fmt.Errorf("database version not handled: %v", dbVersion.Version) } } diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 86348c04..30686242 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -12,7 +12,7 @@ const ( selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," + "used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," + "additional_info,description" - selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description" + selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem" selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description" ) @@ -119,15 +119,19 @@ func getFolderByNameQuery() string { return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0]) } +func checkFolderNameQuery() string { + return fmt.Sprintf(`SELECT name FROM %v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0]) +} + func getAddFolderQuery() string { - return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description) VALUES (%v,%v,%v,%v,%v,%v)`, - sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], - sqlPlaceholders[5]) + return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem) + VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) } func getUpdateFolderQuery() string { - return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], - sqlPlaceholders[1], sqlPlaceholders[2]) + return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v,filesystem=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteFolderQuery() string { @@ -177,9 +181,9 @@ func getRelatedFoldersForUsersQuery(users []User) string { if sb.Len() > 0 { sb.WriteString(")") } - return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,fm.quota_size,fm.quota_files,fm.user_id - FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders, - sqlTableFoldersMapping, sb.String()) + return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, + fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE + fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders, sqlTableFoldersMapping, sb.String()) } func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string { diff --git a/dataprovider/user.go b/dataprovider/user.go index d4ccced5..31771629 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -161,29 +161,6 @@ type UserFilters struct { TLSUsername TLSUsername `json:"tls_username,omitempty"` } -// FilesystemProvider defines the supported storages -type FilesystemProvider int - -// supported values for FilesystemProvider -const ( - LocalFilesystemProvider FilesystemProvider = iota // Local - S3FilesystemProvider // AWS S3 compatible - GCSFilesystemProvider // Google Cloud Storage - AzureBlobFilesystemProvider // Azure Blob Storage - CryptedFilesystemProvider // Local encrypted - SFTPFilesystemProvider // SFTP -) - -// Filesystem defines cloud storage filesystem details -type Filesystem struct { - Provider FilesystemProvider `json:"provider"` - S3Config vfs.S3FsConfig `json:"s3config,omitempty"` - GCSConfig vfs.GCSFsConfig `json:"gcsconfig,omitempty"` - AzBlobConfig vfs.AzBlobFsConfig `json:"azblobconfig,omitempty"` - CryptConfig vfs.CryptFsConfig `json:"cryptconfig,omitempty"` - SFTPConfig vfs.SFTPFsConfig `json:"sftpconfig,omitempty"` -} - // User defines a SFTPGo user type User struct { // Database unique identifier @@ -203,8 +180,7 @@ type User struct { PublicKeys []string `json:"public_keys,omitempty"` // The user cannot upload or download files outside this directory. Must be an absolute path HomeDir string `json:"home_dir"` - // Mapping between virtual paths and filesystem paths outside the home directory. - // Supported for local filesystem only + // Mapping between virtual paths and virtual folders VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"` // If sftpgo runs as root system user then the created files and directories will be assigned to this system UID UID int `json:"uid"` @@ -233,49 +209,148 @@ type User struct { // Additional restrictions Filters UserFilters `json:"filters"` // Filesystem configuration details - FsConfig Filesystem `json:"filesystem"` + FsConfig vfs.Filesystem `json:"filesystem"` // optional description, for example full name Description string `json:"description,omitempty"` // free form text field for external systems AdditionalInfo string `json:"additional_info,omitempty"` + // we store the filesystem here using the base path as key. + fsCache map[string]vfs.Fs `json:"-"` } -// GetFilesystem returns the filesystem for this user -func (u *User) GetFilesystem(connectionID string) (vfs.Fs, error) { - switch u.FsConfig.Provider { - case S3FilesystemProvider: - return vfs.NewS3Fs(connectionID, u.GetHomeDir(), u.FsConfig.S3Config) - case GCSFilesystemProvider: - config := u.FsConfig.GCSConfig - config.CredentialFile = u.getGCSCredentialsFilePath() - return vfs.NewGCSFs(connectionID, u.GetHomeDir(), config) - case AzureBlobFilesystemProvider: - return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), u.FsConfig.AzBlobConfig) - case CryptedFilesystemProvider: - return vfs.NewCryptFs(connectionID, u.GetHomeDir(), u.FsConfig.CryptConfig) - case SFTPFilesystemProvider: - return vfs.NewSFTPFs(connectionID, u.FsConfig.SFTPConfig) - default: - return vfs.NewOsFs(connectionID, u.GetHomeDir(), u.VirtualFolders), nil +// GetFilesystem returns the base filesystem for this user +func (u *User) GetFilesystem(connectionID string) (fs vfs.Fs, err error) { + fs, err = u.getRootFs(connectionID) + if err != nil { + return fs, err } + u.fsCache = make(map[string]vfs.Fs) + u.fsCache["/"] = fs + return fs, err +} + +func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) { + switch u.FsConfig.Provider { + case vfs.S3FilesystemProvider: + return vfs.NewS3Fs(connectionID, u.GetHomeDir(), "", u.FsConfig.S3Config) + case vfs.GCSFilesystemProvider: + config := u.FsConfig.GCSConfig + config.CredentialFile = u.GetGCSCredentialsFilePath() + return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", config) + case vfs.AzureBlobFilesystemProvider: + return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), "", u.FsConfig.AzBlobConfig) + case vfs.CryptedFilesystemProvider: + return vfs.NewCryptFs(connectionID, u.GetHomeDir(), "", u.FsConfig.CryptConfig) + case vfs.SFTPFilesystemProvider: + return vfs.NewSFTPFs(connectionID, "", u.FsConfig.SFTPConfig) + default: + return vfs.NewOsFs(connectionID, u.GetHomeDir(), ""), nil + } +} + +// CheckFsRoot check the root directory for the main fs and the virtual folders. +// It returns an error if the main filesystem cannot be created +func (u *User) CheckFsRoot(connectionID string) error { + fs, err := u.GetFilesystemForPath("/", connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "could not create main filesystem for user %#v err: %v", u.Username, err) + return err + } + fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + fs, err = u.GetFilesystemForPath(v.VirtualPath, connectionID) + if err == nil { + fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) + } + // now check intermediary folders + fs, err = u.GetFilesystemForPath(path.Dir(v.VirtualPath), connectionID) + if err == nil && !fs.HasVirtualFolders() { + fsPath, err := fs.ResolvePath(v.VirtualPath) + if err != nil { + continue + } + err = fs.MkdirAll(fsPath, u.GetUID(), u.GetGID()) + logger.Debug(logSender, connectionID, "create intermediary dir to %#v, path %#v, err: %v", + v.VirtualPath, fsPath, err) + } + } + return nil } // HideConfidentialData hides user confidential data func (u *User) HideConfidentialData() { u.Password = "" switch u.FsConfig.Provider { - case S3FilesystemProvider: + case vfs.S3FilesystemProvider: u.FsConfig.S3Config.AccessSecret.Hide() - case GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: u.FsConfig.GCSConfig.Credentials.Hide() - case AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: u.FsConfig.AzBlobConfig.AccountKey.Hide() - case CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: u.FsConfig.CryptConfig.Passphrase.Hide() - case SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: u.FsConfig.SFTPConfig.Password.Hide() u.FsConfig.SFTPConfig.PrivateKey.Hide() } + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + folder.HideConfidentialData() + } +} + +func (u *User) hasRedactedSecret() bool { + switch u.FsConfig.Provider { + case vfs.S3FilesystemProvider: + if u.FsConfig.S3Config.AccessSecret.IsRedacted() { + return true + } + case vfs.GCSFilesystemProvider: + if u.FsConfig.GCSConfig.Credentials.IsRedacted() { + return true + } + case vfs.AzureBlobFilesystemProvider: + if u.FsConfig.AzBlobConfig.AccountKey.IsRedacted() { + return true + } + case vfs.CryptedFilesystemProvider: + if u.FsConfig.CryptConfig.Passphrase.IsRedacted() { + return true + } + case vfs.SFTPFilesystemProvider: + if u.FsConfig.SFTPConfig.Password.IsRedacted() { + return true + } + if u.FsConfig.SFTPConfig.PrivateKey.IsRedacted() { + return true + } + } + + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + if folder.HasRedactedSecret() { + return true + } + } + + return false +} + +// CloseFs closes the underlying filesystems +func (u *User) CloseFs() error { + if u.fsCache == nil { + return nil + } + + var err error + for _, fs := range u.fsCache { + errClose := fs.Close() + if err == nil { + err = errClose + } + } + return err } // IsPasswordHashed returns true if the password is hashed @@ -300,41 +375,10 @@ func (u *User) SetEmptySecrets() { u.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() u.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret() -} - -// DecryptSecrets tries to decrypts kms secrets -func (u *User) DecryptSecrets() error { - switch u.FsConfig.Provider { - case S3FilesystemProvider: - if u.FsConfig.S3Config.AccessSecret.IsEncrypted() { - return u.FsConfig.S3Config.AccessSecret.Decrypt() - } - case GCSFilesystemProvider: - if u.FsConfig.GCSConfig.Credentials.IsEncrypted() { - return u.FsConfig.GCSConfig.Credentials.Decrypt() - } - case AzureBlobFilesystemProvider: - if u.FsConfig.AzBlobConfig.AccountKey.IsEncrypted() { - return u.FsConfig.AzBlobConfig.AccountKey.Decrypt() - } - case CryptedFilesystemProvider: - if u.FsConfig.CryptConfig.Passphrase.IsEncrypted() { - return u.FsConfig.CryptConfig.Passphrase.Decrypt() - } - case SFTPFilesystemProvider: - if u.FsConfig.SFTPConfig.Password.IsEncrypted() { - if err := u.FsConfig.SFTPConfig.Password.Decrypt(); err != nil { - return err - } - } - if u.FsConfig.SFTPConfig.PrivateKey.IsEncrypted() { - if err := u.FsConfig.SFTPConfig.PrivateKey.Decrypt(); err != nil { - return err - } - } + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + folder.FsConfig.SetEmptySecretsIfNil() } - - return nil } // GetPermissionsForPath returns the permissions for the given path. @@ -349,13 +393,13 @@ func (u *User) GetPermissionsForPath(p string) []string { // fallback permissions permissions = perms } - dirsForPath := utils.GetDirsForSFTPPath(p) + dirsForPath := utils.GetDirsForVirtualPath(p) // dirsForPath contains all the dirs for a given path in reverse order // for example if the path is: /1/2/3/4 it contains: // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] // so the first match is the one we are interested to - for _, val := range dirsForPath { - if perms, ok := u.Permissions[val]; ok { + for idx := range dirsForPath { + if perms, ok := u.Permissions[dirsForPath[idx]]; ok { permissions = perms break } @@ -363,44 +407,120 @@ func (u *User) GetPermissionsForPath(p string) []string { return permissions } -// GetVirtualFolderForPath returns the virtual folder containing the specified sftp path. +// GetFilesystemForPath returns the filesystem for the given path +func (u *User) GetFilesystemForPath(virtualPath, connectionID string) (vfs.Fs, error) { + if u.fsCache == nil { + u.fsCache = make(map[string]vfs.Fs) + } + if virtualPath != "" && virtualPath != "/" && len(u.VirtualFolders) > 0 { + folder, err := u.GetVirtualFolderForPath(virtualPath) + if err == nil { + if fs, ok := u.fsCache[folder.VirtualPath]; ok { + return fs, nil + } + fs, err := folder.GetFilesystem(connectionID) + if err == nil { + u.fsCache[folder.VirtualPath] = fs + } + return fs, err + } + } + + if val, ok := u.fsCache["/"]; ok { + return val, nil + } + + return u.GetFilesystem(connectionID) +} + +// GetVirtualFolderForPath returns the virtual folder containing the specified virtual path. // If the path is not inside a virtual folder an error is returned -func (u *User) GetVirtualFolderForPath(sftpPath string) (vfs.VirtualFolder, error) { +func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, error) { var folder vfs.VirtualFolder - if len(u.VirtualFolders) == 0 || u.FsConfig.Provider != LocalFilesystemProvider { + if len(u.VirtualFolders) == 0 { return folder, errNoMatchingVirtualFolder } - dirsForPath := utils.GetDirsForSFTPPath(sftpPath) - for _, val := range dirsForPath { - for _, v := range u.VirtualFolders { - if v.VirtualPath == val { - return v, nil + dirsForPath := utils.GetDirsForVirtualPath(virtualPath) + for index := range dirsForPath { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if v.VirtualPath == dirsForPath[index] { + return *v, nil } } } return folder, errNoMatchingVirtualFolder } +// ScanQuota scans the user home dir and virtual folders, included in its quota, +// and returns the number of files and their size +func (u *User) ScanQuota() (int, int64, error) { + fs, err := u.getRootFs("") + if err != nil { + return 0, 0, err + } + defer fs.Close() + numFiles, size, err := fs.ScanRootDirContents() + if err != nil { + return numFiles, size, err + } + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if !v.IsIncludedInUserQuota() { + continue + } + num, s, err := v.ScanQuota() + if err != nil { + return numFiles, size, err + } + numFiles += num + size += s + } + + return numFiles, size, nil +} + +// GetVirtualFoldersInPath returns the virtual folders inside virtualPath including +// any parents +func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool { + result := make(map[string]bool) + + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + dirsForPath := utils.GetDirsForVirtualPath(v.VirtualPath) + for index := range dirsForPath { + d := dirsForPath[index] + if d == "/" { + continue + } + if path.Dir(d) == virtualPath { + result[d] = true + } + } + } + + return result +} + // AddVirtualDirs adds virtual folders, if defined, to the given files list -func (u *User) AddVirtualDirs(list []os.FileInfo, sftpPath string) []os.FileInfo { +func (u *User) AddVirtualDirs(list []os.FileInfo, virtualPath string) []os.FileInfo { if len(u.VirtualFolders) == 0 { return list } - for _, v := range u.VirtualFolders { - if path.Dir(v.VirtualPath) == sftpPath { - fi := vfs.NewFileInfo(v.VirtualPath, true, 0, time.Now(), false) - found := false - for index, f := range list { - if f.Name() == fi.Name() { - list[index] = fi - found = true - break - } - } - if !found { - list = append(list, fi) + + for dir := range u.GetVirtualFoldersInPath(virtualPath) { + fi := vfs.NewFileInfo(dir, true, 0, time.Now(), false) + found := false + for index := range list { + if list[index].Name() == fi.Name() { + list[index] = fi + found = true + break } } + if !found { + list = append(list, fi) + } } return list } @@ -408,7 +528,8 @@ func (u *User) AddVirtualDirs(list []os.FileInfo, sftpPath string) []os.FileInfo // IsMappedPath returns true if the specified filesystem path has a virtual folder mapping. // The filesystem path must be cleaned before calling this method func (u *User) IsMappedPath(fsPath string) bool { - for _, v := range u.VirtualFolders { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] if fsPath == v.MappedPath { return true } @@ -416,10 +537,11 @@ func (u *User) IsMappedPath(fsPath string) bool { return false } -// IsVirtualFolder returns true if the specified sftp path is a virtual folder -func (u *User) IsVirtualFolder(sftpPath string) bool { - for _, v := range u.VirtualFolders { - if sftpPath == v.VirtualPath { +// IsVirtualFolder returns true if the specified virtual path is a virtual folder +func (u *User) IsVirtualFolder(virtualPath string) bool { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if virtualPath == v.VirtualPath { return true } } @@ -427,14 +549,15 @@ func (u *User) IsVirtualFolder(sftpPath string) bool { } // HasVirtualFoldersInside returns true if there are virtual folders inside the -// specified SFTP path. We assume that path are cleaned -func (u *User) HasVirtualFoldersInside(sftpPath string) bool { - if sftpPath == "/" && len(u.VirtualFolders) > 0 { +// specified virtual path. We assume that path are cleaned +func (u *User) HasVirtualFoldersInside(virtualPath string) bool { + if virtualPath == "/" && len(u.VirtualFolders) > 0 { return true } - for _, v := range u.VirtualFolders { - if len(v.VirtualPath) > len(sftpPath) { - if strings.HasPrefix(v.VirtualPath, sftpPath+"/") { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if len(v.VirtualPath) > len(virtualPath) { + if strings.HasPrefix(v.VirtualPath, virtualPath+"/") { return true } } @@ -442,32 +565,14 @@ func (u *User) HasVirtualFoldersInside(sftpPath string) bool { return false } -// HasPermissionsInside returns true if the specified sftpPath has no permissions itself and +// HasPermissionsInside returns true if the specified virtualPath has no permissions itself and // no subdirs with defined permissions -func (u *User) HasPermissionsInside(sftpPath string) bool { +func (u *User) HasPermissionsInside(virtualPath string) bool { for dir := range u.Permissions { - if dir == sftpPath { + if dir == virtualPath { return true - } else if len(dir) > len(sftpPath) { - if strings.HasPrefix(dir, sftpPath+"/") { - return true - } - } - } - return false -} - -// HasOverlappedMappedPaths returns true if this user has virtual folders with overlapped mapped paths -func (u *User) HasOverlappedMappedPaths() bool { - if len(u.VirtualFolders) <= 1 { - return false - } - for _, v1 := range u.VirtualFolders { - for _, v2 := range u.VirtualFolders { - if v1.VirtualPath == v2.VirtualPath { - continue - } - if isMappedDirOverlapped(v1.MappedPath, v2.MappedPath) { + } else if len(dir) > len(virtualPath) { + if strings.HasPrefix(dir, virtualPath+"/") { return true } } @@ -585,7 +690,7 @@ func (u *User) isFileExtensionAllowed(virtualPath string) bool { if len(u.Filters.FileExtensions) == 0 { return true } - dirsForPath := utils.GetDirsForSFTPPath(path.Dir(virtualPath)) + dirsForPath := utils.GetDirsForVirtualPath(path.Dir(virtualPath)) var filter ExtensionsFilter for _, dir := range dirsForPath { for _, f := range u.Filters.FileExtensions { @@ -619,7 +724,7 @@ func (u *User) isFilePatternAllowed(virtualPath string) bool { if len(u.Filters.FilePatterns) == 0 { return true } - dirsForPath := utils.GetDirsForSFTPPath(path.Dir(virtualPath)) + dirsForPath := utils.GetDirsForVirtualPath(path.Dir(virtualPath)) var filter PatternsFilter for _, dir := range dirsForPath { for _, f := range u.Filters.FilePatterns { @@ -797,15 +902,15 @@ func (u *User) GetInfoString() string { result += fmt.Sprintf("Last login: %v ", t.Format("2006-01-02 15:04:05")) // YYYY-MM-DD HH:MM:SS } switch u.FsConfig.Provider { - case S3FilesystemProvider: + case vfs.S3FilesystemProvider: result += "Storage: S3 " - case GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: result += "Storage: GCS " - case AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: result += "Storage: Azure " - case CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: result += "Storage: Encrypted " - case SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: result += "Storage: SFTP " } if len(u.PublicKeys) > 0 { @@ -850,23 +955,10 @@ func (u *User) GetDeniedIPAsString() string { // SetEmptySecretsIfNil sets the secrets to empty if nil func (u *User) SetEmptySecretsIfNil() { - if u.FsConfig.S3Config.AccessSecret == nil { - u.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret() - } - if u.FsConfig.GCSConfig.Credentials == nil { - u.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() - } - if u.FsConfig.AzBlobConfig.AccountKey == nil { - u.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret() - } - if u.FsConfig.CryptConfig.Passphrase == nil { - u.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() - } - if u.FsConfig.SFTPConfig.Password == nil { - u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() - } - if u.FsConfig.SFTPConfig.PrivateKey == nil { - u.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret() + u.FsConfig.SetEmptySecretsIfNil() + for idx := range u.VirtualFolders { + vfolder := &u.VirtualFolders[idx] + vfolder.FsConfig.SetEmptySecretsIfNil() } } @@ -874,8 +966,11 @@ func (u *User) getACopy() User { u.SetEmptySecretsIfNil() pubKeys := make([]string, len(u.PublicKeys)) copy(pubKeys, u.PublicKeys) - virtualFolders := make([]vfs.VirtualFolder, len(u.VirtualFolders)) - copy(virtualFolders, u.VirtualFolders) + virtualFolders := make([]vfs.VirtualFolder, 0, len(u.VirtualFolders)) + for idx := range u.VirtualFolders { + vfolder := u.VirtualFolders[idx].GetACopy() + virtualFolders = append(virtualFolders, vfolder) + } permissions := make(map[string][]string) for k, v := range u.Permissions { perms := make([]string, len(v)) @@ -897,55 +992,6 @@ func (u *User) getACopy() User { copy(filters.FilePatterns, u.Filters.FilePatterns) filters.DeniedProtocols = make([]string, len(u.Filters.DeniedProtocols)) copy(filters.DeniedProtocols, u.Filters.DeniedProtocols) - fsConfig := Filesystem{ - Provider: u.FsConfig.Provider, - S3Config: vfs.S3FsConfig{ - Bucket: u.FsConfig.S3Config.Bucket, - Region: u.FsConfig.S3Config.Region, - AccessKey: u.FsConfig.S3Config.AccessKey, - AccessSecret: u.FsConfig.S3Config.AccessSecret.Clone(), - Endpoint: u.FsConfig.S3Config.Endpoint, - StorageClass: u.FsConfig.S3Config.StorageClass, - KeyPrefix: u.FsConfig.S3Config.KeyPrefix, - UploadPartSize: u.FsConfig.S3Config.UploadPartSize, - UploadConcurrency: u.FsConfig.S3Config.UploadConcurrency, - }, - GCSConfig: vfs.GCSFsConfig{ - Bucket: u.FsConfig.GCSConfig.Bucket, - CredentialFile: u.FsConfig.GCSConfig.CredentialFile, - Credentials: u.FsConfig.GCSConfig.Credentials.Clone(), - AutomaticCredentials: u.FsConfig.GCSConfig.AutomaticCredentials, - StorageClass: u.FsConfig.GCSConfig.StorageClass, - KeyPrefix: u.FsConfig.GCSConfig.KeyPrefix, - }, - AzBlobConfig: vfs.AzBlobFsConfig{ - Container: u.FsConfig.AzBlobConfig.Container, - AccountName: u.FsConfig.AzBlobConfig.AccountName, - AccountKey: u.FsConfig.AzBlobConfig.AccountKey.Clone(), - Endpoint: u.FsConfig.AzBlobConfig.Endpoint, - SASURL: u.FsConfig.AzBlobConfig.SASURL, - KeyPrefix: u.FsConfig.AzBlobConfig.KeyPrefix, - UploadPartSize: u.FsConfig.AzBlobConfig.UploadPartSize, - UploadConcurrency: u.FsConfig.AzBlobConfig.UploadConcurrency, - UseEmulator: u.FsConfig.AzBlobConfig.UseEmulator, - AccessTier: u.FsConfig.AzBlobConfig.AccessTier, - }, - CryptConfig: vfs.CryptFsConfig{ - Passphrase: u.FsConfig.CryptConfig.Passphrase.Clone(), - }, - SFTPConfig: vfs.SFTPFsConfig{ - Endpoint: u.FsConfig.SFTPConfig.Endpoint, - Username: u.FsConfig.SFTPConfig.Username, - Password: u.FsConfig.SFTPConfig.Password.Clone(), - PrivateKey: u.FsConfig.SFTPConfig.PrivateKey.Clone(), - Prefix: u.FsConfig.SFTPConfig.Prefix, - DisableCouncurrentReads: u.FsConfig.SFTPConfig.DisableCouncurrentReads, - }, - } - if len(u.FsConfig.SFTPConfig.Fingerprints) > 0 { - fsConfig.SFTPConfig.Fingerprints = make([]string, len(u.FsConfig.SFTPConfig.Fingerprints)) - copy(fsConfig.SFTPConfig.Fingerprints, u.FsConfig.SFTPConfig.Fingerprints) - } return User{ ID: u.ID, @@ -969,7 +1015,7 @@ func (u *User) getACopy() User { ExpirationDate: u.ExpirationDate, LastLogin: u.LastLogin, Filters: filters, - FsConfig: fsConfig, + FsConfig: u.FsConfig.GetACopy(), AdditionalInfo: u.AdditionalInfo, Description: u.Description, } @@ -986,6 +1032,12 @@ func (u *User) getNotificationFieldsAsSlice(action string) []string { } } -func (u *User) getGCSCredentialsFilePath() string { +// GetEncrytionAdditionalData returns the additional data to use for AEAD +func (u *User) GetEncrytionAdditionalData() string { + return u.Username +} + +// GetGCSCredentialsFilePath returns the path for GCS credentials +func (u *User) GetGCSCredentialsFilePath() string { return filepath.Join(credentialsDirPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) } diff --git a/docs/ssh-commands.md b/docs/ssh-commands.md index 1dffc964..9529974e 100644 --- a/docs/ssh-commands.md +++ b/docs/ssh-commands.md @@ -36,7 +36,7 @@ SFTPGo supports the following built-in SSH commands: - `md5sum`, `sha1sum`, `sha256sum`, `sha384sum`, `sha512sum`. Useful to check message digests for uploaded files. - `cd`, `pwd`. Some SFTP clients do not support the SFTP SSH_FXP_REALPATH packet type, so they use `cd` and `pwd` SSH commands to get the initial directory. Currently `cd` does nothing and `pwd` always returns the `/` path. These commands will work with any storage backend but keep in mind that to calculate the hash we need to read the whole file, for remote backends this means downloading the file, for the encrypted backend this means decrypting the file. - `sftpgo-copy`. This is a built-in copy implementation. It allows server side copy for files and directories. The first argument is the source file/directory and the second one is the destination file/directory, for example `sftpgo-copy `. The command will fail if the destination exists. Copy for directories spanning virtual folders is not supported. Only local filesystem is supported: recursive copy for Cloud Storage filesystems requires a new request for every file in any case, so a real server side copy is not possible. -- `sftpgo-remove`. This is a built-in remove implementation. It allows to remove single files and to recursively remove directories. The first argument is the file/directory to remove, for example `sftpgo-remove `. Only local filesystem is supported: recursive remove for Cloud Storage filesystems requires a new request for every file in any case, so a server side remove is not possible. +- `sftpgo-remove`. This is a built-in remove implementation. It allows to remove single files and to recursively remove directories. The first argument is the file/directory to remove, for example `sftpgo-remove `. Only local and encrypted filesystems are supported: recursive remove for Cloud Storage filesystems requires a new request for every file in any case, so a server side remove is not possible. The following SSH commands are enabled by default: diff --git a/ftpd/cryptfs_test.go b/ftpd/cryptfs_test.go index f760139b..7324cb61 100644 --- a/ftpd/cryptfs_test.go +++ b/ftpd/cryptfs_test.go @@ -15,6 +15,7 @@ import ( "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/httpdtest" "github.com/drakkan/sftpgo/kms" + "github.com/drakkan/sftpgo/vfs" ) func TestBasicFTPHandlingCryptFs(t *testing.T) { @@ -214,7 +215,7 @@ func TestResumeCryptFs(t *testing.T) { func getTestUserWithCryptFs() dataprovider.User { user := getTestUser() - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") return user } diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 005961bc..60cb7a62 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -245,6 +245,10 @@ func TestMain(m *testing.M) { if err != nil { logger.ErrorToConsole("error creating banner file: %v", err) } + // we run the test cases with UploadMode atomic and resume support. The non atomic code path + // simply does not execute some code so if it works in atomic mode will + // work in non atomic mode too + os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") err = config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) @@ -254,10 +258,6 @@ func TestMain(m *testing.M) { logger.InfoToConsole("Starting FTPD tests, provider: %v", providerConf.Driver) commonConf := config.GetCommonConfig() - // we run the test cases with UploadMode atomic and resume support. The non atomic code path - // simply does not execute some code so if it works in atomic mode will - // work in non atomic mode too - commonConf.UploadMode = 2 homeBasePath = os.TempDir() if runtime.GOOS != osWindows { commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"} @@ -1274,7 +1274,7 @@ func TestLoginWithIPilters(t *testing.T) { func TestLoginWithDatabaseCredentials(t *testing.T) { u := getTestUser() - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) @@ -1323,7 +1323,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { func TestLoginInvalidFs(t *testing.T) { u := getTestUser() - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) @@ -2153,7 +2153,7 @@ func TestClientCertificateAuth(t *testing.T) { // TLS username is not enabled, mutual TLS should fail _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Login method password is not allowed") + assert.Contains(t, err.Error(), "login method password is not allowed") } user.Filters.TLSUsername = dataprovider.TLSUsernameCN @@ -2186,7 +2186,7 @@ func TestClientCertificateAuth(t *testing.T) { assert.NoError(t, err) _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Login method TLSCertificate+password is not allowed") + assert.Contains(t, err.Error(), "login method TLSCertificate+password is not allowed") } // disable FTP protocol @@ -2196,7 +2196,7 @@ func TestClientCertificateAuth(t *testing.T) { assert.NoError(t, err) _, err = getFTPClient(user, true, tlsConfig) if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Protocol FTP is not allowed") + assert.Contains(t, err.Error(), "protocol FTP is not allowed") } _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -2238,7 +2238,7 @@ func TestClientCertificateAndPwdAuth(t *testing.T) { _, err = getFTPClient(user, true, nil) if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Login method password is not allowed") + assert.Contains(t, err.Error(), "login method password is not allowed") } user.Password = defaultPassword + "1" _, err = getFTPClient(user, true, tlsConfig) @@ -2405,6 +2405,111 @@ func TestPreLoginHookWithClientCert(t *testing.T) { assert.NoError(t, err) } +func TestNestedVirtualFolders(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(sftpUser, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond) +} + func checkBasicFTP(client *ftp.ServerConn) error { _, err := client.CurrentDir() if err != nil { @@ -2562,7 +2667,7 @@ func getTestUser() dataprovider.User { func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = u.Username + "_sftp" - u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + u.FsConfig.Provider = vfs.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) diff --git a/ftpd/handler.go b/ftpd/handler.go index e6b1e490..d6a696e5 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -17,7 +17,7 @@ import ( ) var ( - errNotImplemented = errors.New("Not implemented") + errNotImplemented = errors.New("not implemented") errCOMBNotSupported = errors.New("COMB is not supported for this filesystem") ) @@ -63,11 +63,7 @@ func (c *Connection) Create(name string) (afero.File, error) { func (c *Connection) Mkdir(name string, perm os.FileMode) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) - if err != nil { - return c.GetFsError(err) - } - return c.CreateDir(p, name) + return c.CreateDir(name) } // MkdirAll is not implemented, we don't need it @@ -90,22 +86,22 @@ func (c *Connection) OpenFile(name string, flag int, perm os.FileMode) (afero.Fi func (c *Connection) Remove(name string) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) + fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { - return c.GetFsError(err) + return err } var fi os.FileInfo - if fi, err = c.Fs.Lstat(p); err != nil { + if fi, err = fs.Lstat(p); err != nil { c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", p, err) - return c.GetFsError(err) + return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", p) return c.GetGenericError(nil) } - return c.RemoveFile(p, name, fi) + return c.RemoveFile(fs, p, name, fi) } // RemoveAll is not implemented, we don't need it @@ -117,20 +113,10 @@ func (c *Connection) RemoveAll(path string) error { func (c *Connection) Rename(oldname, newname string) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(oldname) - if err != nil { - return c.GetFsError(err) - } - t, err := c.Fs.ResolvePath(newname) - if err != nil { - return c.GetFsError(err) - } - - if err = c.BaseConnection.Rename(p, t, oldname, newname); err != nil { + if err := c.BaseConnection.Rename(oldname, newname); err != nil { return err } - vfs.SetPathPermissions(c.Fs, t, c.User.GetUID(), c.User.GetGID()) return nil } @@ -143,14 +129,10 @@ func (c *Connection) Stat(name string) (os.FileInfo, error) { return nil, c.GetPermissionDeniedError() } - p, err := c.Fs.ResolvePath(name) + fi, err := c.DoStat(name, 0) if err != nil { - return nil, c.GetFsError(err) - } - fi, err := c.DoStat(p, 0) - if err != nil { - c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err) - return nil, c.GetFsError(err) + c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", name, err) + return nil, err } return fi, nil } @@ -182,31 +164,23 @@ func (c *Connection) Chown(name string, uid, gid int) error { func (c *Connection) Chmod(name string, mode os.FileMode) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) - if err != nil { - return c.GetFsError(err) - } attrs := common.StatAttributes{ Flags: common.StatAttrPerms, Mode: mode, } - return c.SetStat(p, name, &attrs) + return c.SetStat(name, &attrs) } // Chtimes changes the access and modification times of the named file func (c *Connection) Chtimes(name string, atime time.Time, mtime time.Time) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) - if err != nil { - return c.GetFsError(err) - } attrs := common.StatAttributes{ Flags: common.StatAttrTimes, Atime: atime, Mtime: mtime, } - return c.SetStat(p, name, &attrs) + return c.SetStat(name, &attrs) } // GetAvailableSpace implements ClientDriverExtensionAvailableSpace interface @@ -224,14 +198,14 @@ func (c *Connection) GetAvailableSpace(dirName string) (int64, error) { return c.User.Filters.MaxUploadFileSize, nil } - p, err := c.Fs.ResolvePath(dirName) + fs, p, err := c.GetFsAndResolvedPath(dirName) if err != nil { - return 0, c.GetFsError(err) + return 0, err } - statVFS, err := c.Fs.GetAvailableDiskSize(p) + statVFS, err := fs.GetAvailableDiskSize(p) if err != nil { - return 0, c.GetFsError(err) + return 0, c.GetFsError(fs, err) } return int64(statVFS.FreeSpace()), nil } @@ -281,61 +255,43 @@ func (c *Connection) AllocateSpace(size int) error { func (c *Connection) RemoveDir(name string) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) - if err != nil { - return c.GetFsError(err) - } - - return c.BaseConnection.RemoveDir(p, name) + return c.BaseConnection.RemoveDir(name) } // Symlink implements ClientDriverExtensionSymlink func (c *Connection) Symlink(oldname, newname string) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(oldname) - if err != nil { - return c.GetFsError(err) - } - t, err := c.Fs.ResolvePath(newname) - if err != nil { - return c.GetFsError(err) - } - - return c.BaseConnection.CreateSymlink(p, t, oldname, newname) + return c.BaseConnection.CreateSymlink(oldname, newname) } // ReadDir implements ClientDriverExtensionFilelist func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) - if err != nil { - return nil, c.GetFsError(err) - } - return c.ListDir(p, name) + return c.ListDir(name) } // GetHandle implements ClientDriverExtentionFileTransfer func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver.FileTransfer, error) { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(name) + fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { - return nil, c.GetFsError(err) + return nil, err } - if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(c.Fs) { + if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(fs) { return nil, errCOMBNotSupported } if flags&os.O_WRONLY != 0 { - return c.uploadFile(p, name, flags) + return c.uploadFile(fs, p, name, flags) } - return c.downloadFile(p, name, offset) + return c.downloadFile(fs, p, name, offset) } -func (c *Connection) downloadFile(fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) { +func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) { if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(ftpPath)) { return nil, c.GetPermissionDeniedError() } @@ -345,41 +301,41 @@ func (c *Connection) downloadFile(fsPath, ftpPath string, offset int64) (ftpserv return nil, c.GetPermissionDeniedError() } - file, r, cancelFn, err := c.Fs.Open(fsPath, offset) + file, r, cancelFn, err := fs.Open(fsPath, offset) if err != nil { c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", fsPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, ftpPath, common.TransferDownload, - 0, 0, 0, false, c.Fs) + 0, 0, 0, false, fs) t := newTransfer(baseTransfer, nil, r, offset) return t, nil } -func (c *Connection) uploadFile(fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) { +func (c *Connection) uploadFile(fs vfs.Fs, fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) { if !c.User.IsFileAllowed(ftpPath) { c.Log(logger.LevelWarn, "writing file %#v is not allowed", ftpPath) return nil, c.GetPermissionDeniedError() } filePath := fsPath - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - filePath = c.Fs.GetAtomicUploadPath(fsPath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(fsPath) } - stat, statErr := c.Fs.Lstat(fsPath) - if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) { + stat, statErr := fs.Lstat(fsPath) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(ftpPath)) { return nil, c.GetPermissionDeniedError() } - return c.handleFTPUploadToNewFile(fsPath, filePath, ftpPath) + return c.handleFTPUploadToNewFile(fs, fsPath, filePath, ftpPath) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %#v: %+v", fsPath, statErr) - return nil, c.GetFsError(statErr) + return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory @@ -392,34 +348,34 @@ func (c *Connection) uploadFile(fsPath, ftpPath string, flags int) (ftpserver.Fi return nil, c.GetPermissionDeniedError() } - return c.handleFTPUploadToExistingFile(flags, fsPath, filePath, stat.Size(), ftpPath) + return c.handleFTPUploadToExistingFile(fs, flags, fsPath, filePath, stat.Size(), ftpPath) } -func (c *Connection) handleFTPUploadToNewFile(resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) { +func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) { quotaResult := c.HasSpace(true, false, requestPath) if !quotaResult.HasSpace { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } - file, w, cancelFn, err := c.Fs.Create(filePath, 0) + file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) + maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) + common.TransferUpload, 0, 0, maxWriteSize, true, fs) t := newTransfer(baseTransfer, w, nil, 0) return t, nil } -func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, filePath string, fileSize int64, +func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolvedPath, filePath string, fileSize int64, requestPath string) (ftpserver.FileTransfer, error) { var err error quotaResult := c.HasSpace(false, false, requestPath) @@ -436,25 +392,25 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file isResume := flags&os.O_TRUNC == 0 // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before - maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize) + maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported()) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) return nil, err } - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - err = c.Fs.Rename(resolvedPath, filePath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + err = fs.Rename(resolvedPath, filePath) if err != nil { c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v", resolvedPath, filePath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } } - file, w, cancelFn, err := c.Fs.Create(filePath, flags) + file, w, cancelFn, err := fs.Create(filePath, flags) if err != nil { c.Log(logger.LevelWarn, "error opening existing file, flags: %v, source: %#v, err: %+v", flags, filePath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } initialSize := int64(0) @@ -462,12 +418,12 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file c.Log(logger.LevelDebug, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize initialSize = fileSize - if vfs.IsSFTPFs(c.Fs) { + if vfs.IsSFTPFs(fs) { // we need this since we don't allow resume with wrong offset, we should fix this in pkg/sftp file.Seek(initialSize, io.SeekStart) //nolint:errcheck // for sftp seek cannot file, it simply set the offset } } else { - if vfs.IsLocalOrSFTPFs(c.Fs) { + if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck @@ -482,10 +438,10 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file } } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) t := newTransfer(baseTransfer, w, nil, 0) return t, nil diff --git a/ftpd/internal_test.go b/ftpd/internal_test.go index fb97f664..6495705c 100644 --- a/ftpd/internal_test.go +++ b/ftpd/internal_test.go @@ -351,7 +351,7 @@ func (fs MockOsFs) Rename(source, target string) error { func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { return &MockOsFs{ - Fs: vfs.NewOsFs(connectionID, rootDir, nil), + Fs: vfs.NewOsFs(connectionID, rootDir, ""), err: err, statErr: statErr, isAtomicUploadSupported: atomicUpload, @@ -492,7 +492,7 @@ func TestClientVersion(t *testing.T) { connID := fmt.Sprintf("2_%v", mockCC.ID()) user := dataprovider.User{} connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } common.Connections.Add(connection) @@ -509,7 +509,7 @@ func TestDriverMethodsNotImplemented(t *testing.T) { connID := fmt.Sprintf("2_%v", mockCC.ID()) user := dataprovider.User{} connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } _, err := connection.Create("") @@ -533,9 +533,8 @@ func TestResolvePathErrors(t *testing.T) { user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) - fs := vfs.NewOsFs(connID, user.HomeDir, nil) connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } err := connection.Mkdir("", os.ModePerm) @@ -596,9 +595,9 @@ func TestUploadFileStatError(t *testing.T) { user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) - fs := vfs.NewOsFs(connID, user.HomeDir, nil) + fs := vfs.NewOsFs(connID, user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } testFile := filepath.Join(user.HomeDir, "test", "testfile") @@ -608,7 +607,7 @@ func TestUploadFileStatError(t *testing.T) { assert.NoError(t, err) err = os.Chmod(filepath.Dir(testFile), 0001) assert.NoError(t, err) - _, err = connection.uploadFile(testFile, "test", 0) + _, err = connection.uploadFile(fs, testFile, "test", 0) assert.Error(t, err) err = os.Chmod(filepath.Dir(testFile), os.ModePerm) assert.NoError(t, err) @@ -625,9 +624,8 @@ func TestAVBLErrors(t *testing.T) { user.Permissions["/"] = []string{dataprovider.PermAny} mockCC := mockFTPClientContext{} connID := fmt.Sprintf("%v", mockCC.ID()) - fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } _, err := connection.GetAvailableSpace("/") @@ -648,12 +646,12 @@ func TestUploadOverwriteErrors(t *testing.T) { connID := fmt.Sprintf("%v", mockCC.ID()) fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } flags := 0 flags |= os.O_APPEND - _, err := connection.handleFTPUploadToExistingFile(flags, "", "", 0, "") + _, err := connection.handleFTPUploadToExistingFile(fs, flags, "", "", 0, "") if assert.Error(t, err) { assert.EqualError(t, err, common.ErrOpUnsupported.Error()) } @@ -665,7 +663,7 @@ func TestUploadOverwriteErrors(t *testing.T) { flags = 0 flags |= os.O_CREATE flags |= os.O_TRUNC - tr, err := connection.handleFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name()) + tr, err := connection.handleFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name()) if assert.NoError(t, err) { transfer := tr.(*transfer) transfers := connection.GetTransfers() @@ -680,11 +678,11 @@ func TestUploadOverwriteErrors(t *testing.T) { err = os.Remove(f.Name()) assert.NoError(t, err) - _, err = connection.handleFTPUploadToExistingFile(os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"), + _, err = connection.handleFTPUploadToExistingFile(fs, os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"), filepath.Join(os.TempDir(), "sub", "file1"), 0, "/sub/file1") assert.Error(t, err) - connection.Fs = vfs.NewOsFs(connID, user.GetHomeDir(), nil) - _, err = connection.handleFTPUploadToExistingFile(0, "missing1", "missing2", 0, "missing") + fs = vfs.NewOsFs(connID, user.GetHomeDir(), "") + _, err = connection.handleFTPUploadToExistingFile(fs, 0, "missing1", "missing2", 0, "missing") assert.Error(t, err) } @@ -702,7 +700,7 @@ func TestTransferErrors(t *testing.T) { connID := fmt.Sprintf("%v", mockCC.ID()) fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: mockCC, } baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), testfile, common.TransferDownload, diff --git a/ftpd/server.go b/ftpd/server.go index 19bceff5..27140a48 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -145,7 +145,7 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { connID := fmt.Sprintf("%v_%v", s.ID, cc.ID()) user := dataprovider.User{} connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user), clientContext: cc, } common.Connections.Add(connection) @@ -180,7 +180,6 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) if err != nil { return nil, err } - connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID()) connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v", user.ID, user.Username, user.HomeDir, ipAddr) dataprovider.UpdateLastLogin(&user) //nolint:errcheck @@ -219,7 +218,6 @@ func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsCo if err != nil { return nil, err } - connection.Fs.CheckRootPath(connection.GetUsername(), dbUser.GetUID(), dbUser.GetGID()) connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP using a TLS certificate, username: %#v, home_dir: %#v remote addr: %#v", dbUser.ID, dbUser.Username, dbUser.HomeDir, ipAddr) dataprovider.UpdateLastLogin(&dbUser) //nolint:errcheck @@ -295,11 +293,11 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext } if utils.IsStringInSlice(common.ProtocolFTP, user.Filters.DeniedProtocols) { logger.Debug(logSender, connectionID, "cannot login user %#v, protocol FTP is not allowed", user.Username) - return nil, fmt.Errorf("Protocol FTP is not allowed for user %#v", user.Username) + return nil, fmt.Errorf("protocol FTP is not allowed for user %#v", user.Username) } if !user.IsLoginMethodAllowed(loginMethod, nil) { logger.Debug(logSender, connectionID, "cannot login user %#v, %v login method is not allowed", user.Username, loginMethod) - return nil, fmt.Errorf("Login method %v is not allowed for user %#v", loginMethod, user.Username) + return nil, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username) } if user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) @@ -309,27 +307,26 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext return nil, fmt.Errorf("too many open sessions: %v", activeSessions) } } - if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() { - logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled", - user.Username) - return nil, errors.New("overlapping mapped folders are allowed only with quota tracking disabled") - } remoteAddr := cc.RemoteAddr().String() if !user.IsLoginFromAddrAllowed(remoteAddr) { logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr) - return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr) + return nil, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, remoteAddr) } - fs, err := user.GetFilesystem(connectionID) + err := user.CheckFsRoot(connectionID) if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) return nil, err } connection := &Connection{ - BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, user, fs), + BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, user), clientContext: cc, } err = common.Connections.Swap(connection) if err != nil { - return nil, errors.New("Internal authentication error") + err = user.CloseFs() + logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err) + return nil, errors.New("internal authentication error") } return connection, nil } diff --git a/ftpd/transfer.go b/ftpd/transfer.go index 833d3367..5ffcf680 100644 --- a/ftpd/transfer.go +++ b/ftpd/transfer.go @@ -102,7 +102,7 @@ func (t *transfer) Close() error { if errBaseClose != nil { err = errBaseClose } - return t.Connection.GetFsError(err) + return t.Connection.GetFsError(t.Fs, err) } func (t *transfer) closeIO() error { diff --git a/httpd/api_admin.go b/httpd/api_admin.go index 2e3ef2cc..56e1c206 100644 --- a/httpd/api_admin.go +++ b/httpd/api_admin.go @@ -89,11 +89,11 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) { } if username == claims.Username { if claims.isCriticalPermRemoved(admin.Permissions) { - sendAPIResponse(w, r, errors.New("You cannot remove these permissions to yourself"), "", http.StatusBadRequest) + sendAPIResponse(w, r, errors.New("you cannot remove these permissions to yourself"), "", http.StatusBadRequest) return } if admin.Status == 0 { - sendAPIResponse(w, r, errors.New("You cannot disable yourself"), "", http.StatusBadRequest) + sendAPIResponse(w, r, errors.New("you cannot disable yourself"), "", http.StatusBadRequest) return } } @@ -114,7 +114,7 @@ func deleteAdmin(w http.ResponseWriter, r *http.Request) { return } if username == claims.Username { - sendAPIResponse(w, r, errors.New("You cannot delete yourself"), "", http.StatusBadRequest) + sendAPIResponse(w, r, errors.New("you cannot delete yourself"), "", http.StatusBadRequest) return } diff --git a/httpd/api_folder.go b/httpd/api_folder.go index a7282256..749f312b 100644 --- a/httpd/api_folder.go +++ b/httpd/api_folder.go @@ -50,7 +50,20 @@ func updateFolder(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, err, "", getRespStatus(err)) return } + users := folder.Users folderID := folder.ID + currentS3AccessSecret := folder.FsConfig.S3Config.AccessSecret + currentAzAccountKey := folder.FsConfig.AzBlobConfig.AccountKey + currentGCSCredentials := folder.FsConfig.GCSConfig.Credentials + currentCryptoPassphrase := folder.FsConfig.CryptConfig.Passphrase + currentSFTPPassword := folder.FsConfig.SFTPConfig.Password + currentSFTPKey := folder.FsConfig.SFTPConfig.PrivateKey + + folder.FsConfig.S3Config = vfs.S3FsConfig{} + folder.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{} + folder.FsConfig.GCSConfig = vfs.GCSFsConfig{} + folder.FsConfig.CryptConfig = vfs.CryptFsConfig{} + folder.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} err = render.DecodeJSON(r.Body, &folder) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) @@ -58,7 +71,10 @@ func updateFolder(w http.ResponseWriter, r *http.Request) { } folder.ID = folderID folder.Name = name - err = dataprovider.UpdateFolder(&folder) + folder.FsConfig.SetEmptySecretsIfNil() + updateEncryptedSecrets(&folder.FsConfig, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, + currentCryptoPassphrase, currentSFTPPassword, currentSFTPKey) + err = dataprovider.UpdateFolder(&folder, users) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return @@ -72,6 +88,7 @@ func renderFolder(w http.ResponseWriter, r *http.Request, name string, status in sendAPIResponse(w, r, err, "", getRespStatus(err)) return } + folder.HideConfidentialData() if status != http.StatusOK { ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) render.JSON(w, r.WithContext(ctx), folder) diff --git a/httpd/api_maintenance.go b/httpd/api_maintenance.go index 2f771a7d..52e93ae7 100644 --- a/httpd/api_maintenance.go +++ b/httpd/api_maintenance.go @@ -21,13 +21,13 @@ import ( func validateBackupFile(outputFile string) (string, error) { if outputFile == "" { - return "", errors.New("Invalid or missing output-file") + return "", errors.New("invalid or missing output-file") } if filepath.IsAbs(outputFile) { - return "", fmt.Errorf("Invalid output-file %#v: it must be a relative path", outputFile) + return "", fmt.Errorf("invalid output-file %#v: it must be a relative path", outputFile) } if strings.Contains(outputFile, "..") { - return "", fmt.Errorf("Invalid output-file %#v", outputFile) + return "", fmt.Errorf("invalid output-file %#v", outputFile) } outputFile = filepath.Join(backupsPath, outputFile) return outputFile, nil @@ -122,7 +122,7 @@ func loadData(w http.ResponseWriter, r *http.Request) { return } if !filepath.IsAbs(inputFile) { - sendAPIResponse(w, r, fmt.Errorf("Invalid input_file %#v: it must be an absolute path", inputFile), "", http.StatusBadRequest) + sendAPIResponse(w, r, fmt.Errorf("invalid input_file %#v: it must be an absolute path", inputFile), "", http.StatusBadRequest) return } fi, err := os.Stat(inputFile) @@ -207,7 +207,7 @@ func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, sca continue } folder.ID = f.ID - err = dataprovider.UpdateFolder(&folder) + err = dataprovider.UpdateFolder(&folder, f.Users) logger.Debug(logSender, "", "restoring existing folder: %+v, dump file: %#v, error: %v", folder, inputFile, err) } else { folder.Users = nil diff --git a/httpd/api_quota.go b/httpd/api_quota.go index 468d4977..ee306f13 100644 --- a/httpd/api_quota.go +++ b/httpd/api_quota.go @@ -34,7 +34,7 @@ func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) { return } if u.UsedQuotaFiles < 0 || u.UsedQuotaSize < 0 { - sendAPIResponse(w, r, errors.New("Invalid used quota parameters, negative values are not allowed"), + sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), "", http.StatusBadRequest) return } @@ -75,7 +75,7 @@ func updateVFolderQuotaUsage(w http.ResponseWriter, r *http.Request) { return } if f.UsedQuotaFiles < 0 || f.UsedQuotaSize < 0 { - sendAPIResponse(w, r, errors.New("Invalid used quota parameters, negative values are not allowed"), + sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), "", http.StatusBadRequest) return } @@ -154,28 +154,25 @@ func startVFolderQuotaScan(w http.ResponseWriter, r *http.Request) { func doQuotaScan(user dataprovider.User) error { defer common.QuotaScans.RemoveUserQuotaScan(user.Username) - fs, err := user.GetFilesystem("") + numFiles, size, err := user.ScanQuota() if err != nil { - logger.Warn(logSender, "", "unable scan quota for user %#v error creating filesystem: %v", user.Username, err) - return err - } - defer fs.Close() - numFiles, size, err := fs.ScanRootDirContents() - if err != nil { - logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err) + logger.Warn(logSender, "", "error scanning user quota %#v: %v", user.Username, err) return err } err = dataprovider.UpdateUserQuota(&user, numFiles, size, true) - logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err) + logger.Debug(logSender, "", "user quota scanned, user: %#v, error: %v", user.Username, err) return err } func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error { defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) - fs := vfs.NewOsFs("", "", nil).(*vfs.OsFs) - numFiles, size, err := fs.GetDirSize(folder.MappedPath) + f := vfs.VirtualFolder{ + BaseVirtualFolder: folder, + VirtualPath: "/", + } + numFiles, size, err := f.ScanQuota() if err != nil { - logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err) + logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.Name, err) return err } err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) @@ -188,7 +185,7 @@ func getQuotaUpdateMode(r *http.Request) (string, error) { if _, ok := r.URL.Query()["mode"]; ok { mode = r.URL.Query().Get("mode") if mode != quotaUpdateModeReset && mode != quotaUpdateModeAdd { - return "", errors.New("Invalid mode") + return "", errors.New("invalid mode") } } return mode, nil diff --git a/httpd/api_user.go b/httpd/api_user.go index b5467ea6..99e5edfe 100644 --- a/httpd/api_user.go +++ b/httpd/api_user.go @@ -59,27 +59,27 @@ func addUser(w http.ResponseWriter, r *http.Request) { } user.SetEmptySecretsIfNil() switch user.FsConfig.Provider { - case dataprovider.S3FilesystemProvider: + case vfs.S3FilesystemProvider: if user.FsConfig.S3Config.AccessSecret.IsRedacted() { sendAPIResponse(w, r, errors.New("invalid access_secret"), "", http.StatusBadRequest) return } - case dataprovider.GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: if user.FsConfig.GCSConfig.Credentials.IsRedacted() { sendAPIResponse(w, r, errors.New("invalid credentials"), "", http.StatusBadRequest) return } - case dataprovider.AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: if user.FsConfig.AzBlobConfig.AccountKey.IsRedacted() { sendAPIResponse(w, r, errors.New("invalid account_key"), "", http.StatusBadRequest) return } - case dataprovider.CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: if user.FsConfig.CryptConfig.Passphrase.IsRedacted() { sendAPIResponse(w, r, errors.New("invalid passphrase"), "", http.StatusBadRequest) return } - case dataprovider.SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: if user.FsConfig.SFTPConfig.Password.IsRedacted() { sendAPIResponse(w, r, errors.New("invalid SFTP password"), "", http.StatusBadRequest) return @@ -131,6 +131,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { user.FsConfig.GCSConfig = vfs.GCSFsConfig{} user.FsConfig.CryptConfig = vfs.CryptFsConfig{} user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + user.VirtualFolders = nil err = render.DecodeJSON(r.Body, &user) if err != nil { sendAPIResponse(w, r, err, "", http.StatusBadRequest) @@ -143,7 +144,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { if len(user.Permissions) == 0 { user.Permissions = currentPermissions } - updateEncryptedSecrets(&user, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, currentCryptoPassphrase, + updateEncryptedSecrets(&user.FsConfig, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, currentCryptoPassphrase, currentSFTPPassword, currentSFTPKey) err = dataprovider.UpdateUser(&user) if err != nil { @@ -175,32 +176,32 @@ func disconnectUser(username string) { } } -func updateEncryptedSecrets(user *dataprovider.User, currentS3AccessSecret, currentAzAccountKey, +func updateEncryptedSecrets(fsConfig *vfs.Filesystem, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, currentCryptoPassphrase, currentSFTPPassword, currentSFTPKey *kms.Secret) { // we use the new access secret if plain or empty, otherwise the old value - switch user.FsConfig.Provider { - case dataprovider.S3FilesystemProvider: - if user.FsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() { - user.FsConfig.S3Config.AccessSecret = currentS3AccessSecret + switch fsConfig.Provider { + case vfs.S3FilesystemProvider: + if fsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() { + fsConfig.S3Config.AccessSecret = currentS3AccessSecret } - case dataprovider.AzureBlobFilesystemProvider: - if user.FsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() { - user.FsConfig.AzBlobConfig.AccountKey = currentAzAccountKey + case vfs.AzureBlobFilesystemProvider: + if fsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() { + fsConfig.AzBlobConfig.AccountKey = currentAzAccountKey } - case dataprovider.GCSFilesystemProvider: - if user.FsConfig.GCSConfig.Credentials.IsNotPlainAndNotEmpty() { - user.FsConfig.GCSConfig.Credentials = currentGCSCredentials + case vfs.GCSFilesystemProvider: + if fsConfig.GCSConfig.Credentials.IsNotPlainAndNotEmpty() { + fsConfig.GCSConfig.Credentials = currentGCSCredentials } - case dataprovider.CryptedFilesystemProvider: - if user.FsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() { - user.FsConfig.CryptConfig.Passphrase = currentCryptoPassphrase + case vfs.CryptedFilesystemProvider: + if fsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() { + fsConfig.CryptConfig.Passphrase = currentCryptoPassphrase } - case dataprovider.SFTPFilesystemProvider: - if user.FsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() { - user.FsConfig.SFTPConfig.Password = currentSFTPPassword + case vfs.SFTPFilesystemProvider: + if fsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() { + fsConfig.SFTPConfig.Password = currentSFTPPassword } - if user.FsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() { - user.FsConfig.SFTPConfig.PrivateKey = currentSFTPKey + if fsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() { + fsConfig.SFTPConfig.PrivateKey = currentSFTPKey } } } diff --git a/httpd/api_utils.go b/httpd/api_utils.go index 074c4ad3..640e35bb 100644 --- a/httpd/api_utils.go +++ b/httpd/api_utils.go @@ -63,7 +63,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, if _, ok := r.URL.Query()["limit"]; ok { limit, err = strconv.Atoi(r.URL.Query().Get("limit")) if err != nil { - err = errors.New("Invalid limit") + err = errors.New("invalid limit") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } @@ -74,7 +74,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, if _, ok := r.URL.Query()["offset"]; ok { offset, err = strconv.Atoi(r.URL.Query().Get("offset")) if err != nil { - err = errors.New("Invalid offset") + err = errors.New("invalid offset") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } @@ -82,7 +82,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, if _, ok := r.URL.Query()["order"]; ok { order = r.URL.Query().Get("order") if order != dataprovider.OrderASC && order != dataprovider.OrderDESC { - err = errors.New("Invalid order") + err = errors.New("invalid order") sendAPIResponse(w, r, err, "", http.StatusBadRequest) return limit, offset, order, err } diff --git a/httpd/auth_utils.go b/httpd/auth_utils.go index 61adf772..57ba0532 100644 --- a/httpd/auth_utils.go +++ b/httpd/auth_utils.go @@ -212,12 +212,12 @@ func verifyCSRFToken(tokenString string) error { token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF: %v", err) - return fmt.Errorf("Unable to verify form token: %v", err) + return fmt.Errorf("unable to verify form token: %v", err) } if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) { logger.Debug(logSender, "", "error validating CSRF token audience") - return errors.New("The form token is not valid") + return errors.New("the form token is not valid") } return nil diff --git a/httpd/httpd.go b/httpd/httpd.go index 5fc5ea0a..b6ebf7a6 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -190,10 +190,10 @@ func (c *Conf) Initialize(configDir string) error { templatesPath := getConfigPath(c.TemplatesPath, configDir) enableWebAdmin := staticFilesPath != "" || templatesPath != "" if backupsPath == "" { - return fmt.Errorf("Required directory is invalid, backup path %#v", backupsPath) + return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath) } if enableWebAdmin && (staticFilesPath == "" || templatesPath == "") { - return fmt.Errorf("Required directory is invalid, static file path: %#v template path: %#v", + return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v", staticFilesPath, templatesPath) } certificateFile := getConfigPath(c.CertificateFile, configDir) diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index ac88774d..3c41e91a 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -577,7 +577,7 @@ func TestAddUserInvalidFilters(t *testing.T) { func TestAddUserInvalidFsConfig(t *testing.T) { u := getTestUser() - u.FsConfig.Provider = dataprovider.S3FilesystemProvider + u.FsConfig.Provider = vfs.S3FilesystemProvider u.FsConfig.S3Config.Bucket = "" _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) @@ -586,8 +586,8 @@ func TestAddUserInvalidFsConfig(t *testing.T) { err = os.MkdirAll(credentialsPath, 0700) assert.NoError(t, err) u.FsConfig.S3Config.Bucket = "testbucket" - u.FsConfig.S3Config.Region = "eu-west-1" - u.FsConfig.S3Config.AccessKey = "access-key" + u.FsConfig.S3Config.Region = "eu-west-1" //nolint:goconst + u.FsConfig.S3Config.AccessKey = "access-key" //nolint:goconst u.FsConfig.S3Config.AccessSecret = kms.NewSecret(kms.SecretStatusRedacted, "access-secret", "", "") u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?a=b" u.FsConfig.S3Config.StorageClass = "Standard" //nolint:goconst @@ -609,7 +609,7 @@ func TestAddUserInvalidFsConfig(t *testing.T) { _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u = getTestUser() - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) @@ -632,7 +632,7 @@ func TestAddUserInvalidFsConfig(t *testing.T) { assert.NoError(t, err) u = getTestUser() - u.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + u.FsConfig.Provider = vfs.AzureBlobFilesystemProvider u.FsConfig.AzBlobConfig.SASURL = "http://foo\x7f.com/" _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) @@ -659,14 +659,14 @@ func TestAddUserInvalidFsConfig(t *testing.T) { assert.NoError(t, err) u = getTestUser() - u.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + u.FsConfig.Provider = vfs.CryptedFilesystemProvider _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.CryptConfig.Passphrase = kms.NewSecret(kms.SecretStatusRedacted, "akey", "", "") _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u = getTestUser() - u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + u.FsConfig.Provider = vfs.SFTPFilesystemProvider _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) u.FsConfig.SFTPConfig.Password = kms.NewSecret(kms.SecretStatusRedacted, "randompkey", "", "") @@ -678,6 +678,52 @@ func TestAddUserInvalidFsConfig(t *testing.T) { assert.NoError(t, err) } +func TestUserRedactedPassword(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = vfs.S3FilesystemProvider + u.FsConfig.S3Config.Bucket = "b" + u.FsConfig.S3Config.Region = "eu-west-1" + u.FsConfig.S3Config.AccessKey = "access-key" + u.FsConfig.S3Config.AccessSecret = kms.NewSecret(kms.SecretStatusRedacted, "access-secret", "", "") + u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?k=m" + u.FsConfig.S3Config.StorageClass = "Standard" + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "invalid access_secret") + err = dataprovider.AddUser(&u) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") + } + u.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("secret") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + folderName := "folderName" + vfolder := vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "crypted"), + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewSecret(kms.SecretStatusRedacted, "crypted-secret", "", ""), + }, + }, + }, + VirtualPath: "/avpath", + } + + user.Password = defaultPassword + user.VirtualFolders = append(user.VirtualFolders, vfolder) + err = dataprovider.UpdateUser(&user) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + func TestAddUserInvalidVirtualFolders(t *testing.T) { u := getTestUser() folderName := "fname" @@ -764,7 +810,7 @@ func TestAddUserInvalidVirtualFolders(t *testing.T) { }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) assert.NoError(t, err) - u.VirtualFolders = nil + /*u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ MappedPath: filepath.Join(os.TempDir(), "mapped_dir", "subdir"), @@ -831,7 +877,7 @@ func TestAddUserInvalidVirtualFolders(t *testing.T) { VirtualPath: "/vdir1/subdir", // invalid, contained inside /vdir1 }) _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) - assert.NoError(t, err) + assert.NoError(t, err)*/ u.VirtualFolders = nil u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ @@ -902,6 +948,32 @@ func TestAddUserInvalidVirtualFolders(t *testing.T) { assert.NoError(t, err) } +func TestSFTPVirtualFolderSelf(t *testing.T) { + // an sftp virtual folder cannot use the same sftp account, it will generate an infinite loop + // at login + u := getTestUser() + mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") + folderNameSFTP := filepath.Base(mappedPathSFTP) + vdirSFTPPath := "/vdir/sftp" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + FsConfig: vfs.Filesystem{ + Provider: vfs.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + Endpoint: "127.0.0.1:2022", + Username: defaultUsername, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + }, + VirtualPath: vdirSFTPPath, + }) + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "could point to the same SFTPGo account") +} + func TestUserPublicKey(t *testing.T) { u := getTestUser() u.Password = "" @@ -1110,10 +1182,11 @@ func TestUserFolderMapping(t *testing.T) { u1 := getTestUser() u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName1, - MappedPath: mappedPath1, - UsedQuotaFiles: 2, - UsedQuotaSize: 123, + Name: folderName1, + MappedPath: mappedPath1, + UsedQuotaFiles: 2, + UsedQuotaSize: 123, + LastQuotaUpdate: 456, }, VirtualPath: "/vdir", QuotaSize: -1, @@ -1128,6 +1201,10 @@ func TestUserFolderMapping(t *testing.T) { assert.Contains(t, folder.Users, user1.Username) assert.Equal(t, 0, folder.UsedQuotaFiles) assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, int64(0), folder.LastQuotaUpdate) + assert.Equal(t, 0, user1.VirtualFolders[0].UsedQuotaFiles) + assert.Equal(t, int64(0), user1.VirtualFolders[0].UsedQuotaSize) + assert.Equal(t, int64(0), user1.VirtualFolders[0].LastQuotaUpdate) u2 := getTestUser() u2.Username = defaultUsername + "2" @@ -1239,23 +1316,53 @@ func TestUserFolderMapping(t *testing.T) { func TestUserS3Config(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" //nolint:goconst user.FsConfig.S3Config.Region = "us-east-1" //nolint:goconst user.FsConfig.S3Config.AccessKey = "Server-Access-Key" user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("Server-Access-Secret") user.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000" user.FsConfig.S3Config.UploadPartSize = 8 + folderName := "vfolderName" + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "folderName"), + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("Crypted-Secret"), + }, + }, + }, + VirtualPath: "/folderPath", + }) user, body, err := httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(body)) assert.Equal(t, kms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) assert.NotEmpty(t, user.FsConfig.S3Config.AccessSecret.GetPayload()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) + if assert.Len(t, user.VirtualFolders, 1) { + folder := user.VirtualFolders[0] + assert.Equal(t, kms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) + } _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, kms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) user.Password = defaultPassword user.ID = 0 + user.VirtualFolders = nil secret := kms.NewSecret(kms.SecretStatusSecretBox, "Server-Access-Secret", "", "") user.FsConfig.S3Config.AccessSecret = secret _, _, err = httpdtest.AddUser(user, http.StatusCreated) @@ -1268,7 +1375,7 @@ func TestUserS3Config(t *testing.T) { assert.NotEmpty(t, initialSecretPayload) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test-bucket" user.FsConfig.S3Config.Region = "us-east-1" //nolint:goconst user.FsConfig.S3Config.AccessKey = "Server-Access-Key1" @@ -1282,7 +1389,7 @@ func TestUserS3Config(t *testing.T) { assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) // test user without access key and access secret (shared config state) - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "testbucket" user.FsConfig.S3Config.Region = "us-east-1" user.FsConfig.S3Config.AccessKey = "" @@ -1313,7 +1420,7 @@ func TestUserGCSConfig(t *testing.T) { assert.NoError(t, err) err = os.MkdirAll(credentialsPath, 0700) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.GCSFilesystemProvider + user.FsConfig.Provider = vfs.GCSFilesystemProvider user.FsConfig.GCSConfig.Bucket = "test" user.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") //nolint:goconst user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") @@ -1360,7 +1467,7 @@ func TestUserGCSConfig(t *testing.T) { assert.NoError(t, err) assert.NoFileExists(t, credentialFile) user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test1" user.FsConfig.S3Config.Region = "us-east-1" user.FsConfig.S3Config.AccessKey = "Server-Access-Key1" @@ -1370,7 +1477,7 @@ func TestUserGCSConfig(t *testing.T) { user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.Provider = dataprovider.GCSFilesystemProvider + user.FsConfig.Provider = vfs.GCSFilesystemProvider user.FsConfig.GCSConfig.Bucket = "test1" user.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") @@ -1383,7 +1490,7 @@ func TestUserGCSConfig(t *testing.T) { func TestUserAzureBlobConfig(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "test" user.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") @@ -1422,7 +1529,7 @@ func TestUserAzureBlobConfig(t *testing.T) { assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "test-container" user.FsConfig.AzBlobConfig.Endpoint = "http://localhost:9001" user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" @@ -1435,7 +1542,7 @@ func TestUserAzureBlobConfig(t *testing.T) { assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) // test user without access key and access secret (sas) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.SASURL = "https://myaccount.blob.core.windows.net/pictures/profile.jpg?sv=2012-02-12&st=2009-02-09&se=2009-02-10&sr=c&sp=r&si=YWJjZGVmZw%3d%3d&sig=dD80ihBh5jfNpymO5Hg1IdiJIEvHcJpCMiCMnN%2fRnbI%3d" user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" user.FsConfig.AzBlobConfig.AccountName = "" @@ -1460,7 +1567,7 @@ func TestUserAzureBlobConfig(t *testing.T) { func TestUserCryptFs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypt passphrase") user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err) @@ -1495,7 +1602,7 @@ func TestUserCryptFs(t *testing.T) { assert.NotEmpty(t, initialPayload) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase.SetKey("pass") user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) @@ -1512,7 +1619,7 @@ func TestUserCryptFs(t *testing.T) { func TestUserSFTPFs(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + user.FsConfig.Provider = vfs.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1" // missing port user.FsConfig.SFTPConfig.Username = "sftp_user" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp_pwd") @@ -1579,7 +1686,7 @@ func TestUserSFTPFs(t *testing.T) { assert.NotEmpty(t, initialPkeyPayload) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) - user.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + user.FsConfig.Provider = vfs.SFTPFilesystemProvider user.FsConfig.SFTPConfig.PrivateKey.SetKey("k") user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") assert.NoError(t, err, string(bb)) @@ -1607,7 +1714,7 @@ func TestUserHiddenFields(t *testing.T) { usernames := []string{"user1", "user2", "user3", "user4", "user5"} u1 := getTestUser() u1.Username = usernames[0] - u1.FsConfig.Provider = dataprovider.S3FilesystemProvider + u1.FsConfig.Provider = vfs.S3FilesystemProvider u1.FsConfig.S3Config.Bucket = "test" u1.FsConfig.S3Config.Region = "us-east-1" u1.FsConfig.S3Config.AccessKey = "S3-Access-Key" @@ -1617,7 +1724,7 @@ func TestUserHiddenFields(t *testing.T) { u2 := getTestUser() u2.Username = usernames[1] - u2.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u2.FsConfig.Provider = vfs.GCSFilesystemProvider u2.FsConfig.GCSConfig.Bucket = "test" u2.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) @@ -1625,7 +1732,7 @@ func TestUserHiddenFields(t *testing.T) { u3 := getTestUser() u3.Username = usernames[2] - u3.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + u3.FsConfig.Provider = vfs.AzureBlobFilesystemProvider u3.FsConfig.AzBlobConfig.Container = "test" u3.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" u3.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") @@ -1634,14 +1741,14 @@ func TestUserHiddenFields(t *testing.T) { u4 := getTestUser() u4.Username = usernames[3] - u4.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + u4.FsConfig.Provider = vfs.CryptedFilesystemProvider u4.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("test passphrase") user4, _, err := httpdtest.AddUser(u4, http.StatusCreated) assert.NoError(t, err) u5 := getTestUser() u5.Username = usernames[4] - u5.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + u5.FsConfig.Provider = vfs.SFTPFilesystemProvider u5.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" u5.FsConfig.SFTPConfig.Username = "sftp_user" u5.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("apassword") @@ -2102,7 +2209,7 @@ func TestEmbeddedFolders(t *testing.T) { u.Username = u.Username + "1" u.VirtualFolders[0].MappedPath = "" user1, _, err := httpdtest.AddUser(u, http.StatusCreated) - assert.EqualError(t, err, "Virtual folders mismatch") + assert.EqualError(t, err, "mapped path mismatch") if assert.Len(t, user1.VirtualFolders, 1) { assert.Equal(t, mappedPath, user1.VirtualFolders[0].MappedPath) assert.Equal(t, u.VirtualFolders[0].VirtualPath, user1.VirtualFolders[0].VirtualPath) @@ -2111,7 +2218,7 @@ func TestEmbeddedFolders(t *testing.T) { } user1.VirtualFolders = u.VirtualFolders user1, _, err = httpdtest.UpdateUser(user1, http.StatusOK, "") - assert.EqualError(t, err, "Virtual folders mismatch") + assert.EqualError(t, err, "mapped path mismatch") if assert.Len(t, user1.VirtualFolders, 1) { assert.Equal(t, mappedPath, user1.VirtualFolders[0].MappedPath) assert.Equal(t, u.VirtualFolders[0].VirtualPath, user1.VirtualFolders[0].VirtualPath) @@ -2137,6 +2244,153 @@ func TestEmbeddedFolders(t *testing.T) { assert.NoError(t, err) } +func TestEmbeddedFoldersUpdate(t *testing.T) { + u := getTestUser() + mappedPath := filepath.Join(os.TempDir(), "mapped_path") + name := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: name, + MappedPath: mappedPath, + UsedQuotaFiles: 1000, + UsedQuotaSize: 8192, + LastQuotaUpdate: 123, + }, + VirtualPath: "/vdir", + QuotaSize: 4096, + QuotaFiles: 1, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + folder, _, err := httpdtest.GetFolderByName(name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, 0, folder.UsedQuotaFiles) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, int64(0), folder.LastQuotaUpdate) + assert.Empty(t, folder.Description) + assert.Equal(t, vfs.LocalFilesystemProvider, folder.FsConfig.Provider) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user.Username) + // update a field on the folder + description := "updatedDesc" + folder.MappedPath = mappedPath + "_update" + folder.Description = description + folder, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath+"_update", folder.MappedPath) + assert.Equal(t, 0, folder.UsedQuotaFiles) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, int64(0), folder.LastQuotaUpdate) + assert.Equal(t, description, folder.Description) + assert.Equal(t, vfs.LocalFilesystemProvider, folder.FsConfig.Provider) + // check that the user gets the changes + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + userFolder := user.VirtualFolders[0].BaseVirtualFolder + assert.Equal(t, mappedPath+"_update", folder.MappedPath) + assert.Equal(t, 0, userFolder.UsedQuotaFiles) + assert.Equal(t, int64(0), userFolder.UsedQuotaSize) + assert.Equal(t, int64(0), userFolder.LastQuotaUpdate) + assert.Equal(t, description, userFolder.Description) + assert.Equal(t, vfs.LocalFilesystemProvider, userFolder.FsConfig.Provider) + // now update the folder embedding it inside the user + user.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: name, + MappedPath: "", + UsedQuotaFiles: 1000, + UsedQuotaSize: 8192, + LastQuotaUpdate: 123, + FsConfig: vfs.Filesystem{ + Provider: vfs.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + Bucket: "test", + Region: "us-east-1", + AccessKey: "akey", + AccessSecret: kms.NewPlainSecret("asecret"), + Endpoint: "http://127.0.1.1:9090", + }, + }, + }, + VirtualPath: "/vdir1", + QuotaSize: 4096, + QuotaFiles: 1, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + userFolder = user.VirtualFolders[0].BaseVirtualFolder + assert.Equal(t, 0, userFolder.UsedQuotaFiles) + assert.Equal(t, int64(0), userFolder.UsedQuotaSize) + assert.Equal(t, int64(0), userFolder.LastQuotaUpdate) + assert.Empty(t, userFolder.Description) + assert.Equal(t, vfs.S3FilesystemProvider, userFolder.FsConfig.Provider) + assert.Equal(t, "test", userFolder.FsConfig.S3Config.Bucket) + assert.Equal(t, "us-east-1", userFolder.FsConfig.S3Config.Region) + assert.Equal(t, "http://127.0.1.1:9090", userFolder.FsConfig.S3Config.Endpoint) + assert.Equal(t, kms.SecretStatusSecretBox, userFolder.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, userFolder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, userFolder.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, userFolder.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + // confirm the changes + folder, _, err = httpdtest.GetFolderByName(name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, folder.UsedQuotaFiles) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, int64(0), folder.LastQuotaUpdate) + assert.Empty(t, folder.Description) + assert.Equal(t, vfs.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, "test", folder.FsConfig.S3Config.Bucket) + assert.Equal(t, "us-east-1", folder.FsConfig.S3Config.Region) + assert.Equal(t, "http://127.0.1.1:9090", folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, kms.SecretStatusSecretBox, folder.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, folder.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, folder.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + // now update folder usage limits and check that a folder update will not change them + folder.UsedQuotaFiles = 100 + folder.UsedQuotaSize = 32768 + _, err = httpdtest.UpdateFolderQuotaUsage(folder, "reset", http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 100, folder.UsedQuotaFiles) + assert.Equal(t, int64(32768), folder.UsedQuotaSize) + assert.Greater(t, folder.LastQuotaUpdate, int64(0)) + assert.Equal(t, vfs.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, "test", folder.FsConfig.S3Config.Bucket) + assert.Equal(t, "us-east-1", folder.FsConfig.S3Config.Region) + assert.Equal(t, "http://127.0.1.1:9090", folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, kms.SecretStatusSecretBox, folder.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, folder.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, folder.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + + user.VirtualFolders[0].FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("updated secret") + user, resp, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + userFolder = user.VirtualFolders[0].BaseVirtualFolder + assert.Equal(t, 100, userFolder.UsedQuotaFiles) + assert.Equal(t, int64(32768), userFolder.UsedQuotaSize) + assert.Greater(t, userFolder.LastQuotaUpdate, int64(0)) + assert.Empty(t, userFolder.Description) + assert.Equal(t, vfs.S3FilesystemProvider, userFolder.FsConfig.Provider) + assert.Equal(t, "test", userFolder.FsConfig.S3Config.Bucket) + assert.Equal(t, "us-east-1", userFolder.FsConfig.S3Config.Region) + assert.Equal(t, "http://127.0.1.1:9090", userFolder.FsConfig.S3Config.Endpoint) + assert.Equal(t, kms.SecretStatusSecretBox, userFolder.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, userFolder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, userFolder.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, userFolder.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: name}, http.StatusOK) + assert.NoError(t, err) +} + func TestUpdateFolderQuotaUsage(t *testing.T) { f := vfs.BaseVirtualFolder{ Name: "vdir", @@ -2201,7 +2455,7 @@ func TestCloseActiveConnection(t *testing.T) { _, err := httpdtest.CloseConnection("non_existent_id", http.StatusNotFound) assert.NoError(t, err) user := getTestUser() - c := common.NewBaseConnection("connID", common.ProtocolSFTP, user, nil) + c := common.NewBaseConnection("connID", common.ProtocolSFTP, user) fakeConn := &fakeConnection{ BaseConnection: c, } @@ -2214,12 +2468,12 @@ func TestCloseActiveConnection(t *testing.T) { func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) - c := common.NewBaseConnection("connID", common.ProtocolFTP, user, nil) + c := common.NewBaseConnection("connID", common.ProtocolFTP, user) fakeConn := &fakeConnection{ BaseConnection: c, } common.Connections.Add(fakeConn) - c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, user, nil) + c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, user) fakeConn1 := &fakeConnection{ BaseConnection: c1, } @@ -2312,7 +2566,7 @@ func TestUserBaseDir(t *testing.T) { u.HomeDir = "" user, _, err := httpdtest.AddUser(u, http.StatusCreated) if assert.Error(t, err) { - assert.EqualError(t, err, "HomeDir mismatch") + assert.EqualError(t, err, "home dir mismatch") } assert.Equal(t, filepath.Join(providerConf.UsersBaseDir, u.Username), user.HomeDir) _, err = httpdtest.RemoveUser(user, http.StatusOK) @@ -2461,17 +2715,27 @@ func TestFolders(t *testing.T) { Name: "name", MappedPath: "relative path", Users: []string{"1", "2", "3"}, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("asecret"), + }, + }, } _, _, err := httpdtest.AddFolder(folder, http.StatusBadRequest) assert.NoError(t, err) folder.MappedPath = filepath.Clean(os.TempDir()) folder1, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) - assert.EqualError(t, err, "folder users mismatch", string(resp)) + assert.NoError(t, err, string(resp)) assert.Equal(t, folder.Name, folder1.Name) assert.Equal(t, folder.MappedPath, folder1.MappedPath) assert.Equal(t, 0, folder1.UsedQuotaFiles) assert.Equal(t, int64(0), folder1.UsedQuotaSize) assert.Equal(t, int64(0), folder1.LastQuotaUpdate) + assert.Equal(t, kms.SecretStatusSecretBox, folder1.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder1.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetKey()) assert.Len(t, folder1.Users, 0) // adding a duplicate folder must fail _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) @@ -2482,7 +2746,7 @@ func TestFolders(t *testing.T) { folder.UsedQuotaSize = 345 folder.LastQuotaUpdate = 10 folder2, _, err := httpdtest.AddFolder(folder, http.StatusCreated) - assert.EqualError(t, err, "folder users mismatch", string(resp)) + assert.NoError(t, err, string(resp)) assert.Equal(t, 1, folder2.UsedQuotaFiles) assert.Equal(t, int64(345), folder2.UsedQuotaSize) assert.Equal(t, int64(10), folder2.LastQuotaUpdate) @@ -2491,6 +2755,19 @@ func TestFolders(t *testing.T) { assert.NoError(t, err) numResults := len(folders) assert.GreaterOrEqual(t, numResults, 2) + found := false + for _, f := range folders { + if f.Name == folder1.Name { + found = true + assert.Equal(t, folder1.MappedPath, f.MappedPath) + assert.Equal(t, kms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Len(t, f.Users, 0) + } + } + assert.True(t, found) folders, _, err = httpdtest.GetFolders(0, 1, http.StatusOK) assert.NoError(t, err) assert.Len(t, folders, numResults-1) @@ -2501,6 +2778,11 @@ func TestFolders(t *testing.T) { assert.NoError(t, err) assert.Equal(t, folder1.Name, f.Name) assert.Equal(t, folder1.MappedPath, f.MappedPath) + assert.Equal(t, kms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Len(t, f.Users, 0) f, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK) assert.NoError(t, err) assert.Equal(t, folder2.Name, f.Name) @@ -2516,8 +2798,8 @@ func TestFolders(t *testing.T) { assert.NoError(t, err) folder1.MappedPath = filepath.Join(os.TempDir(), "updated") folder1.Description = "updated folder description" - f, _, err = httpdtest.UpdateFolder(folder1, http.StatusOK) - assert.NoError(t, err) + f, resp, err = httpdtest.UpdateFolder(folder1, http.StatusOK) + assert.NoError(t, err, string(resp)) assert.Equal(t, folder1.MappedPath, f.MappedPath) assert.Equal(t, folder1.Description, f.Description) @@ -2535,14 +2817,14 @@ func TestDumpdata(t *testing.T) { providerConf := config.GetProviderConf() err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) - _, _, err = httpdtest.Dumpdata("", "", "", http.StatusBadRequest) - assert.NoError(t, err) + _, rawResp, err := httpdtest.Dumpdata("", "", "", http.StatusBadRequest) + assert.NoError(t, err, string(rawResp)) _, _, err = httpdtest.Dumpdata(filepath.Join(backupsPath, "backup.json"), "", "", http.StatusBadRequest) assert.NoError(t, err) - _, _, err = httpdtest.Dumpdata("../backup.json", "", "", http.StatusBadRequest) - assert.NoError(t, err) - _, _, err = httpdtest.Dumpdata("backup.json", "", "0", http.StatusOK) - assert.NoError(t, err) + _, rawResp, err = httpdtest.Dumpdata("../backup.json", "", "", http.StatusBadRequest) + assert.NoError(t, err, string(rawResp)) + _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "0", http.StatusOK) + assert.NoError(t, err, string(rawResp)) response, _, err := httpdtest.Dumpdata("", "1", "0", http.StatusOK) assert.NoError(t, err) _, ok := response["admins"] @@ -2553,8 +2835,8 @@ func TestDumpdata(t *testing.T) { assert.True(t, ok) _, ok = response["version"] assert.True(t, ok) - _, _, err = httpdtest.Dumpdata("backup.json", "", "1", http.StatusOK) - assert.NoError(t, err) + _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "1", http.StatusOK) + assert.NoError(t, err, string(rawResp)) err = os.Remove(filepath.Join(backupsPath, "backup.json")) assert.NoError(t, err) if runtime.GOOS != "windows" { @@ -2863,7 +3145,7 @@ func TestLoaddataMode(t *testing.T) { assert.Equal(t, int64(789), folder.LastQuotaUpdate) assert.Len(t, folder.Users, 0) - c := common.NewBaseConnection("connID", common.ProtocolFTP, user, nil) + c := common.NewBaseConnection("connID", common.ProtocolFTP, user) fakeConn := &fakeConnection{ BaseConnection: c, } @@ -4080,7 +4362,7 @@ func TestWebLoginMock(t *testing.T) { req.RemoteAddr = "10.9.9.8:1234" rr = executeRequest(req) checkResponseCode(t, http.StatusOK, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") req, _ = http.NewRequest(http.MethodGet, webLoginPath, nil) rr = executeRequest(req) @@ -4135,7 +4417,7 @@ func TestWebAdminPwdChange(t *testing.T) { setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -4268,7 +4550,7 @@ func TestWebAdminBasicMock(t *testing.T) { setJWTCookieForReq(req, token) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) form.Set("status", "a") @@ -4322,7 +4604,7 @@ func TestWebAdminBasicMock(t *testing.T) { setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) form.Set("email", "not-an-email") @@ -4378,7 +4660,7 @@ func TestWebAdminBasicMock(t *testing.T) { setCSRFHeaderForReq(req, csrfToken) rr = executeRequest(req) checkResponseCode(t, http.StatusBadRequest, rr) - assert.Contains(t, rr.Body.String(), "You cannot delete yourself") + assert.Contains(t, rr.Body.String(), "you cannot delete yourself") req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, defaultTokenAuthUser), nil) setJWTCookieForReq(req, token) @@ -4506,7 +4788,7 @@ func TestWebMaintenanceMock(t *testing.T) { req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") @@ -4760,7 +5042,7 @@ func TestWebUserAddMock(t *testing.T) { req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") @@ -4915,7 +5197,7 @@ func TestWebUserUpdateMock(t *testing.T) { req.Header.Set("Content-Type", contentType) rr = executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") @@ -5111,7 +5393,7 @@ func TestUserTemplateWithFoldersMock(t *testing.T) { req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - require.Contains(t, rr.Body.String(), "Unable to verify form token") + require.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) b, contentType, _ = getMultipartFormData(form, "", "") @@ -5164,7 +5446,7 @@ func TestUserTemplateMock(t *testing.T) { token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) user := getTestUser() - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" user.FsConfig.S3Config.Region = "eu-central-1" user.FsConfig.S3Config.AccessKey = "%username%" @@ -5252,9 +5534,9 @@ func TestUserTemplateMock(t *testing.T) { user1 := dump.Users[0] user2 := dump.Users[1] require.Equal(t, "user1", user1.Username) - require.Equal(t, dataprovider.S3FilesystemProvider, user1.FsConfig.Provider) + require.Equal(t, vfs.S3FilesystemProvider, user1.FsConfig.Provider) require.Equal(t, "user2", user2.Username) - require.Equal(t, dataprovider.S3FilesystemProvider, user2.FsConfig.Provider) + require.Equal(t, vfs.S3FilesystemProvider, user2.FsConfig.Provider) require.Len(t, user2.PublicKeys, 1) require.Equal(t, filepath.Join(os.TempDir(), user1.Username), user1.HomeDir) require.Equal(t, filepath.Join(os.TempDir(), user2.Username), user2.HomeDir) @@ -5290,10 +5572,9 @@ func TestFolderTemplateMock(t *testing.T) { req.Header.Set("Content-Type", contentType) rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) - req, _ = http.NewRequest(http.MethodPost, webTemplateFolder+"?param=p%C3%AO%GG", bytes.NewBuffer([]byte(form.Encode()))) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) @@ -5301,6 +5582,9 @@ func TestFolderTemplateMock(t *testing.T) { checkResponseCode(t, http.StatusBadRequest, rr) assert.Contains(t, rr.Body.String(), "Error parsing folders fields") + folder1 := "folder1" + folder2 := "folder2" + folder3 := "folder3" req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, bytes.NewBuffer([]byte(form.Encode()))) setJWTCookieForReq(req, token) req.Header.Set("Content-Type", contentType) @@ -5313,16 +5597,63 @@ func TestFolderTemplateMock(t *testing.T) { require.Len(t, dump.Users, 0) require.Len(t, dump.Admins, 0) require.Len(t, dump.Folders, 3) - require.Equal(t, "folder1", dump.Folders[0].Name) + require.Equal(t, folder1, dump.Folders[0].Name) require.Equal(t, "desc folder folder1", dump.Folders[0].Description) require.True(t, strings.HasSuffix(dump.Folders[0].MappedPath, "folder1mappedfolder1path")) - require.Equal(t, "folder2", dump.Folders[1].Name) + require.Equal(t, folder2, dump.Folders[1].Name) require.Equal(t, "desc folder folder2", dump.Folders[1].Description) require.True(t, strings.HasSuffix(dump.Folders[1].MappedPath, "folder2mappedfolder2path")) - require.Equal(t, "folder3", dump.Folders[2].Name) + require.Equal(t, folder3, dump.Folders[2].Name) require.Equal(t, "desc folder folder3", dump.Folders[2].Description) require.True(t, strings.HasSuffix(dump.Folders[2].MappedPath, "folder3mappedfolder3path")) + form.Set("fs_provider", "1") + form.Set("s3_bucket", "bucket") + form.Set("s3_region", "us-east-1") + form.Set("s3_access_key", "%name%") + form.Set("s3_access_secret", "pwd%name%") + form.Set("s3_key_prefix", "base/%name%") + + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, bytes.NewBuffer([]byte(form.Encode()))) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Error parsing folders fields") + + form.Set("s3_upload_part_size", "5") + form.Set("s3_upload_concurrency", "4") + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, bytes.NewBuffer([]byte(form.Encode()))) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + dump = dataprovider.BackupData{} + err = json.Unmarshal(rr.Body.Bytes(), &dump) + require.NoError(t, err) + require.Len(t, dump.Users, 0) + require.Len(t, dump.Admins, 0) + require.Len(t, dump.Folders, 3) + require.Equal(t, folder1, dump.Folders[0].Name) + require.Equal(t, folder1, dump.Folders[0].FsConfig.S3Config.AccessKey) + err = dump.Folders[0].FsConfig.S3Config.AccessSecret.Decrypt() + require.NoError(t, err) + require.Equal(t, "pwd"+folder1, dump.Folders[0].FsConfig.S3Config.AccessSecret.GetPayload()) + require.Equal(t, "base/"+folder1+"/", dump.Folders[0].FsConfig.S3Config.KeyPrefix) + require.Equal(t, folder2, dump.Folders[1].Name) + require.Equal(t, folder2, dump.Folders[1].FsConfig.S3Config.AccessKey) + err = dump.Folders[1].FsConfig.S3Config.AccessSecret.Decrypt() + require.NoError(t, err) + require.Equal(t, "pwd"+folder2, dump.Folders[1].FsConfig.S3Config.AccessSecret.GetPayload()) + require.Equal(t, "base/"+folder2+"/", dump.Folders[1].FsConfig.S3Config.KeyPrefix) + require.Equal(t, folder3, dump.Folders[2].Name) + require.Equal(t, folder3, dump.Folders[2].FsConfig.S3Config.AccessKey) + err = dump.Folders[2].FsConfig.S3Config.AccessSecret.Decrypt() + require.NoError(t, err) + require.Equal(t, "pwd"+folder3, dump.Folders[2].FsConfig.S3Config.AccessSecret.GetPayload()) + require.Equal(t, "base/"+folder3+"/", dump.Folders[2].FsConfig.S3Config.KeyPrefix) + form.Set("folders", "\n\n\n") req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, bytes.NewBuffer([]byte(form.Encode()))) setJWTCookieForReq(req, token) @@ -5356,7 +5687,7 @@ func TestWebUserS3Mock(t *testing.T) { checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.S3FilesystemProvider + user.FsConfig.Provider = vfs.S3FilesystemProvider user.FsConfig.S3Config.Bucket = "test" user.FsConfig.S3Config.Region = "eu-west-1" user.FsConfig.S3Config.AccessKey = "access-key" @@ -5507,7 +5838,7 @@ func TestWebUserGCSMock(t *testing.T) { credentialsFilePath := filepath.Join(os.TempDir(), "gcs.json") err = createTestFile(credentialsFilePath, 0) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.GCSFilesystemProvider + user.FsConfig.Provider = vfs.GCSFilesystemProvider user.FsConfig.GCSConfig.Bucket = "test" user.FsConfig.GCSConfig.KeyPrefix = "somedir/subdir/" user.FsConfig.GCSConfig.StorageClass = "standard" @@ -5605,7 +5936,7 @@ func TestWebUserAzureBlobMock(t *testing.T) { checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.Container = "container" user.FsConfig.AzBlobConfig.AccountName = "aname" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("access-skey") @@ -5728,7 +6059,7 @@ func TestWebUserCryptMock(t *testing.T) { checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypted passphrase") form := make(url.Values) form.Set(csrfFormToken, csrfToken) @@ -5820,7 +6151,7 @@ func TestWebUserSFTPFsMock(t *testing.T) { checkResponseCode(t, http.StatusCreated, rr) err = render.DecodeJSON(rr.Body, &user) assert.NoError(t, err) - user.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + user.FsConfig.Provider = vfs.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:22" user.FsConfig.SFTPConfig.Username = "sftpuser" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("pwd") @@ -5944,7 +6275,7 @@ func TestAddWebFoldersMock(t *testing.T) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, webFolderPath, strings.NewReader(form.Encode())) @@ -5992,6 +6323,121 @@ func TestAddWebFoldersMock(t *testing.T) { checkResponseCode(t, http.StatusOK, rr) } +func TestS3WebFolderMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFToken() + assert.NoError(t, err) + mappedPath := filepath.Clean(os.TempDir()) + folderName := filepath.Base(mappedPath) + folderDesc := "a simple desc" + S3Bucket := "test" + S3Region := "eu-west-1" + S3AccessKey := "access-key" + S3AccessSecret := kms.NewPlainSecret("folder-access-secret") + S3Endpoint := "http://127.0.0.1:9000/path?b=c" + S3StorageClass := "Standard" + S3KeyPrefix := "somedir/subdir/" + S3UploadPartSize := 5 + S3UploadConcurrency := 4 + form := make(url.Values) + form.Set("mapped_path", mappedPath) + form.Set("name", folderName) + form.Set("description", folderDesc) + form.Set("fs_provider", "1") + form.Set("s3_bucket", S3Bucket) + form.Set("s3_region", S3Region) + form.Set("s3_access_key", S3AccessKey) + form.Set("s3_access_secret", S3AccessSecret.GetPayload()) + form.Set("s3_storage_class", S3StorageClass) + form.Set("s3_endpoint", S3Endpoint) + form.Set("s3_key_prefix", S3KeyPrefix) + form.Set("s3_upload_part_size", strconv.Itoa(S3UploadPartSize)) + form.Set("s3_upload_concurrency", "a") + form.Set(csrfFormToken, csrfToken) + req, err := http.NewRequest(http.MethodPost, webFolderPath, strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) + req, err = http.NewRequest(http.MethodPost, webFolderPath, strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + var folder vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + assert.Equal(t, vfs.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) + assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) + assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) + assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) + assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) + assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) + // update + S3UploadConcurrency = 10 + form.Set("s3_upload_concurrency", "b") + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + folder = vfs.BaseVirtualFolder{} + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + assert.Equal(t, vfs.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) + assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) + assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) + assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) + assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) + assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) + + // cleanup + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + func TestUpdateWebFolderMock(t *testing.T) { webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) assert.NoError(t, err) @@ -6020,7 +6466,7 @@ func TestUpdateWebFolderMock(t *testing.T) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr := executeRequest(req) checkResponseCode(t, http.StatusForbidden, rr) - assert.Contains(t, rr.Body.String(), "Unable to verify form token") + assert.Contains(t, rr.Body.String(), "unable to verify form token") form.Set(csrfFormToken, csrfToken) req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), strings.NewReader(form.Encode())) diff --git a/httpd/internal_test.go b/httpd/internal_test.go index dfbd2523..bfa8455a 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -299,7 +299,7 @@ func TestGCSWebInvalidFormFile(t *testing.T) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") err := req.ParseForm() assert.NoError(t, err) - _, err = getFsConfigFromUserPostFields(req) + _, err = getFsConfigFromPostFields(req) assert.EqualError(t, err, http.ErrNotMultipart.Error()) } @@ -373,7 +373,7 @@ func TestCSRFToken(t *testing.T) { // invalid token err := verifyCSRFToken("token") if assert.Error(t, err) { - assert.Contains(t, err.Error(), "Unable to verify form token") + assert.Contains(t, err.Error(), "unable to verify form token") } // bad audience claims := make(map[string]interface{}) @@ -403,7 +403,7 @@ func TestCSRFToken(t *testing.T) { rr = httptest.NewRecorder() fn.ServeHTTP(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) - assert.Contains(t, rr.Body.String(), "The token is not valid") + assert.Contains(t, rr.Body.String(), "the token is not valid") csrfTokenAuth = jwtauth.New("PS256", utils.GenerateRandomBytes(32), nil) tokenString = createCSRFToken() @@ -682,8 +682,8 @@ func TestQuotaScanInvalidFs(t *testing.T) { user := dataprovider.User{ Username: "test", HomeDir: os.TempDir(), - FsConfig: dataprovider.Filesystem{ - Provider: dataprovider.S3FilesystemProvider, + FsConfig: vfs.Filesystem{ + Provider: vfs.S3FilesystemProvider, }, } common.QuotaScans.AddUserQuotaScan(user.Username) @@ -754,6 +754,44 @@ func TestVerifyTLSConnection(t *testing.T) { certMgr = oldCertMgr } +func TestGetFolderFromTemplate(t *testing.T) { + folder := vfs.BaseVirtualFolder{ + MappedPath: "Folder%name%", + Description: "Folder %name% desc", + } + folderName := "folderTemplate" + folderTemplate := getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.Name) + require.Equal(t, fmt.Sprintf("Folder%v", folderName), folderTemplate.MappedPath) + require.Equal(t, fmt.Sprintf("Folder %v desc", folderName), folderTemplate.Description) + + folder.FsConfig.Provider = vfs.CryptedFilesystemProvider + folder.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) + + folder.FsConfig.Provider = vfs.GCSFilesystemProvider + folder.FsConfig.GCSConfig.KeyPrefix = "prefix%name%/" + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, fmt.Sprintf("prefix%v/", folderName), folderTemplate.FsConfig.GCSConfig.KeyPrefix) + + folder.FsConfig.Provider = vfs.AzureBlobFilesystemProvider + folder.FsConfig.AzBlobConfig.KeyPrefix = "a%name%" + folder.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, "a"+folderName, folderTemplate.FsConfig.AzBlobConfig.KeyPrefix) + require.Equal(t, "pwd"+folderName, folderTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + + folder.FsConfig.Provider = vfs.SFTPFilesystemProvider + folder.FsConfig.SFTPConfig.Prefix = "%name%" + folder.FsConfig.SFTPConfig.Username = "sftp_%name%" + folder.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.FsConfig.SFTPConfig.Prefix) + require.Equal(t, "sftp_"+folderName, folderTemplate.FsConfig.SFTPConfig.Username) + require.Equal(t, "sftp"+folderName, folderTemplate.FsConfig.SFTPConfig.Password.GetPayload()) +} + func TestGetUserFromTemplate(t *testing.T) { user := dataprovider.User{ Status: 1, @@ -775,24 +813,24 @@ func TestGetUserFromTemplate(t *testing.T) { require.Len(t, userTemplate.VirtualFolders, 1) require.Equal(t, "Folder"+username, userTemplate.VirtualFolders[0].Name) - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%password%") userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, password, userTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) - user.FsConfig.Provider = dataprovider.GCSFilesystemProvider + user.FsConfig.Provider = vfs.GCSFilesystemProvider user.FsConfig.GCSConfig.KeyPrefix = "%username%%password%" userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, username+password, userTemplate.FsConfig.GCSConfig.KeyPrefix) - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider user.FsConfig.AzBlobConfig.KeyPrefix = "a%username%" user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%password%%username%") userTemplate = getUserFromTemplate(user, templateFields) require.Equal(t, "a"+username, userTemplate.FsConfig.AzBlobConfig.KeyPrefix) require.Equal(t, "pwd"+password+username, userTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) - user.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + user.FsConfig.Provider = vfs.SFTPFilesystemProvider user.FsConfig.SFTPConfig.Prefix = "%username%" user.FsConfig.SFTPConfig.Username = "sftp_%username%" user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%password%") diff --git a/httpd/middleware.go b/httpd/middleware.go index 7395886d..cb8fec5c 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -134,7 +134,7 @@ func verifyCSRFHeader(next http.Handler) http.Handler { if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) { logger.Debug(logSender, "", "error validating CSRF header audience") - sendAPIResponse(w, r, errors.New("The token is not valid"), "", http.StatusForbidden) + sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return } diff --git a/httpd/schema/openapi.yaml b/httpd/schema/openapi.yaml index 87a43891..a4bf2c58 100644 --- a/httpd/schema/openapi.yaml +++ b/httpd/schema/openapi.yaml @@ -12,7 +12,7 @@ tags: info: title: SFTPGo description: SFTPGo REST API - version: 2.0.3 + version: 2.0.4 contact: url: 'https://github.com/drakkan/sftpgo' license: @@ -483,7 +483,7 @@ paths: tags: - quota summary: Start folder quota scan - description: Starts a new quota scan for the given user. A quota scan update the number of files and their total size for the specified folder + description: Starts a new quota scan for the given folder. A quota scan update the number of files and their total size for the specified folder operationId: start_folder_quota_scan requestBody: required: true @@ -1112,7 +1112,7 @@ paths: - 0 - 1 description: | - output_data: + output data: * `0` or any other value != 1, the backup will be saved to a file on the server, `output_file` is required * `1` the backup will be returned as response body - in: query @@ -1336,6 +1336,19 @@ components: - manage_system - manage_defender - view_defender + description: | + Admin permissions: + * `*` - all permissions are granted + * `add_users` - add new users is allowed + * `edit_users` - change existing users is allowed + * `del_users` - remove users is allowed + * `view_users` - list users is allowed + * `view_conns` - list active connections is allowed + * `close_conns` - close active connections is allowed + * `view_status` - view the server status is allowed + * `manage_admins` - manage other admins is allowed + * `manage_defender` - remove ip from the dynamic blocklist is allowed + * `view_defender` - list the dynamic blocklist is allowed LoginMethods: type: string enum: @@ -1347,13 +1360,25 @@ components: - TLSCertificate - TLSCertificate+password description: | - To enable multi-step authentication you have to allow only multi-step login methods + Available login methods. To enable multi-step authentication you have to allow only multi-step login methods + * `publickey` + * `password` + * `keyboard-interactive` + * `publickey+password` - multi-step auth: public key and password + * `publickey+keyboard-interactive` - multi-step auth: public key and keyboard interactive + * `TLSCertificate` + * `TLSCertificate+password` - multi-step auth: TLS client certificate and password SupportedProtocols: type: string enum: - SSH - FTP - DAV + description: | + Protocols: + * `SSH` - includes both SFTP and SSH commands + * `FTP` - plain FTP and FTPES/FTPS + * `DAV` - WebDAV over HTTP/HTTPS PatternsFilter: type: object properties: @@ -1654,7 +1679,9 @@ components: items: type: string description: list of usernames associated with this virtual folder - description: defines the path for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path. + filesystem: + $ref: '#/components/schemas/FilesystemConfig' + description: Defines the filesystem for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path. VirtualFolder: allOf: - $ref: '#/components/schemas/BaseVirtualFolder' @@ -1837,6 +1864,10 @@ components: enum: - upload - download + description: | + Operations: + * `upload` + * `download` path: type: string description: file path for the upload/download @@ -1899,9 +1930,9 @@ components: FolderQuotaScan: type: object properties: - mapped_path: + name: type: string - description: path with an active scan + description: folder name with an active scan start_time: type: integer format: int64 diff --git a/httpd/web.go b/httpd/web.go index 9a8f02b5..c28b502b 100644 --- a/httpd/web.go +++ b/httpd/web.go @@ -41,6 +41,7 @@ const ( const ( templateBase = "base.html" + templateFsConfig = "fsconfig.html" templateUsers = "users.html" templateUser = "user.html" templateAdmins = "admins.html" @@ -196,6 +197,7 @@ func loadTemplates(templatesPath string) { } userPaths := []string{ filepath.Join(templatesPath, templateBase), + filepath.Join(templatesPath, templateFsConfig), filepath.Join(templatesPath, templateUser), } adminsPaths := []string{ @@ -224,6 +226,7 @@ func loadTemplates(templatesPath string) { } folderPath := []string{ filepath.Join(templatesPath, templateBase), + filepath.Join(templatesPath, templateFsConfig), filepath.Join(templatesPath, templateFolder), } statusPath := []string{ @@ -392,6 +395,7 @@ func renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.U if user.Password != "" && user.IsPasswordHashed() && mode == userPageModeUpdate { user.Password = redactedSecret } + user.FsConfig.RedactedSecret = redactedSecret data := userPage{ basePage: getBasePageData(title, currentURL, r), Mode: mode, @@ -401,7 +405,6 @@ func renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.U ValidLoginMethods: dataprovider.ValidLoginMethods, ValidProtocols: dataprovider.ValidProtocols, RootDirPerms: user.GetPermissionsForPath("/"), - RedactedSecret: redactedSecret, } renderTemplate(w, templateUser, data) } @@ -419,6 +422,9 @@ func renderFolderPage(w http.ResponseWriter, r *http.Request, folder vfs.BaseVir title = "Folder template" currentURL = webTemplateFolder } + folder.FsConfig.RedactedSecret = redactedSecret + folder.FsConfig.SetEmptySecretsIfNil() + data := folderPage{ basePage: getBasePageData(title, currentURL, r), Error: error, @@ -753,35 +759,35 @@ func getAzureConfig(r *http.Request) (vfs.AzBlobFsConfig, error) { return config, err } -func getFsConfigFromUserPostFields(r *http.Request) (dataprovider.Filesystem, error) { - var fs dataprovider.Filesystem +func getFsConfigFromPostFields(r *http.Request) (vfs.Filesystem, error) { + var fs vfs.Filesystem provider, err := strconv.Atoi(r.Form.Get("fs_provider")) if err != nil { - provider = int(dataprovider.LocalFilesystemProvider) + provider = int(vfs.LocalFilesystemProvider) } - fs.Provider = dataprovider.FilesystemProvider(provider) + fs.Provider = vfs.FilesystemProvider(provider) switch fs.Provider { - case dataprovider.S3FilesystemProvider: + case vfs.S3FilesystemProvider: config, err := getS3Config(r) if err != nil { return fs, err } fs.S3Config = config - case dataprovider.AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: config, err := getAzureConfig(r) if err != nil { return fs, err } fs.AzBlobConfig = config - case dataprovider.GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: config, err := getGCSConfig(r) if err != nil { return fs, err } fs.GCSConfig = config - case dataprovider.CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: fs.CryptConfig.Passphrase = getSecretFromFormField(r, "crypt_passphrase") - case dataprovider.SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: fs.SFTPConfig = getSFTPConfig(r) } return fs, nil @@ -822,6 +828,18 @@ func getFolderFromTemplate(folder vfs.BaseVirtualFolder, name string) vfs.BaseVi folder.MappedPath = replacePlaceholders(folder.MappedPath, replacements) folder.Description = replacePlaceholders(folder.Description, replacements) + switch folder.FsConfig.Provider { + case vfs.CryptedFilesystemProvider: + folder.FsConfig.CryptConfig = getCryptFsFromTemplate(folder.FsConfig.CryptConfig, replacements) + case vfs.S3FilesystemProvider: + folder.FsConfig.S3Config = getS3FsFromTemplate(folder.FsConfig.S3Config, replacements) + case vfs.GCSFilesystemProvider: + folder.FsConfig.GCSConfig = getGCSFsFromTemplate(folder.FsConfig.GCSConfig, replacements) + case vfs.AzureBlobFilesystemProvider: + folder.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(folder.FsConfig.AzBlobConfig, replacements) + case vfs.SFTPFilesystemProvider: + folder.FsConfig.SFTPConfig = getSFTPFsFromTemplate(folder.FsConfig.SFTPConfig, replacements) + } return folder } @@ -895,15 +913,15 @@ func getUserFromTemplate(user dataprovider.User, template userTemplateFields) da user.AdditionalInfo = replacePlaceholders(user.AdditionalInfo, replacements) switch user.FsConfig.Provider { - case dataprovider.CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: user.FsConfig.CryptConfig = getCryptFsFromTemplate(user.FsConfig.CryptConfig, replacements) - case dataprovider.S3FilesystemProvider: + case vfs.S3FilesystemProvider: user.FsConfig.S3Config = getS3FsFromTemplate(user.FsConfig.S3Config, replacements) - case dataprovider.GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: user.FsConfig.GCSConfig = getGCSFsFromTemplate(user.FsConfig.GCSConfig, replacements) - case dataprovider.AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: user.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(user.FsConfig.AzBlobConfig, replacements) - case dataprovider.SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: user.FsConfig.SFTPConfig = getSFTPFsFromTemplate(user.FsConfig.SFTPConfig, replacements) } @@ -959,7 +977,7 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { } expirationDateMillis = utils.GetTimeAsMsSinceEpoch(expirationDate) } - fsConfig, err := getFsConfigFromUserPostFields(r) + fsConfig, err := getFsConfigFromPostFields(r) if err != nil { return user, err } @@ -1257,6 +1275,12 @@ func handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) { templateFolder.MappedPath = r.Form.Get("mapped_path") templateFolder.Description = r.Form.Get("description") + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + renderMessagePage(w, r, "Error parsing folders fields", "", http.StatusBadRequest, err, "") + return + } + templateFolder.FsConfig = fsConfig var dump dataprovider.BackupData dump.Version = dataprovider.DumpVersion @@ -1415,7 +1439,7 @@ func handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) { if updatedUser.Password == redactedSecret { updatedUser.Password = user.Password } - updateEncryptedSecrets(&updatedUser, user.FsConfig.S3Config.AccessSecret, user.FsConfig.AzBlobConfig.AccountKey, + updateEncryptedSecrets(&updatedUser.FsConfig, user.FsConfig.S3Config.AccessSecret, user.FsConfig.AzBlobConfig.AccountKey, user.FsConfig.GCSConfig.Credentials, user.FsConfig.CryptConfig.Passphrase, user.FsConfig.SFTPConfig.Password, user.FsConfig.SFTPConfig.PrivateKey) @@ -1466,6 +1490,12 @@ func handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) { folder.MappedPath = r.Form.Get("mapped_path") folder.Name = r.Form.Get("name") folder.Description = r.Form.Get("description") + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + renderFolderPage(w, r, folder, folderPageModeAdd, err.Error()) + return + } + folder.FsConfig = fsConfig err = dataprovider.AddFolder(&folder) if err == nil { @@ -1508,9 +1538,24 @@ func handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) { renderForbiddenPage(w, r, err.Error()) return } - folder.MappedPath = r.Form.Get("mapped_path") - folder.Description = r.Form.Get("description") - err = dataprovider.UpdateFolder(&folder) + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + renderFolderPage(w, r, folder, folderPageModeUpdate, err.Error()) + return + } + updatedFolder := &vfs.BaseVirtualFolder{ + MappedPath: r.Form.Get("mapped_path"), + Description: r.Form.Get("description"), + } + updatedFolder.ID = folder.ID + updatedFolder.Name = folder.Name + updatedFolder.FsConfig = fsConfig + updatedFolder.FsConfig.SetEmptySecretsIfNil() + updateEncryptedSecrets(&updatedFolder.FsConfig, folder.FsConfig.S3Config.AccessSecret, folder.FsConfig.AzBlobConfig.AccountKey, + folder.FsConfig.GCSConfig.Credentials, folder.FsConfig.CryptConfig.Passphrase, folder.FsConfig.SFTPConfig.Password, + folder.FsConfig.SFTPConfig.PrivateKey) + + err = dataprovider.UpdateFolder(updatedFolder, folder.Users) if err != nil { renderFolderPage(w, r, folder, folderPageModeUpdate, err.Error()) return diff --git a/httpdtest/httpdtest.go b/httpdtest/httpdtest.go index e7fe2fe0..3bc91a87 100644 --- a/httpdtest/httpdtest.go +++ b/httpdtest/httpdtest.go @@ -10,7 +10,6 @@ import ( "net/http" "net/url" "path" - "path/filepath" "strconv" "strings" @@ -835,32 +834,18 @@ func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder) if expected.MappedPath != actual.MappedPath { return errors.New("mapped path mismatch") } - if expected.LastQuotaUpdate != actual.LastQuotaUpdate { - return errors.New("last quota update mismatch") - } - if expected.UsedQuotaSize != actual.UsedQuotaSize { - return errors.New("used quota size mismatch") - } - if expected.UsedQuotaFiles != actual.UsedQuotaFiles { - return errors.New("used quota files mismatch") - } if expected.Description != actual.Description { - return errors.New("Description mismatch") + return errors.New("description mismatch") } - if len(expected.Users) != len(actual.Users) { - return errors.New("folder users mismatch") - } - for _, u := range actual.Users { - if !utils.IsStringInSlice(u, expected.Users) { - return errors.New("folder users mismatch") - } + if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { + return err } return nil } func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error { if actual.Password != "" { - return errors.New("Admin password must not be visible") + return errors.New("admin password must not be visible") } if expected.ID <= 0 { if actual.ID <= 0 { @@ -875,19 +860,19 @@ func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error return err } if len(expected.Permissions) != len(actual.Permissions) { - return errors.New("Permissions mismatch") + return errors.New("permissions mismatch") } for _, p := range expected.Permissions { if !utils.IsStringInSlice(p, actual.Permissions) { - return errors.New("Permissions content mismatch") + return errors.New("permissions content mismatch") } } if len(expected.Filters.AllowList) != len(actual.Filters.AllowList) { - return errors.New("AllowList mismatch") + return errors.New("allow list mismatch") } for _, v := range expected.Filters.AllowList { if !utils.IsStringInSlice(v, actual.Filters.AllowList) { - return errors.New("AllowList content mismatch") + return errors.New("allow list content mismatch") } } @@ -896,26 +881,26 @@ func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error func compareAdminEqualFields(expected *dataprovider.Admin, actual *dataprovider.Admin) error { if expected.Username != actual.Username { - return errors.New("Username mismatch") + return errors.New("sername mismatch") } if expected.Email != actual.Email { - return errors.New("Email mismatch") + return errors.New("email mismatch") } if expected.Status != actual.Status { - return errors.New("Status mismatch") + return errors.New("status mismatch") } if expected.Description != actual.Description { - return errors.New("Description mismatch") + return errors.New("description mismatch") } if expected.AdditionalInfo != actual.AdditionalInfo { - return errors.New("AdditionalInfo mismatch") + return errors.New("additional info mismatch") } return nil } func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { if actual.Password != "" { - return errors.New("User password must not be visible") + return errors.New("user password must not be visible") } if expected.ID <= 0 { if actual.ID <= 0 { @@ -927,23 +912,23 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { } } if len(expected.Permissions) != len(actual.Permissions) { - return errors.New("Permissions mismatch") + return errors.New("permissions mismatch") } for dir, perms := range expected.Permissions { if actualPerms, ok := actual.Permissions[dir]; ok { for _, v := range actualPerms { if !utils.IsStringInSlice(v, perms) { - return errors.New("Permissions contents mismatch") + return errors.New("permissions contents mismatch") } } } else { - return errors.New("Permissions directories mismatch") + return errors.New("permissions directories mismatch") } } if err := compareUserFilters(expected, actual); err != nil { return err } - if err := compareUserFsConfig(expected, actual); err != nil { + if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { return err } if err := compareUserVirtualFolders(expected, actual); err != nil { @@ -954,27 +939,35 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { func compareUserVirtualFolders(expected *dataprovider.User, actual *dataprovider.User) error { if len(actual.VirtualFolders) != len(expected.VirtualFolders) { - return errors.New("Virtual folders mismatch") + return errors.New("virtual folders len mismatch") } for _, v := range actual.VirtualFolders { found := false for _, v1 := range expected.VirtualFolders { - if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) && - filepath.Clean(v.MappedPath) == filepath.Clean(v1.MappedPath) { + if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) { + if err := checkFolder(&v1.BaseVirtualFolder, &v.BaseVirtualFolder); err != nil { + return err + } + if v.QuotaSize != v1.QuotaSize { + return errors.New("vfolder quota size mismatch") + } + if (v.QuotaFiles) != (v1.QuotaFiles) { + return errors.New("vfolder quota files mismatch") + } found = true break } } if !found { - return errors.New("Virtual folders mismatch") + return errors.New("virtual folders mismatch") } } return nil } -func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) error { - if expected.FsConfig.Provider != actual.FsConfig.Provider { - return errors.New("Fs provider mismatch") +func compareFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.Provider != actual.Provider { + return errors.New("fs provider mismatch") } if err := compareS3Config(expected, actual); err != nil { return err @@ -985,7 +978,7 @@ func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) if err := compareAzBlobConfig(expected, actual); err != nil { return err } - if err := checkEncryptedSecret(expected.FsConfig.CryptConfig.Passphrase, actual.FsConfig.CryptConfig.Passphrase); err != nil { + if err := checkEncryptedSecret(expected.CryptConfig.Passphrase, actual.CryptConfig.Passphrase); err != nil { return err } if err := compareSFTPFsConfig(expected, actual); err != nil { @@ -994,118 +987,118 @@ func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) return nil } -func compareS3Config(expected *dataprovider.User, actual *dataprovider.User) error { - if expected.FsConfig.S3Config.Bucket != actual.FsConfig.S3Config.Bucket { - return errors.New("S3 bucket mismatch") +func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.S3Config.Bucket != actual.S3Config.Bucket { + return errors.New("fs S3 bucket mismatch") } - if expected.FsConfig.S3Config.Region != actual.FsConfig.S3Config.Region { - return errors.New("S3 region mismatch") + if expected.S3Config.Region != actual.S3Config.Region { + return errors.New("fs S3 region mismatch") } - if expected.FsConfig.S3Config.AccessKey != actual.FsConfig.S3Config.AccessKey { - return errors.New("S3 access key mismatch") + if expected.S3Config.AccessKey != actual.S3Config.AccessKey { + return errors.New("fs S3 access key mismatch") } - if err := checkEncryptedSecret(expected.FsConfig.S3Config.AccessSecret, actual.FsConfig.S3Config.AccessSecret); err != nil { - return fmt.Errorf("S3 access secret mismatch: %v", err) + if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil { + return fmt.Errorf("fs S3 access secret mismatch: %v", err) } - if expected.FsConfig.S3Config.Endpoint != actual.FsConfig.S3Config.Endpoint { - return errors.New("S3 endpoint mismatch") + if expected.S3Config.Endpoint != actual.S3Config.Endpoint { + return errors.New("fs S3 endpoint mismatch") } - if expected.FsConfig.S3Config.StorageClass != actual.FsConfig.S3Config.StorageClass { - return errors.New("S3 storage class mismatch") + if expected.S3Config.StorageClass != actual.S3Config.StorageClass { + return errors.New("fs S3 storage class mismatch") } - if expected.FsConfig.S3Config.UploadPartSize != actual.FsConfig.S3Config.UploadPartSize { - return errors.New("S3 upload part size mismatch") + if expected.S3Config.UploadPartSize != actual.S3Config.UploadPartSize { + return errors.New("fs S3 upload part size mismatch") } - if expected.FsConfig.S3Config.UploadConcurrency != actual.FsConfig.S3Config.UploadConcurrency { - return errors.New("S3 upload concurrency mismatch") + if expected.S3Config.UploadConcurrency != actual.S3Config.UploadConcurrency { + return errors.New("fs S3 upload concurrency mismatch") } - if expected.FsConfig.S3Config.KeyPrefix != actual.FsConfig.S3Config.KeyPrefix && - expected.FsConfig.S3Config.KeyPrefix+"/" != actual.FsConfig.S3Config.KeyPrefix { - return errors.New("S3 key prefix mismatch") + if expected.S3Config.KeyPrefix != actual.S3Config.KeyPrefix && + expected.S3Config.KeyPrefix+"/" != actual.S3Config.KeyPrefix { + return errors.New("fs S3 key prefix mismatch") } return nil } -func compareGCSConfig(expected *dataprovider.User, actual *dataprovider.User) error { - if expected.FsConfig.GCSConfig.Bucket != actual.FsConfig.GCSConfig.Bucket { +func compareGCSConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.GCSConfig.Bucket != actual.GCSConfig.Bucket { return errors.New("GCS bucket mismatch") } - if expected.FsConfig.GCSConfig.StorageClass != actual.FsConfig.GCSConfig.StorageClass { + if expected.GCSConfig.StorageClass != actual.GCSConfig.StorageClass { return errors.New("GCS storage class mismatch") } - if expected.FsConfig.GCSConfig.KeyPrefix != actual.FsConfig.GCSConfig.KeyPrefix && - expected.FsConfig.GCSConfig.KeyPrefix+"/" != actual.FsConfig.GCSConfig.KeyPrefix { + if expected.GCSConfig.KeyPrefix != actual.GCSConfig.KeyPrefix && + expected.GCSConfig.KeyPrefix+"/" != actual.GCSConfig.KeyPrefix { return errors.New("GCS key prefix mismatch") } - if expected.FsConfig.GCSConfig.AutomaticCredentials != actual.FsConfig.GCSConfig.AutomaticCredentials { + if expected.GCSConfig.AutomaticCredentials != actual.GCSConfig.AutomaticCredentials { return errors.New("GCS automatic credentials mismatch") } return nil } -func compareSFTPFsConfig(expected *dataprovider.User, actual *dataprovider.User) error { - if expected.FsConfig.SFTPConfig.Endpoint != actual.FsConfig.SFTPConfig.Endpoint { +func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.SFTPConfig.Endpoint != actual.SFTPConfig.Endpoint { return errors.New("SFTPFs endpoint mismatch") } - if expected.FsConfig.SFTPConfig.Username != actual.FsConfig.SFTPConfig.Username { + if expected.SFTPConfig.Username != actual.SFTPConfig.Username { return errors.New("SFTPFs username mismatch") } - if expected.FsConfig.SFTPConfig.DisableCouncurrentReads != actual.FsConfig.SFTPConfig.DisableCouncurrentReads { + if expected.SFTPConfig.DisableCouncurrentReads != actual.SFTPConfig.DisableCouncurrentReads { return errors.New("SFTPFs disable_concurrent_reads mismatch") } - if err := checkEncryptedSecret(expected.FsConfig.SFTPConfig.Password, actual.FsConfig.SFTPConfig.Password); err != nil { + if err := checkEncryptedSecret(expected.SFTPConfig.Password, actual.SFTPConfig.Password); err != nil { return fmt.Errorf("SFTPFs password mismatch: %v", err) } - if err := checkEncryptedSecret(expected.FsConfig.SFTPConfig.PrivateKey, actual.FsConfig.SFTPConfig.PrivateKey); err != nil { + if err := checkEncryptedSecret(expected.SFTPConfig.PrivateKey, actual.SFTPConfig.PrivateKey); err != nil { return fmt.Errorf("SFTPFs private key mismatch: %v", err) } - if expected.FsConfig.SFTPConfig.Prefix != actual.FsConfig.SFTPConfig.Prefix { - if expected.FsConfig.SFTPConfig.Prefix != "" && actual.FsConfig.SFTPConfig.Prefix != "/" { + if expected.SFTPConfig.Prefix != actual.SFTPConfig.Prefix { + if expected.SFTPConfig.Prefix != "" && actual.SFTPConfig.Prefix != "/" { return errors.New("SFTPFs prefix mismatch") } } - if len(expected.FsConfig.SFTPConfig.Fingerprints) != len(actual.FsConfig.SFTPConfig.Fingerprints) { + if len(expected.SFTPConfig.Fingerprints) != len(actual.SFTPConfig.Fingerprints) { return errors.New("SFTPFs fingerprints mismatch") } - for _, value := range actual.FsConfig.SFTPConfig.Fingerprints { - if !utils.IsStringInSlice(value, expected.FsConfig.SFTPConfig.Fingerprints) { + for _, value := range actual.SFTPConfig.Fingerprints { + if !utils.IsStringInSlice(value, expected.SFTPConfig.Fingerprints) { return errors.New("SFTPFs fingerprints mismatch") } } return nil } -func compareAzBlobConfig(expected *dataprovider.User, actual *dataprovider.User) error { - if expected.FsConfig.AzBlobConfig.Container != actual.FsConfig.AzBlobConfig.Container { - return errors.New("Azure Blob container mismatch") +func compareAzBlobConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.AzBlobConfig.Container != actual.AzBlobConfig.Container { + return errors.New("azure Blob container mismatch") } - if expected.FsConfig.AzBlobConfig.AccountName != actual.FsConfig.AzBlobConfig.AccountName { - return errors.New("Azure Blob account name mismatch") + if expected.AzBlobConfig.AccountName != actual.AzBlobConfig.AccountName { + return errors.New("azure Blob account name mismatch") } - if err := checkEncryptedSecret(expected.FsConfig.AzBlobConfig.AccountKey, actual.FsConfig.AzBlobConfig.AccountKey); err != nil { - return fmt.Errorf("Azure Blob account key mismatch: %v", err) + if err := checkEncryptedSecret(expected.AzBlobConfig.AccountKey, actual.AzBlobConfig.AccountKey); err != nil { + return fmt.Errorf("azure Blob account key mismatch: %v", err) } - if expected.FsConfig.AzBlobConfig.Endpoint != actual.FsConfig.AzBlobConfig.Endpoint { - return errors.New("Azure Blob endpoint mismatch") + if expected.AzBlobConfig.Endpoint != actual.AzBlobConfig.Endpoint { + return errors.New("azure Blob endpoint mismatch") } - if expected.FsConfig.AzBlobConfig.SASURL != actual.FsConfig.AzBlobConfig.SASURL { - return errors.New("Azure Blob SASL URL mismatch") + if expected.AzBlobConfig.SASURL != actual.AzBlobConfig.SASURL { + return errors.New("azure Blob SASL URL mismatch") } - if expected.FsConfig.AzBlobConfig.UploadPartSize != actual.FsConfig.AzBlobConfig.UploadPartSize { - return errors.New("Azure Blob upload part size mismatch") + if expected.AzBlobConfig.UploadPartSize != actual.AzBlobConfig.UploadPartSize { + return errors.New("azure Blob upload part size mismatch") } - if expected.FsConfig.AzBlobConfig.UploadConcurrency != actual.FsConfig.AzBlobConfig.UploadConcurrency { - return errors.New("Azure Blob upload concurrency mismatch") + if expected.AzBlobConfig.UploadConcurrency != actual.AzBlobConfig.UploadConcurrency { + return errors.New("azure Blob upload concurrency mismatch") } - if expected.FsConfig.AzBlobConfig.KeyPrefix != actual.FsConfig.AzBlobConfig.KeyPrefix && - expected.FsConfig.AzBlobConfig.KeyPrefix+"/" != actual.FsConfig.AzBlobConfig.KeyPrefix { - return errors.New("Azure Blob key prefix mismatch") + if expected.AzBlobConfig.KeyPrefix != actual.AzBlobConfig.KeyPrefix && + expected.AzBlobConfig.KeyPrefix+"/" != actual.AzBlobConfig.KeyPrefix { + return errors.New("azure Blob key prefix mismatch") } - if expected.FsConfig.AzBlobConfig.UseEmulator != actual.FsConfig.AzBlobConfig.UseEmulator { - return errors.New("Azure Blob use emulator mismatch") + if expected.AzBlobConfig.UseEmulator != actual.AzBlobConfig.UseEmulator { + return errors.New("azure Blob use emulator mismatch") } - if expected.FsConfig.AzBlobConfig.AccessTier != actual.FsConfig.AzBlobConfig.AccessTier { - return errors.New("Azure Blob access tier mismatch") + if expected.AzBlobConfig.AccessTier != actual.AzBlobConfig.AccessTier { + return errors.New("azure Blob access tier mismatch") } return nil } @@ -1154,22 +1147,22 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error { func compareUserFilterSubStructs(expected *dataprovider.User, actual *dataprovider.User) error { for _, IPMask := range expected.Filters.AllowedIP { if !utils.IsStringInSlice(IPMask, actual.Filters.AllowedIP) { - return errors.New("AllowedIP contents mismatch") + return errors.New("allowed IP contents mismatch") } } for _, IPMask := range expected.Filters.DeniedIP { if !utils.IsStringInSlice(IPMask, actual.Filters.DeniedIP) { - return errors.New("DeniedIP contents mismatch") + return errors.New("denied IP contents mismatch") } } for _, method := range expected.Filters.DeniedLoginMethods { if !utils.IsStringInSlice(method, actual.Filters.DeniedLoginMethods) { - return errors.New("Denied login methods contents mismatch") + return errors.New("denied login methods contents mismatch") } } for _, protocol := range expected.Filters.DeniedProtocols { if !utils.IsStringInSlice(protocol, actual.Filters.DeniedProtocols) { - return errors.New("Denied protocols contents mismatch") + return errors.New("denied protocols contents mismatch") } } return nil @@ -1177,19 +1170,19 @@ func compareUserFilterSubStructs(expected *dataprovider.User, actual *dataprovid func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error { if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) { - return errors.New("AllowedIP mismatch") + return errors.New("allowed IP mismatch") } if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) { - return errors.New("DeniedIP mismatch") + return errors.New("denied IP mismatch") } if len(expected.Filters.DeniedLoginMethods) != len(actual.Filters.DeniedLoginMethods) { - return errors.New("Denied login methods mismatch") + return errors.New("denied login methods mismatch") } if len(expected.Filters.DeniedProtocols) != len(actual.Filters.DeniedProtocols) { - return errors.New("Denied protocols mismatch") + return errors.New("denied protocols mismatch") } if expected.Filters.MaxUploadFileSize != actual.Filters.MaxUploadFileSize { - return errors.New("Max upload file size mismatch") + return errors.New("max upload file size mismatch") } if expected.Filters.TLSUsername != actual.Filters.TLSUsername { return errors.New("TLSUsername mismatch") @@ -1261,10 +1254,10 @@ func compareUserFileExtensionsFilters(expected *dataprovider.User, actual *datap func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error { if expected.Username != actual.Username { - return errors.New("Username mismatch") + return errors.New("username mismatch") } if expected.HomeDir != actual.HomeDir { - return errors.New("HomeDir mismatch") + return errors.New("home dir mismatch") } if expected.UID != actual.UID { return errors.New("UID mismatch") @@ -1282,7 +1275,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U return errors.New("QuotaFiles mismatch") } if len(expected.Permissions) != len(actual.Permissions) { - return errors.New("Permissions mismatch") + return errors.New("permissions mismatch") } if expected.UploadBandwidth != actual.UploadBandwidth { return errors.New("UploadBandwidth mismatch") @@ -1291,7 +1284,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U return errors.New("DownloadBandwidth mismatch") } if expected.Status != actual.Status { - return errors.New("Status mismatch") + return errors.New("status mismatch") } if expected.ExpirationDate != actual.ExpirationDate { return errors.New("ExpirationDate mismatch") @@ -1300,7 +1293,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U return errors.New("AdditionalInfo mismatch") } if expected.Description != actual.Description { - return errors.New("Description mismatch") + return errors.New("description mismatch") } return nil } diff --git a/service/service_portable.go b/service/service_portable.go index 8ac41256..d79bcef9 100644 --- a/service/service_portable.go +++ b/service/service_portable.go @@ -21,6 +21,7 @@ import ( "github.com/drakkan/sftpgo/sftpd" "github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/version" + "github.com/drakkan/sftpgo/vfs" "github.com/drakkan/sftpgo/webdavd" ) @@ -229,9 +230,9 @@ func (s *Service) advertiseServices(advertiseService, advertiseCredentials bool) func (s *Service) getPortableDirToServe() string { var dirToServe string - if s.PortableUser.FsConfig.Provider == dataprovider.S3FilesystemProvider { + if s.PortableUser.FsConfig.Provider == vfs.S3FilesystemProvider { dirToServe = s.PortableUser.FsConfig.S3Config.KeyPrefix - } else if s.PortableUser.FsConfig.Provider == dataprovider.GCSFilesystemProvider { + } else if s.PortableUser.FsConfig.Provider == vfs.GCSFilesystemProvider { dirToServe = s.PortableUser.FsConfig.GCSConfig.KeyPrefix } else { dirToServe = s.PortableUser.HomeDir @@ -263,31 +264,31 @@ func (s *Service) configurePortableUser() string { func (s *Service) configurePortableSecrets() { // we created the user before to initialize the KMS so we need to create the secret here switch s.PortableUser.FsConfig.Provider { - case dataprovider.S3FilesystemProvider: + case vfs.S3FilesystemProvider: payload := s.PortableUser.FsConfig.S3Config.AccessSecret.GetPayload() s.PortableUser.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret() if payload != "" { s.PortableUser.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret(payload) } - case dataprovider.GCSFilesystemProvider: + case vfs.GCSFilesystemProvider: payload := s.PortableUser.FsConfig.GCSConfig.Credentials.GetPayload() s.PortableUser.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() if payload != "" { s.PortableUser.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(payload) } - case dataprovider.AzureBlobFilesystemProvider: + case vfs.AzureBlobFilesystemProvider: payload := s.PortableUser.FsConfig.AzBlobConfig.AccountKey.GetPayload() s.PortableUser.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret() if payload != "" { s.PortableUser.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret(payload) } - case dataprovider.CryptedFilesystemProvider: + case vfs.CryptedFilesystemProvider: payload := s.PortableUser.FsConfig.CryptConfig.Passphrase.GetPayload() s.PortableUser.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() if payload != "" { s.PortableUser.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(payload) } - case dataprovider.SFTPFilesystemProvider: + case vfs.SFTPFilesystemProvider: payload := s.PortableUser.FsConfig.SFTPConfig.Password.GetPayload() s.PortableUser.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() if payload != "" { diff --git a/sftpd/cryptfs_test.go b/sftpd/cryptfs_test.go index 0a5b21f4..82976f83 100644 --- a/sftpd/cryptfs_test.go +++ b/sftpd/cryptfs_test.go @@ -478,7 +478,7 @@ func getEncryptedFileSize(size int64) (int64, error) { func getTestUserWithCryptFs(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) - u.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + u.FsConfig.Provider = vfs.CryptedFilesystemProvider u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase) return u } diff --git a/sftpd/handler.go b/sftpd/handler.go index 2b620fdb..0cc6578d 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -54,19 +54,19 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { return nil, sftp.ErrSSHFxPermissionDenied } - p, err := c.Fs.ResolvePath(request.Filepath) + fs, p, err := c.GetFsAndResolvedPath(request.Filepath) if err != nil { - return nil, c.GetFsError(err) + return nil, err } - file, r, cancelFn, err := c.Fs.Open(p, 0) + file, r, cancelFn, err := fs.Open(p, 0) if err != nil { c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", p, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, request.Filepath, common.TransferDownload, - 0, 0, 0, false, c.Fs) + 0, 0, 0, false, fs) t := newTransfer(baseTransfer, nil, r, nil) return t, nil @@ -90,18 +90,18 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader return nil, sftp.ErrSSHFxPermissionDenied } - p, err := c.Fs.ResolvePath(request.Filepath) + fs, p, err := c.GetFsAndResolvedPath(request.Filepath) if err != nil { - return nil, c.GetFsError(err) + return nil, err } filePath := p - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - filePath = c.Fs.GetAtomicUploadPath(p) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(p) } var errForRead error - if !vfs.IsLocalOrSFTPFs(c.Fs) && request.Pflags().Read { + if !vfs.IsLocalOrSFTPFs(fs) && request.Pflags().Read { // read and write mode is only supported for local filesystem errForRead = sftp.ErrSSHFxOpUnsupported } @@ -112,17 +112,17 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader errForRead = os.ErrPermission } - stat, statErr := c.Fs.Lstat(p) - if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) { + stat, statErr := fs.Lstat(p) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } - return c.handleSFTPUploadToNewFile(p, filePath, request.Filepath, errForRead) + return c.handleSFTPUploadToNewFile(fs, p, filePath, request.Filepath, errForRead) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %#v: %+v", p, statErr) - return nil, c.GetFsError(statErr) + return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory @@ -135,7 +135,7 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader return nil, sftp.ErrSSHFxPermissionDenied } - return c.handleSFTPUploadToExistingFile(request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead) + return c.handleSFTPUploadToExistingFile(fs, request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead) } // Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading @@ -143,37 +143,29 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader func (c *Connection) Filecmd(request *sftp.Request) error { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(request.Filepath) - if err != nil { - return c.GetFsError(err) - } - target, err := c.getSFTPCmdTargetPath(request.Target) - if err != nil { - return c.GetFsError(err) - } - - c.Log(logger.LevelDebug, "new cmd, method: %v, sourcePath: %#v, targetPath: %#v", request.Method, p, target) + c.Log(logger.LevelDebug, "new cmd, method: %v, sourcePath: %#v, targetPath: %#v", request.Method, + request.Filepath, request.Target) switch request.Method { case "Setstat": - return c.handleSFTPSetstat(p, request) + return c.handleSFTPSetstat(request) case "Rename": - if err = c.Rename(p, target, request.Filepath, request.Target); err != nil { + if err := c.Rename(request.Filepath, request.Target); err != nil { return err } case "Rmdir": - return c.RemoveDir(p, request.Filepath) + return c.RemoveDir(request.Filepath) case "Mkdir": - err = c.CreateDir(p, request.Filepath) + err := c.CreateDir(request.Filepath) if err != nil { return err } case "Symlink": - if err = c.CreateSymlink(p, target, request.Filepath, request.Target); err != nil { + if err := c.CreateSymlink(request.Filepath, request.Target); err != nil { return err } case "Remove": - return c.handleSFTPRemove(p, request) + return c.handleSFTPRemove(request) default: return sftp.ErrSSHFxOpUnsupported } @@ -185,14 +177,10 @@ func (c *Connection) Filecmd(request *sftp.Request) error { // a directory as well as perform file/folder stat calls. func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(request.Filepath) - if err != nil { - return nil, c.GetFsError(err) - } switch request.Method { case "List": - files, err := c.ListDir(p, request.Filepath) + files, err := c.ListDir(request.Filepath) if err != nil { return nil, err } @@ -202,10 +190,10 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { return nil, sftp.ErrSSHFxPermissionDenied } - s, err := c.DoStat(p, 0) + s, err := c.DoStat(request.Filepath, 0) if err != nil { - c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err) - return nil, c.GetFsError(err) + c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", request.Filepath, err) + return nil, err } return listerAt([]os.FileInfo{s}), nil @@ -214,10 +202,15 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { return nil, sftp.ErrSSHFxPermissionDenied } - s, err := c.Fs.Readlink(p) + fs, p, err := c.GetFsAndResolvedPath(request.Filepath) + if err != nil { + return nil, err + } + + s, err := fs.Readlink(p) if err != nil { c.Log(logger.LevelDebug, "error running readlink on path %#v: %+v", p, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(s)) { @@ -235,19 +228,14 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { func (c *Connection) Lstat(request *sftp.Request) (sftp.ListerAt, error) { c.UpdateLastActivity() - p, err := c.Fs.ResolvePath(request.Filepath) - if err != nil { - return nil, c.GetFsError(err) - } - if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { return nil, sftp.ErrSSHFxPermissionDenied } - s, err := c.DoStat(p, 1) + s, err := c.DoStat(request.Filepath, 1) if err != nil { - c.Log(logger.LevelDebug, "error running lstat on path %#v: %+v", p, err) - return nil, c.GetFsError(err) + c.Log(logger.LevelDebug, "error running lstat on path %#v: %+v", request.Filepath, err) + return nil, err } return listerAt([]os.FileInfo{s}), nil @@ -263,43 +251,29 @@ func (c *Connection) StatVFS(r *sftp.Request) (*sftp.StatVFS, error) { // not the limit for a single file upload quotaResult := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt")) - p, err := c.Fs.ResolvePath(r.Filepath) + fs, p, err := c.GetFsAndResolvedPath(r.Filepath) if err != nil { - return nil, c.GetFsError(err) + return nil, err } if !quotaResult.HasSpace { - return c.getStatVFSFromQuotaResult(p, quotaResult), nil + return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil } if quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 { // no quota restrictions - statvfs, err := c.Fs.GetAvailableDiskSize(p) + statvfs, err := fs.GetAvailableDiskSize(p) if err == vfs.ErrStorageSizeUnavailable { - return c.getStatVFSFromQuotaResult(p, quotaResult), nil + return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil } return statvfs, err } // there is free space but some limits are configured - return c.getStatVFSFromQuotaResult(p, quotaResult), nil + return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil } -func (c *Connection) getSFTPCmdTargetPath(requestTarget string) (string, error) { - var target string - // If a target is provided in this request validate that it is going to the correct - // location for the server. If it is not, return an error - if requestTarget != "" { - var err error - target, err = c.Fs.ResolvePath(requestTarget) - if err != nil { - return target, err - } - } - return target, nil -} - -func (c *Connection) handleSFTPSetstat(filePath string, request *sftp.Request) error { +func (c *Connection) handleSFTPSetstat(request *sftp.Request) error { attrs := common.StatAttributes{ Flags: 0, } @@ -322,50 +296,54 @@ func (c *Connection) handleSFTPSetstat(filePath string, request *sftp.Request) e attrs.Size = int64(request.Attributes().Size) } - return c.SetStat(filePath, request.Filepath, &attrs) + return c.SetStat(request.Filepath, &attrs) } -func (c *Connection) handleSFTPRemove(filePath string, request *sftp.Request) error { +func (c *Connection) handleSFTPRemove(request *sftp.Request) error { + fs, fsPath, err := c.GetFsAndResolvedPath(request.Filepath) + if err != nil { + return err + } + var fi os.FileInfo - var err error - if fi, err = c.Fs.Lstat(filePath); err != nil { - c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", filePath, err) - return c.GetFsError(err) + if fi, err = fs.Lstat(fsPath); err != nil { + c.Log(logger.LevelDebug, "failed to remove a file %#v: stat error: %+v", fsPath, err) + return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { - c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", filePath) + c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", fsPath) return sftp.ErrSSHFxFailure } - return c.RemoveFile(filePath, request.Filepath, fi) + return c.RemoveFile(fs, fsPath, request.Filepath, fi) } -func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { +func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { quotaResult := c.HasSpace(true, false, requestPath) if !quotaResult.HasSpace { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, sftp.ErrSSHFxFailure } - file, w, cancelFn, err := c.Fs.Create(filePath, 0) + file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) + maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) + common.TransferUpload, 0, 0, maxWriteSize, true, fs) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil } -func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, resolvedPath, filePath string, +func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath string, fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { var err error quotaResult := c.HasSpace(false, false, requestPath) @@ -382,25 +360,25 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r // if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before. // For Cloud FS GetMaxWriteSize will return unsupported operation - maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize) + maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported()) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) return nil, err } - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - err = c.Fs.Rename(resolvedPath, filePath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + err = fs.Rename(resolvedPath, filePath) if err != nil { c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v", resolvedPath, filePath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } } - file, w, cancelFn, err := c.Fs.Create(filePath, osFlags) + file, w, cancelFn, err := fs.Create(filePath, osFlags) if err != nil { c.Log(logger.LevelWarn, "error opening existing file, flags: %v, source: %#v, err: %+v", pflags, filePath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } initialSize := int64(0) @@ -409,7 +387,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r minWriteOffset = fileSize initialSize = fileSize } else { - if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate { + if vfs.IsLocalOrSFTPFs(fs) && isTruncate { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck @@ -424,10 +402,10 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r } } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs) t := newTransfer(baseTransfer, w, nil, errForRead) return t, nil @@ -438,9 +416,9 @@ func (c *Connection) Disconnect() error { return c.channel.Close() } -func (c *Connection) getStatVFSFromQuotaResult(name string, quotaResult vfs.QuotaCheckResult) *sftp.StatVFS { +func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResult vfs.QuotaCheckResult) *sftp.StatVFS { if quotaResult.QuotaSize == 0 || quotaResult.QuotaFiles == 0 { - s, err := c.Fs.GetAvailableDiskSize(name) + s, err := fs.GetAvailableDiskSize(name) if err == nil { if quotaResult.QuotaSize == 0 { quotaResult.QuotaSize = int64(s.TotalSpace()) diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index b25aa704..b6f10091 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -20,6 +20,7 @@ import ( "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/kms" "github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/vfs" ) @@ -124,7 +125,7 @@ func (fs MockOsFs) Rename(source, target string) error { func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { return &MockOsFs{ - Fs: vfs.NewOsFs(connectionID, rootDir, nil), + Fs: vfs.NewOsFs(connectionID, rootDir, ""), err: err, statErr: statErr, isAtomicUploadSupported: atomicUpload, @@ -156,15 +157,15 @@ func TestUploadResumeInvalidOffset(t *testing.T) { user := dataprovider.User{ Username: "testuser", } - fs := vfs.NewOsFs("", os.TempDir(), nil) - conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := common.NewBaseConnection("", common.ProtocolSFTP, user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "upload with invalid offset must fail") if assert.Error(t, transfer.ErrTransfer) { assert.EqualError(t, err, transfer.ErrTransfer.Error()) - assert.Contains(t, transfer.ErrTransfer.Error(), "Invalid write offset") + assert.Contains(t, transfer.ErrTransfer.Error(), "invalid write offset") } err = transfer.Close() @@ -184,8 +185,8 @@ func TestReadWriteErrors(t *testing.T) { user := dataprovider.User{ Username: "testuser", } - fs := vfs.NewOsFs("", os.TempDir(), nil) - conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := common.NewBaseConnection("", common.ProtocolSFTP, user) baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) err = file.Close() @@ -233,8 +234,7 @@ func TestReadWriteErrors(t *testing.T) { } func TestUnsupportedListOP(t *testing.T) { - fs := vfs.NewOsFs("", os.TempDir(), nil) - conn := common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs) + conn := common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}) sftpConn := Connection{ BaseConnection: conn, } @@ -254,8 +254,8 @@ func TestTransferCancelFn(t *testing.T) { user := dataprovider.User{ Username: "testuser", } - fs := vfs.NewOsFs("", os.TempDir(), nil) - conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) + fs := vfs.NewOsFs("", os.TempDir(), "") + conn := common.NewBaseConnection("", common.ProtocolSFTP, user) baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) transfer := newTransfer(baseTransfer, nil, nil, nil) @@ -274,77 +274,37 @@ func TestTransferCancelFn(t *testing.T) { assert.NoError(t, err) } -func TestMockFsErrors(t *testing.T) { - errFake := errors.New("fake error") - fs := newMockOsFs(errFake, errFake, false, "123", os.TempDir()) - u := dataprovider.User{} - u.Username = "test_username" - u.Permissions = make(map[string][]string) - u.Permissions["/"] = []string{dataprovider.PermAny} - u.HomeDir = os.TempDir() - c := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs), - } - testfile := filepath.Join(u.HomeDir, "testfile") - request := sftp.NewRequest("Remove", testfile) - err := os.WriteFile(testfile, []byte("test"), os.ModePerm) - assert.NoError(t, err) - _, err = c.Filewrite(request) - assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error()) - - var flags sftp.FileOpenFlags - flags.Write = true - flags.Trunc = false - flags.Append = true - _, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile", nil) - assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) - - fs = newMockOsFs(errFake, nil, false, "123", os.TempDir()) - c.BaseConnection.Fs = fs - err = c.handleSFTPRemove(testfile, request) - assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error()) - - request = sftp.NewRequest("Rename", filepath.Base(testfile)) - request.Target = filepath.Base(testfile) + "1" - err = c.Rename(testfile, testfile+"1", request.Filepath, request.Target) - assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error()) - - err = os.Remove(testfile) - assert.NoError(t, err) -} - func TestUploadFiles(t *testing.T) { - oldUploadMode := common.Config.UploadMode common.Config.UploadMode = common.UploadModeAtomic - fs := vfs.NewOsFs("123", os.TempDir(), nil) + fs := vfs.NewOsFs("123", os.TempDir(), "") u := dataprovider.User{} c := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u), } var flags sftp.FileOpenFlags flags.Write = true flags.Trunc = true - _, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) + _, err := c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") common.Config.UploadMode = common.UploadModeStandard - _, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) + _, err = c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") missingFile := "missing/relative/file.txt" if runtime.GOOS == osWindows { missingFile = "missing\\relative\\file.txt" } - _, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing", nil) + _, err = c.handleSFTPUploadToNewFile(fs, ".", missingFile, "/missing", nil) assert.Error(t, err, "upload new file in missing path must fail") - c.BaseConnection.Fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) + fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) f, err := os.CreateTemp("", "temp") assert.NoError(t, err) err = f.Close() assert.NoError(t, err) - tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name(), nil) + tr, err := c.handleSFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name(), nil) if assert.NoError(t, err) { transfer := tr.(*transfer) transfers := c.GetTransfers() @@ -358,7 +318,7 @@ func TestUploadFiles(t *testing.T) { } err = os.Remove(f.Name()) assert.NoError(t, err) - common.Config.UploadMode = oldUploadMode + common.Config.UploadMode = common.UploadModeAtomicWithResume } func TestWithInvalidHome(t *testing.T) { @@ -371,9 +331,9 @@ func TestWithInvalidHome(t *testing.T) { fs, err := u.GetFilesystem("123") assert.NoError(t, err) c := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u), } - _, err = c.Fs.ResolvePath("../upper_path") + _, err = fs.ResolvePath("../upper_path") assert.Error(t, err, "tested path is not a home subdir") _, err = c.StatVFS(&sftp.Request{ Method: "StatVFS", @@ -382,25 +342,6 @@ func TestWithInvalidHome(t *testing.T) { assert.Error(t, err) } -func TestSFTPCmdTargetPath(t *testing.T) { - u := dataprovider.User{} - if runtime.GOOS == osWindows { - u.HomeDir = "C:\\invalid_home" - } else { - u.HomeDir = "/invalid_home" - } - u.Username = "testuser" - u.Permissions = make(map[string][]string) - u.Permissions["/"] = []string{dataprovider.PermAny} - fs, err := u.GetFilesystem("123") - assert.NoError(t, err) - connection := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs), - } - _, err = connection.getSFTPCmdTargetPath("invalid_path") - assert.True(t, os.IsNotExist(err)) -} - func TestSFTPGetUsedQuota(t *testing.T) { u := dataprovider.User{} u.HomeDir = "home_rel_path" @@ -410,7 +351,7 @@ func TestSFTPGetUsedQuota(t *testing.T) { u.Permissions = make(map[string][]string) u.Permissions["/"] = []string{dataprovider.PermAny} connection := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, nil), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u), } quotaResult := connection.HasSpace(false, false, "/") assert.False(t, quotaResult.HasSpace) @@ -506,10 +447,8 @@ func TestSSHCommandErrors(t *testing.T) { user := dataprovider.User{} user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) connection := Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, user), channel: &mockSSHChannel, } cmd := sshCommand{ @@ -517,7 +456,7 @@ func TestSSHCommandErrors(t *testing.T) { connection: &connection, args: []string{}, } - err = cmd.handle() + err := cmd.handle() assert.Error(t, err, "ssh command must fail, we are sending a fake error") cmd = sshCommand{ @@ -539,9 +478,8 @@ func TestSSHCommandErrors(t *testing.T) { cmd.connection.User.HomeDir = filepath.Clean(os.TempDir()) cmd.connection.User.QuotaFiles = 1 cmd.connection.User.UsedQuotaFiles = 2 - fs, err = cmd.connection.User.GetFilesystem("123") + fs, err := cmd.connection.User.GetFilesystem("123") assert.NoError(t, err) - cmd.connection.Fs = fs err = cmd.handle() assert.EqualError(t, err, common.ErrQuotaExceeded.Error()) @@ -580,10 +518,6 @@ func TestSSHCommandErrors(t *testing.T) { err = cmd.executeSystemCommand(command) assert.Error(t, err, "command must fail, pipe was already assigned") - fs, err = user.GetFilesystem("123") - assert.NoError(t, err) - connection.Fs = fs - cmd = sshCommand{ command: "sftpgo-remove", connection: &connection, @@ -601,22 +535,12 @@ func TestSSHCommandErrors(t *testing.T) { assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") cmd.connection.User.HomeDir = filepath.Clean(os.TempDir()) - fs, err = cmd.connection.User.GetFilesystem("123") - assert.NoError(t, err) - cmd.connection.Fs = fs - _, _, err = cmd.resolveCopyPaths(".", "../adir") - assert.Error(t, err) cmd = sshCommand{ command: "sftpgo-copy", connection: &connection, args: []string{"src", "dst"}, } - cmd.connection.User.Permissions = make(map[string][]string) - cmd.connection.User.Permissions["/"] = []string{dataprovider.PermDownload} - src, dst, err := cmd.getCopyPaths() - assert.NoError(t, err) - assert.False(t, cmd.hasCopyPermissions(src, dst, nil)) cmd.connection.User.Permissions = make(map[string][]string) cmd.connection.User.Permissions["/"] = []string{dataprovider.PermAny} @@ -629,7 +553,7 @@ func TestSSHCommandErrors(t *testing.T) { assert.NoError(t, err) err = os.Chmod(aDir, 0001) assert.NoError(t, err) - err = cmd.checkCopyDestination(tmpFile) + err = cmd.checkCopyDestination(fs, tmpFile) assert.Error(t, err) err = os.Chmod(aDir, os.ModePerm) assert.NoError(t, err) @@ -661,10 +585,8 @@ func TestCommandsWithExtensionsFilter(t *testing.T) { }, } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, user), channel: &mockSSHChannel, } cmd := sshCommand{ @@ -672,7 +594,7 @@ func TestCommandsWithExtensionsFilter(t *testing.T) { connection: connection, args: []string{"subdir/test.png"}, } - err = cmd.handleHashCommands() + err := cmd.handleHashCommands() assert.EqualError(t, err, common.ErrPermissionDenied.Error()) cmd = sshCommand{ @@ -715,28 +637,17 @@ func TestSSHCommandsRemoteFs(t *testing.T) { Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } - server, client := net.Pipe() - defer func() { - err := server.Close() - assert.NoError(t, err) - }() - defer func() { - err := client.Close() - assert.NoError(t, err) - }() user := dataprovider.User{} - user.FsConfig = dataprovider.Filesystem{ - Provider: dataprovider.S3FilesystemProvider, + user.FsConfig = vfs.Filesystem{ + Provider: vfs.S3FilesystemProvider, S3Config: vfs.S3FsConfig{ Bucket: "s3bucket", Endpoint: "endpoint", Region: "eu-west-1", }, } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), channel: &mockSSHChannel, } cmd := sshCommand{ @@ -755,15 +666,73 @@ func TestSSHCommandsRemoteFs(t *testing.T) { connection: connection, args: []string{}, } - err = cmd.handeSFTPGoCopy() + err = cmd.handleSFTPGoCopy() assert.Error(t, err) cmd = sshCommand{ command: "sftpgo-remove", connection: connection, args: []string{}, } - err = cmd.handeSFTPGoRemove() + err = cmd.handleSFTPGoRemove() assert.Error(t, err) + // the user has no permissions + assert.False(t, cmd.hasCopyPermissions("", "", nil)) +} + +func TestSSHCmdGetFsErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + user := dataprovider.User{ + HomeDir: "relative path", + } + user.Permissions = map[string][]string{} + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), + channel: &mockSSHChannel, + } + cmd := sshCommand{ + command: "sftpgo-remove", + connection: connection, + args: []string{"path"}, + } + err := cmd.handleSFTPGoRemove() + assert.Error(t, err) + + cmd = sshCommand{ + command: "sftpgo-copy", + connection: connection, + args: []string{"path1", "path2"}, + } + _, _, _, _, _, _, err = cmd.getFsAndCopyPaths() //nolint:dogsled + assert.Error(t, err) + user = dataprovider.User{} + user.HomeDir = filepath.Join(os.TempDir(), "home") + user.VirtualFolders = append(connection.User.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: "relative", + }, + VirtualPath: "/vpath", + }) + connection.User = user + + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + + cmd = sshCommand{ + command: "sftpgo-copy", + connection: connection, + args: []string{"path1", "/vpath/path2"}, + } + _, _, _, _, _, _, err = cmd.getFsAndCopyPaths() //nolint:dogsled + assert.Error(t, err) + + err = os.Remove(user.GetHomeDir()) + assert.NoError(t, err) } func TestGitVirtualFolders(t *testing.T) { @@ -773,10 +742,8 @@ func TestGitVirtualFolders(t *testing.T) { Permissions: permissions, HomeDir: os.TempDir(), } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) conn := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } cmd := sshCommand{ command: "git-receive-pack", @@ -789,7 +756,7 @@ func TestGitVirtualFolders(t *testing.T) { }, VirtualPath: "/vdir", }) - _, err = cmd.getSystemCommand() + _, err := cmd.getSystemCommand() assert.NoError(t, err) cmd.args = []string{"/"} _, err = cmd.getSystemCommand() @@ -821,10 +788,8 @@ func TestRsyncOptions(t *testing.T) { Permissions: permissions, HomeDir: os.TempDir(), } - fs, err := user.GetFilesystem("123") - assert.NoError(t, err) conn := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } sshCmd := sshCommand{ command: "rsync", @@ -839,11 +804,9 @@ func TestRsyncOptions(t *testing.T) { permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermListItems, dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermRename} user.Permissions = permissions - fs, err = user.GetFilesystem("123") - assert.NoError(t, err) conn = &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } sshCmd = sshCommand{ command: "rsync", @@ -875,14 +838,14 @@ func TestSystemCommandSizeForPath(t *testing.T) { fs, err := user.GetFilesystem("123") assert.NoError(t, err) conn := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } sshCmd := sshCommand{ command: "rsync", connection: conn, args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"}, } - _, _, err = sshCmd.getSizeForPath("missing path") + _, _, err = sshCmd.getSizeForPath(fs, "missing path") assert.NoError(t, err) testDir := filepath.Join(os.TempDir(), "dir") err = os.MkdirAll(testDir, os.ModePerm) @@ -892,18 +855,18 @@ func TestSystemCommandSizeForPath(t *testing.T) { assert.NoError(t, err) err = os.Symlink(testFile, testFile+".link") assert.NoError(t, err) - numFiles, size, err := sshCmd.getSizeForPath(testFile + ".link") + numFiles, size, err := sshCmd.getSizeForPath(fs, testFile+".link") assert.NoError(t, err) assert.Equal(t, 0, numFiles) assert.Equal(t, int64(0), size) - numFiles, size, err = sshCmd.getSizeForPath(testFile) + numFiles, size, err = sshCmd.getSizeForPath(fs, testFile) assert.NoError(t, err) assert.Equal(t, 1, numFiles) assert.Equal(t, int64(12), size) if runtime.GOOS != osWindows { err = os.Chmod(testDir, 0001) assert.NoError(t, err) - _, _, err = sshCmd.getSizeForPath(testFile) + _, _, err = sshCmd.getSizeForPath(fs, testFile) assert.Error(t, err) err = os.Chmod(testDir, os.ModePerm) assert.NoError(t, err) @@ -923,15 +886,6 @@ func TestSystemCommandErrors(t *testing.T) { ReadError: nil, WriteError: writeErr, } - server, client := net.Pipe() - defer func() { - err := server.Close() - assert.NoError(t, err) - }() - defer func() { - err := client.Close() - assert.NoError(t, err) - }() permissions := make(map[string][]string) permissions["/"] = []string{dataprovider.PermAny} homeDir := filepath.Join(os.TempDir(), "adir") @@ -946,7 +900,7 @@ func TestSystemCommandErrors(t *testing.T) { fs, err := user.GetFilesystem("123") assert.NoError(t, err) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), channel: &mockSSHChannel, } var sshCmd sshCommand @@ -1015,6 +969,48 @@ func TestSystemCommandErrors(t *testing.T) { assert.NoError(t, err) } +func TestCommandGetFsError(t *testing.T) { + user := dataprovider.User{ + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + }, + } + conn := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), + } + sshCmd := sshCommand{ + command: "rsync", + connection: conn, + args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"}, + } + _, err := sshCmd.getSystemCommand() + assert.Error(t, err) + + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + } + conn = &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: conn, + args: []string{"-t", "/tmp"}, + }, + } + + err = scpCommand.handleRecursiveUpload() + assert.Error(t, err) + err = scpCommand.handleDownload("") + assert.Error(t, err) +} + func TestGetConnectionInfo(t *testing.T) { c := common.ConnectionStatus{ Username: "test_user", @@ -1088,10 +1084,9 @@ func TestSCPUploadError(t *testing.T) { Permissions: make(map[string][]string), } user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("", user.HomeDir, nil) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1129,10 +1124,11 @@ func TestSCPInvalidEndDir(t *testing.T) { Buffer: bytes.NewBuffer([]byte("E\n")), StdErrBuffer: bytes.NewBuffer(stdErrBuf), } - fs := vfs.NewOsFs("", os.TempDir(), nil) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs), - channel: &mockSSHChannel, + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{ + HomeDir: os.TempDir(), + }), + channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1153,10 +1149,12 @@ func TestSCPParseUploadMessage(t *testing.T) { StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, } - fs := vfs.NewOsFs("", os.TempDir(), nil) + fs := vfs.NewOsFs("", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs), - channel: &mockSSHChannel, + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{ + HomeDir: os.TempDir(), + }), + channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1165,16 +1163,16 @@ func TestSCPParseUploadMessage(t *testing.T) { args: []string{"-t", "/tmp"}, }, } - _, _, err := scpCommand.parseUploadMessage("invalid") + _, _, err := scpCommand.parseUploadMessage(fs, "invalid") assert.Error(t, err, "parsing invalid upload message must fail") - _, _, err = scpCommand.parseUploadMessage("D0755 0") + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0") assert.Error(t, err, "parsing incomplete upload message must fail") - _, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir") + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 invalidsize testdir") assert.Error(t, err, "parsing upload message with invalid size must fail") - _, _, err = scpCommand.parseUploadMessage("D0755 0 ") + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0 ") assert.Error(t, err, "parsing upload message with invalid name must fail") } @@ -1190,7 +1188,7 @@ func TestSCPProtocolMessages(t *testing.T) { WriteError: writeErr, } connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1251,7 +1249,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) { WriteError: writeErr, } connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1325,7 +1323,7 @@ func TestSCPCommandHandleErrors(t *testing.T) { assert.NoError(t, err) }() connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1344,7 +1342,6 @@ func TestSCPCommandHandleErrors(t *testing.T) { func TestSCPErrorsMockFs(t *testing.T) { errFake := errors.New("fake error") - fs := newMockOsFs(errFake, errFake, false, "1234", os.TempDir()) u := dataprovider.User{} u.Username = "test" u.Permissions = make(map[string][]string) @@ -1367,7 +1364,7 @@ func TestSCPErrorsMockFs(t *testing.T) { }() connection := &Connection{ channel: &mockSSHChannel, - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u), } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1376,22 +1373,12 @@ func TestSCPErrorsMockFs(t *testing.T) { args: []string{"-r", "-t", "/tmp"}, }, } - err := scpCommand.handleUpload("test", 0) - assert.EqualError(t, err, errFake.Error()) - testfile := filepath.Join(u.HomeDir, "testfile") - err = os.WriteFile(testfile, []byte("test"), os.ModePerm) + err := os.WriteFile(testfile, []byte("test"), os.ModePerm) assert.NoError(t, err) - stat, err := os.Stat(u.HomeDir) - assert.NoError(t, err) - err = scpCommand.handleRecursiveDownload(u.HomeDir, stat) - assert.EqualError(t, err, errFake.Error()) - scpCommand.sshCommand.connection.Fs = newMockOsFs(errFake, nil, true, "123", os.TempDir()) - err = scpCommand.handleUpload(filepath.Base(testfile), 0) - assert.EqualError(t, err, errFake.Error()) - - err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4, "/testfile") + fs := newMockOsFs(errFake, nil, true, "123", os.TempDir()) + err = scpCommand.handleUploadFile(fs, testfile, testfile, 0, false, 4, "/testfile") assert.NoError(t, err) err = os.Remove(testfile) assert.NoError(t, err) @@ -1417,10 +1404,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { err := client.Close() assert.NoError(t, err) }() - fs := vfs.NewOsFs("123", os.TempDir(), nil) + fs := vfs.NewOsFs("123", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs), - channel: &mockSSHChannel, + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{ + HomeDir: os.TempDir(), + }), + channel: &mockSSHChannel, } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1434,7 +1423,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { assert.NoError(t, err) stat, err := os.Stat(path) assert.NoError(t, err) - err = scpCommand.handleRecursiveDownload("invalid_dir", stat) + err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) assert.EqualError(t, err, writeErr.Error()) mockSSHChannel = MockChannel{ @@ -1444,7 +1433,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { WriteError: nil, } scpCommand.connection.channel = &mockSSHChannel - err = scpCommand.handleRecursiveDownload("invalid_dir", stat) + err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) assert.Error(t, err, "recursive upload download must fail for a non existing dir") err = os.Remove(path) @@ -1463,7 +1452,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) { WriteError: writeErr, } connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1504,7 +1493,7 @@ func TestSCPCreateDirs(t *testing.T) { fs, err := u.GetFilesystem("123") assert.NoError(t, err) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1514,7 +1503,7 @@ func TestSCPCreateDirs(t *testing.T) { args: []string{"-r", "-t", "/tmp"}, }, } - err = scpCommand.handleCreateDir("invalid_dir") + err = scpCommand.handleCreateDir(fs, "invalid_dir") assert.Error(t, err, "create invalid dir must fail") } @@ -1536,10 +1525,10 @@ func TestSCPDownloadFileData(t *testing.T) { ReadError: nil, WriteError: writeErr, } + fs := vfs.NewOsFs("", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, - vfs.NewOsFs("", os.TempDir(), nil)), - channel: &mockSSHChannelReadErr, + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{HomeDir: os.TempDir()}), + channel: &mockSSHChannelReadErr, } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1552,19 +1541,19 @@ func TestSCPDownloadFileData(t *testing.T) { assert.NoError(t, err) stat, err := os.Stat(testfile) assert.NoError(t, err) - err = scpCommand.sendDownloadFileData(testfile, stat, nil) + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, readErr.Error()) scpCommand.connection.channel = &mockSSHChannelWriteErr - err = scpCommand.sendDownloadFileData(testfile, stat, nil) + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, writeErr.Error()) scpCommand.args = []string{"-r", "-p", "-f", "/tmp"} - err = scpCommand.sendDownloadFileData(testfile, stat, nil) + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, writeErr.Error()) scpCommand.connection.channel = &mockSSHChannelReadErr - err = scpCommand.sendDownloadFileData(testfile, stat, nil) + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) assert.EqualError(t, err, readErr.Error()) err = os.Remove(testfile) @@ -1586,9 +1575,9 @@ func TestSCPUploadFiledata(t *testing.T) { user := dataprovider.User{ Username: "testuser", } - fs := vfs.NewOsFs("", os.TempDir(), nil) + fs := vfs.NewOsFs("", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user), channel: &mockSSHChannel, } scpCommand := scpCommand{ @@ -1670,15 +1659,14 @@ func TestSCPUploadFiledata(t *testing.T) { } func TestUploadError(t *testing.T) { - oldUploadMode := common.Config.UploadMode common.Config.UploadMode = common.UploadModeAtomic user := dataprovider.User{ Username: "testuser", } - fs := vfs.NewOsFs("", os.TempDir(), nil) + fs := vfs.NewOsFs("", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user), } testfile := "testfile" @@ -1703,19 +1691,26 @@ func TestUploadError(t *testing.T) { assert.NoFileExists(t, testfile) assert.NoFileExists(t, fileTempName) - common.Config.UploadMode = oldUploadMode + common.Config.UploadMode = common.UploadModeAtomicWithResume } func TestTransferFailingReader(t *testing.T) { user := dataprovider.User{ Username: "testuser", + HomeDir: os.TempDir(), + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("crypt secret"), + }, + }, } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} fs := newMockOsFs(nil, nil, true, "", os.TempDir()) connection := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } request := sftp.NewRequest("Open", "afile.txt") @@ -1874,7 +1869,7 @@ func TestRecursiveCopyErrors(t *testing.T) { fs, err := user.GetFilesystem("123") assert.NoError(t, err) conn := &Connection{ - BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user), } sshCmd := sshCommand{ command: "sftpgo-copy", @@ -1882,7 +1877,7 @@ func TestRecursiveCopyErrors(t *testing.T) { args: []string{"adir", "another"}, } // try to copy a missing directory - err = sshCmd.checkRecursiveCopyPermissions("adir", "another", "/another") + err = sshCmd.checkRecursiveCopyPermissions(fs, fs, "adir", "another", "/another") assert.Error(t, err) } @@ -1893,10 +1888,10 @@ func TestSFTPSubSystem(t *testing.T) { Permissions: permissions, HomeDir: os.TempDir(), } - user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider err := ServeSubSystemConnection(user, "connID", nil, nil) assert.Error(t, err) - user.FsConfig.Provider = dataprovider.LocalFilesystemProvider + user.FsConfig.Provider = vfs.LocalFilesystemProvider buf := make([]byte, 0, 4096) stdErrBuf := make([]byte, 0, 4096) @@ -1923,7 +1918,7 @@ func TestRecoverer(t *testing.T) { c.AcceptInboundConnection(nil, nil) connID := "connectionID" connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}), } c.handleSftpConnection(nil, connection) sshCmd := sshCommand{ diff --git a/sftpd/scp.go b/sftpd/scp.go index d0c2b692..5909c13e 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -76,36 +76,51 @@ func (c *scpCommand) handleRecursiveUpload() error { numDirs := 0 destPath := c.getDestPath() for { + fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) + if err != nil { + c.connection.Log(logger.LevelWarn, "error uploading file %#v: %+v", destPath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath)) + return err + } command, err := c.getNextUploadProtocolMessage() if err != nil { if errors.Is(err, io.EOF) { return nil } + c.sendErrorMessage(fs, err) return err } if strings.HasPrefix(command, "E") { numDirs-- c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs) if numDirs < 0 { - return errors.New("unacceptable end dir command") + err = errors.New("unacceptable end dir command") + c.sendErrorMessage(nil, err) + return err } // the destination dir is now the parent directory destPath = path.Join(destPath, "..") } else { - sizeToRead, name, err := c.parseUploadMessage(command) + sizeToRead, name, err := c.parseUploadMessage(fs, command) if err != nil { return err } if strings.HasPrefix(command, "D") { numDirs++ destPath = path.Join(destPath, name) - err = c.handleCreateDir(destPath) + fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) + if err != nil { + c.connection.Log(logger.LevelWarn, "error uploading file %#v: %+v", destPath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath)) + return err + } + err = c.handleCreateDir(fs, destPath) if err != nil { return err } c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %#v", numDirs, destPath) } else if strings.HasPrefix(command, "C") { - err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead) + err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead) if err != nil { return err } @@ -118,21 +133,27 @@ func (c *scpCommand) handleRecursiveUpload() error { } } -func (c *scpCommand) handleCreateDir(dirPath string) error { +func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error { c.connection.UpdateLastActivity() - p, err := c.connection.Fs.ResolvePath(dirPath) + + p, err := fs.ResolvePath(dirPath) if err != nil { c.connection.Log(logger.LevelWarn, "error creating dir: %#v, invalid file path, err: %v", dirPath, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) { c.connection.Log(logger.LevelWarn, "error creating dir: %#v, permission denied", dirPath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - err = c.createDir(p) + info, err := c.connection.DoStat(dirPath, 1) + if err == nil && info.IsDir() { + return nil + } + + err = c.createDir(fs, p) if err != nil { return err } @@ -156,14 +177,14 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err for { n, err := c.connection.channel.Read(buf) if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(transfer.Fs, err) transfer.TransferError(err) transfer.Close() return err } _, err = transfer.WriteAt(buf[:n], sizeToRead-remaining) if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(transfer.Fs, err) transfer.Close() return err } @@ -184,33 +205,33 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err } err = transfer.Close() if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(transfer.Fs, err) return err } return nil } -func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { +func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { quotaResult := c.connection.HasSpace(isNewFile, false, requestPath) if !quotaResult.HasSpace { err := fmt.Errorf("denying file write due to quota limits") c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", filePath, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } - maxWriteSize, _ := c.connection.GetMaxWriteSize(quotaResult, false, fileSize) + maxWriteSize, _ := c.connection.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported()) - file, w, cancelFn, err := c.connection.Fs.Create(filePath, 0) + file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { c.connection.Log(logger.LevelError, "error creating file %#v: %v", resolvedPath, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } initialSize := int64(0) if !isNewFile { - if vfs.IsLocalOrSFTPFs(c.connection.Fs) { + if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck @@ -228,10 +249,10 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead } } - vfs.SetPathPermissions(c.connection.Fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, c.connection.Fs) + common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs) t := newTransfer(baseTransfer, w, nil, nil) return c.getUploadFileData(sizeToRead, t) @@ -240,67 +261,66 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error { c.connection.UpdateLastActivity() - var err error + fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath) + if err != nil { + c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", uploadFilePath, err) + c.sendErrorMessage(nil, err) + return err + } if !c.connection.User.IsFileAllowed(uploadFilePath) { c.connection.Log(logger.LevelWarn, "writing file %#v is not allowed", uploadFilePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - p, err := c.connection.Fs.ResolvePath(uploadFilePath) - if err != nil { - c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", uploadFilePath, err) - c.sendErrorMessage(err) - return err - } filePath := p - if common.Config.IsAtomicUploadEnabled() && c.connection.Fs.IsAtomicUploadSupported() { - filePath = c.connection.Fs.GetAtomicUploadPath(p) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(p) } - stat, statErr := c.connection.Fs.Lstat(p) - if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.connection.Fs.IsNotExist(statErr) { + stat, statErr := fs.Lstat(p) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) { c.connection.Log(logger.LevelWarn, "cannot upload file: %#v, permission denied", uploadFilePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - return c.handleUploadFile(p, filePath, sizeToRead, true, 0, uploadFilePath) + return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath) } if statErr != nil { c.connection.Log(logger.LevelError, "error performing file stat %#v: %v", p, statErr) - c.sendErrorMessage(statErr) + c.sendErrorMessage(fs, statErr) return statErr } if stat.IsDir() { c.connection.Log(logger.LevelWarn, "attempted to open a directory for writing to: %#v", p) - err = fmt.Errorf("Attempted to open a directory for writing: %#v", p) - c.sendErrorMessage(err) + err = fmt.Errorf("attempted to open a directory for writing: %#v", p) + c.sendErrorMessage(fs, err) return err } if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) { c.connection.Log(logger.LevelWarn, "cannot overwrite file: %#v, permission denied", uploadFilePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - if common.Config.IsAtomicUploadEnabled() && c.connection.Fs.IsAtomicUploadSupported() { - err = c.connection.Fs.Rename(p, filePath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + err = fs.Rename(p, filePath) if err != nil { c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %v", p, filePath, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } } - return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size(), uploadFilePath) + return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath) } -func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error { +func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error { var err error if c.sendFileTime() { modTime := stat.ModTime().UnixNano() / 1000000000 @@ -315,12 +335,9 @@ func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileIn } } - dirName := filepath.Base(dirPath) - for _, v := range c.connection.User.VirtualFolders { - if v.MappedPath == dirPath { - dirName = path.Base(v.VirtualPath) - break - } + dirName := path.Base(virtualDirPath) + if dirName == "/" || dirName == "." { + dirName = c.connection.User.Username } fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName) @@ -334,23 +351,23 @@ func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileIn // We send first all the files in the root directory and then the directories. // For each directory we recursively call this method again -func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) error { +func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error { var err error if c.isRecursive() { - c.connection.Log(logger.LevelDebug, "recursive download, dir path: %#v", dirPath) - err = c.sendDownloadProtocolMessages(dirPath, stat) + c.connection.Log(logger.LevelDebug, "recursive download, dir path %#v virtual path %#v", dirPath, virtualPath) + err = c.sendDownloadProtocolMessages(virtualPath, stat) if err != nil { return err } - files, err := c.connection.Fs.ReadDir(dirPath) - files = c.connection.User.AddVirtualDirs(files, c.connection.Fs.GetRelativePath(dirPath)) + files, err := fs.ReadDir(dirPath) if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } + files = c.connection.User.AddVirtualDirs(files, fs.GetRelativePath(dirPath)) var dirs []string for _, file := range files { - filePath := c.connection.Fs.GetRelativePath(c.connection.Fs.Join(dirPath, file.Name())) + filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { err = c.handleDownload(filePath) if err != nil { @@ -361,7 +378,7 @@ func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) e } } if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } for _, dir := range dirs { @@ -371,25 +388,21 @@ func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) e } } if err != nil { - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } err = c.sendProtocolMessage("E\n") if err != nil { return err } - err = c.readConfirmationMessage() - if err != nil { - return err - } - return err + return c.readConfirmationMessage() } - err = fmt.Errorf("Unable to send directory for non recursive copy") - c.sendErrorMessage(err) + err = fmt.Errorf("unable to send directory for non recursive copy") + c.sendErrorMessage(nil, err) return err } -func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, transfer *transfer) error { +func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error { var err error if c.sendFileTime() { modTime := stat.ModTime().UnixNano() / 1000000000 @@ -403,8 +416,8 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra return err } } - if vfs.IsCryptOsFs(c.connection.Fs) { - stat = c.connection.Fs.(*vfs.CryptFs).ConvertFileInfo(stat) + if vfs.IsCryptOsFs(fs) { + stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat) } fileSize := stat.Size() @@ -435,7 +448,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra } } if err != io.EOF { - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } err = c.sendConfirmationMessage() @@ -450,55 +463,62 @@ func (c *scpCommand) handleDownload(filePath string) error { c.connection.UpdateLastActivity() var err error - p, err := c.connection.Fs.ResolvePath(filePath) + fs, err := c.connection.User.GetFilesystemForPath(filePath, c.connection.ID) if err != nil { - err := fmt.Errorf("Invalid file path") + c.connection.Log(logger.LevelWarn, "error downloading file %#v: %+v", filePath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", filePath)) + return err + } + + p, err := fs.ResolvePath(filePath) + if err != nil { + err := fmt.Errorf("invalid file path %#v", filePath) c.connection.Log(logger.LevelWarn, "error downloading file: %#v, invalid file path", filePath) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } var stat os.FileInfo - if stat, err = c.connection.Fs.Stat(p); err != nil { + if stat, err = fs.Stat(p); err != nil { c.connection.Log(logger.LevelWarn, "error downloading file: %#v->%#v, err: %v", filePath, p, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } if stat.IsDir() { if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) { c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - err = c.handleRecursiveDownload(p, stat) + err = c.handleRecursiveDownload(fs, p, filePath, stat) return err } if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) { c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } if !c.connection.User.IsFileAllowed(filePath) { c.connection.Log(logger.LevelWarn, "reading file %#v is not allowed", filePath) - c.sendErrorMessage(common.ErrPermissionDenied) + c.sendErrorMessage(fs, common.ErrPermissionDenied) return common.ErrPermissionDenied } - file, r, cancelFn, err := c.connection.Fs.Open(p, 0) + file, r, cancelFn, err := fs.Open(p, 0) if err != nil { c.connection.Log(logger.LevelError, "could not open file %#v for reading: %v", p, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, filePath, - common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + common.TransferDownload, 0, 0, 0, false, fs) t := newTransfer(baseTransfer, nil, r, nil) - err = c.sendDownloadFileData(p, stat, t) + err = c.sendDownloadFileData(fs, p, stat, t) // we need to call Close anyway and return close error if any and // if we have no previous error if err == nil { @@ -578,9 +598,13 @@ func (c *scpCommand) readProtocolMessage() (string, error) { // send an error message and close the channel //nolint:errcheck // we don't check write errors here, we have to close the channel anyway -func (c *scpCommand) sendErrorMessage(err error) { +func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) { c.connection.channel.Write(errMsg) - c.connection.channel.Write([]byte(c.connection.GetFsError(err).Error())) + if fs != nil { + c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error())) + } else { + c.connection.channel.Write([]byte(err.Error())) + } c.connection.channel.Write(newLine) c.connection.channel.Close() } @@ -625,22 +649,14 @@ func (c *scpCommand) getNextUploadProtocolMessage() (string, error) { return command, err } -func (c *scpCommand) createDir(dirPath string) error { - var err error - var isDir bool - isDir, err = vfs.IsDirectory(c.connection.Fs, dirPath) - if err == nil && isDir { - // if this is a virtual dir the resolved path will exist, we don't need a specific check - // TODO: remember to check if it's okay when we'll add virtual folders support to cloud backends - c.connection.Log(logger.LevelDebug, "directory %#v already exists", dirPath) - return nil - } - if err = c.connection.Fs.Mkdir(dirPath); err != nil { +func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error { + err := fs.Mkdir(dirPath) + if err != nil { c.connection.Log(logger.LevelError, "error creating dir %#v: %v", dirPath, err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return err } - vfs.SetPathPermissions(c.connection.Fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID()) + vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID()) return err } @@ -649,7 +665,7 @@ func (c *scpCommand) createDir(dirPath string) error { // or: // C0644 6 testfile // and returns file size and file/directory name -func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) { +func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) { var size int64 var name string var err error @@ -657,7 +673,7 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) { err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v", command, c.args, c.connection.User.Username) c.connection.Log(logger.LevelWarn, "error: %v", err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return size, name, err } parts := strings.SplitN(command, " ", 3) @@ -665,26 +681,26 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) { size, err = strconv.ParseInt(parts[1], 10, 64) if err != nil { c.connection.Log(logger.LevelWarn, "error getting size from upload message: %v", err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return size, name, err } name = parts[2] if name == "" { err = fmt.Errorf("error getting name from upload message, cannot be empty") c.connection.Log(logger.LevelWarn, "error: %v", err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return size, name, err } } else { err = fmt.Errorf("unable to split upload message: %#v", command) c.connection.Log(logger.LevelWarn, "error: %v", err) - c.sendErrorMessage(err) + c.sendErrorMessage(fs, err) return size, name, err } return size, name, err } -func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string { +func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string { if !c.isRecursive() { // if the upload is not recursive and the destination path does not end with "/" // then scpDestPath is the wanted filename, for example: @@ -695,8 +711,8 @@ func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string // but if scpDestPath is an existing directory then we put the uploaded file // inside that directory this is as scp command works, for example: // scp fileName.txt user@127.0.0.1:/existing_dir - if p, err := c.connection.Fs.ResolvePath(scpDestPath); err == nil { - if stat, err := c.connection.Fs.Stat(p); err == nil { + if p, err := fs.ResolvePath(scpDestPath); err == nil { + if stat, err := fs.Stat(p); err == nil { if stat.IsDir() { return path.Join(scpDestPath, fileName) } diff --git a/sftpd/server.go b/sftpd/server.go index 48becfe6..b04694d9 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/hex" "encoding/json" - "errors" "fmt" "io" "net" @@ -400,10 +399,14 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve loginType := sconn.Permissions.Extensions["sftpgo_login_method"] connectionID := hex.EncodeToString(sconn.SessionID()) - if err = checkRootPath(&user, connectionID); err != nil { + if err = user.CheckFsRoot(connectionID); err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) return } + defer user.CloseFs() //nolint:errcheck + logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", user.ID, loginType, user.Username, user.HomeDir, ipAddr) @@ -445,34 +448,24 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve switch req.Type { case "subsystem": if string(req.Payload[4:]) == "sftp" { - fs, err := user.GetFilesystem(connID) - if err == nil { - ok = true - connection := Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs), - ClientVersion: string(sconn.ClientVersion()), - RemoteAddr: conn.RemoteAddr(), - channel: channel, - } - go c.handleSftpConnection(channel, &connection) - } else { - logger.Debug(logSender, connID, "unable to create filesystem: %v", err) - } - } - case "exec": - // protocol will be set later inside processSSHCommand it could be SSH or SCP - fs, err := user.GetFilesystem(connID) - if err == nil { + ok = true connection := Connection{ - BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs), + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user), ClientVersion: string(sconn.ClientVersion()), RemoteAddr: conn.RemoteAddr(), channel: channel, } - ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) - } else { - logger.Debug(sshCommandLogSender, connID, "unable to create filesystem: %v", err) + go c.handleSftpConnection(channel, &connection) } + case "exec": + // protocol will be set later inside processSSHCommand it could be SSH or SCP + connection := Connection{ + BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user), + ClientVersion: string(sconn.ClientVersion()), + RemoteAddr: conn.RemoteAddr(), + channel: channel, + } + ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) } if req.WantReply { req.Reply(ok, nil) //nolint:errcheck @@ -541,21 +534,6 @@ func checkAuthError(ip string, err error) { } } -func checkRootPath(user *dataprovider.User, connectionID string) error { - if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider { - // for sftp fs check root path does nothing so don't open a useless SFTP connection - fs, err := user.GetFilesystem(connectionID) - if err != nil { - logger.Warn(logSender, "", "could not create filesystem for user %#v err: %v", user.Username, err) - return err - } - - fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID()) - fs.Close() - } - return nil -} - func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { connectionID := "" if conn != nil { @@ -568,7 +546,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh. } if utils.IsStringInSlice(common.ProtocolSSH, user.Filters.DeniedProtocols) { logger.Debug(logSender, connectionID, "cannot login user %#v, protocol SSH is not allowed", user.Username) - return nil, fmt.Errorf("Protocol SSH is not allowed for user %#v", user.Username) + return nil, fmt.Errorf("protocol SSH is not allowed for user %#v", user.Username) } if user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) @@ -580,17 +558,12 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh. } if !user.IsLoginMethodAllowed(loginMethod, conn.PartialSuccessMethods()) { logger.Debug(logSender, connectionID, "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod) - return nil, fmt.Errorf("Login method %#v is not allowed for user %#v", loginMethod, user.Username) - } - if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() { - logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled", - user.Username) - return nil, errors.New("overlapping mapped folders are allowed only with quota tracking disabled") + return nil, fmt.Errorf("login method %#v is not allowed for user %#v", loginMethod, user.Username) } remoteAddr := conn.RemoteAddr().String() if !user.IsLoginFromAddrAllowed(remoteAddr) { logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr) - return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr) + return nil, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, remoteAddr) } json, err := json.Marshal(user) diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index dd0e540b..1c379371 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -146,6 +146,7 @@ func TestMain(m *testing.M) { if err != nil { logger.ErrorToConsole("error creating login banner: %v", err) } + os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") err = config.LoadConfig(configDir, "") if err != nil { logger.ErrorToConsole("error loading configuration: %v", err) @@ -155,10 +156,6 @@ func TestMain(m *testing.M) { logger.InfoToConsole("Starting SFTPD tests, provider: %v", providerConf.Driver) commonConf := config.GetCommonConfig() - // we run the test cases with UploadMode atomic and resume support. The non atomic code path - // simply does not execute some code so if it works in atomic mode will - // work in non atomic mode too - commonConf.UploadMode = 2 homeBasePath = os.TempDir() checkSystemCommands() var scriptArgs string @@ -1130,7 +1127,7 @@ func TestChtimes(t *testing.T) { defer client.Close() testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) - testDir := "test" + testDir := "test" //nolint:goconst err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) @@ -1687,7 +1684,7 @@ func TestLoginUserExpiration(t *testing.T) { func TestLoginWithDatabaseCredentials(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "testbucket" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) @@ -1736,7 +1733,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { func TestLoginInvalidFs(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) @@ -2569,6 +2566,29 @@ func TestMaxSessions(t *testing.T) { assert.NoError(t, err) } +func TestSupportedExtensions(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + v, ok := client.HasExtension("statvfs@openssh.com") + assert.Equal(t, "2", v) + assert.True(t, ok) + _, ok = client.HasExtension("hardlink@openssh.com") + assert.False(t, ok) + _, ok = client.HasExtension("posix-rename@openssh.com") + assert.False(t, ok) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestQuotaFileReplace(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) @@ -3344,6 +3364,196 @@ func TestVirtualFoldersQuotaLimit(t *testing.T) { assert.NoError(t, err) } +func TestNestedVirtualFolders(t *testing.T) { + usePubKey := false + baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestSFTPUser(usePubKey) + u.QuotaFiles = 1000 + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + QuotaFiles: 100, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + expectedQuotaSize := int64(0) + expectedQuotaFiles := 0 + fileSize := int64(32765) + err = writeSFTPFile(testFileName, fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 38764 + err = writeSFTPFile(path.Join("/vdir", testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 18769 + err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 27658 + err = writeSFTPFile(path.Join(vdirNestedPath, testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 39765 + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), fileSize, client) + assert.NoError(t, err) + + userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) + + folderGet, _, err := httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, folderGet.UsedQuotaSize, fileSize) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + + folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(18769), folderGet.UsedQuotaSize) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + + folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(27658), folderGet.UsedQuotaSize) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + + files, err := client.ReadDir("/") + if assert.NoError(t, err) { + assert.Len(t, files, 2) + } + info, err := client.Stat("vdir") + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + files, err = client.ReadDir("/vdir") + if assert.NoError(t, err) { + assert.Len(t, files, 3) + } + files, err = client.ReadDir(vdirCryptPath) + if assert.NoError(t, err) { + assert.Len(t, files, 2) + } + info, err = client.Stat(vdirNestedPath) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + // finally add some files directly using os method and then check quota + fName := "testfile" + fileSize = 123456 + err = createTestFile(filepath.Join(baseUser.HomeDir, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 8765 + err = createTestFile(filepath.Join(mappedPath, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 98751 + err = createTestFile(filepath.Join(mappedPathNested, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + err = createTestFile(filepath.Join(mappedPathCrypt, fName), fileSize) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + + userGet, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) + + // the crypt folder is not included within user quota so we need to do a separate scan + _, err = httpdtest.StartFolderQuotaScan(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + folderGet, _, err = httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, folderGet.UsedQuotaSize, int64(39765+98751)) + assert.Equal(t, 2, folderGet.UsedQuotaFiles) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) +} + func TestTruncateQuotaLimits(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) @@ -4733,208 +4943,6 @@ func TestVirtualFoldersLink(t *testing.T) { assert.NoError(t, err) } -func TestOverlappedMappedFolders(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.TrackQuota = 0 - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - usePubKey := false - u := getTestUser(usePubKey) - subDir := "subdir" - mappedPath1 := filepath.Join(os.TempDir(), "vdir1") - folderName1 := filepath.Base(mappedPath1) - vdirPath1 := "/vdir1" - mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir) - folderName2 := filepath.Base(mappedPath2) - vdirPath2 := "/vdir2" - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName1, - MappedPath: mappedPath1, - }, - VirtualPath: vdirPath1, - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName2, - MappedPath: mappedPath2, - }, - VirtualPath: vdirPath2, - }) - err = os.MkdirAll(mappedPath1, os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(mappedPath2, os.ModePerm) - assert.NoError(t, err) - user, _, err := httpdtest.AddUser(u, http.StatusCreated) - assert.NoError(t, err) - client, err := getSftpClient(user, usePubKey) - if assert.NoError(t, err) { - defer client.Close() - err = checkBasicSFTP(client) - assert.NoError(t, err) - testFileSize := int64(131072) - testFilePath := filepath.Join(homeBasePath, testFileName) - err = createTestFile(testFilePath, testFileSize) - assert.NoError(t, err) - err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) - assert.NoError(t, err) - err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) - assert.NoError(t, err) - err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) - assert.NoError(t, err) - fi, err := client.Stat(path.Join(vdirPath1, subDir, testFileName)) - if assert.NoError(t, err) { - assert.Equal(t, testFileSize, fi.Size()) - } - err = client.Rename(path.Join(vdirPath1, subDir, testFileName), path.Join(vdirPath2, testFileName+"1")) - assert.NoError(t, err) - err = client.Rename(path.Join(vdirPath2, testFileName+"1"), path.Join(vdirPath1, subDir, testFileName)) - assert.NoError(t, err) - err = client.Rename(path.Join(vdirPath1, subDir), path.Join(vdirPath2, subDir)) - assert.Error(t, err) - err = client.Mkdir(subDir) - assert.NoError(t, err) - err = client.Rename(subDir, path.Join(vdirPath1, subDir)) - assert.Error(t, err) - err = client.RemoveDirectory(path.Join(vdirPath1, subDir)) - assert.Error(t, err) - err = client.Symlink(path.Join(vdirPath1, subDir), path.Join(vdirPath1, "adir")) - assert.Error(t, err) - err = client.Mkdir(path.Join(vdirPath1, subDir+"1")) - assert.NoError(t, err) - err = client.Symlink(path.Join(vdirPath1, subDir+"1"), path.Join(vdirPath1, subDir)) - assert.Error(t, err) - err = os.Remove(testFilePath) - assert.NoError(t, err) - _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, subDir)), user, usePubKey) - assert.Error(t, err) - } - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - if providerConf.Driver != dataprovider.MemoryDataProviderName { - client, err = getSftpClient(user, usePubKey) - if !assert.Error(t, err) { - client.Close() - } - - _, err = httpdtest.RemoveUser(user, http.StatusOK) - assert.NoError(t, err) - _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) - assert.NoError(t, err) - _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) - assert.NoError(t, err) - } - - _, _, err = httpdtest.AddUser(u, http.StatusCreated) - assert.Error(t, err) - - err = os.RemoveAll(user.GetHomeDir()) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath1) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath2) - assert.NoError(t, err) -} - -func TestResolveOverlappedMappedPaths(t *testing.T) { - u := getTestUser(false) - mappedPath1 := filepath.Join(os.TempDir(), "mapped1", "subdir") - folderName1 := filepath.Base(mappedPath1) - vdirPath1 := "/vdir1/subdir" - mappedPath2 := filepath.Join(os.TempDir(), "mapped2") - folderName2 := filepath.Base(mappedPath2) - vdirPath2 := "/vdir2/subdir" - mappedPath3 := filepath.Join(os.TempDir(), "mapped1") - folderName3 := filepath.Base(mappedPath3) - vdirPath3 := "/vdir3" - mappedPath4 := filepath.Join(os.TempDir(), "mapped1", "subdir", "vdir4") - folderName4 := filepath.Base(mappedPath4) - vdirPath4 := "/vdir4" - u.VirtualFolders = []vfs.VirtualFolder{ - { - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName1, - MappedPath: mappedPath1, - }, - VirtualPath: vdirPath1, - }, - { - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName2, - MappedPath: mappedPath2, - }, - VirtualPath: vdirPath2, - }, - { - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName3, - MappedPath: mappedPath3, - }, - VirtualPath: vdirPath3, - }, - { - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName4, - MappedPath: mappedPath4, - }, - VirtualPath: vdirPath4, - }, - } - err := os.MkdirAll(u.GetHomeDir(), os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(mappedPath1, os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(mappedPath2, os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(mappedPath3, os.ModePerm) - assert.NoError(t, err) - err = os.MkdirAll(mappedPath4, os.ModePerm) - assert.NoError(t, err) - - fs := vfs.NewOsFs("", u.GetHomeDir(), u.VirtualFolders) - p, err := fs.ResolvePath("/vdir1") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(u.GetHomeDir(), "vdir1"), p) - p, err = fs.ResolvePath("/vdir1/subdir") - assert.NoError(t, err) - assert.Equal(t, mappedPath1, p) - p, err = fs.ResolvePath("/vdir3/subdir/vdir4/file.txt") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(mappedPath4, "file.txt"), p) - p, err = fs.ResolvePath("/vdir4/file.txt") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(mappedPath4, "file.txt"), p) - assert.Equal(t, filepath.Join(mappedPath3, "subdir", "vdir4", "file.txt"), p) - assert.Equal(t, filepath.Join(mappedPath1, "vdir4", "file.txt"), p) - p, err = fs.ResolvePath("/vdir3/subdir/vdir4/../file.txt") - assert.NoError(t, err) - assert.Equal(t, filepath.Join(mappedPath3, "subdir", "file.txt"), p) - assert.Equal(t, filepath.Join(mappedPath1, "file.txt"), p) - - err = os.RemoveAll(u.GetHomeDir()) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath4) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath1) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath3) - assert.NoError(t, err) - err = os.RemoveAll(mappedPath2) - assert.NoError(t, err) -} - func TestVirtualFolderQuotaScan(t *testing.T) { mappedPath := filepath.Join(os.TempDir(), "mapped_dir") folderName := filepath.Base(mappedPath) @@ -5161,19 +5169,28 @@ func TestOpenError(t *testing.T) { assert.Error(t, err, "file download must fail if we have no filesystem read permissions") err = sftpUploadFile(localDownloadPath, testFileName, testFileSize, client) assert.Error(t, err, "upload must fail if we have no filesystem write permissions") - err = client.Mkdir("test") + testDir := "test" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = createTestFile(filepath.Join(user.GetHomeDir(), testDir, testFileName), testFileSize) assert.NoError(t, err) err = os.Chmod(user.GetHomeDir(), 0000) assert.NoError(t, err) _, err = client.Lstat(testFileName) assert.Error(t, err, "file stat must fail if we have no filesystem read permissions") + err = sftpUploadFile(localDownloadPath, path.Join(testDir, testFileName), testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.ReadLink(testFileName) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Remove(testFileName) + assert.ErrorIs(t, err, os.ErrPermission) err = os.Chmod(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) - err = os.Chmod(filepath.Join(user.GetHomeDir(), "test"), 0000) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), 0000) assert.NoError(t, err) - err = client.Rename(testFileName, path.Join("test", testFileName)) + err = client.Rename(testFileName, path.Join(testDir, testFileName)) assert.True(t, os.IsPermission(err)) - err = os.Chmod(filepath.Join(user.GetHomeDir(), "test"), os.ModePerm) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm) assert.NoError(t, err) err = os.Remove(localDownloadPath) assert.NoError(t, err) @@ -5941,23 +5958,23 @@ func TestRootDirCommands(t *testing.T) { func TestRelativePaths(t *testing.T) { user := getTestUser(true) var path, rel string - filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders)} + filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "")} keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" s3config := vfs.S3FsConfig{ KeyPrefix: keyPrefix, } - s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), s3config) + s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) gcsConfig := vfs.GCSFsConfig{ KeyPrefix: keyPrefix, } - gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), gcsConfig) + gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) sftpconfig := vfs.SFTPFsConfig{ Endpoint: sftpServerAddr, Username: defaultUsername, Password: kms.NewPlainSecret(defaultPassword), Prefix: keyPrefix, } - sftpfs, _ := vfs.NewSFTPFs("", sftpconfig) + sftpfs, _ := vfs.NewSFTPFs("", "", sftpconfig) if runtime.GOOS != osWindows { filesystems = append(filesystems, s3fs, gcsfs, sftpfs) } @@ -6000,7 +6017,7 @@ func TestResolvePaths(t *testing.T) { user := getTestUser(true) var path, resolved string var err error - filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders)} + filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "")} keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" s3config := vfs.S3FsConfig{ KeyPrefix: keyPrefix, @@ -6009,12 +6026,12 @@ func TestResolvePaths(t *testing.T) { } err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) assert.NoError(t, err) - s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), s3config) + s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) assert.NoError(t, err) gcsConfig := vfs.GCSFsConfig{ KeyPrefix: keyPrefix, } - gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), gcsConfig) + gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) if runtime.GOOS != osWindows { filesystems = append(filesystems, s3fs, gcsfs) } @@ -6049,7 +6066,7 @@ func TestResolvePaths(t *testing.T) { func TestVirtualRelativePaths(t *testing.T) { user := getTestUser(true) - mappedPath := filepath.Join(os.TempDir(), "vdir") + mappedPath := filepath.Join(os.TempDir(), "mdir") vdirPath := "/vdir" //nolint:goconst user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ BaseVirtualFolder: vfs.BaseVirtualFolder{ @@ -6059,47 +6076,23 @@ func TestVirtualRelativePaths(t *testing.T) { }) err := os.MkdirAll(mappedPath, os.ModePerm) assert.NoError(t, err) - fs := vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders) - rel := fs.GetRelativePath(mappedPath) + fsRoot := vfs.NewOsFs("", user.GetHomeDir(), "") + fsVdir := vfs.NewOsFs("", mappedPath, vdirPath) + rel := fsVdir.GetRelativePath(mappedPath) assert.Equal(t, vdirPath, rel) - rel = fs.GetRelativePath(filepath.Join(mappedPath, "..")) + rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "..")) assert.Equal(t, "/", rel) // path outside home and virtual dir - rel = fs.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) + rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) assert.Equal(t, "/", rel) - rel = fs.GetRelativePath(filepath.Join(mappedPath, "../vdir/file.txt")) + rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) + assert.Equal(t, "/vdir", rel) + rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "file.txt")) assert.Equal(t, "/vdir/file.txt", rel) - rel = fs.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt")) + rel = fsRoot.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt")) assert.Equal(t, "/vdir1/file.txt", rel) } -func TestResolveVirtualPaths(t *testing.T) { - user := getTestUser(true) - mappedPath := filepath.Join(os.TempDir(), "vdir") - vdirPath := "/vdir" - user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath, - }, - VirtualPath: vdirPath, - }) - err := os.MkdirAll(mappedPath, os.ModePerm) - assert.NoError(t, err) - osFs := vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders).(*vfs.OsFs) - b, f := osFs.GetFsPaths("/vdir/a.txt") - assert.Equal(t, mappedPath, b) - assert.Equal(t, filepath.Join(mappedPath, "a.txt"), f) - b, f = osFs.GetFsPaths("/vdir/sub with space & spécial chars/a.txt") - assert.Equal(t, mappedPath, b) - assert.Equal(t, filepath.Join(mappedPath, "sub with space & spécial chars/a.txt"), f) - b, f = osFs.GetFsPaths("/vdir/../a.txt") - assert.Equal(t, user.GetHomeDir(), b) - assert.Equal(t, filepath.Join(user.GetHomeDir(), "a.txt"), f) - b, f = osFs.GetFsPaths("/vdir1/a.txt") - assert.Equal(t, user.GetHomeDir(), b) - assert.Equal(t, filepath.Join(user.GetHomeDir(), "/vdir1/a.txt"), f) -} - func TestUserPerms(t *testing.T) { user := getTestUser(true) user.Permissions = make(map[string][]string) @@ -6504,7 +6497,7 @@ func TestStatVFS(t *testing.T) { func TestStatVFSCloudBackend(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) - u.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider + u.FsConfig.Provider = vfs.AzureBlobFilesystemProvider u.FsConfig.AzBlobConfig.SASURL = "https://myaccount.blob.core.windows.net/sasurl" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) @@ -7050,6 +7043,32 @@ func TestSSHCopyQuotaLimits(t *testing.T) { assert.NoError(t, err) } +func TestSSHCopyRemoveNonLocalFs(t *testing.T) { + usePubKey := true + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(sftpUser, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + testDir := "test" + err = client.Mkdir(testDir) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", testDir, testDir+"_copy"), sftpUser, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), sftpUser, usePubKey) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + func TestSSHRemove(t *testing.T) { usePubKey := false u := getTestUser(usePubKey) @@ -7188,6 +7207,117 @@ func TestSSHRemove(t *testing.T) { assert.NoError(t, err) } +func TestSSHRemoveCryptFs(t *testing.T) { + usePubKey := false + u := getTestUserWithCryptFs(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1/sub" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2/sub" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + }, + VirtualPath: vdirPath2, + QuotaFiles: 100, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + testDir := "tdir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + testFileSize := int64(32768) + testFileSize1 := int64(65536) + testFileName1 := "test_file1.dat" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + _, err = runSSHCommand("sftpgo-remove /vdir2", user, usePubKey) + assert.Error(t, err) + out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + _, err := client.Stat(testFileName) + assert.Error(t, err) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir, testFileName)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Greater(t, user.UsedQuotaSize, testFileSize1) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + func TestBasicGitCommands(t *testing.T) { if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows { t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test") @@ -7667,6 +7797,136 @@ func TestSCPVirtualFolders(t *testing.T) { assert.NoError(t, err) } +func TestSCPNestedFolders(t *testing.T) { + if len(scpPath) == 0 { + t.Skip("scp command not found, unable to execute this test") + } + baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) + assert.NoError(t, err, string(resp)) + usePubKey := true + u := getTestUser(usePubKey) + u.HomeDir += "_folders" + u.Username += "_folders" + mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") + folderNameSFTP := filepath.Base(mappedPathSFTP) + vdirSFTPPath := "/vdir/sftp" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + FsConfig: vfs.Filesystem{ + Provider: vfs.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + }, + VirtualPath: vdirSFTPPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + }) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + baseDirDownPath := filepath.Join(os.TempDir(), "basedir-down") + err = os.Mkdir(baseDirDownPath, os.ModePerm) + assert.NoError(t, err) + baseDir := filepath.Join(os.TempDir(), "basedir") + err = os.Mkdir(baseDir, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(baseDir, vdirSFTPPath), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(baseDir, vdirCryptPath), os.ModePerm) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, vdirSFTPPath, testFileName), 32768) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, vdirCryptPath, testFileName), 65535) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, "vdir", testFileName), 65536) + assert.NoError(t, err) + + remoteRootPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) + assert.NoError(t, err) + + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + info, err := client.Stat(path.Join(vdirCryptPath, testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(65535), info.Size()) + info, err = client.Stat(path.Join(vdirSFTPPath, testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(32768), info.Size()) + info, err = client.Stat(path.Join("/vdir", testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(65536), info.Size()) + } + + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.NoError(t, err) + + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, "vdir", testFileName)) + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirCryptPath, testFileName)) + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirSFTPPath, testFileName)) + + if runtime.GOOS != osWindows { + err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), 0001) + assert.NoError(t, err) + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.Error(t, err) + err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), os.ModePerm) + assert.NoError(t, err) + } + + // now change the password for the base user, so SFTP folder will not work + baseUser.Password = defaultPassword + "_mod" + _, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "") + assert.NoError(t, err) + + err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) + assert.Error(t, err) + + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathSFTP) + assert.NoError(t, err) + err = os.RemoveAll(baseDir) + assert.NoError(t, err) + err = os.RemoveAll(baseDirDownPath) + assert.NoError(t, err) +} + func TestSCPVirtualFoldersQuota(t *testing.T) { if len(scpPath) == 0 { t.Skip("scp command not found, unable to execute this test") @@ -8194,7 +8454,7 @@ func getTestUser(usePubKey bool) dataprovider.User { func getTestSFTPUser(usePubKey bool) dataprovider.User { u := getTestUser(usePubKey) u.Username = defaultSFTPUsername - u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + u.FsConfig.Provider = vfs.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) @@ -8377,6 +8637,35 @@ func checkBasicSFTP(client *sftp.Client) error { return err } +func writeSFTPFile(name string, size int64, client *sftp.Client) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + f, err := client.Create(name) + if err != nil { + return err + } + _, err = io.Copy(f, bytes.NewBuffer(content)) + if err != nil { + f.Close() + return err + } + err = f.Close() + if err != nil { + return err + } + info, err := client.Stat(name) + if err != nil { + return err + } + if info.Size() != size { + return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) + } + return nil +} + func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { srcFile, err := os.Open(localSourcePath) if err != nil { diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 6cb5cd0a..ec94e41f 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -47,6 +47,7 @@ type systemCommand struct { cmd *exec.Cmd fsPath string quotaCheckPath string + fs vfs.Fs } func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool { @@ -82,7 +83,8 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %#v", name) } } - connection.Fs.Close() + err := connection.CloseFS() + connection.Log(logger.LevelDebug, "unable to unmarsh ssh command, close fs, err: %v", err) return false } @@ -112,45 +114,42 @@ func (c *sshCommand) handle() (err error) { c.connection.channel.Write([]byte("/\n")) //nolint:errcheck c.sendExitStatus(nil) } else if c.command == "sftpgo-copy" { - return c.handeSFTPGoCopy() + return c.handleSFTPGoCopy() } else if c.command == "sftpgo-remove" { - return c.handeSFTPGoRemove() + return c.handleSFTPGoRemove() } return } -func (c *sshCommand) handeSFTPGoCopy() error { - if !vfs.IsLocalOsFs(c.connection.Fs) { +func (c *sshCommand) handleSFTPGoCopy() error { + fsSrc, fsDst, sshSourcePath, sshDestPath, fsSourcePath, fsDestPath, err := c.getFsAndCopyPaths() + if err != nil { + return c.sendErrorResponse(err) + } + if !c.isLocalCopy(sshSourcePath, sshDestPath) { return c.sendErrorResponse(errUnsupportedConfig) } - sshSourcePath, sshDestPath, err := c.getCopyPaths() - if err != nil { - return c.sendErrorResponse(err) - } - fsSourcePath, fsDestPath, err := c.resolveCopyPaths(sshSourcePath, sshDestPath) - if err != nil { - return c.sendErrorResponse(err) - } - if err := c.checkCopyDestination(fsDestPath); err != nil { - return c.sendErrorResponse(err) + + if err := c.checkCopyDestination(fsDst, fsDestPath); err != nil { + return c.sendErrorResponse(c.connection.GetFsError(fsDst, err)) } c.connection.Log(logger.LevelDebug, "requested copy %#v -> %#v sftp paths %#v -> %#v", fsSourcePath, fsDestPath, sshSourcePath, sshDestPath) - fi, err := c.connection.Fs.Lstat(fsSourcePath) + fi, err := fsSrc.Lstat(fsSourcePath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err)) } - if err := c.checkCopyPermissions(fsSourcePath, fsDestPath, sshSourcePath, sshDestPath, fi); err != nil { + if err := c.checkCopyPermissions(fsSrc, fsDst, fsSourcePath, fsDestPath, sshSourcePath, sshDestPath, fi); err != nil { return c.sendErrorResponse(err) } filesNum := 0 filesSize := int64(0) if fi.IsDir() { - filesNum, filesSize, err = c.connection.Fs.GetDirSize(fsSourcePath) + filesNum, filesSize, err = fsSrc.GetDirSize(fsSourcePath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err)) } if c.connection.User.HasVirtualFoldersInside(sshSourcePath) { err := errors.New("unsupported copy source: the source directory contains virtual folders") @@ -177,7 +176,7 @@ func (c *sshCommand) handeSFTPGoCopy() error { c.connection.Log(logger.LevelDebug, "start copy %#v -> %#v", fsSourcePath, fsDestPath) err = fscopy.Copy(fsSourcePath, fsDestPath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err)) } c.updateQuota(sshDestPath, filesNum, filesSize) c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck @@ -185,10 +184,7 @@ func (c *sshCommand) handeSFTPGoCopy() error { return nil } -func (c *sshCommand) handeSFTPGoRemove() error { - if !vfs.IsLocalOsFs(c.connection.Fs) { - return c.sendErrorResponse(errUnsupportedConfig) - } +func (c *sshCommand) handleSFTPGoRemove() error { sshDestPath, err := c.getRemovePath() if err != nil { return c.sendErrorResponse(err) @@ -196,20 +192,23 @@ func (c *sshCommand) handeSFTPGoRemove() error { if !c.connection.User.HasPerm(dataprovider.PermDelete, path.Dir(sshDestPath)) { return c.sendErrorResponse(common.ErrPermissionDenied) } - fsDestPath, err := c.connection.Fs.ResolvePath(sshDestPath) + fs, fsDestPath, err := c.connection.GetFsAndResolvedPath(sshDestPath) if err != nil { return c.sendErrorResponse(err) } - fi, err := c.connection.Fs.Lstat(fsDestPath) + if !vfs.IsLocalOrCryptoFs(fs) { + return c.sendErrorResponse(errUnsupportedConfig) + } + fi, err := fs.Lstat(fsDestPath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fs, err)) } filesNum := 0 filesSize := int64(0) if fi.IsDir() { - filesNum, filesSize, err = c.connection.Fs.GetDirSize(fsDestPath) + filesNum, filesSize, err = fs.GetDirSize(fsDestPath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fs, err)) } if sshDestPath == "/" { err := errors.New("removing root dir is not allowed") @@ -223,10 +222,6 @@ func (c *sshCommand) handeSFTPGoRemove() error { err := errors.New("unsupported remove source: this directory is a virtual folder") return c.sendErrorResponse(err) } - if c.connection.User.IsMappedPath(fsDestPath) { - err := errors.New("removing a directory mapped as virtual folder is not allowed") - return c.sendErrorResponse(err) - } } else if fi.Mode().IsRegular() { filesNum = 1 filesSize = fi.Size() @@ -284,18 +279,18 @@ func (c *sshCommand) handleHashCommands() error { sshPath := c.getDestPath() if !c.connection.User.IsFileAllowed(sshPath) { c.connection.Log(logger.LevelInfo, "hash not allowed for file %#v", sshPath) - return c.sendErrorResponse(common.ErrPermissionDenied) + return c.sendErrorResponse(c.connection.GetPermissionDeniedError()) } - fsPath, err := c.connection.Fs.ResolvePath(sshPath) + fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath) if err != nil { return c.sendErrorResponse(err) } if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) { - return c.sendErrorResponse(common.ErrPermissionDenied) + return c.sendErrorResponse(c.connection.GetPermissionDeniedError()) } - hash, err := c.computeHashForFile(h, fsPath) + hash, err := c.computeHashForFile(fs, h, fsPath) if err != nil { - return c.sendErrorResponse(err) + return c.sendErrorResponse(c.connection.GetFsError(fs, err)) } response = fmt.Sprintf("%v %v\n", hash, sshPath) } @@ -305,10 +300,10 @@ func (c *sshCommand) handleHashCommands() error { } func (c *sshCommand) executeSystemCommand(command systemCommand) error { - if !vfs.IsLocalOsFs(c.connection.Fs) { + sshDestPath := c.getDestPath() + if !c.isLocalPath(sshDestPath) { return c.sendErrorResponse(errUnsupportedConfig) } - sshDestPath := c.getDestPath() quotaResult := c.connection.HasSpace(true, false, command.quotaCheckPath) if !quotaResult.HasSpace { return c.sendErrorResponse(common.ErrQuotaExceeded) @@ -316,10 +311,10 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { perms := []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermListItems, dataprovider.PermOverwrite, dataprovider.PermDelete} if !c.connection.User.HasPerms(perms, sshDestPath) { - return c.sendErrorResponse(common.ErrPermissionDenied) + return c.sendErrorResponse(c.connection.GetPermissionDeniedError()) } - initialFiles, initialSize, err := c.getSizeForPath(command.fsPath) + initialFiles, initialSize, err := c.getSizeForPath(command.fs, command.fsPath) if err != nil { return c.sendErrorResponse(err) } @@ -356,7 +351,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { defer stdin.Close() baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferUpload, 0, 0, remainingQuotaSize, false, c.connection.Fs) + common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel) @@ -369,7 +364,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + common.TransferDownload, 0, 0, 0, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout) @@ -383,7 +378,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + common.TransferDownload, 0, 0, 0, false, command.fs) transfer := newTransfer(baseTransfer, nil, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr) @@ -399,14 +394,14 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { err = command.cmd.Wait() c.sendExitStatus(err) - numFiles, dirSize, errSize := c.getSizeForPath(command.fsPath) + numFiles, dirSize, errSize := c.getSizeForPath(command.fs, command.fsPath) if errSize == nil { c.updateQuota(sshDestPath, numFiles-initialFiles, dirSize-initialSize) } c.connection.Log(logger.LevelDebug, "command %#v finished for path %#v, initial files %v initial size %v "+ "current files %v current size %v size err: %v", c.connection.command, command.fsPath, initialFiles, initialSize, numFiles, dirSize, errSize) - return err + return c.connection.GetFsError(command.fs, err) } func (c *sshCommand) isSystemCommandAllowed() error { @@ -450,21 +445,26 @@ func (c *sshCommand) isSystemCommandAllowed() error { func (c *sshCommand) getSystemCommand() (systemCommand, error) { command := systemCommand{ cmd: nil, + fs: nil, fsPath: "", quotaCheckPath: "", } args := make([]string, len(c.args)) copy(args, c.args) var fsPath, quotaPath string + sshPath := c.getDestPath() + fs, err := c.connection.User.GetFilesystemForPath(sshPath, c.connection.ID) + if err != nil { + return command, err + } if len(c.args) > 0 { var err error - sshPath := c.getDestPath() - fsPath, err = c.connection.Fs.ResolvePath(sshPath) + fsPath, err = fs.ResolvePath(sshPath) if err != nil { - return command, err + return command, c.connection.GetFsError(fs, err) } quotaPath = sshPath - fi, err := c.connection.Fs.Stat(fsPath) + fi, err := fs.Stat(fsPath) if err == nil && fi.IsDir() { // if the target is an existing dir the command will write inside this dir // so we need to check the quota for this directory and not its parent dir @@ -506,6 +506,7 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) { command.cmd = cmd command.fsPath = fsPath command.quotaCheckPath = quotaPath + command.fs = fs return command, nil } @@ -535,7 +536,7 @@ func cleanCommandPath(name string) string { return result } -func (c *sshCommand) getCopyPaths() (string, string, error) { +func (c *sshCommand) getFsAndCopyPaths() (vfs.Fs, vfs.Fs, string, string, string, string, error) { sshSourcePath := strings.TrimSuffix(c.getSourcePath(), "/") sshDestPath := c.getDestPath() if strings.HasSuffix(sshDestPath, "/") { @@ -543,9 +544,17 @@ func (c *sshCommand) getCopyPaths() (string, string, error) { } if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 { err := errors.New("usage sftpgo-copy ") - return "", "", err + return nil, nil, "", "", "", "", err } - return sshSourcePath, sshDestPath, nil + fsSrc, fsSourcePath, err := c.connection.GetFsAndResolvedPath(sshSourcePath) + if err != nil { + return nil, nil, "", "", "", "", err + } + fsDst, fsDestPath, err := c.connection.GetFsAndResolvedPath(sshDestPath) + if err != nil { + return nil, nil, "", "", "", "", err + } + return fsSrc, fsDst, sshSourcePath, sshDestPath, fsSourcePath, fsDestPath, nil } func (c *sshCommand) hasCopyPermissions(sshSourcePath, sshDestPath string, srcInfo os.FileInfo) bool { @@ -561,7 +570,7 @@ func (c *sshCommand) hasCopyPermissions(sshSourcePath, sshDestPath string, srcIn } // fsSourcePath must be a directory -func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, sshDestPath string) error { +func (c *sshCommand) checkRecursiveCopyPermissions(fsSrc vfs.Fs, fsDst vfs.Fs, fsSourcePath, fsDestPath, sshDestPath string) error { if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(sshDestPath)) { return common.ErrPermissionDenied } @@ -571,13 +580,13 @@ func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, ssh dataprovider.PermUpload, } - err := c.connection.Fs.Walk(fsSourcePath, func(walkedPath string, info os.FileInfo, err error) error { + err := fsSrc.Walk(fsSourcePath, func(walkedPath string, info os.FileInfo, err error) error { if err != nil { - return err + return c.connection.GetFsError(fsSrc, err) } fsDstSubPath := strings.Replace(walkedPath, fsSourcePath, fsDestPath, 1) - sshSrcSubPath := c.connection.Fs.GetRelativePath(walkedPath) - sshDstSubPath := c.connection.Fs.GetRelativePath(fsDstSubPath) + sshSrcSubPath := fsSrc.GetRelativePath(walkedPath) + sshDstSubPath := fsDst.GetRelativePath(fsDstSubPath) // If the current dir has no subdirs with defined permissions inside it // and it has all the possible permissions we can stop scanning if !c.connection.User.HasPermissionsInside(path.Dir(sshSrcSubPath)) && @@ -598,12 +607,12 @@ func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, ssh return err } -func (c *sshCommand) checkCopyPermissions(fsSourcePath, fsDestPath, sshSourcePath, sshDestPath string, info os.FileInfo) error { +func (c *sshCommand) checkCopyPermissions(fsSrc vfs.Fs, fsDst vfs.Fs, fsSourcePath, fsDestPath, sshSourcePath, sshDestPath string, info os.FileInfo) error { if info.IsDir() { - return c.checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, sshDestPath) + return c.checkRecursiveCopyPermissions(fsSrc, fsDst, fsSourcePath, fsDestPath, sshDestPath) } if !c.hasCopyPermissions(sshSourcePath, sshDestPath, info) { - return common.ErrPermissionDenied + return c.connection.GetPermissionDeniedError() } return nil } @@ -620,24 +629,28 @@ func (c *sshCommand) getRemovePath() (string, error) { return sshDestPath, nil } -func (c *sshCommand) resolveCopyPaths(sshSourcePath, sshDestPath string) (string, string, error) { - fsSourcePath, err := c.connection.Fs.ResolvePath(sshSourcePath) +func (c *sshCommand) isLocalPath(virtualPath string) bool { + folder, err := c.connection.User.GetVirtualFolderForPath(virtualPath) if err != nil { - return "", "", err + return c.connection.User.FsConfig.Provider == vfs.LocalFilesystemProvider } - fsDestPath, err := c.connection.Fs.ResolvePath(sshDestPath) - if err != nil { - return "", "", err - } - return fsSourcePath, fsDestPath, nil + return folder.FsConfig.Provider == vfs.LocalFilesystemProvider } -func (c *sshCommand) checkCopyDestination(fsDestPath string) error { - _, err := c.connection.Fs.Lstat(fsDestPath) +func (c *sshCommand) isLocalCopy(virtualSourcePath, virtualTargetPath string) bool { + if !c.isLocalPath(virtualSourcePath) { + return false + } + + return c.isLocalPath(virtualTargetPath) +} + +func (c *sshCommand) checkCopyDestination(fs vfs.Fs, fsDestPath string) error { + _, err := fs.Lstat(fsDestPath) if err == nil { err := errors.New("invalid copy destination: cannot overwrite an existing file or directory") return err - } else if !c.connection.Fs.IsNotExist(err) { + } else if !fs.IsNotExist(err) { return err } return nil @@ -667,18 +680,18 @@ func (c *sshCommand) checkCopyQuota(numFiles int, filesSize int64, requestPath s return nil } -func (c *sshCommand) getSizeForPath(name string) (int, int64, error) { +func (c *sshCommand) getSizeForPath(fs vfs.Fs, name string) (int, int64, error) { if dataprovider.GetQuotaTracking() > 0 { - fi, err := c.connection.Fs.Lstat(name) + fi, err := fs.Lstat(name) if err != nil { - if c.connection.Fs.IsNotExist(err) { + if fs.IsNotExist(err) { return 0, 0, nil } c.connection.Log(logger.LevelDebug, "unable to stat %#v error: %v", name, err) return 0, 0, err } if fi.IsDir() { - files, size, err := c.connection.Fs.GetDirSize(name) + files, size, err := fs.GetDirSize(name) if err != nil { c.connection.Log(logger.LevelDebug, "unable to get size for dir %#v error: %v", name, err) } @@ -691,7 +704,7 @@ func (c *sshCommand) getSizeForPath(name string) (int, int64, error) { } func (c *sshCommand) sendErrorResponse(err error) error { - errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), c.connection.GetFsError(err)) + errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err) c.connection.channel.Write([]byte(errorString)) //nolint:errcheck c.sendExitStatus(err) return err @@ -722,15 +735,15 @@ func (c *sshCommand) sendExitStatus(err error) { // for scp we notify single uploads/downloads if c.command != scpCmdName { metrics.SSHCommandCompleted(err) - if len(cmdPath) > 0 { - p, e := c.connection.Fs.ResolvePath(cmdPath) - if e == nil { + if cmdPath != "" { + _, p, errFs := c.connection.GetFsAndResolvedPath(cmdPath) + if errFs == nil { cmdPath = p } } - if len(targetPath) > 0 { - p, e := c.connection.Fs.ResolvePath(targetPath) - if e == nil { + if targetPath != "" { + _, p, errFs := c.connection.GetFsAndResolvedPath(targetPath) + if errFs == nil { targetPath = p } } @@ -738,9 +751,9 @@ func (c *sshCommand) sendExitStatus(err error) { } } -func (c *sshCommand) computeHashForFile(hasher hash.Hash, path string) (string, error) { +func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) { hash := "" - f, r, _, err := c.connection.Fs.Open(path, 0) + f, r, _, err := fs.Open(path, 0) if err != nil { return hash, err } diff --git a/sftpd/subsystem.go b/sftpd/subsystem.go index c14e007f..ae226fd4 100644 --- a/sftpd/subsystem.go +++ b/sftpd/subsystem.go @@ -8,6 +8,7 @@ import ( "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/logger" ) type subsystemChannel struct { @@ -36,15 +37,16 @@ func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel { // ServeSubSystemConnection handles a connection as SSH subsystem func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error { - fs, err := user.GetFilesystem(connectionID) + err := user.CheckFsRoot(connectionID) if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) return err } - fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID()) dataprovider.UpdateLastLogin(user) //nolint:errcheck connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, *user, fs), + BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, *user), ClientVersion: "", RemoteAddr: &net.IPAddr{}, channel: newSubsystemChannel(reader, writer), diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 49c66941..fcc3159c 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -110,7 +110,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) { func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { t.Connection.UpdateLastActivity() if off < t.MinWriteOffset { - err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset) + err := fmt.Errorf("invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset) t.TransferError(err) return 0, err } @@ -143,7 +143,7 @@ func (t *transfer) Close() error { if errBaseClose != nil { err = errBaseClose } - return t.Connection.GetFsError(err) + return t.Connection.GetFsError(t.Fs, err) } func (t *transfer) closeIO() error { diff --git a/templates/folder.html b/templates/folder.html index bbfff954..3892ee01 100644 --- a/templates/folder.html +++ b/templates/folder.html @@ -63,13 +63,28 @@
+ value="{{.Folder.MappedPath}}" maxlength="512" aria-describedby="mappedPathHelpBlock"> + + Required for local providers. For Cloud providers, if set, it will store temporary files +
+ {{template "fshtml" .Folder.FsConfig}} + +{{end}} + +{{define "extra_js"}} + {{end}} \ No newline at end of file diff --git a/templates/fsconfig.html b/templates/fsconfig.html new file mode 100644 index 00000000..bc04851c --- /dev/null +++ b/templates/fsconfig.html @@ -0,0 +1,359 @@ +{{define "fshtml"}} +
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+ + + The buffer size for multipart uploads. Zero means the default (5 MB). Minimum is 5 + +
+
+ +
+ + + How many parts are uploaded in parallel. Zero means the default (2) + +
+
+ +
+ +
+ + + Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". + +
+
+ +
+ +
+ +
+
+ +
+ +
+ + + Add or update credentials from a JSON file + +
+
+ +
+ +
+
+ +
+
+ + +
+
+ +
+ +
+ + + Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". + +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+ +
+
+
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+ + + The buffer size for multipart uploads. Zero means the default (4 MB) + +
+
+ +
+ + + How many parts are uploaded in parallel. Zero means the default (2) + +
+
+ +
+ +
+ + + Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". + +
+
+ +
+
+ + +
+
+ +
+ +
+ +
+
+ +
+ +
+ + + Endpoint as host:port, port is always required + +
+
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+ + + SHA256 fingerprints to validate when connecting to the external SFTP server, one per line. If + empty any host key will be accepted: this is a security risk! + +
+
+ +
+ +
+ + + Similar to a chroot for local filesystem. Example: "/somedir/subdir". + +
+
+ +
+
+ + +
+
+{{end}} + +{{define "fsjs"}} + function onFilesystemChanged(val){ + if (val == '1'){ + $('.form-group.row.gcs').hide(); + $('.form-group.gcs').hide(); + $('.form-group.row.azblob').hide(); + $('.form-group.azblob').hide(); + $('.form-group.crypt').hide(); + $('.form-group.sftp').hide(); + $('.form-group.row.s3').show(); + } else if (val == '2'){ + $('.form-group.row.gcs').show(); + $('.form-group.gcs').show(); + $('.form-group.row.azblob').hide(); + $('.form-group.azblob').hide(); + $('.form-group.crypt').hide(); + $('.form-group.row.s3').hide(); + $('.form-group.sftp').hide(); + } else if (val == '3'){ + $('.form-group.row.azblob').show(); + $('.form-group.azblob').show(); + $('.form-group.row.gcs').hide(); + $('.form-group.gcs').hide(); + $('.form-group.crypt').hide(); + $('.form-group.row.s3').hide(); + $('.form-group.sftp').hide(); + } else if (val == '4'){ + $('.form-group.row.gcs').hide(); + $('.form-group.gcs').hide(); + $('.form-group.row.s3').hide(); + $('.form-group.row.azblob').hide(); + $('.form-group.azblob').hide(); + $('.form-group.crypt').show(); + $('.form-group.sftp').hide(); + } else if (val == '5'){ + $('.form-group.row.gcs').hide(); + $('.form-group.gcs').hide(); + $('.form-group.row.s3').hide(); + $('.form-group.row.azblob').hide(); + $('.form-group.azblob').hide(); + $('.form-group.crypt').hide(); + $('.form-group.sftp').show(); + } else { + $('.form-group.row.gcs').hide(); + $('.form-group.gcs').hide(); + $('.form-group.row.s3').hide(); + $('.form-group.row.azblob').hide(); + $('.form-group.azblob').hide(); + $('.form-group.crypt').hide(); + $('.form-group.sftp').hide(); + } + } +{{end}} \ No newline at end of file diff --git a/templates/user.html b/templates/user.html index 5d8ec15b..5af9449c 100644 --- a/templates/user.html +++ b/templates/user.html @@ -360,309 +360,7 @@ -
- -
- -
-
- -
- -
- -
-
- -
- -
-
- -
- -
- -
-
- -
- -
-
- -
- -
- -
-
- -
- -
-
- -
- -
- - - The buffer size for multipart uploads. Zero means the default (5 MB). Minimum is 5 - -
-
- -
- - - How many parts are uploaded in parallel. Zero means the default (2) - -
-
- -
- -
- - - Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". - -
-
- -
- -
- -
-
- -
- -
- - - Add or update credentials from a JSON file - -
-
- -
- -
-
- -
-
- - -
-
- -
- -
- - - Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". - -
-
- -
- -
- -
-
- -
- -
-
- -
- -
- -
-
- -
- -
- -
-
-
- -
- -
-
- -
- -
- -
-
- -
- -
- - - The buffer size for multipart uploads. Zero means the default (4 MB) - -
-
- -
- - - How many parts are uploaded in parallel. Zero means the default (2) - -
-
- -
- -
- - - Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/". - -
-
- -
-
- - -
-
- -
- -
- -
-
- -
- -
- - - Endpoint as host:port, port is always required - -
-
- -
- -
-
- -
-
- - -
-
- -
- -
- -
-
- -
- -
- -
-
- -
- -
- - - SHA256 fingerprints to validate when connecting to the external SFTP server, one per line. If - empty any host key will be accepted: this is a security risk! - -
-
- -
- -
- - - Similar to a chroot for local filesystem. Example: "/somedir/subdir". - -
-
+ {{template "fshtml" .User.FsConfig}}
@@ -737,56 +435,6 @@ }); - function onFilesystemChanged(val){ - if (val == '1'){ - $('.form-group.row.gcs').hide(); - $('.form-group.gcs').hide(); - $('.form-group.row.azblob').hide(); - $('.form-group.azblob').hide(); - $('.form-group.crypt').hide(); - $('.form-group.sftp').hide(); - $('.form-group.row.s3').show(); - } else if (val == '2'){ - $('.form-group.row.gcs').show(); - $('.form-group.gcs').show(); - $('.form-group.row.azblob').hide(); - $('.form-group.azblob').hide(); - $('.form-group.crypt').hide(); - $('.form-group.row.s3').hide(); - $('.form-group.sftp').hide(); - } else if (val == '3'){ - $('.form-group.row.azblob').show(); - $('.form-group.azblob').show(); - $('.form-group.row.gcs').hide(); - $('.form-group.gcs').hide(); - $('.form-group.crypt').hide(); - $('.form-group.row.s3').hide(); - $('.form-group.sftp').hide(); - } else if (val == '4'){ - $('.form-group.row.gcs').hide(); - $('.form-group.gcs').hide(); - $('.form-group.row.s3').hide(); - $('.form-group.row.azblob').hide(); - $('.form-group.azblob').hide(); - $('.form-group.crypt').show(); - $('.form-group.sftp').hide(); - } else if (val == '5'){ - $('.form-group.row.gcs').hide(); - $('.form-group.gcs').hide(); - $('.form-group.row.s3').hide(); - $('.form-group.row.azblob').hide(); - $('.form-group.azblob').hide(); - $('.form-group.crypt').hide(); - $('.form-group.sftp').show(); - } else { - $('.form-group.row.gcs').hide(); - $('.form-group.gcs').hide(); - $('.form-group.row.s3').hide(); - $('.form-group.row.azblob').hide(); - $('.form-group.azblob').hide(); - $('.form-group.crypt').hide(); - $('.form-group.sftp').hide(); - } - } + {{template "fsjs"}} {{end}} diff --git a/utils/utils.go b/utils/utils.go index 5a8d8c24..e8eced67 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -312,18 +312,24 @@ func GenerateEd25519Keys(file string) error { return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600) } -// GetDirsForSFTPPath returns all the directory for the given path in reverse order +// GetDirsForVirtualPath returns all the directory for the given path in reverse order // for example if the path is: /1/2/3/4 it returns: // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] -func GetDirsForSFTPPath(p string) []string { - sftpPath := CleanPath(p) - dirsForPath := []string{sftpPath} +func GetDirsForVirtualPath(virtualPath string) []string { + if virtualPath == "." { + virtualPath = "/" + } else { + if !path.IsAbs(virtualPath) { + virtualPath = CleanPath(virtualPath) + } + } + dirsForPath := []string{virtualPath} for { - if sftpPath == "/" { + if virtualPath == "/" { break } - sftpPath = path.Dir(sftpPath) - dirsForPath = append(dirsForPath, sftpPath) + virtualPath = path.Dir(virtualPath) + dirsForPath = append(dirsForPath, virtualPath) } return dirsForPath } diff --git a/vfs/azblobfs.go b/vfs/azblobfs.go index af18547f..a1c455c6 100644 --- a/vfs/azblobfs.go +++ b/vfs/azblobfs.go @@ -36,8 +36,10 @@ var maxTryTimeout = time.Hour * 24 * 365 // AzureBlobFs is a Fs implementation for Azure Blob storage. type AzureBlobFs struct { - connectionID string - localTempDir string + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string config *AzBlobFsConfig svc *azblob.ServiceURL containerURL azblob.ContainerURL @@ -50,10 +52,14 @@ func init() { } // NewAzBlobFs returns an AzBlobFs object that allows to interact with Azure Blob storage -func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs, error) { +func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = filepath.Clean(os.TempDir()) + } fs := &AzureBlobFs{ connectionID: connectionID, localTempDir: localTempDir, + mountPath: mountPath, config: &config, ctxTimeout: 30 * time.Second, ctxLongTimeout: 300 * time.Second, @@ -89,7 +95,7 @@ func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs, parts := azblob.NewBlobURLParts(*u) if parts.ContainerName != "" { if fs.config.Container != "" && fs.config.Container != parts.ContainerName { - return fs, fmt.Errorf("Container name in SAS URL %#v and container provided %#v do not match", + return fs, fmt.Errorf("container name in SAS URL %#v and container provided %#v do not match", parts.ContainerName, fs.config.Container) } fs.svc = nil @@ -282,7 +288,7 @@ func (fs *AzureBlobFs) Rename(source, target string) error { return err } if hasContents { - return fmt.Errorf("Cannot rename non empty directory: %#v", source) + return fmt.Errorf("cannot rename non empty directory: %#v", source) } } dstBlobURL := fs.containerURL.NewBlobURL(target) @@ -318,7 +324,7 @@ func (fs *AzureBlobFs) Rename(source, target string) error { } } if copyStatus != azblob.CopyStatusSuccess { - err := fmt.Errorf("Copy failed with status: %s", copyStatus) + err := fmt.Errorf("copy failed with status: %s", copyStatus) metrics.AZCopyObjectCompleted(err) return err } @@ -334,7 +340,7 @@ func (fs *AzureBlobFs) Remove(name string, isDir bool) error { return err } if hasContents { - return fmt.Errorf("Cannot remove non empty directory: %#v", name) + return fmt.Errorf("cannot remove non empty directory: %#v", name) } } blobBlockURL := fs.containerURL.NewBlockBlobURL(name) @@ -359,6 +365,11 @@ func (fs *AzureBlobFs) Mkdir(name string) error { return w.Close() } +// MkdirAll does nothing, we don't have folder +func (*AzureBlobFs) MkdirAll(name string, uid int, gid int) error { + return nil +} + // Symlink creates source as a symbolic link to target. func (*AzureBlobFs) Symlink(source, target string) error { return ErrVfsUnsupported @@ -528,7 +539,7 @@ func (*AzureBlobFs) IsNotSupported(err error) bool { // CheckRootPath creates the specified local root directory if it does not exists func (fs *AzureBlobFs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files - osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil) + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "") return osFs.CheckRootPath(username, uid, gid) } @@ -607,6 +618,9 @@ func (fs *AzureBlobFs) GetRelativePath(name string) string { } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } return rel } @@ -675,6 +689,9 @@ func (*AzureBlobFs) HasVirtualFolders() bool { // ResolvePath returns the matching filesystem path for the specified sftp path func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath) + } if !path.IsAbs(virtualPath) { virtualPath = path.Clean("/" + virtualPath) } diff --git a/vfs/azblobfs_disabled.go b/vfs/azblobfs_disabled.go index e262c481..7a4869e0 100644 --- a/vfs/azblobfs_disabled.go +++ b/vfs/azblobfs_disabled.go @@ -13,6 +13,6 @@ func init() { } // NewAzBlobFs returns an error, Azure Blob storage is disabled -func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs, error) { +func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) { return nil, errors.New("Azure Blob Storage disabled at build time") } diff --git a/vfs/cryptfs.go b/vfs/cryptfs.go index 3cef672c..79e95c01 100644 --- a/vfs/cryptfs.go +++ b/vfs/cryptfs.go @@ -31,7 +31,7 @@ type CryptFs struct { } // NewCryptFs returns a CryptFs object -func NewCryptFs(connectionID, rootDir string, config CryptFsConfig) (Fs, error) { +func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (Fs, error) { if err := config.Validate(); err != nil { return nil, err } @@ -42,10 +42,10 @@ func NewCryptFs(connectionID, rootDir string, config CryptFsConfig) (Fs, error) } fs := &CryptFs{ OsFs: &OsFs{ - name: cryptFsName, - connectionID: connectionID, - rootDir: rootDir, - virtualFolders: nil, + name: cryptFsName, + connectionID: connectionID, + rootDir: rootDir, + mountPath: mountPath, }, masterKey: []byte(config.Passphrase.GetPayload()), } diff --git a/vfs/filesystem.go b/vfs/filesystem.go new file mode 100644 index 00000000..edceb129 --- /dev/null +++ b/vfs/filesystem.go @@ -0,0 +1,104 @@ +package vfs + +import "github.com/drakkan/sftpgo/kms" + +// FilesystemProvider defines the supported storage filesystems +type FilesystemProvider int + +// supported values for FilesystemProvider +const ( + LocalFilesystemProvider FilesystemProvider = iota // Local + S3FilesystemProvider // AWS S3 compatible + GCSFilesystemProvider // Google Cloud Storage + AzureBlobFilesystemProvider // Azure Blob Storage + CryptedFilesystemProvider // Local encrypted + SFTPFilesystemProvider // SFTP +) + +// Filesystem defines cloud storage filesystem details +type Filesystem struct { + RedactedSecret string `json:"-"` + Provider FilesystemProvider `json:"provider"` + S3Config S3FsConfig `json:"s3config,omitempty"` + GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"` + AzBlobConfig AzBlobFsConfig `json:"azblobconfig,omitempty"` + CryptConfig CryptFsConfig `json:"cryptconfig,omitempty"` + SFTPConfig SFTPFsConfig `json:"sftpconfig,omitempty"` +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (f *Filesystem) SetEmptySecretsIfNil() { + if f.S3Config.AccessSecret == nil { + f.S3Config.AccessSecret = kms.NewEmptySecret() + } + if f.GCSConfig.Credentials == nil { + f.GCSConfig.Credentials = kms.NewEmptySecret() + } + if f.AzBlobConfig.AccountKey == nil { + f.AzBlobConfig.AccountKey = kms.NewEmptySecret() + } + if f.CryptConfig.Passphrase == nil { + f.CryptConfig.Passphrase = kms.NewEmptySecret() + } + if f.SFTPConfig.Password == nil { + f.SFTPConfig.Password = kms.NewEmptySecret() + } + if f.SFTPConfig.PrivateKey == nil { + f.SFTPConfig.PrivateKey = kms.NewEmptySecret() + } +} + +// GetACopy returns a copy +func (f *Filesystem) GetACopy() Filesystem { + f.SetEmptySecretsIfNil() + fs := Filesystem{ + Provider: f.Provider, + S3Config: S3FsConfig{ + Bucket: f.S3Config.Bucket, + Region: f.S3Config.Region, + AccessKey: f.S3Config.AccessKey, + AccessSecret: f.S3Config.AccessSecret.Clone(), + Endpoint: f.S3Config.Endpoint, + StorageClass: f.S3Config.StorageClass, + KeyPrefix: f.S3Config.KeyPrefix, + UploadPartSize: f.S3Config.UploadPartSize, + UploadConcurrency: f.S3Config.UploadConcurrency, + }, + GCSConfig: GCSFsConfig{ + Bucket: f.GCSConfig.Bucket, + CredentialFile: f.GCSConfig.CredentialFile, + Credentials: f.GCSConfig.Credentials.Clone(), + AutomaticCredentials: f.GCSConfig.AutomaticCredentials, + StorageClass: f.GCSConfig.StorageClass, + KeyPrefix: f.GCSConfig.KeyPrefix, + }, + AzBlobConfig: AzBlobFsConfig{ + Container: f.AzBlobConfig.Container, + AccountName: f.AzBlobConfig.AccountName, + AccountKey: f.AzBlobConfig.AccountKey.Clone(), + Endpoint: f.AzBlobConfig.Endpoint, + SASURL: f.AzBlobConfig.SASURL, + KeyPrefix: f.AzBlobConfig.KeyPrefix, + UploadPartSize: f.AzBlobConfig.UploadPartSize, + UploadConcurrency: f.AzBlobConfig.UploadConcurrency, + UseEmulator: f.AzBlobConfig.UseEmulator, + AccessTier: f.AzBlobConfig.AccessTier, + }, + CryptConfig: CryptFsConfig{ + Passphrase: f.CryptConfig.Passphrase.Clone(), + }, + SFTPConfig: SFTPFsConfig{ + Endpoint: f.SFTPConfig.Endpoint, + Username: f.SFTPConfig.Username, + Password: f.SFTPConfig.Password.Clone(), + PrivateKey: f.SFTPConfig.PrivateKey.Clone(), + Prefix: f.SFTPConfig.Prefix, + DisableCouncurrentReads: f.SFTPConfig.DisableCouncurrentReads, + }, + } + if len(f.SFTPConfig.Fingerprints) > 0 { + fs.SFTPConfig.Fingerprints = make([]string, len(f.SFTPConfig.Fingerprints)) + copy(fs.SFTPConfig.Fingerprints, f.SFTPConfig.Fingerprints) + } + return fs +} diff --git a/vfs/folder.go b/vfs/folder.go index cf0b04f9..113211a4 100644 --- a/vfs/folder.go +++ b/vfs/folder.go @@ -2,6 +2,7 @@ package vfs import ( "fmt" + "path/filepath" "strconv" "strings" @@ -23,6 +24,18 @@ type BaseVirtualFolder struct { LastQuotaUpdate int64 `json:"last_quota_update"` // list of usernames associated with this virtual folder Users []string `json:"users,omitempty"` + // Filesystem configuration details + FsConfig Filesystem `json:"filesystem"` +} + +// GetEncrytionAdditionalData returns the additional data to use for AEAD +func (v *BaseVirtualFolder) GetEncrytionAdditionalData() string { + return fmt.Sprintf("folder_%v", v.Name) +} + +// GetGCSCredentialsFilePath returns the path for GCS credentials +func (v *BaseVirtualFolder) GetGCSCredentialsFilePath() string { + return filepath.Join(credentialsDirPath, "folders", fmt.Sprintf("%v_gcs_credentials.json", v.Name)) } // GetACopy returns a copy @@ -38,6 +51,7 @@ func (v *BaseVirtualFolder) GetACopy() BaseVirtualFolder { UsedQuotaFiles: v.UsedQuotaFiles, LastQuotaUpdate: v.LastQuotaUpdate, Users: users, + FsConfig: v.FsConfig.GetACopy(), } } @@ -60,6 +74,58 @@ func (v *BaseVirtualFolder) GetQuotaSummary() string { return result } +// IsLocalOrLocalCrypted returns true if the folder provider is local or local encrypted +func (v *BaseVirtualFolder) IsLocalOrLocalCrypted() bool { + return v.FsConfig.Provider == LocalFilesystemProvider || v.FsConfig.Provider == CryptedFilesystemProvider +} + +// HideConfidentialData hides folder confidential data +func (v *BaseVirtualFolder) HideConfidentialData() { + switch v.FsConfig.Provider { + case S3FilesystemProvider: + v.FsConfig.S3Config.AccessSecret.Hide() + case GCSFilesystemProvider: + v.FsConfig.GCSConfig.Credentials.Hide() + case AzureBlobFilesystemProvider: + v.FsConfig.AzBlobConfig.AccountKey.Hide() + case CryptedFilesystemProvider: + v.FsConfig.CryptConfig.Passphrase.Hide() + case SFTPFilesystemProvider: + v.FsConfig.SFTPConfig.Password.Hide() + v.FsConfig.SFTPConfig.PrivateKey.Hide() + } +} + +// HasRedactedSecret returns true if the folder has a redacted secret +func (v *BaseVirtualFolder) HasRedactedSecret() bool { + switch v.FsConfig.Provider { + case S3FilesystemProvider: + if v.FsConfig.S3Config.AccessSecret.IsRedacted() { + return true + } + case GCSFilesystemProvider: + if v.FsConfig.GCSConfig.Credentials.IsRedacted() { + return true + } + case AzureBlobFilesystemProvider: + if v.FsConfig.AzBlobConfig.AccountKey.IsRedacted() { + return true + } + case CryptedFilesystemProvider: + if v.FsConfig.CryptConfig.Passphrase.IsRedacted() { + return true + } + case SFTPFilesystemProvider: + if v.FsConfig.SFTPConfig.Password.IsRedacted() { + return true + } + if v.FsConfig.SFTPConfig.PrivateKey.IsRedacted() { + return true + } + } + return false +} + // VirtualFolder defines a mapping between an SFTPGo exposed virtual path and a // filesystem path outside the user home directory. // The specified paths must be absolute and the virtual path cannot be "/", @@ -75,6 +141,36 @@ type VirtualFolder struct { QuotaFiles int `json:"quota_files"` } +func (v *VirtualFolder) GetFilesystem(connectionID string) (Fs, error) { + switch v.FsConfig.Provider { + case S3FilesystemProvider: + return NewS3Fs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.S3Config) + case GCSFilesystemProvider: + config := v.FsConfig.GCSConfig + config.CredentialFile = v.GetGCSCredentialsFilePath() + return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, config) + case AzureBlobFilesystemProvider: + return NewAzBlobFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.AzBlobConfig) + case CryptedFilesystemProvider: + return NewCryptFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.CryptConfig) + case SFTPFilesystemProvider: + return NewSFTPFs(connectionID, v.VirtualPath, v.FsConfig.SFTPConfig) + default: + return NewOsFs(connectionID, v.MappedPath, v.VirtualPath), nil + } +} + +// ScanQuota scans the folder and returns the number of files and their size +func (v *VirtualFolder) ScanQuota() (int, int64, error) { + fs, err := v.GetFilesystem("") + if err != nil { + return 0, 0, err + } + defer fs.Close() + + return fs.ScanRootDirContents() +} + // IsIncludedInUserQuota returns true if the virtual folder is included in user quota func (v *VirtualFolder) IsIncludedInUserQuota() bool { return v.QuotaFiles == -1 && v.QuotaSize == -1 @@ -87,3 +183,13 @@ func (v *VirtualFolder) HasNoQuotaRestrictions(checkFiles bool) bool { } return false } + +// GetACopy returns a copy +func (v *VirtualFolder) GetACopy() VirtualFolder { + return VirtualFolder{ + BaseVirtualFolder: v.BaseVirtualFolder.GetACopy(), + VirtualPath: v.VirtualPath, + QuotaSize: v.QuotaSize, + QuotaFiles: v.QuotaFiles, + } +} diff --git a/vfs/gcsfs.go b/vfs/gcsfs.go index 598c3f8f..b1a5781c 100644 --- a/vfs/gcsfs.go +++ b/vfs/gcsfs.go @@ -35,8 +35,10 @@ var ( // GCSFs is a Fs implementation for Google Cloud Storage. type GCSFs struct { - connectionID string - localTempDir string + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string config *GCSFsConfig svc *storage.Client ctxTimeout time.Duration @@ -48,11 +50,16 @@ func init() { } // NewGCSFs returns an GCSFs object that allows to interact with Google Cloud Storage -func NewGCSFs(connectionID, localTempDir string, config GCSFsConfig) (Fs, error) { +func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = filepath.Clean(os.TempDir()) + } + var err error fs := &GCSFs{ connectionID: connectionID, localTempDir: localTempDir, + mountPath: mountPath, config: &config, ctxTimeout: 30 * time.Second, ctxLongTimeout: 300 * time.Second, @@ -152,7 +159,7 @@ func (fs *GCSFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, fu ctx, cancelFn := context.WithCancel(context.Background()) objectReader, err := obj.NewRangeReader(ctx, offset, -1) if err == nil && offset > 0 && objectReader.Attrs.ContentEncoding == "gzip" { - err = fmt.Errorf("Range request is not possible for gzip content encoding, requested offset %v", offset) + err = fmt.Errorf("range request is not possible for gzip content encoding, requested offset %v", offset) objectReader.Close() } if err != nil { @@ -230,7 +237,7 @@ func (fs *GCSFs) Rename(source, target string) error { return err } if hasContents { - return fmt.Errorf("Cannot rename non empty directory: %#v", source) + return fmt.Errorf("cannot rename non empty directory: %#v", source) } } src := fs.svc.Bucket(fs.config.Bucket).Object(source) @@ -266,7 +273,7 @@ func (fs *GCSFs) Remove(name string, isDir bool) error { return err } if hasContents { - return fmt.Errorf("Cannot remove non empty directory: %#v", name) + return fmt.Errorf("cannot remove non empty directory: %#v", name) } } ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) @@ -290,6 +297,11 @@ func (fs *GCSFs) Mkdir(name string) error { return w.Close() } +// MkdirAll does nothing, we don't have folder +func (*GCSFs) MkdirAll(name string, uid int, gid int) error { + return nil +} + // Symlink creates source as a symbolic link to target. func (*GCSFs) Symlink(source, target string) error { return ErrVfsUnsupported @@ -441,7 +453,7 @@ func (*GCSFs) IsNotSupported(err error) bool { // CheckRootPath creates the specified local root directory if it does not exists func (fs *GCSFs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files - osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil) + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "") return osFs.CheckRootPath(username, uid, gid) } @@ -510,6 +522,9 @@ func (fs *GCSFs) GetRelativePath(name string) string { } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } return rel } @@ -578,6 +593,9 @@ func (GCSFs) HasVirtualFolders() bool { // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath) + } if !path.IsAbs(virtualPath) { virtualPath = path.Clean("/" + virtualPath) } diff --git a/vfs/gcsfs_disabled.go b/vfs/gcsfs_disabled.go index e90232e4..8a6daf51 100644 --- a/vfs/gcsfs_disabled.go +++ b/vfs/gcsfs_disabled.go @@ -13,6 +13,6 @@ func init() { } // NewGCSFs returns an error, GCS is disabled -func NewGCSFs(connectionID, localTempDir string, config GCSFsConfig) (Fs, error) { +func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) { return nil, errors.New("Google Cloud Storage disabled at build time") } diff --git a/vfs/osfs.go b/vfs/osfs.go index 56e4b0db..2af7ba69 100644 --- a/vfs/osfs.go +++ b/vfs/osfs.go @@ -15,7 +15,6 @@ import ( "github.com/rs/xid" "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/utils" ) const ( @@ -25,19 +24,20 @@ const ( // OsFs is a Fs implementation that uses functions provided by the os package. type OsFs struct { - name string - connectionID string - rootDir string - virtualFolders []VirtualFolder + name string + connectionID string + rootDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string } // NewOsFs returns an OsFs object that allows to interact with local Os filesystem -func NewOsFs(connectionID, rootDir string, virtualFolders []VirtualFolder) Fs { +func NewOsFs(connectionID, rootDir, mountPath string) Fs { return &OsFs{ - name: osFsName, - connectionID: connectionID, - rootDir: rootDir, - virtualFolders: virtualFolders, + name: osFsName, + connectionID: connectionID, + rootDir: rootDir, + mountPath: mountPath, } } @@ -53,32 +53,12 @@ func (fs *OsFs) ConnectionID() string { // Stat returns a FileInfo describing the named file func (fs *OsFs) Stat(name string) (os.FileInfo, error) { - fi, err := os.Stat(name) - if err != nil { - return fi, err - } - for _, v := range fs.virtualFolders { - if v.MappedPath == name { - info := NewFileInfo(v.VirtualPath, true, fi.Size(), fi.ModTime(), false) - return info, nil - } - } - return fi, err + return os.Stat(name) } // Lstat returns a FileInfo describing the named file func (fs *OsFs) Lstat(name string) (os.FileInfo, error) { - fi, err := os.Lstat(name) - if err != nil { - return fi, err - } - for _, v := range fs.virtualFolders { - if v.MappedPath == name { - info := NewFileInfo(v.VirtualPath, true, fi.Size(), fi.ModTime(), false) - return info, nil - } - } - return fi, err + return os.Lstat(name) } // Open opens the named file for reading @@ -114,6 +94,13 @@ func (*OsFs) Mkdir(name string) error { return os.Mkdir(name, os.ModePerm) } +// MkdirAll creates a directory named path, along with any necessary parents, +// and returns nil, or else returns an error. +// If path is already a directory, MkdirAll does nothing and returns nil. +func (fs *OsFs) MkdirAll(name string, uid int, gid int) error { + return fs.createMissingDirs(name, uid, gid) +} + // Symlink creates source as a symbolic link to target. func (*OsFs) Symlink(source, target string) error { return os.Symlink(source, target) @@ -205,45 +192,13 @@ func (fs *OsFs) CheckRootPath(username string, uid int, gid int) bool { SetPathPermissions(fs, fs.rootDir, uid, gid) } } - // create any missing dirs to the defined virtual dirs - for _, v := range fs.virtualFolders { - p := filepath.Clean(filepath.Join(fs.rootDir, v.VirtualPath)) - err = fs.createMissingDirs(p, uid, gid) - if err != nil { - return false - } - if _, err = fs.Stat(v.MappedPath); fs.IsNotExist(err) { - err = os.MkdirAll(v.MappedPath, os.ModePerm) - fsLog(fs, logger.LevelDebug, "virtual directory %#v for user %#v does not exist, try to create, mkdir error: %v", - v.MappedPath, username, err) - if err == nil { - SetPathPermissions(fs, fs.rootDir, uid, gid) - } - } - } - return (err == nil) + return err == nil } // ScanRootDirContents returns the number of files contained in the root // directory and their size func (fs *OsFs) ScanRootDirContents() (int, int64, error) { - numFiles, size, err := fs.GetDirSize(fs.rootDir) - for _, v := range fs.virtualFolders { - if !v.IsIncludedInUserQuota() { - continue - } - num, s, err := fs.GetDirSize(v.MappedPath) - if err != nil { - if fs.IsNotExist(err) { - fsLog(fs, logger.LevelWarn, "unable to scan contents for non-existent mapped path: %#v", v.MappedPath) - continue - } - return numFiles, size, err - } - numFiles += num - size += s - } - return numFiles, size, err + return fs.GetDirSize(fs.rootDir) } // GetAtomicUploadPath returns the path to use for an atomic upload @@ -254,18 +209,13 @@ func (*OsFs) GetAtomicUploadPath(name string) string { } // GetRelativePath returns the path for a file relative to the user's home dir. -// This is the path as seen by SFTP users +// This is the path as seen by SFTPGo users func (fs *OsFs) GetRelativePath(name string) string { - basePath := fs.rootDir virtualPath := "/" - for _, v := range fs.virtualFolders { - if strings.HasPrefix(name, v.MappedPath+string(os.PathSeparator)) || - filepath.Clean(name) == v.MappedPath { - basePath = v.MappedPath - virtualPath = v.VirtualPath - } + if fs.mountPath != "" { + virtualPath = fs.mountPath } - rel, err := filepath.Rel(basePath, filepath.Clean(name)) + rel, err := filepath.Rel(fs.rootDir, filepath.Clean(name)) if err != nil { return "" } @@ -287,28 +237,31 @@ func (*OsFs) Join(elem ...string) string { } // ResolvePath returns the matching filesystem path for the specified sftp path -func (fs *OsFs) ResolvePath(sftpPath string) (string, error) { +func (fs *OsFs) ResolvePath(virtualPath string) (string, error) { if !filepath.IsAbs(fs.rootDir) { - return "", fmt.Errorf("Invalid root path: %v", fs.rootDir) + return "", fmt.Errorf("invalid root path: %v", fs.rootDir) } - basePath, r := fs.GetFsPaths(sftpPath) + if fs.mountPath != "" { + virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath) + } + r := filepath.Clean(filepath.Join(fs.rootDir, virtualPath)) p, err := filepath.EvalSymlinks(r) if err != nil && !os.IsNotExist(err) { return "", err } else if os.IsNotExist(err) { // The requested path doesn't exist, so at this point we need to iterate up the // path chain until we hit a directory that _does_ exist and can be validated. - _, err = fs.findFirstExistingDir(r, basePath) + _, err = fs.findFirstExistingDir(r) if err != nil { fsLog(fs, logger.LevelWarn, "error resolving non-existent path %#v", err) } return r, err } - err = fs.isSubDir(p, basePath) + err = fs.isSubDir(p) if err != nil { fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir %#v original path %#v resolved %#v err: %v", - p, sftpPath, r, err) + p, virtualPath, r, err) } return r, err } @@ -339,43 +292,9 @@ func (*OsFs) HasVirtualFolders() bool { return false } -// GetFsPaths returns the base path and filesystem path for the given sftpPath. -// base path is the root dir or matching the virtual folder dir for the sftpPath. -// file path is the filesystem path matching the sftpPath -func (fs *OsFs) GetFsPaths(sftpPath string) (string, string) { - basePath := fs.rootDir - virtualPath, mappedPath := fs.getMappedFolderForPath(sftpPath) - if mappedPath != "" { - basePath = mappedPath - sftpPath = strings.TrimPrefix(utils.CleanPath(sftpPath), virtualPath) - } - r := filepath.Clean(filepath.Join(basePath, sftpPath)) - return basePath, r -} - -// returns the path for the mapped folders or an empty string -func (fs *OsFs) getMappedFolderForPath(p string) (virtualPath, mappedPath string) { - if len(fs.virtualFolders) == 0 { - return - } - dirsForPath := utils.GetDirsForSFTPPath(p) - // dirsForPath contains all the dirs for a given path in reverse order - // for example if the path is: /1/2/3/4 it contains: - // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] - // so the first match is the one we are interested to - for _, val := range dirsForPath { - for _, v := range fs.virtualFolders { - if val == v.VirtualPath { - return v.VirtualPath, v.MappedPath - } - } - } - return -} - -func (fs *OsFs) findNonexistentDirs(path, rootPath string) ([]string, error) { +func (fs *OsFs) findNonexistentDirs(filePath string) ([]string, error) { results := []string{} - cleanPath := filepath.Clean(path) + cleanPath := filepath.Clean(filePath) parent := filepath.Dir(cleanPath) _, err := os.Stat(parent) @@ -391,15 +310,15 @@ func (fs *OsFs) findNonexistentDirs(path, rootPath string) ([]string, error) { if err != nil { return results, err } - err = fs.isSubDir(p, rootPath) + err = fs.isSubDir(p) if err != nil { fsLog(fs, logger.LevelWarn, "error finding non existing dir: %v", err) } return results, err } -func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) { - results, err := fs.findNonexistentDirs(path, rootPath) +func (fs *OsFs) findFirstExistingDir(path string) (string, error) { + results, err := fs.findNonexistentDirs(path) if err != nil { fsLog(fs, logger.LevelWarn, "unable to find non existent dirs: %v", err) return "", err @@ -409,7 +328,7 @@ func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) { lastMissingDir := results[len(results)-1] parent = filepath.Dir(lastMissingDir) } else { - parent = rootPath + parent = fs.rootDir } p, err := filepath.EvalSymlinks(parent) if err != nil { @@ -422,15 +341,15 @@ func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) { if !fileInfo.IsDir() { return "", fmt.Errorf("resolved path is not a dir: %#v", p) } - err = fs.isSubDir(p, rootPath) + err = fs.isSubDir(p) return p, err } -func (fs *OsFs) isSubDir(sub, rootPath string) error { - // rootPath must exist and it is already a validated absolute path - parent, err := filepath.EvalSymlinks(rootPath) +func (fs *OsFs) isSubDir(sub string) error { + // fs.rootDir must exist and it is already a validated absolute path + parent, err := filepath.EvalSymlinks(fs.rootDir) if err != nil { - fsLog(fs, logger.LevelWarn, "invalid root path %#v: %v", rootPath, err) + fsLog(fs, logger.LevelWarn, "invalid root path %#v: %v", fs.rootDir, err) return err } if parent == sub { @@ -448,7 +367,7 @@ func (fs *OsFs) isSubDir(sub, rootPath string) error { } func (fs *OsFs) createMissingDirs(filePath string, uid, gid int) error { - dirsToCreate, err := fs.findNonexistentDirs(filePath, fs.rootDir) + dirsToCreate, err := fs.findNonexistentDirs(filePath) if err != nil { return err } diff --git a/vfs/s3fs.go b/vfs/s3fs.go index e7557e86..93912f74 100644 --- a/vfs/s3fs.go +++ b/vfs/s3fs.go @@ -30,8 +30,10 @@ import ( // S3Fs is a Fs implementation for AWS S3 compatible object storages type S3Fs struct { - connectionID string - localTempDir string + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string config *S3FsConfig svc *s3.S3 ctxTimeout time.Duration @@ -44,10 +46,14 @@ func init() { // NewS3Fs returns an S3Fs object that allows to interact with an s3 compatible // object storage -func NewS3Fs(connectionID, localTempDir string, config S3FsConfig) (Fs, error) { +func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = filepath.Clean(os.TempDir()) + } fs := &S3Fs{ connectionID: connectionID, localTempDir: localTempDir, + mountPath: mountPath, config: &config, ctxTimeout: 30 * time.Second, ctxLongTimeout: 300 * time.Second, @@ -249,7 +255,7 @@ func (fs *S3Fs) Rename(source, target string) error { return err } if hasContents { - return fmt.Errorf("Cannot rename non empty directory: %#v", source) + return fmt.Errorf("cannot rename non empty directory: %#v", source) } if !strings.HasSuffix(copySource, "/") { copySource += "/" @@ -288,7 +294,7 @@ func (fs *S3Fs) Remove(name string, isDir bool) error { return err } if hasContents { - return fmt.Errorf("Cannot remove non empty directory: %#v", name) + return fmt.Errorf("cannot remove non empty directory: %#v", name) } if !strings.HasSuffix(name, "/") { name += "/" @@ -320,6 +326,11 @@ func (fs *S3Fs) Mkdir(name string) error { return w.Close() } +// MkdirAll does nothing, we don't have folder +func (*S3Fs) MkdirAll(name string, uid int, gid int) error { + return nil +} + // Symlink creates source as a symbolic link to target. func (*S3Fs) Symlink(source, target string) error { return ErrVfsUnsupported @@ -465,7 +476,7 @@ func (*S3Fs) IsNotSupported(err error) bool { // CheckRootPath creates the specified local root directory if it does not exists func (fs *S3Fs) CheckRootPath(username string, uid int, gid int) bool { // we need a local directory for temporary files - osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil) + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "") return osFs.CheckRootPath(username, uid, gid) } @@ -522,6 +533,9 @@ func (fs *S3Fs) GetRelativePath(name string) string { } rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } return rel } @@ -574,6 +588,9 @@ func (*S3Fs) HasVirtualFolders() bool { // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath) + } if !path.IsAbs(virtualPath) { virtualPath = path.Clean("/" + virtualPath) } diff --git a/vfs/s3fs_disabled.go b/vfs/s3fs_disabled.go index 743b0d1f..3fafee1d 100644 --- a/vfs/s3fs_disabled.go +++ b/vfs/s3fs_disabled.go @@ -13,6 +13,6 @@ func init() { } // NewS3Fs returns an error, S3 is disabled -func NewS3Fs(connectionID, localTempDir string, config S3FsConfig) (Fs, error) { +func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (Fs, error) { return nil, errors.New("S3 disabled at build time") } diff --git a/vfs/sftpfs.go b/vfs/sftpfs.go index 128d3b76..97d8bd7f 100644 --- a/vfs/sftpfs.go +++ b/vfs/sftpfs.go @@ -110,14 +110,16 @@ func (c *SFTPFsConfig) EncryptCredentials(additionalData string) error { type SFTPFs struct { sync.Mutex connectionID string - config *SFTPFsConfig - sshClient *ssh.Client - sftpClient *sftp.Client - err chan error + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + config *SFTPFsConfig + sshClient *ssh.Client + sftpClient *sftp.Client + err chan error } // NewSFTPFs returns an SFTPFa object that allows to interact with an SFTP server -func NewSFTPFs(connectionID string, config SFTPFsConfig) (Fs, error) { +func NewSFTPFs(connectionID, mountPath string, config SFTPFsConfig) (Fs, error) { if err := config.Validate(); err != nil { return nil, err } @@ -133,6 +135,7 @@ func NewSFTPFs(connectionID string, config SFTPFsConfig) (Fs, error) { } sftpFs := &SFTPFs{ connectionID: connectionID, + mountPath: mountPath, config: &config, err: make(chan error, 1), } @@ -226,6 +229,16 @@ func (fs *SFTPFs) Mkdir(name string) error { return fs.sftpClient.Mkdir(name) } +// MkdirAll creates a directory named path, along with any necessary parents, +// and returns nil, or else returns an error. +// If path is already a directory, MkdirAll does nothing and returns nil. +func (fs *SFTPFs) MkdirAll(name string, uid int, gid int) error { + if err := fs.checkConnection(); err != nil { + return err + } + return fs.sftpClient.MkdirAll(name) +} + // Symlink creates source as a symbolic link to target. func (fs *SFTPFs) Symlink(source, target string) error { if err := fs.checkConnection(); err != nil { @@ -358,6 +371,9 @@ func (fs *SFTPFs) GetRelativePath(name string) string { } rel = path.Clean("/" + strings.TrimPrefix(rel, fs.config.Prefix)) } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } return rel } @@ -393,6 +409,9 @@ func (*SFTPFs) HasVirtualFolders() bool { // ResolvePath returns the matching filesystem path for the specified virtual path func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath) + } if !path.IsAbs(virtualPath) { virtualPath = path.Clean("/" + virtualPath) } @@ -559,7 +578,7 @@ func (fs *SFTPFs) createConnection() error { return nil } } - return fmt.Errorf("Invalid fingerprint %#v", fp) + return fmt.Errorf("invalid fingerprint %#v", fp) } fsLog(fs, logger.LevelWarn, "login without host key validation, please provide at least a fingerprint!") return nil diff --git a/vfs/vfs.go b/vfs/vfs.go index 9ebba568..2660682f 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -27,8 +27,21 @@ var ( validAzAccessTier = []string{"", "Archive", "Hot", "Cool"} // ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend") + // ErrVfsUnsupported defines the error for an unsupported VFS operation + ErrVfsUnsupported = errors.New("not supported") + credentialsDirPath string ) +// SetCredentialsDirPath sets the credentials dir path +func SetCredentialsDirPath(credentialsPath string) { + credentialsDirPath = credentialsPath +} + +// GetCredentialsDirPath returns the credentials dir path +func GetCredentialsDirPath() string { + return credentialsDirPath +} + // Fs defines the interface for filesystem backends type Fs interface { Name() string @@ -40,6 +53,7 @@ type Fs interface { Rename(source, target string) error Remove(name string, isDir bool) error Mkdir(name string) error + MkdirAll(name string, uid int, gid int) error Symlink(source, target string) error Chown(name string, uid int, gid int) error Chmod(name string, mode os.FileMode) error @@ -79,9 +93,6 @@ type File interface { Truncate(size int64) error } -// ErrVfsUnsupported defines the error for an unsupported VFS operation -var ErrVfsUnsupported = errors.New("Not supported") - // QuotaCheckResult defines the result for a quota check type QuotaCheckResult struct { HasSpace bool @@ -438,6 +449,11 @@ func IsLocalOrSFTPFs(fs Fs) bool { return IsLocalOsFs(fs) || IsSFTPFs(fs) } +// IsLocalOrCryptoFs returns true if fs is local or local encrypted +func IsLocalOrCryptoFs(fs Fs) bool { + return IsLocalOsFs(fs) || IsCryptOsFs(fs) +} + // SetPathPermissions calls fs.Chown. // It does nothing for local filesystem on windows func SetPathPermissions(fs Fs, path string, uid int, gid int) { diff --git a/webdavd/file.go b/webdavd/file.go index 41c1f843..f0f4c930 100644 --- a/webdavd/file.go +++ b/webdavd/file.go @@ -84,7 +84,7 @@ func (f *webDavFile) Readdir(count int) ([]os.FileInfo, error) { if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) { return nil, f.Connection.GetPermissionDeniedError() } - fileInfos, err := f.Connection.ListDir(f.GetFsPath(), f.GetVirtualPath()) + fileInfos, err := f.Connection.ListDir(f.GetVirtualPath()) if err != nil { return nil, err } @@ -299,7 +299,7 @@ func (f *webDavFile) Close() error { } else { f.Connection.RemoveTransfer(f.BaseTransfer) } - return f.Connection.GetFsError(err) + return f.Connection.GetFsError(f.Fs, err) } func (f *webDavFile) closeIO() error { diff --git a/webdavd/handler.go b/webdavd/handler.go index 7d18a54d..f8995bf9 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -57,11 +57,7 @@ func (c *Connection) Mkdir(ctx context.Context, name string, perm os.FileMode) e c.UpdateLastActivity() name = utils.CleanPath(name) - p, err := c.Fs.ResolvePath(name) - if err != nil { - return c.GetFsError(err) - } - return c.CreateDir(p, name) + return c.CreateDir(name) } // Rename renames a file or a directory @@ -71,20 +67,10 @@ func (c *Connection) Rename(ctx context.Context, oldName, newName string) error oldName = utils.CleanPath(oldName) newName = utils.CleanPath(newName) - p, err := c.Fs.ResolvePath(oldName) - if err != nil { - return c.GetFsError(err) - } - t, err := c.Fs.ResolvePath(newName) - if err != nil { - return c.GetFsError(err) - } - - if err = c.BaseConnection.Rename(p, t, oldName, newName); err != nil { + if err := c.BaseConnection.Rename(oldName, newName); err != nil { return err } - vfs.SetPathPermissions(c.Fs, t, c.User.GetUID(), c.User.GetGID()) return nil } @@ -98,14 +84,10 @@ func (c *Connection) Stat(ctx context.Context, name string) (os.FileInfo, error) return nil, c.GetPermissionDeniedError() } - p, err := c.Fs.ResolvePath(name) + fi, err := c.DoStat(name, 0) if err != nil { - return nil, c.GetFsError(err) - } - fi, err := c.DoStat(p, 0) - if err != nil { - c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err) - return nil, c.GetFsError(err) + c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", name, err) + return nil, err } return fi, err } @@ -116,21 +98,21 @@ func (c *Connection) RemoveAll(ctx context.Context, name string) error { c.UpdateLastActivity() name = utils.CleanPath(name) - p, err := c.Fs.ResolvePath(name) + fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { - return c.GetFsError(err) + return err } var fi os.FileInfo - if fi, err = c.Fs.Lstat(p); err != nil { - c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", p, err) - return c.GetFsError(err) + if fi, err = fs.Lstat(p); err != nil { + c.Log(logger.LevelDebug, "failed to remove a file %#v: stat error: %+v", p, err) + return c.GetFsError(fs, err) } if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { - return c.removeDirTree(p, name) + return c.removeDirTree(fs, p, name) } - return c.RemoveFile(p, name, fi) + return c.RemoveFile(fs, p, name, fi) } // OpenFile opens the named file with specified flag. @@ -139,19 +121,19 @@ func (c *Connection) OpenFile(ctx context.Context, name string, flag int, perm o c.UpdateLastActivity() name = utils.CleanPath(name) - p, err := c.Fs.ResolvePath(name) + fs, p, err := c.GetFsAndResolvedPath(name) if err != nil { - return nil, c.GetFsError(err) + return nil, err } if flag == os.O_RDONLY || c.request.Method == "PROPPATCH" { // Download, Stat, Readdir or simply open/close - return c.getFile(p, name) + return c.getFile(fs, p, name) } - return c.putFile(p, name) + return c.putFile(fs, p, name) } -func (c *Connection) getFile(fsPath, virtualPath string) (webdav.File, error) { +func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { var err error var file vfs.File var r *pipeat.PipeReaderAt @@ -159,42 +141,42 @@ func (c *Connection) getFile(fsPath, virtualPath string) (webdav.File, error) { // for cloud fs we open the file when we receive the first read to avoid to download the first part of // the file if it was opened only to do a stat or a readdir and so it is not a real download - if vfs.IsLocalOrSFTPFs(c.Fs) { - file, r, cancelFn, err = c.Fs.Open(fsPath, 0) + if vfs.IsLocalOrSFTPFs(fs) { + file, r, cancelFn, err = fs.Open(fsPath, 0) if err != nil { c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", fsPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, virtualPath, common.TransferDownload, - 0, 0, 0, false, c.Fs) + 0, 0, 0, false, fs) return newWebDavFile(baseTransfer, nil, r), nil } -func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) { +func (c *Connection) putFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { if !c.User.IsFileAllowed(virtualPath) { c.Log(logger.LevelWarn, "writing file %#v is not allowed", virtualPath) return nil, c.GetPermissionDeniedError() } filePath := fsPath - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - filePath = c.Fs.GetAtomicUploadPath(fsPath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(fsPath) } - stat, statErr := c.Fs.Lstat(fsPath) - if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) { + stat, statErr := fs.Lstat(fsPath) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) { return nil, c.GetPermissionDeniedError() } - return c.handleUploadToNewFile(fsPath, filePath, virtualPath) + return c.handleUploadToNewFile(fs, fsPath, filePath, virtualPath) } if statErr != nil { c.Log(logger.LevelError, "error performing file stat %#v: %+v", fsPath, statErr) - return nil, c.GetFsError(statErr) + return nil, c.GetFsError(fs, statErr) } // This happen if we upload a file that has the same name of an existing directory @@ -207,33 +189,33 @@ func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) { return nil, c.GetPermissionDeniedError() } - return c.handleUploadToExistingFile(fsPath, filePath, stat.Size(), virtualPath) + return c.handleUploadToExistingFile(fs, fsPath, filePath, stat.Size(), virtualPath) } -func (c *Connection) handleUploadToNewFile(resolvedPath, filePath, requestPath string) (webdav.File, error) { +func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (webdav.File, error) { quotaResult := c.HasSpace(true, false, requestPath) if !quotaResult.HasSpace { c.Log(logger.LevelInfo, "denying file write due to quota limits") return nil, common.ErrQuotaExceeded } - file, w, cancelFn, err := c.Fs.Create(filePath, 0) + file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) // we can get an error only for resume - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) + maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) + common.TransferUpload, 0, 0, maxWriteSize, true, fs) return newWebDavFile(baseTransfer, w, nil), nil } -func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, fileSize int64, +func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePath string, fileSize int64, requestPath string) (webdav.File, error) { var err error quotaResult := c.HasSpace(false, false, requestPath) @@ -244,24 +226,24 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace // will return false in this case and we deny the upload before - maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize) + maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported()) - if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() { - err = c.Fs.Rename(resolvedPath, filePath) + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + err = fs.Rename(resolvedPath, filePath) if err != nil { c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v", resolvedPath, filePath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } } - file, w, cancelFn, err := c.Fs.Create(filePath, 0) + file, w, cancelFn, err := fs.Create(filePath, 0) if err != nil { c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err) - return nil, c.GetFsError(err) + return nil, c.GetFsError(fs, err) } initialSize := int64(0) - if vfs.IsLocalOrSFTPFs(c.Fs) { + if vfs.IsLocalOrSFTPFs(fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck @@ -275,10 +257,10 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f initialSize = fileSize } - vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, initialSize, maxWriteSize, false, c.Fs) + common.TransferUpload, 0, initialSize, maxWriteSize, false, fs) return newWebDavFile(baseTransfer, w, nil), nil } @@ -289,22 +271,22 @@ type objectMapping struct { info os.FileInfo } -func (c *Connection) removeDirTree(fsPath, virtualPath string) error { +func (c *Connection) removeDirTree(fs vfs.Fs, fsPath, virtualPath string) error { var dirsToRemove []objectMapping var filesToRemove []objectMapping - err := c.Fs.Walk(fsPath, func(walkedPath string, info os.FileInfo, err error) error { + err := fs.Walk(fsPath, func(walkedPath string, info os.FileInfo, err error) error { if err != nil { return err } obj := objectMapping{ fsPath: walkedPath, - virtualPath: c.Fs.GetRelativePath(walkedPath), + virtualPath: fs.GetRelativePath(walkedPath), info: info, } if info.IsDir() { - err = c.IsRemoveDirAllowed(obj.fsPath, obj.virtualPath) + err = c.IsRemoveDirAllowed(fs, obj.fsPath, obj.virtualPath) isDuplicated := false for _, d := range dirsToRemove { if d.fsPath == obj.fsPath { @@ -316,7 +298,7 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error { dirsToRemove = append(dirsToRemove, obj) } } else { - err = c.IsRemoveFileAllowed(obj.fsPath, obj.virtualPath) + err = c.IsRemoveFileAllowed(obj.virtualPath) filesToRemove = append(filesToRemove, obj) } if err != nil { @@ -333,7 +315,7 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error { } for _, fileObj := range filesToRemove { - err = c.RemoveFile(fileObj.fsPath, fileObj.virtualPath, fileObj.info) + err = c.RemoveFile(fs, fileObj.fsPath, fileObj.virtualPath, fileObj.info) if err != nil { c.Log(logger.LevelDebug, "unable to remove dir tree, error removing file %#v->%#v: %v", fileObj.virtualPath, fileObj.fsPath, err) @@ -341,8 +323,8 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error { } } - for _, dirObj := range c.orderDirsToRemove(dirsToRemove) { - err = c.RemoveDir(dirObj.fsPath, dirObj.virtualPath) + for _, dirObj := range c.orderDirsToRemove(fs, dirsToRemove) { + err = c.RemoveDir(dirObj.virtualPath) if err != nil { c.Log(logger.LevelDebug, "unable to remove dir tree, error removing directory %#v->%#v: %v", dirObj.virtualPath, dirObj.fsPath, err) @@ -354,12 +336,12 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error { } // order directories so that the empty ones will be at slice start -func (c *Connection) orderDirsToRemove(dirsToRemove []objectMapping) []objectMapping { +func (c *Connection) orderDirsToRemove(fs vfs.Fs, dirsToRemove []objectMapping) []objectMapping { orderedDirs := make([]objectMapping, 0, len(dirsToRemove)) removedDirs := make([]string, 0, len(dirsToRemove)) pathSeparator := "/" - if vfs.IsLocalOsFs(c.Fs) { + if vfs.IsLocalOsFs(fs) { pathSeparator = string(os.PathSeparator) } diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index d002d69f..ae9afe1d 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -22,6 +22,7 @@ import ( "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" + "github.com/drakkan/sftpgo/kms" "github.com/drakkan/sftpgo/vfs" ) @@ -321,7 +322,7 @@ func (fs *MockOsFs) GetMimeType(name string) (string, error) { func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string, reader *pipeat.PipeReaderAt) vfs.Fs { return &MockOsFs{ - Fs: vfs.NewOsFs(connectionID, rootDir, nil), + Fs: vfs.NewOsFs(connectionID, rootDir, ""), err: err, isAtomicUploadSupported: atomicUpload, reader: reader, @@ -330,14 +331,14 @@ func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string, rea func TestOrderDirsToRemove(t *testing.T) { user := dataprovider.User{} - fs := vfs.NewOsFs("id", os.TempDir(), nil) + fs := vfs.NewOsFs("id", os.TempDir(), "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), request: nil, } dirsToRemove := []objectMapping{} - orderedDirs := connection.orderDirsToRemove(dirsToRemove) + orderedDirs := connection.orderDirsToRemove(fs, dirsToRemove) assert.Equal(t, len(dirsToRemove), len(orderedDirs)) dirsToRemove = []objectMapping{ @@ -346,7 +347,7 @@ func TestOrderDirsToRemove(t *testing.T) { virtualPath: "", }, } - orderedDirs = connection.orderDirsToRemove(dirsToRemove) + orderedDirs = connection.orderDirsToRemove(fs, dirsToRemove) assert.Equal(t, len(dirsToRemove), len(orderedDirs)) dirsToRemove = []objectMapping{ @@ -368,7 +369,7 @@ func TestOrderDirsToRemove(t *testing.T) { }, } - orderedDirs = connection.orderDirsToRemove(dirsToRemove) + orderedDirs = connection.orderDirsToRemove(fs, dirsToRemove) if assert.Equal(t, len(dirsToRemove), len(orderedDirs)) { assert.Equal(t, "dir12", orderedDirs[0].fsPath) assert.Equal(t, filepath.Join("dir1", "a", "b"), orderedDirs[1].fsPath) @@ -403,30 +404,6 @@ func TestUserInvalidParams(t *testing.T) { assert.EqualError(t, err, fmt.Sprintf("cannot login user with invalid home dir: %#v", u.HomeDir)) } - u.HomeDir = filepath.Clean(os.TempDir()) - subDir := "subdir" - mappedPath1 := filepath.Join(os.TempDir(), "vdir1") - vdirPath1 := "/vdir1" - mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir) - vdirPath2 := "/vdir2" - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath1, - }, - VirtualPath: vdirPath1, - }) - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - MappedPath: mappedPath2, - }, - VirtualPath: vdirPath2, - }) - - _, err = server.validateUser(u, req, dataprovider.LoginMethodPassword) - if assert.Error(t, err) { - assert.EqualError(t, err, "overlapping mapped folders are allowed only with quota tracking disabled") - } - req.TLS = &tls.ConnectionState{} writeLog(req, nil) } @@ -478,9 +455,9 @@ func TestResolvePathErrors(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } err := connection.Mkdir(ctx, "", os.ModePerm) @@ -509,8 +486,9 @@ func TestResolvePathErrors(t *testing.T) { } if runtime.GOOS != "windows" { - connection.User.HomeDir = filepath.Clean(os.TempDir()) - connection.Fs = vfs.NewOsFs("connID", connection.User.HomeDir, nil) + user.HomeDir = filepath.Clean(os.TempDir()) + connection.User = user + fs := vfs.NewOsFs("connID", connection.User.HomeDir, "") subDir := "sub" testTxtFile := "file.txt" err = os.MkdirAll(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm) @@ -519,11 +497,13 @@ func TestResolvePathErrors(t *testing.T) { assert.NoError(t, err) err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), 0001) assert.NoError(t, err) + err = os.WriteFile(filepath.Join(os.TempDir(), testTxtFile), []byte("test content"), os.ModePerm) + assert.NoError(t, err) err = connection.Rename(ctx, testTxtFile, path.Join(subDir, subDir, testTxtFile)) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrPermissionDenied.Error()) } - _, err = connection.putFile(filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile), + _, err = connection.putFile(fs, filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile), path.Join(subDir, subDir, testTxtFile)) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrPermissionDenied.Error()) @@ -532,6 +512,8 @@ func TestResolvePathErrors(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(filepath.Join(os.TempDir(), subDir)) assert.NoError(t, err) + err = os.Remove(filepath.Join(os.TempDir(), testTxtFile)) + assert.NoError(t, err) } } @@ -542,9 +524,9 @@ func TestFileAccessErrors(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } missingPath := "missing path" fsMissingPath := filepath.Join(user.HomeDir, missingPath) @@ -552,26 +534,26 @@ func TestFileAccessErrors(t *testing.T) { if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } - _, err = connection.getFile(fsMissingPath, missingPath) + _, err = connection.getFile(fs, fsMissingPath, missingPath) if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } - _, err = connection.getFile(fsMissingPath, missingPath) + _, err = connection.getFile(fs, fsMissingPath, missingPath) if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } p := filepath.Join(user.HomeDir, "adir", missingPath) - _, err = connection.handleUploadToNewFile(p, p, path.Join("adir", missingPath)) + _, err = connection.handleUploadToNewFile(fs, p, p, path.Join("adir", missingPath)) if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } - _, err = connection.handleUploadToExistingFile(p, p, 0, path.Join("adir", missingPath)) + _, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath)) if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } - connection.Fs = newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, nil) - _, err = connection.handleUploadToExistingFile(p, p, 0, path.Join("adir", missingPath)) + fs = newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, nil) + _, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath)) if assert.Error(t, err) { assert.EqualError(t, err, os.ErrNotExist.Error()) } @@ -580,7 +562,7 @@ func TestFileAccessErrors(t *testing.T) { assert.NoError(t, err) err = f.Close() assert.NoError(t, err) - davFile, err := connection.handleUploadToExistingFile(f.Name(), f.Name(), 123, f.Name()) + davFile, err := connection.handleUploadToExistingFile(fs, f.Name(), f.Name(), 123, f.Name()) if assert.NoError(t, err) { transfer := davFile.(*webDavFile) transfers := connection.GetTransfers() @@ -603,46 +585,46 @@ func TestRemoveDirTree(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } vpath := path.Join("adir", "missing") p := filepath.Join(user.HomeDir, "adir", "missing") - err := connection.removeDirTree(p, vpath) + err := connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { assert.True(t, os.IsNotExist(err)) } - connection.Fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil) - err = connection.removeDirTree(p, vpath) + fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil) + err = connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { - assert.True(t, os.IsNotExist(err)) + assert.True(t, os.IsNotExist(err), "unexpected error: %v", err) } errFake := errors.New("fake err") - connection.Fs = newMockOsFs(errFake, false, "mockID", user.HomeDir, nil) - err = connection.removeDirTree(p, vpath) + fs = newMockOsFs(errFake, false, "mockID", user.HomeDir, nil) + err = connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { assert.EqualError(t, err, errFake.Error()) } - connection.Fs = newMockOsFs(errWalkDir, true, "mockID", user.HomeDir, nil) - err = connection.removeDirTree(p, vpath) + fs = newMockOsFs(errWalkDir, true, "mockID", user.HomeDir, nil) + err = connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { - assert.True(t, os.IsNotExist(err)) + assert.True(t, os.IsPermission(err), "unexpected error: %v", err) } - connection.Fs = newMockOsFs(errWalkFile, false, "mockID", user.HomeDir, nil) - err = connection.removeDirTree(p, vpath) + fs = newMockOsFs(errWalkFile, false, "mockID", user.HomeDir, nil) + err = connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { assert.EqualError(t, err, errWalkFile.Error()) } connection.User.Permissions["/"] = []string{dataprovider.PermListItems} - connection.Fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil) - err = connection.removeDirTree(p, vpath) + fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil) + err = connection.removeDirTree(fs, p, vpath) if assert.Error(t, err) { assert.EqualError(t, err, common.ErrPermissionDenied.Error()) } @@ -654,9 +636,9 @@ func TestContentType(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } testFilePath := filepath.Join(user.HomeDir, testFile) ctx := context.Background() @@ -679,7 +661,7 @@ func TestContentType(t *testing.T) { assert.NoError(t, err) davFile = newWebDavFile(baseTransfer, nil, nil) - davFile.Fs = vfs.NewOsFs("id", user.HomeDir, nil) + davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "") fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(*webDavFileInfo).ContentType(ctx) @@ -703,9 +685,9 @@ func TestTransferReadWriteErrors(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, @@ -796,9 +778,9 @@ func TestTransferSeek(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - fs := vfs.NewOsFs("connID", user.HomeDir, nil) + fs := vfs.NewOsFs("connID", user.HomeDir, "") connection := &Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } testFilePath := filepath.Join(user.HomeDir, testFile) testFileContents := []byte("content") @@ -991,6 +973,118 @@ func TestBasicUsersCache(t *testing.T) { assert.NoError(t, err) } +func TestCachedUserWithFolders(t *testing.T) { + username := "webdav_internal_folder_test" + password := "dav_pwd" + folderName := "test_folder" + u := dataprovider.User{ + Username: username, + Password: password, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), folderName), + }, + VirtualPath: "/vpath", + }) + err := dataprovider.AddUser(&u) + assert.NoError(t, err) + user, err := dataprovider.UserExists(u.Username) + assert.NoError(t, err) + + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + Cache: Cache{ + Users: UsersCacheConfig{ + MaxSize: 50, + ExpirationTime: 1, + }, + }, + } + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) + assert.NoError(t, err) + + ipAddr := "127.0.0.1" + + _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled + assert.Error(t, err) + + now := time.Now() + req.SetBasicAuth(username, password) + _, isCached, _, loginMethod, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + // now the user should be cached + var cachedUser *dataprovider.CachedUser + result, ok := dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + cachedUser = result.(*dataprovider.CachedUser) + assert.False(t, cachedUser.IsExpired()) + assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) + // authenticate must return the cached user now + authUser, isCached, _, _, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.True(t, isCached) + assert.Equal(t, cachedUser.User, authUser) + } + + folder, err := dataprovider.GetFolderByName(folderName) + assert.NoError(t, err) + // updating a used folder should invalidate the cache + err = dataprovider.UpdateFolder(&folder, folder.Users) + assert.NoError(t, err) + + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + result, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + cachedUser = result.(*dataprovider.CachedUser) + assert.False(t, cachedUser.IsExpired()) + } + + err = dataprovider.DeleteFolder(folderName) + assert.NoError(t, err) + // removing a used folder should invalidate the cache + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + result, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + cachedUser = result.(*dataprovider.CachedUser) + assert.False(t, cachedUser.IsExpired()) + } + + err = dataprovider.DeleteUser(user.Username) + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + + err = os.RemoveAll(folder.MappedPath) + assert.NoError(t, err) +} + func TestUsersCacheSizeAndExpiration(t *testing.T) { username := "webdav_internal_test" password := "pwd" @@ -1188,6 +1282,65 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) { assert.NoError(t, err) } +func TestUserCacheIsolation(t *testing.T) { + username := "webdav_internal_cache_test" + password := "dav_pwd" + u := dataprovider.User{ + Username: username, + Password: password, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + err := dataprovider.AddUser(&u) + assert.NoError(t, err) + user, err := dataprovider.UserExists(u.Username) + assert.NoError(t, err) + cachedUser := &dataprovider.CachedUser{ + User: user, + Expiration: time.Now().Add(24 * time.Hour), + Password: password, + LockSystem: webdav.NewMemLS(), + } + cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret") + err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt() + assert.NoError(t, err) + + dataprovider.CacheWebDAVUser(cachedUser, 10) + result, ok := dataprovider.GetCachedWebDAVUser(username) + + if assert.True(t, ok) { + cachedUser := result.(*dataprovider.CachedUser).User + _, err = cachedUser.GetFilesystem("") + assert.NoError(t, err) + // the filesystem is now cached + } + result, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + cachedUser := result.(*dataprovider.CachedUser).User + assert.True(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted()) + err = cachedUser.FsConfig.S3Config.AccessSecret.Decrypt() + assert.NoError(t, err) + cachedUser.FsConfig.Provider = vfs.S3FilesystemProvider + _, err = cachedUser.GetFilesystem("") + assert.Error(t, err, "we don't have to get the previously cached filesystem!") + } + result, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + cachedUser := result.(*dataprovider.CachedUser).User + assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.FsConfig.Provider) + // FIXME: should we really allow to modify the cached users concurrently????? + assert.False(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted()) + } + + err = dataprovider.DeleteUser(username) + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) +} + func TestRecoverer(t *testing.T) { c := &Configuration{ Bindings: []Binding{ diff --git a/webdavd/server.go b/webdavd/server.go index 933baba5..cacddb7e 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -160,7 +160,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) return } - user, _, lockSystem, loginMethod, err := s.authenticate(r, ipAddr) + user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr) if err != nil { w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"") http.Error(w, fmt.Sprintf("Authentication error: %v", err), http.StatusUnauthorized) @@ -174,8 +174,14 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - fs, err := user.GetFilesystem(connectionID) + if !isCached { + err = user.CheckFsRoot(connectionID) + } else { + _, err = user.GetFilesystem(connectionID) + } if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) updateLoginMetrics(&user, ipAddr, loginMethod, err) http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -187,7 +193,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = context.WithValue(ctx, requestStartKey, time.Now()) connection := &Connection{ - BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, user), request: r, } common.Connections.Add(connection) @@ -273,14 +279,6 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute) } dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize) - if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider { - // for sftp fs check root path does nothing so don't open a useless SFTP connection - tempFs, err := user.GetFilesystem("temp") - if err == nil { - tempFs.CheckRootPath(user.Username, user.UID, user.GID) - tempFs.Close() - } - } return user, false, lockSystem, loginMethod, nil } @@ -295,11 +293,11 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo } if utils.IsStringInSlice(common.ProtocolWebDAV, user.Filters.DeniedProtocols) { logger.Debug(logSender, connectionID, "cannot login user %#v, protocol DAV is not allowed", user.Username) - return connID, fmt.Errorf("Protocol DAV is not allowed for user %#v", user.Username) + return connID, fmt.Errorf("protocol DAV is not allowed for user %#v", user.Username) } if !user.IsLoginMethodAllowed(loginMethod, nil) { logger.Debug(logSender, connectionID, "cannot login user %#v, %v login method is not allowed", user.Username, loginMethod) - return connID, fmt.Errorf("Login method %v is not allowed for user %#v", loginMethod, user.Username) + return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username) } if user.MaxSessions > 0 { activeSessions := common.Connections.GetActiveSessions(user.Username) @@ -309,14 +307,9 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo return connID, fmt.Errorf("too many open sessions: %v", activeSessions) } } - if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() { - logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled", - user.Username) - return connID, errors.New("overlapping mapped folders are allowed only with quota tracking disabled") - } if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr) - return connID, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr) + return connID, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr) } return connID, nil } diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index a6ba3972..be852a33 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -876,9 +876,9 @@ func TestMaxConnections(t *testing.T) { client := getWebDavClient(user, true, nil) assert.NoError(t, checkBasicFunc(client)) // now add a fake connection - fs := vfs.NewOsFs("id", os.TempDir(), nil) + fs := vfs.NewOsFs("id", os.TempDir(), "") connection := &webdavd.Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } common.Connections.Add(connection) assert.Error(t, checkBasicFunc(client)) @@ -900,9 +900,9 @@ func TestMaxSessions(t *testing.T) { client := getWebDavClient(user, false, nil) assert.NoError(t, checkBasicFunc(client)) // now add a fake connection - fs := vfs.NewOsFs("id", os.TempDir(), nil) + fs := vfs.NewOsFs("id", os.TempDir(), "") connection := &webdavd.Connection{ - BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs), + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user), } common.Connections.Add(connection) assert.Error(t, checkBasicFunc(client)) @@ -1308,7 +1308,7 @@ func TestClientClose(t *testing.T) { func TestLoginWithDatabaseCredentials(t *testing.T) { u := getTestUser() - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) @@ -1356,7 +1356,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { func TestLoginInvalidFs(t *testing.T) { u := getTestUser() - u.FsConfig.Provider = dataprovider.GCSFilesystemProvider + u.FsConfig.Provider = vfs.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") user, _, err := httpdtest.AddUser(u, http.StatusCreated) @@ -2006,6 +2006,106 @@ func TestPreLoginHookWithClientCert(t *testing.T) { assert.NoError(t, err) } +func TestNestedVirtualFolders(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: vfs.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + }, + VirtualPath: vdirCryptPath, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(sftpUser, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = uploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(), 0) +} + func checkBasicFunc(client *gowebdav.Client) error { err := client.Connect() if err != nil { @@ -2125,7 +2225,7 @@ func getTestUser() dataprovider.User { func getTestSFTPUser() dataprovider.User { u := getTestUser() u.Username = u.Username + "_sftp" - u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider + u.FsConfig.Provider = vfs.SFTPFilesystemProvider u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr u.FsConfig.SFTPConfig.Username = defaultUsername u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) @@ -2134,7 +2234,7 @@ func getTestSFTPUser() dataprovider.User { func getTestUserWithCryptFs() dataprovider.User { user := getTestUser() - user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider + user.FsConfig.Provider = vfs.CryptedFilesystemProvider user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") return user }