add a recoverer where appropriate

I have never seen this, but a malformed packet can easily crash pkg/sftp
This commit is contained in:
Nicola Murino 2020-10-31 11:02:04 +01:00
parent fcfdd633f6
commit 950a5ad9ea
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
7 changed files with 73 additions and 6 deletions

View file

@ -1846,3 +1846,28 @@ func TestSFTPSubSystem(t *testing.T) {
err = subsystemChannel.Close()
assert.NoError(t, err)
}
func TestRecoverer(t *testing.T) {
c := Configuration{}
c.AcceptInboundConnection(nil, nil)
connID := "connectionID"
connection := &Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil),
}
c.handleSftpConnection(nil, connection)
sshCmd := sshCommand{
command: "cd",
connection: connection,
}
err := sshCmd.handle()
assert.EqualError(t, err, common.ErrGenericFailure.Error())
scpCmd := scpCommand{
sshCommand: sshCommand{
command: "scp",
connection: connection,
},
}
err = scpCmd.handle()
assert.EqualError(t, err, common.ErrGenericFailure.Error())
assert.Len(t, common.Connections.GetStats(), 0)
}

View file

@ -7,6 +7,7 @@ import (
"os"
"path"
"path/filepath"
"runtime/debug"
"strconv"
"strings"
@ -28,11 +29,16 @@ type scpCommand struct {
sshCommand
}
func (c *scpCommand) handle() error {
func (c *scpCommand) handle() (err error) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handle scp command: %#v stack strace: %v", r, string(debug.Stack()))
err = common.ErrGenericFailure
}
}()
common.Connections.Add(c.connection)
defer common.Connections.Remove(c.connection.GetID())
var err error
destPath := c.getDestPath()
commandType := c.getCommandType()
c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",

View file

@ -11,6 +11,7 @@ import (
"net"
"os"
"path/filepath"
"runtime/debug"
"strings"
"time"
@ -266,6 +267,11 @@ func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Server
// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
}
}()
// Before beginning a handshake must be performed on the incoming net.Conn
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
@ -374,6 +380,11 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
}
func (c Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
}
}()
common.Connections.Add(connection)
defer common.Connections.Remove(connection.GetID())

View file

@ -462,13 +462,12 @@ func TestConcurrency(t *testing.T) {
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
assert.NoError(t, err)
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
client.Close()
}
}(i)
}

View file

@ -12,6 +12,7 @@ import (
"os"
"os/exec"
"path"
"runtime/debug"
"strings"
"sync"
@ -84,7 +85,13 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
return false
}
func (c *sshCommand) handle() error {
func (c *sshCommand) handle() (err error) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handle ssh command: %#v stack strace: %v", r, string(debug.Stack()))
err = common.ErrGenericFailure
}
}()
common.Connections.Add(c.connection)
defer common.Connections.Remove(c.connection.GetID())
@ -108,7 +115,7 @@ func (c *sshCommand) handle() error {
} else if c.command == "sftpgo-remove" {
return c.handeSFTPGoRemove()
}
return nil
return
}
func (c *sshCommand) handeSFTPGoCopy() error {

View file

@ -8,6 +8,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path"
"path/filepath"
@ -862,3 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
_, err = httpd.RemoveUser(user4, http.StatusOK)
assert.NoError(t, err)
}
func TestRecoverer(t *testing.T) {
c := &Configuration{
BindPort: 9000,
}
server, err := newServer(c, configDir)
assert.NoError(t, err)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}

View file

@ -8,6 +8,7 @@ import (
"net/http"
"path"
"path/filepath"
"runtime/debug"
"strings"
"time"
@ -85,6 +86,12 @@ func (s *webDavServer) listenAndServe() error {
// ServeHTTP implements the http.Handler interface
func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack()))
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
}
}()
checkRemoteAddress(r)
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)