Browse Source

vfs: store root dir

so we don't need to pass it over and over
Nicola Murino 5 years ago
parent
commit
d75f56b914
12 changed files with 65 additions and 57 deletions
  1. 2 2
      README.md
  2. 1 1
      dataprovider/user.go
  3. 1 1
      httpd/api_quota.go
  4. 5 5
      sftpd/handler.go
  5. 13 10
      sftpd/internal_test.go
  6. 4 4
      sftpd/scp.go
  7. 1 1
      sftpd/server.go
  8. 2 1
      sftpd/sftpd_test.go
  9. 4 4
      sftpd/ssh_cmd.go
  10. 22 19
      vfs/osfs.go
  11. 6 6
      vfs/s3fs.go
  12. 4 3
      vfs/vfs.go

+ 2 - 2
README.md

@@ -425,13 +425,13 @@ SFTPGo uses multipart uploads and parallel downloads for storing and retrieving
 Some SFTP commands doesn't work over S3:
 
 - `symlink` and `chtimes` will fail
-- `chown`, `chmod` are silently ignored
+- `chown` and `chmod` are silently ignored
 - upload resume is not supported
 - upload mode `atomic` is ignored since S3 uploads are already atomic
 
 Other notes:
 
-- `rename` is a two steps operation: server-side copy and then deletion. So it is not atomic as for local filesystem
+- `rename` is a two steps operation: server-side copy and then deletion. So it is not atomic as for local filesystem.
 - We don't support renaming non empty directories since we should rename all the contents too and this could take long time: think about directories with thousands of files, for each file we should do an AWS API call.
 - For server side encryption you have to configure the mapped bucket to automatically encrypt objects.
 - A local home directory is still required to store temporary files.

+ 1 - 1
dataprovider/user.go

@@ -114,7 +114,7 @@ func (u *User) GetFilesystem(connectionID string) (vfs.Fs, error) {
 	if u.FsConfig.Provider == 1 {
 		return vfs.NewS3Fs(connectionID, u.GetHomeDir(), u.FsConfig.S3Config)
 	}
-	return vfs.NewOsFs(connectionID), nil
+	return vfs.NewOsFs(connectionID, u.GetHomeDir()), nil
 }
 
 // GetPermissionsForPath returns the permissions for the given path.

+ 1 - 1
httpd/api_quota.go

@@ -40,7 +40,7 @@ func doQuotaScan(user dataprovider.User) error {
 		logger.Warn(logSender, "", "unable scan quota for user %#v error creating filesystem: %v", user.Username, err)
 		return err
 	}
-	numFiles, size, err := fs.ScanDirContents(user.HomeDir)
+	numFiles, size, err := fs.ScanRootDirContents()
 	if err != nil {
 		logger.Warn(logSender, "", "error scanning user home dir %#v: %v", user.HomeDir, err)
 	} else {

+ 5 - 5
sftpd/handler.go

@@ -51,7 +51,7 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 		return nil, sftp.ErrSSHFxPermissionDenied
 	}
 
-	p, err := c.fs.ResolvePath(request.Filepath, c.User.GetHomeDir())
+	p, err := c.fs.ResolvePath(request.Filepath)
 	if err != nil {
 		return nil, vfs.GetSFTPError(c.fs, err)
 	}
@@ -97,7 +97,7 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 // Filewrite handles the write actions for a file on the system.
 func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 	updateConnectionActivity(c.ID)
-	p, err := c.fs.ResolvePath(request.Filepath, c.User.GetHomeDir())
+	p, err := c.fs.ResolvePath(request.Filepath)
 	if err != nil {
 		return nil, vfs.GetSFTPError(c.fs, err)
 	}
