sshd: map each channel with a new connection

Fixes #169
This commit is contained in:
Nicola Murino 2020-09-18 10:52:53 +02:00
parent 98a6d138d4
commit 2df0dd1f70
10 changed files with 57 additions and 61 deletions

View file

@ -160,7 +160,6 @@ type ActiveConnection interface {
GetLastActivity() time.Time
GetCommand() string
Disconnect() error
SetConnDeadline()
AddTransfer(t ActiveTransfer)
RemoveTransfer(t ActiveTransfer)
GetTransfers() []ConnectionTransfer
@ -405,16 +404,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
conns.connections[len(conns.connections)-1] = nil
conns.connections = conns.connections[:len(conns.connections)-1]
metrics.UpdateActiveConnectionsSize(len(conns.connections))
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v",
len(conns.connections))
// we have finished to send data here and most of the time the underlying network connection
// is already closed. Sometime a client can still be reading the last sended data, so we set
// a deadline instead of directly closing the network connection.
// Setting a deadline on an already closed connection has no effect.
// We only need to ensure that a connection will not remain indefinitely open and so the
// underlying file descriptor is not released.
// This should protect us against buggy clients and edge cases.
c.SetConnDeadline()
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", len(conns.connections))
} else {
logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID)
}

View file

@ -68,8 +68,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
return ""
}
func (c *fakeConnection) SetConnDeadline() {}
func TestMain(m *testing.M) {
logfilePath := "common_test.log"
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)

View file

@ -42,9 +42,6 @@ func (c *Connection) GetRemoteAddress() string {
return c.clientContext.RemoteAddr().String()
}
// SetConnDeadline does nothing
func (c *Connection) SetConnDeadline() {}
// Disconnect disconnects the client
func (c *Connection) Disconnect() error {
return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed")

View file

@ -114,8 +114,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
return ""
}
func (c *fakeConnection) SetConnDeadline() {}
func TestMain(m *testing.M) {
homeBasePath = os.TempDir()
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")

View file

@ -23,7 +23,6 @@ type Connection struct {
ClientVersion string
// Remote address for this connection
RemoteAddr net.Addr
netConn net.Conn
channel ssh.Channel
command string
}
@ -38,11 +37,6 @@ func (c *Connection) GetRemoteAddress() string {
return c.RemoteAddr.String()
}
// SetConnDeadline sets a deadline on the network connection so it will be eventually closed
func (c *Connection) SetConnDeadline() {
c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck
}
// GetCommand returns the SSH command, if any
func (c *Connection) GetCommand() string {
return c.command
@ -413,11 +407,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
// Disconnect disconnects the client closing the network connection
func (c *Connection) Disconnect() error {
if c.channel != nil {
err := c.channel.Close()
c.Log(logger.LevelInfo, "channel close, err: %v", err)
}
return c.netConn.Close()
return c.channel.Close()
}
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {

View file

@ -518,7 +518,6 @@ func TestSSHCommandErrors(t *testing.T) {
connection := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel,
netConn: client,
}
cmd := sshCommand{
command: "md5sum",
@ -674,7 +673,6 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel,
netConn: client,
}
cmd := sshCommand{
command: "md5sum",
@ -747,7 +745,6 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel,
netConn: client,
}
cmd := sshCommand{
command: "md5sum",
@ -960,7 +957,6 @@ func TestSystemCommandErrors(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel,
netConn: client,
}
var sshCmd sshCommand
if runtime.GOOS == osWindows {
@ -1268,7 +1264,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
channel: &mockSSHChannel,
netConn: client,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
@ -1309,7 +1304,6 @@ func TestSCPErrorsMockFs(t *testing.T) {
}()
connection := &Connection{
channel: &mockSSHChannel,
netConn: client,
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
}
scpCommand := scpCommand{
@ -1364,7 +1358,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
channel: &mockSSHChannel,
netConn: client,
}
scpCommand := scpCommand{
sshCommand: sshCommand{

View file

@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
conn.SetDeadline(time.Time{}) //nolint:errcheck
defer conn.Close()
var user dataprovider.User
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if err != nil {
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
conn.Close()
return
}
connection := Connection{
BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
netConn: conn,
channel: nil,
}
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
dataprovider.UpdateLastLogin(user) //nolint:errcheck
go ssh.DiscardRequests(reqs)
channelCounter := 0
for newChannel := range chans {
// If its not a session channel we just move on because its not something we
// know how to handle at this point.
if newChannel.ChannelType() != "session" {
connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType())
logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v",
newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
connection.Log(logger.LevelWarn, "could not accept a channel: %v", err)
logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err)
continue
}
channelCounter++
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
go func(in <-chan *ssh.Request) {
go func(in <-chan *ssh.Request, counter int) {
for req := range in {
ok := false
connID := fmt.Sprintf("%v_%v", connectionID, counter)
switch req.Type {
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
ok = true
connection.SetProtocol(common.ProtocolSFTP)
connection.channel = channel
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
go c.handleSftpConnection(channel, &connection)
}
case "exec":
connection.SetProtocol(common.ProtocolSSH)
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands)
// protocol will be set later inside processSSHCommand it could be SSH or SCP
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
}
req.Reply(ok, nil) //nolint:errcheck
}
}(requests)
}(requests, channelCounter)
}
}

View file

@ -5368,6 +5368,33 @@ func TestPermsSubDirsSetstat(t *testing.T) {
assert.NoError(t, err)
}
func TestOpenUnhandledChannel(t *testing.T) {
u := getTestUser(false)
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
config := &ssh.ClientConfig{
User: user.Username,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)},
}
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
if assert.NoError(t, err) {
_, _, err = conn.OpenChannel("unhandled", nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unknown channel type")
}
err = conn.Close()
assert.NoError(t, err)
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func TestPermsSubDirsCommands(t *testing.T) {
usePubKey := true
u := getTestUser(usePubKey)

View file

@ -48,7 +48,7 @@ type systemCommand struct {
quotaCheckPath string
}
func processSSHCommand(payload []byte, connection *Connection, channel ssh.Channel, enabledSSHCommands []string) bool {
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
var msg sshSubsystemExecMsg
if err := ssh.Unmarshal(payload, &msg); err == nil {
name, args, err := parseCommandPayload(msg.Command)
@ -58,7 +58,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
connection.channel = channel
scpCommand := scpCommand{
sshCommand: sshCommand{
command: name,
@ -70,7 +69,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
}
if name != scpCmdName {
connection.SetProtocol(common.ProtocolSSH)
connection.channel = channel
sshCommand := sshCommand{
command: name,
connection: connection,

View file

@ -39,9 +39,6 @@ func (c *Connection) GetRemoteAddress() string {
return ""
}
// SetConnDeadline does nothing
func (c *Connection) SetConnDeadline() {}
// Disconnect closes the active transfer
func (c *Connection) Disconnect() error {
return c.SignalTransfersAbort()