mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
simplify scp upload code and add some test cases
This commit is contained in:
parent
e50c521c33
commit
1c5aac0dc4
3 changed files with 38 additions and 21 deletions
|
@ -436,7 +436,7 @@ func TestSCPCommandHandleErrors(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRecursiveDownloadErrors(t *testing.T) {
|
||||
func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
|
@ -475,7 +475,7 @@ func TestRecursiveDownloadErrors(t *testing.T) {
|
|||
os.Remove(path)
|
||||
}
|
||||
|
||||
func TestRecursiveUploadErrors(t *testing.T) {
|
||||
func TestSCPRecursiveUploadErrors(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
|
|
47
sftpd/scp.go
47
sftpd/scp.go
|
@ -101,32 +101,16 @@ func (c *scpCommand) handleRecursiveUpload() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objPath := path.Join(destPath, name)
|
||||
if strings.HasPrefix(command, "D") {
|
||||
numDirs++
|
||||
err = c.handleCreateDir(objPath)
|
||||
destPath = path.Join(destPath, name)
|
||||
err = c.handleCreateDir(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destPath = objPath
|
||||
logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath)
|
||||
} else if strings.HasPrefix(command, "C") {
|
||||
// if the upload is not recursive and the destination path does not end with "/"
|
||||
// then this is the wanted filename ...
|
||||
if !c.isRecursive() {
|
||||
if !strings.HasSuffix(destPath, "/") {
|
||||
objPath = destPath
|
||||
// ... but if the requested path is an existing directory then put the uploaded file inside that directory
|
||||
if p, err := c.connection.buildPath(objPath); err == nil {
|
||||
if stat, err := os.Stat(p); err == nil {
|
||||
if stat.IsDir() {
|
||||
objPath = path.Join(destPath, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
err = c.handleUpload(objPath, sizeToRead)
|
||||
err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -690,6 +674,31 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
|
|||
return size, name, err
|
||||
}
|
||||
|
||||
func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string {
|
||||
if !c.isRecursive() {
|
||||
// if the upload is not recursive and the destination path does not end with "/"
|
||||
// then scpDestPath is the wanted filename, for example:
|
||||
// scp fileName.txt user@127.0.0.1:/newFileName.txt
|
||||
// or
|
||||
// scp fileName.txt user@127.0.0.1:/fileName.txt
|
||||
if !strings.HasSuffix(scpDestPath, "/") {
|
||||
// but if scpDestPath is an existing directory then we put the uploaded file
|
||||
// inside that directory this is as scp command works, for example:
|
||||
// scp fileName.txt user@127.0.0.1:/existing_dir
|
||||
if p, err := c.connection.buildPath(scpDestPath); err == nil {
|
||||
if stat, err := os.Stat(p); err == nil {
|
||||
if stat.IsDir() {
|
||||
return path.Join(scpDestPath, fileName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return scpDestPath
|
||||
}
|
||||
}
|
||||
// if the upload is recursive then the destination file is relative to the current scpDestPath
|
||||
return path.Join(scpDestPath, fileName)
|
||||
}
|
||||
|
||||
func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
|
||||
var defaultMode string
|
||||
if isDir {
|
||||
|
|
|
@ -275,6 +275,10 @@ func TestDirCommands(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("error mkdir all: %v", err)
|
||||
}
|
||||
_, err = client.ReadDir("/this/dir/does/not/exist")
|
||||
if err == nil {
|
||||
t.Errorf("reading a missing dir must fail")
|
||||
}
|
||||
testFileName := "/test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
|
@ -334,6 +338,10 @@ func TestSymlink(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("error creating symlink: %v", err)
|
||||
}
|
||||
_, err = client.ReadLink(testFileName + ".link")
|
||||
if err == nil {
|
||||
t.Errorf("readlink is currently not implemented so must fail")
|
||||
}
|
||||
err = client.Symlink(testFileName, testFileName+".link")
|
||||
if err == nil {
|
||||
t.Errorf("creating a symlink to an existing one must fail")
|
||||
|
|
Loading…
Reference in a new issue