From 242dde4480c0e39ad9dd50eeecb7eba0418188c2 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Fri, 18 Sep 2020 18:15:28 +0200 Subject: [PATCH] sftpd: ensure to always close idle connections after the last commit this wasn't the case anymore Completly fixes #169 --- common/common.go | 98 ++++++++++++++++++++++++++++++++++++++++++- common/common_test.go | 98 +++++++++++++++++++++++++++++++++++++++---- sftpd/server.go | 12 ++++-- 3 files changed, 196 insertions(+), 12 deletions(-) diff --git a/common/common.go b/common/common.go index c2e8de12..5d222332 100644 --- a/common/common.go +++ b/common/common.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/pires/go-proxyproto" @@ -336,10 +337,48 @@ func (c *Configuration) ExecutePostConnectHook(remoteAddr, protocol string) erro return err } +// SSHConnection defines an ssh connection. +// Each SSH connection can open several channels for SFTP or SSH commands +type SSHConnection struct { + id string + conn net.Conn + lastActivity int64 +} + +// NewSSHConnection returns a new SSHConnection +func NewSSHConnection(id string, conn net.Conn) *SSHConnection { + return &SSHConnection{ + id: id, + conn: conn, + lastActivity: time.Now().UnixNano(), + } +} + +// GetID returns the ID for this SSHConnection +func (c *SSHConnection) GetID() string { + return c.id +} + +// UpdateLastActivity updates last activity for this connection +func (c *SSHConnection) UpdateLastActivity() { + atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano()) +} + +// GetLastActivity returns the last connection activity +func (c *SSHConnection) GetLastActivity() time.Time { + return time.Unix(0, atomic.LoadInt64(&c.lastActivity)) +} + +// Close closes the underlying network connection +func (c *SSHConnection) Close() error { + return c.conn.Close() +} + // ActiveConnections holds the currect active connections with the associated transfers type ActiveConnections struct { sync.RWMutex - connections []ActiveConnection + connections []ActiveConnection + sshConnections []*SSHConnection } // GetActiveSessions returns the number of active sessions for the given username. @@ -431,9 +470,64 @@ func (conns *ActiveConnections) Close(connectionID string) bool { return result } +// AddSSHConnection adds a new ssh connection to the active ones +func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) { + conns.Lock() + defer conns.Unlock() + + conns.sshConnections = append(conns.sshConnections, c) + logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections)) +} + +// RemoveSSHConnection removes a connection from the active ones +func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) { + conns.Lock() + defer conns.Unlock() + + var c *SSHConnection + indexToRemove := -1 + for i, conn := range conns.sshConnections { + if conn.GetID() == connectionID { + indexToRemove = i + c = conn + break + } + } + if indexToRemove >= 0 { + conns.sshConnections[indexToRemove] = conns.sshConnections[len(conns.sshConnections)-1] + conns.sshConnections[len(conns.sshConnections)-1] = nil + conns.sshConnections = conns.sshConnections[:len(conns.sshConnections)-1] + logger.Debug(logSender, c.GetID(), "ssh connection removed, num open ssh connections: %v", len(conns.sshConnections)) + } else { + logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID) + } +} + func (conns *ActiveConnections) checkIdleConnections() { conns.RLock() + for _, sshConn := range conns.sshConnections { + idleTime := time.Since(sshConn.GetLastActivity()) + if idleTime > Config.idleTimeoutAsDuration { + // we close the an ssh connection if it has no active connections associated + idToMatch := fmt.Sprintf("_%v_", sshConn.GetID()) + toClose := true + for _, conn := range conns.connections { + if strings.Contains(conn.GetID(), idToMatch) { + toClose = false + break + } + } + if toClose { + defer func(c *SSHConnection) { + err := c.Close() + logger.Debug(logSender, c.GetID(), "close idle SSH connection, idle time: %v, close err: %v", + time.Since(c.GetLastActivity()), err) + }(sshConn) + } + } + } + for _, c := range conns.connections { idleTime := time.Since(c.GetLastActivity()) isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && len(c.GetUsername()) == 0) @@ -442,7 +536,7 @@ func (conns *ActiveConnections) checkIdleConnections() { defer func(conn ActiveConnection, isFTPNoAuth bool) { err := conn.Disconnect() logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v", - idleTime, conn.GetUsername(), err) + time.Since(conn.GetLastActivity()), conn.GetUsername(), err) if isFTPNoAuth { ip := utils.GetIPFromRemoteAddress(c.GetRemoteAddress()) logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle") diff --git a/common/common_test.go b/common/common_test.go index 9f812dda..c75651ec 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -68,6 +68,18 @@ func (c *fakeConnection) GetRemoteAddress() string { return "" } +type customNetConn struct { + net.Conn + id string + isClosed bool +} + +func (c *customNetConn) Close() error { + Connections.RemoveSSHConnection(c.id) + c.isClosed = true + return c.Conn.Close() +} + func TestMain(m *testing.M) { logfilePath := "common_test.log" logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel) @@ -168,40 +180,112 @@ func closeDataprovider() error { return dataprovider.Close() } +func TestSSHConnections(t *testing.T) { + conn1, conn2 := net.Pipe() + now := time.Now() + sshConn1 := NewSSHConnection("id1", conn1) + sshConn2 := NewSSHConnection("id2", conn2) + assert.Equal(t, "id1", sshConn1.GetID()) + assert.Equal(t, "id2", sshConn2.GetID()) + sshConn1.UpdateLastActivity() + assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano()) + Connections.AddSSHConnection(sshConn1) + Connections.AddSSHConnection(sshConn2) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 2) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn1.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 1) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn1.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 1) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn2.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 0) + Connections.RUnlock() + assert.NoError(t, sshConn1.Close()) + assert.NoError(t, sshConn2.Close()) +} + func TestIdleConnections(t *testing.T) { configCopy := Config Config.IdleTimeout = 1 Initialize(Config) + conn1, conn2 := net.Pipe() + customConn1 := &customNetConn{ + Conn: conn1, + id: "id1", + } + customConn2 := &customNetConn{ + Conn: conn2, + id: "id2", + } + sshConn1 := NewSSHConnection(customConn1.id, customConn1) + sshConn2 := NewSSHConnection(customConn2.id, customConn2) + username := "test_user" user := dataprovider.User{ Username: username, } - c := NewBaseConnection("id1", ProtocolSFTP, user, nil) + c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() fakeConn := &fakeConnection{ BaseConnection: c, } + // both ssh connections are expired but they should get removed only + // if there is no associated connection + sshConn1.lastActivity = c.lastActivity + sshConn2.lastActivity = c.lastActivity + Connections.AddSSHConnection(sshConn1) Connections.Add(fakeConn) assert.Equal(t, Connections.GetActiveSessions(username), 1) - c = NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil) - c.lastActivity = time.Now().UnixNano() + c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil) fakeConn = &fakeConnection{ BaseConnection: c, } + Connections.AddSSHConnection(sshConn2) Connections.Add(fakeConn) - assert.Equal(t, Connections.GetActiveSessions(username), 1) - assert.Len(t, Connections.GetStats(), 2) + assert.Equal(t, Connections.GetActiveSessions(username), 2) + + cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil) + cFTP.lastActivity = time.Now().UnixNano() + fakeConn = &fakeConnection{ + BaseConnection: cFTP, + } + Connections.Add(fakeConn) + assert.Equal(t, Connections.GetActiveSessions(username), 2) + assert.Len(t, Connections.GetStats(), 3) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 2) + Connections.RUnlock() startIdleTimeoutTicker(100 * time.Millisecond) - assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 0 }, 1*time.Second, 200*time.Millisecond) + assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond) + assert.Eventually(t, func() bool { + Connections.RLock() + defer Connections.RUnlock() + return len(Connections.sshConnections) == 1 + }, 1*time.Second, 200*time.Millisecond) stopIdleTimeoutTicker() - assert.Len(t, Connections.GetStats(), 1) + assert.Len(t, Connections.GetStats(), 2) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() + cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() + sshConn2.lastActivity = c.lastActivity startIdleTimeoutTicker(100 * time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond) + assert.Eventually(t, func() bool { + Connections.RLock() + defer Connections.RUnlock() + return len(Connections.sshConnections) == 0 + }, 1*time.Second, 200*time.Millisecond) stopIdleTimeoutTicker() + assert.True(t, customConn1.isClosed) + assert.True(t, customConn2.isClosed) Config = configCopy } diff --git a/sftpd/server.go b/sftpd/server.go index 48728318..a386af5b 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -311,9 +311,14 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) dataprovider.UpdateLastLogin(user) //nolint:errcheck + sshConnection := common.NewSSHConnection(connectionID, conn) + common.Connections.AddSSHConnection(sshConnection) + + defer common.Connections.RemoveSSHConnection(connectionID) + go ssh.DiscardRequests(reqs) - channelCounter := 0 + channelCounter := int64(0) for newChannel := range chans { // If its not a session channel we just move on because its not something we // know how to handle at this point. @@ -331,9 +336,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server } channelCounter++ + sshConnection.UpdateLastActivity() // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) - go func(in <-chan *ssh.Request, counter int) { + go func(in <-chan *ssh.Request, counter int64) { for req := range in { ok := false connID := fmt.Sprintf("%v_%v", connectionID, counter) @@ -353,7 +359,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server case "exec": // protocol will be set later inside processSSHCommand it could be SSH or SCP connection := Connection{ - BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs), + BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs), ClientVersion: string(sconn.ClientVersion()), RemoteAddr: remoteAddr, channel: channel,