ソースを参照

connections: close the ssh channel before the network connection

This way if pkg/sftp is stuck in Serve() method should be unlocked.
Nicola Murino 5 年 前
コミット
3d13fe15c3
4 ファイル変更70 行追加56 行削除
  1. 6 0
      sftpd/handler.go
  2. 42 35
      sftpd/internal_test.go
  3. 18 19
      sftpd/scp.go
  4. 4 2
      sftpd/server.go

+ 6 - 0
sftpd/handler.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/drakkan/sftpgo/utils"
 	"github.com/rs/xid"
+	"golang.org/x/crypto/ssh"
 
 	"github.com/drakkan/sftpgo/dataprovider"
 	"github.com/drakkan/sftpgo/logger"
@@ -37,6 +38,7 @@ type Connection struct {
 	protocol     string
 	lock         *sync.Mutex
 	netConn      net.Conn
+	channel      ssh.Channel
 }
 
 // Log outputs a log entry to the configured logger
@@ -580,6 +582,10 @@ func (c Connection) createMissingDirs(filePath string) error {
 }
 
 func (c Connection) close() error {
+	if c.channel != nil {
+		err := c.channel.Close()
+		c.Log(logger.LevelInfo, logSender, "channel close, err: %v", err)
+	}
 	return c.netConn.Close()
 }
 

+ 42 - 35
sftpd/internal_test.go

