// Copyright (C) 2019 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
package sftpd
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
"io"
"os"
"os/exec"
"path"
"runtime/debug"
"strings"
"sync"
"time"
"github.com/google/shlex"
"github.com/sftpgo/sdk"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
)
const (
scpCmdName = "scp"
sshCommandLogSender = "SSHCommand"
)
var (
errUnsupportedConfig = errors.New("command unsupported for this configuration")
)
type sshCommand struct {
command string
args []string
connection *Connection
startTime time.Time
}
type systemCommand struct {
cmd *exec.Cmd
fsPath string
quotaCheckPath string
fs vfs.Fs
}
func (c *systemCommand) GetSTDs() (io.WriteCloser, io.ReadCloser, io.ReadCloser, error) {
stdin, err := c.cmd.StdinPipe()
if err != nil {
return nil, nil, nil, err
}
stdout, err := c.cmd.StdoutPipe()
if err != nil {
stdin.Close()
return nil, nil, nil, err
}
stderr, err := c.cmd.StderrPipe()
if err != nil {
stdin.Close()
stdout.Close()
return nil, nil, nil, err
}
return stdin, stdout, stderr, nil
}
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
var msg sshSubsystemExecMsg
if err := ssh.Unmarshal(payload, &msg); err == nil {
name, args, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
name, args, len(args), connection.User.Username, err)
if err == nil && util.Contains(enabledSSHCommands, name) {
connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
scpCommand := scpCommand{
sshCommand: sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args},
}
go scpCommand.handle() //nolint:errcheck
return true
}
if name != scpCmdName {
connection.SetProtocol(common.ProtocolSSH)
sshCommand := sshCommand{
command: name,
connection: connection,
startTime: time.Now(),
args: args,
}
go sshCommand.handle() //nolint:errcheck
return true
}
} else {
connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %q", name)
}
}
err := connection.CloseFS()
connection.Log(logger.LevelError, "unable to unmarshal ssh command, close fs, err: %v", err)
return false
}
func (c *sshCommand) handle() (err error) {
defer func() {
if r := recover(); r != nil {
logger.Error(logSender, "", "panic in handle ssh command: %q stack trace: %v", r, string(debug.Stack()))
err = common.ErrGenericFailure
}
}()
if err := common.Connections.Add(c.connection); err != nil {
logger.Info(logSender, "", "unable to add SSH command connection: %v", err)
return err
}
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
if util.Contains(sshHashCommands, c.command) {
return c.handleHashCommands()
} else if util.Contains(systemCommands, c.command) {
command, err := c.getSystemCommand()
if err != nil {
return c.sendErrorResponse(err)
}
return c.executeSystemCommand(command)
} else if c.command == "cd" {
c.sendExitStatus(nil)
} else if c.command == "pwd" {
// hard coded response to the start directory
c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck
c.sendExitStatus(nil)
} else if c.command == "sftpgo-copy" {
return c.handleSFTPGoCopy()
} else if c.command == "sftpgo-remove" {
return c.handleSFTPGoRemove()
}
return
}
func (c *sshCommand) handleSFTPGoCopy() error {
sshSourcePath := c.getSourcePath()
sshDestPath := c.getDestPath()
if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 {
return c.sendErrorResponse(errors.New("usage sftpgo-copy