diff --git a/internal/common/connection.go b/internal/common/connection.go index 60f96b14..9795289c 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -570,6 +570,20 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s if ok, _ := c.User.IsFileAllowed(virtualTargetPath); !ok { return fmt.Errorf("file %q is not allowed: %w", virtualTargetPath, c.GetPermissionDeniedError()) } + if c.isSameResource(virtualSourcePath, virtualSourcePath) { + fs, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) + if err != nil { + return err + } + if copier, ok := fs.(vfs.FsFileCopier); ok { + _, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) + if err != nil { + return err + } + return copier.CopyFile(fsSourcePath, fsTargetPath, srcSize) + } + } + reader, rCancelFn, err := getFileReader(c, virtualSourcePath) if err != nil { return fmt.Errorf("unable to get reader for path %q: %w", virtualSourcePath, err) @@ -835,7 +849,7 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) if err != nil { - return info, err + return nil, err } if mode == 1 { @@ -847,7 +861,7 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP if !fs.IsNotExist(err) { c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err) } - return info, c.GetFsError(fs, err) + return nil, c.GetFsError(fs, err) } if convertResult && vfs.IsCryptOsFs(fs) { info = fs.(*vfs.CryptFs).ConvertFileInfo(info) @@ -1108,7 +1122,7 @@ func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fs func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo, ) bool { - if !c.isSameResourceRename(virtualSourcePath, virtualTargetPath) { + if !c.isSameResource(virtualSourcePath, virtualTargetPath) { c.Log(logger.LevelInfo, "rename %#q->%q is not allowed: the paths must be on the same resource", virtualSourcePath, virtualTargetPath) return false @@ -1359,7 +1373,7 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) return result, transferQuota } -func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPath string) bool { +func (c *BaseConnection) isSameResource(virtualSourcePath, virtualTargetPath string) bool { sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) if errSrc != nil && errDst != nil { diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go index e183c7f7..68e711cb 100644 --- a/internal/common/connection_test.go +++ b/internal/common/connection_test.go @@ -627,3 +627,21 @@ func TestConnectionKeepAlive(t *testing.T) { keepConnectionAlive(conn, done, 50*time.Millisecond) assert.Greater(t, conn.GetLastActivity(), lastActivity) } + +func TestFsFileCopier(t *testing.T) { + fs := vfs.Fs(&vfs.AzureBlobFs{}) + _, ok := fs.(vfs.FsFileCopier) + assert.True(t, ok) + fs = vfs.Fs(&vfs.OsFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.False(t, ok) + fs = vfs.Fs(&vfs.SFTPFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.False(t, ok) + fs = vfs.Fs(&vfs.GCSFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.True(t, ok) + fs = vfs.Fs(&vfs.S3Fs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.True(t, ok) +} diff --git a/internal/sftpd/scp.go b/internal/sftpd/scp.go index cd241e73..9ed12727 100644 --- a/internal/sftpd/scp.go +++ b/internal/sftpd/scp.go @@ -174,7 +174,7 @@ func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error { if err != nil { return err } - c.connection.Log(logger.LevelDebug, "created dir %#v", dirPath) + c.connection.Log(logger.LevelDebug, "created dir %q", dirPath) return nil } diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 02491f66..95bed93d 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -709,6 +709,11 @@ func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) { return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } +// CopyFile implements the FsFileCopier interface +func (fs *AzureBlobFs) CopyFile(source, target string, srcSize int64) error { + return fs.copyFileInternal(source, target) +} + func (fs *AzureBlobFs) headObject(name string) (blob.GetPropertiesResponse, error) { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() diff --git a/internal/vfs/gcsfs.go b/internal/vfs/gcsfs.go index 7b5c39d1..f9d08397 100644 --- a/internal/vfs/gcsfs.go +++ b/internal/vfs/gcsfs.go @@ -687,6 +687,11 @@ func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) { return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } +// CopyFile implements the FsFileCopier interface +func (fs *GCSFs) CopyFile(source, target string, srcSize int64) error { + return fs.copyFileInternal(source, target) +} + func (fs *GCSFs) resolve(name, prefix, contentType string) (string, bool) { result := strings.TrimPrefix(name, prefix) isDir := strings.HasSuffix(result, "/") diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index 27326018..d4b97527 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -680,6 +680,11 @@ func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) { return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil } +// CopyFile implements the FsFileCopier interface +func (fs *S3Fs) CopyFile(source, target string, srcSize int64) error { + return fs.copyFileInternal(source, target, srcSize) +} + func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) { result := strings.TrimPrefix(util.GetStringFromPointer(name), prefix) isDir := strings.HasSuffix(result, "/") diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index e80594be..c14fa2f1 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -133,6 +133,12 @@ type fsMetadataChecker interface { getFileNamesInPrefix(fsPrefix string) (map[string]bool, error) } +// FsFileCopier is a Fs that implements the CopyFile method. +type FsFileCopier interface { + Fs + CopyFile(source, target string, srcSize int64) error +} + // File defines an interface representing a SFTPGo file type File interface { io.Reader