From 324d695d93ab327af8bde283f4f4f2e0606b6a24 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 8 Jun 2023 19:41:58 +0200 Subject: [PATCH] try to fix a randomly failing test case Signed-off-by: Nicola Murino --- internal/sftpd/server.go | 49 +++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index eb6ae298..b4e0ec3e 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -571,25 +571,6 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) } -func canAcceptConnection(ip string) bool { - if common.IsBanned(ip, common.ProtocolSSH) { - logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip) - return false - } - if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil { - logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err) - return false - } - _, err := common.LimitRate(common.ProtocolSSH, ip) - if err != nil { - return false - } - if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { - return false - } - return true -} - // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { defer func() { @@ -618,6 +599,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve } // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on conn.SetDeadline(time.Time{}) //nolint:errcheck + go ssh.DiscardRequests(reqs) defer conn.Close() @@ -632,6 +614,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve defer user.CloseFs() //nolint:errcheck if err = user.CheckFsRoot(connectionID); err != nil { logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err) + go discardAllChannels(chans, "invalid root fs", connectionID) return } @@ -645,8 +628,6 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve defer common.Connections.RemoveSSHConnection(connectionID) - go ssh.DiscardRequests(reqs) - channelCounter := int64(0) for newChannel := range chans { // If its not a session channel we just move on because its not something we @@ -756,6 +737,32 @@ func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers { } } +func canAcceptConnection(ip string) bool { + if common.IsBanned(ip, common.ProtocolSSH) { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip) + return false + } + if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err) + return false + } + _, err := common.LimitRate(common.ProtocolSSH, ip) + if err != nil { + return false + } + if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { + return false + } + return true +} + +func discardAllChannels(in <-chan ssh.NewChannel, message, connectionID string) { + for req := range in { + err := req.Reject(ssh.ConnectionFailed, message) + logger.Debug(logSender, connectionID, "discarded channel request, message %q err: %v", message, err) + } +} + func checkAuthError(ip string, err error) { if authErrors, ok := err.(*ssh.ServerAuthError); ok { // check public key auth errors here