From 2df0dd1f70f0c4cd2128ba3ad85f8a7cf725c035 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Fri, 18 Sep 2020 10:52:53 +0200 Subject: [PATCH] sshd: map each channel with a new connection Fixes #169 --- common/common.go | 12 +---------- common/common_test.go | 2 -- ftpd/handler.go | 3 --- httpd/httpd_test.go | 2 -- sftpd/handler.go | 12 +---------- sftpd/internal_test.go | 7 ------- sftpd/server.go | 46 +++++++++++++++++++++++++----------------- sftpd/sftpd_test.go | 27 +++++++++++++++++++++++++ sftpd/ssh_cmd.go | 4 +--- webdavd/handler.go | 3 --- 10 files changed, 57 insertions(+), 61 deletions(-) diff --git a/common/common.go b/common/common.go index 0eccee64..c2e8de12 100644 --- a/common/common.go +++ b/common/common.go @@ -160,7 +160,6 @@ type ActiveConnection interface { GetLastActivity() time.Time GetCommand() string Disconnect() error - SetConnDeadline() AddTransfer(t ActiveTransfer) RemoveTransfer(t ActiveTransfer) GetTransfers() []ConnectionTransfer @@ -405,16 +404,7 @@ func (conns *ActiveConnections) Remove(connectionID string) { conns.connections[len(conns.connections)-1] = nil conns.connections = conns.connections[:len(conns.connections)-1] metrics.UpdateActiveConnectionsSize(len(conns.connections)) - logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", - len(conns.connections)) - // we have finished to send data here and most of the time the underlying network connection - // is already closed. Sometime a client can still be reading the last sended data, so we set - // a deadline instead of directly closing the network connection. - // Setting a deadline on an already closed connection has no effect. - // We only need to ensure that a connection will not remain indefinitely open and so the - // underlying file descriptor is not released. - // This should protect us against buggy clients and edge cases. - c.SetConnDeadline() + logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", len(conns.connections)) } else { logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID) } diff --git a/common/common_test.go b/common/common_test.go index 4e221610..9f812dda 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -68,8 +68,6 @@ func (c *fakeConnection) GetRemoteAddress() string { return "" } -func (c *fakeConnection) SetConnDeadline() {} - func TestMain(m *testing.M) { logfilePath := "common_test.log" logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel) diff --git a/ftpd/handler.go b/ftpd/handler.go index 6cdec4e2..f530be81 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -42,9 +42,6 @@ func (c *Connection) GetRemoteAddress() string { return c.clientContext.RemoteAddr().String() } -// SetConnDeadline does nothing -func (c *Connection) SetConnDeadline() {} - // Disconnect disconnects the client func (c *Connection) Disconnect() error { return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed") diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index de901799..69333b21 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -114,8 +114,6 @@ func (c *fakeConnection) GetRemoteAddress() string { return "" } -func (c *fakeConnection) SetConnDeadline() {} - func TestMain(m *testing.M) { homeBasePath = os.TempDir() logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") diff --git a/sftpd/handler.go b/sftpd/handler.go index 174fc84f..18148182 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -23,7 +23,6 @@ type Connection struct { ClientVersion string // Remote address for this connection RemoteAddr net.Addr - netConn net.Conn channel ssh.Channel command string } @@ -38,11 +37,6 @@ func (c *Connection) GetRemoteAddress() string { return c.RemoteAddr.String() } -// SetConnDeadline sets a deadline on the network connection so it will be eventually closed -func (c *Connection) SetConnDeadline() { - c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck -} - // GetCommand returns the SSH command, if any func (c *Connection) GetCommand() string { return c.command @@ -413,11 +407,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r // Disconnect disconnects the client closing the network connection func (c *Connection) Disconnect() error { - if c.channel != nil { - err := c.channel.Close() - c.Log(logger.LevelInfo, "channel close, err: %v", err) - } - return c.netConn.Close() + return c.channel.Close() } func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) { diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 3699f7c8..2f564d57 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -518,7 +518,6 @@ func TestSSHCommandErrors(t *testing.T) { connection := Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), channel: &mockSSHChannel, - netConn: client, } cmd := sshCommand{ command: "md5sum", @@ -674,7 +673,6 @@ func TestCommandsWithExtensionsFilter(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), channel: &mockSSHChannel, - netConn: client, } cmd := sshCommand{ command: "md5sum", @@ -747,7 +745,6 @@ func TestSSHCommandsRemoteFs(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), channel: &mockSSHChannel, - netConn: client, } cmd := sshCommand{ command: "md5sum", @@ -960,7 +957,6 @@ func TestSystemCommandErrors(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), channel: &mockSSHChannel, - netConn: client, } var sshCmd sshCommand if runtime.GOOS == osWindows { @@ -1268,7 +1264,6 @@ func TestSCPCommandHandleErrors(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), channel: &mockSSHChannel, - netConn: client, } scpCommand := scpCommand{ sshCommand: sshCommand{ @@ -1309,7 +1304,6 @@ func TestSCPErrorsMockFs(t *testing.T) { }() connection := &Connection{ channel: &mockSSHChannel, - netConn: client, BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs), } scpCommand := scpCommand{ @@ -1364,7 +1358,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { connection := &Connection{ BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs), channel: &mockSSHChannel, - netConn: client, } scpCommand := scpCommand{ sshCommand: sshCommand{ diff --git a/sftpd/server.go b/sftpd/server.go index 60f450a7..48728318 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on conn.SetDeadline(time.Time{}) //nolint:errcheck + defer conn.Close() + var user dataprovider.User // Unmarshal cannot fails here and even if it fails we'll have a user with no permissions @@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server if err != nil { logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err) - conn.Close() return } - connection := Connection{ - BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs), - ClientVersion: string(sconn.ClientVersion()), - RemoteAddr: remoteAddr, - netConn: conn, - channel: nil, - } + fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID()) - connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID()) - - connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", + logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, + "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) dataprovider.UpdateLastLogin(user) //nolint:errcheck go ssh.DiscardRequests(reqs) + channelCounter := 0 for newChannel := range chans { // If its not a session channel we just move on because its not something we // know how to handle at this point. if newChannel.ChannelType() != "session" { - connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType()) + logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v", + newChannel.ChannelType()) newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck continue } channel, requests, err := newChannel.Accept() if err != nil { - connection.Log(logger.LevelWarn, "could not accept a channel: %v", err) + logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err) continue } + channelCounter++ // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) - go func(in <-chan *ssh.Request) { + go func(in <-chan *ssh.Request, counter int) { for req := range in { ok := false + connID := fmt.Sprintf("%v_%v", connectionID, counter) switch req.Type { case "subsystem": if string(req.Payload[4:]) == "sftp" { ok = true - connection.SetProtocol(common.ProtocolSFTP) - connection.channel = channel + connection := Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs), + ClientVersion: string(sconn.ClientVersion()), + RemoteAddr: remoteAddr, + channel: channel, + } go c.handleSftpConnection(channel, &connection) } case "exec": - connection.SetProtocol(common.ProtocolSSH) - ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands) + // protocol will be set later inside processSSHCommand it could be SSH or SCP + connection := Connection{ + BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs), + ClientVersion: string(sconn.ClientVersion()), + RemoteAddr: remoteAddr, + channel: channel, + } + ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) } req.Reply(ok, nil) //nolint:errcheck } - }(requests) + }(requests, channelCounter) } } diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index ac282b3b..c130bb69 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -5368,6 +5368,33 @@ func TestPermsSubDirsSetstat(t *testing.T) { assert.NoError(t, err) } +func TestOpenUnhandledChannel(t *testing.T) { + u := getTestUser(false) + user, _, err := httpd.AddUser(u, http.StatusOK) + assert.NoError(t, err) + + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if assert.NoError(t, err) { + _, _, err = conn.OpenChannel("unhandled", nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unknown channel type") + } + err = conn.Close() + assert.NoError(t, err) + } + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestPermsSubDirsCommands(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index fbd58b22..51f115ca 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -48,7 +48,7 @@ type systemCommand struct { quotaCheckPath string } -func processSSHCommand(payload []byte, connection *Connection, channel ssh.Channel, enabledSSHCommands []string) bool { +func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool { var msg sshSubsystemExecMsg if err := ssh.Unmarshal(payload, &msg); err == nil { name, args, err := parseCommandPayload(msg.Command) @@ -58,7 +58,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann connection.command = msg.Command if name == scpCmdName && len(args) >= 2 { connection.SetProtocol(common.ProtocolSCP) - connection.channel = channel scpCommand := scpCommand{ sshCommand: sshCommand{ command: name, @@ -70,7 +69,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann } if name != scpCmdName { connection.SetProtocol(common.ProtocolSSH) - connection.channel = channel sshCommand := sshCommand{ command: name, connection: connection, diff --git a/webdavd/handler.go b/webdavd/handler.go index 7fd3297a..251aab6b 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -39,9 +39,6 @@ func (c *Connection) GetRemoteAddress() string { return "" } -// SetConnDeadline does nothing -func (c *Connection) SetConnDeadline() {} - // Disconnect closes the active transfer func (c *Connection) Disconnect() error { return c.SignalTransfersAbort()