diff --git a/go.mod b/go.mod index e35c4416..81a83e84 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/miekg/dns v1.1.22 // indirect github.com/nathanaelle/password v1.0.0 github.com/pelletier/go-toml v1.6.0 // indirect - github.com/pkg/sftp v1.10.2-0.20191102210727-6d50bf4a2122 + github.com/pkg/sftp v1.10.2-0.20191111234405-8488d36edee7 github.com/prometheus/client_golang v1.2.1 github.com/rs/xid v1.2.1 github.com/rs/zerolog v1.16.0 diff --git a/go.sum b/go.sum index 797bbf5f..a2a9cb4a 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,8 @@ github.com/pelletier/go-toml v1.6.0/go.mod h1:5N711Q9dKgbdkxHL+MEfF31hpT7l0S0s/t github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/sftp v1.10.2-0.20191102210727-6d50bf4a2122 h1:sb1Pv18vtpHTpRq4zlPIaiBw815nIkFIrARKIRSVBjM= -github.com/pkg/sftp v1.10.2-0.20191102210727-6d50bf4a2122/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= +github.com/pkg/sftp v1.10.2-0.20191111234405-8488d36edee7 h1:0aliGCO3gzhJZYrCyPwl/H631u53ol99CoxH1Xx3ROk= +github.com/pkg/sftp v1.10.2-0.20191111234405-8488d36edee7/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 9f0ab73a..fea3677c 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -802,3 +802,18 @@ func TestConnectionStatusStruct(t *testing.T) { t.Errorf("error getting connection info") } } + +func TestSFTPExtensions(t *testing.T) { + initialSFTPExtensions := sftpExtensions + c := Configuration{} + err := c.configureSFTPExtensions() + if err != nil { + t.Errorf("error configuring SFTP extensions") + } + sftpExtensions = append(sftpExtensions, "invalid@example.com") + err = c.configureSFTPExtensions() + if err == nil { + t.Errorf("configuring invalid SFTP extensions must fail") + } + sftpExtensions = initialSFTPExtensions +} diff --git a/sftpd/server.go b/sftpd/server.go index f7f161cd..ac503ca7 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -28,6 +28,8 @@ import ( const defaultPrivateKeyName = "id_rsa" +var sftpExtensions = []string{"posix-rename@openssh.com"} + // Configuration for the SFTP server type Configuration struct { // Identification string used by the server @@ -153,6 +155,7 @@ func (c Configuration) Initialize(configDir string) error { c.configureSecurityOptions(serverConfig) c.configureLoginBanner(serverConfig, configDir) + c.configureSFTPExtensions() listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort)) if err != nil { @@ -208,6 +211,15 @@ func (c Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, conf return err } +func (c Configuration) configureSFTPExtensions() error { + err := sftp.SetSFTPExtensions(sftpExtensions...) + if err != nil { + logger.WarnToConsole("unable to configure SFTP extensions: %v", err) + logger.Warn(logSender, "", "unable to configure SFTP extensions: %v", err) + } + return err +} + // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {