Nicola Murino 4 лет назад
Родитель
Сommit
c0e09374a8
3 измененных файлов с 91 добавлено и 20 удалено
  1. 72 0
      sftpd/internal_test.go
  2. 17 19
      sftpd/scp.go
  3. 2 1
      sftpd/ssh_cmd.go

+ 72 - 0
sftpd/internal_test.go

@@ -1069,6 +1069,78 @@ func TestSCPFileMode(t *testing.T) {
 	assert.Equal(t, "1044", mode)
 }
 
+func TestSCPUploadError(t *testing.T) {
+	buf := make([]byte, 65535)
+	stdErrBuf := make([]byte, 65535)
+	writeErr := fmt.Errorf("test write error")
+	mockSSHChannel := MockChannel{
+		Buffer:       bytes.NewBuffer(buf),
+		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
+		ReadError:    nil,
+		WriteError:   writeErr,
+	}
+	user := dataprovider.User{
+		HomeDir:     filepath.Join(os.TempDir()),
+		Permissions: make(map[string][]string),
+	}
+	user.Permissions["/"] = []string{dataprovider.PermAny}
+	fs := vfs.NewOsFs("", user.HomeDir, nil)
+
+	connection := &Connection{
+		BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
+		channel:        &mockSSHChannel,
+	}
+	scpCommand := scpCommand{
+		sshCommand: sshCommand{
+			command:    "scp",
+			connection: connection,
+			args:       []string{"-t", "/"},
+		},
+	}
+	err := scpCommand.handle()
+	assert.EqualError(t, err, writeErr.Error())
+
+	mockSSHChannel = MockChannel{
+		Buffer:       bytes.NewBuffer([]byte("D0755 0 testdir\n")),
+		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
+		ReadError:    nil,
+		WriteError:   writeErr,
+	}
+	err = scpCommand.handleRecursiveUpload()
+	assert.EqualError(t, err, writeErr.Error())
+
+	mockSSHChannel = MockChannel{
+		Buffer:       bytes.NewBuffer([]byte("D0755 a testdir\n")),
+		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
+		ReadError:    nil,
+		WriteError:   nil,
+	}
+	err = scpCommand.handleRecursiveUpload()
+	assert.Error(t, err)
+}
+
+func TestSCPInvalidEndDir(t *testing.T) {
+	stdErrBuf := make([]byte, 65535)
+	mockSSHChannel := MockChannel{
+		Buffer:       bytes.NewBuffer([]byte("E\n")),
+		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
+	}
+	fs := vfs.NewOsFs("", os.TempDir(), nil)
+	connection := &Connection{
+		BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs),
+		channel:        &mockSSHChannel,
+	}
+	scpCommand := scpCommand{
+		sshCommand: sshCommand{
+			command:    "scp",
+			connection: connection,
+			args:       []string{"-t", "/tmp"},
+		},
+	}
+	err := scpCommand.handleRecursiveUpload()
+	assert.EqualError(t, err, "unacceptable end dir command")
+}
+
 func TestSCPParseUploadMessage(t *testing.T) {
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)

+ 17 - 19
sftpd/scp.go

@@ -1,6 +1,7 @@
 package sftpd
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -45,6 +46,10 @@ func (c *scpCommand) handle() (err error) {
 		c.args, c.connection.User.Username, commandType, destPath)
 	if commandType == "-t" {
 		// -t means "to", so upload
+		err = c.sendConfirmationMessage()
+		if err != nil {
+			return err
+		}
 		err = c.handleRecursiveUpload()
 		if err != nil {
 			return err
@@ -68,31 +73,24 @@ func (c *scpCommand) handle() (err error) {
 }
 
 func (c *scpCommand) handleRecursiveUpload() error {
-	var err error
 	numDirs := 0
 	destPath := c.getDestPath()
 	for {
-		err = c.sendConfirmationMessage()
-		if err != nil {
-			return err
-		}
 		command, err := c.getNextUploadProtocolMessage()
 		if err != nil {
+			if errors.Is(err, io.EOF) {
+				return nil
+			}
 			return err
 		}
 		if strings.HasPrefix(command, "E") {
 			numDirs--
 			c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
-			if numDirs == 0 {
-				// upload is now complete send confirmation message
-				err = c.sendConfirmationMessage()
-				if err != nil {
-					return err
-				}
-			} else {
-				// the destination dir is now the parent directory
-				destPath = path.Join(destPath, "..")
+			if numDirs < 0 {
+				return errors.New("unacceptable end dir command")
 			}
+			// the destination dir is now the parent directory
+			destPath = path.Join(destPath, "..")
 		} else {
 			sizeToRead, name, err := c.parseUploadMessage(command)
 			if err != nil {
@@ -113,11 +111,11 @@ func (c *scpCommand) handleRecursiveUpload() error {
 				}
 			}
 		}
-		if err != nil || numDirs == 0 {
-			break
+		err = c.sendConfirmationMessage()
+		if err != nil {
+			return err
 		}
 	}
-	return err
 }
 
 func (c *scpCommand) handleCreateDir(dirPath string) error {
@@ -189,7 +187,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
 		c.sendErrorMessage(err)
 		return err
 	}
-	return c.sendConfirmationMessage()
+	return nil
 }
 
 func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
@@ -572,7 +570,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
 			command.Write(readed)
 		}
 	}
-	if err != nil {
+	if err != nil && !errors.Is(err, io.EOF) {
 		c.connection.channel.Close()
 	}
 	return command.String(), err

+ 2 - 1
sftpd/ssh_cmd.go

@@ -712,7 +712,8 @@ func (c *sshCommand) sendExitStatus(err error) {
 	exitStatus := sshSubsystemExitStatus{
 		Status: status,
 	}
-	c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) //nolint:errcheck
+	_, err = c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus))
+	c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", err)
 	c.connection.channel.Close()
 	// for scp we notify single uploads/downloads
 	if c.command != scpCmdName {