Pārlūkot izejas kodu

copy: use server side copy if available

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
Nicola Murino 2 gadi atpakaļ
vecāks
revīzija
93e5cb36df

+ 18 - 4
internal/common/connection.go

@@ -570,6 +570,20 @@ func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, s
 	if ok, _ := c.User.IsFileAllowed(virtualTargetPath); !ok {
 	if ok, _ := c.User.IsFileAllowed(virtualTargetPath); !ok {
 		return fmt.Errorf("file %q is not allowed: %w", virtualTargetPath, c.GetPermissionDeniedError())
 		return fmt.Errorf("file %q is not allowed: %w", virtualTargetPath, c.GetPermissionDeniedError())
 	}
 	}
+	if c.isSameResource(virtualSourcePath, virtualSourcePath) {
+		fs, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath)
+		if err != nil {
+			return err
+		}
+		if copier, ok := fs.(vfs.FsFileCopier); ok {
+			_, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath)
+			if err != nil {
+				return err
+			}
+			return copier.CopyFile(fsSourcePath, fsTargetPath, srcSize)
+		}
+	}
+
 	reader, rCancelFn, err := getFileReader(c, virtualSourcePath)
 	reader, rCancelFn, err := getFileReader(c, virtualSourcePath)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("unable to get reader for path %q: %w", virtualSourcePath, err)
 		return fmt.Errorf("unable to get reader for path %q: %w", virtualSourcePath, err)
@@ -835,7 +849,7 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
 
 
 	fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
 	fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath)
 	if err != nil {
 	if err != nil {
-		return info, err
+		return nil, err
 	}
 	}
 
 
 	if mode == 1 {
 	if mode == 1 {
@@ -847,7 +861,7 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
 		if !fs.IsNotExist(err) {
 		if !fs.IsNotExist(err) {
 			c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err)
 			c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err)
 		}
 		}
-		return info, c.GetFsError(fs, err)
+		return nil, c.GetFsError(fs, err)
 	}
 	}
 	if convertResult && vfs.IsCryptOsFs(fs) {
 	if convertResult && vfs.IsCryptOsFs(fs) {
 		info = fs.(*vfs.CryptFs).ConvertFileInfo(info)
 		info = fs.(*vfs.CryptFs).ConvertFileInfo(info)
@@ -1108,7 +1122,7 @@ func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
 func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath,
 func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath,
 	virtualTargetPath string, fi os.FileInfo,
 	virtualTargetPath string, fi os.FileInfo,
 ) bool {
 ) bool {
-	if !c.isSameResourceRename(virtualSourcePath, virtualTargetPath) {
+	if !c.isSameResource(virtualSourcePath, virtualTargetPath) {
 		c.Log(logger.LevelInfo, "rename %#q->%q is not allowed: the paths must be on the same resource",
 		c.Log(logger.LevelInfo, "rename %#q->%q is not allowed: the paths must be on the same resource",
 			virtualSourcePath, virtualTargetPath)
 			virtualSourcePath, virtualTargetPath)
 		return false
 		return false
@@ -1359,7 +1373,7 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string)
 	return result, transferQuota
 	return result, transferQuota
 }
 }
 
 
-func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPath string) bool {
+func (c *BaseConnection) isSameResource(virtualSourcePath, virtualTargetPath string) bool {
 	sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
 	sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
 	dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
 	dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
 	if errSrc != nil && errDst != nil {
 	if errSrc != nil && errDst != nil {

+ 18 - 0
internal/common/connection_test.go

@@ -627,3 +627,21 @@ func TestConnectionKeepAlive(t *testing.T) {
 	keepConnectionAlive(conn, done, 50*time.Millisecond)
 	keepConnectionAlive(conn, done, 50*time.Millisecond)
 	assert.Greater(t, conn.GetLastActivity(), lastActivity)
 	assert.Greater(t, conn.GetLastActivity(), lastActivity)
 }
 }
+
+func TestFsFileCopier(t *testing.T) {
+	fs := vfs.Fs(&vfs.AzureBlobFs{})
+	_, ok := fs.(vfs.FsFileCopier)
+	assert.True(t, ok)
+	fs = vfs.Fs(&vfs.OsFs{})
+	_, ok = fs.(vfs.FsFileCopier)
+	assert.False(t, ok)
+	fs = vfs.Fs(&vfs.SFTPFs{})
+	_, ok = fs.(vfs.FsFileCopier)
+	assert.False(t, ok)
+	fs = vfs.Fs(&vfs.GCSFs{})
+	_, ok = fs.(vfs.FsFileCopier)
+	assert.True(t, ok)
+	fs = vfs.Fs(&vfs.S3Fs{})
+	_, ok = fs.(vfs.FsFileCopier)
+	assert.True(t, ok)
+}

+ 1 - 1
internal/sftpd/scp.go

@@ -174,7 +174,7 @@ func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	c.connection.Log(logger.LevelDebug, "created dir %#v", dirPath)
+	c.connection.Log(logger.LevelDebug, "created dir %q", dirPath)
 	return nil
 	return nil
 }
 }
 
 

+ 5 - 0
internal/vfs/azblobfs.go

@@ -709,6 +709,11 @@ func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) {
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 }
 }
 
 
