浏览代码

sftpd: refactor connection closing

we have not known bugs with the previous implementation anyway this one
is cleaner: the underlying network connection is directly related with
SFTP/SCP connections.
This should better protect us against buggy clients and edge cases
Nicola Murino 5 年之前
父节点
当前提交
871e2ccbbf
共有 4 个文件被更改,包括 25 次插入10 次删除
  1. 9 0
      sftpd/internal_test.go
  2. 2 2
      sftpd/scp.go
  3. 2 3
      sftpd/server.go
  4. 12 5
      sftpd/sftpd.go

+ 9 - 0
sftpd/internal_test.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
+	"net"
 	"os"
 	"os"
 	"runtime"
 	"runtime"
 	"testing"
 	"testing"
@@ -446,8 +447,12 @@ func TestSCPCommandHandleErrors(t *testing.T) {
 		ReadError:    readErr,
 		ReadError:    readErr,
 		WriteError:   writeErr,
 		WriteError:   writeErr,
 	}
 	}
+	server, client := net.Pipe()
+	defer server.Close()
+	defer client.Close()
 	connection := Connection{
 	connection := Connection{
 		channel: &mockSSHChannel,
 		channel: &mockSSHChannel,
+		netConn: client,
 	}
 	}
 	scpCommand := scpCommand{
 	scpCommand := scpCommand{
 		connection: connection,
 		connection: connection,
@@ -475,8 +480,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
 		ReadError:    readErr,
 		ReadError:    readErr,
 		WriteError:   writeErr,
 		WriteError:   writeErr,
 	}
 	}
+	server, client := net.Pipe()
+	defer server.Close()
+	defer client.Close()
 	connection := Connection{
 	connection := Connection{
 		channel: &mockSSHChannel,
 		channel: &mockSSHChannel,
+		netConn: client,
 	}
 	}
 	scpCommand := scpCommand{
 	scpCommand := scpCommand{
 		connection: connection,
 		connection: connection,

+ 2 - 2
sftpd/scp.go

@@ -39,8 +39,8 @@ type scpCommand struct {
 
 
 func (c *scpCommand) handle() error {
 func (c *scpCommand) handle() error {
 	var err error
 	var err error
-	addConnection(c.connection.ID, c.connection)
-	defer removeConnection(c.connection.ID)
+	addConnection(c.connection)
+	defer removeConnection(c.connection)
 	destPath := c.getDestPath()
 	destPath := c.getDestPath()
 	commandType := c.getCommandType()
 	commandType := c.getCommandType()
 	c.connection.Log(logger.LevelDebug, logSenderSCP, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",
 	c.connection.Log(logger.LevelDebug, logSenderSCP, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",

+ 2 - 3
sftpd/server.go

@@ -299,7 +299,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
 }
 }
 
 
 func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) {
 func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) {
-	addConnection(connection.ID, connection)
+	addConnection(connection)
+	defer removeConnection(connection)
 	// Create a new handler for the currently logged in user's server.
 	// Create a new handler for the currently logged in user's server.
 	handler := c.createHandler(connection)
 	handler := c.createHandler(connection)
 
 
@@ -312,8 +313,6 @@ func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Conn
 	} else if err != nil {
 	} else if err != nil {
 		connection.Log(logger.LevelWarn, logSender, "connection closed with error: %v", err)
 		connection.Log(logger.LevelWarn, logSender, "connection closed with error: %v", err)
 	}
 	}
-
-	removeConnection(connection.ID)
 }
 }
 
 
 func (c Configuration) createHandler(connection Connection) sftp.Handlers {
 func (c Configuration) createHandler(connection Connection) sftp.Handlers {

+ 12 - 5
sftpd/sftpd.go

@@ -310,20 +310,27 @@ func CheckIdleConnections() {
 	logger.Debug(logSender, "", "check idle connections ended")
 	logger.Debug(logSender, "", "check idle connections ended")
 }
 }
 
 
-func addConnection(id string, c Connection) {
+func addConnection(c Connection) {
 	mutex.Lock()
 	mutex.Lock()
 	defer mutex.Unlock()
 	defer mutex.Unlock()
-	openConnections[id] = c
+	openConnections[c.ID] = c
 	metrics.UpdateActiveConnectionsSize(len(openConnections))
 	metrics.UpdateActiveConnectionsSize(len(openConnections))
 	c.Log(logger.LevelDebug, logSender, "connection added, num open connections: %v", len(openConnections))
 	c.Log(logger.LevelDebug, logSender, "connection added, num open connections: %v", len(openConnections))
 }
 }
 
 
-func removeConnection(id string) {
+func removeConnection(c Connection) {
 	mutex.Lock()
 	mutex.Lock()
 	defer mutex.Unlock()
 	defer mutex.Unlock()
-	c := openConnections[id]
-	delete(openConnections, id)
+	delete(openConnections, c.ID)
 	metrics.UpdateActiveConnectionsSize(len(openConnections))
 	metrics.UpdateActiveConnectionsSize(len(openConnections))
+	// we have finished to send data here and most of the time the underlying network connection
+	// is already closed. Sometime a client can still be reading, the last sended data, from the
+	// connection so we set a deadline instead of directly closing the network connection.
+	// Setting a deadline on an already closed connection has no effect.
+	// We only need to ensure that a connection will not remain undefinitely open and so the
+	// underlying file descriptor is not released.
+	// This should protect us against buggy clients and edge cases.
+	c.netConn.SetDeadline(time.Now().Add(2 * time.Minute))
 	c.Log(logger.LevelDebug, logSender, "connection removed, num open connections: %v", len(openConnections))
 	c.Log(logger.LevelDebug, logSender, "connection removed, num open connections: %v", len(openConnections))
 }
 }