simplify scp upload code and add some test cases

This commit is contained in:
Nicola Murino 2019-08-24 22:44:01 +02:00
parent e50c521c33
commit 1c5aac0dc4
3 changed files with 38 additions and 21 deletions

View file

@ -436,7 +436,7 @@ func TestSCPCommandHandleErrors(t *testing.T) {
}
}
func TestRecursiveDownloadErrors(t *testing.T) {
func TestSCPRecursiveDownloadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
@ -475,7 +475,7 @@ func TestRecursiveDownloadErrors(t *testing.T) {
os.Remove(path)
}
func TestRecursiveUploadErrors(t *testing.T) {
func TestSCPRecursiveUploadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)

View file

@ -101,32 +101,16 @@ func (c *scpCommand) handleRecursiveUpload() error {
if err != nil {
return err
}
objPath := path.Join(destPath, name)
if strings.HasPrefix(command, "D") {
numDirs++
err = c.handleCreateDir(objPath)
destPath = path.Join(destPath, name)
err = c.handleCreateDir(destPath)
if err != nil {
return err
}
destPath = objPath
logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath)
} else if strings.HasPrefix(command, "C") {
// if the upload is not recursive and the destination path does not end with "/"
// then this is the wanted filename ...
if !c.isRecursive() {
if !strings.HasSuffix(destPath, "/") {
objPath = destPath
// ... but if the requested path is an existing directory then put the uploaded file inside that directory
if p, err := c.connection.buildPath(objPath); err == nil {
if stat, err := os.Stat(p); err == nil {
if stat.IsDir() {
objPath = path.Join(destPath, name)
}
}
}
}
}
err = c.handleUpload(objPath, sizeToRead)
err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead)
if err != nil {
return err
}
@ -690,6 +674,31 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
return size, name, err
}
func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string {
if !c.isRecursive() {
// if the upload is not recursive and the destination path does not end with "/"
// then scpDestPath is the wanted filename, for example:
// scp fileName.txt user@127.0.0.1:/newFileName.txt
// or
// scp fileName.txt user@127.0.0.1:/fileName.txt
if !strings.HasSuffix(scpDestPath, "/") {
// but if scpDestPath is an existing directory then we put the uploaded file
// inside that directory this is as scp command works, for example:
// scp fileName.txt user@127.0.0.1:/existing_dir
if p, err := c.connection.buildPath(scpDestPath); err == nil {
if stat, err := os.Stat(p); err == nil {
if stat.IsDir() {
return path.Join(scpDestPath, fileName)
}
}
}
return scpDestPath
}
}
// if the upload is recursive then the destination file is relative to the current scpDestPath
return path.Join(scpDestPath, fileName)
}
func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
var defaultMode string
if isDir {

View file

@ -275,6 +275,10 @@ func TestDirCommands(t *testing.T) {
if err != nil {
t.Errorf("error mkdir all: %v", err)
}
_, err = client.ReadDir("/this/dir/does/not/exist")
if err == nil {
t.Errorf("reading a missing dir must fail")
}
testFileName := "/test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
@ -334,6 +338,10 @@ func TestSymlink(t *testing.T) {
if err != nil {
t.Errorf("error creating symlink: %v", err)
}
_, err = client.ReadLink(testFileName + ".link")
if err == nil {
t.Errorf("readlink is currently not implemented so must fail")
}
err = client.Symlink(testFileName, testFileName+".link")
if err == nil {
t.Errorf("creating a symlink to an existing one must fail")