@@ -252,7 +252,6 @@ func TestSCPGetNonExistingDirContent(t *testing.T) {
 }
 
 func TestSCPParseUploadMessage(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	mockSSHChannel := MockChannel{
@@ -260,10 +259,12 @@ func TestSCPParseUploadMessage(t *testing.T) {
 		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
 		ReadError:    nil,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-t", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	_, _, err := scpCommand.parseUploadMessage("invalid")
 	if err == nil {
@@ -284,7 +285,6 @@ func TestSCPParseUploadMessage(t *testing.T) {
 }
 
 func TestSCPProtocolMessages(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -295,10 +295,12 @@ func TestSCPProtocolMessages(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-t", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	_, err := scpCommand.readProtocolMessage()
 	if err == nil || err != readErr {
@@ -322,7 +324,7 @@ func TestSCPProtocolMessages(t *testing.T) {
 		ReadError:    nil,
 		WriteError:   writeErr,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	_, err = scpCommand.getNextUploadProtocolMessage()
 	if err == nil || err != writeErr {
 		t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
@@ -337,7 +339,7 @@ func TestSCPProtocolMessages(t *testing.T) {
 		ReadError:    nil,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	err = scpCommand.readConfirmationMessage()
 	if err == nil || err.Error() != protocolErrorMsg {
 		t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
@@ -345,7 +347,6 @@ func TestSCPProtocolMessages(t *testing.T) {
 }
 
 func TestSCPTestDownloadProtocolMessages(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -356,10 +357,12 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-f", "-p", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	path := "testDir"
 	os.Mkdir(path, 0777)
@@ -388,7 +391,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
 		WriteError:   writeErr,
 	}
 	scpCommand.args = []string{"-f", "/tmp"}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	err = scpCommand.sendDownloadProtocolMessages(path, stat)
 	if err != writeErr {
 		t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -400,7 +403,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	err = scpCommand.sendDownloadProtocolMessages(path, stat)
 	if err != readErr {
 		t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -409,7 +412,6 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
 }
 
 func TestSCPCommandHandleErrors(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -420,10 +422,12 @@ func TestSCPCommandHandleErrors(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-f", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	err := scpCommand.handle()
 	if err == nil || err != readErr {
@@ -437,7 +441,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
 }
 
 func TestSCPRecursiveDownloadErrors(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -448,10 +451,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-r", "-f", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	path := "testDir"
 	os.Mkdir(path, 0777)
@@ -466,7 +471,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
 		ReadError:    nil,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
 	if err == nil {
 		t.Errorf("recursive upload download must fail for a non existing dir")
@@ -476,7 +481,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
 }
 
 func TestSCPRecursiveUploadErrors(t *testing.T) {
-	connection := Connection{}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -487,10 +491,12 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-r", "-t", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	err := scpCommand.handleRecursiveUpload()
 	if err == nil {
@@ -502,7 +508,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	err = scpCommand.handleRecursiveUpload()
 	if err == nil {
 		t.Errorf("recursive upload must fail, we send a fake error message")
@@ -516,19 +522,19 @@ func TestSCPCreateDirs(t *testing.T) {
 	u.HomeDir = "home_rel_path"
 	u.Username = "test"
 	u.Permissions = []string{"*"}
-	connection := Connection{
-		User: u,
-	}
 	mockSSHChannel := MockChannel{
 		Buffer:       bytes.NewBuffer(buf),
 		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
 		ReadError:    nil,
 		WriteError:   nil,
 	}
+	connection := Connection{
+		User:    u,
+		channel: &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-r", "-t", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	err := scpCommand.handleCreateDir("invalid_dir")
 	if err == nil {
@@ -542,7 +548,6 @@ func TestSCPDownloadFileData(t *testing.T) {
 	readErr := fmt.Errorf("test read error")
 	writeErr := fmt.Errorf("test write error")
 	stdErrBuf := make([]byte, 65535)
-	connection := Connection{}
 	mockSSHChannelReadErr := MockChannel{
 		Buffer:       bytes.NewBuffer(buf),
 		StdErrBuffer: bytes.NewBuffer(stdErrBuf),
@@ -555,10 +560,12 @@ func TestSCPDownloadFileData(t *testing.T) {
 		ReadError:    nil,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		channel: &mockSSHChannelReadErr,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-r", "-f", "/tmp"},
-		channel:    &mockSSHChannelReadErr,
 	}
 	ioutil.WriteFile(testfile, []byte("test"), 0666)
 	stat, _ := os.Stat(testfile)
@@ -566,7 +573,7 @@ func TestSCPDownloadFileData(t *testing.T) {
 	if err != readErr {
 		t.Errorf("send download file data must fail with the expected error: %v", err)
 	}
-	scpCommand.channel = &mockSSHChannelWriteErr
+	scpCommand.connection.channel = &mockSSHChannelWriteErr
 	err = scpCommand.sendDownloadFileData(testfile, stat, nil)
 	if err != writeErr {
 		t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -576,7 +583,7 @@ func TestSCPDownloadFileData(t *testing.T) {
 	if err != writeErr {
 		t.Errorf("send download file data must fail with the expected error: %v", err)
 	}
-	scpCommand.channel = &mockSSHChannelReadErr
+	scpCommand.connection.channel = &mockSSHChannelReadErr
 	err = scpCommand.sendDownloadFileData(testfile, stat, nil)
 	if err != readErr {
 		t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -586,12 +593,6 @@ func TestSCPDownloadFileData(t *testing.T) {
 
 func TestSCPUploadFiledata(t *testing.T) {
 	testfile := "testfile"
-	connection := Connection{
-		User: dataprovider.User{
-			Username: "testuser",
-		},
-		protocol: protocolSCP,
-	}
 	buf := make([]byte, 65535)
 	stdErrBuf := make([]byte, 65535)
 	readErr := fmt.Errorf("test read error")
@@ -602,10 +603,16 @@ func TestSCPUploadFiledata(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   writeErr,
 	}
+	connection := Connection{
+		User: dataprovider.User{
+			Username: "testuser",
+		},
+		protocol: protocolSCP,
+		channel:  &mockSSHChannel,
+	}
 	scpCommand := scpCommand{
 		connection: connection,
 		args:       []string{"-r", "-t", "/tmp"},
-		channel:    &mockSSHChannel,
 	}
 	file, _ := os.Create(testfile)
 	transfer := Transfer{
@@ -634,7 +641,7 @@ func TestSCPUploadFiledata(t *testing.T) {
 		ReadError:    readErr,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	file, _ = os.Create(testfile)
 	transfer.file = file
 	addTransfer(&transfer)
@@ -651,7 +658,7 @@ func TestSCPUploadFiledata(t *testing.T) {
 		ReadError:    nil,
 		WriteError:   nil,
 	}
-	scpCommand.channel = &mockSSHChannel
+	scpCommand.connection.channel = &mockSSHChannel
 	file, _ = os.Create(testfile)
 	transfer.file = file
 	addTransfer(&transfer)

+ 18 - 19
sftpd/scp.go

@@ -35,7 +35,6 @@ type exitStatusMsg struct {
 type scpCommand struct {
 	connection Connection
 	args       []string
-	channel    ssh.Channel
 }
 
 func (c *scpCommand) handle() error {
@@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
 		remaining := sizeToRead
 		buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
 		for {
-			n, err := c.channel.Read(buf)
+			n, err := c.connection.channel.Read(buf)
 			if err != nil {
 				c.sendErrorMessage(err.Error())
 				transfer.TransferError(err)
@@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
 		n, err := transfer.ReadAt(buf, readed)
 		if err == nil || err == io.EOF {
 			if n > 0 {
-				_, err = c.channel.Write(buf[:n])
+				_, err = c.connection.channel.Write(buf[:n])
 			}
 		}
 		readed += int64(n)
@@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool {
 func (c *scpCommand) readConfirmationMessage() error {
 	var msg strings.Builder
 	buf := make([]byte, 1)
-	n, err := c.channel.Read(buf)
+	n, err := c.connection.channel.Read(buf)
 	if err != nil {
-		c.channel.Close()
+		c.connection.channel.Close()
 		return err
 	}
 	if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
 		isError := buf[0] == errMsg[0]
 		for {
-			n, err = c.channel.Read(buf)
+			n, err = c.connection.channel.Read(buf)
 			readed := buf[:n]
 			if err != nil || (n == 1 && readed[0] == newLine[0]) {
 				break
@@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error {
 		}
 		c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
 		err = fmt.Errorf("%v", msg.String())
-		c.channel.Close()
+		c.connection.channel.Close()
 	}
 	return err
 }
@@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
 	buf := make([]byte, 1)
 	for {
 		var n int
-		n, err = c.channel.Read(buf)
+		n, err = c.connection.channel.Read(buf)
 		if err != nil {
 			break
 		}
@@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
 		}
 	}
 	if err != nil {
-		c.channel.Close()
+		c.connection.channel.Close()
 	}
 	return command.String(), err
 }
 
 // send an error message and close the channel
 func (c *scpCommand) sendErrorMessage(error string) {
-	c.channel.Write(errMsg)
-	c.channel.Write([]byte(error))
-	c.channel.Write(newLine)
-	c.channel.Close()
+	c.connection.channel.Write(errMsg)
+	c.connection.channel.Write([]byte(error))
+	c.connection.channel.Write(newLine)
+	c.connection.channel.Close()
 }
 
 // send scp confirmation message and close the channel if an error happen
 func (c *scpCommand) sendConfirmationMessage() error {
-	_, err := c.channel.Write(okMsg)
+	_, err := c.connection.channel.Write(okMsg)
 	if err != nil {
-		c.channel.Close()
+		c.connection.channel.Close()
 	}
 	return err
 }
 
 // sends a protocol message and close the channel on error
 func (c *scpCommand) sendProtocolMessage(message string) error {
-	_, err := c.channel.Write([]byte(message))
+	_, err := c.connection.channel.Write([]byte(message))
 	if err != nil {
 		c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
-		c.channel.Close()
+		c.connection.channel.Close()
 	}
 	return err
 }
@@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) {
 	}
 	c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
 		c.args, c.connection.User.Username, err)
-	c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
-	c.channel.Close()
+	c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
+	c.connection.channel.Close()
 }
 
 // get the next upload protocol message ignoring T command if any

+ 4 - 2
sftpd/server.go

@@ -229,6 +229,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 		lastActivity:  time.Now(),
 		lock:          new(sync.Mutex),
 		netConn:       conn,
+		channel:       nil,
 	}
 	connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v",
 		user.ID, loginType, user.Username, user.HomeDir)
@@ -261,6 +262,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 					if string(req.Payload[4:]) == "sftp" {
 						ok = true
 						connection.protocol = protocolSFTP
+						connection.channel = channel
 						go c.handleSftpConnection(channel, connection)
 					}
 				case "exec":
@@ -274,10 +276,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 							if err == nil && name == "scp" && len(scpArgs) >= 2 {
 								ok = true
 								connection.protocol = protocolSCP
+								connection.channel = channel
 								scpCommand := scpCommand{
 									connection: connection,
 									args:       scpArgs,
-									channel:    channel,
 								}
 								go scpCommand.handle()
 							}
@@ -290,7 +292,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 	}
 }
 
-func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) {
+func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) {
 	addConnection(connection.ID, connection)
 	// Create a new handler for the currently logged in user's server.
 	handler := c.createHandler(connection)