mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
add a recoverer where appropriate
I have never seen this, but a malformed packet can easily crash pkg/sftp
This commit is contained in:
parent
fcfdd633f6
commit
950a5ad9ea
7 changed files with 73 additions and 6 deletions
|
@ -1846,3 +1846,28 @@ func TestSFTPSubSystem(t *testing.T) {
|
||||||
err = subsystemChannel.Close()
|
err = subsystemChannel.Close()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRecoverer(t *testing.T) {
|
||||||
|
c := Configuration{}
|
||||||
|
c.AcceptInboundConnection(nil, nil)
|
||||||
|
connID := "connectionID"
|
||||||
|
connection := &Connection{
|
||||||
|
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, dataprovider.User{}, nil),
|
||||||
|
}
|
||||||
|
c.handleSftpConnection(nil, connection)
|
||||||
|
sshCmd := sshCommand{
|
||||||
|
command: "cd",
|
||||||
|
connection: connection,
|
||||||
|
}
|
||||||
|
err := sshCmd.handle()
|
||||||
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||||
|
scpCmd := scpCommand{
|
||||||
|
sshCommand: sshCommand{
|
||||||
|
command: "scp",
|
||||||
|
connection: connection,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = scpCmd.handle()
|
||||||
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||||
|
assert.Len(t, common.Connections.GetStats(), 0)
|
||||||
|
}
|
||||||
|
|
10
sftpd/scp.go
10
sftpd/scp.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -28,11 +29,16 @@ type scpCommand struct {
|
||||||
sshCommand
|
sshCommand
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *scpCommand) handle() error {
|
func (c *scpCommand) handle() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error(logSender, "", "panic in handle scp command: %#v stack strace: %v", r, string(debug.Stack()))
|
||||||
|
err = common.ErrGenericFailure
|
||||||
|
}
|
||||||
|
}()
|
||||||
common.Connections.Add(c.connection)
|
common.Connections.Add(c.connection)
|
||||||
defer common.Connections.Remove(c.connection.GetID())
|
defer common.Connections.Remove(c.connection.GetID())
|
||||||
|
|
||||||
var err error
|
|
||||||
destPath := c.getDestPath()
|
destPath := c.getDestPath()
|
||||||
commandType := c.getCommandType()
|
commandType := c.getCommandType()
|
||||||
c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",
|
c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %v command type: %v, dest path: %#v",
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -266,6 +267,11 @@ func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Server
|
||||||
|
|
||||||
// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
|
// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
|
||||||
func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
|
func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
// Before beginning a handshake must be performed on the incoming net.Conn
|
// Before beginning a handshake must be performed on the incoming net.Conn
|
||||||
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
|
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
|
||||||
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
|
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
|
||||||
|
@ -374,6 +380,11 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) {
|
func (c Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
common.Connections.Add(connection)
|
common.Connections.Add(connection)
|
||||||
defer common.Connections.Remove(connection.GetID())
|
defer common.Connections.Remove(connection.GetID())
|
||||||
|
|
||||||
|
|
|
@ -462,13 +462,12 @@ func TestConcurrency(t *testing.T) {
|
||||||
|
|
||||||
client, err := getSftpClient(user, usePubKey)
|
client, err := getSftpClient(user, usePubKey)
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
err = checkBasicSFTP(client)
|
err = checkBasicSFTP(client)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
|
err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
|
assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0)
|
||||||
|
client.Close()
|
||||||
}
|
}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -84,7 +85,13 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sshCommand) handle() error {
|
func (c *sshCommand) handle() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error(logSender, "", "panic in handle ssh command: %#v stack strace: %v", r, string(debug.Stack()))
|
||||||
|
err = common.ErrGenericFailure
|
||||||
|
}
|
||||||
|
}()
|
||||||
common.Connections.Add(c.connection)
|
common.Connections.Add(c.connection)
|
||||||
defer common.Connections.Remove(c.connection.GetID())
|
defer common.Connections.Remove(c.connection.GetID())
|
||||||
|
|
||||||
|
@ -108,7 +115,7 @@ func (c *sshCommand) handle() error {
|
||||||
} else if c.command == "sftpgo-remove" {
|
} else if c.command == "sftpgo-remove" {
|
||||||
return c.handeSFTPGoRemove()
|
return c.handeSFTPGoRemove()
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *sshCommand) handeSFTPGoCopy() error {
|
func (c *sshCommand) handeSFTPGoCopy() error {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -862,3 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
|
||||||
_, err = httpd.RemoveUser(user4, http.StatusOK)
|
_, err = httpd.RemoveUser(user4, http.StatusOK)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRecoverer(t *testing.T) {
|
||||||
|
c := &Configuration{
|
||||||
|
BindPort: 9000,
|
||||||
|
}
|
||||||
|
server, err := newServer(c, configDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
server.ServeHTTP(rr, nil)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, rr.Code)
|
||||||
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -85,6 +86,12 @@ func (s *webDavServer) listenAndServe() error {
|
||||||
|
|
||||||
// ServeHTTP implements the http.Handler interface
|
// ServeHTTP implements the http.Handler interface
|
||||||
func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack()))
|
||||||
|
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
checkRemoteAddress(r)
|
checkRemoteAddress(r)
|
||||||
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
|
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
|
||||||
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)
|
||||||
|
|
Loading…
Reference in a new issue