@@ -138,7 +138,7 @@ func (c Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 func (c Connection) Filecmd(request *sftp.Request) error {
 	updateConnectionActivity(c.ID)
 
-	p, err := c.fs.ResolvePath(request.Filepath, c.User.GetHomeDir())
+	p, err := c.fs.ResolvePath(request.Filepath)
 	if err != nil {
 		return vfs.GetSFTPError(c.fs, err)
 	}
@@ -194,7 +194,7 @@ func (c Connection) Filecmd(request *sftp.Request) error {
 // a directory as well as perform file/folder stat calls.
 func (c Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 	updateConnectionActivity(c.ID)
-	p, err := c.fs.ResolvePath(request.Filepath, c.User.GetHomeDir())
+	p, err := c.fs.ResolvePath(request.Filepath)
 	if err != nil {
 		return nil, vfs.GetSFTPError(c.fs, err)
 	}
@@ -238,7 +238,7 @@ func (c Connection) getSFTPCmdTargetPath(requestTarget string) (string, error) {
 	// location for the server. If it is not, return an error
 	if len(requestTarget) > 0 {
 		var err error
-		target, err = c.fs.ResolvePath(requestTarget, c.User.GetHomeDir())
+		target, err = c.fs.ResolvePath(requestTarget)
 		if err != nil {
 			return target, vfs.GetSFTPError(c.fs, err)
 		}

+ 13 - 10
sftpd/internal_test.go

@@ -65,7 +65,7 @@ func (c *MockChannel) Stderr() io.ReadWriter {
 
 // MockOsFs mockable OsFs
 type MockOsFs struct {
-	vfs.OsFs
+	vfs.Fs
 	err                     error
 	statErr                 error
 	isAtomicUploadSupported bool
@@ -110,8 +110,9 @@ func (fs MockOsFs) Rename(source, target string) error {
 	return os.Rename(source, target)
 }
 
-func newMockOsFs(err, statErr error, atomicUpload bool) vfs.Fs {
+func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
 	return &MockOsFs{
+		Fs:                      vfs.NewOsFs(connectionID, rootDir),
 		err:                     err,
 		statErr:                 statErr,
 		isAtomicUploadSupported: atomicUpload,
@@ -366,7 +367,7 @@ func TestTransferCancelFn(t *testing.T) {
 
 func TestMockFsErrors(t *testing.T) {
 	errFake := errors.New("fake error")
-	fs := newMockOsFs(errFake, errFake, false)
+	fs := newMockOsFs(errFake, errFake, false, "123", os.TempDir())
 	u := dataprovider.User{}
 	u.Username = "test"
 	u.Permissions = make(map[string][]string)
@@ -402,7 +403,7 @@ func TestUploadFiles(t *testing.T) {
 	oldUploadMode := uploadMode
 	uploadMode = uploadModeAtomic
 	c := Connection{
-		fs: vfs.NewOsFs("123"),
+		fs: vfs.NewOsFs("123", os.TempDir()),
 	}
 	var flags sftp.FileOpenFlags
 	flags.Write = true
@@ -434,13 +435,13 @@ func TestWithInvalidHome(t *testing.T) {
 	if err == nil {
 		t.Errorf("login a user with an invalid home_dir must fail")
 	}
+	u.HomeDir = os.TempDir()
 	fs, _ := u.GetFilesystem("123")
 	c := Connection{
 		User: u,
 		fs:   fs,
 	}
-	u.HomeDir = os.TempDir()
-	_, err = c.fs.ResolvePath("../upper_path", u.GetHomeDir())
+	_, err = c.fs.ResolvePath("../upper_path")
 	if err == nil {
 		t.Errorf("tested path is not a home subdir")
 	}
@@ -469,7 +470,7 @@ func TestSFTPCmdTargetPath(t *testing.T) {
 
 func TestGetSFTPErrorFromOSError(t *testing.T) {
 	err := os.ErrNotExist
-	fs := vfs.NewOsFs("")
+	fs := vfs.NewOsFs("", os.TempDir())
 	err = vfs.GetSFTPError(fs, err)
 	if err != sftp.ErrSSHFxNoSuchFile {
 		t.Errorf("unexpected error: %v", err)
@@ -644,6 +645,8 @@ func TestSSHCommandErrors(t *testing.T) {
 	cmd.connection.User.HomeDir = os.TempDir()
 	cmd.connection.User.QuotaFiles = 1
 	cmd.connection.User.UsedQuotaFiles = 2
+	fs, _ = cmd.connection.User.GetFilesystem("123")
+	cmd.connection.fs = fs
 	err = cmd.handle()
 	if err != errQuotaExceeded {
 		t.Errorf("unexpected error: %v", err)
@@ -1175,7 +1178,7 @@ func TestSCPCommandHandleErrors(t *testing.T) {
 
 func TestSCPErrorsMockFs(t *testing.T) {
 	errFake := errors.New("fake error")
-	fs := newMockOsFs(errFake, errFake, false)
+	fs := newMockOsFs(errFake, errFake, false, "123", os.TempDir())
 	u := dataprovider.User{}
 	u.Username = "test"
 	u.Permissions = make(map[string][]string)
@@ -1214,7 +1217,7 @@ func TestSCPErrorsMockFs(t *testing.T) {
 	if err != errFake {
 		t.Errorf("unexpected error: %v", err)
 	}
-	scpCommand.sshCommand.connection.fs = newMockOsFs(errFake, nil, true)
+	scpCommand.sshCommand.connection.fs = newMockOsFs(errFake, nil, true, "123", os.TempDir())
 	err = scpCommand.handleUpload(filepath.Base(testfile), 0)
 	if err != errFake {
 		t.Errorf("unexpected error: %v", err)
@@ -1239,7 +1242,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
 	connection := Connection{
 		channel: &mockSSHChannel,
 		netConn: client,
-		fs:      vfs.NewOsFs("123"),
+		fs:      vfs.NewOsFs("123", os.TempDir()),
 	}
 	scpCommand := scpCommand{
 		sshCommand: sshCommand{

+ 4 - 4
sftpd/scp.go

@@ -116,7 +116,7 @@ func (c *scpCommand) handleRecursiveUpload() error {
 
 func (c *scpCommand) handleCreateDir(dirPath string) error {
 	updateConnectionActivity(c.connection.ID)
-	p, err := c.connection.fs.ResolvePath(dirPath, c.connection.User.GetHomeDir())
+	p, err := c.connection.fs.ResolvePath(dirPath)
 	if err != nil {
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error creating dir: %#v, invalid file path, err: %v", dirPath, err)
 		c.sendErrorMessage(err.Error())
@@ -228,7 +228,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
 
 	updateConnectionActivity(c.connection.ID)
 
-	p, err := c.connection.fs.ResolvePath(uploadFilePath, c.connection.User.GetHomeDir())
+	p, err := c.connection.fs.ResolvePath(uploadFilePath)
 	if err != nil {
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", uploadFilePath, err)
 		c.sendErrorMessage(err.Error())
@@ -422,7 +422,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
 
 	updateConnectionActivity(c.connection.ID)
 
-	p, err := c.connection.fs.ResolvePath(filePath, c.connection.User.GetHomeDir())
+	p, err := c.connection.fs.ResolvePath(filePath)
 	if err != nil {
 		err := fmt.Errorf("Invalid file path")
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error downloading file: %#v, invalid file path", filePath)
@@ -674,7 +674,7 @@ func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string
 			// but if scpDestPath is an existing directory then we put the uploaded file
 			// inside that directory this is as scp command works, for example:
 			// scp fileName.txt user@127.0.0.1:/existing_dir
-			if p, err := c.connection.fs.ResolvePath(scpDestPath, c.connection.User.GetHomeDir()); err == nil {
+			if p, err := c.connection.fs.ResolvePath(scpDestPath); err == nil {
 				if stat, err := c.connection.fs.Stat(p); err == nil {
 					if stat.IsDir() {
 						return path.Join(scpDestPath, fileName)

+ 1 - 1
sftpd/server.go

@@ -286,7 +286,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 		fs:            fs,
 	}
 
-	connection.fs.CheckRootPath(user.GetHomeDir(), user.Username, user.GetUID(), user.GetGID())
+	connection.fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
 
 	connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
 		user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())

+ 2 - 1
sftpd/sftpd_test.go

@@ -1515,6 +1515,7 @@ func TestQuotaDisabledError(t *testing.T) {
 func TestMaxSessions(t *testing.T) {
 	usePubKey := false
 	u := getTestUser(usePubKey)
+	u.Username += "1"
 	u.MaxSessions = 1
 	user, _, err := httpd.AddUser(u, http.StatusOK)
 	if err != nil {
@@ -2904,7 +2905,7 @@ func TestRootDirCommands(t *testing.T) {
 func TestRelativePaths(t *testing.T) {
 	user := getTestUser(true)
 	path := filepath.Join(user.HomeDir, "/")
-	fs := vfs.NewOsFs("")
+	fs := vfs.NewOsFs("", user.GetHomeDir())
 	rel := fs.GetRelativePath(path, user.GetHomeDir())
 	if rel != "/" {
 		t.Errorf("Unexpected relative path: %v", rel)

+ 4 - 4
sftpd/ssh_cmd.go

@@ -130,7 +130,7 @@ func (c *sshCommand) handleHashCommands() error {
 		response = fmt.Sprintf("%x  -\n", h.Sum(nil))
 	} else {
 		sshPath := c.getDestPath()
-		fsPath, err := c.connection.fs.ResolvePath(sshPath, c.connection.User.GetHomeDir())
+		fsPath, err := c.connection.fs.ResolvePath(sshPath)
 		if err != nil {
 			return c.sendErrorResponse(err)
 		}
@@ -296,7 +296,7 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) {
 	if len(c.args) > 0 {
 		var err error
 		sshPath := c.getDestPath()
-		path, err = c.connection.fs.ResolvePath(sshPath, c.connection.User.GetHomeDir())
+		path, err = c.connection.fs.ResolvePath(sshPath)
 		if err != nil {
 			return command, err
 		}
@@ -339,7 +339,7 @@ func (c *sshCommand) rescanHomeDir() error {
 	var numFiles int
 	var size int64
 	if AddQuotaScan(c.connection.User.Username) {
-		numFiles, size, err = c.connection.fs.ScanDirContents(c.connection.User.HomeDir)
+		numFiles, size, err = c.connection.fs.ScanRootDirContents()
 		if err != nil {
 			c.connection.Log(logger.LevelWarn, logSenderSSH, "error scanning user home dir %#v: %v", c.connection.User.HomeDir, err)
 		} else {
@@ -397,7 +397,7 @@ func (c *sshCommand) sendExitStatus(err error) {
 	if err == nil && c.command != "scp" {
 		realPath := c.getDestPath()
 		if len(realPath) > 0 {
-			p, err := c.connection.fs.ResolvePath(realPath, c.connection.User.GetHomeDir())
+			p, err := c.connection.fs.ResolvePath(realPath)
 			if err == nil {
 				realPath = p
 			}

+ 22 - 19
vfs/osfs.go

@@ -22,13 +22,16 @@ const (
 type OsFs struct {
 	name         string
 	connectionID string
+	rootDir      string
 }
 
 // NewOsFs returns an OsFs object that allows to interact with local Os filesystem
-func NewOsFs(connectionID string) Fs {
+func NewOsFs(connectionID, rootDir string) Fs {
 	return &OsFs{
 		name:         osFsName,
-		connectionID: connectionID}
+		connectionID: connectionID,
+		rootDir:      rootDir,
+	}
 }
 
 // Name returns the name for the Fs implementation
@@ -133,28 +136,28 @@ func (OsFs) IsPermission(err error) bool {
 	return os.IsPermission(err)
 }
 
-// CheckRootPath creates the specified root directory if it does not exists
-func (fs OsFs) CheckRootPath(rootPath, username string, uid int, gid int) bool {
+// CheckRootPath creates the root directory if it does not exists
+func (fs OsFs) CheckRootPath(username string, uid int, gid int) bool {
 	var err error
-	if _, err = fs.Stat(rootPath); fs.IsNotExist(err) {
-		err = os.MkdirAll(rootPath, 0777)
+	if _, err = fs.Stat(fs.rootDir); fs.IsNotExist(err) {
+		err = os.MkdirAll(fs.rootDir, 0777)
 		fsLog(fs, logger.LevelDebug, "root directory %#v for user %#v does not exist, try to create, mkdir error: %v",
-			rootPath, username, err)
+			fs.rootDir, username, err)
 		if err == nil {
-			SetPathPermissions(fs, rootPath, uid, gid)
+			SetPathPermissions(fs, fs.rootDir, uid, gid)
 		}
 	}
 	return (err == nil)
 }
 
-// ScanDirContents returns the number of files contained in a directory and
+// ScanRootDirContents returns the number of files contained in a directory and
 // their size
-func (fs OsFs) ScanDirContents(dirPath string) (int, int64, error) {
+func (fs OsFs) ScanRootDirContents() (int, int64, error) {
 	numFiles := 0
 	size := int64(0)
-	isDir, err := IsDirectory(fs, dirPath)
+	isDir, err := IsDirectory(fs, fs.rootDir)
 	if err == nil && isDir {
-		err = filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
+		err = filepath.Walk(fs.rootDir, func(path string, info os.FileInfo, err error) error {
 			if err != nil {
 				return err
 			}
@@ -194,27 +197,27 @@ func (OsFs) Join(elem ...string) string {
 }
 
 // ResolvePath returns the matching filesystem path for the specified sftp path
-func (fs OsFs) ResolvePath(sftpPath, rootPath string) (string, error) {
-	if !filepath.IsAbs(rootPath) {
-		return "", fmt.Errorf("Invalid root path: %v", rootPath)
+func (fs OsFs) ResolvePath(sftpPath string) (string, error) {
+	if !filepath.IsAbs(fs.rootDir) {
+		return "", fmt.Errorf("Invalid root path: %v", fs.rootDir)
 	}
-	r := filepath.Clean(filepath.Join(rootPath, sftpPath))
+	r := filepath.Clean(filepath.Join(fs.rootDir, sftpPath))
 	p, err := filepath.EvalSymlinks(r)
 	if err != nil && !os.IsNotExist(err) {
 		return "", err
 	} else if os.IsNotExist(err) {
 		// The requested path doesn't exist, so at this point we need to iterate up the
 		// path chain until we hit a directory that _does_ exist and can be validated.
-		_, err = fs.findFirstExistingDir(r, rootPath)
+		_, err = fs.findFirstExistingDir(r, fs.rootDir)
 		if err != nil {
 			fsLog(fs, logger.LevelWarn, "error resolving not existent path: %#v", err)
 		}
 		return r, err
 	}
 
-	err = fs.isSubDir(p, rootPath)
+	err = fs.isSubDir(p, fs.rootDir)
 	if err != nil {
-		fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir: %#v outside user home: %#v err: %v", p, rootPath, err)
+		fsLog(fs, logger.LevelWarn, "Invalid path resolution, dir: %#v outside user home: %#v err: %v", p, fs.rootDir, err)
 	}
 	return r, err
 }

+ 6 - 6
vfs/s3fs.go

@@ -372,10 +372,10 @@ func (S3Fs) IsPermission(err error) bool {
 }
 
 // CheckRootPath creates the specified root directory if it does not exists
-func (fs S3Fs) CheckRootPath(rootPath, username string, uid int, gid int) bool {
+func (fs S3Fs) CheckRootPath(username string, uid int, gid int) bool {
 	// we need a local directory for temporary files
-	osFs := NewOsFs(fs.ConnectionID())
-	osFs.CheckRootPath(fs.localTempDir, username, uid, gid)
+	osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir)
+	osFs.CheckRootPath(username, uid, gid)
 	err := fs.checkIfBucketExists()
 	if err == nil {
 		return true
@@ -394,9 +394,9 @@ func (fs S3Fs) CheckRootPath(rootPath, username string, uid int, gid int) bool {
 	return err == nil
 }
 
-// ScanDirContents returns the number of files contained in the bucket,
+// ScanRootDirContents returns the number of files contained in the bucket,
 // and their size
-func (fs S3Fs) ScanDirContents(dirPath string) (int, int64, error) {
+func (fs S3Fs) ScanRootDirContents() (int, int64, error) {
 	numFiles := 0
 	size := int64(0)
 	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout))
@@ -440,7 +440,7 @@ func (S3Fs) Join(elem ...string) string {
 }
 
 // ResolvePath returns the matching filesystem path for the specified sftp path
-func (fs S3Fs) ResolvePath(sftpPath, rootPath string) (string, error) {
+func (fs S3Fs) ResolvePath(sftpPath string) (string, error) {
 	return sftpPath, nil
 }
 

+ 4 - 3
vfs/vfs.go

@@ -1,3 +1,4 @@
+// Package vfs provides local and remote filesystems support
 package vfs
 
 import (
@@ -29,11 +30,11 @@ type Fs interface {
 	ReadDir(dirname string) ([]os.FileInfo, error)
 	IsUploadResumeSupported() bool
 	IsAtomicUploadSupported() bool
-	CheckRootPath(rootPath, username string, uid int, gid int) bool
-	ResolvePath(sftpPath, rootPath string) (string, error)
+	CheckRootPath(username string, uid int, gid int) bool
+	ResolvePath(sftpPath string) (string, error)
 	IsNotExist(err error) bool
 	IsPermission(err error) bool
-	ScanDirContents(dirPath string) (int, int64, error)
+	ScanRootDirContents() (int, int64, error)
 	GetAtomicUploadPath(name string) string
 	GetRelativePath(name, rootPath string) string
 	Join(elem ...string) string