Procházet zdrojové kódy

simplify scp upload code and add some test cases

Nicola Murino před 6 roky
rodič
revize
1c5aac0dc4
3 změnil soubory, kde provedl 38 přidání a 21 odebrání
  1. 2 2
      sftpd/internal_test.go
  2. 28 19
      sftpd/scp.go
  3. 8 0
      sftpd/sftpd_test.go

+ 2 - 2
sftpd/internal_test.go

@@ -436,7 +436,7 @@ func TestSCPCommandHandleErrors(t *testing.T) {
 	}
 }
 
-func TestRecursiveDownloadErrors(t *testing.T) {
+func TestSCPRecursiveDownloadErrors(t *testing.T) {
 	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
@@ -475,7 +475,7 @@ func TestRecursiveDownloadErrors(t *testing.T) {
 	os.Remove(path)
 }
 
-func TestRecursiveUploadErrors(t *testing.T) {
+func TestSCPRecursiveUploadErrors(t *testing.T) {
 	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)

+ 28 - 19
sftpd/scp.go

@@ -101,32 +101,16 @@ func (c *scpCommand) handleRecursiveUpload() error {
 			if err != nil {
 				return err
 			}
-			objPath := path.Join(destPath, name)
 			if strings.HasPrefix(command, "D") {
 				numDirs++
-				err = c.handleCreateDir(objPath)
+				destPath = path.Join(destPath, name)
+				err = c.handleCreateDir(destPath)
 				if err != nil {
 					return err
 				}
-				destPath = objPath
 				logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath)
 			} else if strings.HasPrefix(command, "C") {
-				// if the upload is not recursive and the destination path does not end with "/"
-				// then this is the wanted filename ...
-				if !c.isRecursive() {
-					if !strings.HasSuffix(destPath, "/") {
-						objPath = destPath
-						// ... but if the requested path is an existing directory then put the uploaded file inside that directory
-						if p, err := c.connection.buildPath(objPath); err == nil {
-							if stat, err := os.Stat(p); err == nil {
-								if stat.IsDir() {
-									objPath = path.Join(destPath, name)
-								}
-							}
-						}
-					}
-				}
-				err = c.handleUpload(objPath, sizeToRead)
+				err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead)
 				if err != nil {
 					return err
 				}
@@ -690,6 +674,31 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
 	return size, name, err
 }
 
+func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string {
+	if !c.isRecursive() {
+		// if the upload is not recursive and the destination path does not end with "/"
+		// then scpDestPath is the wanted filename, for example:
+		// scp fileName.txt user@127.0.0.1:/newFileName.txt
+		// or
+		// scp fileName.txt user@127.0.0.1:/fileName.txt
+		if !strings.HasSuffix(scpDestPath, "/") {
+			// but if scpDestPath is an existing directory then we put the uploaded file
+			// inside that directory this is as scp command works, for example:
+			// scp fileName.txt user@127.0.0.1:/existing_dir
+			if p, err := c.connection.buildPath(scpDestPath); err == nil {
+				if stat, err := os.Stat(p); err == nil {
+					if stat.IsDir() {
+						return path.Join(scpDestPath, fileName)
+					}
+				}
+			}
+			return scpDestPath
+		}
+	}
+	// if the upload is recursive then the destination file is relative to the current scpDestPath
+	return path.Join(scpDestPath, fileName)
+}
+
 func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
 	var defaultMode string
 	if isDir {

+ 8 - 0
sftpd/sftpd_test.go

@@ -275,6 +275,10 @@ func TestDirCommands(t *testing.T) {
 		if err != nil {
 			t.Errorf("error mkdir all: %v", err)
 		}
+		_, err = client.ReadDir("/this/dir/does/not/exist")
+		if err == nil {
+			t.Errorf("reading a missing dir must fail")
+		}
 		testFileName := "/test_file.dat"
 		testFilePath := filepath.Join(homeBasePath, testFileName)
 		testFileSize := int64(65535)
@@ -334,6 +338,10 @@ func TestSymlink(t *testing.T) {
 		if err != nil {
 			t.Errorf("error creating symlink: %v", err)
 		}
+		_, err = client.ReadLink(testFileName + ".link")
+		if err == nil {
+			t.Errorf("readlink is currently not implemented so must fail")
+		}
 		err = client.Symlink(testFileName, testFileName+".link")
 		if err == nil {
 			t.Errorf("creating a symlink to an existing one must fail")