add a recoverer where appropriate
I have never seen this, but a malformed packet can easily crash pkg/sftp
This commit is contained in:
parent
fcfdd633f6
commit
950a5ad9ea
7 changed files with 73 additions and 6 deletions
|
@ -1846,3 +1846,28 @@ func TestSFTPSubSystem(t *testing.T) {
|
|||
err = subsystemChannel.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRecoverer(t *testing.T) {
|
||||
c := Configuration{}
|
||||
c.AcceptInboundConnection(nil, nil)
|
||||
connID := "connectionID"
|
||||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil),
|
||||
}
|
||||
c.handleSftpConnection(nil, connection)
|
||||
sshCmd := sshCommand{
|
||||
command: "cd",
|
||||
connection: connection,
|
||||
}
|
||||
err := sshCmd.handle()
|
||||
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||
scpCmd := scpCommand{
|
||||
sshCommand: sshCommand{
|
||||
command: "scp",
|
||||
connection: connection,
|
||||
},
|
||||
}
|
||||
err = scpCmd.handle()
|
||||
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||
assert.Len(t, common.Connections.GetStats(), 0)
|
||||
}
|
||||
|
|
10
sftpd/scp.go
10
sftpd/scp.go
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -28,11 +29,16 @@ type scpCommand struct {
|
|||
sshCommand
|
||||
}
|
||||
|
||||
func (c *scpCommand) handle() error {
|
||||
func (c *scpCommand) handle() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error(logSender, "", "panic in handle scp command: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
err = common.ErrGenericFailure
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(c.connection)
|
||||
defer common.Connections.Remove(c.connection.GetID())
|
||||
|
||||
var err error
|
||||
destPath := c.getDestPath()
|
||||
commandType := c.getCommandType()
|
||||
c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -266,6 +267,11 @@ func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Server
|
|||
|
||||
// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
|
||||
func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
// Before beginning a handshake must be performed on the incoming net.Conn
|
||||
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
|
||||
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
|
||||
|
@ -374,6 +380,11 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||
}
|
||||
|
||||
func (c Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(connection)
|
||||
defer common.Connections.Remove(connection.GetID())
|
||||
|
||||
|
|
|
@ -462,13 +462,12 @@ func TestConcurrency(t *testing.T) {
|
|||
|
||||
client, err := getSftpClient(user, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer client.Close()
|
||||
|
||||
err = checkBasicSFTP(client)
|
||||
assert.NoError(t, err)
|
||||
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
|
||||
client.Close()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
@ -84,7 +85,13 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
|
|||
return false
|
||||
}
|
||||
|
||||
func (c *sshCommand) handle() error {
|
||||
func (c *sshCommand) handle() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error(logSender, "", "panic in handle ssh command: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
err = common.ErrGenericFailure
|
||||
}
|
||||
}()
|
||||
common.Connections.Add(c.connection)
|
||||
defer common.Connections.Remove(c.connection.GetID())
|
||||
|
||||
|
@ -108,7 +115,7 @@ func (c *sshCommand) handle() error {
|
|||
} else if c.command == "sftpgo-remove" {
|
||||
return c.handeSFTPGoRemove()
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
func (c *sshCommand) handeSFTPGoCopy() error {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
@ -862,3 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
|||
_, err = httpd.RemoveUser(user4, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRecoverer(t *testing.T) {
|
||||
c := &Configuration{
|
||||
BindPort: 9000,
|
||||
}
|
||||
server, err := newServer(c, configDir)
|
||||
assert.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
server.ServeHTTP(rr, nil)
|
||||
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/http"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -85,6 +86,12 @@ func (s *webDavServer) listenAndServe() error {
|
|||
|
||||
// ServeHTTP implements the http.Handler interface
|
||||
func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack()))
|
||||
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
checkRemoteAddress(r)
|
||||
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
|
||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
||||
|
|
Loading…
Reference in a new issue