mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-22 07:30:25 +00:00
parent
98a6d138d4
commit
2df0dd1f70
10 changed files with 57 additions and 61 deletions
|
@ -160,7 +160,6 @@ type ActiveConnection interface {
|
|||
GetLastActivity() time.Time
|
||||
GetCommand() string
|
||||
Disconnect() error
|
||||
SetConnDeadline()
|
||||
AddTransfer(t ActiveTransfer)
|
||||
RemoveTransfer(t ActiveTransfer)
|
||||
GetTransfers() []ConnectionTransfer
|
||||
|
@ -405,16 +404,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
|||
conns.connections[len(conns.connections)-1] = nil
|
||||
conns.connections = conns.connections[:len(conns.connections)-1]
|
||||
metrics.UpdateActiveConnectionsSize(len(conns.connections))
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v",
|
||||
len(conns.connections))
|
||||
// we have finished to send data here and most of the time the underlying network connection
|
||||
// is already closed. Sometime a client can still be reading the last sended data, so we set
|
||||
// a deadline instead of directly closing the network connection.
|
||||
// Setting a deadline on an already closed connection has no effect.
|
||||
// We only need to ensure that a connection will not remain indefinitely open and so the
|
||||
// underlying file descriptor is not released.
|
||||
// This should protect us against buggy clients and edge cases.
|
||||
c.SetConnDeadline()
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", len(conns.connections))
|
||||
} else {
|
||||
logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID)
|
||||
}
|
||||
|
|
|
@ -68,8 +68,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (c *fakeConnection) SetConnDeadline() {}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
logfilePath := "common_test.log"
|
||||
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
||||
|
|
|
@ -42,9 +42,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||
return c.clientContext.RemoteAddr().String()
|
||||
}
|
||||
|
||||
// SetConnDeadline does nothing
|
||||
func (c *Connection) SetConnDeadline() {}
|
||||
|
||||
// Disconnect disconnects the client
|
||||
func (c *Connection) Disconnect() error {
|
||||
return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed")
|
||||
|
|
|
@ -114,8 +114,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (c *fakeConnection) SetConnDeadline() {}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
homeBasePath = os.TempDir()
|
||||
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
|
||||
|
|
|
@ -23,7 +23,6 @@ type Connection struct {
|
|||
ClientVersion string
|
||||
// Remote address for this connection
|
||||
RemoteAddr net.Addr
|
||||
netConn net.Conn
|
||||
channel ssh.Channel
|
||||
command string
|
||||
}
|
||||
|
@ -38,11 +37,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||
return c.RemoteAddr.String()
|
||||
}
|
||||
|
||||
// SetConnDeadline sets a deadline on the network connection so it will be eventually closed
|
||||
func (c *Connection) SetConnDeadline() {
|
||||
c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck
|
||||
}
|
||||
|
||||
// GetCommand returns the SSH command, if any
|
||||
func (c *Connection) GetCommand() string {
|
||||
return c.command
|
||||
|
@ -413,11 +407,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
|
|||
|
||||
// Disconnect disconnects the client closing the network connection
|
||||
func (c *Connection) Disconnect() error {
|
||||
if c.channel != nil {
|
||||
err := c.channel.Close()
|
||||
c.Log(logger.LevelInfo, "channel close, err: %v", err)
|
||||
}
|
||||
return c.netConn.Close()
|
||||
return c.channel.Close()
|
||||
}
|
||||
|
||||
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {
|
||||
|
|
|
@ -518,7 +518,6 @@ func TestSSHCommandErrors(t *testing.T) {
|
|||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
cmd := sshCommand{
|
||||
command: "md5sum",
|
||||
|
@ -674,7 +673,6 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
|
|||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
cmd := sshCommand{
|
||||
command: "md5sum",
|
||||
|
@ -747,7 +745,6 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
|
|||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
cmd := sshCommand{
|
||||
command: "md5sum",
|
||||
|
@ -960,7 +957,6 @@ func TestSystemCommandErrors(t *testing.T) {
|
|||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
var sshCmd sshCommand
|
||||
if runtime.GOOS == osWindows {
|
||||
|
@ -1268,7 +1264,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
|
|||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
sshCommand: sshCommand{
|
||||
|
@ -1309,7 +1304,6 @@ func TestSCPErrorsMockFs(t *testing.T) {
|
|||
}()
|
||||
connection := &Connection{
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
|
@ -1364,7 +1358,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
|||
connection := &Connection{
|
||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
|
||||
channel: &mockSSHChannel,
|
||||
netConn: client,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
sshCommand: sshCommand{
|
||||
|
|
|
@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
|
||||
conn.SetDeadline(time.Time{}) //nolint:errcheck
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
var user dataprovider.User
|
||||
|
||||
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
|
||||
|
@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: remoteAddr,
|
||||
netConn: conn,
|
||||
channel: nil,
|
||||
}
|
||||
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
|
||||
|
||||
connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
|
||||
|
||||
connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
||||
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
||||
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
channelCounter := 0
|
||||
for newChannel := range chans {
|
||||
// If its not a session channel we just move on because its not something we
|
||||
// know how to handle at this point.
|
||||
if newChannel.ChannelType() != "session" {
|
||||
connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType())
|
||||
logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v",
|
||||
newChannel.ChannelType())
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
|
||||
continue
|
||||
}
|
||||
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
connection.Log(logger.LevelWarn, "could not accept a channel: %v", err)
|
||||
logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
channelCounter++
|
||||
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
||||
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
||||
go func(in <-chan *ssh.Request) {
|
||||
go func(in <-chan *ssh.Request, counter int) {
|
||||
for req := range in {
|
||||
ok := false
|
||||
connID := fmt.Sprintf("%v_%v", connectionID, counter)
|
||||
|
||||
switch req.Type {
|
||||
case "subsystem":
|
||||
if string(req.Payload[4:]) == "sftp" {
|
||||
ok = true
|
||||
connection.SetProtocol(common.ProtocolSFTP)
|
||||
connection.channel = channel
|
||||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: remoteAddr,
|
||||
channel: channel,
|
||||
}
|
||||
go c.handleSftpConnection(channel, &connection)
|
||||
}
|
||||
case "exec":
|
||||
connection.SetProtocol(common.ProtocolSSH)
|
||||
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands)
|
||||
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
||||
connection := Connection{
|
||||
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: remoteAddr,
|
||||
channel: channel,
|
||||
}
|
||||
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
|
||||
}
|
||||
req.Reply(ok, nil) //nolint:errcheck
|
||||
}
|
||||
}(requests)
|
||||
}(requests, channelCounter)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5368,6 +5368,33 @@ func TestPermsSubDirsSetstat(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestOpenUnhandledChannel(t *testing.T) {
|
||||
u := getTestUser(false)
|
||||
user, _, err := httpd.AddUser(u, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user.Username,
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)},
|
||||
}
|
||||
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
|
||||
if assert.NoError(t, err) {
|
||||
_, _, err = conn.OpenChannel("unhandled", nil)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unknown channel type")
|
||||
}
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
_, err = httpd.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPermsSubDirsCommands(t *testing.T) {
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
|
|
|
@ -48,7 +48,7 @@ type systemCommand struct {
|
|||
quotaCheckPath string
|
||||
}
|
||||
|
||||
func processSSHCommand(payload []byte, connection *Connection, channel ssh.Channel, enabledSSHCommands []string) bool {
|
||||
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
|
||||
var msg sshSubsystemExecMsg
|
||||
if err := ssh.Unmarshal(payload, &msg); err == nil {
|
||||
name, args, err := parseCommandPayload(msg.Command)
|
||||
|
@ -58,7 +58,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
|
|||
connection.command = msg.Command
|
||||
if name == scpCmdName && len(args) >= 2 {
|
||||
connection.SetProtocol(common.ProtocolSCP)
|
||||
connection.channel = channel
|
||||
scpCommand := scpCommand{
|
||||
sshCommand: sshCommand{
|
||||
command: name,
|
||||
|
@ -70,7 +69,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
|
|||
}
|
||||
if name != scpCmdName {
|
||||
connection.SetProtocol(common.ProtocolSSH)
|
||||
connection.channel = channel
|
||||
sshCommand := sshCommand{
|
||||
command: name,
|
||||
connection: connection,
|
||||
|
|
|
@ -39,9 +39,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// SetConnDeadline does nothing
|
||||
func (c *Connection) SetConnDeadline() {}
|
||||
|
||||
// Disconnect closes the active transfer
|
||||
func (c *Connection) Disconnect() error {
|
||||
return c.SignalTransfersAbort()
|
||||
|
|
Loading…
Reference in a new issue