diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 34f9f276..ff99ca87 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "net" "os" "runtime" "testing" @@ -446,8 +447,12 @@ func TestSCPCommandHandleErrors(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + server, client := net.Pipe() + defer server.Close() + defer client.Close() connection := Connection{ channel: &mockSSHChannel, + netConn: client, } scpCommand := scpCommand{ connection: connection, @@ -475,8 +480,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + server, client := net.Pipe() + defer server.Close() + defer client.Close() connection := Connection{ channel: &mockSSHChannel, + netConn: client, } scpCommand := scpCommand{ connection: connection, diff --git a/sftpd/scp.go b/sftpd/scp.go index dfbd802b..50e631e2 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -39,8 +39,8 @@ type scpCommand struct { func (c *scpCommand) handle() error { var err error - addConnection(c.connection.ID, c.connection) - defer removeConnection(c.connection.ID) + addConnection(c.connection) + defer removeConnection(c.connection) destPath := c.getDestPath() commandType := c.getCommandType() c.connection.Log(logger.LevelDebug, logSenderSCP, "handle scp command, args: %v user: %v command type: %v, dest path: %#v", diff --git a/sftpd/server.go b/sftpd/server.go index 9819a96f..f9c532d1 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -299,7 +299,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server } func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) { - addConnection(connection.ID, connection) + addConnection(connection) + defer removeConnection(connection) // Create a new handler for the currently logged in user's server. handler := c.createHandler(connection) @@ -312,8 +313,6 @@ func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Conn } else if err != nil { connection.Log(logger.LevelWarn, logSender, "connection closed with error: %v", err) } - - removeConnection(connection.ID) } func (c Configuration) createHandler(connection Connection) sftp.Handlers { diff --git a/sftpd/sftpd.go b/sftpd/sftpd.go index 94fa7cd8..39061f83 100644 --- a/sftpd/sftpd.go +++ b/sftpd/sftpd.go @@ -310,20 +310,27 @@ func CheckIdleConnections() { logger.Debug(logSender, "", "check idle connections ended") } -func addConnection(id string, c Connection) { +func addConnection(c Connection) { mutex.Lock() defer mutex.Unlock() - openConnections[id] = c + openConnections[c.ID] = c metrics.UpdateActiveConnectionsSize(len(openConnections)) c.Log(logger.LevelDebug, logSender, "connection added, num open connections: %v", len(openConnections)) } -func removeConnection(id string) { +func removeConnection(c Connection) { mutex.Lock() defer mutex.Unlock() - c := openConnections[id] - delete(openConnections, id) + delete(openConnections, c.ID) metrics.UpdateActiveConnectionsSize(len(openConnections)) + // we have finished to send data here and most of the time the underlying network connection + // is already closed. Sometime a client can still be reading, the last sended data, from the + // connection so we set a deadline instead of directly closing the network connection. + // Setting a deadline on an already closed connection has no effect. + // We only need to ensure that a connection will not remain undefinitely open and so the + // underlying file descriptor is not released. + // This should protect us against buggy clients and edge cases. + c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) c.Log(logger.LevelDebug, logSender, "connection removed, num open connections: %v", len(openConnections)) }