extend virtual folders support to all storage backends

Fixes #241
This commit is contained in:
Nicola Murino 2021-03-21 19:15:47 +01:00
parent 0286da2356
commit d6dc3a507e
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
70 changed files with 6825 additions and 3740 deletions

View file

@ -84,9 +84,9 @@ $ sftpgo portable
Please take a look at the usage below to customize the serving parameters`,
Run: func(cmd *cobra.Command, args []string) {
portableDir := directoryToServe
fsProvider := dataprovider.FilesystemProvider(portableFsProvider)
fsProvider := vfs.FilesystemProvider(portableFsProvider)
if !filepath.IsAbs(portableDir) {
if fsProvider == dataprovider.LocalFilesystemProvider {
if fsProvider == vfs.LocalFilesystemProvider {
portableDir, _ = filepath.Abs(portableDir)
} else {
portableDir = os.TempDir()
@ -95,7 +95,7 @@ Please take a look at the usage below to customize the serving parameters`,
permissions := make(map[string][]string)
permissions["/"] = portablePermissions
portableGCSCredentials := ""
if fsProvider == dataprovider.GCSFilesystemProvider && portableGCSCredentialsFile != "" {
if fsProvider == vfs.GCSFilesystemProvider && portableGCSCredentialsFile != "" {
contents, err := getFileContents(portableGCSCredentialsFile)
if err != nil {
fmt.Printf("Unable to get GCS credentials: %v\n", err)
@ -105,7 +105,7 @@ Please take a look at the usage below to customize the serving parameters`,
portableGCSAutoCredentials = 0
}
portableSFTPPrivateKey := ""
if fsProvider == dataprovider.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" {
if fsProvider == vfs.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" {
contents, err := getFileContents(portableSFTPPrivateKeyPath)
if err != nil {
fmt.Printf("Unable to get SFTP private key: %v\n", err)
@ -149,8 +149,8 @@ Please take a look at the usage below to customize the serving parameters`,
Permissions: permissions,
HomeDir: portableDir,
Status: 1,
FsConfig: dataprovider.Filesystem{
Provider: dataprovider.FilesystemProvider(portableFsProvider),
FsConfig: vfs.Filesystem{
Provider: vfs.FilesystemProvider(portableFsProvider),
S3Config: vfs.S3FsConfig{
Bucket: portableS3Bucket,
Region: portableS3Region,
@ -257,7 +257,7 @@ multicast DNS`)
advertised via multicast DNS, this
flag allows to put username/password
inside the advertised TXT record`)
portableCmd.Flags().IntVarP(&portableFsProvider, "fs-provider", "f", int(dataprovider.LocalFilesystemProvider), `0 => local filesystem
portableCmd.Flags().IntVarP(&portableFsProvider, "fs-provider", "f", int(vfs.LocalFilesystemProvider), `0 => local filesystem
1 => AWS S3 compatible
2 => Google Cloud Storage
3 => Azure Blob Storage

View file

@ -18,6 +18,7 @@ import (
"github.com/drakkan/sftpgo/httpclient"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs"
)
var (
@ -79,12 +80,12 @@ func newActionNotification(
var bucket, endpoint string
status := 1
if user.FsConfig.Provider == dataprovider.S3FilesystemProvider {
if user.FsConfig.Provider == vfs.S3FilesystemProvider {
bucket = user.FsConfig.S3Config.Bucket
endpoint = user.FsConfig.S3Config.Endpoint
} else if user.FsConfig.Provider == dataprovider.GCSFilesystemProvider {
} else if user.FsConfig.Provider == vfs.GCSFilesystemProvider {
bucket = user.FsConfig.GCSConfig.Bucket
} else if user.FsConfig.Provider == dataprovider.AzureBlobFilesystemProvider {
} else if user.FsConfig.Provider == vfs.AzureBlobFilesystemProvider {
bucket = user.FsConfig.AzBlobConfig.Container
if user.FsConfig.AzBlobConfig.SASURL != "" {
endpoint = user.FsConfig.AzBlobConfig.SASURL

View file

@ -19,7 +19,7 @@ func TestNewActionNotification(t *testing.T) {
user := &dataprovider.User{
Username: "username",
}
user.FsConfig.Provider = dataprovider.LocalFilesystemProvider
user.FsConfig.Provider = vfs.LocalFilesystemProvider
user.FsConfig.S3Config = vfs.S3FsConfig{
Bucket: "s3bucket",
Endpoint: "endpoint",
@ -38,19 +38,19 @@ func TestNewActionNotification(t *testing.T) {
assert.Equal(t, 0, len(a.Endpoint))
assert.Equal(t, 0, a.Status)
user.FsConfig.Provider = dataprovider.S3FilesystemProvider
user.FsConfig.Provider = vfs.S3FilesystemProvider
a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSSH, 123, nil)
assert.Equal(t, "s3bucket", a.Bucket)
assert.Equal(t, "endpoint", a.Endpoint)
assert.Equal(t, 1, a.Status)
user.FsConfig.Provider = dataprovider.GCSFilesystemProvider
user.FsConfig.Provider = vfs.GCSFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, ErrQuotaExceeded)
assert.Equal(t, "gcsbucket", a.Bucket)
assert.Equal(t, 0, len(a.Endpoint))
assert.Equal(t, 2, a.Status)
user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider
user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "target", "", ProtocolSCP, 123, nil)
assert.Equal(t, "azcontainer", a.Bucket)
assert.Equal(t, "azsasurl", a.Endpoint)
@ -179,15 +179,15 @@ func TestPreDeleteAction(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("id", homeDir, nil)
c := NewBaseConnection("id", ProtocolSFTP, user, fs)
fs := vfs.NewOsFs("id", homeDir, "")
c := NewBaseConnection("id", ProtocolSFTP, user)
testfile := filepath.Join(user.HomeDir, "testfile")
err = os.WriteFile(testfile, []byte("test"), os.ModePerm)
assert.NoError(t, err)
info, err := os.Stat(testfile)
assert.NoError(t, err)
err = c.RemoveFile(testfile, "testfile", info)
err = c.RemoveFile(fs, testfile, "testfile", info)
assert.NoError(t, err)
assert.FileExists(t, testfile)

View file

@ -3,7 +3,6 @@ package common
import (
"fmt"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
@ -13,15 +12,11 @@ import (
"testing"
"time"
"github.com/rs/zerolog"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/httpclient"
"github.com/drakkan/sftpgo/kms"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs"
)
@ -29,29 +24,22 @@ import (
const (
logSenderTest = "common_test"
httpAddr = "127.0.0.1:9999"
httpProxyAddr = "127.0.0.1:7777"
configDir = ".."
osWindows = "windows"
userTestUsername = "common_test_username"
userTestPwd = "common_test_pwd"
)
type providerConf struct {
Config dataprovider.Config `json:"data_provider" mapstructure:"data_provider"`
}
type fakeConnection struct {
*BaseConnection
command string
}
func (c *fakeConnection) AddUser(user dataprovider.User) error {
fs, err := user.GetFilesystem(c.GetID())
_, err := user.GetFilesystem(c.GetID())
if err != nil {
return err
}
c.BaseConnection.User = user
c.BaseConnection.Fs = fs
return nil
}
@ -84,110 +72,6 @@ func (c *customNetConn) Close() error {
return c.Conn.Close()
}
func TestMain(m *testing.M) {
logfilePath := "common_test.log"
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
viper.SetEnvPrefix("sftpgo")
replacer := strings.NewReplacer(".", "__")
viper.SetEnvKeyReplacer(replacer)
viper.SetConfigName("sftpgo")
viper.AutomaticEnv()
viper.AllowEmptyEnv(true)
driver, err := initializeDataprovider(-1)
if err != nil {
logger.WarnToConsole("error initializing data provider: %v", err)
os.Exit(1)
}
logger.InfoToConsole("Starting COMMON tests, provider: %v", driver)
err = Initialize(Configuration{})
if err != nil {
logger.WarnToConsole("error initializing common: %v", err)
os.Exit(1)
}
httpConfig := httpclient.Config{
Timeout: 5,
}
httpConfig.Initialize(configDir) //nolint:errcheck
go func() {
// start a test HTTP server to receive action notifications
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "OK\n")
})
http.HandleFunc("/404", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Not found\n")
})
if err := http.ListenAndServe(httpAddr, nil); err != nil {
logger.ErrorToConsole("could not start HTTP notification server: %v", err)
os.Exit(1)
}
}()
go func() {
Config.ProxyProtocol = 2
listener, err := net.Listen("tcp", httpProxyAddr)
if err != nil {
logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err)
os.Exit(1)
}
proxyListener, err := Config.GetProxyListener(listener)
if err != nil {
logger.ErrorToConsole("error creating proxy protocol listener: %v", err)
os.Exit(1)
}
Config.ProxyProtocol = 0
s := &http.Server{}
if err := s.Serve(proxyListener); err != nil {
logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err)
os.Exit(1)
}
}()
waitTCPListening(httpAddr)
waitTCPListening(httpProxyAddr)
exitCode := m.Run()
os.Remove(logfilePath) //nolint:errcheck
os.Exit(exitCode)
}
func waitTCPListening(address string) {
for {
conn, err := net.Dial("tcp", address)
if err != nil {
logger.WarnToConsole("tcp server %v not listening: %v", address, err)
time.Sleep(100 * time.Millisecond)
continue
}
logger.InfoToConsole("tcp server %v now listening", address)
conn.Close()
break
}
}
func initializeDataprovider(trackQuota int) (string, error) {
configDir := ".."
viper.AddConfigPath(configDir)
if err := viper.ReadInConfig(); err != nil {
return "", err
}
var cfg providerConf
if err := viper.Unmarshal(&cfg); err != nil {
return "", err
}
if trackQuota >= 0 && trackQuota <= 2 {
cfg.Config.TrackQuota = trackQuota
}
return cfg.Config.Driver, dataprovider.Initialize(cfg.Config, configDir, true)
}
func closeDataprovider() error {
return dataprovider.Close()
}
func TestSSHConnections(t *testing.T) {
conn1, conn2 := net.Pipe()
now := time.Now()
@ -286,7 +170,7 @@ func TestMaxConnections(t *testing.T) {
Config.MaxTotalConnections = 1
assert.True(t, Connections.IsNewConnectionAllowed())
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
}
@ -324,7 +208,7 @@ func TestIdleConnections(t *testing.T) {
user := dataprovider.User{
Username: username,
}
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil)
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
fakeConn := &fakeConnection{
BaseConnection: c,
@ -336,7 +220,7 @@ func TestIdleConnections(t *testing.T) {
Connections.AddSSHConnection(sshConn1)
Connections.Add(fakeConn)
assert.Equal(t, Connections.GetActiveSessions(username), 1)
c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil)
c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user)
fakeConn = &fakeConnection{
BaseConnection: c,
}
@ -344,7 +228,7 @@ func TestIdleConnections(t *testing.T) {
Connections.Add(fakeConn)
assert.Equal(t, Connections.GetActiveSessions(username), 2)
cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{})
cFTP.lastActivity = time.Now().UnixNano()
fakeConn = &fakeConnection{
BaseConnection: cFTP,
@ -383,7 +267,7 @@ func TestIdleConnections(t *testing.T) {
}
func TestCloseConnection(t *testing.T) {
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
}
@ -399,7 +283,7 @@ func TestCloseConnection(t *testing.T) {
}
func TestSwapConnection(t *testing.T) {
c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{}, nil)
c := NewBaseConnection("id", ProtocolFTP, dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
}
@ -409,7 +293,7 @@ func TestSwapConnection(t *testing.T) {
}
c = NewBaseConnection("id", ProtocolFTP, dataprovider.User{
Username: userTestUsername,
}, nil)
})
fakeConn = &fakeConnection{
BaseConnection: c,
}
@ -443,8 +327,8 @@ func TestConnectionStatus(t *testing.T) {
user := dataprovider.User{
Username: username,
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
c1 := NewBaseConnection("id1", ProtocolSFTP, user, fs)
fs := vfs.NewOsFs("", os.TempDir(), "")
c1 := NewBaseConnection("id1", ProtocolSFTP, user)
fakeConn1 := &fakeConnection{
BaseConnection: c1,
}
@ -452,12 +336,12 @@ func TestConnectionStatus(t *testing.T) {
t1.BytesReceived = 123
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs)
t2.BytesSent = 456
c2 := NewBaseConnection("id2", ProtocolSSH, user, nil)
c2 := NewBaseConnection("id2", ProtocolSSH, user)
fakeConn2 := &fakeConnection{
BaseConnection: c2,
command: "md5sum",
}
c3 := NewBaseConnection("id3", ProtocolWebDAV, user, nil)
c3 := NewBaseConnection("id3", ProtocolWebDAV, user)
fakeConn3 := &fakeConnection{
BaseConnection: c3,
command: "PROPFIND",
@ -565,15 +449,6 @@ func TestProxyProtocolVersion(t *testing.T) {
assert.Error(t, err)
}
func TestProxyProtocol(t *testing.T) {
httpClient := httpclient.GetHTTPClient()
resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr))
if assert.NoError(t, err) {
defer resp.Body.Close()
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}
}
func TestPostConnectHook(t *testing.T) {
Config.PostConnectHook = ""
@ -614,7 +489,7 @@ func TestPostConnectHook(t *testing.T) {
func TestCryptoConvertFileInfo(t *testing.T) {
name := "name"
fs, err := vfs.NewCryptFs("connID1", os.TempDir(), vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")})
fs, err := vfs.NewCryptFs("connID1", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")})
require.NoError(t, err)
cryptFs := fs.(*vfs.CryptFs)
info := vfs.NewFileInfo(name, true, 48, time.Now(), false)
@ -649,4 +524,50 @@ func TestFolderCopy(t *testing.T) {
require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize)
require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles)
require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate)
folder.FsConfig = vfs.Filesystem{
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret("crypto secret"),
},
}
folderCopy = folder.GetACopy()
folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
require.Len(t, folderCopy.Users, 1)
require.True(t, utils.IsStringInSlice("user3", folderCopy.Users))
require.Equal(t, int64(2), folderCopy.ID)
require.Equal(t, folder.Name, folderCopy.Name)
require.Equal(t, folder.MappedPath, folderCopy.MappedPath)
require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize)
require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles)
require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate)
require.Equal(t, "crypto secret", folderCopy.FsConfig.CryptConfig.Passphrase.GetPayload())
}
func TestCachedFs(t *testing.T) {
user := dataprovider.User{
HomeDir: filepath.Clean(os.TempDir()),
}
conn := NewBaseConnection("id", ProtocolSFTP, user)
// changing the user should not affect the connection
user.HomeDir = filepath.Join(os.TempDir(), "temp")
err := os.Mkdir(user.HomeDir, os.ModePerm)
assert.NoError(t, err)
fs, err := user.GetFilesystem("")
assert.NoError(t, err)
p, err := fs.ResolvePath("/")
assert.NoError(t, err)
assert.Equal(t, user.GetHomeDir(), p)
_, p, err = conn.GetFsAndResolvedPath("/")
assert.NoError(t, err)
assert.Equal(t, filepath.Clean(os.TempDir()), p)
user.FsConfig.Provider = vfs.S3FilesystemProvider
_, err = user.GetFilesystem("")
assert.Error(t, err)
conn.User.FsConfig.Provider = vfs.S3FilesystemProvider
_, p, err = conn.GetFsAndResolvedPath("/")
assert.NoError(t, err)
assert.Equal(t, filepath.Clean(os.TempDir()), p)
err = os.Remove(user.HomeDir)
assert.NoError(t, err)
}

View file

@ -30,14 +30,13 @@ type BaseConnection struct {
// start time for this connection
startTime time.Time
protocol string
Fs vfs.Fs
sync.RWMutex
transferID uint64
activeTransfers []ActiveTransfer
}
// NewBaseConnection returns a new BaseConnection
func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *BaseConnection {
func NewBaseConnection(id, protocol string, user dataprovider.User) *BaseConnection {
connID := id
if utils.IsStringInSlice(protocol, supportedProtocols) {
connID = fmt.Sprintf("%v_%v", protocol, id)
@ -47,7 +46,6 @@ func NewBaseConnection(id, protocol string, user dataprovider.User, fs vfs.Fs) *
User: user,
startTime: time.Now(),
protocol: protocol,
Fs: fs,
lastActivity: time.Now().UnixNano(),
transferID: 0,
}
@ -103,10 +101,7 @@ func (c *BaseConnection) GetLastActivity() time.Time {
// CloseFS closes the underlying fs
func (c *BaseConnection) CloseFS() error {
if c.Fs != nil {
return c.Fs.Close()
}
return nil
return c.User.CloseFs()
}
// AddTransfer associates a new transfer to this connection
@ -207,21 +202,25 @@ func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, e
return 0, errNoTransfer
}
// ListDir reads the directory named by fsPath and returns a list of directory entries
func (c *BaseConnection) ListDir(fsPath, virtualPath string) ([]os.FileInfo, error) {
// ListDir reads the directory matching virtualPath and returns a list of directory entries
func (c *BaseConnection) ListDir(virtualPath string) ([]os.FileInfo, error) {
if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) {
return nil, c.GetPermissionDeniedError()
}
files, err := c.Fs.ReadDir(fsPath)
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return nil, err
}
files, err := fs.ReadDir(fsPath)
if err != nil {
c.Log(logger.LevelWarn, "error listing directory: %+v", err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
return c.User.AddVirtualDirs(files, virtualPath), nil
}
// CreateDir creates a new directory at the specified fsPath
func (c *BaseConnection) CreateDir(fsPath, virtualPath string) error {
func (c *BaseConnection) CreateDir(virtualPath string) error {
if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) {
return c.GetPermissionDeniedError()
}
@ -229,42 +228,47 @@ func (c *BaseConnection) CreateDir(fsPath, virtualPath string) error {
c.Log(logger.LevelWarn, "mkdir not allowed %#v is a virtual folder", virtualPath)
return c.GetPermissionDeniedError()
}
if err := c.Fs.Mkdir(fsPath); err != nil {
c.Log(logger.LevelWarn, "error creating dir: %#v error: %+v", fsPath, err)
return c.GetFsError(err)
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return err
}
vfs.SetPathPermissions(c.Fs, fsPath, c.User.GetUID(), c.User.GetGID())
if err := fs.Mkdir(fsPath); err != nil {
c.Log(logger.LevelWarn, "error creating dir: %#v error: %+v", fsPath, err)
return c.GetFsError(fs, err)
}
vfs.SetPathPermissions(fs, fsPath, c.User.GetUID(), c.User.GetGID())
logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1)
return nil
}
// IsRemoveFileAllowed returns an error if removing this file is not allowed
func (c *BaseConnection) IsRemoveFileAllowed(fsPath, virtualPath string) error {
func (c *BaseConnection) IsRemoveFileAllowed(virtualPath string) error {
if !c.User.HasPerm(dataprovider.PermDelete, path.Dir(virtualPath)) {
return c.GetPermissionDeniedError()
}
if !c.User.IsFileAllowed(virtualPath) {
c.Log(logger.LevelDebug, "removing file %#v is not allowed", fsPath)
c.Log(logger.LevelDebug, "removing file %#v is not allowed", virtualPath)
return c.GetPermissionDeniedError()
}
return nil
}
// RemoveFile removes a file at the specified fsPath
func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo) error {
if err := c.IsRemoveFileAllowed(fsPath, virtualPath); err != nil {
func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error {
if err := c.IsRemoveFileAllowed(virtualPath); err != nil {
return err
}
size := info.Size()
action := newActionNotification(&c.User, operationPreDelete, fsPath, "", "", c.protocol, size, nil)
actionErr := actionHandler.Handle(action)
if actionErr == nil {
c.Log(logger.LevelDebug, "remove for path %#v handled by pre-delete action", fsPath)
} else {
if err := c.Fs.Remove(fsPath, false); err != nil {
if err := fs.Remove(fsPath, false); err != nil {
c.Log(logger.LevelWarn, "failed to remove a file/symlink %#v: %+v", fsPath, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
}
@ -288,8 +292,8 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo
}
// IsRemoveDirAllowed returns an error if removing this directory is not allowed
func (c *BaseConnection) IsRemoveDirAllowed(fsPath, virtualPath string) error {
if c.Fs.GetRelativePath(fsPath) == "/" {
func (c *BaseConnection) IsRemoveDirAllowed(fs vfs.Fs, fsPath, virtualPath string) error {
if fs.GetRelativePath(fsPath) == "/" {
c.Log(logger.LevelWarn, "removing root dir is not allowed")
return c.GetPermissionDeniedError()
}
@ -312,54 +316,57 @@ func (c *BaseConnection) IsRemoveDirAllowed(fsPath, virtualPath string) error {
}
// RemoveDir removes a directory at the specified fsPath
func (c *BaseConnection) RemoveDir(fsPath, virtualPath string) error {
if err := c.IsRemoveDirAllowed(fsPath, virtualPath); err != nil {
func (c *BaseConnection) RemoveDir(virtualPath string) error {
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return err
}
if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil {
return err
}
var fi os.FileInfo
var err error
if fi, err = c.Fs.Lstat(fsPath); err != nil {
if fi, err = fs.Lstat(fsPath); err != nil {
// see #149
if c.Fs.IsNotExist(err) && c.Fs.HasVirtualFolders() {
if fs.IsNotExist(err) && fs.HasVirtualFolders() {
return nil
}
c.Log(logger.LevelWarn, "failed to remove a dir %#v: stat error: %+v", fsPath, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
if !fi.IsDir() || fi.Mode()&os.ModeSymlink != 0 {
c.Log(logger.LevelDebug, "cannot remove %#v is not a directory", fsPath)
return c.GetGenericError(nil)
}
if err := c.Fs.Remove(fsPath, true); err != nil {
if err := fs.Remove(fsPath, true); err != nil {
c.Log(logger.LevelWarn, "failed to remove directory %#v: %+v", fsPath, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1)
return nil
}
// Rename renames (moves) fsSourcePath to fsTargetPath
func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string) error {
if c.User.IsMappedPath(fsSourcePath) {
c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %#v", fsSourcePath)
return c.GetPermissionDeniedError()
}
if c.User.IsMappedPath(fsTargetPath) {
c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath)
return c.GetPermissionDeniedError()
}
srcInfo, err := c.Fs.Lstat(fsSourcePath)
// Rename renames (moves) virtualSourcePath to virtualTargetPath
func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) error {
fsSrc, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath)
if err != nil {
return c.GetFsError(err)
return err
}
if !c.isRenamePermitted(fsSourcePath, virtualSourcePath, virtualTargetPath, srcInfo) {
fsDst, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath)
if err != nil {
return err
}
srcInfo, err := fsSrc.Lstat(fsSourcePath)
if err != nil {
return c.GetFsError(fsSrc, err)
}
if !c.isRenamePermitted(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo) {
return c.GetPermissionDeniedError()
}
initialSize := int64(-1)
if dstInfo, err := c.Fs.Lstat(fsTargetPath); err == nil {
if dstInfo, err := fsDst.Lstat(fsTargetPath); err == nil {
if dstInfo.IsDir() {
c.Log(logger.LevelWarn, "attempted to rename %#v overwriting an existing directory %#v",
fsSourcePath, fsTargetPath)
@ -370,8 +377,8 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v
initialSize = dstInfo.Size()
}
if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualTargetPath)) {
c.Log(logger.LevelDebug, "renaming is not allowed, %#v -> %#v. Target exists but the user "+
"has no overwrite permission", virtualSourcePath, virtualTargetPath)
c.Log(logger.LevelDebug, "renaming %#v -> %#v is not allowed. Target exists but the user %#v"+
"has no overwrite permission", virtualSourcePath, virtualTargetPath, c.User.Username)
return c.GetPermissionDeniedError()
}
}
@ -381,22 +388,21 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v
virtualSourcePath)
return c.GetOpUnsupportedError()
}
if err = c.checkRecursiveRenameDirPermissions(fsSourcePath, fsTargetPath); err != nil {
if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %#v: %+v", fsSourcePath, err)
return c.GetFsError(err)
return err
}
}
if !c.hasSpaceForRename(virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath) {
if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath) {
c.Log(logger.LevelInfo, "denying cross rename due to space limit")
return c.GetGenericError(ErrQuotaExceeded)
}
if err := c.Fs.Rename(fsSourcePath, fsTargetPath); err != nil {
if err := fsSrc.Rename(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelWarn, "failed to rename %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(err)
}
if dataprovider.GetQuotaTracking() > 0 {
c.updateQuotaAfterRename(virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck
return c.GetFsError(fsSrc, err)
}
vfs.SetPathPermissions(fsDst, fsTargetPath, c.User.GetUID(), c.User.GetGID())
c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck
logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1,
"", "", "", -1)
action := newActionNotification(&c.User, operationRename, fsSourcePath, fsTargetPath, "", c.protocol, 0, nil)
@ -407,41 +413,42 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v
}
// CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath
func (c *BaseConnection) CreateSymlink(fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string) error {
if c.Fs.GetRelativePath(fsSourcePath) == "/" {
func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error {
if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath)
return c.GetOpUnsupportedError()
}
// we cannot have a cross folder request here so only one fs is enough
fs, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath)
if err != nil {
return err
}
fsTargetPath, err := fs.ResolvePath(virtualTargetPath)
if err != nil {
return c.GetFsError(fs, err)
}
if fs.GetRelativePath(fsSourcePath) == "/" {
c.Log(logger.LevelWarn, "symlinking root dir is not allowed")
return c.GetPermissionDeniedError()
}
if c.User.IsVirtualFolder(virtualTargetPath) {
c.Log(logger.LevelWarn, "symlinking a virtual folder is not allowed")
if fs.GetRelativePath(fsTargetPath) == "/" {
c.Log(logger.LevelWarn, "symlinking to root dir is not allowed")
return c.GetPermissionDeniedError()
}
if !c.User.HasPerm(dataprovider.PermCreateSymlinks, path.Dir(virtualTargetPath)) {
return c.GetPermissionDeniedError()
}
if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath)
return c.GetOpUnsupportedError()
}
if c.User.IsMappedPath(fsSourcePath) {
c.Log(logger.LevelWarn, "symlinking a directory mapped as virtual folder is not allowed: %#v", fsSourcePath)
return c.GetPermissionDeniedError()
}
if c.User.IsMappedPath(fsTargetPath) {
c.Log(logger.LevelWarn, "symlinking to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath)
return c.GetPermissionDeniedError()
}
if err := c.Fs.Symlink(fsSourcePath, fsTargetPath); err != nil {
if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelWarn, "failed to create symlink %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
logger.CommandLog(symlinkLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1)
return nil
}
func (c *BaseConnection) getPathForSetStatPerms(fsPath, virtualPath string) string {
func (c *BaseConnection) getPathForSetStatPerms(fs vfs.Fs, fsPath, virtualPath string) string {
pathForPerms := virtualPath
if fi, err := c.Fs.Lstat(fsPath); err == nil {
if fi, err := fs.Lstat(fsPath); err == nil {
if fi.IsDir() {
pathForPerms = path.Dir(virtualPath)
}
@ -450,74 +457,89 @@ func (c *BaseConnection) getPathForSetStatPerms(fsPath, virtualPath string) stri
}
// DoStat execute a Stat if mode = 0, Lstat if mode = 1
func (c *BaseConnection) DoStat(fsPath string, mode int) (os.FileInfo, error) {
func (c *BaseConnection) DoStat(virtualPath string, mode int) (os.FileInfo, error) {
// for some vfs we don't create intermediary folders so we cannot simply check
// if virtualPath is a virtual folder
vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath))
if _, ok := vfolders[virtualPath]; ok {
return vfs.NewFileInfo(virtualPath, true, 0, time.Now(), false), nil
}
var info os.FileInfo
var err error
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return info, err
}
if mode == 1 {
info, err = c.Fs.Lstat(c.getRealFsPath(fsPath))
info, err = fs.Lstat(c.getRealFsPath(fsPath))
} else {
info, err = c.Fs.Stat(c.getRealFsPath(fsPath))
info, err = fs.Stat(c.getRealFsPath(fsPath))
}
if err == nil && vfs.IsCryptOsFs(c.Fs) {
info = c.Fs.(*vfs.CryptFs).ConvertFileInfo(info)
if err != nil {
return info, c.GetFsError(fs, err)
}
return info, err
if vfs.IsCryptOsFs(fs) {
info = fs.(*vfs.CryptFs).ConvertFileInfo(info)
}
return info, nil
}
func (c *BaseConnection) ignoreSetStat() bool {
func (c *BaseConnection) ignoreSetStat(fs vfs.Fs) bool {
if Config.SetstatMode == 1 {
return true
}
if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(c.Fs) {
if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(fs) && !vfs.IsCryptOsFs(fs) {
return true
}
return false
}
func (c *BaseConnection) handleChmod(fsPath, pathForPerms string, attributes *StatAttributes) error {
func (c *BaseConnection) handleChmod(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error {
if !c.User.HasPerm(dataprovider.PermChmod, pathForPerms) {
return c.GetPermissionDeniedError()
}
if c.ignoreSetStat() {
if c.ignoreSetStat(fs) {
return nil
}
if err := c.Fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil {
if err := fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil {
c.Log(logger.LevelWarn, "failed to chmod path %#v, mode: %v, err: %+v", fsPath, attributes.Mode.String(), err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
logger.CommandLog(chmodLogSender, fsPath, "", c.User.Username, attributes.Mode.String(), c.ID, c.protocol,
-1, -1, "", "", "", -1)
return nil
}
func (c *BaseConnection) handleChown(fsPath, pathForPerms string, attributes *StatAttributes) error {
func (c *BaseConnection) handleChown(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error {
if !c.User.HasPerm(dataprovider.PermChown, pathForPerms) {
return c.GetPermissionDeniedError()
}
if c.ignoreSetStat() {
if c.ignoreSetStat(fs) {
return nil
}
if err := c.Fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil {
if err := fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil {
c.Log(logger.LevelWarn, "failed to chown path %#v, uid: %v, gid: %v, err: %+v", fsPath, attributes.UID,
attributes.GID, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
logger.CommandLog(chownLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, attributes.UID, attributes.GID,
"", "", "", -1)
return nil
}
func (c *BaseConnection) handleChtimes(fsPath, pathForPerms string, attributes *StatAttributes) error {
func (c *BaseConnection) handleChtimes(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error {
if !c.User.HasPerm(dataprovider.PermChtimes, pathForPerms) {
return c.GetPermissionDeniedError()
}
if c.ignoreSetStat() {
if c.ignoreSetStat(fs) {
return nil
}
if err := c.Fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime); err != nil {
if err := fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime); err != nil {
c.Log(logger.LevelWarn, "failed to chtimes for path %#v, access time: %v, modification time: %v, err: %+v",
fsPath, attributes.Atime, attributes.Mtime, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
accessTimeString := attributes.Atime.Format(chtimesFormat)
modificationTimeString := attributes.Mtime.Format(chtimesFormat)
@ -527,19 +549,23 @@ func (c *BaseConnection) handleChtimes(fsPath, pathForPerms string, attributes *
}
// SetStat set StatAttributes for the specified fsPath
func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAttributes) error {
pathForPerms := c.getPathForSetStatPerms(fsPath, virtualPath)
func (c *BaseConnection) SetStat(virtualPath string, attributes *StatAttributes) error {
fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
if err != nil {
return err
}
pathForPerms := c.getPathForSetStatPerms(fs, fsPath, virtualPath)
if attributes.Flags&StatAttrPerms != 0 {
return c.handleChmod(fsPath, pathForPerms, attributes)
return c.handleChmod(fs, fsPath, pathForPerms, attributes)
}
if attributes.Flags&StatAttrUIDGID != 0 {
return c.handleChown(fsPath, pathForPerms, attributes)
return c.handleChown(fs, fsPath, pathForPerms, attributes)
}
if attributes.Flags&StatAttrTimes != 0 {
return c.handleChtimes(fsPath, pathForPerms, attributes)
return c.handleChtimes(fs, fsPath, pathForPerms, attributes)
}
if attributes.Flags&StatAttrSize != 0 {
@ -547,9 +573,9 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt
return c.GetPermissionDeniedError()
}
if err := c.truncateFile(fsPath, virtualPath, attributes.Size); err != nil {
if err := c.truncateFile(fs, fsPath, virtualPath, attributes.Size); err != nil {
c.Log(logger.LevelWarn, "failed to truncate path %#v, size: %v, err: %+v", fsPath, attributes.Size, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
logger.CommandLog(truncateLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", attributes.Size)
}
@ -557,7 +583,7 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt
return nil
}
func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) error {
func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, size int64) error {
// check first if we have an open transfer for the given path and try to truncate the file already opened
// if we found no transfer we truncate by path.
var initialSize int64
@ -566,14 +592,14 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er
if err == errNoTransfer {
c.Log(logger.LevelDebug, "file path %#v not found in active transfers, execute trucate by path", fsPath)
var info os.FileInfo
info, err = c.Fs.Stat(fsPath)
info, err = fs.Stat(fsPath)
if err != nil {
return err
}
initialSize = info.Size()
err = c.Fs.Truncate(fsPath, size)
err = fs.Truncate(fsPath, size)
}
if err == nil && vfs.IsLocalOrSFTPFs(c.Fs) {
if err == nil && vfs.IsLocalOrSFTPFs(fs) {
sizeDiff := initialSize - size
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil {
@ -588,20 +614,20 @@ func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) er
return err
}
func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPath string) error {
func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath string) error {
dstPerms := []string{
dataprovider.PermCreateDirs,
dataprovider.PermUpload,
dataprovider.PermCreateSymlinks,
}
err := c.Fs.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
err := fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
if err != nil {
return err
return c.GetFsError(fsSrc, err)
}
dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1)
virtualSrcPath := c.Fs.GetRelativePath(walkedPath)
virtualDstPath := c.Fs.GetRelativePath(dstPath)
virtualSrcPath := fsSrc.GetRelativePath(walkedPath)
virtualDstPath := fsDst.GetRelativePath(dstPath)
// walk scans the directory tree in order, checking the parent directory permissions we are sure that all contents
// inside the parent path was checked. If the current dir has no subdirs with defined permissions inside it
// and it has all the possible permissions we can stop scanning
@ -615,10 +641,10 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPa
return ErrSkipPermissionsCheck
}
}
if !c.isRenamePermitted(walkedPath, virtualSrcPath, virtualDstPath, info) {
if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) {
c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v",
walkedPath, dstPath, virtualDstPath)
return os.ErrPermission
return c.GetPermissionDeniedError()
}
return nil
})
@ -628,22 +654,7 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(sourcePath, targetPa
return err
}
func (c *BaseConnection) isRenamePermitted(fsSourcePath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
if c.Fs.GetRelativePath(fsSourcePath) == "/" {
c.Log(logger.LevelWarn, "renaming root dir is not allowed")
return false
}
if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) {
c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed")
return false
}
if !c.User.IsFileAllowed(virtualSourcePath) || !c.User.IsFileAllowed(virtualTargetPath) {
if fi != nil && fi.Mode().IsRegular() {
c.Log(logger.LevelDebug, "renaming file is not allowed, source: %#v target: %#v",
virtualSourcePath, virtualTargetPath)
return false
}
}
func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
if c.User.HasPerm(dataprovider.PermRename, path.Dir(virtualSourcePath)) &&
c.User.HasPerm(dataprovider.PermRename, path.Dir(virtualTargetPath)) {
return true
@ -661,7 +672,39 @@ func (c *BaseConnection) isRenamePermitted(fsSourcePath, virtualSourcePath, virt
return c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualTargetPath))
}
func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath string, initialSize int64,
func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
if !c.isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be local or on the same virtual folder",
virtualSourcePath, virtualTargetPath)
return false
}
if c.User.IsMappedPath(fsSourcePath) && vfs.IsLocalOrCryptoFs(fsSrc) {
c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %#v", fsSourcePath)
return false
}
if c.User.IsMappedPath(fsTargetPath) && vfs.IsLocalOrCryptoFs(fsDst) {
c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath)
return false
}
if fsSrc.GetRelativePath(fsSourcePath) == "/" {
c.Log(logger.LevelWarn, "renaming root dir is not allowed")
return false
}
if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) {
c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed")
return false
}
if !c.User.IsFileAllowed(virtualSourcePath) || !c.User.IsFileAllowed(virtualTargetPath) {
if fi != nil && fi.Mode().IsRegular() {
c.Log(logger.LevelDebug, "renaming file is not allowed, source: %#v target: %#v",
virtualSourcePath, virtualTargetPath)
return false
}
}
return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, fi)
}
func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath string, initialSize int64,
fsSourcePath string) bool {
if dataprovider.GetQuotaTracking() == 0 {
return true
@ -674,7 +717,7 @@ func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath
}
if errSrc == nil && errDst == nil {
// rename between virtual folders
if sourceFolder.MappedPath == dstFolder.MappedPath {
if sourceFolder.Name == dstFolder.Name {
// rename inside the same virtual folder
return true
}
@ -684,16 +727,16 @@ func (c *BaseConnection) hasSpaceForRename(virtualSourcePath, virtualTargetPath
return true
}
quotaResult := c.HasSpace(true, false, virtualTargetPath)
return c.hasSpaceForCrossRename(quotaResult, initialSize, fsSourcePath)
return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, fsSourcePath)
}
// hasSpaceForCrossRename checks the quota after a rename between different folders
func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string) bool {
func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64, sourcePath string) bool {
if !quotaResult.HasSpace && initialSize == -1 {
// we are over quota and this is not a file replace
return false
}
fi, err := c.Fs.Lstat(sourcePath)
fi, err := fs.Lstat(sourcePath)
if err != nil {
c.Log(logger.LevelWarn, "cross rename denied, stat error for path %#v: %v", sourcePath, err)
return false
@ -708,7 +751,7 @@ func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult
filesDiff = 0
}
} else if fi.IsDir() {
filesDiff, sizeDiff, err = c.Fs.GetDirSize(sourcePath)
filesDiff, sizeDiff, err = fs.GetDirSize(sourcePath)
if err != nil {
c.Log(logger.LevelWarn, "cross rename denied, error getting size for directory %#v: %v", sourcePath, err)
return false
@ -745,11 +788,11 @@ func (c *BaseConnection) hasSpaceForCrossRename(quotaResult vfs.QuotaCheckResult
// GetMaxWriteSize returns the allowed size for an upload or an error
// if no enough size is available for a resume/append
func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64) (int64, error) {
func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, isUploadResumeSupported bool) (int64, error) {
maxWriteSize := quotaResult.GetRemainingSize()
if isResume {
if !c.Fs.IsUploadResumeSupported() {
if !isUploadResumeSupported {
return 0, c.GetOpUnsupportedError()
}
if c.User.Filters.MaxUploadFileSize > 0 && c.User.Filters.MaxUploadFileSize <= fileSize {
@ -823,6 +866,40 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string)
return result
}
// returns true if this is a rename on the same fs or local virtual folders
func (c *BaseConnection) isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath string) bool {
sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
if errSrc != nil && errDst != nil {
return true
}
if errSrc == nil && errDst == nil {
if sourceFolder.Name == dstFolder.Name {
return true
}
// we have different folders, only local fs is supported
if sourceFolder.FsConfig.Provider == vfs.LocalFilesystemProvider &&
dstFolder.FsConfig.Provider == vfs.LocalFilesystemProvider {
return true
}
return false
}
if c.User.FsConfig.Provider != vfs.LocalFilesystemProvider {
return false
}
if errSrc == nil {
if sourceFolder.FsConfig.Provider == vfs.LocalFilesystemProvider {
return true
}
}
if errDst == nil {
if dstFolder.FsConfig.Provider == vfs.LocalFilesystemProvider {
return true
}
}
return false
}
func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool {
sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
@ -830,14 +907,14 @@ func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetP
return false
}
if errSrc == nil && errDst == nil {
return sourceFolder.MappedPath != dstFolder.MappedPath
return sourceFolder.Name != dstFolder.Name
}
return true
}
func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize,
filesSize int64, numFiles int) {
if sourceFolder.MappedPath == dstFolder.MappedPath {
if sourceFolder.Name == dstFolder.Name {
// both files are inside the same virtual folder
if initialSize != -1 {
dataprovider.UpdateVirtualFolderQuota(&dstFolder.BaseVirtualFolder, -numFiles, -initialSize, false) //nolint:errcheck
@ -897,7 +974,10 @@ func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder,
}
}
func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTargetPath, targetPath string, initialSize int64) error {
func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string, initialSize int64) error {
if dataprovider.GetQuotaTracking() == 0 {
return nil
}
// we don't allow to overwrite an existing directory so targetPath can be:
// - a new file, a symlink is as a new file here
// - a file overwriting an existing one
@ -908,7 +988,8 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
if errSrc != nil && errDst != nil {
// both files are contained inside the user home dir
if initialSize != -1 {
// we cannot have a directory here
// we cannot have a directory here, we are overwriting an existing file
// we need to subtract the size of the overwritten file from the user quota
dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck
}
return nil
@ -916,9 +997,9 @@ func (c *BaseConnection) updateQuotaAfterRename(virtualSourcePath, virtualTarget
filesSize := int64(0)
numFiles := 1
if fi, err := c.Fs.Stat(targetPath); err == nil {
if fi, err := fs.Stat(targetPath); err == nil {
if fi.Mode().IsDir() {
numFiles, filesSize, err = c.Fs.GetDirSize(targetPath)
numFiles, filesSize, err = fs.GetDirSize(targetPath)
if err != nil {
c.Log(logger.LevelWarn, "failed to update quota after rename, error scanning moved folder %#v: %v",
targetPath, err)
@ -995,15 +1076,30 @@ func (c *BaseConnection) GetGenericError(err error) error {
}
// GetFsError converts a filesystem error to a protocol error
func (c *BaseConnection) GetFsError(err error) error {
if c.Fs.IsNotExist(err) {
func (c *BaseConnection) GetFsError(fs vfs.Fs, err error) error {
if fs.IsNotExist(err) {
return c.GetNotExistError()
} else if c.Fs.IsPermission(err) {
} else if fs.IsPermission(err) {
return c.GetPermissionDeniedError()
} else if c.Fs.IsNotSupported(err) {
} else if fs.IsNotSupported(err) {
return c.GetOpUnsupportedError()
} else if err != nil {
return c.GetGenericError(err)
}
return nil
}
// GetFsAndResolvedPath returns the fs and the fs path matching virtualPath
func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, string, error) {
fs, err := c.User.GetFilesystemForPath(virtualPath, c.ID)
if err != nil {
return nil, "", err
}
fsPath, err := fs.ResolvePath(virtualPath)
if err != nil {
return nil, "", c.GetFsError(fs, err)
}
return fs, fsPath, nil
}

File diff suppressed because it is too large Load diff

2415
common/protocol_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -180,7 +180,7 @@ func (t *BaseTransfer) getUploadFileSize() (int64, error) {
fileSize = info.Size()
}
if vfs.IsCryptOsFs(t.Fs) && t.ErrTransfer != nil {
errDelete := t.Connection.Fs.Remove(t.fsPath, false)
errDelete := t.Fs.Remove(t.fsPath, false)
if errDelete != nil {
t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %#v: %v", t.fsPath, errDelete)
}
@ -204,7 +204,7 @@ func (t *BaseTransfer) Close() error {
metrics.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer)
if t.ErrTransfer == ErrQuotaExceeded && t.File != nil {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = t.Connection.Fs.Remove(t.File.Name(), false)
err = t.Fs.Remove(t.File.Name(), false)
if err == nil {
numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
@ -214,11 +214,11 @@ func (t *BaseTransfer) Close() error {
t.File.Name(), err)
} else if t.transferType == TransferUpload && t.File != nil && t.File.Name() != t.fsPath {
if t.ErrTransfer == nil || Config.UploadMode == UploadModeAtomicWithResume {
err = t.Connection.Fs.Rename(t.File.Name(), t.fsPath)
err = t.Fs.Rename(t.File.Name(), t.fsPath)
t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %#v -> %#v, error: %v",
t.File.Name(), t.fsPath, err)
} else {
err = t.Connection.Fs.Remove(t.File.Name(), false)
err = t.Fs.Remove(t.File.Name(), false)
t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
"deletion error: %v", t.ErrTransfer, t.File.Name(), err)
if err == nil {

View file

@ -16,12 +16,12 @@ import (
)
func TestTransferUpdateQuota(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{}, nil)
conn := NewBaseConnection("", ProtocolSFTP, dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), nil),
Fs: vfs.NewOsFs("", os.TempDir(), ""),
}
errFake := errors.New("fake error")
transfer.TransferError(errFake)
@ -54,14 +54,14 @@ func TestTransferThrottling(t *testing.T) {
UploadBandwidth: 50,
DownloadBandwidth: 40,
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
fs := vfs.NewOsFs("", os.TempDir(), "")
testFileSize := int64(131072)
wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth
wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth
// some tolerance
wantedUploadElapsed -= wantedDownloadElapsed / 10
wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, u, nil)
conn := NewBaseConnection("id", ProtocolSCP, u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", TransferUpload, 0, 0, 0, true, fs)
transfer.BytesReceived = testFileSize
transfer.Connection.UpdateLastActivity()
@ -86,7 +86,7 @@ func TestTransferThrottling(t *testing.T) {
func TestRealPath(t *testing.T) {
testFile := filepath.Join(os.TempDir(), "afile.txt")
fs := vfs.NewOsFs("123", os.TempDir(), nil)
fs := vfs.NewOsFs("123", os.TempDir(), "")
u := dataprovider.User{
Username: "user",
HomeDir: os.TempDir(),
@ -95,7 +95,7 @@ func TestRealPath(t *testing.T) {
u.Permissions["/"] = []string{dataprovider.PermAny}
file, err := os.Create(testFile)
require.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u)
transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
rPath := transfer.GetRealFsPath(testFile)
assert.Equal(t, testFile, rPath)
@ -117,7 +117,7 @@ func TestRealPath(t *testing.T) {
func TestTruncate(t *testing.T) {
testFile := filepath.Join(os.TempDir(), "transfer_test_file")
fs := vfs.NewOsFs("123", os.TempDir(), nil)
fs := vfs.NewOsFs("123", os.TempDir(), "")
u := dataprovider.User{
Username: "user",
HomeDir: os.TempDir(),
@ -130,10 +130,10 @@ func TestTruncate(t *testing.T) {
}
_, err = file.Write([]byte("hello"))
assert.NoError(t, err)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u)
transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 5, 100, false, fs)
err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{
err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2,
Flags: StatAttrSize,
})
@ -150,7 +150,7 @@ func TestTruncate(t *testing.T) {
transfer = NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs)
// file.Stat will fail on a closed file
err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{
err = conn.SetStat("/transfer_test_file", &StatAttributes{
Size: 2,
Flags: StatAttrSize,
})
@ -181,7 +181,7 @@ func TestTransferErrors(t *testing.T) {
isCancelled = true
}
testFile := filepath.Join(os.TempDir(), "transfer_test_file")
fs := vfs.NewOsFs("id", os.TempDir(), nil)
fs := vfs.NewOsFs("id", os.TempDir(), "")
u := dataprovider.User{
Username: "test",
HomeDir: os.TempDir(),
@ -192,7 +192,7 @@ func TestTransferErrors(t *testing.T) {
if !assert.NoError(t, err) {
assert.FailNow(t, "unable to open test file")
}
conn := NewBaseConnection("id", ProtocolSFTP, u, fs)
conn := NewBaseConnection("id", ProtocolSFTP, u)
transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
assert.Nil(t, transfer.cancelFn)
assert.Equal(t, testFile, transfer.GetFsPath())
@ -255,13 +255,13 @@ func TestTransferErrors(t *testing.T) {
func TestRemovePartialCryptoFile(t *testing.T) {
testFile := filepath.Join(os.TempDir(), "transfer_test_file")
fs, err := vfs.NewCryptFs("id", os.TempDir(), vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")})
fs, err := vfs.NewCryptFs("id", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")})
require.NoError(t, err)
u := dataprovider.User{
Username: "test",
HomeDir: os.TempDir(),
}
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs)
conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u)
transfer := NewBaseTransfer(nil, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs)
transfer.ErrTransfer = errors.New("test error")
_, err = transfer.getUploadFileSize()

View file

@ -44,7 +44,7 @@ func initializeBoltProvider(basePath string) error {
logSender = fmt.Sprintf("dataprovider_%v", BoltDataProviderName)
dbPath := config.Name
if !utils.IsFileInputValid(dbPath) {
return fmt.Errorf("Invalid database path: %#v", dbPath)
return fmt.Errorf("invalid database path: %#v", dbPath)
}
if !filepath.IsAbs(dbPath) {
dbPath = filepath.Join(basePath, dbPath)
@ -119,7 +119,7 @@ func (p *BoltProvider) validateUserAndTLSCert(username, protocol string, tlsCert
func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) {
var user User
if password == "" {
return user, errors.New("Credentials cannot be null or empty")
return user, errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
@ -142,7 +142,7 @@ func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admi
func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
var user User
if len(pubKey) == 0 {
return user, "", errors.New("Credentials cannot be null or empty")
return user, "", errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
@ -435,8 +435,8 @@ func (p *BoltProvider) addUser(user *User) error {
user.UsedQuotaSize = 0
user.UsedQuotaFiles = 0
user.LastLogin = 0
for _, folder := range user.VirtualFolders {
err = addUserToFolderMapping(folder, user, folderBucket)
for idx := range user.VirtualFolders {
err = addUserToFolderMapping(&user.VirtualFolders[idx].BaseVirtualFolder, user, folderBucket)
if err != nil {
return err
}
@ -472,14 +472,14 @@ func (p *BoltProvider) updateUser(user *User) error {
if err != nil {
return err
}
for _, folder := range oldUser.VirtualFolders {
err = removeUserFromFolderMapping(folder, &oldUser, folderBucket)
for idx := range oldUser.VirtualFolders {
err = removeUserFromFolderMapping(&oldUser.VirtualFolders[idx], &oldUser, folderBucket)
if err != nil {
return err
}
}
for _, folder := range user.VirtualFolders {
err = addUserToFolderMapping(folder, user, folderBucket)
for idx := range user.VirtualFolders {
err = addUserToFolderMapping(&user.VirtualFolders[idx].BaseVirtualFolder, user, folderBucket)
if err != nil {
return err
}
@ -508,8 +508,8 @@ func (p *BoltProvider) deleteUser(user *User) error {
if err != nil {
return err
}
for _, folder := range user.VirtualFolders {
err = removeUserFromFolderMapping(folder, user, folderBucket)
for idx := range user.VirtualFolders {
err = removeUserFromFolderMapping(&user.VirtualFolders[idx], user, folderBucket)
if err != nil {
return err
}
@ -649,6 +649,7 @@ func (p *BoltProvider) getFolders(limit, offset int, order string) ([]vfs.BaseVi
if err != nil {
return err
}
folder.HideConfidentialData()
folders = append(folders, folder)
if len(folders) >= limit {
break
@ -665,6 +666,7 @@ func (p *BoltProvider) getFolders(limit, offset int, order string) ([]vfs.BaseVi
if err != nil {
return err
}
folder.HideConfidentialData()
folders = append(folders, folder)
if len(folders) >= limit {
break
@ -703,8 +705,7 @@ func (p *BoltProvider) addFolder(folder *vfs.BaseVirtualFolder) error {
return fmt.Errorf("folder %v already exists", folder.Name)
}
folder.Users = nil
_, err = addFolderInternal(*folder, bucket)
return err
return addFolderInternal(*folder, bucket)
})
}
@ -867,7 +868,7 @@ func (p *BoltProvider) migrateDatabase() error {
boltDatabaseVersion)
return nil
}
return fmt.Errorf("Database version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@ -893,17 +894,14 @@ func joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) {
}
if len(user.VirtualFolders) > 0 {
var folders []vfs.VirtualFolder
for _, folder := range user.VirtualFolders {
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
baseFolder, err := folderExistsInternal(folder.Name, foldersBucket)
if err != nil {
continue
}
folder.MappedPath = baseFolder.MappedPath
folder.UsedQuotaFiles = baseFolder.UsedQuotaFiles
folder.UsedQuotaSize = baseFolder.UsedQuotaSize
folder.LastQuotaUpdate = baseFolder.LastQuotaUpdate
folder.ID = baseFolder.ID
folders = append(folders, folder)
folder.BaseVirtualFolder = baseFolder
folders = append(folders, *folder)
}
user.VirtualFolders = folders
}
@ -922,50 +920,50 @@ func folderExistsInternal(name string, bucket *bolt.Bucket) (vfs.BaseVirtualFold
return folder, err
}
func addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) (vfs.BaseVirtualFolder, error) {
func addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) error {
id, err := bucket.NextSequence()
if err != nil {
return folder, err
return err
}
folder.ID = int64(id)
buf, err := json.Marshal(folder)
if err != nil {
return folder, err
return err
}
err = bucket.Put([]byte(folder.Name), buf)
return folder, err
return bucket.Put([]byte(folder.Name), buf)
}
func addUserToFolderMapping(folder vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error {
var baseFolder vfs.BaseVirtualFolder
var err error
if f := bucket.Get([]byte(folder.Name)); f == nil {
func addUserToFolderMapping(baseFolder *vfs.BaseVirtualFolder, user *User, bucket *bolt.Bucket) error {
f := bucket.Get([]byte(baseFolder.Name))
if f == nil {
// folder does not exists, try to create
folder.LastQuotaUpdate = 0
folder.UsedQuotaFiles = 0
folder.UsedQuotaSize = 0
baseFolder, err = addFolderInternal(folder.BaseVirtualFolder, bucket)
} else {
err = json.Unmarshal(f, &baseFolder)
baseFolder.LastQuotaUpdate = 0
baseFolder.UsedQuotaFiles = 0
baseFolder.UsedQuotaSize = 0
baseFolder.Users = []string{user.Username}
return addFolderInternal(*baseFolder, bucket)
}
var oldFolder vfs.BaseVirtualFolder
err := json.Unmarshal(f, &oldFolder)
if err != nil {
return err
}
baseFolder.ID = oldFolder.ID
baseFolder.LastQuotaUpdate = oldFolder.LastQuotaUpdate
baseFolder.UsedQuotaFiles = oldFolder.UsedQuotaFiles
baseFolder.UsedQuotaSize = oldFolder.UsedQuotaSize
baseFolder.Users = oldFolder.Users
if !utils.IsStringInSlice(user.Username, baseFolder.Users) {
baseFolder.Users = append(baseFolder.Users, user.Username)
buf, err := json.Marshal(baseFolder)
if err != nil {
return err
}
err = bucket.Put([]byte(folder.Name), buf)
if err != nil {
return err
}
}
return err
buf, err := json.Marshal(baseFolder)
if err != nil {
return err
}
return bucket.Put([]byte(baseFolder.Name), buf)
}
func removeUserFromFolderMapping(folder vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error {
func removeUserFromFolderMapping(folder *vfs.VirtualFolder, user *User, bucket *bolt.Bucket) error {
var f []byte
if f = bucket.Get([]byte(folder.Name)); f == nil {
// the folder does not exists so there is no associated user

View file

@ -105,7 +105,7 @@ var (
// ValidProtocols defines all the valid protcols
ValidProtocols = []string{"SSH", "FTP", "DAV"}
// ErrNoInitRequired defines the error returned by InitProvider if no inizialization/update is required
ErrNoInitRequired = errors.New("The data provider is up to date")
ErrNoInitRequired = errors.New("the data provider is up to date")
// ErrInvalidCredentials defines the error to return if the supplied credentials are invalid
ErrInvalidCredentials = errors.New("invalid credentials")
validTLSUsernames = []string{string(TLSUsernameNone), string(TLSUsernameCN)}
@ -410,6 +410,11 @@ type Provider interface {
revertDatabase(targetVersion int) error
}
type fsValidatorHelper interface {
GetGCSCredentialsFilePath() string
GetEncrytionAdditionalData() string
}
// Initialize the data provider.
// An error is returned if the configured driver is invalid or if the data provider cannot be initialized
func Initialize(cnf Config, basePath string, checkAdmins bool) error {
@ -421,6 +426,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error {
} else {
credentialsDirPath = filepath.Join(basePath, config.CredentialsPath)
}
vfs.SetCredentialsDirPath(credentialsDirPath)
if err = validateHooks(); err != nil {
return err
@ -498,7 +504,7 @@ func validateSQLTablesPrefix() error {
if len(config.SQLTablesPrefix) > 0 {
for _, char := range config.SQLTablesPrefix {
if !strings.Contains(sqlPrefixValidChars, strings.ToLower(string(char))) {
return errors.New("Invalid sql_tables_prefix only chars in range 'a..z', 'A..Z' and '_' are allowed")
return errors.New("invalid sql_tables_prefix only chars in range 'a..z', 'A..Z' and '_' are allowed")
}
}
sqlTableUsers = config.SQLTablesPrefix + sqlTableUsers
@ -583,7 +589,7 @@ func CheckCachedUserCredentials(user *CachedUser, password, loginMethod, protoco
}
if loginMethod == LoginMethodTLSCertificate {
if !user.User.IsLoginMethodAllowed(LoginMethodTLSCertificate, nil) {
return fmt.Errorf("Certificate login method is not allowed for user %#v", user.User.Username)
return fmt.Errorf("certificate login method is not allowed for user %#v", user.User.Username)
}
return nil
}
@ -628,7 +634,7 @@ func CheckCompositeCredentials(username, password, ip, loginMethod, protocol str
return user, loginMethod, err
}
if loginMethod == LoginMethodTLSCertificate && !user.IsLoginMethodAllowed(LoginMethodTLSCertificate, nil) {
return user, loginMethod, fmt.Errorf("Certificate login method is not allowed for user %#v", user.Username)
return user, loginMethod, fmt.Errorf("certificate login method is not allowed for user %#v", user.Username)
}
if loginMethod == LoginMethodTLSCertificateAndPwd {
if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) {
@ -876,8 +882,14 @@ func AddFolder(folder *vfs.BaseVirtualFolder) error {
}
// UpdateFolder updates the specified virtual folder
func UpdateFolder(folder *vfs.BaseVirtualFolder) error {
return provider.updateFolder(folder)
func UpdateFolder(folder *vfs.BaseVirtualFolder, users []string) error {
err := provider.updateFolder(folder)
if err == nil {
for _, user := range users {
RemoveCachedWebDAVUser(user)
}
}
return err
}
// DeleteFolder deletes an existing folder.
@ -886,7 +898,13 @@ func DeleteFolder(folderName string) error {
if err != nil {
return err
}
return provider.deleteFolder(&folder)
err = provider.deleteFolder(&folder)
if err == nil {
for _, user := range folder.Users {
RemoveCachedWebDAVUser(user)
}
}
return err
}
// GetFolderByName returns the folder with the specified name if any
@ -981,41 +999,45 @@ func buildUserHomeDir(user *User) {
if user.HomeDir == "" {
if config.UsersBaseDir != "" {
user.HomeDir = filepath.Join(config.UsersBaseDir, user.Username)
} else if user.FsConfig.Provider == SFTPFilesystemProvider {
} else if user.FsConfig.Provider == vfs.SFTPFilesystemProvider {
user.HomeDir = filepath.Join(os.TempDir(), user.Username)
}
}
}
func isVirtualDirOverlapped(dir1, dir2 string) bool {
func isVirtualDirOverlapped(dir1, dir2 string, fullCheck bool) bool {
if dir1 == dir2 {
return true
}
if len(dir1) > len(dir2) {
if strings.HasPrefix(dir1, dir2+"/") {
return true
if fullCheck {
if len(dir1) > len(dir2) {
if strings.HasPrefix(dir1, dir2+"/") {
return true
}
}
}
if len(dir2) > len(dir1) {
if strings.HasPrefix(dir2, dir1+"/") {
return true
if len(dir2) > len(dir1) {
if strings.HasPrefix(dir2, dir1+"/") {
return true
}
}
}
return false
}
func isMappedDirOverlapped(dir1, dir2 string) bool {
func isMappedDirOverlapped(dir1, dir2 string, fullCheck bool) bool {
if dir1 == dir2 {
return true
}
if len(dir1) > len(dir2) {
if strings.HasPrefix(dir1, dir2+string(os.PathSeparator)) {
return true
if fullCheck {
if len(dir1) > len(dir2) {
if strings.HasPrefix(dir1, dir2+string(os.PathSeparator)) {
return true
}
}
}
if len(dir2) > len(dir1) {
if strings.HasPrefix(dir2, dir1+string(os.PathSeparator)) {
return true
if len(dir2) > len(dir1) {
if strings.HasPrefix(dir2, dir1+string(os.PathSeparator)) {
return true
}
}
}
return false
@ -1046,19 +1068,33 @@ func getVirtualFolderIfInvalid(folder *vfs.BaseVirtualFolder) *vfs.BaseVirtualFo
if folder.Name == "" {
return folder
}
if folder.FsConfig.Provider != vfs.LocalFilesystemProvider {
return folder
}
if f, err := GetFolderByName(folder.Name); err == nil {
return &f
}
return folder
}
func hasSFTPLoopForFolder(user *User, folder *vfs.BaseVirtualFolder) bool {
if folder.FsConfig.Provider == vfs.SFTPFilesystemProvider {
// FIXME: this could be inaccurate, it is not easy to check the endpoint too
if folder.FsConfig.SFTPConfig.Username == user.Username {
return true
}
}
return false
}
func validateUserVirtualFolders(user *User) error {
if len(user.VirtualFolders) == 0 || user.FsConfig.Provider != LocalFilesystemProvider {
if len(user.VirtualFolders) == 0 {
user.VirtualFolders = []vfs.VirtualFolder{}
return nil
}
var virtualFolders []vfs.VirtualFolder
mappedPaths := make(map[string]string)
mappedPaths := make(map[string]bool)
virtualPaths := make(map[string]bool)
for _, v := range user.VirtualFolders {
cleanedVPath := filepath.ToSlash(path.Clean(v.VirtualPath))
if !path.IsAbs(cleanedVPath) || cleanedVPath == "/" {
@ -1071,34 +1107,37 @@ func validateUserVirtualFolders(user *User) error {
if err := ValidateFolder(folder); err != nil {
return err
}
cleanedMPath := folder.MappedPath
if isMappedDirOverlapped(cleanedMPath, user.GetHomeDir()) {
return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v cannot be inside or contain the user home dir %#v",
folder.MappedPath, user.GetHomeDir())}
if hasSFTPLoopForFolder(user, folder) {
return &ValidationError{err: fmt.Sprintf("SFTP folder %#v could point to the same SFTPGo account, this is not allowed",
folder.Name)}
}
cleanedMPath := folder.MappedPath
if folder.IsLocalOrLocalCrypted() {
if isMappedDirOverlapped(cleanedMPath, user.GetHomeDir(), true) {
return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v cannot be inside or contain the user home dir %#v",
folder.MappedPath, user.GetHomeDir())}
}
for mPath := range mappedPaths {
if folder.IsLocalOrLocalCrypted() && isMappedDirOverlapped(mPath, cleanedMPath, false) {
return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v overlaps with mapped folder %#v",
v.MappedPath, mPath)}
}
}
mappedPaths[cleanedMPath] = true
}
for vPath := range virtualPaths {
if isVirtualDirOverlapped(vPath, cleanedVPath, false) {
return &ValidationError{err: fmt.Sprintf("invalid virtual folder %#v overlaps with virtual folder %#v",
v.VirtualPath, vPath)}
}
}
virtualPaths[cleanedVPath] = true
virtualFolders = append(virtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: *folder,
VirtualPath: cleanedVPath,
QuotaSize: v.QuotaSize,
QuotaFiles: v.QuotaFiles,
})
for k, virtual := range mappedPaths {
if GetQuotaTracking() > 0 {
if isMappedDirOverlapped(k, cleanedMPath) {
return &ValidationError{err: fmt.Sprintf("invalid mapped folder %#v overlaps with mapped folder %#v",
v.MappedPath, k)}
}
} else {
if k == cleanedMPath {
return &ValidationError{err: fmt.Sprintf("duplicated mapped folder %#v", v.MappedPath)}
}
}
if isVirtualDirOverlapped(virtual, cleanedVPath) {
return &ValidationError{err: fmt.Sprintf("invalid virtual folder %#v overlaps with virtual folder %#v",
v.VirtualPath, virtual)}
}
}
mappedPaths[cleanedMPath] = cleanedVPath
}
user.VirtualFolders = virtualFolders
return nil
@ -1297,35 +1336,35 @@ func validateFilters(user *User) error {
return validateFileFilters(user)
}
func saveGCSCredentials(user *User) error {
if user.FsConfig.Provider != GCSFilesystemProvider {
func saveGCSCredentials(fsConfig *vfs.Filesystem, helper fsValidatorHelper) error {
if fsConfig.Provider != vfs.GCSFilesystemProvider {
return nil
}
if user.FsConfig.GCSConfig.Credentials.GetPayload() == "" {
if fsConfig.GCSConfig.Credentials.GetPayload() == "" {
return nil
}
if config.PreferDatabaseCredentials {
if user.FsConfig.GCSConfig.Credentials.IsPlain() {
user.FsConfig.GCSConfig.Credentials.SetAdditionalData(user.Username)
err := user.FsConfig.GCSConfig.Credentials.Encrypt()
if fsConfig.GCSConfig.Credentials.IsPlain() {
fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncrytionAdditionalData())
err := fsConfig.GCSConfig.Credentials.Encrypt()
if err != nil {
return err
}
}
return nil
}
if user.FsConfig.GCSConfig.Credentials.IsPlain() {
user.FsConfig.GCSConfig.Credentials.SetAdditionalData(user.Username)
err := user.FsConfig.GCSConfig.Credentials.Encrypt()
if fsConfig.GCSConfig.Credentials.IsPlain() {
fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncrytionAdditionalData())
err := fsConfig.GCSConfig.Credentials.Encrypt()
if err != nil {
return &ValidationError{err: fmt.Sprintf("could not encrypt GCS credentials: %v", err)}
}
}
creds, err := json.Marshal(user.FsConfig.GCSConfig.Credentials)
creds, err := json.Marshal(fsConfig.GCSConfig.Credentials)
if err != nil {
return &ValidationError{err: fmt.Sprintf("could not marshal GCS credentials: %v", err)}
}
credentialsFilePath := user.getGCSCredentialsFilePath()
credentialsFilePath := helper.GetGCSCredentialsFilePath()
err = os.MkdirAll(filepath.Dir(credentialsFilePath), 0700)
if err != nil {
return &ValidationError{err: fmt.Sprintf("could not create GCS credentials dir: %v", err)}
@ -1334,75 +1373,75 @@ func saveGCSCredentials(user *User) error {
if err != nil {
return &ValidationError{err: fmt.Sprintf("could not save GCS credentials: %v", err)}
}
user.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret()
fsConfig.GCSConfig.Credentials = kms.NewEmptySecret()
return nil
}
func validateFilesystemConfig(user *User) error {
if user.FsConfig.Provider == S3FilesystemProvider {
if err := user.FsConfig.S3Config.Validate(); err != nil {
func validateFilesystemConfig(fsConfig *vfs.Filesystem, helper fsValidatorHelper) error {
if fsConfig.Provider == vfs.S3FilesystemProvider {
if err := fsConfig.S3Config.Validate(); err != nil {
return &ValidationError{err: fmt.Sprintf("could not validate s3config: %v", err)}
}
if err := user.FsConfig.S3Config.EncryptCredentials(user.Username); err != nil {
if err := fsConfig.S3Config.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil {
return &ValidationError{err: fmt.Sprintf("could not encrypt s3 access secret: %v", err)}
}
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
fsConfig.GCSConfig = vfs.GCSFsConfig{}
fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
fsConfig.CryptConfig = vfs.CryptFsConfig{}
fsConfig.SFTPConfig = vfs.SFTPFsConfig{}
return nil
} else if user.FsConfig.Provider == GCSFilesystemProvider {
if err := user.FsConfig.GCSConfig.Validate(user.getGCSCredentialsFilePath()); err != nil {
} else if fsConfig.Provider == vfs.GCSFilesystemProvider {
if err := fsConfig.GCSConfig.Validate(helper.GetGCSCredentialsFilePath()); err != nil {
return &ValidationError{err: fmt.Sprintf("could not validate GCS config: %v", err)}
}
user.FsConfig.S3Config = vfs.S3FsConfig{}
user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
fsConfig.S3Config = vfs.S3FsConfig{}
fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
fsConfig.CryptConfig = vfs.CryptFsConfig{}
fsConfig.SFTPConfig = vfs.SFTPFsConfig{}
return nil
} else if user.FsConfig.Provider == AzureBlobFilesystemProvider {
if err := user.FsConfig.AzBlobConfig.Validate(); err != nil {
} else if fsConfig.Provider == vfs.AzureBlobFilesystemProvider {
if err := fsConfig.AzBlobConfig.Validate(); err != nil {
return &ValidationError{err: fmt.Sprintf("could not validate Azure Blob config: %v", err)}
}
if err := user.FsConfig.AzBlobConfig.EncryptCredentials(user.Username); err != nil {
if err := fsConfig.AzBlobConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil {
return &ValidationError{err: fmt.Sprintf("could not encrypt Azure blob account key: %v", err)}
}
user.FsConfig.S3Config = vfs.S3FsConfig{}
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
fsConfig.S3Config = vfs.S3FsConfig{}
fsConfig.GCSConfig = vfs.GCSFsConfig{}
fsConfig.CryptConfig = vfs.CryptFsConfig{}
fsConfig.SFTPConfig = vfs.SFTPFsConfig{}
return nil
} else if user.FsConfig.Provider == CryptedFilesystemProvider {
if err := user.FsConfig.CryptConfig.Validate(); err != nil {
} else if fsConfig.Provider == vfs.CryptedFilesystemProvider {
if err := fsConfig.CryptConfig.Validate(); err != nil {
return &ValidationError{err: fmt.Sprintf("could not validate Crypt fs config: %v", err)}
}
if err := user.FsConfig.CryptConfig.EncryptCredentials(user.Username); err != nil {
if err := fsConfig.CryptConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil {
return &ValidationError{err: fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)}
}
user.FsConfig.S3Config = vfs.S3FsConfig{}
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
fsConfig.S3Config = vfs.S3FsConfig{}
fsConfig.GCSConfig = vfs.GCSFsConfig{}
fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
fsConfig.SFTPConfig = vfs.SFTPFsConfig{}
return nil
} else if user.FsConfig.Provider == SFTPFilesystemProvider {
if err := user.FsConfig.SFTPConfig.Validate(); err != nil {
} else if fsConfig.Provider == vfs.SFTPFilesystemProvider {
if err := fsConfig.SFTPConfig.Validate(); err != nil {
return &ValidationError{err: fmt.Sprintf("could not validate SFTP fs config: %v", err)}
}
if err := user.FsConfig.SFTPConfig.EncryptCredentials(user.Username); err != nil {
if err := fsConfig.SFTPConfig.EncryptCredentials(helper.GetEncrytionAdditionalData()); err != nil {
return &ValidationError{err: fmt.Sprintf("could not encrypt SFTP fs credentials: %v", err)}
}
user.FsConfig.S3Config = vfs.S3FsConfig{}
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
fsConfig.S3Config = vfs.S3FsConfig{}
fsConfig.GCSConfig = vfs.GCSFsConfig{}
fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
fsConfig.CryptConfig = vfs.CryptFsConfig{}
return nil
}
user.FsConfig.Provider = LocalFilesystemProvider
user.FsConfig.S3Config = vfs.S3FsConfig{}
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
fsConfig.Provider = vfs.LocalFilesystemProvider
fsConfig.S3Config = vfs.S3FsConfig{}
fsConfig.GCSConfig = vfs.GCSFsConfig{}
fsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
fsConfig.CryptConfig = vfs.CryptFsConfig{}
fsConfig.SFTPConfig = vfs.SFTPFsConfig{}
return nil
}
@ -1447,11 +1486,23 @@ func ValidateFolder(folder *vfs.BaseVirtualFolder) error {
return &ValidationError{err: fmt.Sprintf("folder name %#v is not valid, the following characters are allowed: a-zA-Z0-9-_.~",
folder.Name)}
}
cleanedMPath := filepath.Clean(folder.MappedPath)
if !filepath.IsAbs(cleanedMPath) {
return &ValidationError{err: fmt.Sprintf("invalid folder mapped path %#v", folder.MappedPath)}
if folder.FsConfig.Provider == vfs.LocalFilesystemProvider || folder.FsConfig.Provider == vfs.CryptedFilesystemProvider ||
folder.MappedPath != "" {
cleanedMPath := filepath.Clean(folder.MappedPath)
if !filepath.IsAbs(cleanedMPath) {
return &ValidationError{err: fmt.Sprintf("invalid folder mapped path %#v", folder.MappedPath)}
}
folder.MappedPath = cleanedMPath
}
if folder.HasRedactedSecret() {
return errors.New("cannot save a folder with a redacted secret")
}
if err := validateFilesystemConfig(&folder.FsConfig, folder); err != nil {
return err
}
if err := saveGCSCredentials(&folder.FsConfig, folder); err != nil {
return err
}
folder.MappedPath = cleanedMPath
return nil
}
@ -1466,7 +1517,10 @@ func ValidateUser(user *User) error {
if err := validatePermissions(user); err != nil {
return err
}
if err := validateFilesystemConfig(user); err != nil {
if user.hasRedactedSecret() {
return errors.New("cannot save a user with a redacted secret")
}
if err := validateFilesystemConfig(&user.FsConfig, user); err != nil {
return err
}
if err := validateUserVirtualFolders(user); err != nil {
@ -1484,7 +1538,7 @@ func ValidateUser(user *User) error {
if err := validateFilters(user); err != nil {
return err
}
if err := saveGCSCredentials(user); err != nil {
if err := saveGCSCredentials(&user.FsConfig, user); err != nil {
return err
}
return nil
@ -1555,12 +1609,12 @@ func checkUserAndPass(user *User, password, ip, protocol string) (User, error) {
return *user, err
}
if user.Password == "" {
return *user, errors.New("Credentials cannot be null or empty")
return *user, errors.New("credentials cannot be null or empty")
}
hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol)
if err != nil {
providerLog(logger.LevelDebug, "error executing check password hook: %v", err)
return *user, errors.New("Unable to check credentials")
return *user, errors.New("unable to check credentials")
}
switch hookResponse.Status {
case -1:
@ -1664,7 +1718,10 @@ func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error)
}
func addCredentialsToUser(user *User) error {
if user.FsConfig.Provider != GCSFilesystemProvider {
if err := addFolderCredentialsToUser(user); err != nil {
return err
}
if user.FsConfig.Provider != vfs.GCSFilesystemProvider {
return nil
}
if user.FsConfig.GCSConfig.AutomaticCredentials > 0 {
@ -1676,13 +1733,38 @@ func addCredentialsToUser(user *User) error {
return nil
}
cred, err := os.ReadFile(user.getGCSCredentialsFilePath())
cred, err := os.ReadFile(user.GetGCSCredentialsFilePath())
if err != nil {
return err
}
return json.Unmarshal(cred, &user.FsConfig.GCSConfig.Credentials)
}
func addFolderCredentialsToUser(user *User) error {
for idx := range user.VirtualFolders {
f := &user.VirtualFolders[idx]
if f.FsConfig.Provider != vfs.GCSFilesystemProvider {
continue
}
if f.FsConfig.GCSConfig.AutomaticCredentials > 0 {
continue
}
// Don't read from file if credentials have already been set
if f.FsConfig.GCSConfig.Credentials.IsValid() {
continue
}
cred, err := os.ReadFile(f.GetGCSCredentialsFilePath())
if err != nil {
return err
}
err = json.Unmarshal(cred, f.FsConfig.GCSConfig.Credentials)
if err != nil {
return err
}
}
return nil
}
func getSSLMode() string {
if config.Driver == PGSQLDataProviderName {
if config.SSLMode == 0 {
@ -2014,7 +2096,9 @@ func executeCheckPasswordHook(username, password, ip, protocol string) (checkPas
return response, nil
}
startTime := time.Now()
out, err := getPasswordHookResponse(username, password, ip, protocol)
providerLog(logger.LevelDebug, "check password hook executed, error: %v, elapsed: %v", err, time.Since(startTime))
if err != nil {
return response, err
}
@ -2078,10 +2162,12 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
if err != nil {
return u, err
}
startTime := time.Now()
out, err := getPreLoginHookResponse(loginMethod, ip, protocol, userAsJSON)
if err != nil {
return u, fmt.Errorf("Pre-login hook error: %v", err)
return u, fmt.Errorf("pre-login hook error: %v, elapsed %v", err, time.Since(startTime))
}
providerLog(logger.LevelDebug, "pre-login hook completed, elapsed: %v", time.Since(startTime))
if strings.TrimSpace(string(out)) == "" {
providerLog(logger.LevelDebug, "empty response from pre-login hook, no modification requested for user %#v id: %v",
username, u.ID)
@ -2098,7 +2184,7 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
userLastLogin := u.LastLogin
err = json.Unmarshal(out, &u)
if err != nil {
return u, fmt.Errorf("Invalid pre-login hook response %#v, error: %v", string(out), err)
return u, fmt.Errorf("invalid pre-login hook response %#v, error: %v", string(out), err)
}
u.ID = userID
u.UsedQuotaSize = userUsedQuotaSize
@ -2214,9 +2300,11 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip,
return result, err
}
defer resp.Body.Close()
providerLog(logger.LevelDebug, "external auth hook executed, response code: %v", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
return result, fmt.Errorf("wrong external auth http status code: %v, expected 200", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
@ -2250,13 +2338,15 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
return user, err
}
}
startTime := time.Now()
out, err := getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol, cert)
if err != nil {
return user, fmt.Errorf("External auth error: %v", err)
return user, fmt.Errorf("external auth error: %v, elapsed: %v", err, time.Since(startTime))
}
providerLog(logger.LevelDebug, "external auth completed, elapsed: %v", time.Since(startTime))
err = json.Unmarshal(out, &user)
if err != nil {
return user, fmt.Errorf("Invalid external auth response: %v", err)
return user, fmt.Errorf("invalid external auth response: %v", err)
}
if user.Username == "" {
return user, ErrInvalidCredentials

View file

@ -105,7 +105,7 @@ func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCe
func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) {
var user User
if password == "" {
return user, errors.New("Credentials cannot be null or empty")
return user, errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
@ -118,7 +118,7 @@ func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol st
func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte) (User, string, error) {
var user User
if len(pubKey) == 0 {
return user, "", errors.New("Credentials cannot be null or empty")
return user, "", errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
@ -548,15 +548,12 @@ func (p *MemoryProvider) getUsedFolderQuota(name string) (int, int64, error) {
func (p *MemoryProvider) joinVirtualFoldersFields(user *User) []vfs.VirtualFolder {
var folders []vfs.VirtualFolder
for _, folder := range user.VirtualFolders {
f, err := p.addOrGetFolderInternal(folder.Name, folder.MappedPath, user.Username)
for idx := range user.VirtualFolders {
folder := &user.VirtualFolders[idx]
f, err := p.addOrUpdateFolderInternal(&folder.BaseVirtualFolder, user.Username, 0, 0, 0)
if err == nil {
folder.UsedQuotaFiles = f.UsedQuotaFiles
folder.UsedQuotaSize = f.UsedQuotaSize
folder.LastQuotaUpdate = f.LastQuotaUpdate
folder.ID = f.ID
folder.MappedPath = f.MappedPath
folders = append(folders, folder)
folder.BaseVirtualFolder = f
folders = append(folders, *folder)
}
}
return folders
@ -584,24 +581,29 @@ func (p *MemoryProvider) updateFoldersMappingInternal(folder vfs.BaseVirtualFold
}
}
func (p *MemoryProvider) addOrGetFolderInternal(folderName, folderMappedPath, username string) (vfs.BaseVirtualFolder, error) {
folder, err := p.folderExistsInternal(folderName)
if _, ok := err.(*RecordNotFoundError); ok {
folder := vfs.BaseVirtualFolder{
ID: p.getNextFolderID(),
Name: folderName,
MappedPath: folderMappedPath,
UsedQuotaSize: 0,
UsedQuotaFiles: 0,
LastQuotaUpdate: 0,
Users: []string{username},
func (p *MemoryProvider) addOrUpdateFolderInternal(baseFolder *vfs.BaseVirtualFolder, username string, usedQuotaSize int64,
usedQuotaFiles int, lastQuotaUpdate int64) (vfs.BaseVirtualFolder, error) {
folder, err := p.folderExistsInternal(baseFolder.Name)
if err == nil {
// exists
folder.MappedPath = baseFolder.MappedPath
folder.Description = baseFolder.Description
folder.FsConfig = baseFolder.FsConfig.GetACopy()
if !utils.IsStringInSlice(username, folder.Users) {
folder.Users = append(folder.Users, username)
}
p.updateFoldersMappingInternal(folder)
return folder, nil
}
if err == nil && !utils.IsStringInSlice(username, folder.Users) {
folder.Users = append(folder.Users, username)
if _, ok := err.(*RecordNotFoundError); ok {
folder = baseFolder.GetACopy()
folder.ID = p.getNextFolderID()
folder.UsedQuotaSize = usedQuotaSize
folder.UsedQuotaFiles = usedQuotaFiles
folder.LastQuotaUpdate = lastQuotaUpdate
folder.Users = []string{username}
p.updateFoldersMappingInternal(folder)
return folder, nil
}
return folder, err
}
@ -631,7 +633,9 @@ func (p *MemoryProvider) getFolders(limit, offset int, order string) ([]vfs.Base
if itNum <= offset {
continue
}
folder := p.dbHandle.vfolders[name]
f := p.dbHandle.vfolders[name]
folder := f.GetACopy()
folder.HideConfidentialData()
folders = append(folders, folder)
if len(folders) >= limit {
break
@ -644,7 +648,9 @@ func (p *MemoryProvider) getFolders(limit, offset int, order string) ([]vfs.Base
continue
}
name := p.dbHandle.vfoldersNames[i]
folder := p.dbHandle.vfolders[name]
f := p.dbHandle.vfolders[name]
folder := f.GetACopy()
folder.HideConfidentialData()
folders = append(folders, folder)
if len(folders) >= limit {
break
@ -660,7 +666,11 @@ func (p *MemoryProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, er
if p.dbHandle.isClosed {
return vfs.BaseVirtualFolder{}, errMemoryProviderClosed
}
return p.folderExistsInternal(name)
folder, err := p.folderExistsInternal(name)
if err != nil {
return vfs.BaseVirtualFolder{}, err
}
return folder.GetACopy(), nil
}
func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error {
@ -708,6 +718,22 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error {
folder.UsedQuotaSize = f.UsedQuotaSize
folder.Users = f.Users
p.dbHandle.vfolders[folder.Name] = folder.GetACopy()
// now update the related users
for _, username := range folder.Users {
user, err := p.userExistsInternal(username)
if err == nil {
var folders []vfs.VirtualFolder
for idx := range user.VirtualFolders {
userFolder := &user.VirtualFolders[idx]
if folder.Name == userFolder.Name {
userFolder.BaseVirtualFolder = folder.GetACopy()
}
folders = append(folders, *userFolder)
}
user.VirtualFolders = folders
p.dbHandle.users[user.Username] = user
}
}
return nil
}
@ -726,9 +752,10 @@ func (p *MemoryProvider) deleteFolder(folder *vfs.BaseVirtualFolder) error {
user, err := p.userExistsInternal(username)
if err == nil {
var folders []vfs.VirtualFolder
for _, userFolder := range user.VirtualFolders {
for idx := range user.VirtualFolders {
userFolder := &user.VirtualFolders[idx]
if folder.Name != userFolder.Name {
folders = append(folders, userFolder)
folders = append(folders, *userFolder)
}
}
user.VirtualFolders = folders

View file

@ -257,7 +257,7 @@ func (p *MySQLProvider) migrateDatabase() error {
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("Database version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@ -274,7 +274,7 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
case 9:
return downgradeMySQLDatabaseFromV9(p.dbHandle)
default:
return fmt.Errorf("Database version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}

View file

@ -265,7 +265,7 @@ func (p *PGSQLProvider) migrateDatabase() error {
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("Database version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@ -282,7 +282,7 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
case 9:
return downgradePGSQLDatabaseFromV9(p.dbHandle)
default:
return fmt.Errorf("Database version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}

View file

@ -217,7 +217,7 @@ func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, err
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
var user User
if password == "" {
return user, errors.New("Credentials cannot be null or empty")
return user, errors.New("credentials cannot be null or empty")
}
user, err := sqlCommonGetUserByUsername(username, dbHandle)
if err != nil {
@ -243,7 +243,7 @@ func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *
func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
var user User
if len(pubKey) == 0 {
return user, "", errors.New("Credentials cannot be null or empty")
return user, "", errors.New("credentials cannot be null or empty")
}
user, err := sqlCommonGetUserByUsername(username, dbHandle)
if err != nil {
@ -587,7 +587,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
}
}
if fsConfig.Valid {
var fs Filesystem
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
user.FsConfig = fs
@ -603,7 +603,20 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
return user, err
}
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
var folderName string
q := checkFolderNameQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, name)
return row.Scan(&folderName)
}
func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
var folder vfs.BaseVirtualFolder
q := getFolderByNameQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
@ -613,9 +626,9 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, name)
var mappedPath, description sql.NullString
var mappedPath, description, fsConfig sql.NullString
err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
&folder.Name, &description)
&folder.Name, &description, &fsConfig)
if err == sql.ErrNoRows {
return folder, &RecordNotFoundError{err: err.Error()}
}
@ -625,11 +638,18 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
return folder, err
}
func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle)
folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
if err != nil {
return folder, err
}
@ -643,23 +663,30 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer
return folders[0], nil
}
func sqlCommonAddOrGetFolder(ctx context.Context, baseFolder vfs.BaseVirtualFolder, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
folder, err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
if _, ok := err.(*RecordNotFoundError); ok {
f := &vfs.BaseVirtualFolder{
Name: baseFolder.Name,
MappedPath: baseFolder.MappedPath,
UsedQuotaSize: usedQuotaSize,
UsedQuotaFiles: usedQuotaFiles,
LastQuotaUpdate: lastQuotaUpdate,
}
err = sqlCommonAddFolder(f, dbHandle)
func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
var folder vfs.BaseVirtualFolder
// FIXME: we could use an UPSERT here, this SELECT could be racy
err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
switch err {
case nil:
err = sqlCommonUpdateFolder(baseFolder, dbHandle)
if err != nil {
return folder, err
}
return sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
case sql.ErrNoRows:
baseFolder.UsedQuotaFiles = usedQuotaFiles
baseFolder.UsedQuotaSize = usedQuotaSize
baseFolder.LastQuotaUpdate = lastQuotaUpdate
err = sqlCommonAddFolder(baseFolder, dbHandle)
if err != nil {
return folder, err
}
default:
return folder, err
}
return folder, err
return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle)
}
func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
@ -667,6 +694,10 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
if err != nil {
return err
}
fsConfig, err := json.Marshal(folder.FsConfig)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getAddFolderQuery()
@ -677,15 +708,19 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
folder.LastQuotaUpdate, folder.Name, folder.Description)
folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
return err
}
func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) error {
func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
err := ValidateFolder(folder)
if err != nil {
return err
}
fsConfig, err := json.Marshal(folder.FsConfig)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUpdateFolderQuery()
@ -695,7 +730,7 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) erro
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, folder.Name)
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
return err
}
@ -731,9 +766,9 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
defer rows.Close()
for rows.Next() {
var folder vfs.BaseVirtualFolder
var mappedPath, description sql.NullString
var mappedPath, description, fsConfig sql.NullString
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.Name, &description)
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
if err != nil {
return folders, err
}
@ -743,6 +778,13 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
folders = append(folders, folder)
}
err = rows.Err()
@ -771,9 +813,9 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
defer rows.Close()
for rows.Next() {
var folder vfs.BaseVirtualFolder
var mappedPath, description sql.NullString
var mappedPath, description, fsConfig sql.NullString
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.Name, &description)
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
if err != nil {
return folders, err
}
@ -783,6 +825,14 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
folder.HideConfidentialData()
folders = append(folders, folder)
}
@ -805,7 +855,7 @@ func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQu
return err
}
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
q := getAddFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
@ -822,8 +872,9 @@ func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sql
if err != nil {
return err
}
for _, vfolder := range user.VirtualFolders {
f, err := sqlCommonAddOrGetFolder(ctx, vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
for idx := range user.VirtualFolders {
vfolder := &user.VirtualFolders[idx]
f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
if err != nil {
return err
}
@ -870,15 +921,26 @@ func getUsersWithVirtualFolders(users []User, dbHandle sqlQuerier) ([]User, erro
for rows.Next() {
var folder vfs.VirtualFolder
var userID int64
var mappedPath sql.NullString
var mappedPath, fsConfig, description sql.NullString
err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID)
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
&description)
if err != nil {
return users, err
}
if mappedPath.Valid {
folder.MappedPath = mappedPath.String
}
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
}
err = rows.Err()

View file

@ -95,7 +95,7 @@ func initializeSQLiteProvider(basePath string) error {
if config.ConnectionString == "" {
dbPath := config.Name
if !utils.IsFileInputValid(dbPath) {
return fmt.Errorf("Invalid database path: %#v", dbPath)
return fmt.Errorf("invalid database path: %#v", dbPath)
}
if !filepath.IsAbs(dbPath) {
dbPath = filepath.Join(basePath, dbPath)
@ -278,7 +278,7 @@ func (p *SQLiteProvider) migrateDatabase() error {
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("Database version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@ -295,7 +295,7 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
case 9:
return downgradeSQLiteDatabaseFromV9(p.dbHandle)
default:
return fmt.Errorf("Database version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}

View file

@ -12,7 +12,7 @@ const (
selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," +
"used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," +
"additional_info,description"
selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description"
selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem"
selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description"
)
@ -119,15 +119,19 @@ func getFolderByNameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0])
}
func checkFolderNameQuery() string {
return fmt.Sprintf(`SELECT name FROM %v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0])
}
func getAddFolderQuery() string {
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description) VALUES (%v,%v,%v,%v,%v,%v)`,
sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5])
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
}
func getUpdateFolderQuery() string {
return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0],
sqlPlaceholders[1], sqlPlaceholders[2])
return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v,filesystem=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0],
sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getDeleteFolderQuery() string {
@ -177,9 +181,9 @@ func getRelatedFoldersForUsersQuery(users []User) string {
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,fm.quota_size,fm.quota_files,fm.user_id
FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders,
sqlTableFoldersMapping, sb.String())
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,
fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE
fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders, sqlTableFoldersMapping, sb.String())
}
func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string {

View file

@ -161,29 +161,6 @@ type UserFilters struct {
TLSUsername TLSUsername `json:"tls_username,omitempty"`
}
// FilesystemProvider defines the supported storages
type FilesystemProvider int
// supported values for FilesystemProvider
const (
LocalFilesystemProvider FilesystemProvider = iota // Local
S3FilesystemProvider // AWS S3 compatible
GCSFilesystemProvider // Google Cloud Storage
AzureBlobFilesystemProvider // Azure Blob Storage
CryptedFilesystemProvider // Local encrypted
SFTPFilesystemProvider // SFTP
)
// Filesystem defines cloud storage filesystem details
type Filesystem struct {
Provider FilesystemProvider `json:"provider"`
S3Config vfs.S3FsConfig `json:"s3config,omitempty"`
GCSConfig vfs.GCSFsConfig `json:"gcsconfig,omitempty"`
AzBlobConfig vfs.AzBlobFsConfig `json:"azblobconfig,omitempty"`
CryptConfig vfs.CryptFsConfig `json:"cryptconfig,omitempty"`
SFTPConfig vfs.SFTPFsConfig `json:"sftpconfig,omitempty"`
}
// User defines a SFTPGo user
type User struct {
// Database unique identifier
@ -203,8 +180,7 @@ type User struct {
PublicKeys []string `json:"public_keys,omitempty"`
// The user cannot upload or download files outside this directory. Must be an absolute path
HomeDir string `json:"home_dir"`
// Mapping between virtual paths and filesystem paths outside the home directory.
// Supported for local filesystem only
// Mapping between virtual paths and virtual folders
VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"`
// If sftpgo runs as root system user then the created files and directories will be assigned to this system UID
UID int `json:"uid"`
@ -233,49 +209,148 @@ type User struct {
// Additional restrictions
Filters UserFilters `json:"filters"`
// Filesystem configuration details
FsConfig Filesystem `json:"filesystem"`
FsConfig vfs.Filesystem `json:"filesystem"`
// optional description, for example full name
Description string `json:"description,omitempty"`
// free form text field for external systems
AdditionalInfo string `json:"additional_info,omitempty"`
// we store the filesystem here using the base path as key.
fsCache map[string]vfs.Fs `json:"-"`
}
// GetFilesystem returns the filesystem for this user
func (u *User) GetFilesystem(connectionID string) (vfs.Fs, error) {
switch u.FsConfig.Provider {
case S3FilesystemProvider:
return vfs.NewS3Fs(connectionID, u.GetHomeDir(), u.FsConfig.S3Config)
case GCSFilesystemProvider:
config := u.FsConfig.GCSConfig
config.CredentialFile = u.getGCSCredentialsFilePath()
return vfs.NewGCSFs(connectionID, u.GetHomeDir(), config)
case AzureBlobFilesystemProvider:
return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), u.FsConfig.AzBlobConfig)
case CryptedFilesystemProvider:
return vfs.NewCryptFs(connectionID, u.GetHomeDir(), u.FsConfig.CryptConfig)
case SFTPFilesystemProvider:
return vfs.NewSFTPFs(connectionID, u.FsConfig.SFTPConfig)
default:
return vfs.NewOsFs(connectionID, u.GetHomeDir(), u.VirtualFolders), nil
// GetFilesystem returns the base filesystem for this user
func (u *User) GetFilesystem(connectionID string) (fs vfs.Fs, err error) {
fs, err = u.getRootFs(connectionID)
if err != nil {
return fs, err
}
u.fsCache = make(map[string]vfs.Fs)
u.fsCache["/"] = fs
return fs, err
}
func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) {
switch u.FsConfig.Provider {
case vfs.S3FilesystemProvider:
return vfs.NewS3Fs(connectionID, u.GetHomeDir(), "", u.FsConfig.S3Config)
case vfs.GCSFilesystemProvider:
config := u.FsConfig.GCSConfig
config.CredentialFile = u.GetGCSCredentialsFilePath()
return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", config)
case vfs.AzureBlobFilesystemProvider:
return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), "", u.FsConfig.AzBlobConfig)
case vfs.CryptedFilesystemProvider:
return vfs.NewCryptFs(connectionID, u.GetHomeDir(), "", u.FsConfig.CryptConfig)
case vfs.SFTPFilesystemProvider:
return vfs.NewSFTPFs(connectionID, "", u.FsConfig.SFTPConfig)
default:
return vfs.NewOsFs(connectionID, u.GetHomeDir(), ""), nil
}
}
// CheckFsRoot check the root directory for the main fs and the virtual folders.
// It returns an error if the main filesystem cannot be created
func (u *User) CheckFsRoot(connectionID string) error {
fs, err := u.GetFilesystemForPath("/", connectionID)
if err != nil {
logger.Warn(logSender, connectionID, "could not create main filesystem for user %#v err: %v", u.Username, err)
return err
}
fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID())
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
fs, err = u.GetFilesystemForPath(v.VirtualPath, connectionID)
if err == nil {
fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID())
}
// now check intermediary folders
fs, err = u.GetFilesystemForPath(path.Dir(v.VirtualPath), connectionID)
if err == nil && !fs.HasVirtualFolders() {
fsPath, err := fs.ResolvePath(v.VirtualPath)
if err != nil {
continue
}
err = fs.MkdirAll(fsPath, u.GetUID(), u.GetGID())
logger.Debug(logSender, connectionID, "create intermediary dir to %#v, path %#v, err: %v",
v.VirtualPath, fsPath, err)
}
}
return nil
}
// HideConfidentialData hides user confidential data
func (u *User) HideConfidentialData() {
u.Password = ""
switch u.FsConfig.Provider {
case S3FilesystemProvider:
case vfs.S3FilesystemProvider:
u.FsConfig.S3Config.AccessSecret.Hide()
case GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
u.FsConfig.GCSConfig.Credentials.Hide()
case AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
u.FsConfig.AzBlobConfig.AccountKey.Hide()
case CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
u.FsConfig.CryptConfig.Passphrase.Hide()
case SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
u.FsConfig.SFTPConfig.Password.Hide()
u.FsConfig.SFTPConfig.PrivateKey.Hide()
}
for idx := range u.VirtualFolders {
folder := &u.VirtualFolders[idx]
folder.HideConfidentialData()
}
}
func (u *User) hasRedactedSecret() bool {
switch u.FsConfig.Provider {
case vfs.S3FilesystemProvider:
if u.FsConfig.S3Config.AccessSecret.IsRedacted() {
return true
}
case vfs.GCSFilesystemProvider:
if u.FsConfig.GCSConfig.Credentials.IsRedacted() {
return true
}
case vfs.AzureBlobFilesystemProvider:
if u.FsConfig.AzBlobConfig.AccountKey.IsRedacted() {
return true
}
case vfs.CryptedFilesystemProvider:
if u.FsConfig.CryptConfig.Passphrase.IsRedacted() {
return true
}
case vfs.SFTPFilesystemProvider:
if u.FsConfig.SFTPConfig.Password.IsRedacted() {
return true
}
if u.FsConfig.SFTPConfig.PrivateKey.IsRedacted() {
return true
}
}
for idx := range u.VirtualFolders {
folder := &u.VirtualFolders[idx]
if folder.HasRedactedSecret() {
return true
}
}
return false
}
// CloseFs closes the underlying filesystems
func (u *User) CloseFs() error {
if u.fsCache == nil {
return nil
}
var err error
for _, fs := range u.fsCache {
errClose := fs.Close()
if err == nil {
err = errClose
}
}
return err
}
// IsPasswordHashed returns true if the password is hashed
@ -300,41 +375,10 @@ func (u *User) SetEmptySecrets() {
u.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret()
u.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret()
}
// DecryptSecrets tries to decrypts kms secrets
func (u *User) DecryptSecrets() error {
switch u.FsConfig.Provider {
case S3FilesystemProvider:
if u.FsConfig.S3Config.AccessSecret.IsEncrypted() {
return u.FsConfig.S3Config.AccessSecret.Decrypt()
}
case GCSFilesystemProvider:
if u.FsConfig.GCSConfig.Credentials.IsEncrypted() {
return u.FsConfig.GCSConfig.Credentials.Decrypt()
}
case AzureBlobFilesystemProvider:
if u.FsConfig.AzBlobConfig.AccountKey.IsEncrypted() {
return u.FsConfig.AzBlobConfig.AccountKey.Decrypt()
}
case CryptedFilesystemProvider:
if u.FsConfig.CryptConfig.Passphrase.IsEncrypted() {
return u.FsConfig.CryptConfig.Passphrase.Decrypt()
}
case SFTPFilesystemProvider:
if u.FsConfig.SFTPConfig.Password.IsEncrypted() {
if err := u.FsConfig.SFTPConfig.Password.Decrypt(); err != nil {
return err
}
}
if u.FsConfig.SFTPConfig.PrivateKey.IsEncrypted() {
if err := u.FsConfig.SFTPConfig.PrivateKey.Decrypt(); err != nil {
return err
}
}
for idx := range u.VirtualFolders {
folder := &u.VirtualFolders[idx]
folder.FsConfig.SetEmptySecretsIfNil()
}
return nil
}
// GetPermissionsForPath returns the permissions for the given path.
@ -349,13 +393,13 @@ func (u *User) GetPermissionsForPath(p string) []string {
// fallback permissions
permissions = perms
}
dirsForPath := utils.GetDirsForSFTPPath(p)
dirsForPath := utils.GetDirsForVirtualPath(p)
// dirsForPath contains all the dirs for a given path in reverse order
// for example if the path is: /1/2/3/4 it contains:
// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ]
// so the first match is the one we are interested to
for _, val := range dirsForPath {
if perms, ok := u.Permissions[val]; ok {
for idx := range dirsForPath {
if perms, ok := u.Permissions[dirsForPath[idx]]; ok {
permissions = perms
break
}
@ -363,44 +407,120 @@ func (u *User) GetPermissionsForPath(p string) []string {
return permissions
}
// GetVirtualFolderForPath returns the virtual folder containing the specified sftp path.
// GetFilesystemForPath returns the filesystem for the given path
func (u *User) GetFilesystemForPath(virtualPath, connectionID string) (vfs.Fs, error) {
if u.fsCache == nil {
u.fsCache = make(map[string]vfs.Fs)
}
if virtualPath != "" && virtualPath != "/" && len(u.VirtualFolders) > 0 {
folder, err := u.GetVirtualFolderForPath(virtualPath)
if err == nil {
if fs, ok := u.fsCache[folder.VirtualPath]; ok {
return fs, nil
}
fs, err := folder.GetFilesystem(connectionID)
if err == nil {
u.fsCache[folder.VirtualPath] = fs
}
return fs, err
}
}
if val, ok := u.fsCache["/"]; ok {
return val, nil
}
return u.GetFilesystem(connectionID)
}
// GetVirtualFolderForPath returns the virtual folder containing the specified virtual path.
// If the path is not inside a virtual folder an error is returned
func (u *User) GetVirtualFolderForPath(sftpPath string) (vfs.VirtualFolder, error) {
func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, error) {
var folder vfs.VirtualFolder
if len(u.VirtualFolders) == 0 || u.FsConfig.Provider != LocalFilesystemProvider {
if len(u.VirtualFolders) == 0 {
return folder, errNoMatchingVirtualFolder
}
dirsForPath := utils.GetDirsForSFTPPath(sftpPath)
for _, val := range dirsForPath {
for _, v := range u.VirtualFolders {
if v.VirtualPath == val {
return v, nil
dirsForPath := utils.GetDirsForVirtualPath(virtualPath)
for index := range dirsForPath {
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
if v.VirtualPath == dirsForPath[index] {
return *v, nil
}
}
}
return folder, errNoMatchingVirtualFolder
}
// ScanQuota scans the user home dir and virtual folders, included in its quota,
// and returns the number of files and their size
func (u *User) ScanQuota() (int, int64, error) {
fs, err := u.getRootFs("")
if err != nil {
return 0, 0, err
}
defer fs.Close()
numFiles, size, err := fs.ScanRootDirContents()
if err != nil {
return numFiles, size, err
}
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
if !v.IsIncludedInUserQuota() {
continue
}
num, s, err := v.ScanQuota()
if err != nil {
return numFiles, size, err
}
numFiles += num
size += s
}
return numFiles, size, nil
}
// GetVirtualFoldersInPath returns the virtual folders inside virtualPath including
// any parents
func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool {
result := make(map[string]bool)
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
dirsForPath := utils.GetDirsForVirtualPath(v.VirtualPath)
for index := range dirsForPath {
d := dirsForPath[index]
if d == "/" {
continue
}
if path.Dir(d) == virtualPath {
result[d] = true
}
}
}
return result
}
// AddVirtualDirs adds virtual folders, if defined, to the given files list
func (u *User) AddVirtualDirs(list []os.FileInfo, sftpPath string) []os.FileInfo {
func (u *User) AddVirtualDirs(list []os.FileInfo, virtualPath string) []os.FileInfo {
if len(u.VirtualFolders) == 0 {
return list
}
for _, v := range u.VirtualFolders {
if path.Dir(v.VirtualPath) == sftpPath {
fi := vfs.NewFileInfo(v.VirtualPath, true, 0, time.Now(), false)
found := false
for index, f := range list {
if f.Name() == fi.Name() {
list[index] = fi
found = true
break
}
}
if !found {
list = append(list, fi)
for dir := range u.GetVirtualFoldersInPath(virtualPath) {
fi := vfs.NewFileInfo(dir, true, 0, time.Now(), false)
found := false
for index := range list {
if list[index].Name() == fi.Name() {
list[index] = fi
found = true
break
}
}
if !found {
list = append(list, fi)
}
}
return list
}
@ -408,7 +528,8 @@ func (u *User) AddVirtualDirs(list []os.FileInfo, sftpPath string) []os.FileInfo
// IsMappedPath returns true if the specified filesystem path has a virtual folder mapping.
// The filesystem path must be cleaned before calling this method
func (u *User) IsMappedPath(fsPath string) bool {
for _, v := range u.VirtualFolders {
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
if fsPath == v.MappedPath {
return true
}
@ -416,10 +537,11 @@ func (u *User) IsMappedPath(fsPath string) bool {
return false
}
// IsVirtualFolder returns true if the specified sftp path is a virtual folder
func (u *User) IsVirtualFolder(sftpPath string) bool {
for _, v := range u.VirtualFolders {
if sftpPath == v.VirtualPath {
// IsVirtualFolder returns true if the specified virtual path is a virtual folder
func (u *User) IsVirtualFolder(virtualPath string) bool {
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
if virtualPath == v.VirtualPath {
return true
}
}
@ -427,14 +549,15 @@ func (u *User) IsVirtualFolder(sftpPath string) bool {
}
// HasVirtualFoldersInside returns true if there are virtual folders inside the
// specified SFTP path. We assume that path are cleaned
func (u *User) HasVirtualFoldersInside(sftpPath string) bool {
if sftpPath == "/" && len(u.VirtualFolders) > 0 {
// specified virtual path. We assume that path are cleaned
func (u *User) HasVirtualFoldersInside(virtualPath string) bool {
if virtualPath == "/" && len(u.VirtualFolders) > 0 {
return true
}
for _, v := range u.VirtualFolders {
if len(v.VirtualPath) > len(sftpPath) {
if strings.HasPrefix(v.VirtualPath, sftpPath+"/") {
for idx := range u.VirtualFolders {
v := &u.VirtualFolders[idx]
if len(v.VirtualPath) > len(virtualPath) {
if strings.HasPrefix(v.VirtualPath, virtualPath+"/") {
return true
}
}
@ -442,32 +565,14 @@ func (u *User) HasVirtualFoldersInside(sftpPath string) bool {
return false
}
// HasPermissionsInside returns true if the specified sftpPath has no permissions itself and
// HasPermissionsInside returns true if the specified virtualPath has no permissions itself and
// no subdirs with defined permissions
func (u *User) HasPermissionsInside(sftpPath string) bool {
func (u *User) HasPermissionsInside(virtualPath string) bool {
for dir := range u.Permissions {
if dir == sftpPath {
if dir == virtualPath {
return true
} else if len(dir) > len(sftpPath) {
if strings.HasPrefix(dir, sftpPath+"/") {
return true
}
}
}
return false
}
// HasOverlappedMappedPaths returns true if this user has virtual folders with overlapped mapped paths
func (u *User) HasOverlappedMappedPaths() bool {
if len(u.VirtualFolders) <= 1 {
return false
}
for _, v1 := range u.VirtualFolders {
for _, v2 := range u.VirtualFolders {
if v1.VirtualPath == v2.VirtualPath {
continue
}
if isMappedDirOverlapped(v1.MappedPath, v2.MappedPath) {
} else if len(dir) > len(virtualPath) {
if strings.HasPrefix(dir, virtualPath+"/") {
return true
}
}
@ -585,7 +690,7 @@ func (u *User) isFileExtensionAllowed(virtualPath string) bool {
if len(u.Filters.FileExtensions) == 0 {
return true
}
dirsForPath := utils.GetDirsForSFTPPath(path.Dir(virtualPath))
dirsForPath := utils.GetDirsForVirtualPath(path.Dir(virtualPath))
var filter ExtensionsFilter
for _, dir := range dirsForPath {
for _, f := range u.Filters.FileExtensions {
@ -619,7 +724,7 @@ func (u *User) isFilePatternAllowed(virtualPath string) bool {
if len(u.Filters.FilePatterns) == 0 {
return true
}
dirsForPath := utils.GetDirsForSFTPPath(path.Dir(virtualPath))
dirsForPath := utils.GetDirsForVirtualPath(path.Dir(virtualPath))
var filter PatternsFilter
for _, dir := range dirsForPath {
for _, f := range u.Filters.FilePatterns {
@ -797,15 +902,15 @@ func (u *User) GetInfoString() string {
result += fmt.Sprintf("Last login: %v ", t.Format("2006-01-02 15:04:05")) // YYYY-MM-DD HH:MM:SS
}
switch u.FsConfig.Provider {
case S3FilesystemProvider:
case vfs.S3FilesystemProvider:
result += "Storage: S3 "
case GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
result += "Storage: GCS "
case AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
result += "Storage: Azure "
case CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
result += "Storage: Encrypted "
case SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
result += "Storage: SFTP "
}
if len(u.PublicKeys) > 0 {
@ -850,23 +955,10 @@ func (u *User) GetDeniedIPAsString() string {
// SetEmptySecretsIfNil sets the secrets to empty if nil
func (u *User) SetEmptySecretsIfNil() {
if u.FsConfig.S3Config.AccessSecret == nil {
u.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret()
}
if u.FsConfig.GCSConfig.Credentials == nil {
u.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret()
}
if u.FsConfig.AzBlobConfig.AccountKey == nil {
u.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret()
}
if u.FsConfig.CryptConfig.Passphrase == nil {
u.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
}
if u.FsConfig.SFTPConfig.Password == nil {
u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret()
}
if u.FsConfig.SFTPConfig.PrivateKey == nil {
u.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret()
u.FsConfig.SetEmptySecretsIfNil()
for idx := range u.VirtualFolders {
vfolder := &u.VirtualFolders[idx]
vfolder.FsConfig.SetEmptySecretsIfNil()
}
}
@ -874,8 +966,11 @@ func (u *User) getACopy() User {
u.SetEmptySecretsIfNil()
pubKeys := make([]string, len(u.PublicKeys))
copy(pubKeys, u.PublicKeys)
virtualFolders := make([]vfs.VirtualFolder, len(u.VirtualFolders))
copy(virtualFolders, u.VirtualFolders)
virtualFolders := make([]vfs.VirtualFolder, 0, len(u.VirtualFolders))
for idx := range u.VirtualFolders {
vfolder := u.VirtualFolders[idx].GetACopy()
virtualFolders = append(virtualFolders, vfolder)
}
permissions := make(map[string][]string)
for k, v := range u.Permissions {
perms := make([]string, len(v))
@ -897,55 +992,6 @@ func (u *User) getACopy() User {
copy(filters.FilePatterns, u.Filters.FilePatterns)
filters.DeniedProtocols = make([]string, len(u.Filters.DeniedProtocols))
copy(filters.DeniedProtocols, u.Filters.DeniedProtocols)
fsConfig := Filesystem{
Provider: u.FsConfig.Provider,
S3Config: vfs.S3FsConfig{
Bucket: u.FsConfig.S3Config.Bucket,
Region: u.FsConfig.S3Config.Region,
AccessKey: u.FsConfig.S3Config.AccessKey,
AccessSecret: u.FsConfig.S3Config.AccessSecret.Clone(),
Endpoint: u.FsConfig.S3Config.Endpoint,
StorageClass: u.FsConfig.S3Config.StorageClass,
KeyPrefix: u.FsConfig.S3Config.KeyPrefix,
UploadPartSize: u.FsConfig.S3Config.UploadPartSize,
UploadConcurrency: u.FsConfig.S3Config.UploadConcurrency,
},
GCSConfig: vfs.GCSFsConfig{
Bucket: u.FsConfig.GCSConfig.Bucket,
CredentialFile: u.FsConfig.GCSConfig.CredentialFile,
Credentials: u.FsConfig.GCSConfig.Credentials.Clone(),
AutomaticCredentials: u.FsConfig.GCSConfig.AutomaticCredentials,
StorageClass: u.FsConfig.GCSConfig.StorageClass,
KeyPrefix: u.FsConfig.GCSConfig.KeyPrefix,
},
AzBlobConfig: vfs.AzBlobFsConfig{
Container: u.FsConfig.AzBlobConfig.Container,
AccountName: u.FsConfig.AzBlobConfig.AccountName,
AccountKey: u.FsConfig.AzBlobConfig.AccountKey.Clone(),
Endpoint: u.FsConfig.AzBlobConfig.Endpoint,
SASURL: u.FsConfig.AzBlobConfig.SASURL,
KeyPrefix: u.FsConfig.AzBlobConfig.KeyPrefix,
UploadPartSize: u.FsConfig.AzBlobConfig.UploadPartSize,
UploadConcurrency: u.FsConfig.AzBlobConfig.UploadConcurrency,
UseEmulator: u.FsConfig.AzBlobConfig.UseEmulator,
AccessTier: u.FsConfig.AzBlobConfig.AccessTier,
},
CryptConfig: vfs.CryptFsConfig{
Passphrase: u.FsConfig.CryptConfig.Passphrase.Clone(),
},
SFTPConfig: vfs.SFTPFsConfig{
Endpoint: u.FsConfig.SFTPConfig.Endpoint,
Username: u.FsConfig.SFTPConfig.Username,
Password: u.FsConfig.SFTPConfig.Password.Clone(),
PrivateKey: u.FsConfig.SFTPConfig.PrivateKey.Clone(),
Prefix: u.FsConfig.SFTPConfig.Prefix,
DisableCouncurrentReads: u.FsConfig.SFTPConfig.DisableCouncurrentReads,
},
}
if len(u.FsConfig.SFTPConfig.Fingerprints) > 0 {
fsConfig.SFTPConfig.Fingerprints = make([]string, len(u.FsConfig.SFTPConfig.Fingerprints))
copy(fsConfig.SFTPConfig.Fingerprints, u.FsConfig.SFTPConfig.Fingerprints)
}
return User{
ID: u.ID,
@ -969,7 +1015,7 @@ func (u *User) getACopy() User {
ExpirationDate: u.ExpirationDate,
LastLogin: u.LastLogin,
Filters: filters,
FsConfig: fsConfig,
FsConfig: u.FsConfig.GetACopy(),
AdditionalInfo: u.AdditionalInfo,
Description: u.Description,
}
@ -986,6 +1032,12 @@ func (u *User) getNotificationFieldsAsSlice(action string) []string {
}
}
func (u *User) getGCSCredentialsFilePath() string {
// GetEncrytionAdditionalData returns the additional data to use for AEAD
func (u *User) GetEncrytionAdditionalData() string {
return u.Username
}
// GetGCSCredentialsFilePath returns the path for GCS credentials
func (u *User) GetGCSCredentialsFilePath() string {
return filepath.Join(credentialsDirPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username))
}

View file

@ -36,7 +36,7 @@ SFTPGo supports the following built-in SSH commands:
- `md5sum`, `sha1sum`, `sha256sum`, `sha384sum`, `sha512sum`. Useful to check message digests for uploaded files.
- `cd`, `pwd`. Some SFTP clients do not support the SFTP SSH_FXP_REALPATH packet type, so they use `cd` and `pwd` SSH commands to get the initial directory. Currently `cd` does nothing and `pwd` always returns the `/` path. These commands will work with any storage backend but keep in mind that to calculate the hash we need to read the whole file, for remote backends this means downloading the file, for the encrypted backend this means decrypting the file.
- `sftpgo-copy`. This is a built-in copy implementation. It allows server side copy for files and directories. The first argument is the source file/directory and the second one is the destination file/directory, for example `sftpgo-copy <src> <dst>`. The command will fail if the destination exists. Copy for directories spanning virtual folders is not supported. Only local filesystem is supported: recursive copy for Cloud Storage filesystems requires a new request for every file in any case, so a real server side copy is not possible.
- `sftpgo-remove`. This is a built-in remove implementation. It allows to remove single files and to recursively remove directories. The first argument is the file/directory to remove, for example `sftpgo-remove <dst>`. Only local filesystem is supported: recursive remove for Cloud Storage filesystems requires a new request for every file in any case, so a server side remove is not possible.
- `sftpgo-remove`. This is a built-in remove implementation. It allows to remove single files and to recursively remove directories. The first argument is the file/directory to remove, for example `sftpgo-remove <dst>`. Only local and encrypted filesystems are supported: recursive remove for Cloud Storage filesystems requires a new request for every file in any case, so a server side remove is not possible.
The following SSH commands are enabled by default:

View file

@ -15,6 +15,7 @@ import (
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/httpdtest"
"github.com/drakkan/sftpgo/kms"
"github.com/drakkan/sftpgo/vfs"
)
func TestBasicFTPHandlingCryptFs(t *testing.T) {
@ -214,7 +215,7 @@ func TestResumeCryptFs(t *testing.T) {
func getTestUserWithCryptFs() dataprovider.User {
user := getTestUser()
user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider
user.FsConfig.Provider = vfs.CryptedFilesystemProvider
user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase")
return user
}

View file

@ -245,6 +245,10 @@ func TestMain(m *testing.M) {
if err != nil {
logger.ErrorToConsole("error creating banner file: %v", err)
}
// we run the test cases with UploadMode atomic and resume support. The non atomic code path
// simply does not execute some code so if it works in atomic mode will
// work in non atomic mode too
os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2")
err = config.LoadConfig(configDir, "")
if err != nil {
logger.ErrorToConsole("error loading configuration: %v", err)
@ -254,10 +258,6 @@ func TestMain(m *testing.M) {
logger.InfoToConsole("Starting FTPD tests, provider: %v", providerConf.Driver)
commonConf := config.GetCommonConfig()
// we run the test cases with UploadMode atomic and resume support. The non atomic code path
// simply does not execute some code so if it works in atomic mode will
// work in non atomic mode too
commonConf.UploadMode = 2
homeBasePath = os.TempDir()
if runtime.GOOS != osWindows {
commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"}
@ -1274,7 +1274,7 @@ func TestLoginWithIPilters(t *testing.T) {
func TestLoginWithDatabaseCredentials(t *testing.T) {
u := getTestUser()
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "test"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`)
@ -1323,7 +1323,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) {
func TestLoginInvalidFs(t *testing.T) {
u := getTestUser()
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "test"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials")
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
@ -2153,7 +2153,7 @@ func TestClientCertificateAuth(t *testing.T) {
// TLS username is not enabled, mutual TLS should fail
_, err = getFTPClient(user, true, tlsConfig)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Login method password is not allowed")
assert.Contains(t, err.Error(), "login method password is not allowed")
}
user.Filters.TLSUsername = dataprovider.TLSUsernameCN
@ -2186,7 +2186,7 @@ func TestClientCertificateAuth(t *testing.T) {
assert.NoError(t, err)
_, err = getFTPClient(user, true, tlsConfig)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Login method TLSCertificate+password is not allowed")
assert.Contains(t, err.Error(), "login method TLSCertificate+password is not allowed")
}
// disable FTP protocol
@ -2196,7 +2196,7 @@ func TestClientCertificateAuth(t *testing.T) {
assert.NoError(t, err)
_, err = getFTPClient(user, true, tlsConfig)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Protocol FTP is not allowed")
assert.Contains(t, err.Error(), "protocol FTP is not allowed")
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
@ -2238,7 +2238,7 @@ func TestClientCertificateAndPwdAuth(t *testing.T) {
_, err = getFTPClient(user, true, nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Login method password is not allowed")
assert.Contains(t, err.Error(), "login method password is not allowed")
}
user.Password = defaultPassword + "1"
_, err = getFTPClient(user, true, tlsConfig)
@ -2405,6 +2405,111 @@ func TestPreLoginHookWithClientCert(t *testing.T) {
assert.NoError(t, err)
}
func TestNestedVirtualFolders(t *testing.T) {
u := getTestUser()
localUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
u = getTestSFTPUser()
mappedPathCrypt := filepath.Join(os.TempDir(), "crypt")
folderNameCrypt := filepath.Base(mappedPathCrypt)
vdirCryptPath := "/vdir/crypt"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameCrypt,
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret(defaultPassword),
},
},
MappedPath: mappedPathCrypt,
},
VirtualPath: vdirCryptPath,
})
mappedPath := filepath.Join(os.TempDir(), "local")
folderName := filepath.Base(mappedPath)
vdirPath := "/vdir/local"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: mappedPath,
},
VirtualPath: vdirPath,
})
mappedPathNested := filepath.Join(os.TempDir(), "nested")
folderNameNested := filepath.Base(mappedPathNested)
vdirNestedPath := "/vdir/crypt/nested"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameNested,
MappedPath: mappedPathNested,
},
VirtualPath: vdirNestedPath,
QuotaFiles: -1,
QuotaSize: -1,
})
sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getFTPClient(sftpUser, false, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
localDownloadPath := filepath.Join(homeBasePath, testDLFileName)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.NoError(t, err)
err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0)
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client, 0)
assert.NoError(t, err)
err = ftpDownloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client, 0)
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client, 0)
assert.NoError(t, err)
err = ftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client, 0)
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client, 0)
assert.NoError(t, err)
err = ftpDownloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client, 0)
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client, 0)
assert.NoError(t, err)
err = ftpDownloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client, 0)
assert.NoError(t, err)
err = client.Quit()
assert.NoError(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
err = os.Remove(localDownloadPath)
assert.NoError(t, err)
}
_, err = httpdtest.RemoveUser(sftpUser, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(localUser, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathCrypt)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathNested)
assert.NoError(t, err)
err = os.RemoveAll(localUser.GetHomeDir())
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
}
func checkBasicFTP(client *ftp.ServerConn) error {
_, err := client.CurrentDir()
if err != nil {
@ -2562,7 +2667,7 @@ func getTestUser() dataprovider.User {
func getTestSFTPUser() dataprovider.User {
u := getTestUser()
u.Username = u.Username + "_sftp"
u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider
u.FsConfig.Provider = vfs.SFTPFilesystemProvider
u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr
u.FsConfig.SFTPConfig.Username = defaultUsername
u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword)

View file

@ -17,7 +17,7 @@ import (
)
var (
errNotImplemented = errors.New("Not implemented")
errNotImplemented = errors.New("not implemented")
errCOMBNotSupported = errors.New("COMB is not supported for this filesystem")
)
@ -63,11 +63,7 @@ func (c *Connection) Create(name string) (afero.File, error) {
func (c *Connection) Mkdir(name string, perm os.FileMode) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
if err != nil {
return c.GetFsError(err)
}
return c.CreateDir(p, name)
return c.CreateDir(name)
}
// MkdirAll is not implemented, we don't need it
@ -90,22 +86,22 @@ func (c *Connection) OpenFile(name string, flag int, perm os.FileMode) (afero.Fi
func (c *Connection) Remove(name string) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
fs, p, err := c.GetFsAndResolvedPath(name)
if err != nil {
return c.GetFsError(err)
return err
}
var fi os.FileInfo
if fi, err = c.Fs.Lstat(p); err != nil {
if fi, err = fs.Lstat(p); err != nil {
c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", p, err)
return c.GetFsError(err)
return c.GetFsError(fs, err)
}
if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 {
c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", p)
return c.GetGenericError(nil)
}
return c.RemoveFile(p, name, fi)
return c.RemoveFile(fs, p, name, fi)
}
// RemoveAll is not implemented, we don't need it
@ -117,20 +113,10 @@ func (c *Connection) RemoveAll(path string) error {
func (c *Connection) Rename(oldname, newname string) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(oldname)
if err != nil {
return c.GetFsError(err)
}
t, err := c.Fs.ResolvePath(newname)
if err != nil {
return c.GetFsError(err)
}
if err = c.BaseConnection.Rename(p, t, oldname, newname); err != nil {
if err := c.BaseConnection.Rename(oldname, newname); err != nil {
return err
}
vfs.SetPathPermissions(c.Fs, t, c.User.GetUID(), c.User.GetGID())
return nil
}
@ -143,14 +129,10 @@ func (c *Connection) Stat(name string) (os.FileInfo, error) {
return nil, c.GetPermissionDeniedError()
}
p, err := c.Fs.ResolvePath(name)
fi, err := c.DoStat(name, 0)
if err != nil {
return nil, c.GetFsError(err)
}
fi, err := c.DoStat(p, 0)
if err != nil {
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
return nil, c.GetFsError(err)
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", name, err)
return nil, err
}
return fi, nil
}
@ -182,31 +164,23 @@ func (c *Connection) Chown(name string, uid, gid int) error {
func (c *Connection) Chmod(name string, mode os.FileMode) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
if err != nil {
return c.GetFsError(err)
}
attrs := common.StatAttributes{
Flags: common.StatAttrPerms,
Mode: mode,
}
return c.SetStat(p, name, &attrs)
return c.SetStat(name, &attrs)
}
// Chtimes changes the access and modification times of the named file
func (c *Connection) Chtimes(name string, atime time.Time, mtime time.Time) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
if err != nil {
return c.GetFsError(err)
}
attrs := common.StatAttributes{
Flags: common.StatAttrTimes,
Atime: atime,
Mtime: mtime,
}
return c.SetStat(p, name, &attrs)
return c.SetStat(name, &attrs)
}
// GetAvailableSpace implements ClientDriverExtensionAvailableSpace interface
@ -224,14 +198,14 @@ func (c *Connection) GetAvailableSpace(dirName string) (int64, error) {
return c.User.Filters.MaxUploadFileSize, nil
}
p, err := c.Fs.ResolvePath(dirName)
fs, p, err := c.GetFsAndResolvedPath(dirName)
if err != nil {
return 0, c.GetFsError(err)
return 0, err
}
statVFS, err := c.Fs.GetAvailableDiskSize(p)
statVFS, err := fs.GetAvailableDiskSize(p)
if err != nil {
return 0, c.GetFsError(err)
return 0, c.GetFsError(fs, err)
}
return int64(statVFS.FreeSpace()), nil
}
@ -281,61 +255,43 @@ func (c *Connection) AllocateSpace(size int) error {
func (c *Connection) RemoveDir(name string) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
if err != nil {
return c.GetFsError(err)
}
return c.BaseConnection.RemoveDir(p, name)
return c.BaseConnection.RemoveDir(name)
}
// Symlink implements ClientDriverExtensionSymlink
func (c *Connection) Symlink(oldname, newname string) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(oldname)
if err != nil {
return c.GetFsError(err)
}
t, err := c.Fs.ResolvePath(newname)
if err != nil {
return c.GetFsError(err)
}
return c.BaseConnection.CreateSymlink(p, t, oldname, newname)
return c.BaseConnection.CreateSymlink(oldname, newname)
}
// ReadDir implements ClientDriverExtensionFilelist
func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
if err != nil {
return nil, c.GetFsError(err)
}
return c.ListDir(p, name)
return c.ListDir(name)
}
// GetHandle implements ClientDriverExtentionFileTransfer
func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver.FileTransfer, error) {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(name)
fs, p, err := c.GetFsAndResolvedPath(name)
if err != nil {
return nil, c.GetFsError(err)
return nil, err
}
if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(c.Fs) {
if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(fs) {
return nil, errCOMBNotSupported
}
if flags&os.O_WRONLY != 0 {
return c.uploadFile(p, name, flags)
return c.uploadFile(fs, p, name, flags)
}
return c.downloadFile(p, name, offset)
return c.downloadFile(fs, p, name, offset)
}
func (c *Connection) downloadFile(fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) {
func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) {
if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(ftpPath)) {
return nil, c.GetPermissionDeniedError()
}
@ -345,41 +301,41 @@ func (c *Connection) downloadFile(fsPath, ftpPath string, offset int64) (ftpserv
return nil, c.GetPermissionDeniedError()
}
file, r, cancelFn, err := c.Fs.Open(fsPath, offset)
file, r, cancelFn, err := fs.Open(fsPath, offset)
if err != nil {
c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", fsPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, ftpPath, common.TransferDownload,
0, 0, 0, false, c.Fs)
0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, offset)
return t, nil
}
func (c *Connection) uploadFile(fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) {
func (c *Connection) uploadFile(fs vfs.Fs, fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) {
if !c.User.IsFileAllowed(ftpPath) {
c.Log(logger.LevelWarn, "writing file %#v is not allowed", ftpPath)
return nil, c.GetPermissionDeniedError()
}
filePath := fsPath
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
filePath = c.Fs.GetAtomicUploadPath(fsPath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
filePath = fs.GetAtomicUploadPath(fsPath)
}
stat, statErr := c.Fs.Lstat(fsPath)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) {
stat, statErr := fs.Lstat(fsPath)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(ftpPath)) {
return nil, c.GetPermissionDeniedError()
}
return c.handleFTPUploadToNewFile(fsPath, filePath, ftpPath)
return c.handleFTPUploadToNewFile(fs, fsPath, filePath, ftpPath)
}
if statErr != nil {
c.Log(logger.LevelError, "error performing file stat %#v: %+v", fsPath, statErr)
return nil, c.GetFsError(statErr)
return nil, c.GetFsError(fs, statErr)
}
// This happen if we upload a file that has the same name of an existing directory
@ -392,34 +348,34 @@ func (c *Connection) uploadFile(fsPath, ftpPath string, flags int) (ftpserver.Fi
return nil, c.GetPermissionDeniedError()
}
return c.handleFTPUploadToExistingFile(flags, fsPath, filePath, stat.Size(), ftpPath)
return c.handleFTPUploadToExistingFile(fs, flags, fsPath, filePath, stat.Size(), ftpPath)
}
func (c *Connection) handleFTPUploadToNewFile(resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) {
func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) {
quotaResult := c.HasSpace(true, false, requestPath)
if !quotaResult.HasSpace {
c.Log(logger.LevelInfo, "denying file write due to quota limits")
return nil, common.ErrQuotaExceeded
}
file, w, cancelFn, err := c.Fs.Create(filePath, 0)
file, w, cancelFn, err := fs.Create(filePath, 0)
if err != nil {
c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
// we can get an error only for resume
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0)
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs)
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
t := newTransfer(baseTransfer, w, nil, 0)
return t, nil
}
func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, filePath string, fileSize int64,
func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolvedPath, filePath string, fileSize int64,
requestPath string) (ftpserver.FileTransfer, error) {
var err error
quotaResult := c.HasSpace(false, false, requestPath)
@ -436,25 +392,25 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file
isResume := flags&os.O_TRUNC == 0
// if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace
// will return false in this case and we deny the upload before
maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize)
maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported())
if err != nil {
c.Log(logger.LevelDebug, "unable to get max write size: %v", err)
return nil, err
}
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
err = c.Fs.Rename(resolvedPath, filePath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
err = fs.Rename(resolvedPath, filePath)
if err != nil {
c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v",
resolvedPath, filePath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
}
file, w, cancelFn, err := c.Fs.Create(filePath, flags)
file, w, cancelFn, err := fs.Create(filePath, flags)
if err != nil {
c.Log(logger.LevelWarn, "error opening existing file, flags: %v, source: %#v, err: %+v", flags, filePath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
initialSize := int64(0)
@ -462,12 +418,12 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file
c.Log(logger.LevelDebug, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize)
minWriteOffset = fileSize
initialSize = fileSize
if vfs.IsSFTPFs(c.Fs) {
if vfs.IsSFTPFs(fs) {
// we need this since we don't allow resume with wrong offset, we should fix this in pkg/sftp
file.Seek(initialSize, io.SeekStart) //nolint:errcheck // for sftp seek cannot file, it simply set the offset
}
} else {
if vfs.IsLocalOrSFTPFs(c.Fs) {
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
@ -482,10 +438,10 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file
}
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs)
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
t := newTransfer(baseTransfer, w, nil, 0)
return t, nil

View file

@ -351,7 +351,7 @@ func (fs MockOsFs) Rename(source, target string) error {
func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
return &MockOsFs{
Fs: vfs.NewOsFs(connectionID, rootDir, nil),
Fs: vfs.NewOsFs(connectionID, rootDir, ""),
err: err,
statErr: statErr,
isAtomicUploadSupported: atomicUpload,
@ -492,7 +492,7 @@ func TestClientVersion(t *testing.T) {
connID := fmt.Sprintf("2_%v", mockCC.ID())
user := dataprovider.User{}
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
common.Connections.Add(connection)
@ -509,7 +509,7 @@ func TestDriverMethodsNotImplemented(t *testing.T) {
connID := fmt.Sprintf("2_%v", mockCC.ID())
user := dataprovider.User{}
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
_, err := connection.Create("")
@ -533,9 +533,8 @@ func TestResolvePathErrors(t *testing.T) {
user.Permissions["/"] = []string{dataprovider.PermAny}
mockCC := mockFTPClientContext{}
connID := fmt.Sprintf("%v", mockCC.ID())
fs := vfs.NewOsFs(connID, user.HomeDir, nil)
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
err := connection.Mkdir("", os.ModePerm)
@ -596,9 +595,9 @@ func TestUploadFileStatError(t *testing.T) {
user.Permissions["/"] = []string{dataprovider.PermAny}
mockCC := mockFTPClientContext{}
connID := fmt.Sprintf("%v", mockCC.ID())
fs := vfs.NewOsFs(connID, user.HomeDir, nil)
fs := vfs.NewOsFs(connID, user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
testFile := filepath.Join(user.HomeDir, "test", "testfile")
@ -608,7 +607,7 @@ func TestUploadFileStatError(t *testing.T) {
assert.NoError(t, err)
err = os.Chmod(filepath.Dir(testFile), 0001)
assert.NoError(t, err)
_, err = connection.uploadFile(testFile, "test", 0)
_, err = connection.uploadFile(fs, testFile, "test", 0)
assert.Error(t, err)
err = os.Chmod(filepath.Dir(testFile), os.ModePerm)
assert.NoError(t, err)
@ -625,9 +624,8 @@ func TestAVBLErrors(t *testing.T) {
user.Permissions["/"] = []string{dataprovider.PermAny}
mockCC := mockFTPClientContext{}
connID := fmt.Sprintf("%v", mockCC.ID())
fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir())
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
_, err := connection.GetAvailableSpace("/")
@ -648,12 +646,12 @@ func TestUploadOverwriteErrors(t *testing.T) {
connID := fmt.Sprintf("%v", mockCC.ID())
fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir())
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
flags := 0
flags |= os.O_APPEND
_, err := connection.handleFTPUploadToExistingFile(flags, "", "", 0, "")
_, err := connection.handleFTPUploadToExistingFile(fs, flags, "", "", 0, "")
if assert.Error(t, err) {
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
}
@ -665,7 +663,7 @@ func TestUploadOverwriteErrors(t *testing.T) {
flags = 0
flags |= os.O_CREATE
flags |= os.O_TRUNC
tr, err := connection.handleFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name())
tr, err := connection.handleFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name())
if assert.NoError(t, err) {
transfer := tr.(*transfer)
transfers := connection.GetTransfers()
@ -680,11 +678,11 @@ func TestUploadOverwriteErrors(t *testing.T) {
err = os.Remove(f.Name())
assert.NoError(t, err)
_, err = connection.handleFTPUploadToExistingFile(os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"),
_, err = connection.handleFTPUploadToExistingFile(fs, os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"),
filepath.Join(os.TempDir(), "sub", "file1"), 0, "/sub/file1")
assert.Error(t, err)
connection.Fs = vfs.NewOsFs(connID, user.GetHomeDir(), nil)
_, err = connection.handleFTPUploadToExistingFile(0, "missing1", "missing2", 0, "missing")
fs = vfs.NewOsFs(connID, user.GetHomeDir(), "")
_, err = connection.handleFTPUploadToExistingFile(fs, 0, "missing1", "missing2", 0, "missing")
assert.Error(t, err)
}
@ -702,7 +700,7 @@ func TestTransferErrors(t *testing.T) {
connID := fmt.Sprintf("%v", mockCC.ID())
fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir())
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: mockCC,
}
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), testfile, common.TransferDownload,

View file

@ -145,7 +145,7 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
connID := fmt.Sprintf("%v_%v", s.ID, cc.ID())
user := dataprovider.User{}
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user, nil),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, user),
clientContext: cc,
}
common.Connections.Add(connection)
@ -180,7 +180,6 @@ func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string)
if err != nil {
return nil, err
}
connection.Fs.CheckRootPath(connection.GetUsername(), user.GetUID(), user.GetGID())
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, user.Username, user.HomeDir, ipAddr)
dataprovider.UpdateLastLogin(&user) //nolint:errcheck
@ -219,7 +218,6 @@ func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsCo
if err != nil {
return nil, err
}
connection.Fs.CheckRootPath(connection.GetUsername(), dbUser.GetUID(), dbUser.GetGID())
connection.Log(logger.LevelInfo, "User id: %d, logged in with FTP using a TLS certificate, username: %#v, home_dir: %#v remote addr: %#v",
dbUser.ID, dbUser.Username, dbUser.HomeDir, ipAddr)
dataprovider.UpdateLastLogin(&dbUser) //nolint:errcheck
@ -295,11 +293,11 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
}
if utils.IsStringInSlice(common.ProtocolFTP, user.Filters.DeniedProtocols) {
logger.Debug(logSender, connectionID, "cannot login user %#v, protocol FTP is not allowed", user.Username)
return nil, fmt.Errorf("Protocol FTP is not allowed for user %#v", user.Username)
return nil, fmt.Errorf("protocol FTP is not allowed for user %#v", user.Username)
}
if !user.IsLoginMethodAllowed(loginMethod, nil) {
logger.Debug(logSender, connectionID, "cannot login user %#v, %v login method is not allowed", user.Username, loginMethod)
return nil, fmt.Errorf("Login method %v is not allowed for user %#v", loginMethod, user.Username)
return nil, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
}
if user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
@ -309,27 +307,26 @@ func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext
return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
}
}
if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() {
logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled",
user.Username)
return nil, errors.New("overlapping mapped folders are allowed only with quota tracking disabled")
}
remoteAddr := cc.RemoteAddr().String()
if !user.IsLoginFromAddrAllowed(remoteAddr) {
logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
return nil, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
}
fs, err := user.GetFilesystem(connectionID)
err := user.CheckFsRoot(connectionID)
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
return nil, err
}
connection := &Connection{
BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, user, fs),
BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, user),
clientContext: cc,
}
err = common.Connections.Swap(connection)
if err != nil {
return nil, errors.New("Internal authentication error")
err = user.CloseFs()
logger.Warn(logSender, connectionID, "unable to swap connection, close fs error: %v", err)
return nil, errors.New("internal authentication error")
}
return connection, nil
}

View file

@ -102,7 +102,7 @@ func (t *transfer) Close() error {
if errBaseClose != nil {
err = errBaseClose
}
return t.Connection.GetFsError(err)
return t.Connection.GetFsError(t.Fs, err)
}
func (t *transfer) closeIO() error {

View file

@ -89,11 +89,11 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) {
}
if username == claims.Username {
if claims.isCriticalPermRemoved(admin.Permissions) {
sendAPIResponse(w, r, errors.New("You cannot remove these permissions to yourself"), "", http.StatusBadRequest)
sendAPIResponse(w, r, errors.New("you cannot remove these permissions to yourself"), "", http.StatusBadRequest)
return
}
if admin.Status == 0 {
sendAPIResponse(w, r, errors.New("You cannot disable yourself"), "", http.StatusBadRequest)
sendAPIResponse(w, r, errors.New("you cannot disable yourself"), "", http.StatusBadRequest)
return
}
}
@ -114,7 +114,7 @@ func deleteAdmin(w http.ResponseWriter, r *http.Request) {
return
}
if username == claims.Username {
sendAPIResponse(w, r, errors.New("You cannot delete yourself"), "", http.StatusBadRequest)
sendAPIResponse(w, r, errors.New("you cannot delete yourself"), "", http.StatusBadRequest)
return
}

View file

@ -50,7 +50,20 @@ func updateFolder(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
users := folder.Users
folderID := folder.ID
currentS3AccessSecret := folder.FsConfig.S3Config.AccessSecret
currentAzAccountKey := folder.FsConfig.AzBlobConfig.AccountKey
currentGCSCredentials := folder.FsConfig.GCSConfig.Credentials
currentCryptoPassphrase := folder.FsConfig.CryptConfig.Passphrase
currentSFTPPassword := folder.FsConfig.SFTPConfig.Password
currentSFTPKey := folder.FsConfig.SFTPConfig.PrivateKey
folder.FsConfig.S3Config = vfs.S3FsConfig{}
folder.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{}
folder.FsConfig.GCSConfig = vfs.GCSFsConfig{}
folder.FsConfig.CryptConfig = vfs.CryptFsConfig{}
folder.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
err = render.DecodeJSON(r.Body, &folder)
if err != nil {
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
@ -58,7 +71,10 @@ func updateFolder(w http.ResponseWriter, r *http.Request) {
}
folder.ID = folderID
folder.Name = name
err = dataprovider.UpdateFolder(&folder)
folder.FsConfig.SetEmptySecretsIfNil()
updateEncryptedSecrets(&folder.FsConfig, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials,
currentCryptoPassphrase, currentSFTPPassword, currentSFTPKey)
err = dataprovider.UpdateFolder(&folder, users)
if err != nil {
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
@ -72,6 +88,7 @@ func renderFolder(w http.ResponseWriter, r *http.Request, name string, status in
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
folder.HideConfidentialData()
if status != http.StatusOK {
ctx := context.WithValue(r.Context(), render.StatusCtxKey, status)
render.JSON(w, r.WithContext(ctx), folder)

View file

@ -21,13 +21,13 @@ import (
func validateBackupFile(outputFile string) (string, error) {
if outputFile == "" {
return "", errors.New("Invalid or missing output-file")
return "", errors.New("invalid or missing output-file")
}
if filepath.IsAbs(outputFile) {
return "", fmt.Errorf("Invalid output-file %#v: it must be a relative path", outputFile)
return "", fmt.Errorf("invalid output-file %#v: it must be a relative path", outputFile)
}
if strings.Contains(outputFile, "..") {
return "", fmt.Errorf("Invalid output-file %#v", outputFile)
return "", fmt.Errorf("invalid output-file %#v", outputFile)
}
outputFile = filepath.Join(backupsPath, outputFile)
return outputFile, nil
@ -122,7 +122,7 @@ func loadData(w http.ResponseWriter, r *http.Request) {
return
}
if !filepath.IsAbs(inputFile) {
sendAPIResponse(w, r, fmt.Errorf("Invalid input_file %#v: it must be an absolute path", inputFile), "", http.StatusBadRequest)
sendAPIResponse(w, r, fmt.Errorf("invalid input_file %#v: it must be an absolute path", inputFile), "", http.StatusBadRequest)
return
}
fi, err := os.Stat(inputFile)
@ -207,7 +207,7 @@ func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, sca
continue
}
folder.ID = f.ID
err = dataprovider.UpdateFolder(&folder)
err = dataprovider.UpdateFolder(&folder, f.Users)
logger.Debug(logSender, "", "restoring existing folder: %+v, dump file: %#v, error: %v", folder, inputFile, err)
} else {
folder.Users = nil

View file

@ -34,7 +34,7 @@ func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) {
return
}
if u.UsedQuotaFiles < 0 || u.UsedQuotaSize < 0 {
sendAPIResponse(w, r, errors.New("Invalid used quota parameters, negative values are not allowed"),
sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"),
"", http.StatusBadRequest)
return
}
@ -75,7 +75,7 @@ func updateVFolderQuotaUsage(w http.ResponseWriter, r *http.Request) {
return
}
if f.UsedQuotaFiles < 0 || f.UsedQuotaSize < 0 {
sendAPIResponse(w, r, errors.New("Invalid used quota parameters, negative values are not allowed"),
sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"),
"", http.StatusBadRequest)
return
}
@ -154,28 +154,25 @@ func startVFolderQuotaScan(w http.ResponseWriter, r *http.Request) {
func doQuotaScan(user dataprovider.User) error {
defer common.QuotaScans.RemoveUserQuotaScan(user.Username)
fs, err := user.GetFilesystem("")
numFiles, size, err := user.ScanQuota()
if err != nil {
logger.Warn(logSender, "", "unable scan quota for user %#v error creating filesystem: %v", user.Username, err)
return err
}
defer fs.Close()
numFiles, size, err := fs.ScanRootDirContents()
if err != nil {
logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.Username, err)
logger.Warn(logSender, "", "error scanning user quota %#v: %v", user.Username, err)
return err
}
err = dataprovider.UpdateUserQuota(&user, numFiles, size, true)
logger.Debug(logSender, "", "user home dir scanned, user: %#v, error: %v", user.Username, err)
logger.Debug(logSender, "", "user quota scanned, user: %#v, error: %v", user.Username, err)
return err
}
func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error {
defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name)
fs := vfs.NewOsFs("", "", nil).(*vfs.OsFs)
numFiles, size, err := fs.GetDirSize(folder.MappedPath)
f := vfs.VirtualFolder{
BaseVirtualFolder: folder,
VirtualPath: "/",
}
numFiles, size, err := f.ScanQuota()
if err != nil {
logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.MappedPath, err)
logger.Warn(logSender, "", "error scanning folder %#v: %v", folder.Name, err)
return err
}
err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true)
@ -188,7 +185,7 @@ func getQuotaUpdateMode(r *http.Request) (string, error) {
if _, ok := r.URL.Query()["mode"]; ok {
mode = r.URL.Query().Get("mode")
if mode != quotaUpdateModeReset && mode != quotaUpdateModeAdd {
return "", errors.New("Invalid mode")
return "", errors.New("invalid mode")
}
}
return mode, nil

View file

@ -59,27 +59,27 @@ func addUser(w http.ResponseWriter, r *http.Request) {
}
user.SetEmptySecretsIfNil()
switch user.FsConfig.Provider {
case dataprovider.S3FilesystemProvider:
case vfs.S3FilesystemProvider:
if user.FsConfig.S3Config.AccessSecret.IsRedacted() {
sendAPIResponse(w, r, errors.New("invalid access_secret"), "", http.StatusBadRequest)
return
}
case dataprovider.GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
if user.FsConfig.GCSConfig.Credentials.IsRedacted() {
sendAPIResponse(w, r, errors.New("invalid credentials"), "", http.StatusBadRequest)
return
}
case dataprovider.AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
if user.FsConfig.AzBlobConfig.AccountKey.IsRedacted() {
sendAPIResponse(w, r, errors.New("invalid account_key"), "", http.StatusBadRequest)
return
}
case dataprovider.CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
if user.FsConfig.CryptConfig.Passphrase.IsRedacted() {
sendAPIResponse(w, r, errors.New("invalid passphrase"), "", http.StatusBadRequest)
return
}
case dataprovider.SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
if user.FsConfig.SFTPConfig.Password.IsRedacted() {
sendAPIResponse(w, r, errors.New("invalid SFTP password"), "", http.StatusBadRequest)
return
@ -131,6 +131,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
user.FsConfig.GCSConfig = vfs.GCSFsConfig{}
user.FsConfig.CryptConfig = vfs.CryptFsConfig{}
user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{}
user.VirtualFolders = nil
err = render.DecodeJSON(r.Body, &user)
if err != nil {
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
@ -143,7 +144,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
if len(user.Permissions) == 0 {
user.Permissions = currentPermissions
}
updateEncryptedSecrets(&user, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, currentCryptoPassphrase,
updateEncryptedSecrets(&user.FsConfig, currentS3AccessSecret, currentAzAccountKey, currentGCSCredentials, currentCryptoPassphrase,
currentSFTPPassword, currentSFTPKey)
err = dataprovider.UpdateUser(&user)
if err != nil {
@ -175,32 +176,32 @@ func disconnectUser(username string) {
}
}
func updateEncryptedSecrets(user *dataprovider.User, currentS3AccessSecret, currentAzAccountKey,
func updateEncryptedSecrets(fsConfig *vfs.Filesystem, currentS3AccessSecret, currentAzAccountKey,
currentGCSCredentials, currentCryptoPassphrase, currentSFTPPassword, currentSFTPKey *kms.Secret) {
// we use the new access secret if plain or empty, otherwise the old value
switch user.FsConfig.Provider {
case dataprovider.S3FilesystemProvider:
if user.FsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() {
user.FsConfig.S3Config.AccessSecret = currentS3AccessSecret
switch fsConfig.Provider {
case vfs.S3FilesystemProvider:
if fsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() {
fsConfig.S3Config.AccessSecret = currentS3AccessSecret
}
case dataprovider.AzureBlobFilesystemProvider:
if user.FsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() {
user.FsConfig.AzBlobConfig.AccountKey = currentAzAccountKey
case vfs.AzureBlobFilesystemProvider:
if fsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() {
fsConfig.AzBlobConfig.AccountKey = currentAzAccountKey
}
case dataprovider.GCSFilesystemProvider:
if user.FsConfig.GCSConfig.Credentials.IsNotPlainAndNotEmpty() {
user.FsConfig.GCSConfig.Credentials = currentGCSCredentials
case vfs.GCSFilesystemProvider:
if fsConfig.GCSConfig.Credentials.IsNotPlainAndNotEmpty() {
fsConfig.GCSConfig.Credentials = currentGCSCredentials
}
case dataprovider.CryptedFilesystemProvider:
if user.FsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() {
user.FsConfig.CryptConfig.Passphrase = currentCryptoPassphrase
case vfs.CryptedFilesystemProvider:
if fsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() {
fsConfig.CryptConfig.Passphrase = currentCryptoPassphrase
}
case dataprovider.SFTPFilesystemProvider:
if user.FsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() {
user.FsConfig.SFTPConfig.Password = currentSFTPPassword
case vfs.SFTPFilesystemProvider:
if fsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() {
fsConfig.SFTPConfig.Password = currentSFTPPassword
}
if user.FsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() {
user.FsConfig.SFTPConfig.PrivateKey = currentSFTPKey
if fsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() {
fsConfig.SFTPConfig.PrivateKey = currentSFTPKey
}
}
}

View file

@ -63,7 +63,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string,
if _, ok := r.URL.Query()["limit"]; ok {
limit, err = strconv.Atoi(r.URL.Query().Get("limit"))
if err != nil {
err = errors.New("Invalid limit")
err = errors.New("invalid limit")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}
@ -74,7 +74,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string,
if _, ok := r.URL.Query()["offset"]; ok {
offset, err = strconv.Atoi(r.URL.Query().Get("offset"))
if err != nil {
err = errors.New("Invalid offset")
err = errors.New("invalid offset")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}
@ -82,7 +82,7 @@ func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string,
if _, ok := r.URL.Query()["order"]; ok {
order = r.URL.Query().Get("order")
if order != dataprovider.OrderASC && order != dataprovider.OrderDESC {
err = errors.New("Invalid order")
err = errors.New("invalid order")
sendAPIResponse(w, r, err, "", http.StatusBadRequest)
return limit, offset, order, err
}

View file

@ -212,12 +212,12 @@ func verifyCSRFToken(tokenString string) error {
token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF: %v", err)
return fmt.Errorf("Unable to verify form token: %v", err)
return fmt.Errorf("unable to verify form token: %v", err)
}
if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) {
logger.Debug(logSender, "", "error validating CSRF token audience")
return errors.New("The form token is not valid")
return errors.New("the form token is not valid")
}
return nil

View file

@ -190,10 +190,10 @@ func (c *Conf) Initialize(configDir string) error {
templatesPath := getConfigPath(c.TemplatesPath, configDir)
enableWebAdmin := staticFilesPath != "" || templatesPath != ""
if backupsPath == "" {
return fmt.Errorf("Required directory is invalid, backup path %#v", backupsPath)
return fmt.Errorf("required directory is invalid, backup path %#v", backupsPath)
}
if enableWebAdmin && (staticFilesPath == "" || templatesPath == "") {
return fmt.Errorf("Required directory is invalid, static file path: %#v template path: %#v",
return fmt.Errorf("required directory is invalid, static file path: %#v template path: %#v",
staticFilesPath, templatesPath)
}
certificateFile := getConfigPath(c.CertificateFile, configDir)

File diff suppressed because it is too large Load diff

View file

@ -299,7 +299,7 @@ func TestGCSWebInvalidFormFile(t *testing.T) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
err := req.ParseForm()
assert.NoError(t, err)
_, err = getFsConfigFromUserPostFields(req)
_, err = getFsConfigFromPostFields(req)
assert.EqualError(t, err, http.ErrNotMultipart.Error())
}
@ -373,7 +373,7 @@ func TestCSRFToken(t *testing.T) {
// invalid token
err := verifyCSRFToken("token")
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "Unable to verify form token")
assert.Contains(t, err.Error(), "unable to verify form token")
}
// bad audience
claims := make(map[string]interface{})
@ -403,7 +403,7 @@ func TestCSRFToken(t *testing.T) {
rr = httptest.NewRecorder()
fn.ServeHTTP(rr, req)
assert.Equal(t, http.StatusForbidden, rr.Code)
assert.Contains(t, rr.Body.String(), "The token is not valid")
assert.Contains(t, rr.Body.String(), "the token is not valid")
csrfTokenAuth = jwtauth.New("PS256", utils.GenerateRandomBytes(32), nil)
tokenString = createCSRFToken()
@ -682,8 +682,8 @@ func TestQuotaScanInvalidFs(t *testing.T) {
user := dataprovider.User{
Username: "test",
HomeDir: os.TempDir(),
FsConfig: dataprovider.Filesystem{
Provider: dataprovider.S3FilesystemProvider,
FsConfig: vfs.Filesystem{
Provider: vfs.S3FilesystemProvider,
},
}
common.QuotaScans.AddUserQuotaScan(user.Username)
@ -754,6 +754,44 @@ func TestVerifyTLSConnection(t *testing.T) {
certMgr = oldCertMgr
}
func TestGetFolderFromTemplate(t *testing.T) {
folder := vfs.BaseVirtualFolder{
MappedPath: "Folder%name%",
Description: "Folder %name% desc",
}
folderName := "folderTemplate"
folderTemplate := getFolderFromTemplate(folder, folderName)
require.Equal(t, folderName, folderTemplate.Name)
require.Equal(t, fmt.Sprintf("Folder%v", folderName), folderTemplate.MappedPath)
require.Equal(t, fmt.Sprintf("Folder %v desc", folderName), folderTemplate.Description)
folder.FsConfig.Provider = vfs.CryptedFilesystemProvider
folder.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%name%")
folderTemplate = getFolderFromTemplate(folder, folderName)
require.Equal(t, folderName, folderTemplate.FsConfig.CryptConfig.Passphrase.GetPayload())
folder.FsConfig.Provider = vfs.GCSFilesystemProvider
folder.FsConfig.GCSConfig.KeyPrefix = "prefix%name%/"
folderTemplate = getFolderFromTemplate(folder, folderName)
require.Equal(t, fmt.Sprintf("prefix%v/", folderName), folderTemplate.FsConfig.GCSConfig.KeyPrefix)
folder.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
folder.FsConfig.AzBlobConfig.KeyPrefix = "a%name%"
folder.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%name%")
folderTemplate = getFolderFromTemplate(folder, folderName)
require.Equal(t, "a"+folderName, folderTemplate.FsConfig.AzBlobConfig.KeyPrefix)
require.Equal(t, "pwd"+folderName, folderTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload())
folder.FsConfig.Provider = vfs.SFTPFilesystemProvider
folder.FsConfig.SFTPConfig.Prefix = "%name%"
folder.FsConfig.SFTPConfig.Username = "sftp_%name%"
folder.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%name%")
folderTemplate = getFolderFromTemplate(folder, folderName)
require.Equal(t, folderName, folderTemplate.FsConfig.SFTPConfig.Prefix)
require.Equal(t, "sftp_"+folderName, folderTemplate.FsConfig.SFTPConfig.Username)
require.Equal(t, "sftp"+folderName, folderTemplate.FsConfig.SFTPConfig.Password.GetPayload())
}
func TestGetUserFromTemplate(t *testing.T) {
user := dataprovider.User{
Status: 1,
@ -775,24 +813,24 @@ func TestGetUserFromTemplate(t *testing.T) {
require.Len(t, userTemplate.VirtualFolders, 1)
require.Equal(t, "Folder"+username, userTemplate.VirtualFolders[0].Name)
user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider
user.FsConfig.Provider = vfs.CryptedFilesystemProvider
user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%password%")
userTemplate = getUserFromTemplate(user, templateFields)
require.Equal(t, password, userTemplate.FsConfig.CryptConfig.Passphrase.GetPayload())
user.FsConfig.Provider = dataprovider.GCSFilesystemProvider
user.FsConfig.Provider = vfs.GCSFilesystemProvider
user.FsConfig.GCSConfig.KeyPrefix = "%username%%password%"
userTemplate = getUserFromTemplate(user, templateFields)
require.Equal(t, username+password, userTemplate.FsConfig.GCSConfig.KeyPrefix)
user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider
user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
user.FsConfig.AzBlobConfig.KeyPrefix = "a%username%"
user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%password%%username%")
userTemplate = getUserFromTemplate(user, templateFields)
require.Equal(t, "a"+username, userTemplate.FsConfig.AzBlobConfig.KeyPrefix)
require.Equal(t, "pwd"+password+username, userTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload())
user.FsConfig.Provider = dataprovider.SFTPFilesystemProvider
user.FsConfig.Provider = vfs.SFTPFilesystemProvider
user.FsConfig.SFTPConfig.Prefix = "%username%"
user.FsConfig.SFTPConfig.Username = "sftp_%username%"
user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%password%")

View file

@ -134,7 +134,7 @@ func verifyCSRFHeader(next http.Handler) http.Handler {
if !utils.IsStringInSlice(tokenAudienceCSRF, token.Audience()) {
logger.Debug(logSender, "", "error validating CSRF header audience")
sendAPIResponse(w, r, errors.New("The token is not valid"), "", http.StatusForbidden)
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
return
}

View file

@ -12,7 +12,7 @@ tags:
info:
title: SFTPGo
description: SFTPGo REST API
version: 2.0.3
version: 2.0.4
contact:
url: 'https://github.com/drakkan/sftpgo'
license:
@ -483,7 +483,7 @@ paths:
tags:
- quota
summary: Start folder quota scan
description: Starts a new quota scan for the given user. A quota scan update the number of files and their total size for the specified folder
description: Starts a new quota scan for the given folder. A quota scan update the number of files and their total size for the specified folder
operationId: start_folder_quota_scan
requestBody:
required: true
@ -1112,7 +1112,7 @@ paths:
- 0
- 1
description: |
output_data:
output data:
* `0` or any other value != 1, the backup will be saved to a file on the server, `output_file` is required
* `1` the backup will be returned as response body
- in: query
@ -1336,6 +1336,19 @@ components:
- manage_system
- manage_defender
- view_defender
description: |
Admin permissions:
* `*` - all permissions are granted
* `add_users` - add new users is allowed
* `edit_users` - change existing users is allowed
* `del_users` - remove users is allowed
* `view_users` - list users is allowed
* `view_conns` - list active connections is allowed
* `close_conns` - close active connections is allowed
* `view_status` - view the server status is allowed
* `manage_admins` - manage other admins is allowed
* `manage_defender` - remove ip from the dynamic blocklist is allowed
* `view_defender` - list the dynamic blocklist is allowed
LoginMethods:
type: string
enum:
@ -1347,13 +1360,25 @@ components:
- TLSCertificate
- TLSCertificate+password
description: |
To enable multi-step authentication you have to allow only multi-step login methods
Available login methods. To enable multi-step authentication you have to allow only multi-step login methods
* `publickey`
* `password`
* `keyboard-interactive`
* `publickey+password` - multi-step auth: public key and password
* `publickey+keyboard-interactive` - multi-step auth: public key and keyboard interactive
* `TLSCertificate`
* `TLSCertificate+password` - multi-step auth: TLS client certificate and password
SupportedProtocols:
type: string
enum:
- SSH
- FTP
- DAV
description: |
Protocols:
* `SSH` - includes both SFTP and SSH commands
* `FTP` - plain FTP and FTPES/FTPS
* `DAV` - WebDAV over HTTP/HTTPS
PatternsFilter:
type: object
properties:
@ -1654,7 +1679,9 @@ components:
items:
type: string
description: list of usernames associated with this virtual folder
description: defines the path for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path.
filesystem:
$ref: '#/components/schemas/FilesystemConfig'
description: Defines the filesystem for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path.
VirtualFolder:
allOf:
- $ref: '#/components/schemas/BaseVirtualFolder'
@ -1837,6 +1864,10 @@ components:
enum:
- upload
- download
description: |
Operations:
* `upload`
* `download`
path:
type: string
description: file path for the upload/download
@ -1899,9 +1930,9 @@ components:
FolderQuotaScan:
type: object
properties:
mapped_path:
name:
type: string
description: path with an active scan
description: folder name with an active scan
start_time:
type: integer
format: int64

View file

@ -41,6 +41,7 @@ const (
const (
templateBase = "base.html"
templateFsConfig = "fsconfig.html"
templateUsers = "users.html"
templateUser = "user.html"
templateAdmins = "admins.html"
@ -196,6 +197,7 @@ func loadTemplates(templatesPath string) {
}
userPaths := []string{
filepath.Join(templatesPath, templateBase),
filepath.Join(templatesPath, templateFsConfig),
filepath.Join(templatesPath, templateUser),
}
adminsPaths := []string{
@ -224,6 +226,7 @@ func loadTemplates(templatesPath string) {
}
folderPath := []string{
filepath.Join(templatesPath, templateBase),
filepath.Join(templatesPath, templateFsConfig),
filepath.Join(templatesPath, templateFolder),
}
statusPath := []string{
@ -392,6 +395,7 @@ func renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.U
if user.Password != "" && user.IsPasswordHashed() && mode == userPageModeUpdate {
user.Password = redactedSecret
}
user.FsConfig.RedactedSecret = redactedSecret
data := userPage{
basePage: getBasePageData(title, currentURL, r),
Mode: mode,
@ -401,7 +405,6 @@ func renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.U
ValidLoginMethods: dataprovider.ValidLoginMethods,
ValidProtocols: dataprovider.ValidProtocols,
RootDirPerms: user.GetPermissionsForPath("/"),
RedactedSecret: redactedSecret,
}
renderTemplate(w, templateUser, data)
}
@ -419,6 +422,9 @@ func renderFolderPage(w http.ResponseWriter, r *http.Request, folder vfs.BaseVir
title = "Folder template"
currentURL = webTemplateFolder
}
folder.FsConfig.RedactedSecret = redactedSecret
folder.FsConfig.SetEmptySecretsIfNil()
data := folderPage{
basePage: getBasePageData(title, currentURL, r),
Error: error,
@ -753,35 +759,35 @@ func getAzureConfig(r *http.Request) (vfs.AzBlobFsConfig, error) {
return config, err
}
func getFsConfigFromUserPostFields(r *http.Request) (dataprovider.Filesystem, error) {
var fs dataprovider.Filesystem
func getFsConfigFromPostFields(r *http.Request) (vfs.Filesystem, error) {
var fs vfs.Filesystem
provider, err := strconv.Atoi(r.Form.Get("fs_provider"))
if err != nil {
provider = int(dataprovider.LocalFilesystemProvider)
provider = int(vfs.LocalFilesystemProvider)
}
fs.Provider = dataprovider.FilesystemProvider(provider)
fs.Provider = vfs.FilesystemProvider(provider)
switch fs.Provider {
case dataprovider.S3FilesystemProvider:
case vfs.S3FilesystemProvider:
config, err := getS3Config(r)
if err != nil {
return fs, err
}
fs.S3Config = config
case dataprovider.AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
config, err := getAzureConfig(r)
if err != nil {
return fs, err
}
fs.AzBlobConfig = config
case dataprovider.GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
config, err := getGCSConfig(r)
if err != nil {
return fs, err
}
fs.GCSConfig = config
case dataprovider.CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
fs.CryptConfig.Passphrase = getSecretFromFormField(r, "crypt_passphrase")
case dataprovider.SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
fs.SFTPConfig = getSFTPConfig(r)
}
return fs, nil
@ -822,6 +828,18 @@ func getFolderFromTemplate(folder vfs.BaseVirtualFolder, name string) vfs.BaseVi
folder.MappedPath = replacePlaceholders(folder.MappedPath, replacements)
folder.Description = replacePlaceholders(folder.Description, replacements)
switch folder.FsConfig.Provider {
case vfs.CryptedFilesystemProvider:
folder.FsConfig.CryptConfig = getCryptFsFromTemplate(folder.FsConfig.CryptConfig, replacements)
case vfs.S3FilesystemProvider:
folder.FsConfig.S3Config = getS3FsFromTemplate(folder.FsConfig.S3Config, replacements)
case vfs.GCSFilesystemProvider:
folder.FsConfig.GCSConfig = getGCSFsFromTemplate(folder.FsConfig.GCSConfig, replacements)
case vfs.AzureBlobFilesystemProvider:
folder.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(folder.FsConfig.AzBlobConfig, replacements)
case vfs.SFTPFilesystemProvider:
folder.FsConfig.SFTPConfig = getSFTPFsFromTemplate(folder.FsConfig.SFTPConfig, replacements)
}
return folder
}
@ -895,15 +913,15 @@ func getUserFromTemplate(user dataprovider.User, template userTemplateFields) da
user.AdditionalInfo = replacePlaceholders(user.AdditionalInfo, replacements)
switch user.FsConfig.Provider {
case dataprovider.CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
user.FsConfig.CryptConfig = getCryptFsFromTemplate(user.FsConfig.CryptConfig, replacements)
case dataprovider.S3FilesystemProvider:
case vfs.S3FilesystemProvider:
user.FsConfig.S3Config = getS3FsFromTemplate(user.FsConfig.S3Config, replacements)
case dataprovider.GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
user.FsConfig.GCSConfig = getGCSFsFromTemplate(user.FsConfig.GCSConfig, replacements)
case dataprovider.AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
user.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(user.FsConfig.AzBlobConfig, replacements)
case dataprovider.SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
user.FsConfig.SFTPConfig = getSFTPFsFromTemplate(user.FsConfig.SFTPConfig, replacements)
}
@ -959,7 +977,7 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) {
}
expirationDateMillis = utils.GetTimeAsMsSinceEpoch(expirationDate)
}
fsConfig, err := getFsConfigFromUserPostFields(r)
fsConfig, err := getFsConfigFromPostFields(r)
if err != nil {
return user, err
}
@ -1257,6 +1275,12 @@ func handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) {
templateFolder.MappedPath = r.Form.Get("mapped_path")
templateFolder.Description = r.Form.Get("description")
fsConfig, err := getFsConfigFromPostFields(r)
if err != nil {
renderMessagePage(w, r, "Error parsing folders fields", "", http.StatusBadRequest, err, "")
return
}
templateFolder.FsConfig = fsConfig
var dump dataprovider.BackupData
dump.Version = dataprovider.DumpVersion
@ -1415,7 +1439,7 @@ func handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) {
if updatedUser.Password == redactedSecret {
updatedUser.Password = user.Password
}
updateEncryptedSecrets(&updatedUser, user.FsConfig.S3Config.AccessSecret, user.FsConfig.AzBlobConfig.AccountKey,
updateEncryptedSecrets(&updatedUser.FsConfig, user.FsConfig.S3Config.AccessSecret, user.FsConfig.AzBlobConfig.AccountKey,
user.FsConfig.GCSConfig.Credentials, user.FsConfig.CryptConfig.Passphrase, user.FsConfig.SFTPConfig.Password,
user.FsConfig.SFTPConfig.PrivateKey)
@ -1466,6 +1490,12 @@ func handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) {
folder.MappedPath = r.Form.Get("mapped_path")
folder.Name = r.Form.Get("name")
folder.Description = r.Form.Get("description")
fsConfig, err := getFsConfigFromPostFields(r)
if err != nil {
renderFolderPage(w, r, folder, folderPageModeAdd, err.Error())
return
}
folder.FsConfig = fsConfig
err = dataprovider.AddFolder(&folder)
if err == nil {
@ -1508,9 +1538,24 @@ func handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) {
renderForbiddenPage(w, r, err.Error())
return
}
folder.MappedPath = r.Form.Get("mapped_path")
folder.Description = r.Form.Get("description")
err = dataprovider.UpdateFolder(&folder)
fsConfig, err := getFsConfigFromPostFields(r)
if err != nil {
renderFolderPage(w, r, folder, folderPageModeUpdate, err.Error())
return
}
updatedFolder := &vfs.BaseVirtualFolder{
MappedPath: r.Form.Get("mapped_path"),
Description: r.Form.Get("description"),
}
updatedFolder.ID = folder.ID
updatedFolder.Name = folder.Name
updatedFolder.FsConfig = fsConfig
updatedFolder.FsConfig.SetEmptySecretsIfNil()
updateEncryptedSecrets(&updatedFolder.FsConfig, folder.FsConfig.S3Config.AccessSecret, folder.FsConfig.AzBlobConfig.AccountKey,
folder.FsConfig.GCSConfig.Credentials, folder.FsConfig.CryptConfig.Passphrase, folder.FsConfig.SFTPConfig.Password,
folder.FsConfig.SFTPConfig.PrivateKey)
err = dataprovider.UpdateFolder(updatedFolder, folder.Users)
if err != nil {
renderFolderPage(w, r, folder, folderPageModeUpdate, err.Error())
return

View file

@ -10,7 +10,6 @@ import (
"net/http"
"net/url"
"path"
"path/filepath"
"strconv"
"strings"
@ -835,32 +834,18 @@ func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder)
if expected.MappedPath != actual.MappedPath {
return errors.New("mapped path mismatch")
}
if expected.LastQuotaUpdate != actual.LastQuotaUpdate {
return errors.New("last quota update mismatch")
}
if expected.UsedQuotaSize != actual.UsedQuotaSize {
return errors.New("used quota size mismatch")
}
if expected.UsedQuotaFiles != actual.UsedQuotaFiles {
return errors.New("used quota files mismatch")
}
if expected.Description != actual.Description {
return errors.New("Description mismatch")
return errors.New("description mismatch")
}
if len(expected.Users) != len(actual.Users) {
return errors.New("folder users mismatch")
}
for _, u := range actual.Users {
if !utils.IsStringInSlice(u, expected.Users) {
return errors.New("folder users mismatch")
}
if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil {
return err
}
return nil
}
func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error {
if actual.Password != "" {
return errors.New("Admin password must not be visible")
return errors.New("admin password must not be visible")
}
if expected.ID <= 0 {
if actual.ID <= 0 {
@ -875,19 +860,19 @@ func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error
return err
}
if len(expected.Permissions) != len(actual.Permissions) {
return errors.New("Permissions mismatch")
return errors.New("permissions mismatch")
}
for _, p := range expected.Permissions {
if !utils.IsStringInSlice(p, actual.Permissions) {
return errors.New("Permissions content mismatch")
return errors.New("permissions content mismatch")
}
}
if len(expected.Filters.AllowList) != len(actual.Filters.AllowList) {
return errors.New("AllowList mismatch")
return errors.New("allow list mismatch")
}
for _, v := range expected.Filters.AllowList {
if !utils.IsStringInSlice(v, actual.Filters.AllowList) {
return errors.New("AllowList content mismatch")
return errors.New("allow list content mismatch")
}
}
@ -896,26 +881,26 @@ func checkAdmin(expected *dataprovider.Admin, actual *dataprovider.Admin) error
func compareAdminEqualFields(expected *dataprovider.Admin, actual *dataprovider.Admin) error {
if expected.Username != actual.Username {
return errors.New("Username mismatch")
return errors.New("sername mismatch")
}
if expected.Email != actual.Email {
return errors.New("Email mismatch")
return errors.New("email mismatch")
}
if expected.Status != actual.Status {
return errors.New("Status mismatch")
return errors.New("status mismatch")
}
if expected.Description != actual.Description {
return errors.New("Description mismatch")
return errors.New("description mismatch")
}
if expected.AdditionalInfo != actual.AdditionalInfo {
return errors.New("AdditionalInfo mismatch")
return errors.New("additional info mismatch")
}
return nil
}
func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
if actual.Password != "" {
return errors.New("User password must not be visible")
return errors.New("user password must not be visible")
}
if expected.ID <= 0 {
if actual.ID <= 0 {
@ -927,23 +912,23 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
}
}
if len(expected.Permissions) != len(actual.Permissions) {
return errors.New("Permissions mismatch")
return errors.New("permissions mismatch")
}
for dir, perms := range expected.Permissions {
if actualPerms, ok := actual.Permissions[dir]; ok {
for _, v := range actualPerms {
if !utils.IsStringInSlice(v, perms) {
return errors.New("Permissions contents mismatch")
return errors.New("permissions contents mismatch")
}
}
} else {
return errors.New("Permissions directories mismatch")
return errors.New("permissions directories mismatch")
}
}
if err := compareUserFilters(expected, actual); err != nil {
return err
}
if err := compareUserFsConfig(expected, actual); err != nil {
if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil {
return err
}
if err := compareUserVirtualFolders(expected, actual); err != nil {
@ -954,27 +939,35 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error {
func compareUserVirtualFolders(expected *dataprovider.User, actual *dataprovider.User) error {
if len(actual.VirtualFolders) != len(expected.VirtualFolders) {
return errors.New("Virtual folders mismatch")
return errors.New("virtual folders len mismatch")
}
for _, v := range actual.VirtualFolders {
found := false
for _, v1 := range expected.VirtualFolders {
if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) &&
filepath.Clean(v.MappedPath) == filepath.Clean(v1.MappedPath) {
if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) {
if err := checkFolder(&v1.BaseVirtualFolder, &v.BaseVirtualFolder); err != nil {
return err
}
if v.QuotaSize != v1.QuotaSize {
return errors.New("vfolder quota size mismatch")
}
if (v.QuotaFiles) != (v1.QuotaFiles) {
return errors.New("vfolder quota files mismatch")
}
found = true
break
}
}
if !found {
return errors.New("Virtual folders mismatch")
return errors.New("virtual folders mismatch")
}
}
return nil
}
func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.FsConfig.Provider != actual.FsConfig.Provider {
return errors.New("Fs provider mismatch")
func compareFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error {
if expected.Provider != actual.Provider {
return errors.New("fs provider mismatch")
}
if err := compareS3Config(expected, actual); err != nil {
return err
@ -985,7 +978,7 @@ func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User)
if err := compareAzBlobConfig(expected, actual); err != nil {
return err
}
if err := checkEncryptedSecret(expected.FsConfig.CryptConfig.Passphrase, actual.FsConfig.CryptConfig.Passphrase); err != nil {
if err := checkEncryptedSecret(expected.CryptConfig.Passphrase, actual.CryptConfig.Passphrase); err != nil {
return err
}
if err := compareSFTPFsConfig(expected, actual); err != nil {
@ -994,118 +987,118 @@ func compareUserFsConfig(expected *dataprovider.User, actual *dataprovider.User)
return nil
}
func compareS3Config(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.FsConfig.S3Config.Bucket != actual.FsConfig.S3Config.Bucket {
return errors.New("S3 bucket mismatch")
func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error {
if expected.S3Config.Bucket != actual.S3Config.Bucket {
return errors.New("fs S3 bucket mismatch")
}
if expected.FsConfig.S3Config.Region != actual.FsConfig.S3Config.Region {
return errors.New("S3 region mismatch")
if expected.S3Config.Region != actual.S3Config.Region {
return errors.New("fs S3 region mismatch")
}
if expected.FsConfig.S3Config.AccessKey != actual.FsConfig.S3Config.AccessKey {
return errors.New("S3 access key mismatch")
if expected.S3Config.AccessKey != actual.S3Config.AccessKey {
return errors.New("fs S3 access key mismatch")
}
if err := checkEncryptedSecret(expected.FsConfig.S3Config.AccessSecret, actual.FsConfig.S3Config.AccessSecret); err != nil {
return fmt.Errorf("S3 access secret mismatch: %v", err)
if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil {
return fmt.Errorf("fs S3 access secret mismatch: %v", err)
}
if expected.FsConfig.S3Config.Endpoint != actual.FsConfig.S3Config.Endpoint {
return errors.New("S3 endpoint mismatch")
if expected.S3Config.Endpoint != actual.S3Config.Endpoint {
return errors.New("fs S3 endpoint mismatch")
}
if expected.FsConfig.S3Config.StorageClass != actual.FsConfig.S3Config.StorageClass {
return errors.New("S3 storage class mismatch")
if expected.S3Config.StorageClass != actual.S3Config.StorageClass {
return errors.New("fs S3 storage class mismatch")
}
if expected.FsConfig.S3Config.UploadPartSize != actual.FsConfig.S3Config.UploadPartSize {
return errors.New("S3 upload part size mismatch")
if expected.S3Config.UploadPartSize != actual.S3Config.UploadPartSize {
return errors.New("fs S3 upload part size mismatch")
}
if expected.FsConfig.S3Config.UploadConcurrency != actual.FsConfig.S3Config.UploadConcurrency {
return errors.New("S3 upload concurrency mismatch")
if expected.S3Config.UploadConcurrency != actual.S3Config.UploadConcurrency {
return errors.New("fs S3 upload concurrency mismatch")
}
if expected.FsConfig.S3Config.KeyPrefix != actual.FsConfig.S3Config.KeyPrefix &&
expected.FsConfig.S3Config.KeyPrefix+"/" != actual.FsConfig.S3Config.KeyPrefix {
return errors.New("S3 key prefix mismatch")
if expected.S3Config.KeyPrefix != actual.S3Config.KeyPrefix &&
expected.S3Config.KeyPrefix+"/" != actual.S3Config.KeyPrefix {
return errors.New("fs S3 key prefix mismatch")
}
return nil
}
func compareGCSConfig(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.FsConfig.GCSConfig.Bucket != actual.FsConfig.GCSConfig.Bucket {
func compareGCSConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error {
if expected.GCSConfig.Bucket != actual.GCSConfig.Bucket {
return errors.New("GCS bucket mismatch")
}
if expected.FsConfig.GCSConfig.StorageClass != actual.FsConfig.GCSConfig.StorageClass {
if expected.GCSConfig.StorageClass != actual.GCSConfig.StorageClass {
return errors.New("GCS storage class mismatch")
}
if expected.FsConfig.GCSConfig.KeyPrefix != actual.FsConfig.GCSConfig.KeyPrefix &&
expected.FsConfig.GCSConfig.KeyPrefix+"/" != actual.FsConfig.GCSConfig.KeyPrefix {
if expected.GCSConfig.KeyPrefix != actual.GCSConfig.KeyPrefix &&
expected.GCSConfig.KeyPrefix+"/" != actual.GCSConfig.KeyPrefix {
return errors.New("GCS key prefix mismatch")
}
if expected.FsConfig.GCSConfig.AutomaticCredentials != actual.FsConfig.GCSConfig.AutomaticCredentials {
if expected.GCSConfig.AutomaticCredentials != actual.GCSConfig.AutomaticCredentials {
return errors.New("GCS automatic credentials mismatch")
}
return nil
}
func compareSFTPFsConfig(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.FsConfig.SFTPConfig.Endpoint != actual.FsConfig.SFTPConfig.Endpoint {
func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error {
if expected.SFTPConfig.Endpoint != actual.SFTPConfig.Endpoint {
return errors.New("SFTPFs endpoint mismatch")
}
if expected.FsConfig.SFTPConfig.Username != actual.FsConfig.SFTPConfig.Username {
if expected.SFTPConfig.Username != actual.SFTPConfig.Username {
return errors.New("SFTPFs username mismatch")
}
if expected.FsConfig.SFTPConfig.DisableCouncurrentReads != actual.FsConfig.SFTPConfig.DisableCouncurrentReads {
if expected.SFTPConfig.DisableCouncurrentReads != actual.SFTPConfig.DisableCouncurrentReads {
return errors.New("SFTPFs disable_concurrent_reads mismatch")
}
if err := checkEncryptedSecret(expected.FsConfig.SFTPConfig.Password, actual.FsConfig.SFTPConfig.Password); err != nil {
if err := checkEncryptedSecret(expected.SFTPConfig.Password, actual.SFTPConfig.Password); err != nil {
return fmt.Errorf("SFTPFs password mismatch: %v", err)
}
if err := checkEncryptedSecret(expected.FsConfig.SFTPConfig.PrivateKey, actual.FsConfig.SFTPConfig.PrivateKey); err != nil {
if err := checkEncryptedSecret(expected.SFTPConfig.PrivateKey, actual.SFTPConfig.PrivateKey); err != nil {
return fmt.Errorf("SFTPFs private key mismatch: %v", err)
}
if expected.FsConfig.SFTPConfig.Prefix != actual.FsConfig.SFTPConfig.Prefix {
if expected.FsConfig.SFTPConfig.Prefix != "" && actual.FsConfig.SFTPConfig.Prefix != "/" {
if expected.SFTPConfig.Prefix != actual.SFTPConfig.Prefix {
if expected.SFTPConfig.Prefix != "" && actual.SFTPConfig.Prefix != "/" {
return errors.New("SFTPFs prefix mismatch")
}
}
if len(expected.FsConfig.SFTPConfig.Fingerprints) != len(actual.FsConfig.SFTPConfig.Fingerprints) {
if len(expected.SFTPConfig.Fingerprints) != len(actual.SFTPConfig.Fingerprints) {
return errors.New("SFTPFs fingerprints mismatch")
}
for _, value := range actual.FsConfig.SFTPConfig.Fingerprints {
if !utils.IsStringInSlice(value, expected.FsConfig.SFTPConfig.Fingerprints) {
for _, value := range actual.SFTPConfig.Fingerprints {
if !utils.IsStringInSlice(value, expected.SFTPConfig.Fingerprints) {
return errors.New("SFTPFs fingerprints mismatch")
}
}
return nil
}
func compareAzBlobConfig(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.FsConfig.AzBlobConfig.Container != actual.FsConfig.AzBlobConfig.Container {
return errors.New("Azure Blob container mismatch")
func compareAzBlobConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error {
if expected.AzBlobConfig.Container != actual.AzBlobConfig.Container {
return errors.New("azure Blob container mismatch")
}
if expected.FsConfig.AzBlobConfig.AccountName != actual.FsConfig.AzBlobConfig.AccountName {
return errors.New("Azure Blob account name mismatch")
if expected.AzBlobConfig.AccountName != actual.AzBlobConfig.AccountName {
return errors.New("azure Blob account name mismatch")
}
if err := checkEncryptedSecret(expected.FsConfig.AzBlobConfig.AccountKey, actual.FsConfig.AzBlobConfig.AccountKey); err != nil {
return fmt.Errorf("Azure Blob account key mismatch: %v", err)
if err := checkEncryptedSecret(expected.AzBlobConfig.AccountKey, actual.AzBlobConfig.AccountKey); err != nil {
return fmt.Errorf("azure Blob account key mismatch: %v", err)
}
if expected.FsConfig.AzBlobConfig.Endpoint != actual.FsConfig.AzBlobConfig.Endpoint {
return errors.New("Azure Blob endpoint mismatch")
if expected.AzBlobConfig.Endpoint != actual.AzBlobConfig.Endpoint {
return errors.New("azure Blob endpoint mismatch")
}
if expected.FsConfig.AzBlobConfig.SASURL != actual.FsConfig.AzBlobConfig.SASURL {
return errors.New("Azure Blob SASL URL mismatch")
if expected.AzBlobConfig.SASURL != actual.AzBlobConfig.SASURL {
return errors.New("azure Blob SASL URL mismatch")
}
if expected.FsConfig.AzBlobConfig.UploadPartSize != actual.FsConfig.AzBlobConfig.UploadPartSize {
return errors.New("Azure Blob upload part size mismatch")
if expected.AzBlobConfig.UploadPartSize != actual.AzBlobConfig.UploadPartSize {
return errors.New("azure Blob upload part size mismatch")
}
if expected.FsConfig.AzBlobConfig.UploadConcurrency != actual.FsConfig.AzBlobConfig.UploadConcurrency {
return errors.New("Azure Blob upload concurrency mismatch")
if expected.AzBlobConfig.UploadConcurrency != actual.AzBlobConfig.UploadConcurrency {
return errors.New("azure Blob upload concurrency mismatch")
}
if expected.FsConfig.AzBlobConfig.KeyPrefix != actual.FsConfig.AzBlobConfig.KeyPrefix &&
expected.FsConfig.AzBlobConfig.KeyPrefix+"/" != actual.FsConfig.AzBlobConfig.KeyPrefix {
return errors.New("Azure Blob key prefix mismatch")
if expected.AzBlobConfig.KeyPrefix != actual.AzBlobConfig.KeyPrefix &&
expected.AzBlobConfig.KeyPrefix+"/" != actual.AzBlobConfig.KeyPrefix {
return errors.New("azure Blob key prefix mismatch")
}
if expected.FsConfig.AzBlobConfig.UseEmulator != actual.FsConfig.AzBlobConfig.UseEmulator {
return errors.New("Azure Blob use emulator mismatch")
if expected.AzBlobConfig.UseEmulator != actual.AzBlobConfig.UseEmulator {
return errors.New("azure Blob use emulator mismatch")
}
if expected.FsConfig.AzBlobConfig.AccessTier != actual.FsConfig.AzBlobConfig.AccessTier {
return errors.New("Azure Blob access tier mismatch")
if expected.AzBlobConfig.AccessTier != actual.AzBlobConfig.AccessTier {
return errors.New("azure Blob access tier mismatch")
}
return nil
}
@ -1154,22 +1147,22 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error {
func compareUserFilterSubStructs(expected *dataprovider.User, actual *dataprovider.User) error {
for _, IPMask := range expected.Filters.AllowedIP {
if !utils.IsStringInSlice(IPMask, actual.Filters.AllowedIP) {
return errors.New("AllowedIP contents mismatch")
return errors.New("allowed IP contents mismatch")
}
}
for _, IPMask := range expected.Filters.DeniedIP {
if !utils.IsStringInSlice(IPMask, actual.Filters.DeniedIP) {
return errors.New("DeniedIP contents mismatch")
return errors.New("denied IP contents mismatch")
}
}
for _, method := range expected.Filters.DeniedLoginMethods {
if !utils.IsStringInSlice(method, actual.Filters.DeniedLoginMethods) {
return errors.New("Denied login methods contents mismatch")
return errors.New("denied login methods contents mismatch")
}
}
for _, protocol := range expected.Filters.DeniedProtocols {
if !utils.IsStringInSlice(protocol, actual.Filters.DeniedProtocols) {
return errors.New("Denied protocols contents mismatch")
return errors.New("denied protocols contents mismatch")
}
}
return nil
@ -1177,19 +1170,19 @@ func compareUserFilterSubStructs(expected *dataprovider.User, actual *dataprovid
func compareUserFilters(expected *dataprovider.User, actual *dataprovider.User) error {
if len(expected.Filters.AllowedIP) != len(actual.Filters.AllowedIP) {
return errors.New("AllowedIP mismatch")
return errors.New("allowed IP mismatch")
}
if len(expected.Filters.DeniedIP) != len(actual.Filters.DeniedIP) {
return errors.New("DeniedIP mismatch")
return errors.New("denied IP mismatch")
}
if len(expected.Filters.DeniedLoginMethods) != len(actual.Filters.DeniedLoginMethods) {
return errors.New("Denied login methods mismatch")
return errors.New("denied login methods mismatch")
}
if len(expected.Filters.DeniedProtocols) != len(actual.Filters.DeniedProtocols) {
return errors.New("Denied protocols mismatch")
return errors.New("denied protocols mismatch")
}
if expected.Filters.MaxUploadFileSize != actual.Filters.MaxUploadFileSize {
return errors.New("Max upload file size mismatch")
return errors.New("max upload file size mismatch")
}
if expected.Filters.TLSUsername != actual.Filters.TLSUsername {
return errors.New("TLSUsername mismatch")
@ -1261,10 +1254,10 @@ func compareUserFileExtensionsFilters(expected *dataprovider.User, actual *datap
func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error {
if expected.Username != actual.Username {
return errors.New("Username mismatch")
return errors.New("username mismatch")
}
if expected.HomeDir != actual.HomeDir {
return errors.New("HomeDir mismatch")
return errors.New("home dir mismatch")
}
if expected.UID != actual.UID {
return errors.New("UID mismatch")
@ -1282,7 +1275,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U
return errors.New("QuotaFiles mismatch")
}
if len(expected.Permissions) != len(actual.Permissions) {
return errors.New("Permissions mismatch")
return errors.New("permissions mismatch")
}
if expected.UploadBandwidth != actual.UploadBandwidth {
return errors.New("UploadBandwidth mismatch")
@ -1291,7 +1284,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U
return errors.New("DownloadBandwidth mismatch")
}
if expected.Status != actual.Status {
return errors.New("Status mismatch")
return errors.New("status mismatch")
}
if expected.ExpirationDate != actual.ExpirationDate {
return errors.New("ExpirationDate mismatch")
@ -1300,7 +1293,7 @@ func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.U
return errors.New("AdditionalInfo mismatch")
}
if expected.Description != actual.Description {
return errors.New("Description mismatch")
return errors.New("description mismatch")
}
return nil
}

View file

@ -21,6 +21,7 @@ import (
"github.com/drakkan/sftpgo/sftpd"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/version"
"github.com/drakkan/sftpgo/vfs"
"github.com/drakkan/sftpgo/webdavd"
)
@ -229,9 +230,9 @@ func (s *Service) advertiseServices(advertiseService, advertiseCredentials bool)
func (s *Service) getPortableDirToServe() string {
var dirToServe string
if s.PortableUser.FsConfig.Provider == dataprovider.S3FilesystemProvider {
if s.PortableUser.FsConfig.Provider == vfs.S3FilesystemProvider {
dirToServe = s.PortableUser.FsConfig.S3Config.KeyPrefix
} else if s.PortableUser.FsConfig.Provider == dataprovider.GCSFilesystemProvider {
} else if s.PortableUser.FsConfig.Provider == vfs.GCSFilesystemProvider {
dirToServe = s.PortableUser.FsConfig.GCSConfig.KeyPrefix
} else {
dirToServe = s.PortableUser.HomeDir
@ -263,31 +264,31 @@ func (s *Service) configurePortableUser() string {
func (s *Service) configurePortableSecrets() {
// we created the user before to initialize the KMS so we need to create the secret here
switch s.PortableUser.FsConfig.Provider {
case dataprovider.S3FilesystemProvider:
case vfs.S3FilesystemProvider:
payload := s.PortableUser.FsConfig.S3Config.AccessSecret.GetPayload()
s.PortableUser.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret()
if payload != "" {
s.PortableUser.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret(payload)
}
case dataprovider.GCSFilesystemProvider:
case vfs.GCSFilesystemProvider:
payload := s.PortableUser.FsConfig.GCSConfig.Credentials.GetPayload()
s.PortableUser.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret()
if payload != "" {
s.PortableUser.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(payload)
}
case dataprovider.AzureBlobFilesystemProvider:
case vfs.AzureBlobFilesystemProvider:
payload := s.PortableUser.FsConfig.AzBlobConfig.AccountKey.GetPayload()
s.PortableUser.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret()
if payload != "" {
s.PortableUser.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret(payload)
}
case dataprovider.CryptedFilesystemProvider:
case vfs.CryptedFilesystemProvider:
payload := s.PortableUser.FsConfig.CryptConfig.Passphrase.GetPayload()
s.PortableUser.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret()
if payload != "" {
s.PortableUser.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(payload)
}
case dataprovider.SFTPFilesystemProvider:
case vfs.SFTPFilesystemProvider:
payload := s.PortableUser.FsConfig.SFTPConfig.Password.GetPayload()
s.PortableUser.FsConfig.SFTPConfig.Password = kms.NewEmptySecret()
if payload != "" {

View file

@ -478,7 +478,7 @@ func getEncryptedFileSize(size int64) (int64, error) {
func getTestUserWithCryptFs(usePubKey bool) dataprovider.User {
u := getTestUser(usePubKey)
u.FsConfig.Provider = dataprovider.CryptedFilesystemProvider
u.FsConfig.Provider = vfs.CryptedFilesystemProvider
u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase)
return u
}

View file

@ -54,19 +54,19 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
return nil, sftp.ErrSSHFxPermissionDenied
}
p, err := c.Fs.ResolvePath(request.Filepath)
fs, p, err := c.GetFsAndResolvedPath(request.Filepath)
if err != nil {
return nil, c.GetFsError(err)
return nil, err
}
file, r, cancelFn, err := c.Fs.Open(p, 0)
file, r, cancelFn, err := fs.Open(p, 0)
if err != nil {
c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", p, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, request.Filepath, common.TransferDownload,
0, 0, 0, false, c.Fs)
0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil)
return t, nil
@ -90,18 +90,18 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader
return nil, sftp.ErrSSHFxPermissionDenied
}
p, err := c.Fs.ResolvePath(request.Filepath)
fs, p, err := c.GetFsAndResolvedPath(request.Filepath)
if err != nil {
return nil, c.GetFsError(err)
return nil, err
}
filePath := p
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
filePath = c.Fs.GetAtomicUploadPath(p)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
filePath = fs.GetAtomicUploadPath(p)
}
var errForRead error
if !vfs.IsLocalOrSFTPFs(c.Fs) && request.Pflags().Read {
if !vfs.IsLocalOrSFTPFs(fs) && request.Pflags().Read {
// read and write mode is only supported for local filesystem
errForRead = sftp.ErrSSHFxOpUnsupported
}
@ -112,17 +112,17 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader
errForRead = os.ErrPermission
}
stat, statErr := c.Fs.Lstat(p)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) {
stat, statErr := fs.Lstat(p)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) {
return nil, sftp.ErrSSHFxPermissionDenied
}
return c.handleSFTPUploadToNewFile(p, filePath, request.Filepath, errForRead)
return c.handleSFTPUploadToNewFile(fs, p, filePath, request.Filepath, errForRead)
}
if statErr != nil {
c.Log(logger.LevelError, "error performing file stat %#v: %+v", p, statErr)
return nil, c.GetFsError(statErr)
return nil, c.GetFsError(fs, statErr)
}
// This happen if we upload a file that has the same name of an existing directory
@ -135,7 +135,7 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader
return nil, sftp.ErrSSHFxPermissionDenied
}
return c.handleSFTPUploadToExistingFile(request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead)
return c.handleSFTPUploadToExistingFile(fs, request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead)
}
// Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading
@ -143,37 +143,29 @@ func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReader
func (c *Connection) Filecmd(request *sftp.Request) error {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(request.Filepath)
if err != nil {
return c.GetFsError(err)
}
target, err := c.getSFTPCmdTargetPath(request.Target)
if err != nil {
return c.GetFsError(err)
}
c.Log(logger.LevelDebug, "new cmd, method: %v, sourcePath: %#v, targetPath: %#v", request.Method, p, target)
c.Log(logger.LevelDebug, "new cmd, method: %v, sourcePath: %#v, targetPath: %#v", request.Method,
request.Filepath, request.Target)
switch request.Method {
case "Setstat":
return c.handleSFTPSetstat(p, request)
return c.handleSFTPSetstat(request)
case "Rename":
if err = c.Rename(p, target, request.Filepath, request.Target); err != nil {
if err := c.Rename(request.Filepath, request.Target); err != nil {
return err
}
case "Rmdir":
return c.RemoveDir(p, request.Filepath)
return c.RemoveDir(request.Filepath)
case "Mkdir":
err = c.CreateDir(p, request.Filepath)
err := c.CreateDir(request.Filepath)
if err != nil {
return err
}
case "Symlink":
if err = c.CreateSymlink(p, target, request.Filepath, request.Target); err != nil {
if err := c.CreateSymlink(request.Filepath, request.Target); err != nil {
return err
}
case "Remove":
return c.handleSFTPRemove(p, request)
return c.handleSFTPRemove(request)
default:
return sftp.ErrSSHFxOpUnsupported
}
@ -185,14 +177,10 @@ func (c *Connection) Filecmd(request *sftp.Request) error {
// a directory as well as perform file/folder stat calls.
func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(request.Filepath)
if err != nil {
return nil, c.GetFsError(err)
}
switch request.Method {
case "List":
files, err := c.ListDir(p, request.Filepath)
files, err := c.ListDir(request.Filepath)
if err != nil {
return nil, err
}
@ -202,10 +190,10 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
return nil, sftp.ErrSSHFxPermissionDenied
}
s, err := c.DoStat(p, 0)
s, err := c.DoStat(request.Filepath, 0)
if err != nil {
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
return nil, c.GetFsError(err)
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", request.Filepath, err)
return nil, err
}
return listerAt([]os.FileInfo{s}), nil
@ -214,10 +202,15 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
return nil, sftp.ErrSSHFxPermissionDenied
}
s, err := c.Fs.Readlink(p)
fs, p, err := c.GetFsAndResolvedPath(request.Filepath)
if err != nil {
return nil, err
}
s, err := fs.Readlink(p)
if err != nil {
c.Log(logger.LevelDebug, "error running readlink on path %#v: %+v", p, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(s)) {
@ -235,19 +228,14 @@ func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
func (c *Connection) Lstat(request *sftp.Request) (sftp.ListerAt, error) {
c.UpdateLastActivity()
p, err := c.Fs.ResolvePath(request.Filepath)
if err != nil {
return nil, c.GetFsError(err)
}
if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) {
return nil, sftp.ErrSSHFxPermissionDenied
}
s, err := c.DoStat(p, 1)
s, err := c.DoStat(request.Filepath, 1)
if err != nil {
c.Log(logger.LevelDebug, "error running lstat on path %#v: %+v", p, err)
return nil, c.GetFsError(err)
c.Log(logger.LevelDebug, "error running lstat on path %#v: %+v", request.Filepath, err)
return nil, err
}
return listerAt([]os.FileInfo{s}), nil
@ -263,43 +251,29 @@ func (c *Connection) StatVFS(r *sftp.Request) (*sftp.StatVFS, error) {
// not the limit for a single file upload
quotaResult := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt"))
p, err := c.Fs.ResolvePath(r.Filepath)
fs, p, err := c.GetFsAndResolvedPath(r.Filepath)
if err != nil {
return nil, c.GetFsError(err)
return nil, err
}
if !quotaResult.HasSpace {
return c.getStatVFSFromQuotaResult(p, quotaResult), nil
return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil
}
if quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 {
// no quota restrictions
statvfs, err := c.Fs.GetAvailableDiskSize(p)
statvfs, err := fs.GetAvailableDiskSize(p)
if err == vfs.ErrStorageSizeUnavailable {
return c.getStatVFSFromQuotaResult(p, quotaResult), nil
return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil
}
return statvfs, err
}
// there is free space but some limits are configured
return c.getStatVFSFromQuotaResult(p, quotaResult), nil
return c.getStatVFSFromQuotaResult(fs, p, quotaResult), nil
}
func (c *Connection) getSFTPCmdTargetPath(requestTarget string) (string, error) {
var target string
// If a target is provided in this request validate that it is going to the correct
// location for the server. If it is not, return an error
if requestTarget != "" {
var err error
target, err = c.Fs.ResolvePath(requestTarget)
if err != nil {
return target, err
}
}
return target, nil
}
func (c *Connection) handleSFTPSetstat(filePath string, request *sftp.Request) error {
func (c *Connection) handleSFTPSetstat(request *sftp.Request) error {
attrs := common.StatAttributes{
Flags: 0,
}
@ -322,50 +296,54 @@ func (c *Connection) handleSFTPSetstat(filePath string, request *sftp.Request) e
attrs.Size = int64(request.Attributes().Size)
}
return c.SetStat(filePath, request.Filepath, &attrs)
return c.SetStat(request.Filepath, &attrs)
}
func (c *Connection) handleSFTPRemove(filePath string, request *sftp.Request) error {
func (c *Connection) handleSFTPRemove(request *sftp.Request) error {
fs, fsPath, err := c.GetFsAndResolvedPath(request.Filepath)
if err != nil {
return err
}
var fi os.FileInfo
var err error
if fi, err = c.Fs.Lstat(filePath); err != nil {
c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", filePath, err)
return c.GetFsError(err)
if fi, err = fs.Lstat(fsPath); err != nil {
c.Log(logger.LevelDebug, "failed to remove a file %#v: stat error: %+v", fsPath, err)
return c.GetFsError(fs, err)
}
if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 {
c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", filePath)
c.Log(logger.LevelDebug, "cannot remove %#v is not a file/symlink", fsPath)
return sftp.ErrSSHFxFailure
}
return c.RemoveFile(filePath, request.Filepath, fi)
return c.RemoveFile(fs, fsPath, request.Filepath, fi)
}
func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) {
func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) {
quotaResult := c.HasSpace(true, false, requestPath)
if !quotaResult.HasSpace {
c.Log(logger.LevelInfo, "denying file write due to quota limits")
return nil, sftp.ErrSSHFxFailure
}
file, w, cancelFn, err := c.Fs.Create(filePath, 0)
file, w, cancelFn, err := fs.Create(filePath, 0)
if err != nil {
c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
// we can get an error only for resume
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0)
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs)
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil
}
func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, resolvedPath, filePath string,
func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath string,
fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) {
var err error
quotaResult := c.HasSpace(false, false, requestPath)
@ -382,25 +360,25 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
// if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace
// will return false in this case and we deny the upload before.
// For Cloud FS GetMaxWriteSize will return unsupported operation
maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize)
maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize, fs.IsUploadResumeSupported())
if err != nil {
c.Log(logger.LevelDebug, "unable to get max write size: %v", err)
return nil, err
}
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
err = c.Fs.Rename(resolvedPath, filePath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
err = fs.Rename(resolvedPath, filePath)
if err != nil {
c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v",
resolvedPath, filePath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
}
file, w, cancelFn, err := c.Fs.Create(filePath, osFlags)
file, w, cancelFn, err := fs.Create(filePath, osFlags)
if err != nil {
c.Log(logger.LevelWarn, "error opening existing file, flags: %v, source: %#v, err: %+v", pflags, filePath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
initialSize := int64(0)
@ -409,7 +387,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
minWriteOffset = fileSize
initialSize = fileSize
} else {
if vfs.IsLocalOrSFTPFs(c.Fs) && isTruncate {
if vfs.IsLocalOrSFTPFs(fs) && isTruncate {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
@ -424,10 +402,10 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
}
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs)
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil
@ -438,9 +416,9 @@ func (c *Connection) Disconnect() error {
return c.channel.Close()
}
func (c *Connection) getStatVFSFromQuotaResult(name string, quotaResult vfs.QuotaCheckResult) *sftp.StatVFS {
func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResult vfs.QuotaCheckResult) *sftp.StatVFS {
if quotaResult.QuotaSize == 0 || quotaResult.QuotaFiles == 0 {
s, err := c.Fs.GetAvailableDiskSize(name)
s, err := fs.GetAvailableDiskSize(name)
if err == nil {
if quotaResult.QuotaSize == 0 {
quotaResult.QuotaSize = int64(s.TotalSpace())

View file

@ -20,6 +20,7 @@ import (
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/kms"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs"
)
@ -124,7 +125,7 @@ func (fs MockOsFs) Rename(source, target string) error {
func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
return &MockOsFs{
Fs: vfs.NewOsFs(connectionID, rootDir, nil),
Fs: vfs.NewOsFs(connectionID, rootDir, ""),
err: err,
statErr: statErr,
isAtomicUploadSupported: atomicUpload,
@ -156,15 +157,15 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
_, err = transfer.WriteAt([]byte("test"), 0)
assert.Error(t, err, "upload with invalid offset must fail")
if assert.Error(t, transfer.ErrTransfer) {
assert.EqualError(t, err, transfer.ErrTransfer.Error())
assert.Contains(t, transfer.ErrTransfer.Error(), "Invalid write offset")
assert.Contains(t, transfer.ErrTransfer.Error(), "invalid write offset")
}
err = transfer.Close()
@ -184,8 +185,8 @@ func TestReadWriteErrors(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, user)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
err = file.Close()
@ -233,8 +234,7 @@ func TestReadWriteErrors(t *testing.T) {
}
func TestUnsupportedListOP(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs)
conn := common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{})
sftpConn := Connection{
BaseConnection: conn,
}
@ -254,8 +254,8 @@ func TestTransferCancelFn(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := common.NewBaseConnection("", common.ProtocolSFTP, user)
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
@ -274,77 +274,37 @@ func TestTransferCancelFn(t *testing.T) {
assert.NoError(t, err)
}
func TestMockFsErrors(t *testing.T) {
errFake := errors.New("fake error")
fs := newMockOsFs(errFake, errFake, false, "123", os.TempDir())
u := dataprovider.User{}
u.Username = "test_username"
u.Permissions = make(map[string][]string)
u.Permissions["/"] = []string{dataprovider.PermAny}
u.HomeDir = os.TempDir()
c := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs),
}
testfile := filepath.Join(u.HomeDir, "testfile")
request := sftp.NewRequest("Remove", testfile)
err := os.WriteFile(testfile, []byte("test"), os.ModePerm)
assert.NoError(t, err)
_, err = c.Filewrite(request)
assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
var flags sftp.FileOpenFlags
flags.Write = true
flags.Trunc = false
flags.Append = true
_, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile", nil)
assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
fs = newMockOsFs(errFake, nil, false, "123", os.TempDir())
c.BaseConnection.Fs = fs
err = c.handleSFTPRemove(testfile, request)
assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
request = sftp.NewRequest("Rename", filepath.Base(testfile))
request.Target = filepath.Base(testfile) + "1"
err = c.Rename(testfile, testfile+"1", request.Filepath, request.Target)
assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
err = os.Remove(testfile)
assert.NoError(t, err)
}
func TestUploadFiles(t *testing.T) {
oldUploadMode := common.Config.UploadMode
common.Config.UploadMode = common.UploadModeAtomic
fs := vfs.NewOsFs("123", os.TempDir(), nil)
fs := vfs.NewOsFs("123", os.TempDir(), "")
u := dataprovider.User{}
c := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u),
}
var flags sftp.FileOpenFlags
flags.Write = true
flags.Trunc = true
_, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
_, err := c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
common.Config.UploadMode = common.UploadModeStandard
_, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
_, err = c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
missingFile := "missing/relative/file.txt"
if runtime.GOOS == osWindows {
missingFile = "missing\\relative\\file.txt"
}
_, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing", nil)
_, err = c.handleSFTPUploadToNewFile(fs, ".", missingFile, "/missing", nil)
assert.Error(t, err, "upload new file in missing path must fail")
c.BaseConnection.Fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
f, err := os.CreateTemp("", "temp")
assert.NoError(t, err)
err = f.Close()
assert.NoError(t, err)
tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name(), nil)
tr, err := c.handleSFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name(), nil)
if assert.NoError(t, err) {
transfer := tr.(*transfer)
transfers := c.GetTransfers()
@ -358,7 +318,7 @@ func TestUploadFiles(t *testing.T) {
}
err = os.Remove(f.Name())
assert.NoError(t, err)
common.Config.UploadMode = oldUploadMode
common.Config.UploadMode = common.UploadModeAtomicWithResume
}
func TestWithInvalidHome(t *testing.T) {
@ -371,9 +331,9 @@ func TestWithInvalidHome(t *testing.T) {
fs, err := u.GetFilesystem("123")
assert.NoError(t, err)
c := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u),
}
_, err = c.Fs.ResolvePath("../upper_path")
_, err = fs.ResolvePath("../upper_path")
assert.Error(t, err, "tested path is not a home subdir")
_, err = c.StatVFS(&sftp.Request{
Method: "StatVFS",
@ -382,25 +342,6 @@ func TestWithInvalidHome(t *testing.T) {
assert.Error(t, err)
}
func TestSFTPCmdTargetPath(t *testing.T) {
u := dataprovider.User{}
if runtime.GOOS == osWindows {
u.HomeDir = "C:\\invalid_home"
} else {
u.HomeDir = "/invalid_home"
}
u.Username = "testuser"
u.Permissions = make(map[string][]string)
u.Permissions["/"] = []string{dataprovider.PermAny}
fs, err := u.GetFilesystem("123")
assert.NoError(t, err)
connection := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, fs),
}
_, err = connection.getSFTPCmdTargetPath("invalid_path")
assert.True(t, os.IsNotExist(err))
}
func TestSFTPGetUsedQuota(t *testing.T) {
u := dataprovider.User{}
u.HomeDir = "home_rel_path"
@ -410,7 +351,7 @@ func TestSFTPGetUsedQuota(t *testing.T) {
u.Permissions = make(map[string][]string)
u.Permissions["/"] = []string{dataprovider.PermAny}
connection := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u, nil),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, u),
}
quotaResult := connection.HasSpace(false, false, "/")
assert.False(t, quotaResult.HasSpace)
@ -506,10 +447,8 @@ func TestSSHCommandErrors(t *testing.T) {
user := dataprovider.User{}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
connection := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, user),
channel: &mockSSHChannel,
}
cmd := sshCommand{
@ -517,7 +456,7 @@ func TestSSHCommandErrors(t *testing.T) {
connection: &connection,
args: []string{},
}
err = cmd.handle()
err := cmd.handle()
assert.Error(t, err, "ssh command must fail, we are sending a fake error")
cmd = sshCommand{
@ -539,9 +478,8 @@ func TestSSHCommandErrors(t *testing.T) {
cmd.connection.User.HomeDir = filepath.Clean(os.TempDir())
cmd.connection.User.QuotaFiles = 1
cmd.connection.User.UsedQuotaFiles = 2
fs, err = cmd.connection.User.GetFilesystem("123")
fs, err := cmd.connection.User.GetFilesystem("123")
assert.NoError(t, err)
cmd.connection.Fs = fs
err = cmd.handle()
assert.EqualError(t, err, common.ErrQuotaExceeded.Error())
@ -580,10 +518,6 @@ func TestSSHCommandErrors(t *testing.T) {
err = cmd.executeSystemCommand(command)
assert.Error(t, err, "command must fail, pipe was already assigned")
fs, err = user.GetFilesystem("123")
assert.NoError(t, err)
connection.Fs = fs
cmd = sshCommand{
command: "sftpgo-remove",
connection: &connection,
@ -601,22 +535,12 @@ func TestSSHCommandErrors(t *testing.T) {
assert.Error(t, err, "ssh command must fail, we are requesting an invalid path")
cmd.connection.User.HomeDir = filepath.Clean(os.TempDir())
fs, err = cmd.connection.User.GetFilesystem("123")
assert.NoError(t, err)
cmd.connection.Fs = fs
_, _, err = cmd.resolveCopyPaths(".", "../adir")
assert.Error(t, err)
cmd = sshCommand{
command: "sftpgo-copy",
connection: &connection,
args: []string{"src", "dst"},
}
cmd.connection.User.Permissions = make(map[string][]string)
cmd.connection.User.Permissions["/"] = []string{dataprovider.PermDownload}
src, dst, err := cmd.getCopyPaths()
assert.NoError(t, err)
assert.False(t, cmd.hasCopyPermissions(src, dst, nil))
cmd.connection.User.Permissions = make(map[string][]string)
cmd.connection.User.Permissions["/"] = []string{dataprovider.PermAny}
@ -629,7 +553,7 @@ func TestSSHCommandErrors(t *testing.T) {
assert.NoError(t, err)
err = os.Chmod(aDir, 0001)
assert.NoError(t, err)
err = cmd.checkCopyDestination(tmpFile)
err = cmd.checkCopyDestination(fs, tmpFile)
assert.Error(t, err)
err = os.Chmod(aDir, os.ModePerm)
assert.NoError(t, err)
@ -661,10 +585,8 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
},
}
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, user),
channel: &mockSSHChannel,
}
cmd := sshCommand{
@ -672,7 +594,7 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
connection: connection,
args: []string{"subdir/test.png"},
}
err = cmd.handleHashCommands()
err := cmd.handleHashCommands()
assert.EqualError(t, err, common.ErrPermissionDenied.Error())
cmd = sshCommand{
@ -715,28 +637,17 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
}
server, client := net.Pipe()
defer func() {
err := server.Close()
assert.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
user := dataprovider.User{}
user.FsConfig = dataprovider.Filesystem{
Provider: dataprovider.S3FilesystemProvider,
user.FsConfig = vfs.Filesystem{
Provider: vfs.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
Bucket: "s3bucket",
Endpoint: "endpoint",
Region: "eu-west-1",
},
}
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
channel: &mockSSHChannel,
}
cmd := sshCommand{
@ -755,15 +666,73 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
connection: connection,
args: []string{},
}
err = cmd.handeSFTPGoCopy()
err = cmd.handleSFTPGoCopy()
assert.Error(t, err)
cmd = sshCommand{
command: "sftpgo-remove",
connection: connection,
args: []string{},
}
err = cmd.handeSFTPGoRemove()
err = cmd.handleSFTPGoRemove()
assert.Error(t, err)
// the user has no permissions
assert.False(t, cmd.hasCopyPermissions("", "", nil))
}
func TestSSHCmdGetFsErrors(t *testing.T) {
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
}
user := dataprovider.User{
HomeDir: "relative path",
}
user.Permissions = map[string][]string{}
user.Permissions["/"] = []string{dataprovider.PermAny}
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
channel: &mockSSHChannel,
}
cmd := sshCommand{
command: "sftpgo-remove",
connection: connection,
args: []string{"path"},
}
err := cmd.handleSFTPGoRemove()
assert.Error(t, err)
cmd = sshCommand{
command: "sftpgo-copy",
connection: connection,
args: []string{"path1", "path2"},
}
_, _, _, _, _, _, err = cmd.getFsAndCopyPaths() //nolint:dogsled
assert.Error(t, err)
user = dataprovider.User{}
user.HomeDir = filepath.Join(os.TempDir(), "home")
user.VirtualFolders = append(connection.User.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
MappedPath: "relative",
},
VirtualPath: "/vpath",
})
connection.User = user
err = os.MkdirAll(user.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
cmd = sshCommand{
command: "sftpgo-copy",
connection: connection,
args: []string{"path1", "/vpath/path2"},
}
_, _, _, _, _, _, err = cmd.getFsAndCopyPaths() //nolint:dogsled
assert.Error(t, err)
err = os.Remove(user.GetHomeDir())
assert.NoError(t, err)
}
func TestGitVirtualFolders(t *testing.T) {
@ -773,10 +742,8 @@ func TestGitVirtualFolders(t *testing.T) {
Permissions: permissions,
HomeDir: os.TempDir(),
}
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
conn := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
cmd := sshCommand{
command: "git-receive-pack",
@ -789,7 +756,7 @@ func TestGitVirtualFolders(t *testing.T) {
},
VirtualPath: "/vdir",
})
_, err = cmd.getSystemCommand()
_, err := cmd.getSystemCommand()
assert.NoError(t, err)
cmd.args = []string{"/"}
_, err = cmd.getSystemCommand()
@ -821,10 +788,8 @@ func TestRsyncOptions(t *testing.T) {
Permissions: permissions,
HomeDir: os.TempDir(),
}
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
conn := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
sshCmd := sshCommand{
command: "rsync",
@ -839,11 +804,9 @@ func TestRsyncOptions(t *testing.T) {
permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs,
dataprovider.PermListItems, dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermRename}
user.Permissions = permissions
fs, err = user.GetFilesystem("123")
assert.NoError(t, err)
conn = &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
sshCmd = sshCommand{
command: "rsync",
@ -875,14 +838,14 @@ func TestSystemCommandSizeForPath(t *testing.T) {
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
conn := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
sshCmd := sshCommand{
command: "rsync",
connection: conn,
args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"},
}
_, _, err = sshCmd.getSizeForPath("missing path")
_, _, err = sshCmd.getSizeForPath(fs, "missing path")
assert.NoError(t, err)
testDir := filepath.Join(os.TempDir(), "dir")
err = os.MkdirAll(testDir, os.ModePerm)
@ -892,18 +855,18 @@ func TestSystemCommandSizeForPath(t *testing.T) {
assert.NoError(t, err)
err = os.Symlink(testFile, testFile+".link")
assert.NoError(t, err)
numFiles, size, err := sshCmd.getSizeForPath(testFile + ".link")
numFiles, size, err := sshCmd.getSizeForPath(fs, testFile+".link")
assert.NoError(t, err)
assert.Equal(t, 0, numFiles)
assert.Equal(t, int64(0), size)
numFiles, size, err = sshCmd.getSizeForPath(testFile)
numFiles, size, err = sshCmd.getSizeForPath(fs, testFile)
assert.NoError(t, err)
assert.Equal(t, 1, numFiles)
assert.Equal(t, int64(12), size)
if runtime.GOOS != osWindows {
err = os.Chmod(testDir, 0001)
assert.NoError(t, err)
_, _, err = sshCmd.getSizeForPath(testFile)
_, _, err = sshCmd.getSizeForPath(fs, testFile)
assert.Error(t, err)
err = os.Chmod(testDir, os.ModePerm)
assert.NoError(t, err)
@ -923,15 +886,6 @@ func TestSystemCommandErrors(t *testing.T) {
ReadError: nil,
WriteError: writeErr,
}
server, client := net.Pipe()
defer func() {
err := server.Close()
assert.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
permissions := make(map[string][]string)
permissions["/"] = []string{dataprovider.PermAny}
homeDir := filepath.Join(os.TempDir(), "adir")
@ -946,7 +900,7 @@ func TestSystemCommandErrors(t *testing.T) {
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
channel: &mockSSHChannel,
}
var sshCmd sshCommand
@ -1015,6 +969,48 @@ func TestSystemCommandErrors(t *testing.T) {
assert.NoError(t, err)
}
func TestCommandGetFsError(t *testing.T) {
user := dataprovider.User{
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
},
}
conn := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
sshCmd := sshCommand{
command: "rsync",
connection: conn,
args: []string{"--server", "-vlogDtprze.iLsfxC", ".", "/"},
}
_, err := sshCmd.getSystemCommand()
assert.Error(t, err)
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
}
conn = &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
command: "scp",
connection: conn,
args: []string{"-t", "/tmp"},
},
}
err = scpCommand.handleRecursiveUpload()
assert.Error(t, err)
err = scpCommand.handleDownload("")
assert.Error(t, err)
}
func TestGetConnectionInfo(t *testing.T) {
c := common.ConnectionStatus{
Username: "test_user",
@ -1088,10 +1084,9 @@ func TestSCPUploadError(t *testing.T) {
Permissions: make(map[string][]string),
}
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("", user.HomeDir, nil)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1129,10 +1124,11 @@ func TestSCPInvalidEndDir(t *testing.T) {
Buffer: bytes.NewBuffer([]byte("E\n")),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs),
channel: &mockSSHChannel,
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{
HomeDir: os.TempDir(),
}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1153,10 +1149,12 @@ func TestSCPParseUploadMessage(t *testing.T) {
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
fs := vfs.NewOsFs("", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs),
channel: &mockSSHChannel,
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{
HomeDir: os.TempDir(),
}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1165,16 +1163,16 @@ func TestSCPParseUploadMessage(t *testing.T) {
args: []string{"-t", "/tmp"},
},
}
_, _, err := scpCommand.parseUploadMessage("invalid")
_, _, err := scpCommand.parseUploadMessage(fs, "invalid")
assert.Error(t, err, "parsing invalid upload message must fail")
_, _, err = scpCommand.parseUploadMessage("D0755 0")
_, _, err = scpCommand.parseUploadMessage(fs, "D0755 0")
assert.Error(t, err, "parsing incomplete upload message must fail")
_, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
_, _, err = scpCommand.parseUploadMessage(fs, "D0755 invalidsize testdir")
assert.Error(t, err, "parsing upload message with invalid size must fail")
_, _, err = scpCommand.parseUploadMessage("D0755 0 ")
_, _, err = scpCommand.parseUploadMessage(fs, "D0755 0 ")
assert.Error(t, err, "parsing upload message with invalid name must fail")
}
@ -1190,7 +1188,7 @@ func TestSCPProtocolMessages(t *testing.T) {
WriteError: writeErr,
}
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1251,7 +1249,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
WriteError: writeErr,
}
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1325,7 +1323,7 @@ func TestSCPCommandHandleErrors(t *testing.T) {
assert.NoError(t, err)
}()
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1344,7 +1342,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
func TestSCPErrorsMockFs(t *testing.T) {
errFake := errors.New("fake error")
fs := newMockOsFs(errFake, errFake, false, "1234", os.TempDir())
u := dataprovider.User{}
u.Username = "test"
u.Permissions = make(map[string][]string)
@ -1367,7 +1364,7 @@ func TestSCPErrorsMockFs(t *testing.T) {
}()
connection := &Connection{
channel: &mockSSHChannel,
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u),
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1376,22 +1373,12 @@ func TestSCPErrorsMockFs(t *testing.T) {
args: []string{"-r", "-t", "/tmp"},
},
}
err := scpCommand.handleUpload("test", 0)
assert.EqualError(t, err, errFake.Error())
testfile := filepath.Join(u.HomeDir, "testfile")
err = os.WriteFile(testfile, []byte("test"), os.ModePerm)
err := os.WriteFile(testfile, []byte("test"), os.ModePerm)
assert.NoError(t, err)
stat, err := os.Stat(u.HomeDir)
assert.NoError(t, err)
err = scpCommand.handleRecursiveDownload(u.HomeDir, stat)
assert.EqualError(t, err, errFake.Error())
scpCommand.sshCommand.connection.Fs = newMockOsFs(errFake, nil, true, "123", os.TempDir())
err = scpCommand.handleUpload(filepath.Base(testfile), 0)
assert.EqualError(t, err, errFake.Error())
err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4, "/testfile")
fs := newMockOsFs(errFake, nil, true, "123", os.TempDir())
err = scpCommand.handleUploadFile(fs, testfile, testfile, 0, false, 4, "/testfile")
assert.NoError(t, err)
err = os.Remove(testfile)
assert.NoError(t, err)
@ -1417,10 +1404,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
err := client.Close()
assert.NoError(t, err)
}()
fs := vfs.NewOsFs("123", os.TempDir(), nil)
fs := vfs.NewOsFs("123", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
channel: &mockSSHChannel,
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{
HomeDir: os.TempDir(),
}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1434,7 +1423,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
assert.NoError(t, err)
stat, err := os.Stat(path)
assert.NoError(t, err)
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat)
assert.EqualError(t, err, writeErr.Error())
mockSSHChannel = MockChannel{
@ -1444,7 +1433,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
WriteError: nil,
}
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat)
assert.Error(t, err, "recursive upload download must fail for a non existing dir")
err = os.Remove(path)
@ -1463,7 +1452,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
WriteError: writeErr,
}
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1504,7 +1493,7 @@ func TestSCPCreateDirs(t *testing.T) {
fs, err := u.GetFilesystem("123")
assert.NoError(t, err)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1514,7 +1503,7 @@ func TestSCPCreateDirs(t *testing.T) {
args: []string{"-r", "-t", "/tmp"},
},
}
err = scpCommand.handleCreateDir("invalid_dir")
err = scpCommand.handleCreateDir(fs, "invalid_dir")
assert.Error(t, err, "create invalid dir must fail")
}
@ -1536,10 +1525,10 @@ func TestSCPDownloadFileData(t *testing.T) {
ReadError: nil,
WriteError: writeErr,
}
fs := vfs.NewOsFs("", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{},
vfs.NewOsFs("", os.TempDir(), nil)),
channel: &mockSSHChannelReadErr,
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{HomeDir: os.TempDir()}),
channel: &mockSSHChannelReadErr,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1552,19 +1541,19 @@ func TestSCPDownloadFileData(t *testing.T) {
assert.NoError(t, err)
stat, err := os.Stat(testfile)
assert.NoError(t, err)
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil)
assert.EqualError(t, err, readErr.Error())
scpCommand.connection.channel = &mockSSHChannelWriteErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil)
assert.EqualError(t, err, writeErr.Error())
scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil)
assert.EqualError(t, err, writeErr.Error())
scpCommand.connection.channel = &mockSSHChannelReadErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil)
assert.EqualError(t, err, readErr.Error())
err = os.Remove(testfile)
@ -1586,9 +1575,9 @@ func TestSCPUploadFiledata(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
fs := vfs.NewOsFs("", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
@ -1670,15 +1659,14 @@ func TestSCPUploadFiledata(t *testing.T) {
}
func TestUploadError(t *testing.T) {
oldUploadMode := common.Config.UploadMode
common.Config.UploadMode = common.UploadModeAtomic
user := dataprovider.User{
Username: "testuser",
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
fs := vfs.NewOsFs("", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, user),
}
testfile := "testfile"
@ -1703,19 +1691,26 @@ func TestUploadError(t *testing.T) {
assert.NoFileExists(t, testfile)
assert.NoFileExists(t, fileTempName)
common.Config.UploadMode = oldUploadMode
common.Config.UploadMode = common.UploadModeAtomicWithResume
}
func TestTransferFailingReader(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
HomeDir: os.TempDir(),
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret("crypt secret"),
},
},
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := newMockOsFs(nil, nil, true, "", os.TempDir())
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
request := sftp.NewRequest("Open", "afile.txt")
@ -1874,7 +1869,7 @@ func TestRecursiveCopyErrors(t *testing.T) {
fs, err := user.GetFilesystem("123")
assert.NoError(t, err)
conn := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user),
}
sshCmd := sshCommand{
command: "sftpgo-copy",
@ -1882,7 +1877,7 @@ func TestRecursiveCopyErrors(t *testing.T) {
args: []string{"adir", "another"},
}
// try to copy a missing directory
err = sshCmd.checkRecursiveCopyPermissions("adir", "another", "/another")
err = sshCmd.checkRecursiveCopyPermissions(fs, fs, "adir", "another", "/another")
assert.Error(t, err)
}
@ -1893,10 +1888,10 @@ func TestSFTPSubSystem(t *testing.T) {
Permissions: permissions,
HomeDir: os.TempDir(),
}
user.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider
user.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
err := ServeSubSystemConnection(user, "connID", nil, nil)
assert.Error(t, err)
user.FsConfig.Provider = dataprovider.LocalFilesystemProvider
user.FsConfig.Provider = vfs.LocalFilesystemProvider
buf := make([]byte, 0, 4096)
stdErrBuf := make([]byte, 0, 4096)
@ -1923,7 +1918,7 @@ func TestRecoverer(t *testing.T) {
c.AcceptInboundConnection(nil, nil)
connID := "connectionID"
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}),
}
c.handleSftpConnection(nil, connection)
sshCmd := sshCommand{

View file

@ -76,36 +76,51 @@ func (c *scpCommand) handleRecursiveUpload() error {
numDirs := 0
destPath := c.getDestPath()
for {
fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
if err != nil {
c.connection.Log(logger.LevelWarn, "error uploading file %#v: %+v", destPath, err)
c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath))
return err
}
command, err := c.getNextUploadProtocolMessage()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
c.sendErrorMessage(fs, err)
return err
}
if strings.HasPrefix(command, "E") {
numDirs--
c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
if numDirs < 0 {
return errors.New("unacceptable end dir command")
err = errors.New("unacceptable end dir command")
c.sendErrorMessage(nil, err)
return err
}
// the destination dir is now the parent directory
destPath = path.Join(destPath, "..")
} else {
sizeToRead, name, err := c.parseUploadMessage(command)
sizeToRead, name, err := c.parseUploadMessage(fs, command)
if err != nil {
return err
}
if strings.HasPrefix(command, "D") {
numDirs++
destPath = path.Join(destPath, name)
err = c.handleCreateDir(destPath)
fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID)
if err != nil {
c.connection.Log(logger.LevelWarn, "error uploading file %#v: %+v", destPath, err)
c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", destPath))
return err
}
err = c.handleCreateDir(fs, destPath)
if err != nil {
return err
}
c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %#v", numDirs, destPath)
} else if strings.HasPrefix(command, "C") {
err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead)
err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead)
if err != nil {
return err
}
@ -118,21 +133,27 @@ func (c *scpCommand) handleRecursiveUpload() error {
}
}
func (c *scpCommand) handleCreateDir(dirPath string) error {
func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error {
c.connection.UpdateLastActivity()
p, err := c.connection.Fs.ResolvePath(dirPath)
p, err := fs.ResolvePath(dirPath)
if err != nil {
c.connection.Log(logger.LevelWarn, "error creating dir: %#v, invalid file path, err: %v", dirPath, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) {
c.connection.Log(logger.LevelWarn, "error creating dir: %#v, permission denied", dirPath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
err = c.createDir(p)
info, err := c.connection.DoStat(dirPath, 1)
if err == nil && info.IsDir() {
return nil
}
err = c.createDir(fs, p)
if err != nil {
return err
}
@ -156,14 +177,14 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
for {
n, err := c.connection.channel.Read(buf)
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(transfer.Fs, err)
transfer.TransferError(err)
transfer.Close()
return err
}
_, err = transfer.WriteAt(buf[:n], sizeToRead-remaining)
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(transfer.Fs, err)
transfer.Close()
return err
}
@ -184,33 +205,33 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
}
err = transfer.Close()
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(transfer.Fs, err)
return err
}
return nil
}
func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
quotaResult := c.connection.HasSpace(isNewFile, false, requestPath)
if !quotaResult.HasSpace {
err := fmt.Errorf("denying file write due to quota limits")
c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", filePath, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
maxWriteSize, _ := c.connection.GetMaxWriteSize(quotaResult, false, fileSize)
maxWriteSize, _ := c.connection.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported())
file, w, cancelFn, err := c.connection.Fs.Create(filePath, 0)
file, w, cancelFn, err := fs.Create(filePath, 0)
if err != nil {
c.connection.Log(logger.LevelError, "error creating file %#v: %v", resolvedPath, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
initialSize := int64(0)
if !isNewFile {
if vfs.IsLocalOrSFTPFs(c.connection.Fs) {
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
@ -228,10 +249,10 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead
}
}
vfs.SetPathPermissions(c.connection.Fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, c.connection.Fs)
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
t := newTransfer(baseTransfer, w, nil, nil)
return c.getUploadFileData(sizeToRead, t)
@ -240,67 +261,66 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead
func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error {
c.connection.UpdateLastActivity()
var err error
fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath)
if err != nil {
c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", uploadFilePath, err)
c.sendErrorMessage(nil, err)
return err
}
if !c.connection.User.IsFileAllowed(uploadFilePath) {
c.connection.Log(logger.LevelWarn, "writing file %#v is not allowed", uploadFilePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
p, err := c.connection.Fs.ResolvePath(uploadFilePath)
if err != nil {
c.connection.Log(logger.LevelWarn, "error uploading file: %#v, err: %v", uploadFilePath, err)
c.sendErrorMessage(err)
return err
}
filePath := p
if common.Config.IsAtomicUploadEnabled() && c.connection.Fs.IsAtomicUploadSupported() {
filePath = c.connection.Fs.GetAtomicUploadPath(p)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
filePath = fs.GetAtomicUploadPath(p)
}
stat, statErr := c.connection.Fs.Lstat(p)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.connection.Fs.IsNotExist(statErr) {
stat, statErr := fs.Lstat(p)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) {
c.connection.Log(logger.LevelWarn, "cannot upload file: %#v, permission denied", uploadFilePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
return c.handleUploadFile(p, filePath, sizeToRead, true, 0, uploadFilePath)
return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath)
}
if statErr != nil {
c.connection.Log(logger.LevelError, "error performing file stat %#v: %v", p, statErr)
c.sendErrorMessage(statErr)
c.sendErrorMessage(fs, statErr)
return statErr
}
if stat.IsDir() {
c.connection.Log(logger.LevelWarn, "attempted to open a directory for writing to: %#v", p)
err = fmt.Errorf("Attempted to open a directory for writing: %#v", p)
c.sendErrorMessage(err)
err = fmt.Errorf("attempted to open a directory for writing: %#v", p)
c.sendErrorMessage(fs, err)
return err
}
if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) {
c.connection.Log(logger.LevelWarn, "cannot overwrite file: %#v, permission denied", uploadFilePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
if common.Config.IsAtomicUploadEnabled() && c.connection.Fs.IsAtomicUploadSupported() {
err = c.connection.Fs.Rename(p, filePath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
err = fs.Rename(p, filePath)
if err != nil {
c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %v",
p, filePath, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
}
return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size(), uploadFilePath)
return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath)
}
func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {
func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error {
var err error
if c.sendFileTime() {
modTime := stat.ModTime().UnixNano() / 1000000000
@ -315,12 +335,9 @@ func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileIn
}
}
dirName := filepath.Base(dirPath)
for _, v := range c.connection.User.VirtualFolders {
if v.MappedPath == dirPath {
dirName = path.Base(v.VirtualPath)
break
}
dirName := path.Base(virtualDirPath)
if dirName == "/" || dirName == "." {
dirName = c.connection.User.Username
}
fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName)
@ -334,23 +351,23 @@ func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileIn
// We send first all the files in the root directory and then the directories.
// For each directory we recursively call this method again
func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) error {
func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error {
var err error
if c.isRecursive() {
c.connection.Log(logger.LevelDebug, "recursive download, dir path: %#v", dirPath)
err = c.sendDownloadProtocolMessages(dirPath, stat)
c.connection.Log(logger.LevelDebug, "recursive download, dir path %#v virtual path %#v", dirPath, virtualPath)
err = c.sendDownloadProtocolMessages(virtualPath, stat)
if err != nil {
return err
}
files, err := c.connection.Fs.ReadDir(dirPath)
files = c.connection.User.AddVirtualDirs(files, c.connection.Fs.GetRelativePath(dirPath))
files, err := fs.ReadDir(dirPath)
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
files = c.connection.User.AddVirtualDirs(files, fs.GetRelativePath(dirPath))
var dirs []string
for _, file := range files {
filePath := c.connection.Fs.GetRelativePath(c.connection.Fs.Join(dirPath, file.Name()))
filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name()))
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 {
err = c.handleDownload(filePath)
if err != nil {
@ -361,7 +378,7 @@ func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) e
}
}
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
for _, dir := range dirs {
@ -371,25 +388,21 @@ func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) e
}
}
if err != nil {
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
err = c.sendProtocolMessage("E\n")
if err != nil {
return err
}
err = c.readConfirmationMessage()
if err != nil {
return err
}
return err
return c.readConfirmationMessage()
}
err = fmt.Errorf("Unable to send directory for non recursive copy")
c.sendErrorMessage(err)
err = fmt.Errorf("unable to send directory for non recursive copy")
c.sendErrorMessage(nil, err)
return err
}
func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, transfer *transfer) error {
func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error {
var err error
if c.sendFileTime() {
modTime := stat.ModTime().UnixNano() / 1000000000
@ -403,8 +416,8 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
return err
}
}
if vfs.IsCryptOsFs(c.connection.Fs) {
stat = c.connection.Fs.(*vfs.CryptFs).ConvertFileInfo(stat)
if vfs.IsCryptOsFs(fs) {
stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat)
}
fileSize := stat.Size()
@ -435,7 +448,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
}
}
if err != io.EOF {
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
err = c.sendConfirmationMessage()
@ -450,55 +463,62 @@ func (c *scpCommand) handleDownload(filePath string) error {
c.connection.UpdateLastActivity()
var err error
p, err := c.connection.Fs.ResolvePath(filePath)
fs, err := c.connection.User.GetFilesystemForPath(filePath, c.connection.ID)
if err != nil {
err := fmt.Errorf("Invalid file path")
c.connection.Log(logger.LevelWarn, "error downloading file %#v: %+v", filePath, err)
c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %#v", filePath))
return err
}
p, err := fs.ResolvePath(filePath)
if err != nil {
err := fmt.Errorf("invalid file path %#v", filePath)
c.connection.Log(logger.LevelWarn, "error downloading file: %#v, invalid file path", filePath)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
var stat os.FileInfo
if stat, err = c.connection.Fs.Stat(p); err != nil {
if stat, err = fs.Stat(p); err != nil {
c.connection.Log(logger.LevelWarn, "error downloading file: %#v->%#v, err: %v", filePath, p, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
if stat.IsDir() {
if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) {
c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
err = c.handleRecursiveDownload(p, stat)
err = c.handleRecursiveDownload(fs, p, filePath, stat)
return err
}
if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) {
c.connection.Log(logger.LevelWarn, "error downloading dir: %#v, permission denied", filePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
if !c.connection.User.IsFileAllowed(filePath) {
c.connection.Log(logger.LevelWarn, "reading file %#v is not allowed", filePath)
c.sendErrorMessage(common.ErrPermissionDenied)
c.sendErrorMessage(fs, common.ErrPermissionDenied)
return common.ErrPermissionDenied
}
file, r, cancelFn, err := c.connection.Fs.Open(p, 0)
file, r, cancelFn, err := fs.Open(p, 0)
if err != nil {
c.connection.Log(logger.LevelError, "could not open file %#v for reading: %v", p, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, filePath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
common.TransferDownload, 0, 0, 0, false, fs)
t := newTransfer(baseTransfer, nil, r, nil)
err = c.sendDownloadFileData(p, stat, t)
err = c.sendDownloadFileData(fs, p, stat, t)
// we need to call Close anyway and return close error if any and
// if we have no previous error
if err == nil {
@ -578,9 +598,13 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
// send an error message and close the channel
//nolint:errcheck // we don't check write errors here, we have to close the channel anyway
func (c *scpCommand) sendErrorMessage(err error) {
func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) {
c.connection.channel.Write(errMsg)
c.connection.channel.Write([]byte(c.connection.GetFsError(err).Error()))
if fs != nil {
c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error()))
} else {
c.connection.channel.Write([]byte(err.Error()))
}
c.connection.channel.Write(newLine)
c.connection.channel.Close()
}
@ -625,22 +649,14 @@ func (c *scpCommand) getNextUploadProtocolMessage() (string, error) {
return command, err
}
func (c *scpCommand) createDir(dirPath string) error {
var err error
var isDir bool
isDir, err = vfs.IsDirectory(c.connection.Fs, dirPath)
if err == nil && isDir {
// if this is a virtual dir the resolved path will exist, we don't need a specific check
// TODO: remember to check if it's okay when we'll add virtual folders support to cloud backends
c.connection.Log(logger.LevelDebug, "directory %#v already exists", dirPath)
return nil
}
if err = c.connection.Fs.Mkdir(dirPath); err != nil {
func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error {
err := fs.Mkdir(dirPath)
if err != nil {
c.connection.Log(logger.LevelError, "error creating dir %#v: %v", dirPath, err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return err
}
vfs.SetPathPermissions(c.connection.Fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
return err
}
@ -649,7 +665,7 @@ func (c *scpCommand) createDir(dirPath string) error {
// or:
// C0644 6 testfile
// and returns file size and file/directory name
func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) {
var size int64
var name string
var err error
@ -657,7 +673,7 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v",
command, c.args, c.connection.User.Username)
c.connection.Log(logger.LevelWarn, "error: %v", err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return size, name, err
}
parts := strings.SplitN(command, " ", 3)
@ -665,26 +681,26 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
size, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
c.connection.Log(logger.LevelWarn, "error getting size from upload message: %v", err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return size, name, err
}
name = parts[2]
if name == "" {
err = fmt.Errorf("error getting name from upload message, cannot be empty")
c.connection.Log(logger.LevelWarn, "error: %v", err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return size, name, err
}
} else {
err = fmt.Errorf("unable to split upload message: %#v", command)
c.connection.Log(logger.LevelWarn, "error: %v", err)
c.sendErrorMessage(err)
c.sendErrorMessage(fs, err)
return size, name, err
}
return size, name, err
}
func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string {
func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string {
if !c.isRecursive() {
// if the upload is not recursive and the destination path does not end with "/"
// then scpDestPath is the wanted filename, for example:
@ -695,8 +711,8 @@ func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string
// but if scpDestPath is an existing directory then we put the uploaded file
// inside that directory this is as scp command works, for example:
// scp fileName.txt user@127.0.0.1:/existing_dir
if p, err := c.connection.Fs.ResolvePath(scpDestPath); err == nil {
if stat, err := c.connection.Fs.Stat(p); err == nil {
if p, err := fs.ResolvePath(scpDestPath); err == nil {
if stat, err := fs.Stat(p); err == nil {
if stat.IsDir() {
return path.Join(scpDestPath, fileName)
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -400,10 +399,14 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
loginType := sconn.Permissions.Extensions["sftpgo_login_method"]
connectionID := hex.EncodeToString(sconn.SessionID())
if err = checkRootPath(&user, connectionID); err != nil {
if err = user.CheckFsRoot(connectionID); err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
return
}
defer user.CloseFs() //nolint:errcheck
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, ipAddr)
@ -445,34 +448,24 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
switch req.Type {
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
fs, err := user.GetFilesystem(connID)
if err == nil {
ok = true
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: conn.RemoteAddr(),
channel: channel,
}
go c.handleSftpConnection(channel, &connection)
} else {
logger.Debug(logSender, connID, "unable to create filesystem: %v", err)
}
}
case "exec":
// protocol will be set later inside processSSHCommand it could be SSH or SCP
fs, err := user.GetFilesystem(connID)
if err == nil {
ok = true
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: conn.RemoteAddr(),
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
} else {
logger.Debug(sshCommandLogSender, connID, "unable to create filesystem: %v", err)
go c.handleSftpConnection(channel, &connection)
}
case "exec":
// protocol will be set later inside processSSHCommand it could be SSH or SCP
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: conn.RemoteAddr(),
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
}
if req.WantReply {
req.Reply(ok, nil) //nolint:errcheck
@ -541,21 +534,6 @@ func checkAuthError(ip string, err error) {
}
}
func checkRootPath(user *dataprovider.User, connectionID string) error {
if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
// for sftp fs check root path does nothing so don't open a useless SFTP connection
fs, err := user.GetFilesystem(connectionID)
if err != nil {
logger.Warn(logSender, "", "could not create filesystem for user %#v err: %v", user.Username, err)
return err
}
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
fs.Close()
}
return nil
}
func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) {
connectionID := ""
if conn != nil {
@ -568,7 +546,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
}
if utils.IsStringInSlice(common.ProtocolSSH, user.Filters.DeniedProtocols) {
logger.Debug(logSender, connectionID, "cannot login user %#v, protocol SSH is not allowed", user.Username)
return nil, fmt.Errorf("Protocol SSH is not allowed for user %#v", user.Username)
return nil, fmt.Errorf("protocol SSH is not allowed for user %#v", user.Username)
}
if user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
@ -580,17 +558,12 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
}
if !user.IsLoginMethodAllowed(loginMethod, conn.PartialSuccessMethods()) {
logger.Debug(logSender, connectionID, "cannot login user %#v, login method %#v is not allowed", user.Username, loginMethod)
return nil, fmt.Errorf("Login method %#v is not allowed for user %#v", loginMethod, user.Username)
}
if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() {
logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled",
user.Username)
return nil, errors.New("overlapping mapped folders are allowed only with quota tracking disabled")
return nil, fmt.Errorf("login method %#v is not allowed for user %#v", loginMethod, user.Username)
}
remoteAddr := conn.RemoteAddr().String()
if !user.IsLoginFromAddrAllowed(remoteAddr) {
logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, remoteAddr)
return nil, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
return nil, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, remoteAddr)
}
json, err := json.Marshal(user)

View file

@ -146,6 +146,7 @@ func TestMain(m *testing.M) {
if err != nil {
logger.ErrorToConsole("error creating login banner: %v", err)
}
os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2")
err = config.LoadConfig(configDir, "")
if err != nil {
logger.ErrorToConsole("error loading configuration: %v", err)
@ -155,10 +156,6 @@ func TestMain(m *testing.M) {
logger.InfoToConsole("Starting SFTPD tests, provider: %v", providerConf.Driver)
commonConf := config.GetCommonConfig()
// we run the test cases with UploadMode atomic and resume support. The non atomic code path
// simply does not execute some code so if it works in atomic mode will
// work in non atomic mode too
commonConf.UploadMode = 2
homeBasePath = os.TempDir()
checkSystemCommands()
var scriptArgs string
@ -1130,7 +1127,7 @@ func TestChtimes(t *testing.T) {
defer client.Close()
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
testDir := "test"
testDir := "test" //nolint:goconst
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
@ -1687,7 +1684,7 @@ func TestLoginUserExpiration(t *testing.T) {
func TestLoginWithDatabaseCredentials(t *testing.T) {
usePubKey := true
u := getTestUser(usePubKey)
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "testbucket"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`)
@ -1736,7 +1733,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) {
func TestLoginInvalidFs(t *testing.T) {
usePubKey := true
u := getTestUser(usePubKey)
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "test"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials")
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
@ -2569,6 +2566,29 @@ func TestMaxSessions(t *testing.T) {
assert.NoError(t, err)
}
func TestSupportedExtensions(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
v, ok := client.HasExtension("statvfs@openssh.com")
assert.Equal(t, "2", v)
assert.True(t, ok)
_, ok = client.HasExtension("hardlink@openssh.com")
assert.False(t, ok)
_, ok = client.HasExtension("posix-rename@openssh.com")
assert.False(t, ok)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func TestQuotaFileReplace(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)
@ -3344,6 +3364,196 @@ func TestVirtualFoldersQuotaLimit(t *testing.T) {
assert.NoError(t, err)
}
func TestNestedVirtualFolders(t *testing.T) {
usePubKey := false
baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated)
assert.NoError(t, err, string(resp))
u := getTestSFTPUser(usePubKey)
u.QuotaFiles = 1000
mappedPathCrypt := filepath.Join(os.TempDir(), "crypt")
folderNameCrypt := filepath.Base(mappedPathCrypt)
vdirCryptPath := "/vdir/crypt"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameCrypt,
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret(defaultPassword),
},
},
MappedPath: mappedPathCrypt,
},
VirtualPath: vdirCryptPath,
QuotaFiles: 100,
})
mappedPath := filepath.Join(os.TempDir(), "local")
folderName := filepath.Base(mappedPath)
vdirPath := "/vdir/local"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: mappedPath,
},
VirtualPath: vdirPath,
QuotaFiles: -1,
QuotaSize: -1,
})
mappedPathNested := filepath.Join(os.TempDir(), "nested")
folderNameNested := filepath.Base(mappedPathNested)
vdirNestedPath := "/vdir/crypt/nested"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameNested,
MappedPath: mappedPathNested,
},
VirtualPath: vdirNestedPath,
QuotaFiles: -1,
QuotaSize: -1,
})
user, resp, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err, string(resp))
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
expectedQuotaSize := int64(0)
expectedQuotaFiles := 0
fileSize := int64(32765)
err = writeSFTPFile(testFileName, fileSize, client)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 38764
err = writeSFTPFile(path.Join("/vdir", testFileName), fileSize, client)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 18769
err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 27658
err = writeSFTPFile(path.Join(vdirNestedPath, testFileName), fileSize, client)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 39765
err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), fileSize, client)
assert.NoError(t, err)
userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles)
assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize)
folderGet, _, err := httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK)
assert.NoError(t, err)
assert.Greater(t, folderGet.UsedQuotaSize, fileSize)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(18769), folderGet.UsedQuotaSize)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, int64(27658), folderGet.UsedQuotaSize)
assert.Equal(t, 1, folderGet.UsedQuotaFiles)
files, err := client.ReadDir("/")
if assert.NoError(t, err) {
assert.Len(t, files, 2)
}
info, err := client.Stat("vdir")
if assert.NoError(t, err) {
assert.True(t, info.IsDir())
}
files, err = client.ReadDir("/vdir")
if assert.NoError(t, err) {
assert.Len(t, files, 3)
}
files, err = client.ReadDir(vdirCryptPath)
if assert.NoError(t, err) {
assert.Len(t, files, 2)
}
info, err = client.Stat(vdirNestedPath)
if assert.NoError(t, err) {
assert.True(t, info.IsDir())
}
// finally add some files directly using os method and then check quota
fName := "testfile"
fileSize = 123456
err = createTestFile(filepath.Join(baseUser.HomeDir, fName), fileSize)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 8765
err = createTestFile(filepath.Join(mappedPath, fName), fileSize)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
fileSize = 98751
err = createTestFile(filepath.Join(mappedPathNested, fName), fileSize)
assert.NoError(t, err)
expectedQuotaSize += fileSize
expectedQuotaFiles++
err = createTestFile(filepath.Join(mappedPathCrypt, fName), fileSize)
assert.NoError(t, err)
_, err = httpdtest.StartQuotaScan(user, http.StatusAccepted)
assert.NoError(t, err)
assert.Eventually(t, func() bool {
scans, _, err := httpdtest.GetQuotaScans(http.StatusOK)
if err == nil {
return len(scans) == 0
}
return false
}, 1*time.Second, 50*time.Millisecond)
userGet, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles)
assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize)
// the crypt folder is not included within user quota so we need to do a separate scan
_, err = httpdtest.StartFolderQuotaScan(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusAccepted)
assert.NoError(t, err)
assert.Eventually(t, func() bool {
scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK)
if err == nil {
return len(scans) == 0
}
return false
}, 1*time.Second, 50*time.Millisecond)
folderGet, _, err = httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK)
assert.NoError(t, err)
assert.Greater(t, folderGet.UsedQuotaSize, int64(39765+98751))
assert.Equal(t, 2, folderGet.UsedQuotaFiles)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(baseUser, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(baseUser.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(mappedPathCrypt)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathNested)
assert.NoError(t, err)
}
func TestTruncateQuotaLimits(t *testing.T) {
usePubKey := true
u := getTestUser(usePubKey)
@ -4733,208 +4943,6 @@ func TestVirtualFoldersLink(t *testing.T) {
assert.NoError(t, err)
}
func TestOverlappedMappedFolders(t *testing.T) {
err := dataprovider.Close()
assert.NoError(t, err)
err = config.LoadConfig(configDir, "")
assert.NoError(t, err)
providerConf := config.GetProviderConf()
providerConf.TrackQuota = 0
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
usePubKey := false
u := getTestUser(usePubKey)
subDir := "subdir"
mappedPath1 := filepath.Join(os.TempDir(), "vdir1")
folderName1 := filepath.Base(mappedPath1)
vdirPath1 := "/vdir1"
mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir)
folderName2 := filepath.Base(mappedPath2)
vdirPath2 := "/vdir2"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName1,
MappedPath: mappedPath1,
},
VirtualPath: vdirPath1,
})
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName2,
MappedPath: mappedPath2,
},
VirtualPath: vdirPath2,
})
err = os.MkdirAll(mappedPath1, os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(mappedPath2, os.ModePerm)
assert.NoError(t, err)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
testFileSize := int64(131072)
testFilePath := filepath.Join(homeBasePath, testFileName)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client)
assert.NoError(t, err)
fi, err := client.Stat(path.Join(vdirPath1, subDir, testFileName))
if assert.NoError(t, err) {
assert.Equal(t, testFileSize, fi.Size())
}
err = client.Rename(path.Join(vdirPath1, subDir, testFileName), path.Join(vdirPath2, testFileName+"1"))
assert.NoError(t, err)
err = client.Rename(path.Join(vdirPath2, testFileName+"1"), path.Join(vdirPath1, subDir, testFileName))
assert.NoError(t, err)
err = client.Rename(path.Join(vdirPath1, subDir), path.Join(vdirPath2, subDir))
assert.Error(t, err)
err = client.Mkdir(subDir)
assert.NoError(t, err)
err = client.Rename(subDir, path.Join(vdirPath1, subDir))
assert.Error(t, err)
err = client.RemoveDirectory(path.Join(vdirPath1, subDir))
assert.Error(t, err)
err = client.Symlink(path.Join(vdirPath1, subDir), path.Join(vdirPath1, "adir"))
assert.Error(t, err)
err = client.Mkdir(path.Join(vdirPath1, subDir+"1"))
assert.NoError(t, err)
err = client.Symlink(path.Join(vdirPath1, subDir+"1"), path.Join(vdirPath1, subDir))
assert.Error(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
_, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, subDir)), user, usePubKey)
assert.Error(t, err)
}
err = dataprovider.Close()
assert.NoError(t, err)
err = config.LoadConfig(configDir, "")
assert.NoError(t, err)
providerConf = config.GetProviderConf()
err = dataprovider.Initialize(providerConf, configDir, true)
assert.NoError(t, err)
if providerConf.Driver != dataprovider.MemoryDataProviderName {
client, err = getSftpClient(user, usePubKey)
if !assert.Error(t, err) {
client.Close()
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK)
assert.NoError(t, err)
}
_, _, err = httpdtest.AddUser(u, http.StatusCreated)
assert.Error(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(mappedPath1)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath2)
assert.NoError(t, err)
}
func TestResolveOverlappedMappedPaths(t *testing.T) {
u := getTestUser(false)
mappedPath1 := filepath.Join(os.TempDir(), "mapped1", "subdir")
folderName1 := filepath.Base(mappedPath1)
vdirPath1 := "/vdir1/subdir"
mappedPath2 := filepath.Join(os.TempDir(), "mapped2")
folderName2 := filepath.Base(mappedPath2)
vdirPath2 := "/vdir2/subdir"
mappedPath3 := filepath.Join(os.TempDir(), "mapped1")
folderName3 := filepath.Base(mappedPath3)
vdirPath3 := "/vdir3"
mappedPath4 := filepath.Join(os.TempDir(), "mapped1", "subdir", "vdir4")
folderName4 := filepath.Base(mappedPath4)
vdirPath4 := "/vdir4"
u.VirtualFolders = []vfs.VirtualFolder{
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName1,
MappedPath: mappedPath1,
},
VirtualPath: vdirPath1,
},
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName2,
MappedPath: mappedPath2,
},
VirtualPath: vdirPath2,
},
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName3,
MappedPath: mappedPath3,
},
VirtualPath: vdirPath3,
},
{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName4,
MappedPath: mappedPath4,
},
VirtualPath: vdirPath4,
},
}
err := os.MkdirAll(u.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(mappedPath1, os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(mappedPath2, os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(mappedPath3, os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(mappedPath4, os.ModePerm)
assert.NoError(t, err)
fs := vfs.NewOsFs("", u.GetHomeDir(), u.VirtualFolders)
p, err := fs.ResolvePath("/vdir1")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(u.GetHomeDir(), "vdir1"), p)
p, err = fs.ResolvePath("/vdir1/subdir")
assert.NoError(t, err)
assert.Equal(t, mappedPath1, p)
p, err = fs.ResolvePath("/vdir3/subdir/vdir4/file.txt")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(mappedPath4, "file.txt"), p)
p, err = fs.ResolvePath("/vdir4/file.txt")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(mappedPath4, "file.txt"), p)
assert.Equal(t, filepath.Join(mappedPath3, "subdir", "vdir4", "file.txt"), p)
assert.Equal(t, filepath.Join(mappedPath1, "vdir4", "file.txt"), p)
p, err = fs.ResolvePath("/vdir3/subdir/vdir4/../file.txt")
assert.NoError(t, err)
assert.Equal(t, filepath.Join(mappedPath3, "subdir", "file.txt"), p)
assert.Equal(t, filepath.Join(mappedPath1, "file.txt"), p)
err = os.RemoveAll(u.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(mappedPath4)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath1)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath3)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath2)
assert.NoError(t, err)
}
func TestVirtualFolderQuotaScan(t *testing.T) {
mappedPath := filepath.Join(os.TempDir(), "mapped_dir")
folderName := filepath.Base(mappedPath)
@ -5161,19 +5169,28 @@ func TestOpenError(t *testing.T) {
assert.Error(t, err, "file download must fail if we have no filesystem read permissions")
err = sftpUploadFile(localDownloadPath, testFileName, testFileSize, client)
assert.Error(t, err, "upload must fail if we have no filesystem write permissions")
err = client.Mkdir("test")
testDir := "test"
err = client.Mkdir(testDir)
assert.NoError(t, err)
err = createTestFile(filepath.Join(user.GetHomeDir(), testDir, testFileName), testFileSize)
assert.NoError(t, err)
err = os.Chmod(user.GetHomeDir(), 0000)
assert.NoError(t, err)
_, err = client.Lstat(testFileName)
assert.Error(t, err, "file stat must fail if we have no filesystem read permissions")
err = sftpUploadFile(localDownloadPath, path.Join(testDir, testFileName), testFileSize, client)
assert.ErrorIs(t, err, os.ErrPermission)
_, err = client.ReadLink(testFileName)
assert.ErrorIs(t, err, os.ErrPermission)
err = client.Remove(testFileName)
assert.ErrorIs(t, err, os.ErrPermission)
err = os.Chmod(user.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
err = os.Chmod(filepath.Join(user.GetHomeDir(), "test"), 0000)
err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), 0000)
assert.NoError(t, err)
err = client.Rename(testFileName, path.Join("test", testFileName))
err = client.Rename(testFileName, path.Join(testDir, testFileName))
assert.True(t, os.IsPermission(err))
err = os.Chmod(filepath.Join(user.GetHomeDir(), "test"), os.ModePerm)
err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm)
assert.NoError(t, err)
err = os.Remove(localDownloadPath)
assert.NoError(t, err)
@ -5941,23 +5958,23 @@ func TestRootDirCommands(t *testing.T) {
func TestRelativePaths(t *testing.T) {
user := getTestUser(true)
var path, rel string
filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders)}
filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "")}
keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/"
s3config := vfs.S3FsConfig{
KeyPrefix: keyPrefix,
}
s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), s3config)
s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config)
gcsConfig := vfs.GCSFsConfig{
KeyPrefix: keyPrefix,
}
gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), gcsConfig)
gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig)
sftpconfig := vfs.SFTPFsConfig{
Endpoint: sftpServerAddr,
Username: defaultUsername,
Password: kms.NewPlainSecret(defaultPassword),
Prefix: keyPrefix,
}
sftpfs, _ := vfs.NewSFTPFs("", sftpconfig)
sftpfs, _ := vfs.NewSFTPFs("", "", sftpconfig)
if runtime.GOOS != osWindows {
filesystems = append(filesystems, s3fs, gcsfs, sftpfs)
}
@ -6000,7 +6017,7 @@ func TestResolvePaths(t *testing.T) {
user := getTestUser(true)
var path, resolved string
var err error
filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders)}
filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "")}
keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/"
s3config := vfs.S3FsConfig{
KeyPrefix: keyPrefix,
@ -6009,12 +6026,12 @@ func TestResolvePaths(t *testing.T) {
}
err = os.MkdirAll(user.GetHomeDir(), os.ModePerm)
assert.NoError(t, err)
s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), s3config)
s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config)
assert.NoError(t, err)
gcsConfig := vfs.GCSFsConfig{
KeyPrefix: keyPrefix,
}
gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), gcsConfig)
gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig)
if runtime.GOOS != osWindows {
filesystems = append(filesystems, s3fs, gcsfs)
}
@ -6049,7 +6066,7 @@ func TestResolvePaths(t *testing.T) {
func TestVirtualRelativePaths(t *testing.T) {
user := getTestUser(true)
mappedPath := filepath.Join(os.TempDir(), "vdir")
mappedPath := filepath.Join(os.TempDir(), "mdir")
vdirPath := "/vdir" //nolint:goconst
user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
@ -6059,47 +6076,23 @@ func TestVirtualRelativePaths(t *testing.T) {
})
err := os.MkdirAll(mappedPath, os.ModePerm)
assert.NoError(t, err)
fs := vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders)
rel := fs.GetRelativePath(mappedPath)
fsRoot := vfs.NewOsFs("", user.GetHomeDir(), "")
fsVdir := vfs.NewOsFs("", mappedPath, vdirPath)
rel := fsVdir.GetRelativePath(mappedPath)
assert.Equal(t, vdirPath, rel)
rel = fs.GetRelativePath(filepath.Join(mappedPath, ".."))
rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, ".."))
assert.Equal(t, "/", rel)
// path outside home and virtual dir
rel = fs.GetRelativePath(filepath.Join(mappedPath, "../vdir1"))
rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "../vdir1"))
assert.Equal(t, "/", rel)
rel = fs.GetRelativePath(filepath.Join(mappedPath, "../vdir/file.txt"))
rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "../vdir1"))
assert.Equal(t, "/vdir", rel)
rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "file.txt"))
assert.Equal(t, "/vdir/file.txt", rel)
rel = fs.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt"))
rel = fsRoot.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt"))
assert.Equal(t, "/vdir1/file.txt", rel)
}
func TestResolveVirtualPaths(t *testing.T) {
user := getTestUser(true)
mappedPath := filepath.Join(os.TempDir(), "vdir")
vdirPath := "/vdir"
user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
MappedPath: mappedPath,
},
VirtualPath: vdirPath,
})
err := os.MkdirAll(mappedPath, os.ModePerm)
assert.NoError(t, err)
osFs := vfs.NewOsFs("", user.GetHomeDir(), user.VirtualFolders).(*vfs.OsFs)
b, f := osFs.GetFsPaths("/vdir/a.txt")
assert.Equal(t, mappedPath, b)
assert.Equal(t, filepath.Join(mappedPath, "a.txt"), f)
b, f = osFs.GetFsPaths("/vdir/sub with space & spécial chars/a.txt")
assert.Equal(t, mappedPath, b)
assert.Equal(t, filepath.Join(mappedPath, "sub with space & spécial chars/a.txt"), f)
b, f = osFs.GetFsPaths("/vdir/../a.txt")
assert.Equal(t, user.GetHomeDir(), b)
assert.Equal(t, filepath.Join(user.GetHomeDir(), "a.txt"), f)
b, f = osFs.GetFsPaths("/vdir1/a.txt")
assert.Equal(t, user.GetHomeDir(), b)
assert.Equal(t, filepath.Join(user.GetHomeDir(), "/vdir1/a.txt"), f)
}
func TestUserPerms(t *testing.T) {
user := getTestUser(true)
user.Permissions = make(map[string][]string)
@ -6504,7 +6497,7 @@ func TestStatVFS(t *testing.T) {
func TestStatVFSCloudBackend(t *testing.T) {
usePubKey := true
u := getTestUser(usePubKey)
u.FsConfig.Provider = dataprovider.AzureBlobFilesystemProvider
u.FsConfig.Provider = vfs.AzureBlobFilesystemProvider
u.FsConfig.AzBlobConfig.SASURL = "https://myaccount.blob.core.windows.net/sasurl"
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
@ -7050,6 +7043,32 @@ func TestSSHCopyQuotaLimits(t *testing.T) {
assert.NoError(t, err)
}
func TestSSHCopyRemoveNonLocalFs(t *testing.T) {
usePubKey := true
localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated)
assert.NoError(t, err)
sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(sftpUser, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
testDir := "test"
err = client.Mkdir(testDir)
assert.NoError(t, err)
_, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", testDir, testDir+"_copy"), sftpUser, usePubKey)
assert.Error(t, err)
_, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), sftpUser, usePubKey)
assert.Error(t, err)
}
_, err = httpdtest.RemoveUser(sftpUser, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(localUser, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(localUser.GetHomeDir())
assert.NoError(t, err)
}
func TestSSHRemove(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)
@ -7188,6 +7207,117 @@ func TestSSHRemove(t *testing.T) {
assert.NoError(t, err)
}
func TestSSHRemoveCryptFs(t *testing.T) {
usePubKey := false
u := getTestUserWithCryptFs(usePubKey)
u.QuotaFiles = 100
mappedPath1 := filepath.Join(os.TempDir(), "vdir1")
folderName1 := filepath.Base(mappedPath1)
vdirPath1 := "/vdir1/sub"
mappedPath2 := filepath.Join(os.TempDir(), "vdir2")
folderName2 := filepath.Base(mappedPath2)
vdirPath2 := "/vdir2/sub"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName1,
MappedPath: mappedPath1,
},
VirtualPath: vdirPath1,
QuotaFiles: -1,
QuotaSize: -1,
})
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName2,
MappedPath: mappedPath2,
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret(defaultPassword),
},
},
},
VirtualPath: vdirPath2,
QuotaFiles: 100,
QuotaSize: 0,
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
testDir := "tdir"
err = client.Mkdir(testDir)
assert.NoError(t, err)
err = client.Mkdir(path.Join(vdirPath1, testDir))
assert.NoError(t, err)
err = client.Mkdir(path.Join(vdirPath2, testDir))
assert.NoError(t, err)
testFileSize := int64(32768)
testFileSize1 := int64(65536)
testFileName1 := "test_file1.dat"
err = writeSFTPFile(testFileName, testFileSize, client)
assert.NoError(t, err)
err = writeSFTPFile(testFileName1, testFileSize1, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(testDir, testFileName), testFileSize, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(testDir, testFileName1), testFileSize1, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName), testFileSize, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client)
assert.NoError(t, err)
err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client)
assert.NoError(t, err)
_, err = runSSHCommand("sftpgo-remove /vdir2", user, usePubKey)
assert.Error(t, err)
out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey)
if assert.NoError(t, err) {
assert.Equal(t, "OK\n", string(out))
_, err := client.Stat(testFileName)
assert.Error(t, err)
}
out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey)
if assert.NoError(t, err) {
assert.Equal(t, "OK\n", string(out))
}
out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey)
if assert.NoError(t, err) {
assert.Equal(t, "OK\n", string(out))
}
out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir, testFileName)), user, usePubKey)
if assert.NoError(t, err) {
assert.Equal(t, "OK\n", string(out))
}
err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client)
assert.NoError(t, err)
out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey)
if assert.NoError(t, err) {
assert.Equal(t, "OK\n", string(out))
}
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err)
assert.Equal(t, 1, user.UsedQuotaFiles)
assert.Greater(t, user.UsedQuotaSize, testFileSize1)
}
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(mappedPath1)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath2)
assert.NoError(t, err)
}
func TestBasicGitCommands(t *testing.T) {
if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows {
t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test")
@ -7667,6 +7797,136 @@ func TestSCPVirtualFolders(t *testing.T) {
assert.NoError(t, err)
}
func TestSCPNestedFolders(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated)
assert.NoError(t, err, string(resp))
usePubKey := true
u := getTestUser(usePubKey)
u.HomeDir += "_folders"
u.Username += "_folders"
mappedPathSFTP := filepath.Join(os.TempDir(), "sftp")
folderNameSFTP := filepath.Base(mappedPathSFTP)
vdirSFTPPath := "/vdir/sftp"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameSFTP,
FsConfig: vfs.Filesystem{
Provider: vfs.SFTPFilesystemProvider,
SFTPConfig: vfs.SFTPFsConfig{
Endpoint: sftpServerAddr,
Username: baseUser.Username,
Password: kms.NewPlainSecret(defaultPassword),
},
},
},
VirtualPath: vdirSFTPPath,
})
mappedPathCrypt := filepath.Join(os.TempDir(), "crypt")
folderNameCrypt := filepath.Base(mappedPathCrypt)
vdirCryptPath := "/vdir/crypt"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameCrypt,
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret(defaultPassword),
},
},
MappedPath: mappedPathCrypt,
},
VirtualPath: vdirCryptPath,
})
user, resp, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err, string(resp))
baseDirDownPath := filepath.Join(os.TempDir(), "basedir-down")
err = os.Mkdir(baseDirDownPath, os.ModePerm)
assert.NoError(t, err)
baseDir := filepath.Join(os.TempDir(), "basedir")
err = os.Mkdir(baseDir, os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(filepath.Join(baseDir, vdirSFTPPath), os.ModePerm)
assert.NoError(t, err)
err = os.MkdirAll(filepath.Join(baseDir, vdirCryptPath), os.ModePerm)
assert.NoError(t, err)
err = createTestFile(filepath.Join(baseDir, vdirSFTPPath, testFileName), 32768)
assert.NoError(t, err)
err = createTestFile(filepath.Join(baseDir, vdirCryptPath, testFileName), 65535)
assert.NoError(t, err)
err = createTestFile(filepath.Join(baseDir, "vdir", testFileName), 65536)
assert.NoError(t, err)
remoteRootPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
info, err := client.Stat(path.Join(vdirCryptPath, testFileName))
assert.NoError(t, err)
assert.Equal(t, int64(65535), info.Size())
info, err = client.Stat(path.Join(vdirSFTPPath, testFileName))
assert.NoError(t, err)
assert.Equal(t, int64(32768), info.Size())
info, err = client.Stat(path.Join("/vdir", testFileName))
assert.NoError(t, err)
assert.Equal(t, int64(65536), info.Size())
}
err = scpDownload(baseDirDownPath, remoteRootPath, true, true)
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, "vdir", testFileName))
assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirCryptPath, testFileName))
assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirSFTPPath, testFileName))
if runtime.GOOS != osWindows {
err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), 0001)
assert.NoError(t, err)
err = scpDownload(baseDirDownPath, remoteRootPath, true, true)
assert.Error(t, err)
err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), os.ModePerm)
assert.NoError(t, err)
}
// now change the password for the base user, so SFTP folder will not work
baseUser.Password = defaultPassword + "_mod"
_, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "")
assert.NoError(t, err)
err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false)
assert.Error(t, err)
err = scpDownload(baseDirDownPath, remoteRootPath, true, true)
assert.Error(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(baseUser, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(baseUser.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(mappedPathCrypt)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathSFTP)
assert.NoError(t, err)
err = os.RemoveAll(baseDir)
assert.NoError(t, err)
err = os.RemoveAll(baseDirDownPath)
assert.NoError(t, err)
}
func TestSCPVirtualFoldersQuota(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
@ -8194,7 +8454,7 @@ func getTestUser(usePubKey bool) dataprovider.User {
func getTestSFTPUser(usePubKey bool) dataprovider.User {
u := getTestUser(usePubKey)
u.Username = defaultSFTPUsername
u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider
u.FsConfig.Provider = vfs.SFTPFilesystemProvider
u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr
u.FsConfig.SFTPConfig.Username = defaultUsername
u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword)
@ -8377,6 +8637,35 @@ func checkBasicSFTP(client *sftp.Client) error {
return err
}
func writeSFTPFile(name string, size int64, client *sftp.Client) error {
content := make([]byte, size)
_, err := rand.Read(content)
if err != nil {
return err
}
f, err := client.Create(name)
if err != nil {
return err
}
_, err = io.Copy(f, bytes.NewBuffer(content))
if err != nil {
f.Close()
return err
}
err = f.Close()
if err != nil {
return err
}
info, err := client.Stat(name)
if err != nil {
return err
}
if info.Size() != size {
return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size())
}
return nil
}
func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
srcFile, err := os.Open(localSourcePath)
if err != nil {

View file

@ -47,6 +47,7 @@ type systemCommand struct {
cmd *exec.Cmd
fsPath string
quotaCheckPath string
fs vfs.Fs
}
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
@ -82,7 +83,8 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %#v", name)
}
}
connection.Fs.Close()
err := connection.CloseFS()
connection.Log(logger.LevelDebug, "unable to unmarsh ssh command, close fs, err: %v", err)
return false
}
@ -112,45 +114,42 @@ func (c *sshCommand) handle() (err error) {
c.connection.channel.Write([]byte("/\n")) //nolint:errcheck
c.sendExitStatus(nil)
} else if c.command == "sftpgo-copy" {
return c.handeSFTPGoCopy()
return c.handleSFTPGoCopy()
} else if c.command == "sftpgo-remove" {
return c.handeSFTPGoRemove()
return c.handleSFTPGoRemove()
}
return
}
func (c *sshCommand) handeSFTPGoCopy() error {
if !vfs.IsLocalOsFs(c.connection.Fs) {
func (c *sshCommand) handleSFTPGoCopy() error {
fsSrc, fsDst, sshSourcePath, sshDestPath, fsSourcePath, fsDestPath, err := c.getFsAndCopyPaths()
if err != nil {
return c.sendErrorResponse(err)
}
if !c.isLocalCopy(sshSourcePath, sshDestPath) {
return c.sendErrorResponse(errUnsupportedConfig)
}
sshSourcePath, sshDestPath, err := c.getCopyPaths()
if err != nil {
return c.sendErrorResponse(err)
}
fsSourcePath, fsDestPath, err := c.resolveCopyPaths(sshSourcePath, sshDestPath)
if err != nil {
return c.sendErrorResponse(err)
}
if err := c.checkCopyDestination(fsDestPath); err != nil {
return c.sendErrorResponse(err)
if err := c.checkCopyDestination(fsDst, fsDestPath); err != nil {
return c.sendErrorResponse(c.connection.GetFsError(fsDst, err))
}
c.connection.Log(logger.LevelDebug, "requested copy %#v -> %#v sftp paths %#v -> %#v",
fsSourcePath, fsDestPath, sshSourcePath, sshDestPath)
fi, err := c.connection.Fs.Lstat(fsSourcePath)
fi, err := fsSrc.Lstat(fsSourcePath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err))
}
if err := c.checkCopyPermissions(fsSourcePath, fsDestPath, sshSourcePath, sshDestPath, fi); err != nil {
if err := c.checkCopyPermissions(fsSrc, fsDst, fsSourcePath, fsDestPath, sshSourcePath, sshDestPath, fi); err != nil {
return c.sendErrorResponse(err)
}
filesNum := 0
filesSize := int64(0)
if fi.IsDir() {
filesNum, filesSize, err = c.connection.Fs.GetDirSize(fsSourcePath)
filesNum, filesSize, err = fsSrc.GetDirSize(fsSourcePath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err))
}
if c.connection.User.HasVirtualFoldersInside(sshSourcePath) {
err := errors.New("unsupported copy source: the source directory contains virtual folders")
@ -177,7 +176,7 @@ func (c *sshCommand) handeSFTPGoCopy() error {
c.connection.Log(logger.LevelDebug, "start copy %#v -> %#v", fsSourcePath, fsDestPath)
err = fscopy.Copy(fsSourcePath, fsDestPath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fsSrc, err))
}
c.updateQuota(sshDestPath, filesNum, filesSize)
c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck
@ -185,10 +184,7 @@ func (c *sshCommand) handeSFTPGoCopy() error {
return nil
}
func (c *sshCommand) handeSFTPGoRemove() error {
if !vfs.IsLocalOsFs(c.connection.Fs) {
return c.sendErrorResponse(errUnsupportedConfig)
}
func (c *sshCommand) handleSFTPGoRemove() error {
sshDestPath, err := c.getRemovePath()
if err != nil {
return c.sendErrorResponse(err)
@ -196,20 +192,23 @@ func (c *sshCommand) handeSFTPGoRemove() error {
if !c.connection.User.HasPerm(dataprovider.PermDelete, path.Dir(sshDestPath)) {
return c.sendErrorResponse(common.ErrPermissionDenied)
}
fsDestPath, err := c.connection.Fs.ResolvePath(sshDestPath)
fs, fsDestPath, err := c.connection.GetFsAndResolvedPath(sshDestPath)
if err != nil {
return c.sendErrorResponse(err)
}
fi, err := c.connection.Fs.Lstat(fsDestPath)
if !vfs.IsLocalOrCryptoFs(fs) {
return c.sendErrorResponse(errUnsupportedConfig)
}
fi, err := fs.Lstat(fsDestPath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fs, err))
}
filesNum := 0
filesSize := int64(0)
if fi.IsDir() {
filesNum, filesSize, err = c.connection.Fs.GetDirSize(fsDestPath)
filesNum, filesSize, err = fs.GetDirSize(fsDestPath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fs, err))
}
if sshDestPath == "/" {
err := errors.New("removing root dir is not allowed")
@ -223,10 +222,6 @@ func (c *sshCommand) handeSFTPGoRemove() error {
err := errors.New("unsupported remove source: this directory is a virtual folder")
return c.sendErrorResponse(err)
}
if c.connection.User.IsMappedPath(fsDestPath) {
err := errors.New("removing a directory mapped as virtual folder is not allowed")
return c.sendErrorResponse(err)
}
} else if fi.Mode().IsRegular() {
filesNum = 1
filesSize = fi.Size()
@ -284,18 +279,18 @@ func (c *sshCommand) handleHashCommands() error {
sshPath := c.getDestPath()
if !c.connection.User.IsFileAllowed(sshPath) {
c.connection.Log(logger.LevelInfo, "hash not allowed for file %#v", sshPath)
return c.sendErrorResponse(common.ErrPermissionDenied)
return c.sendErrorResponse(c.connection.GetPermissionDeniedError())
}
fsPath, err := c.connection.Fs.ResolvePath(sshPath)
fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath)
if err != nil {
return c.sendErrorResponse(err)
}
if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) {
return c.sendErrorResponse(common.ErrPermissionDenied)
return c.sendErrorResponse(c.connection.GetPermissionDeniedError())
}
hash, err := c.computeHashForFile(h, fsPath)
hash, err := c.computeHashForFile(fs, h, fsPath)
if err != nil {
return c.sendErrorResponse(err)
return c.sendErrorResponse(c.connection.GetFsError(fs, err))
}
response = fmt.Sprintf("%v %v\n", hash, sshPath)
}
@ -305,10 +300,10 @@ func (c *sshCommand) handleHashCommands() error {
}
func (c *sshCommand) executeSystemCommand(command systemCommand) error {
if !vfs.IsLocalOsFs(c.connection.Fs) {
sshDestPath := c.getDestPath()
if !c.isLocalPath(sshDestPath) {
return c.sendErrorResponse(errUnsupportedConfig)
}
sshDestPath := c.getDestPath()
quotaResult := c.connection.HasSpace(true, false, command.quotaCheckPath)
if !quotaResult.HasSpace {
return c.sendErrorResponse(common.ErrQuotaExceeded)
@ -316,10 +311,10 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
perms := []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs, dataprovider.PermListItems,
dataprovider.PermOverwrite, dataprovider.PermDelete}
if !c.connection.User.HasPerms(perms, sshDestPath) {
return c.sendErrorResponse(common.ErrPermissionDenied)
return c.sendErrorResponse(c.connection.GetPermissionDeniedError())
}
initialFiles, initialSize, err := c.getSizeForPath(command.fsPath)
initialFiles, initialSize, err := c.getSizeForPath(command.fs, command.fsPath)
if err != nil {
return c.sendErrorResponse(err)
}
@ -356,7 +351,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
defer stdin.Close()
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferUpload, 0, 0, remainingQuotaSize, false, c.connection.Fs)
common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
@ -369,7 +364,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
common.TransferDownload, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
@ -383,7 +378,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
common.TransferDownload, 0, 0, 0, false, command.fs)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)
@ -399,14 +394,14 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
err = command.cmd.Wait()
c.sendExitStatus(err)
numFiles, dirSize, errSize := c.getSizeForPath(command.fsPath)
numFiles, dirSize, errSize := c.getSizeForPath(command.fs, command.fsPath)
if errSize == nil {
c.updateQuota(sshDestPath, numFiles-initialFiles, dirSize-initialSize)
}
c.connection.Log(logger.LevelDebug, "command %#v finished for path %#v, initial files %v initial size %v "+
"current files %v current size %v size err: %v", c.connection.command, command.fsPath, initialFiles, initialSize,
numFiles, dirSize, errSize)
return err
return c.connection.GetFsError(command.fs, err)
}
func (c *sshCommand) isSystemCommandAllowed() error {
@ -450,21 +445,26 @@ func (c *sshCommand) isSystemCommandAllowed() error {
func (c *sshCommand) getSystemCommand() (systemCommand, error) {
command := systemCommand{
cmd: nil,
fs: nil,
fsPath: "",
quotaCheckPath: "",
}
args := make([]string, len(c.args))
copy(args, c.args)
var fsPath, quotaPath string
sshPath := c.getDestPath()
fs, err := c.connection.User.GetFilesystemForPath(sshPath, c.connection.ID)
if err != nil {
return command, err
}
if len(c.args) > 0 {
var err error
sshPath := c.getDestPath()
fsPath, err = c.connection.Fs.ResolvePath(sshPath)
fsPath, err = fs.ResolvePath(sshPath)
if err != nil {
return command, err
return command, c.connection.GetFsError(fs, err)
}
quotaPath = sshPath
fi, err := c.connection.Fs.Stat(fsPath)
fi, err := fs.Stat(fsPath)
if err == nil && fi.IsDir() {
// if the target is an existing dir the command will write inside this dir
// so we need to check the quota for this directory and not its parent dir
@ -506,6 +506,7 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) {
command.cmd = cmd
command.fsPath = fsPath
command.quotaCheckPath = quotaPath
command.fs = fs
return command, nil
}
@ -535,7 +536,7 @@ func cleanCommandPath(name string) string {
return result
}
func (c *sshCommand) getCopyPaths() (string, string, error) {
func (c *sshCommand) getFsAndCopyPaths() (vfs.Fs, vfs.Fs, string, string, string, string, error) {
sshSourcePath := strings.TrimSuffix(c.getSourcePath(), "/")
sshDestPath := c.getDestPath()
if strings.HasSuffix(sshDestPath, "/") {
@ -543,9 +544,17 @@ func (c *sshCommand) getCopyPaths() (string, string, error) {
}
if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 {
err := errors.New("usage sftpgo-copy <source dir path> <destination dir path>")
return "", "", err
return nil, nil, "", "", "", "", err
}
return sshSourcePath, sshDestPath, nil
fsSrc, fsSourcePath, err := c.connection.GetFsAndResolvedPath(sshSourcePath)
if err != nil {
return nil, nil, "", "", "", "", err
}
fsDst, fsDestPath, err := c.connection.GetFsAndResolvedPath(sshDestPath)
if err != nil {
return nil, nil, "", "", "", "", err
}
return fsSrc, fsDst, sshSourcePath, sshDestPath, fsSourcePath, fsDestPath, nil
}
func (c *sshCommand) hasCopyPermissions(sshSourcePath, sshDestPath string, srcInfo os.FileInfo) bool {
@ -561,7 +570,7 @@ func (c *sshCommand) hasCopyPermissions(sshSourcePath, sshDestPath string, srcIn
}
// fsSourcePath must be a directory
func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, sshDestPath string) error {
func (c *sshCommand) checkRecursiveCopyPermissions(fsSrc vfs.Fs, fsDst vfs.Fs, fsSourcePath, fsDestPath, sshDestPath string) error {
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(sshDestPath)) {
return common.ErrPermissionDenied
}
@ -571,13 +580,13 @@ func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, ssh
dataprovider.PermUpload,
}
err := c.connection.Fs.Walk(fsSourcePath, func(walkedPath string, info os.FileInfo, err error) error {
err := fsSrc.Walk(fsSourcePath, func(walkedPath string, info os.FileInfo, err error) error {
if err != nil {
return err
return c.connection.GetFsError(fsSrc, err)
}
fsDstSubPath := strings.Replace(walkedPath, fsSourcePath, fsDestPath, 1)
sshSrcSubPath := c.connection.Fs.GetRelativePath(walkedPath)
sshDstSubPath := c.connection.Fs.GetRelativePath(fsDstSubPath)
sshSrcSubPath := fsSrc.GetRelativePath(walkedPath)
sshDstSubPath := fsDst.GetRelativePath(fsDstSubPath)
// If the current dir has no subdirs with defined permissions inside it
// and it has all the possible permissions we can stop scanning
if !c.connection.User.HasPermissionsInside(path.Dir(sshSrcSubPath)) &&
@ -598,12 +607,12 @@ func (c *sshCommand) checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, ssh
return err
}
func (c *sshCommand) checkCopyPermissions(fsSourcePath, fsDestPath, sshSourcePath, sshDestPath string, info os.FileInfo) error {
func (c *sshCommand) checkCopyPermissions(fsSrc vfs.Fs, fsDst vfs.Fs, fsSourcePath, fsDestPath, sshSourcePath, sshDestPath string, info os.FileInfo) error {
if info.IsDir() {
return c.checkRecursiveCopyPermissions(fsSourcePath, fsDestPath, sshDestPath)
return c.checkRecursiveCopyPermissions(fsSrc, fsDst, fsSourcePath, fsDestPath, sshDestPath)
}
if !c.hasCopyPermissions(sshSourcePath, sshDestPath, info) {
return common.ErrPermissionDenied
return c.connection.GetPermissionDeniedError()
}
return nil
}
@ -620,24 +629,28 @@ func (c *sshCommand) getRemovePath() (string, error) {
return sshDestPath, nil
}
func (c *sshCommand) resolveCopyPaths(sshSourcePath, sshDestPath string) (string, string, error) {
fsSourcePath, err := c.connection.Fs.ResolvePath(sshSourcePath)
func (c *sshCommand) isLocalPath(virtualPath string) bool {
folder, err := c.connection.User.GetVirtualFolderForPath(virtualPath)
if err != nil {
return "", "", err
return c.connection.User.FsConfig.Provider == vfs.LocalFilesystemProvider
}
fsDestPath, err := c.connection.Fs.ResolvePath(sshDestPath)
if err != nil {
return "", "", err
}
return fsSourcePath, fsDestPath, nil
return folder.FsConfig.Provider == vfs.LocalFilesystemProvider
}
func (c *sshCommand) checkCopyDestination(fsDestPath string) error {
_, err := c.connection.Fs.Lstat(fsDestPath)
func (c *sshCommand) isLocalCopy(virtualSourcePath, virtualTargetPath string) bool {
if !c.isLocalPath(virtualSourcePath) {
return false
}
return c.isLocalPath(virtualTargetPath)
}
func (c *sshCommand) checkCopyDestination(fs vfs.Fs, fsDestPath string) error {
_, err := fs.Lstat(fsDestPath)
if err == nil {
err := errors.New("invalid copy destination: cannot overwrite an existing file or directory")
return err
} else if !c.connection.Fs.IsNotExist(err) {
} else if !fs.IsNotExist(err) {
return err
}
return nil
@ -667,18 +680,18 @@ func (c *sshCommand) checkCopyQuota(numFiles int, filesSize int64, requestPath s
return nil
}
func (c *sshCommand) getSizeForPath(name string) (int, int64, error) {
func (c *sshCommand) getSizeForPath(fs vfs.Fs, name string) (int, int64, error) {
if dataprovider.GetQuotaTracking() > 0 {
fi, err := c.connection.Fs.Lstat(name)
fi, err := fs.Lstat(name)
if err != nil {
if c.connection.Fs.IsNotExist(err) {
if fs.IsNotExist(err) {
return 0, 0, nil
}
c.connection.Log(logger.LevelDebug, "unable to stat %#v error: %v", name, err)
return 0, 0, err
}
if fi.IsDir() {
files, size, err := c.connection.Fs.GetDirSize(name)
files, size, err := fs.GetDirSize(name)
if err != nil {
c.connection.Log(logger.LevelDebug, "unable to get size for dir %#v error: %v", name, err)
}
@ -691,7 +704,7 @@ func (c *sshCommand) getSizeForPath(name string) (int, int64, error) {
}
func (c *sshCommand) sendErrorResponse(err error) error {
errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), c.connection.GetFsError(err))
errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err)
c.connection.channel.Write([]byte(errorString)) //nolint:errcheck
c.sendExitStatus(err)
return err
@ -722,15 +735,15 @@ func (c *sshCommand) sendExitStatus(err error) {
// for scp we notify single uploads/downloads
if c.command != scpCmdName {
metrics.SSHCommandCompleted(err)
if len(cmdPath) > 0 {
p, e := c.connection.Fs.ResolvePath(cmdPath)
if e == nil {
if cmdPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(cmdPath)
if errFs == nil {
cmdPath = p
}
}
if len(targetPath) > 0 {
p, e := c.connection.Fs.ResolvePath(targetPath)
if e == nil {
if targetPath != "" {
_, p, errFs := c.connection.GetFsAndResolvedPath(targetPath)
if errFs == nil {
targetPath = p
}
}
@ -738,9 +751,9 @@ func (c *sshCommand) sendExitStatus(err error) {
}
}
func (c *sshCommand) computeHashForFile(hasher hash.Hash, path string) (string, error) {
func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) {
hash := ""
f, r, _, err := c.connection.Fs.Open(path, 0)
f, r, _, err := fs.Open(path, 0)
if err != nil {
return hash, err
}

View file

@ -8,6 +8,7 @@ import (
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/logger"
)
type subsystemChannel struct {
@ -36,15 +37,16 @@ func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel {
// ServeSubSystemConnection handles a connection as SSH subsystem
func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error {
fs, err := user.GetFilesystem(connectionID)
err := user.CheckFsRoot(connectionID)
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
return err
}
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
dataprovider.UpdateLastLogin(user) //nolint:errcheck
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolSFTP, *user, fs),
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, *user),
ClientVersion: "",
RemoteAddr: &net.IPAddr{},
channel: newSubsystemChannel(reader, writer),

View file

@ -110,7 +110,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) {
func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) {
t.Connection.UpdateLastActivity()
if off < t.MinWriteOffset {
err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset)
err := fmt.Errorf("invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset)
t.TransferError(err)
return 0, err
}
@ -143,7 +143,7 @@ func (t *transfer) Close() error {
if errBaseClose != nil {
err = errBaseClose
}
return t.Connection.GetFsError(err)
return t.Connection.GetFsError(t.Fs, err)
}
func (t *transfer) closeIO() error {

View file

@ -63,13 +63,28 @@
<label for="idMappedPath" class="col-sm-2 col-form-label">Absolute Path</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idMappedPath" name="mapped_path" placeholder=""
value="{{.Folder.MappedPath}}" maxlength="512" autocomplete="nope" required>
value="{{.Folder.MappedPath}}" maxlength="512" aria-describedby="mappedPathHelpBlock">
<small id="descriptionHelpBlock" class="form-text text-muted">
Required for local providers. For Cloud providers, if set, it will store temporary files
</small>
</div>
</div>
{{template "fshtml" .Folder.FsConfig}}
<input type="hidden" name="_form_token" value="{{.CSRFToken}}">
<button type="submit" class="btn btn-primary float-right mt-3 px-5 px-3">{{if eq .Mode 3}}Generate and export folders{{else}}Submit{{end}}</button>
</form>
</div>
</div>
{{end}}
{{define "extra_js"}}
<script type="text/javascript">
$(document).ready(function () {
onFilesystemChanged('{{.Folder.FsConfig.Provider}}');
});
{{template "fsjs"}}
</script>
{{end}}

359
templates/fsconfig.html Normal file
View file

@ -0,0 +1,359 @@
{{define "fshtml"}}
<div class="form-group row">
<label for="idFilesystem" class="col-sm-2 col-form-label">Storage</label>
<div class="col-sm-10">
<select class="form-control" id="idFilesystem" name="fs_provider"
onchange="onFilesystemChanged(this.value)">
<option value="0" {{if eq .Provider 0 }}selected{{end}}>Local</option>
<option value="4" {{if eq .Provider 4 }}selected{{end}}>Local encrypted</option>
<option value="1" {{if eq .Provider 1 }}selected{{end}}>AWS S3 (Compatible)</option>
<option value="2" {{if eq .Provider 2 }}selected{{end}}>Google Cloud Storage</option>
<option value="3" {{if eq .Provider 3 }}selected{{end}}>Azure Blob Storage</option>
<option value="5" {{if eq .Provider 5 }}selected{{end}}>SFTP</option>
</select>
</div>
</div>
<div class="form-group row s3">
<label for="idS3Bucket" class="col-sm-2 col-form-label">Bucket</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Bucket" name="s3_bucket" placeholder=""
value="{{.S3Config.Bucket}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3Region" class="col-sm-2 col-form-label">Region</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Region" name="s3_region" placeholder=""
value="{{.S3Config.Region}}" maxlength="255">
</div>
</div>
<div class="form-group row s3">
<label for="idS3AccessKey" class="col-sm-2 col-form-label">Access Key</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3AccessKey" name="s3_access_key" placeholder=""
value="{{.S3Config.AccessKey}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3AccessSecret" class="col-sm-2 col-form-label">Access Secret</label>
<div class="col-sm-3">
<input type="password" class="form-control" id="idS3AccessSecret" name="s3_access_secret"
placeholder=""
value="{{if .S3Config.AccessSecret.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.S3Config.AccessSecret.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row s3">
<label for="idS3StorageClass" class="col-sm-2 col-form-label">Storage Class</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3StorageClass" name="s3_storage_class" placeholder=""
value="{{.S3Config.StorageClass}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3Endpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Endpoint" name="s3_endpoint" placeholder=""
value="{{.S3Config.Endpoint}}" maxlength="255">
</div>
</div>
<div class="form-group row s3">
<label for="idS3PartSize" class="col-sm-2 col-form-label">UL Part Size (MB)</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idS3PartSize" name="s3_upload_part_size"
placeholder="" value="{{.S3Config.UploadPartSize}}"
aria-describedby="S3PartSizeHelpBlock">
<small id="S3PartSizeHelpBlock" class="form-text text-muted">
The buffer size for multipart uploads. Zero means the default (5 MB). Minimum is 5
</small>
</div>
<div class="col-sm-2"></div>
<label for="idS3UploadConcurrency" class="col-sm-2 col-form-label">UL Concurrency</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idS3UploadConcurrency" name="s3_upload_concurrency"
placeholder="" value="{{.S3Config.UploadConcurrency}}" min="0"
aria-describedby="S3ConcurrencyHelpBlock">
<small id="S3ConcurrencyHelpBlock" class="form-text text-muted">
How many parts are uploaded in parallel. Zero means the default (2)
</small>
</div>
</div>
<div class="form-group row s3">
<label for="idS3KeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idS3KeyPrefix" name="s3_key_prefix" placeholder=""
value="{{.S3Config.KeyPrefix}}" maxlength="255"
aria-describedby="S3KeyPrefixHelpBlock">
<small id="S3KeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSBucket" class="col-sm-2 col-form-label">Bucket</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idGCSBucket" name="gcs_bucket" placeholder=""
value="{{.GCSConfig.Bucket}}" maxlength="255">
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSCredentialFile" class="col-sm-2 col-form-label">Credentials file</label>
<div class="col-sm-4">
<input type="file" class="form-control-file" id="idGCSCredentialFile" name="gcs_credential_file"
aria-describedby="GCSCredentialsHelpBlock">
<small id="GCSCredentialsHelpBlock" class="form-text text-muted">
Add or update credentials from a JSON file
</small>
</div>
<div class="col-sm-1"></div>
<label for="idGCSStorageClass" class="col-sm-2 col-form-label">Storage Class</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idGCSStorageClass" name="gcs_storage_class"
placeholder="" value="{{.GCSConfig.StorageClass}}" maxlength="255">
</div>
</div>
<div class="form-group gcs">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idGCSAutoCredentials"
name="gcs_auto_credentials" {{if gt .GCSConfig.AutomaticCredentials 0}}checked{{end}}>
<label for="idGCSAutoCredentials" class="form-check-label">Automatic credentials</label>
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSKeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idGCSKeyPrefix" name="gcs_key_prefix" placeholder=""
value="{{.GCSConfig.KeyPrefix}}" maxlength="255"
aria-describedby="GCSKeyPrefixHelpBlock">
<small id="GCSKeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzContainer" class="col-sm-2 col-form-label">Container</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idAzContainer" name="az_container" placeholder=""
value="{{.AzBlobConfig.Container}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idAzAccountName" class="col-sm-2 col-form-label">Account Name</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idAzAccountName" name="az_account_name" placeholder=""
value="{{.AzBlobConfig.AccountName}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzAccountKey" class="col-sm-2 col-form-label">Account Key</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idAzAccountKey" name="az_account_key" placeholder=""
value="{{if .AzBlobConfig.AccountKey.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.AzBlobConfig.AccountKey.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzSASURL" class="col-sm-2 col-form-label">SAS URL</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzSASURL" name="az_sas_url" placeholder=""
value="{{.AzBlobConfig.SASURL}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzEndpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzEndpoint" name="az_endpoint" placeholder=""
value="{{.AzBlobConfig.Endpoint}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzAccessTier" class="col-sm-2 col-form-label">Access Tier</label>
<div class="col-sm-10">
<select class="form-control" id="idAzAccessTier" name="az_access_tier">
<option value="" {{if eq .AzBlobConfig.AccessTier "" }}selected{{end}}>Default</option>
<option value="Hot" {{if eq .AzBlobConfig.AccessTier "Hot" }}selected{{end}}>Hot</option>
<option value="Cool" {{if eq .AzBlobConfig.AccessTier "Cool" }}selected{{end}}>Cool</option>
<option value="Archive" {{if eq .AzBlobConfig.AccessTier "Archive"}}selected{{end}}>Archive</option>
</select>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzPartSize" class="col-sm-2 col-form-label">UL Part Size (MB)</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idAzPartSize" name="az_upload_part_size"
placeholder="" value="{{.AzBlobConfig.UploadPartSize}}"
aria-describedby="AzPartSizeHelpBlock">
<small id="AzPartSizeHelpBlock" class="form-text text-muted">
The buffer size for multipart uploads. Zero means the default (4 MB)
</small>
</div>
<div class="col-sm-2"></div>
<label for="idAzUploadConcurrency" class="col-sm-2 col-form-label">UL Concurrency</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idAzUploadConcurrency" name="az_upload_concurrency"
placeholder="" value="{{.AzBlobConfig.UploadConcurrency}}" min="0"
aria-describedby="AzConcurrencyHelpBlock">
<small id="AzConcurrencyHelpBlock" class="form-text text-muted">
How many parts are uploaded in parallel. Zero means the default (2)
</small>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzKeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzKeyPrefix" name="az_key_prefix" placeholder=""
value="{{.AzBlobConfig.KeyPrefix}}" maxlength="255"
aria-describedby="AzKeyPrefixHelpBlock">
<small id="AzKeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group azblob">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idUseEmulator" name="az_use_emulator" {{if .AzBlobConfig.UseEmulator}}checked{{end}}>
<label for="idUseEmulator" class="form-check-label">Use Azure Blob emulator</label>
</div>
</div>
<div class="form-group row crypt">
<label for="idCryptPassphrase" class="col-sm-2 col-form-label">Passphrase</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idCryptPassphrase" name="crypt_passphrase"
placeholder=""
value="{{if .CryptConfig.Passphrase.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.CryptConfig.Passphrase.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPEndpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idSFTPEndpoint" name="sftp_endpoint" placeholder=""
value="{{.SFTPConfig.Endpoint}}" maxlength="255" aria-describedby="SFTPEndpointHelpBlock">
<small id="SFTPEndpointHelpBlock" class="form-text text-muted">
Endpoint as host:port, port is always required
</small>
</div>
<div class="col-sm-2"></div>
<label for="idSFTPUsername" class="col-sm-2 col-form-label">Username</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idSFTPUsername" name="sftp_username" placeholder=""
value="{{.SFTPConfig.Username}}" maxlength="255">
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPassword" class="col-sm-2 col-form-label">Password</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idSFTPPassword" name="sftp_password" placeholder=""
value="{{if .SFTPConfig.Password.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.SFTPConfig.Password.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPrivateKey" class="col-sm-2 col-form-label">Private key</label>
<div class="col-sm-10">
<textarea type="password" class="form-control" id="idSFTPPrivateKey" name="sftp_private_key"
rows="3">{{if .SFTPConfig.PrivateKey.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.SFTPConfig.PrivateKey.GetPayload}}{{end}}</textarea>
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPFingerprints" class="col-sm-2 col-form-label">Fingerprints</label>
<div class="col-sm-10">
<textarea class="form-control" id="idSFTPFingerprints" name="sftp_fingerprints" rows="3"
aria-describedby="SFTPFingerprintsHelpBlock">{{range .SFTPConfig.Fingerprints}}{{.}}&#10;{{end}}</textarea>
<small id="SFTPFingerprintsHelpBlock" class="form-text text-muted">
SHA256 fingerprints to validate when connecting to the external SFTP server, one per line. If
empty any host key will be accepted: this is a security risk!
</small>
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPrefix" class="col-sm-2 col-form-label">Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idSFTPPrefix" name="sftp_prefix" placeholder=""
value="{{.SFTPConfig.Prefix}}" maxlength="255"
aria-describedby="SFTPPrefixHelpBlock">
<small id="SFTPPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Example: "/somedir/subdir".
</small>
</div>
</div>
<div class="form-group sftp">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idDisableConcurrentReads" name="sftp_disable_concurrent_reads" {{if .SFTPConfig.DisableCouncurrentReads}}checked{{end}}>
<label for="idDisableConcurrentReads" class="form-check-label">Disable concurrent reads</label>
</div>
</div>
{{end}}
{{define "fsjs"}}
function onFilesystemChanged(val){
if (val == '1'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').hide();
$('.form-group.row.s3').show();
} else if (val == '2'){
$('.form-group.row.gcs').show();
$('.form-group.gcs').show();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.row.s3').hide();
$('.form-group.sftp').hide();
} else if (val == '3'){
$('.form-group.row.azblob').show();
$('.form-group.azblob').show();
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.crypt').hide();
$('.form-group.row.s3').hide();
$('.form-group.sftp').hide();
} else if (val == '4'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').show();
$('.form-group.sftp').hide();
} else if (val == '5'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').show();
} else {
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').hide();
}
}
{{end}}

View file

@ -360,309 +360,7 @@
</div>
</div>
<div class="form-group row">
<label for="idFilesystem" class="col-sm-2 col-form-label">Storage</label>
<div class="col-sm-10">
<select class="form-control" id="idFilesystem" name="fs_provider"
onchange="onFilesystemChanged(this.value)">
<option value="0" {{if eq .User.FsConfig.Provider 0 }}selected{{end}}>Local</option>
<option value="4" {{if eq .User.FsConfig.Provider 4 }}selected{{end}}>Local encrypted</option>
<option value="1" {{if eq .User.FsConfig.Provider 1 }}selected{{end}}>AWS S3 (Compatible)</option>
<option value="2" {{if eq .User.FsConfig.Provider 2 }}selected{{end}}>Google Cloud Storage</option>
<option value="3" {{if eq .User.FsConfig.Provider 3 }}selected{{end}}>Azure Blob Storage</option>
<option value="5" {{if eq .User.FsConfig.Provider 5 }}selected{{end}}>SFTP</option>
</select>
</div>
</div>
<div class="form-group row s3">
<label for="idS3Bucket" class="col-sm-2 col-form-label">Bucket</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Bucket" name="s3_bucket" placeholder=""
value="{{.User.FsConfig.S3Config.Bucket}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3Region" class="col-sm-2 col-form-label">Region</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Region" name="s3_region" placeholder=""
value="{{.User.FsConfig.S3Config.Region}}" maxlength="255">
</div>
</div>
<div class="form-group row s3">
<label for="idS3AccessKey" class="col-sm-2 col-form-label">Access Key</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3AccessKey" name="s3_access_key" placeholder=""
value="{{.User.FsConfig.S3Config.AccessKey}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3AccessSecret" class="col-sm-2 col-form-label">Access Secret</label>
<div class="col-sm-3">
<input type="password" class="form-control" id="idS3AccessSecret" name="s3_access_secret"
placeholder=""
value="{{if .User.FsConfig.S3Config.AccessSecret.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.User.FsConfig.S3Config.AccessSecret.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row s3">
<label for="idS3StorageClass" class="col-sm-2 col-form-label">Storage Class</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3StorageClass" name="s3_storage_class" placeholder=""
value="{{.User.FsConfig.S3Config.StorageClass}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idS3Endpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idS3Endpoint" name="s3_endpoint" placeholder=""
value="{{.User.FsConfig.S3Config.Endpoint}}" maxlength="255">
</div>
</div>
<div class="form-group row s3">
<label for="idS3PartSize" class="col-sm-2 col-form-label">UL Part Size (MB)</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idS3PartSize" name="s3_upload_part_size"
placeholder="" value="{{.User.FsConfig.S3Config.UploadPartSize}}"
aria-describedby="S3PartSizeHelpBlock">
<small id="S3PartSizeHelpBlock" class="form-text text-muted">
The buffer size for multipart uploads. Zero means the default (5 MB). Minimum is 5
</small>
</div>
<div class="col-sm-2"></div>
<label for="idS3UploadConcurrency" class="col-sm-2 col-form-label">UL Concurrency</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idS3UploadConcurrency" name="s3_upload_concurrency"
placeholder="" value="{{.User.FsConfig.S3Config.UploadConcurrency}}" min="0"
aria-describedby="S3ConcurrencyHelpBlock">
<small id="S3ConcurrencyHelpBlock" class="form-text text-muted">
How many parts are uploaded in parallel. Zero means the default (2)
</small>
</div>
</div>
<div class="form-group row s3">
<label for="idS3KeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idS3KeyPrefix" name="s3_key_prefix" placeholder=""
value="{{.User.FsConfig.S3Config.KeyPrefix}}" maxlength="255"
aria-describedby="S3KeyPrefixHelpBlock">
<small id="S3KeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSBucket" class="col-sm-2 col-form-label">Bucket</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idGCSBucket" name="gcs_bucket" placeholder=""
value="{{.User.FsConfig.GCSConfig.Bucket}}" maxlength="255">
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSCredentialFile" class="col-sm-2 col-form-label">Credentials file</label>
<div class="col-sm-4">
<input type="file" class="form-control-file" id="idGCSCredentialFile" name="gcs_credential_file"
aria-describedby="GCSCredentialsHelpBlock">
<small id="GCSCredentialsHelpBlock" class="form-text text-muted">
Add or update credentials from a JSON file
</small>
</div>
<div class="col-sm-1"></div>
<label for="idGCSStorageClass" class="col-sm-2 col-form-label">Storage Class</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idGCSStorageClass" name="gcs_storage_class"
placeholder="" value="{{.User.FsConfig.GCSConfig.StorageClass}}" maxlength="255">
</div>
</div>
<div class="form-group gcs">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idGCSAutoCredentials"
name="gcs_auto_credentials" {{if gt .User.FsConfig.GCSConfig.AutomaticCredentials 0}}checked{{end}}>
<label for="idGCSAutoCredentials" class="form-check-label">Automatic credentials</label>
</div>
</div>
<div class="form-group row gcs">
<label for="idGCSKeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idGCSKeyPrefix" name="gcs_key_prefix" placeholder=""
value="{{.User.FsConfig.GCSConfig.KeyPrefix}}" maxlength="255"
aria-describedby="GCSKeyPrefixHelpBlock">
<small id="GCSKeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzContainer" class="col-sm-2 col-form-label">Container</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idAzContainer" name="az_container" placeholder=""
value="{{.User.FsConfig.AzBlobConfig.Container}}" maxlength="255">
</div>
<div class="col-sm-2"></div>
<label for="idAzAccountName" class="col-sm-2 col-form-label">Account Name</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idAzAccountName" name="az_account_name" placeholder=""
value="{{.User.FsConfig.AzBlobConfig.AccountName}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzAccountKey" class="col-sm-2 col-form-label">Account Key</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idAzAccountKey" name="az_account_key" placeholder=""
value="{{if .User.FsConfig.AzBlobConfig.AccountKey.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.User.FsConfig.AzBlobConfig.AccountKey.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzSASURL" class="col-sm-2 col-form-label">SAS URL</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzSASURL" name="az_sas_url" placeholder=""
value="{{.User.FsConfig.AzBlobConfig.SASURL}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzEndpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzEndpoint" name="az_endpoint" placeholder=""
value="{{.User.FsConfig.AzBlobConfig.Endpoint}}" maxlength="255">
</div>
</div>
<div class="form-group row azblob">
<label for="idAzAccessTier" class="col-sm-2 col-form-label">Access Tier</label>
<div class="col-sm-10">
<select class="form-control" id="idAzAccessTier" name="az_access_tier">
<option value="" {{if eq .User.FsConfig.AzBlobConfig.AccessTier "" }}selected{{end}}>Default</option>
<option value="Hot" {{if eq .User.FsConfig.AzBlobConfig.AccessTier "Hot" }}selected{{end}}>Hot</option>
<option value="Cool" {{if eq .User.FsConfig.AzBlobConfig.AccessTier "Cool" }}selected{{end}}>Cool</option>
<option value="Archive" {{if eq .User.FsConfig.AzBlobConfig.AccessTier "Archive"}}selected{{end}}>Archive</option>
</select>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzPartSize" class="col-sm-2 col-form-label">UL Part Size (MB)</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idAzPartSize" name="az_upload_part_size"
placeholder="" value="{{.User.FsConfig.AzBlobConfig.UploadPartSize}}"
aria-describedby="AzPartSizeHelpBlock">
<small id="AzPartSizeHelpBlock" class="form-text text-muted">
The buffer size for multipart uploads. Zero means the default (4 MB)
</small>
</div>
<div class="col-sm-2"></div>
<label for="idAzUploadConcurrency" class="col-sm-2 col-form-label">UL Concurrency</label>
<div class="col-sm-3">
<input type="number" class="form-control" id="idAzUploadConcurrency" name="az_upload_concurrency"
placeholder="" value="{{.User.FsConfig.AzBlobConfig.UploadConcurrency}}" min="0"
aria-describedby="AzConcurrencyHelpBlock">
<small id="AzConcurrencyHelpBlock" class="form-text text-muted">
How many parts are uploaded in parallel. Zero means the default (2)
</small>
</div>
</div>
<div class="form-group row azblob">
<label for="idAzKeyPrefix" class="col-sm-2 col-form-label">Key Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idAzKeyPrefix" name="az_key_prefix" placeholder=""
value="{{.User.FsConfig.AzBlobConfig.KeyPrefix}}" maxlength="255"
aria-describedby="AzKeyPrefixHelpBlock">
<small id="AzKeyPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Cannot start with "/". Example: "somedir/subdir/".
</small>
</div>
</div>
<div class="form-group azblob">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idUseEmulator" name="az_use_emulator" {{if .User.FsConfig.AzBlobConfig.UseEmulator}}checked{{end}}>
<label for="idUseEmulator" class="form-check-label">Use Azure Blob emulator</label>
</div>
</div>
<div class="form-group row crypt">
<label for="idCryptPassphrase" class="col-sm-2 col-form-label">Passphrase</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idCryptPassphrase" name="crypt_passphrase"
placeholder=""
value="{{if .User.FsConfig.CryptConfig.Passphrase.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.User.FsConfig.CryptConfig.Passphrase.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPEndpoint" class="col-sm-2 col-form-label">Endpoint</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idSFTPEndpoint" name="sftp_endpoint" placeholder=""
value="{{.User.FsConfig.SFTPConfig.Endpoint}}" maxlength="255" aria-describedby="SFTPEndpointHelpBlock">
<small id="SFTPEndpointHelpBlock" class="form-text text-muted">
Endpoint as host:port, port is always required
</small>
</div>
<div class="col-sm-2"></div>
<label for="idSFTPUsername" class="col-sm-2 col-form-label">Username</label>
<div class="col-sm-3">
<input type="text" class="form-control" id="idSFTPUsername" name="sftp_username" placeholder=""
value="{{.User.FsConfig.SFTPConfig.Username}}" maxlength="255">
</div>
</div>
<div class="form-group sftp">
<div class="form-check">
<input type="checkbox" class="form-check-input" id="idDisableConcurrentReads" name="sftp_disable_concurrent_reads" {{if .User.FsConfig.SFTPConfig.DisableCouncurrentReads}}checked{{end}}>
<label for="idDisableConcurrentReads" class="form-check-label">Disable concurrent reads</label>
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPassword" class="col-sm-2 col-form-label">Password</label>
<div class="col-sm-10">
<input type="password" class="form-control" id="idSFTPPassword" name="sftp_password" placeholder=""
value="{{if .User.FsConfig.SFTPConfig.Password.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.User.FsConfig.SFTPConfig.Password.GetPayload}}{{end}}"
maxlength="1000">
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPrivateKey" class="col-sm-2 col-form-label">Private key</label>
<div class="col-sm-10">
<textarea type="password" class="form-control" id="idSFTPPrivateKey" name="sftp_private_key"
rows="3">{{if .User.FsConfig.SFTPConfig.PrivateKey.IsEncrypted}}{{.RedactedSecret}}{{else}}{{.User.FsConfig.SFTPConfig.PrivateKey.GetPayload}}{{end}}</textarea>
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPFingerprints" class="col-sm-2 col-form-label">Fingerprints</label>
<div class="col-sm-10">
<textarea class="form-control" id="idSFTPFingerprints" name="sftp_fingerprints" rows="3"
aria-describedby="SFTPFingerprintsHelpBlock">{{range .User.FsConfig.SFTPConfig.Fingerprints}}{{.}}&#10;{{end}}</textarea>
<small id="SFTPFingerprintsHelpBlock" class="form-text text-muted">
SHA256 fingerprints to validate when connecting to the external SFTP server, one per line. If
empty any host key will be accepted: this is a security risk!
</small>
</div>
</div>
<div class="form-group row sftp">
<label for="idSFTPPrefix" class="col-sm-2 col-form-label">Prefix</label>
<div class="col-sm-10">
<input type="text" class="form-control" id="idSFTPPrefix" name="sftp_prefix" placeholder=""
value="{{.User.FsConfig.SFTPConfig.Prefix}}" maxlength="255"
aria-describedby="SFTPPrefixHelpBlock">
<small id="SFTPPrefixHelpBlock" class="form-text text-muted">
Similar to a chroot for local filesystem. Example: "/somedir/subdir".
</small>
</div>
</div>
{{template "fshtml" .User.FsConfig}}
<div class="form-group row">
<label for="idAdditionalInfo" class="col-sm-2 col-form-label">Additional info</label>
@ -737,56 +435,6 @@
});
function onFilesystemChanged(val){
if (val == '1'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').hide();
$('.form-group.row.s3').show();
} else if (val == '2'){
$('.form-group.row.gcs').show();
$('.form-group.gcs').show();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.row.s3').hide();
$('.form-group.sftp').hide();
} else if (val == '3'){
$('.form-group.row.azblob').show();
$('.form-group.azblob').show();
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.crypt').hide();
$('.form-group.row.s3').hide();
$('.form-group.sftp').hide();
} else if (val == '4'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').show();
$('.form-group.sftp').hide();
} else if (val == '5'){
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').show();
} else {
$('.form-group.row.gcs').hide();
$('.form-group.gcs').hide();
$('.form-group.row.s3').hide();
$('.form-group.row.azblob').hide();
$('.form-group.azblob').hide();
$('.form-group.crypt').hide();
$('.form-group.sftp').hide();
}
}
{{template "fsjs"}}
</script>
{{end}}

View file

@ -312,18 +312,24 @@ func GenerateEd25519Keys(file string) error {
return os.WriteFile(file+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
}
// GetDirsForSFTPPath returns all the directory for the given path in reverse order
// GetDirsForVirtualPath returns all the directory for the given path in reverse order
// for example if the path is: /1/2/3/4 it returns:
// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ]
func GetDirsForSFTPPath(p string) []string {
sftpPath := CleanPath(p)
dirsForPath := []string{sftpPath}
func GetDirsForVirtualPath(virtualPath string) []string {
if virtualPath == "." {
virtualPath = "/"
} else {
if !path.IsAbs(virtualPath) {
virtualPath = CleanPath(virtualPath)
}
}
dirsForPath := []string{virtualPath}
for {
if sftpPath == "/" {
if virtualPath == "/" {
break
}
sftpPath = path.Dir(sftpPath)
dirsForPath = append(dirsForPath, sftpPath)
virtualPath = path.Dir(virtualPath)
dirsForPath = append(dirsForPath, virtualPath)
}
return dirsForPath
}

View file

@ -36,8 +36,10 @@ var maxTryTimeout = time.Hour * 24 * 365
// AzureBlobFs is a Fs implementation for Azure Blob storage.
type AzureBlobFs struct {
connectionID string
localTempDir string
connectionID string
localTempDir string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
config *AzBlobFsConfig
svc *azblob.ServiceURL
containerURL azblob.ContainerURL
@ -50,10 +52,14 @@ func init() {
}
// NewAzBlobFs returns an AzBlobFs object that allows to interact with Azure Blob storage
func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs, error) {
func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) {
if localTempDir == "" {
localTempDir = filepath.Clean(os.TempDir())
}
fs := &AzureBlobFs{
connectionID: connectionID,
localTempDir: localTempDir,
mountPath: mountPath,
config: &config,
ctxTimeout: 30 * time.Second,
ctxLongTimeout: 300 * time.Second,
@ -89,7 +95,7 @@ func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs,
parts := azblob.NewBlobURLParts(*u)
if parts.ContainerName != "" {
if fs.config.Container != "" && fs.config.Container != parts.ContainerName {
return fs, fmt.Errorf("Container name in SAS URL %#v and container provided %#v do not match",
return fs, fmt.Errorf("container name in SAS URL %#v and container provided %#v do not match",
parts.ContainerName, fs.config.Container)
}
fs.svc = nil
@ -282,7 +288,7 @@ func (fs *AzureBlobFs) Rename(source, target string) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot rename non empty directory: %#v", source)
return fmt.Errorf("cannot rename non empty directory: %#v", source)
}
}
dstBlobURL := fs.containerURL.NewBlobURL(target)
@ -318,7 +324,7 @@ func (fs *AzureBlobFs) Rename(source, target string) error {
}
}
if copyStatus != azblob.CopyStatusSuccess {
err := fmt.Errorf("Copy failed with status: %s", copyStatus)
err := fmt.Errorf("copy failed with status: %s", copyStatus)
metrics.AZCopyObjectCompleted(err)
return err
}
@ -334,7 +340,7 @@ func (fs *AzureBlobFs) Remove(name string, isDir bool) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot remove non empty directory: %#v", name)
return fmt.Errorf("cannot remove non empty directory: %#v", name)
}
}
blobBlockURL := fs.containerURL.NewBlockBlobURL(name)
@ -359,6 +365,11 @@ func (fs *AzureBlobFs) Mkdir(name string) error {
return w.Close()
}
// MkdirAll does nothing, we don't have folder
func (*AzureBlobFs) MkdirAll(name string, uid int, gid int) error {
return nil
}
// Symlink creates source as a symbolic link to target.
func (*AzureBlobFs) Symlink(source, target string) error {
return ErrVfsUnsupported
@ -528,7 +539,7 @@ func (*AzureBlobFs) IsNotSupported(err error) bool {
// CheckRootPath creates the specified local root directory if it does not exists
func (fs *AzureBlobFs) CheckRootPath(username string, uid int, gid int) bool {
// we need a local directory for temporary files
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil)
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "")
return osFs.CheckRootPath(username, uid, gid)
}
@ -607,6 +618,9 @@ func (fs *AzureBlobFs) GetRelativePath(name string) string {
}
rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix))
}
if fs.mountPath != "" {
rel = path.Join(fs.mountPath, rel)
}
return rel
}
@ -675,6 +689,9 @@ func (*AzureBlobFs) HasVirtualFolders() bool {
// ResolvePath returns the matching filesystem path for the specified sftp path
func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) {
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
if !path.IsAbs(virtualPath) {
virtualPath = path.Clean("/" + virtualPath)
}

View file

@ -13,6 +13,6 @@ func init() {
}
// NewAzBlobFs returns an error, Azure Blob storage is disabled
func NewAzBlobFs(connectionID, localTempDir string, config AzBlobFsConfig) (Fs, error) {
func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) {
return nil, errors.New("Azure Blob Storage disabled at build time")
}

View file

@ -31,7 +31,7 @@ type CryptFs struct {
}
// NewCryptFs returns a CryptFs object
func NewCryptFs(connectionID, rootDir string, config CryptFsConfig) (Fs, error) {
func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (Fs, error) {
if err := config.Validate(); err != nil {
return nil, err
}
@ -42,10 +42,10 @@ func NewCryptFs(connectionID, rootDir string, config CryptFsConfig) (Fs, error)
}
fs := &CryptFs{
OsFs: &OsFs{
name: cryptFsName,
connectionID: connectionID,
rootDir: rootDir,
virtualFolders: nil,
name: cryptFsName,
connectionID: connectionID,
rootDir: rootDir,
mountPath: mountPath,
},
masterKey: []byte(config.Passphrase.GetPayload()),
}

104
vfs/filesystem.go Normal file
View file

@ -0,0 +1,104 @@
package vfs
import "github.com/drakkan/sftpgo/kms"
// FilesystemProvider defines the supported storage filesystems
type FilesystemProvider int
// supported values for FilesystemProvider
const (
LocalFilesystemProvider FilesystemProvider = iota // Local
S3FilesystemProvider // AWS S3 compatible
GCSFilesystemProvider // Google Cloud Storage
AzureBlobFilesystemProvider // Azure Blob Storage
CryptedFilesystemProvider // Local encrypted
SFTPFilesystemProvider // SFTP
)
// Filesystem defines cloud storage filesystem details
type Filesystem struct {
RedactedSecret string `json:"-"`
Provider FilesystemProvider `json:"provider"`
S3Config S3FsConfig `json:"s3config,omitempty"`
GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"`
AzBlobConfig AzBlobFsConfig `json:"azblobconfig,omitempty"`
CryptConfig CryptFsConfig `json:"cryptconfig,omitempty"`
SFTPConfig SFTPFsConfig `json:"sftpconfig,omitempty"`
}
// SetEmptySecretsIfNil sets the secrets to empty if nil
func (f *Filesystem) SetEmptySecretsIfNil() {
if f.S3Config.AccessSecret == nil {
f.S3Config.AccessSecret = kms.NewEmptySecret()
}
if f.GCSConfig.Credentials == nil {
f.GCSConfig.Credentials = kms.NewEmptySecret()
}
if f.AzBlobConfig.AccountKey == nil {
f.AzBlobConfig.AccountKey = kms.NewEmptySecret()
}
if f.CryptConfig.Passphrase == nil {
f.CryptConfig.Passphrase = kms.NewEmptySecret()
}
if f.SFTPConfig.Password == nil {
f.SFTPConfig.Password = kms.NewEmptySecret()
}
if f.SFTPConfig.PrivateKey == nil {
f.SFTPConfig.PrivateKey = kms.NewEmptySecret()
}
}
// GetACopy returns a copy
func (f *Filesystem) GetACopy() Filesystem {
f.SetEmptySecretsIfNil()
fs := Filesystem{
Provider: f.Provider,
S3Config: S3FsConfig{
Bucket: f.S3Config.Bucket,
Region: f.S3Config.Region,
AccessKey: f.S3Config.AccessKey,
AccessSecret: f.S3Config.AccessSecret.Clone(),
Endpoint: f.S3Config.Endpoint,
StorageClass: f.S3Config.StorageClass,
KeyPrefix: f.S3Config.KeyPrefix,
UploadPartSize: f.S3Config.UploadPartSize,
UploadConcurrency: f.S3Config.UploadConcurrency,
},
GCSConfig: GCSFsConfig{
Bucket: f.GCSConfig.Bucket,
CredentialFile: f.GCSConfig.CredentialFile,
Credentials: f.GCSConfig.Credentials.Clone(),
AutomaticCredentials: f.GCSConfig.AutomaticCredentials,
StorageClass: f.GCSConfig.StorageClass,
KeyPrefix: f.GCSConfig.KeyPrefix,
},
AzBlobConfig: AzBlobFsConfig{
Container: f.AzBlobConfig.Container,
AccountName: f.AzBlobConfig.AccountName,
AccountKey: f.AzBlobConfig.AccountKey.Clone(),
Endpoint: f.AzBlobConfig.Endpoint,
SASURL: f.AzBlobConfig.SASURL,
KeyPrefix: f.AzBlobConfig.KeyPrefix,
UploadPartSize: f.AzBlobConfig.UploadPartSize,
UploadConcurrency: f.AzBlobConfig.UploadConcurrency,
UseEmulator: f.AzBlobConfig.UseEmulator,
AccessTier: f.AzBlobConfig.AccessTier,
},
CryptConfig: CryptFsConfig{
Passphrase: f.CryptConfig.Passphrase.Clone(),
},
SFTPConfig: SFTPFsConfig{
Endpoint: f.SFTPConfig.Endpoint,
Username: f.SFTPConfig.Username,
Password: f.SFTPConfig.Password.Clone(),
PrivateKey: f.SFTPConfig.PrivateKey.Clone(),
Prefix: f.SFTPConfig.Prefix,
DisableCouncurrentReads: f.SFTPConfig.DisableCouncurrentReads,
},
}
if len(f.SFTPConfig.Fingerprints) > 0 {
fs.SFTPConfig.Fingerprints = make([]string, len(f.SFTPConfig.Fingerprints))
copy(fs.SFTPConfig.Fingerprints, f.SFTPConfig.Fingerprints)
}
return fs
}

View file

@ -2,6 +2,7 @@ package vfs
import (
"fmt"
"path/filepath"
"strconv"
"strings"
@ -23,6 +24,18 @@ type BaseVirtualFolder struct {
LastQuotaUpdate int64 `json:"last_quota_update"`
// list of usernames associated with this virtual folder
Users []string `json:"users,omitempty"`
// Filesystem configuration details
FsConfig Filesystem `json:"filesystem"`
}
// GetEncrytionAdditionalData returns the additional data to use for AEAD
func (v *BaseVirtualFolder) GetEncrytionAdditionalData() string {
return fmt.Sprintf("folder_%v", v.Name)
}
// GetGCSCredentialsFilePath returns the path for GCS credentials
func (v *BaseVirtualFolder) GetGCSCredentialsFilePath() string {
return filepath.Join(credentialsDirPath, "folders", fmt.Sprintf("%v_gcs_credentials.json", v.Name))
}
// GetACopy returns a copy
@ -38,6 +51,7 @@ func (v *BaseVirtualFolder) GetACopy() BaseVirtualFolder {
UsedQuotaFiles: v.UsedQuotaFiles,
LastQuotaUpdate: v.LastQuotaUpdate,
Users: users,
FsConfig: v.FsConfig.GetACopy(),
}
}
@ -60,6 +74,58 @@ func (v *BaseVirtualFolder) GetQuotaSummary() string {
return result
}
// IsLocalOrLocalCrypted returns true if the folder provider is local or local encrypted
func (v *BaseVirtualFolder) IsLocalOrLocalCrypted() bool {
return v.FsConfig.Provider == LocalFilesystemProvider || v.FsConfig.Provider == CryptedFilesystemProvider
}
// HideConfidentialData hides folder confidential data
func (v *BaseVirtualFolder) HideConfidentialData() {
switch v.FsConfig.Provider {
case S3FilesystemProvider:
v.FsConfig.S3Config.AccessSecret.Hide()
case GCSFilesystemProvider:
v.FsConfig.GCSConfig.Credentials.Hide()
case AzureBlobFilesystemProvider:
v.FsConfig.AzBlobConfig.AccountKey.Hide()
case CryptedFilesystemProvider:
v.FsConfig.CryptConfig.Passphrase.Hide()
case SFTPFilesystemProvider:
v.FsConfig.SFTPConfig.Password.Hide()
v.FsConfig.SFTPConfig.PrivateKey.Hide()
}
}
// HasRedactedSecret returns true if the folder has a redacted secret
func (v *BaseVirtualFolder) HasRedactedSecret() bool {
switch v.FsConfig.Provider {
case S3FilesystemProvider:
if v.FsConfig.S3Config.AccessSecret.IsRedacted() {
return true
}
case GCSFilesystemProvider:
if v.FsConfig.GCSConfig.Credentials.IsRedacted() {
return true
}
case AzureBlobFilesystemProvider:
if v.FsConfig.AzBlobConfig.AccountKey.IsRedacted() {
return true
}
case CryptedFilesystemProvider:
if v.FsConfig.CryptConfig.Passphrase.IsRedacted() {
return true
}
case SFTPFilesystemProvider:
if v.FsConfig.SFTPConfig.Password.IsRedacted() {
return true
}
if v.FsConfig.SFTPConfig.PrivateKey.IsRedacted() {
return true
}
}
return false
}
// VirtualFolder defines a mapping between an SFTPGo exposed virtual path and a
// filesystem path outside the user home directory.
// The specified paths must be absolute and the virtual path cannot be "/",
@ -75,6 +141,36 @@ type VirtualFolder struct {
QuotaFiles int `json:"quota_files"`
}
func (v *VirtualFolder) GetFilesystem(connectionID string) (Fs, error) {
switch v.FsConfig.Provider {
case S3FilesystemProvider:
return NewS3Fs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.S3Config)
case GCSFilesystemProvider:
config := v.FsConfig.GCSConfig
config.CredentialFile = v.GetGCSCredentialsFilePath()
return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, config)
case AzureBlobFilesystemProvider:
return NewAzBlobFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.AzBlobConfig)
case CryptedFilesystemProvider:
return NewCryptFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.CryptConfig)
case SFTPFilesystemProvider:
return NewSFTPFs(connectionID, v.VirtualPath, v.FsConfig.SFTPConfig)
default:
return NewOsFs(connectionID, v.MappedPath, v.VirtualPath), nil
}
}
// ScanQuota scans the folder and returns the number of files and their size
func (v *VirtualFolder) ScanQuota() (int, int64, error) {
fs, err := v.GetFilesystem("")
if err != nil {
return 0, 0, err
}
defer fs.Close()
return fs.ScanRootDirContents()
}
// IsIncludedInUserQuota returns true if the virtual folder is included in user quota
func (v *VirtualFolder) IsIncludedInUserQuota() bool {
return v.QuotaFiles == -1 && v.QuotaSize == -1
@ -87,3 +183,13 @@ func (v *VirtualFolder) HasNoQuotaRestrictions(checkFiles bool) bool {
}
return false
}
// GetACopy returns a copy
func (v *VirtualFolder) GetACopy() VirtualFolder {
return VirtualFolder{
BaseVirtualFolder: v.BaseVirtualFolder.GetACopy(),
VirtualPath: v.VirtualPath,
QuotaSize: v.QuotaSize,
QuotaFiles: v.QuotaFiles,
}
}

View file

@ -35,8 +35,10 @@ var (
// GCSFs is a Fs implementation for Google Cloud Storage.
type GCSFs struct {
connectionID string
localTempDir string
connectionID string
localTempDir string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
config *GCSFsConfig
svc *storage.Client
ctxTimeout time.Duration
@ -48,11 +50,16 @@ func init() {
}
// NewGCSFs returns an GCSFs object that allows to interact with Google Cloud Storage
func NewGCSFs(connectionID, localTempDir string, config GCSFsConfig) (Fs, error) {
func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) {
if localTempDir == "" {
localTempDir = filepath.Clean(os.TempDir())
}
var err error
fs := &GCSFs{
connectionID: connectionID,
localTempDir: localTempDir,
mountPath: mountPath,
config: &config,
ctxTimeout: 30 * time.Second,
ctxLongTimeout: 300 * time.Second,
@ -152,7 +159,7 @@ func (fs *GCSFs) Open(name string, offset int64) (File, *pipeat.PipeReaderAt, fu
ctx, cancelFn := context.WithCancel(context.Background())
objectReader, err := obj.NewRangeReader(ctx, offset, -1)
if err == nil && offset > 0 && objectReader.Attrs.ContentEncoding == "gzip" {
err = fmt.Errorf("Range request is not possible for gzip content encoding, requested offset %v", offset)
err = fmt.Errorf("range request is not possible for gzip content encoding, requested offset %v", offset)
objectReader.Close()
}
if err != nil {
@ -230,7 +237,7 @@ func (fs *GCSFs) Rename(source, target string) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot rename non empty directory: %#v", source)
return fmt.Errorf("cannot rename non empty directory: %#v", source)
}
}
src := fs.svc.Bucket(fs.config.Bucket).Object(source)
@ -266,7 +273,7 @@ func (fs *GCSFs) Remove(name string, isDir bool) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot remove non empty directory: %#v", name)
return fmt.Errorf("cannot remove non empty directory: %#v", name)
}
}
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
@ -290,6 +297,11 @@ func (fs *GCSFs) Mkdir(name string) error {
return w.Close()
}
// MkdirAll does nothing, we don't have folder
func (*GCSFs) MkdirAll(name string, uid int, gid int) error {
return nil
}
// Symlink creates source as a symbolic link to target.
func (*GCSFs) Symlink(source, target string) error {
return ErrVfsUnsupported
@ -441,7 +453,7 @@ func (*GCSFs) IsNotSupported(err error) bool {
// CheckRootPath creates the specified local root directory if it does not exists
func (fs *GCSFs) CheckRootPath(username string, uid int, gid int) bool {
// we need a local directory for temporary files
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil)
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "")
return osFs.CheckRootPath(username, uid, gid)
}
@ -510,6 +522,9 @@ func (fs *GCSFs) GetRelativePath(name string) string {
}
rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix))
}
if fs.mountPath != "" {
rel = path.Join(fs.mountPath, rel)
}
return rel
}
@ -578,6 +593,9 @@ func (GCSFs) HasVirtualFolders() bool {
// ResolvePath returns the matching filesystem path for the specified virtual path
func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) {
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
if !path.IsAbs(virtualPath) {
virtualPath = path.Clean("/" + virtualPath)
}

View file

@ -13,6 +13,6 @@ func init() {
}
// NewGCSFs returns an error, GCS is disabled
func NewGCSFs(connectionID, localTempDir string, config GCSFsConfig) (Fs, error) {
func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) {
return nil, errors.New("Google Cloud Storage disabled at build time")
}

View file

@ -15,7 +15,6 @@ import (
"github.com/rs/xid"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
)
const (
@ -25,19 +24,20 @@ const (
// OsFs is a Fs implementation that uses functions provided by the os package.
type OsFs struct {
name string
connectionID string
rootDir string
virtualFolders []VirtualFolder
name string
connectionID string
rootDir string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
}
// NewOsFs returns an OsFs object that allows to interact with local Os filesystem
func NewOsFs(connectionID, rootDir string, virtualFolders []VirtualFolder) Fs {
func NewOsFs(connectionID, rootDir, mountPath string) Fs {
return &OsFs{
name: osFsName,
connectionID: connectionID,
rootDir: rootDir,
virtualFolders: virtualFolders,
name: osFsName,
connectionID: connectionID,
rootDir: rootDir,
mountPath: mountPath,
}
}
@ -53,32 +53,12 @@ func (fs *OsFs) ConnectionID() string {
// Stat returns a FileInfo describing the named file
func (fs *OsFs) Stat(name string) (os.FileInfo, error) {
fi, err := os.Stat(name)
if err != nil {
return fi, err
}
for _, v := range fs.virtualFolders {
if v.MappedPath == name {
info := NewFileInfo(v.VirtualPath, true, fi.Size(), fi.ModTime(), false)
return info, nil
}
}
return fi, err
return os.Stat(name)
}
// Lstat returns a FileInfo describing the named file
func (fs *OsFs) Lstat(name string) (os.FileInfo, error) {
fi, err := os.Lstat(name)
if err != nil {
return fi, err
}
for _, v := range fs.virtualFolders {
if v.MappedPath == name {
info := NewFileInfo(v.VirtualPath, true, fi.Size(), fi.ModTime(), false)
return info, nil
}
}
return fi, err
return os.Lstat(name)
}
// Open opens the named file for reading
@ -114,6 +94,13 @@ func (*OsFs) Mkdir(name string) error {
return os.Mkdir(name, os.ModePerm)
}
// MkdirAll creates a directory named path, along with any necessary parents,
// and returns nil, or else returns an error.
// If path is already a directory, MkdirAll does nothing and returns nil.
func (fs *OsFs) MkdirAll(name string, uid int, gid int) error {
return fs.createMissingDirs(name, uid, gid)
}
// Symlink creates source as a symbolic link to target.
func (*OsFs) Symlink(source, target string) error {
return os.Symlink(source, target)
@ -205,45 +192,13 @@ func (fs *OsFs) CheckRootPath(username string, uid int, gid int) bool {
SetPathPermissions(fs, fs.rootDir, uid, gid)
}
}
// create any missing dirs to the defined virtual dirs
for _, v := range fs.virtualFolders {
p := filepath.Clean(filepath.Join(fs.rootDir, v.VirtualPath))
err = fs.createMissingDirs(p, uid, gid)
if err != nil {
return false
}
if _, err = fs.Stat(v.MappedPath); fs.IsNotExist(err) {
err = os.MkdirAll(v.MappedPath, os.ModePerm)
fsLog(fs, logger.LevelDebug, "virtual directory %#v for user %#v does not exist, try to create, mkdir error: %v",
v.MappedPath, username, err)
if err == nil {
SetPathPermissions(fs, fs.rootDir, uid, gid)
}
}
}
return (err == nil)
return err == nil
}
// ScanRootDirContents returns the number of files contained in the root
// directory and their size
func (fs *OsFs) ScanRootDirContents() (int, int64, error) {
numFiles, size, err := fs.GetDirSize(fs.rootDir)
for _, v := range fs.virtualFolders {
if !v.IsIncludedInUserQuota() {
continue
}
num, s, err := fs.GetDirSize(v.MappedPath)
if err != nil {
if fs.IsNotExist(err) {
fsLog(fs, logger.LevelWarn, "unable to scan contents for non-existent mapped path: %#v", v.MappedPath)
continue
}
return numFiles, size, err
}
numFiles += num
size += s
}
return numFiles, size, err
return fs.GetDirSize(fs.rootDir)
}
// GetAtomicUploadPath returns the path to use for an atomic upload
@ -254,18 +209,13 @@ func (*OsFs) GetAtomicUploadPath(name string) string {
}
// GetRelativePath returns the path for a file relative to the user's home dir.
// This is the path as seen by SFTP users
// This is the path as seen by SFTPGo users
func (fs *OsFs) GetRelativePath(name string) string {
basePath := fs.rootDir
virtualPath := "/"
for _, v := range fs.virtualFolders {
if strings.HasPrefix(name, v.MappedPath+string(os.PathSeparator)) ||
filepath.Clean(name) == v.MappedPath {
basePath = v.MappedPath
virtualPath = v.VirtualPath
}
if fs.mountPath != "" {
virtualPath = fs.mountPath
}
rel, err := filepath.Rel(basePath, filepath.Clean(name))
rel, err := filepath.Rel(fs.rootDir, filepath.Clean(name))
if err != nil {
return ""
}
@ -287,28 +237,31 @@ func (*OsFs) Join(elem ...string) string {
}
// ResolvePath returns the matching filesystem path for the specified sftp path
func (fs *OsFs) ResolvePath(sftpPath string) (string, error) {
func (fs *OsFs) ResolvePath(virtualPath string) (string, error) {
if !filepath.IsAbs(fs.rootDir) {
return "", fmt.Errorf("Invalid root path: %v", fs.rootDir)
return "", fmt.Errorf("invalid root path: %v", fs.rootDir)
}
basePath, r := fs.GetFsPaths(sftpPath)
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
r := filepath.Clean(filepath.Join(fs.rootDir, virtualPath))
p, err := filepath.EvalSymlinks(r)
if err != nil && !os.IsNotExist(err) {
return "", err
} else if os.IsNotExist(err) {
// The requested path doesn't exist, so at this point we need to iterate up the
// path chain until we hit a directory that _does_ exist and can be validated.
_, err = fs.findFirstExistingDir(r, basePath)
_, err = fs.findFirstExistingDir(r)
if err != nil {
fsLog(fs, logger.LevelWarn, "error resolving non-existent path %#v", err)
}
return r, err
}
err = fs.isSubDir(p, basePath)
err = fs.isSubDir(p)
if err != nil {
fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir %#v original path %#v resolved %#v err: %v",
p, sftpPath, r, err)
p, virtualPath, r, err)
}
return r, err
}
@ -339,43 +292,9 @@ func (*OsFs) HasVirtualFolders() bool {
return false
}
// GetFsPaths returns the base path and filesystem path for the given sftpPath.
// base path is the root dir or matching the virtual folder dir for the sftpPath.
// file path is the filesystem path matching the sftpPath
func (fs *OsFs) GetFsPaths(sftpPath string) (string, string) {
basePath := fs.rootDir
virtualPath, mappedPath := fs.getMappedFolderForPath(sftpPath)
if mappedPath != "" {
basePath = mappedPath
sftpPath = strings.TrimPrefix(utils.CleanPath(sftpPath), virtualPath)
}
r := filepath.Clean(filepath.Join(basePath, sftpPath))
return basePath, r
}
// returns the path for the mapped folders or an empty string
func (fs *OsFs) getMappedFolderForPath(p string) (virtualPath, mappedPath string) {
if len(fs.virtualFolders) == 0 {
return
}
dirsForPath := utils.GetDirsForSFTPPath(p)
// dirsForPath contains all the dirs for a given path in reverse order
// for example if the path is: /1/2/3/4 it contains:
// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ]
// so the first match is the one we are interested to
for _, val := range dirsForPath {
for _, v := range fs.virtualFolders {
if val == v.VirtualPath {
return v.VirtualPath, v.MappedPath
}
}
}
return
}
func (fs *OsFs) findNonexistentDirs(path, rootPath string) ([]string, error) {
func (fs *OsFs) findNonexistentDirs(filePath string) ([]string, error) {
results := []string{}
cleanPath := filepath.Clean(path)
cleanPath := filepath.Clean(filePath)
parent := filepath.Dir(cleanPath)
_, err := os.Stat(parent)
@ -391,15 +310,15 @@ func (fs *OsFs) findNonexistentDirs(path, rootPath string) ([]string, error) {
if err != nil {
return results, err
}
err = fs.isSubDir(p, rootPath)
err = fs.isSubDir(p)
if err != nil {
fsLog(fs, logger.LevelWarn, "error finding non existing dir: %v", err)
}
return results, err
}
func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) {
results, err := fs.findNonexistentDirs(path, rootPath)
func (fs *OsFs) findFirstExistingDir(path string) (string, error) {
results, err := fs.findNonexistentDirs(path)
if err != nil {
fsLog(fs, logger.LevelWarn, "unable to find non existent dirs: %v", err)
return "", err
@ -409,7 +328,7 @@ func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) {
lastMissingDir := results[len(results)-1]
parent = filepath.Dir(lastMissingDir)
} else {
parent = rootPath
parent = fs.rootDir
}
p, err := filepath.EvalSymlinks(parent)
if err != nil {
@ -422,15 +341,15 @@ func (fs *OsFs) findFirstExistingDir(path, rootPath string) (string, error) {
if !fileInfo.IsDir() {
return "", fmt.Errorf("resolved path is not a dir: %#v", p)
}
err = fs.isSubDir(p, rootPath)
err = fs.isSubDir(p)
return p, err
}
func (fs *OsFs) isSubDir(sub, rootPath string) error {
// rootPath must exist and it is already a validated absolute path
parent, err := filepath.EvalSymlinks(rootPath)
func (fs *OsFs) isSubDir(sub string) error {
// fs.rootDir must exist and it is already a validated absolute path
parent, err := filepath.EvalSymlinks(fs.rootDir)
if err != nil {
fsLog(fs, logger.LevelWarn, "invalid root path %#v: %v", rootPath, err)
fsLog(fs, logger.LevelWarn, "invalid root path %#v: %v", fs.rootDir, err)
return err
}
if parent == sub {
@ -448,7 +367,7 @@ func (fs *OsFs) isSubDir(sub, rootPath string) error {
}
func (fs *OsFs) createMissingDirs(filePath string, uid, gid int) error {
dirsToCreate, err := fs.findNonexistentDirs(filePath, fs.rootDir)
dirsToCreate, err := fs.findNonexistentDirs(filePath)
if err != nil {
return err
}

View file

@ -30,8 +30,10 @@ import (
// S3Fs is a Fs implementation for AWS S3 compatible object storages
type S3Fs struct {
connectionID string
localTempDir string
connectionID string
localTempDir string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
config *S3FsConfig
svc *s3.S3
ctxTimeout time.Duration
@ -44,10 +46,14 @@ func init() {
// NewS3Fs returns an S3Fs object that allows to interact with an s3 compatible
// object storage
func NewS3Fs(connectionID, localTempDir string, config S3FsConfig) (Fs, error) {
func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (Fs, error) {
if localTempDir == "" {
localTempDir = filepath.Clean(os.TempDir())
}
fs := &S3Fs{
connectionID: connectionID,
localTempDir: localTempDir,
mountPath: mountPath,
config: &config,
ctxTimeout: 30 * time.Second,
ctxLongTimeout: 300 * time.Second,
@ -249,7 +255,7 @@ func (fs *S3Fs) Rename(source, target string) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot rename non empty directory: %#v", source)
return fmt.Errorf("cannot rename non empty directory: %#v", source)
}
if !strings.HasSuffix(copySource, "/") {
copySource += "/"
@ -288,7 +294,7 @@ func (fs *S3Fs) Remove(name string, isDir bool) error {
return err
}
if hasContents {
return fmt.Errorf("Cannot remove non empty directory: %#v", name)
return fmt.Errorf("cannot remove non empty directory: %#v", name)
}
if !strings.HasSuffix(name, "/") {
name += "/"
@ -320,6 +326,11 @@ func (fs *S3Fs) Mkdir(name string) error {
return w.Close()
}
// MkdirAll does nothing, we don't have folder
func (*S3Fs) MkdirAll(name string, uid int, gid int) error {
return nil
}
// Symlink creates source as a symbolic link to target.
func (*S3Fs) Symlink(source, target string) error {
return ErrVfsUnsupported
@ -465,7 +476,7 @@ func (*S3Fs) IsNotSupported(err error) bool {
// CheckRootPath creates the specified local root directory if it does not exists
func (fs *S3Fs) CheckRootPath(username string, uid int, gid int) bool {
// we need a local directory for temporary files
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, nil)
osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "")
return osFs.CheckRootPath(username, uid, gid)
}
@ -522,6 +533,9 @@ func (fs *S3Fs) GetRelativePath(name string) string {
}
rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix))
}
if fs.mountPath != "" {
rel = path.Join(fs.mountPath, rel)
}
return rel
}
@ -574,6 +588,9 @@ func (*S3Fs) HasVirtualFolders() bool {
// ResolvePath returns the matching filesystem path for the specified virtual path
func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) {
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
if !path.IsAbs(virtualPath) {
virtualPath = path.Clean("/" + virtualPath)
}

View file

@ -13,6 +13,6 @@ func init() {
}
// NewS3Fs returns an error, S3 is disabled
func NewS3Fs(connectionID, localTempDir string, config S3FsConfig) (Fs, error) {
func NewS3Fs(connectionID, localTempDir, mountPath string, config S3FsConfig) (Fs, error) {
return nil, errors.New("S3 disabled at build time")
}

View file

@ -110,14 +110,16 @@ func (c *SFTPFsConfig) EncryptCredentials(additionalData string) error {
type SFTPFs struct {
sync.Mutex
connectionID string
config *SFTPFsConfig
sshClient *ssh.Client
sftpClient *sftp.Client
err chan error
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
config *SFTPFsConfig
sshClient *ssh.Client
sftpClient *sftp.Client
err chan error
}
// NewSFTPFs returns an SFTPFa object that allows to interact with an SFTP server
func NewSFTPFs(connectionID string, config SFTPFsConfig) (Fs, error) {
func NewSFTPFs(connectionID, mountPath string, config SFTPFsConfig) (Fs, error) {
if err := config.Validate(); err != nil {
return nil, err
}
@ -133,6 +135,7 @@ func NewSFTPFs(connectionID string, config SFTPFsConfig) (Fs, error) {
}
sftpFs := &SFTPFs{
connectionID: connectionID,
mountPath: mountPath,
config: &config,
err: make(chan error, 1),
}
@ -226,6 +229,16 @@ func (fs *SFTPFs) Mkdir(name string) error {
return fs.sftpClient.Mkdir(name)
}
// MkdirAll creates a directory named path, along with any necessary parents,
// and returns nil, or else returns an error.
// If path is already a directory, MkdirAll does nothing and returns nil.
func (fs *SFTPFs) MkdirAll(name string, uid int, gid int) error {
if err := fs.checkConnection(); err != nil {
return err
}
return fs.sftpClient.MkdirAll(name)
}
// Symlink creates source as a symbolic link to target.
func (fs *SFTPFs) Symlink(source, target string) error {
if err := fs.checkConnection(); err != nil {
@ -358,6 +371,9 @@ func (fs *SFTPFs) GetRelativePath(name string) string {
}
rel = path.Clean("/" + strings.TrimPrefix(rel, fs.config.Prefix))
}
if fs.mountPath != "" {
rel = path.Join(fs.mountPath, rel)
}
return rel
}
@ -393,6 +409,9 @@ func (*SFTPFs) HasVirtualFolders() bool {
// ResolvePath returns the matching filesystem path for the specified virtual path
func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) {
if fs.mountPath != "" {
virtualPath = strings.TrimPrefix(virtualPath, fs.mountPath)
}
if !path.IsAbs(virtualPath) {
virtualPath = path.Clean("/" + virtualPath)
}
@ -559,7 +578,7 @@ func (fs *SFTPFs) createConnection() error {
return nil
}
}
return fmt.Errorf("Invalid fingerprint %#v", fp)
return fmt.Errorf("invalid fingerprint %#v", fp)
}
fsLog(fs, logger.LevelWarn, "login without host key validation, please provide at least a fingerprint!")
return nil

View file

@ -27,8 +27,21 @@ var (
validAzAccessTier = []string{"", "Archive", "Hot", "Cool"}
// ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size
ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend")
// ErrVfsUnsupported defines the error for an unsupported VFS operation
ErrVfsUnsupported = errors.New("not supported")
credentialsDirPath string
)
// SetCredentialsDirPath sets the credentials dir path
func SetCredentialsDirPath(credentialsPath string) {
credentialsDirPath = credentialsPath
}
// GetCredentialsDirPath returns the credentials dir path
func GetCredentialsDirPath() string {
return credentialsDirPath
}
// Fs defines the interface for filesystem backends
type Fs interface {
Name() string
@ -40,6 +53,7 @@ type Fs interface {
Rename(source, target string) error
Remove(name string, isDir bool) error
Mkdir(name string) error
MkdirAll(name string, uid int, gid int) error
Symlink(source, target string) error
Chown(name string, uid int, gid int) error
Chmod(name string, mode os.FileMode) error
@ -79,9 +93,6 @@ type File interface {
Truncate(size int64) error
}
// ErrVfsUnsupported defines the error for an unsupported VFS operation
var ErrVfsUnsupported = errors.New("Not supported")
// QuotaCheckResult defines the result for a quota check
type QuotaCheckResult struct {
HasSpace bool
@ -438,6 +449,11 @@ func IsLocalOrSFTPFs(fs Fs) bool {
return IsLocalOsFs(fs) || IsSFTPFs(fs)
}
// IsLocalOrCryptoFs returns true if fs is local or local encrypted
func IsLocalOrCryptoFs(fs Fs) bool {
return IsLocalOsFs(fs) || IsCryptOsFs(fs)
}
// SetPathPermissions calls fs.Chown.
// It does nothing for local filesystem on windows
func SetPathPermissions(fs Fs, path string, uid int, gid int) {

View file

@ -84,7 +84,7 @@ func (f *webDavFile) Readdir(count int) ([]os.FileInfo, error) {
if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) {
return nil, f.Connection.GetPermissionDeniedError()
}
fileInfos, err := f.Connection.ListDir(f.GetFsPath(), f.GetVirtualPath())
fileInfos, err := f.Connection.ListDir(f.GetVirtualPath())
if err != nil {
return nil, err
}
@ -299,7 +299,7 @@ func (f *webDavFile) Close() error {
} else {
f.Connection.RemoveTransfer(f.BaseTransfer)
}
return f.Connection.GetFsError(err)
return f.Connection.GetFsError(f.Fs, err)
}
func (f *webDavFile) closeIO() error {

View file

@ -57,11 +57,7 @@ func (c *Connection) Mkdir(ctx context.Context, name string, perm os.FileMode) e
c.UpdateLastActivity()
name = utils.CleanPath(name)
p, err := c.Fs.ResolvePath(name)
if err != nil {
return c.GetFsError(err)
}
return c.CreateDir(p, name)
return c.CreateDir(name)
}
// Rename renames a file or a directory
@ -71,20 +67,10 @@ func (c *Connection) Rename(ctx context.Context, oldName, newName string) error
oldName = utils.CleanPath(oldName)
newName = utils.CleanPath(newName)
p, err := c.Fs.ResolvePath(oldName)
if err != nil {
return c.GetFsError(err)
}
t, err := c.Fs.ResolvePath(newName)
if err != nil {
return c.GetFsError(err)
}
if err = c.BaseConnection.Rename(p, t, oldName, newName); err != nil {
if err := c.BaseConnection.Rename(oldName, newName); err != nil {
return err
}
vfs.SetPathPermissions(c.Fs, t, c.User.GetUID(), c.User.GetGID())
return nil
}
@ -98,14 +84,10 @@ func (c *Connection) Stat(ctx context.Context, name string) (os.FileInfo, error)
return nil, c.GetPermissionDeniedError()
}
p, err := c.Fs.ResolvePath(name)
fi, err := c.DoStat(name, 0)
if err != nil {
return nil, c.GetFsError(err)
}
fi, err := c.DoStat(p, 0)
if err != nil {
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", p, err)
return nil, c.GetFsError(err)
c.Log(logger.LevelDebug, "error running stat on path %#v: %+v", name, err)
return nil, err
}
return fi, err
}
@ -116,21 +98,21 @@ func (c *Connection) RemoveAll(ctx context.Context, name string) error {
c.UpdateLastActivity()
name = utils.CleanPath(name)
p, err := c.Fs.ResolvePath(name)
fs, p, err := c.GetFsAndResolvedPath(name)
if err != nil {
return c.GetFsError(err)
return err
}
var fi os.FileInfo
if fi, err = c.Fs.Lstat(p); err != nil {
c.Log(logger.LevelWarn, "failed to remove a file %#v: stat error: %+v", p, err)
return c.GetFsError(err)
if fi, err = fs.Lstat(p); err != nil {
c.Log(logger.LevelDebug, "failed to remove a file %#v: stat error: %+v", p, err)
return c.GetFsError(fs, err)
}
if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 {
return c.removeDirTree(p, name)
return c.removeDirTree(fs, p, name)
}
return c.RemoveFile(p, name, fi)
return c.RemoveFile(fs, p, name, fi)
}
// OpenFile opens the named file with specified flag.
@ -139,19 +121,19 @@ func (c *Connection) OpenFile(ctx context.Context, name string, flag int, perm o
c.UpdateLastActivity()
name = utils.CleanPath(name)
p, err := c.Fs.ResolvePath(name)
fs, p, err := c.GetFsAndResolvedPath(name)
if err != nil {
return nil, c.GetFsError(err)
return nil, err
}
if flag == os.O_RDONLY || c.request.Method == "PROPPATCH" {
// Download, Stat, Readdir or simply open/close
return c.getFile(p, name)
return c.getFile(fs, p, name)
}
return c.putFile(p, name)
return c.putFile(fs, p, name)
}
func (c *Connection) getFile(fsPath, virtualPath string) (webdav.File, error) {
func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) {
var err error
var file vfs.File
var r *pipeat.PipeReaderAt
@ -159,42 +141,42 @@ func (c *Connection) getFile(fsPath, virtualPath string) (webdav.File, error) {
// for cloud fs we open the file when we receive the first read to avoid to download the first part of
// the file if it was opened only to do a stat or a readdir and so it is not a real download
if vfs.IsLocalOrSFTPFs(c.Fs) {
file, r, cancelFn, err = c.Fs.Open(fsPath, 0)
if vfs.IsLocalOrSFTPFs(fs) {
file, r, cancelFn, err = fs.Open(fsPath, 0)
if err != nil {
c.Log(logger.LevelWarn, "could not open file %#v for reading: %+v", fsPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
}
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, virtualPath, common.TransferDownload,
0, 0, 0, false, c.Fs)
0, 0, 0, false, fs)
return newWebDavFile(baseTransfer, nil, r), nil
}
func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) {
func (c *Connection) putFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) {
if !c.User.IsFileAllowed(virtualPath) {
c.Log(logger.LevelWarn, "writing file %#v is not allowed", virtualPath)
return nil, c.GetPermissionDeniedError()
}
filePath := fsPath
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
filePath = c.Fs.GetAtomicUploadPath(fsPath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
filePath = fs.GetAtomicUploadPath(fsPath)
}
stat, statErr := c.Fs.Lstat(fsPath)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || c.Fs.IsNotExist(statErr) {
stat, statErr := fs.Lstat(fsPath)
if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) {
if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) {
return nil, c.GetPermissionDeniedError()
}
return c.handleUploadToNewFile(fsPath, filePath, virtualPath)
return c.handleUploadToNewFile(fs, fsPath, filePath, virtualPath)
}
if statErr != nil {
c.Log(logger.LevelError, "error performing file stat %#v: %+v", fsPath, statErr)
return nil, c.GetFsError(statErr)
return nil, c.GetFsError(fs, statErr)
}
// This happen if we upload a file that has the same name of an existing directory
@ -207,33 +189,33 @@ func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) {
return nil, c.GetPermissionDeniedError()
}
return c.handleUploadToExistingFile(fsPath, filePath, stat.Size(), virtualPath)
return c.handleUploadToExistingFile(fs, fsPath, filePath, stat.Size(), virtualPath)
}
func (c *Connection) handleUploadToNewFile(resolvedPath, filePath, requestPath string) (webdav.File, error) {
func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (webdav.File, error) {
quotaResult := c.HasSpace(true, false, requestPath)
if !quotaResult.HasSpace {
c.Log(logger.LevelInfo, "denying file write due to quota limits")
return nil, common.ErrQuotaExceeded
}
file, w, cancelFn, err := c.Fs.Create(filePath, 0)
file, w, cancelFn, err := fs.Create(filePath, 0)
if err != nil {
c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
// we can get an error only for resume
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0)
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs)
common.TransferUpload, 0, 0, maxWriteSize, true, fs)
return newWebDavFile(baseTransfer, w, nil), nil
}
func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, fileSize int64,
func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePath string, fileSize int64,
requestPath string) (webdav.File, error) {
var err error
quotaResult := c.HasSpace(false, false, requestPath)
@ -244,24 +226,24 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f
// if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace
// will return false in this case and we deny the upload before
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize)
maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, fileSize, fs.IsUploadResumeSupported())
if common.Config.IsAtomicUploadEnabled() && c.Fs.IsAtomicUploadSupported() {
err = c.Fs.Rename(resolvedPath, filePath)
if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() {
err = fs.Rename(resolvedPath, filePath)
if err != nil {
c.Log(logger.LevelWarn, "error renaming existing file for atomic upload, source: %#v, dest: %#v, err: %+v",
resolvedPath, filePath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
}
file, w, cancelFn, err := c.Fs.Create(filePath, 0)
file, w, cancelFn, err := fs.Create(filePath, 0)
if err != nil {
c.Log(logger.LevelWarn, "error creating file %#v: %+v", resolvedPath, err)
return nil, c.GetFsError(err)
return nil, c.GetFsError(fs, err)
}
initialSize := int64(0)
if vfs.IsLocalOrSFTPFs(c.Fs) {
if vfs.IsLocalOrSFTPFs(fs) {
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck
@ -275,10 +257,10 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f
initialSize = fileSize
}
vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID())
vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, false, c.Fs)
common.TransferUpload, 0, initialSize, maxWriteSize, false, fs)
return newWebDavFile(baseTransfer, w, nil), nil
}
@ -289,22 +271,22 @@ type objectMapping struct {
info os.FileInfo
}
func (c *Connection) removeDirTree(fsPath, virtualPath string) error {
func (c *Connection) removeDirTree(fs vfs.Fs, fsPath, virtualPath string) error {
var dirsToRemove []objectMapping
var filesToRemove []objectMapping
err := c.Fs.Walk(fsPath, func(walkedPath string, info os.FileInfo, err error) error {
err := fs.Walk(fsPath, func(walkedPath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
obj := objectMapping{
fsPath: walkedPath,
virtualPath: c.Fs.GetRelativePath(walkedPath),
virtualPath: fs.GetRelativePath(walkedPath),
info: info,
}
if info.IsDir() {
err = c.IsRemoveDirAllowed(obj.fsPath, obj.virtualPath)
err = c.IsRemoveDirAllowed(fs, obj.fsPath, obj.virtualPath)
isDuplicated := false
for _, d := range dirsToRemove {
if d.fsPath == obj.fsPath {
@ -316,7 +298,7 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error {
dirsToRemove = append(dirsToRemove, obj)
}
} else {
err = c.IsRemoveFileAllowed(obj.fsPath, obj.virtualPath)
err = c.IsRemoveFileAllowed(obj.virtualPath)
filesToRemove = append(filesToRemove, obj)
}
if err != nil {
@ -333,7 +315,7 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error {
}
for _, fileObj := range filesToRemove {
err = c.RemoveFile(fileObj.fsPath, fileObj.virtualPath, fileObj.info)
err = c.RemoveFile(fs, fileObj.fsPath, fileObj.virtualPath, fileObj.info)
if err != nil {
c.Log(logger.LevelDebug, "unable to remove dir tree, error removing file %#v->%#v: %v",
fileObj.virtualPath, fileObj.fsPath, err)
@ -341,8 +323,8 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error {
}
}
for _, dirObj := range c.orderDirsToRemove(dirsToRemove) {
err = c.RemoveDir(dirObj.fsPath, dirObj.virtualPath)
for _, dirObj := range c.orderDirsToRemove(fs, dirsToRemove) {
err = c.RemoveDir(dirObj.virtualPath)
if err != nil {
c.Log(logger.LevelDebug, "unable to remove dir tree, error removing directory %#v->%#v: %v",
dirObj.virtualPath, dirObj.fsPath, err)
@ -354,12 +336,12 @@ func (c *Connection) removeDirTree(fsPath, virtualPath string) error {
}
// order directories so that the empty ones will be at slice start
func (c *Connection) orderDirsToRemove(dirsToRemove []objectMapping) []objectMapping {
func (c *Connection) orderDirsToRemove(fs vfs.Fs, dirsToRemove []objectMapping) []objectMapping {
orderedDirs := make([]objectMapping, 0, len(dirsToRemove))
removedDirs := make([]string, 0, len(dirsToRemove))
pathSeparator := "/"
if vfs.IsLocalOsFs(c.Fs) {
if vfs.IsLocalOsFs(fs) {
pathSeparator = string(os.PathSeparator)
}

View file

@ -22,6 +22,7 @@ import (
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/kms"
"github.com/drakkan/sftpgo/vfs"
)
@ -321,7 +322,7 @@ func (fs *MockOsFs) GetMimeType(name string) (string, error) {
func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string, reader *pipeat.PipeReaderAt) vfs.Fs {
return &MockOsFs{
Fs: vfs.NewOsFs(connectionID, rootDir, nil),
Fs: vfs.NewOsFs(connectionID, rootDir, ""),
err: err,
isAtomicUploadSupported: atomicUpload,
reader: reader,
@ -330,14 +331,14 @@ func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string, rea
func TestOrderDirsToRemove(t *testing.T) {
user := dataprovider.User{}
fs := vfs.NewOsFs("id", os.TempDir(), nil)
fs := vfs.NewOsFs("id", os.TempDir(), "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
request: nil,
}
dirsToRemove := []objectMapping{}
orderedDirs := connection.orderDirsToRemove(dirsToRemove)
orderedDirs := connection.orderDirsToRemove(fs, dirsToRemove)
assert.Equal(t, len(dirsToRemove), len(orderedDirs))
dirsToRemove = []objectMapping{
@ -346,7 +347,7 @@ func TestOrderDirsToRemove(t *testing.T) {
virtualPath: "",
},
}
orderedDirs = connection.orderDirsToRemove(dirsToRemove)
orderedDirs = connection.orderDirsToRemove(fs, dirsToRemove)
assert.Equal(t, len(dirsToRemove), len(orderedDirs))
dirsToRemove = []objectMapping{
@ -368,7 +369,7 @@ func TestOrderDirsToRemove(t *testing.T) {
},
}
orderedDirs = connection.orderDirsToRemove(dirsToRemove)
orderedDirs = connection.orderDirsToRemove(fs, dirsToRemove)
if assert.Equal(t, len(dirsToRemove), len(orderedDirs)) {
assert.Equal(t, "dir12", orderedDirs[0].fsPath)
assert.Equal(t, filepath.Join("dir1", "a", "b"), orderedDirs[1].fsPath)
@ -403,30 +404,6 @@ func TestUserInvalidParams(t *testing.T) {
assert.EqualError(t, err, fmt.Sprintf("cannot login user with invalid home dir: %#v", u.HomeDir))
}
u.HomeDir = filepath.Clean(os.TempDir())
subDir := "subdir"
mappedPath1 := filepath.Join(os.TempDir(), "vdir1")
vdirPath1 := "/vdir1"
mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir)
vdirPath2 := "/vdir2"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
MappedPath: mappedPath1,
},
VirtualPath: vdirPath1,
})
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
MappedPath: mappedPath2,
},
VirtualPath: vdirPath2,
})
_, err = server.validateUser(u, req, dataprovider.LoginMethodPassword)
if assert.Error(t, err) {
assert.EqualError(t, err, "overlapping mapped folders are allowed only with quota tracking disabled")
}
req.TLS = &tls.ConnectionState{}
writeLog(req, nil)
}
@ -478,9 +455,9 @@ func TestResolvePathErrors(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
err := connection.Mkdir(ctx, "", os.ModePerm)
@ -509,8 +486,9 @@ func TestResolvePathErrors(t *testing.T) {
}
if runtime.GOOS != "windows" {
connection.User.HomeDir = filepath.Clean(os.TempDir())
connection.Fs = vfs.NewOsFs("connID", connection.User.HomeDir, nil)
user.HomeDir = filepath.Clean(os.TempDir())
connection.User = user
fs := vfs.NewOsFs("connID", connection.User.HomeDir, "")
subDir := "sub"
testTxtFile := "file.txt"
err = os.MkdirAll(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm)
@ -519,11 +497,13 @@ func TestResolvePathErrors(t *testing.T) {
assert.NoError(t, err)
err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), 0001)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(os.TempDir(), testTxtFile), []byte("test content"), os.ModePerm)
assert.NoError(t, err)
err = connection.Rename(ctx, testTxtFile, path.Join(subDir, subDir, testTxtFile))
if assert.Error(t, err) {
assert.EqualError(t, err, common.ErrPermissionDenied.Error())
}
_, err = connection.putFile(filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile),
_, err = connection.putFile(fs, filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile),
path.Join(subDir, subDir, testTxtFile))
if assert.Error(t, err) {
assert.EqualError(t, err, common.ErrPermissionDenied.Error())
@ -532,6 +512,8 @@ func TestResolvePathErrors(t *testing.T) {
assert.NoError(t, err)
err = os.RemoveAll(filepath.Join(os.TempDir(), subDir))
assert.NoError(t, err)
err = os.Remove(filepath.Join(os.TempDir(), testTxtFile))
assert.NoError(t, err)
}
}
@ -542,9 +524,9 @@ func TestFileAccessErrors(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
missingPath := "missing path"
fsMissingPath := filepath.Join(user.HomeDir, missingPath)
@ -552,26 +534,26 @@ func TestFileAccessErrors(t *testing.T) {
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
_, err = connection.getFile(fsMissingPath, missingPath)
_, err = connection.getFile(fs, fsMissingPath, missingPath)
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
_, err = connection.getFile(fsMissingPath, missingPath)
_, err = connection.getFile(fs, fsMissingPath, missingPath)
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
p := filepath.Join(user.HomeDir, "adir", missingPath)
_, err = connection.handleUploadToNewFile(p, p, path.Join("adir", missingPath))
_, err = connection.handleUploadToNewFile(fs, p, p, path.Join("adir", missingPath))
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
_, err = connection.handleUploadToExistingFile(p, p, 0, path.Join("adir", missingPath))
_, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath))
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
connection.Fs = newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, nil)
_, err = connection.handleUploadToExistingFile(p, p, 0, path.Join("adir", missingPath))
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.HomeDir, nil)
_, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath))
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
@ -580,7 +562,7 @@ func TestFileAccessErrors(t *testing.T) {
assert.NoError(t, err)
err = f.Close()
assert.NoError(t, err)
davFile, err := connection.handleUploadToExistingFile(f.Name(), f.Name(), 123, f.Name())
davFile, err := connection.handleUploadToExistingFile(fs, f.Name(), f.Name(), 123, f.Name())
if assert.NoError(t, err) {
transfer := davFile.(*webDavFile)
transfers := connection.GetTransfers()
@ -603,46 +585,46 @@ func TestRemoveDirTree(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
vpath := path.Join("adir", "missing")
p := filepath.Join(user.HomeDir, "adir", "missing")
err := connection.removeDirTree(p, vpath)
err := connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.True(t, os.IsNotExist(err))
}
connection.Fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(p, vpath)
fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.True(t, os.IsNotExist(err))
assert.True(t, os.IsNotExist(err), "unexpected error: %v", err)
}
errFake := errors.New("fake err")
connection.Fs = newMockOsFs(errFake, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(p, vpath)
fs = newMockOsFs(errFake, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.EqualError(t, err, errFake.Error())
}
connection.Fs = newMockOsFs(errWalkDir, true, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(p, vpath)
fs = newMockOsFs(errWalkDir, true, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.True(t, os.IsNotExist(err))
assert.True(t, os.IsPermission(err), "unexpected error: %v", err)
}
connection.Fs = newMockOsFs(errWalkFile, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(p, vpath)
fs = newMockOsFs(errWalkFile, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.EqualError(t, err, errWalkFile.Error())
}
connection.User.Permissions["/"] = []string{dataprovider.PermListItems}
connection.Fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(p, vpath)
fs = newMockOsFs(nil, false, "mockID", user.HomeDir, nil)
err = connection.removeDirTree(fs, p, vpath)
if assert.Error(t, err) {
assert.EqualError(t, err, common.ErrPermissionDenied.Error())
}
@ -654,9 +636,9 @@ func TestContentType(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
testFilePath := filepath.Join(user.HomeDir, testFile)
ctx := context.Background()
@ -679,7 +661,7 @@ func TestContentType(t *testing.T) {
assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = vfs.NewOsFs("id", user.HomeDir, nil)
davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "")
fi, err = davFile.Stat()
if assert.NoError(t, err) {
ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
@ -703,9 +685,9 @@ func TestTransferReadWriteErrors(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
testFilePath := filepath.Join(user.HomeDir, testFile)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
@ -796,9 +778,9 @@ func TestTransferSeek(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("connID", user.HomeDir, nil)
fs := vfs.NewOsFs("connID", user.HomeDir, "")
connection := &Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content")
@ -991,6 +973,118 @@ func TestBasicUsersCache(t *testing.T) {
assert.NoError(t, err)
}
func TestCachedUserWithFolders(t *testing.T) {
username := "webdav_internal_folder_test"
password := "dav_pwd"
folderName := "test_folder"
u := dataprovider.User{
Username: username,
Password: password,
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
ExpirationDate: 0,
}
u.Permissions = make(map[string][]string)
u.Permissions["/"] = []string{dataprovider.PermAny}
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: filepath.Join(os.TempDir(), folderName),
},
VirtualPath: "/vpath",
})
err := dataprovider.AddUser(&u)
assert.NoError(t, err)
user, err := dataprovider.UserExists(u.Username)
assert.NoError(t, err)
c := &Configuration{
Bindings: []Binding{
{
Port: 9000,
},
},
Cache: Cache{
Users: UsersCacheConfig{
MaxSize: 50,
ExpirationTime: 1,
},
},
}
server := webDavServer{
config: c,
binding: c.Bindings[0],
}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
assert.NoError(t, err)
ipAddr := "127.0.0.1"
_, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled
assert.Error(t, err)
now := time.Now()
req.SetBasicAuth(username, password)
_, isCached, _, loginMethod, err := server.authenticate(req, ipAddr)
assert.NoError(t, err)
assert.False(t, isCached)
assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
// now the user should be cached
var cachedUser *dataprovider.CachedUser
result, ok := dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired())
assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
// authenticate must return the cached user now
authUser, isCached, _, _, err := server.authenticate(req, ipAddr)
assert.NoError(t, err)
assert.True(t, isCached)
assert.Equal(t, cachedUser.User, authUser)
}
folder, err := dataprovider.GetFolderByName(folderName)
assert.NoError(t, err)
// updating a used folder should invalidate the cache
err = dataprovider.UpdateFolder(&folder, folder.Users)
assert.NoError(t, err)
_, isCached, _, loginMethod, err = server.authenticate(req, ipAddr)
assert.NoError(t, err)
assert.False(t, isCached)
assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired())
}
err = dataprovider.DeleteFolder(folderName)
assert.NoError(t, err)
// removing a used folder should invalidate the cache
_, isCached, _, loginMethod, err = server.authenticate(req, ipAddr)
assert.NoError(t, err)
assert.False(t, isCached)
assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod)
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired())
}
err = dataprovider.DeleteUser(user.Username)
assert.NoError(t, err)
_, ok = dataprovider.GetCachedWebDAVUser(username)
assert.False(t, ok)
err = os.RemoveAll(u.GetHomeDir())
assert.NoError(t, err)
err = os.RemoveAll(folder.MappedPath)
assert.NoError(t, err)
}
func TestUsersCacheSizeAndExpiration(t *testing.T) {
username := "webdav_internal_test"
password := "pwd"
@ -1188,6 +1282,65 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
assert.NoError(t, err)
}
func TestUserCacheIsolation(t *testing.T) {
username := "webdav_internal_cache_test"
password := "dav_pwd"
u := dataprovider.User{
Username: username,
Password: password,
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
ExpirationDate: 0,
}
u.Permissions = make(map[string][]string)
u.Permissions["/"] = []string{dataprovider.PermAny}
err := dataprovider.AddUser(&u)
assert.NoError(t, err)
user, err := dataprovider.UserExists(u.Username)
assert.NoError(t, err)
cachedUser := &dataprovider.CachedUser{
User: user,
Expiration: time.Now().Add(24 * time.Hour),
Password: password,
LockSystem: webdav.NewMemLS(),
}
cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret")
err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt()
assert.NoError(t, err)
dataprovider.CacheWebDAVUser(cachedUser, 10)
result, ok := dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser := result.(*dataprovider.CachedUser).User
_, err = cachedUser.GetFilesystem("")
assert.NoError(t, err)
// the filesystem is now cached
}
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser := result.(*dataprovider.CachedUser).User
assert.True(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted())
err = cachedUser.FsConfig.S3Config.AccessSecret.Decrypt()
assert.NoError(t, err)
cachedUser.FsConfig.Provider = vfs.S3FilesystemProvider
_, err = cachedUser.GetFilesystem("")
assert.Error(t, err, "we don't have to get the previously cached filesystem!")
}
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser := result.(*dataprovider.CachedUser).User
assert.Equal(t, vfs.LocalFilesystemProvider, cachedUser.FsConfig.Provider)
// FIXME: should we really allow to modify the cached users concurrently?????
assert.False(t, cachedUser.FsConfig.S3Config.AccessSecret.IsEncrypted())
}
err = dataprovider.DeleteUser(username)
assert.NoError(t, err)
_, ok = dataprovider.GetCachedWebDAVUser(username)
assert.False(t, ok)
}
func TestRecoverer(t *testing.T) {
c := &Configuration{
Bindings: []Binding{

View file

@ -160,7 +160,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
return
}
user, _, lockSystem, loginMethod, err := s.authenticate(r, ipAddr)
user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr)
if err != nil {
w.Header().Set("WWW-Authenticate", "Basic realm=\"SFTPGo WebDAV\"")
http.Error(w, fmt.Sprintf("Authentication error: %v", err), http.StatusUnauthorized)
@ -174,8 +174,14 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
fs, err := user.GetFilesystem(connectionID)
if !isCached {
err = user.CheckFsRoot(connectionID)
} else {
_, err = user.GetFilesystem(connectionID)
}
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
updateLoginMetrics(&user, ipAddr, loginMethod, err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@ -187,7 +193,7 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx = context.WithValue(ctx, requestStartKey, time.Now())
connection := &Connection{
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, user),
request: r,
}
common.Connections.Add(connection)
@ -273,14 +279,6 @@ func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.Us
cachedUser.Expiration = time.Now().Add(time.Duration(s.config.Cache.Users.ExpirationTime) * time.Minute)
}
dataprovider.CacheWebDAVUser(cachedUser, s.config.Cache.Users.MaxSize)
if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
// for sftp fs check root path does nothing so don't open a useless SFTP connection
tempFs, err := user.GetFilesystem("temp")
if err == nil {
tempFs.CheckRootPath(user.Username, user.UID, user.GID)
tempFs.Close()
}
}
return user, false, lockSystem, loginMethod, nil
}
@ -295,11 +293,11 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
}
if utils.IsStringInSlice(common.ProtocolWebDAV, user.Filters.DeniedProtocols) {
logger.Debug(logSender, connectionID, "cannot login user %#v, protocol DAV is not allowed", user.Username)
return connID, fmt.Errorf("Protocol DAV is not allowed for user %#v", user.Username)
return connID, fmt.Errorf("protocol DAV is not allowed for user %#v", user.Username)
}
if !user.IsLoginMethodAllowed(loginMethod, nil) {
logger.Debug(logSender, connectionID, "cannot login user %#v, %v login method is not allowed", user.Username, loginMethod)
return connID, fmt.Errorf("Login method %v is not allowed for user %#v", loginMethod, user.Username)
return connID, fmt.Errorf("login method %v is not allowed for user %#v", loginMethod, user.Username)
}
if user.MaxSessions > 0 {
activeSessions := common.Connections.GetActiveSessions(user.Username)
@ -309,14 +307,9 @@ func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, lo
return connID, fmt.Errorf("too many open sessions: %v", activeSessions)
}
}
if dataprovider.GetQuotaTracking() > 0 && user.HasOverlappedMappedPaths() {
logger.Debug(logSender, connectionID, "cannot login user %#v, overlapping mapped folders are allowed only with quota tracking disabled",
user.Username)
return connID, errors.New("overlapping mapped folders are allowed only with quota tracking disabled")
}
if !user.IsLoginFromAddrAllowed(r.RemoteAddr) {
logger.Debug(logSender, connectionID, "cannot login user %#v, remote address is not allowed: %v", user.Username, r.RemoteAddr)
return connID, fmt.Errorf("Login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
return connID, fmt.Errorf("login for user %#v is not allowed from this address: %v", user.Username, r.RemoteAddr)
}
return connID, nil
}

View file

@ -876,9 +876,9 @@ func TestMaxConnections(t *testing.T) {
client := getWebDavClient(user, true, nil)
assert.NoError(t, checkBasicFunc(client))
// now add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), nil)
fs := vfs.NewOsFs("id", os.TempDir(), "")
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
common.Connections.Add(connection)
assert.Error(t, checkBasicFunc(client))
@ -900,9 +900,9 @@ func TestMaxSessions(t *testing.T) {
client := getWebDavClient(user, false, nil)
assert.NoError(t, checkBasicFunc(client))
// now add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), nil)
fs := vfs.NewOsFs("id", os.TempDir(), "")
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user),
}
common.Connections.Add(connection)
assert.Error(t, checkBasicFunc(client))
@ -1308,7 +1308,7 @@ func TestClientClose(t *testing.T) {
func TestLoginWithDatabaseCredentials(t *testing.T) {
u := getTestUser()
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "test"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`)
@ -1356,7 +1356,7 @@ func TestLoginWithDatabaseCredentials(t *testing.T) {
func TestLoginInvalidFs(t *testing.T) {
u := getTestUser()
u.FsConfig.Provider = dataprovider.GCSFilesystemProvider
u.FsConfig.Provider = vfs.GCSFilesystemProvider
u.FsConfig.GCSConfig.Bucket = "test"
u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials")
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
@ -2006,6 +2006,106 @@ func TestPreLoginHookWithClientCert(t *testing.T) {
assert.NoError(t, err)
}
func TestNestedVirtualFolders(t *testing.T) {
u := getTestUser()
localUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
u = getTestSFTPUser()
mappedPathCrypt := filepath.Join(os.TempDir(), "crypt")
folderNameCrypt := filepath.Base(mappedPathCrypt)
vdirCryptPath := "/vdir/crypt"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameCrypt,
FsConfig: vfs.Filesystem{
Provider: vfs.CryptedFilesystemProvider,
CryptConfig: vfs.CryptFsConfig{
Passphrase: kms.NewPlainSecret(defaultPassword),
},
},
MappedPath: mappedPathCrypt,
},
VirtualPath: vdirCryptPath,
})
mappedPath := filepath.Join(os.TempDir(), "local")
folderName := filepath.Base(mappedPath)
vdirPath := "/vdir/local"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderName,
MappedPath: mappedPath,
},
VirtualPath: vdirPath,
})
mappedPathNested := filepath.Join(os.TempDir(), "nested")
folderNameNested := filepath.Base(mappedPathNested)
vdirNestedPath := "/vdir/crypt/nested"
u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: vfs.BaseVirtualFolder{
Name: folderNameNested,
MappedPath: mappedPathNested,
},
VirtualPath: vdirNestedPath,
QuotaFiles: -1,
QuotaSize: -1,
})
sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(sftpUser, true, nil)
assert.NoError(t, checkBasicFunc(client))
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
localDownloadPath := filepath.Join(homeBasePath, testDLFileName)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
err = uploadFile(testFilePath, testFileName, testFileSize, client)
assert.NoError(t, err)
err = downloadFile(testFileName, localDownloadPath, testFileSize, client)
assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client)
assert.NoError(t, err)
err = downloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client)
assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client)
assert.NoError(t, err)
err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client)
assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client)
assert.NoError(t, err)
err = downloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client)
assert.NoError(t, err)
err = uploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client)
assert.NoError(t, err)
err = downloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client)
assert.NoError(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
err = os.Remove(localDownloadPath)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(sftpUser, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(localUser, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK)
assert.NoError(t, err)
_, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathCrypt)
assert.NoError(t, err)
err = os.RemoveAll(mappedPath)
assert.NoError(t, err)
err = os.RemoveAll(mappedPathNested)
assert.NoError(t, err)
err = os.RemoveAll(localUser.GetHomeDir())
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
}
func checkBasicFunc(client *gowebdav.Client) error {
err := client.Connect()
if err != nil {
@ -2125,7 +2225,7 @@ func getTestUser() dataprovider.User {
func getTestSFTPUser() dataprovider.User {
u := getTestUser()
u.Username = u.Username + "_sftp"
u.FsConfig.Provider = dataprovider.SFTPFilesystemProvider
u.FsConfig.Provider = vfs.SFTPFilesystemProvider
u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr
u.FsConfig.SFTPConfig.Username = defaultUsername
u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword)
@ -2134,7 +2234,7 @@ func getTestSFTPUser() dataprovider.User {
func getTestUserWithCryptFs() dataprovider.User {
user := getTestUser()
user.FsConfig.Provider = dataprovider.CryptedFilesystemProvider
user.FsConfig.Provider = vfs.CryptedFilesystemProvider
user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase")
return user
}