+// CopyFile implements the FsFileCopier interface
+func (fs *AzureBlobFs) CopyFile(source, target string, srcSize int64) error {
+	return fs.copyFileInternal(source, target)
+}
+
 func (fs *AzureBlobFs) headObject(name string) (blob.GetPropertiesResponse, error) {
 func (fs *AzureBlobFs) headObject(name string) (blob.GetPropertiesResponse, error) {
 	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
 	ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
 	defer cancelFn()
 	defer cancelFn()

+ 5 - 0
internal/vfs/gcsfs.go

@@ -687,6 +687,11 @@ func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) {
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 }
 }
 
 
+// CopyFile implements the FsFileCopier interface
+func (fs *GCSFs) CopyFile(source, target string, srcSize int64) error {
+	return fs.copyFileInternal(source, target)
+}
+
 func (fs *GCSFs) resolve(name, prefix, contentType string) (string, bool) {
 func (fs *GCSFs) resolve(name, prefix, contentType string) (string, bool) {
 	result := strings.TrimPrefix(name, prefix)
 	result := strings.TrimPrefix(name, prefix)
 	isDir := strings.HasSuffix(result, "/")
 	isDir := strings.HasSuffix(result, "/")

+ 5 - 0
internal/vfs/s3fs.go

@@ -680,6 +680,11 @@ func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) {
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 	return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil
 }
 }
 
 
+// CopyFile implements the FsFileCopier interface
+func (fs *S3Fs) CopyFile(source, target string, srcSize int64) error {
+	return fs.copyFileInternal(source, target, srcSize)
+}
+
 func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) {
 func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) {
 	result := strings.TrimPrefix(util.GetStringFromPointer(name), prefix)
 	result := strings.TrimPrefix(util.GetStringFromPointer(name), prefix)
 	isDir := strings.HasSuffix(result, "/")
 	isDir := strings.HasSuffix(result, "/")

+ 6 - 0
internal/vfs/vfs.go

@@ -133,6 +133,12 @@ type fsMetadataChecker interface {
 	getFileNamesInPrefix(fsPrefix string) (map[string]bool, error)
 	getFileNamesInPrefix(fsPrefix string) (map[string]bool, error)
 }
 }
 
 
+// FsFileCopier is a Fs that implements the CopyFile method.
+type FsFileCopier interface {
+	Fs
+	CopyFile(source, target string, srcSize int64) error
+}
+
 // File defines an interface representing a SFTPGo file
 // File defines an interface representing a SFTPGo file
 type File interface {
 type File interface {
 	io.Reader
 	io.Reader