add SCP support

SCP is an experimental feature, we have our own SCP implementation
since we can't rely on scp system command to proper handle permissions,
quota and user's home dir restrictions. The SCP protocol is quite simple
but there is no official docs about it, so we need more testing and
feedbacks before enabling it by default.
We may not handle some borderline cases or have sneaky bugs.

This commit contains some breaking changes to the REST API.
SFTPGo API should be stable now and I hope no more breaking changes
before the first stable release.
This commit is contained in:
Nicola Murino 2019-08-24 14:41:15 +02:00
parent 2c05791624
commit e50c521c33
19 changed files with 2077 additions and 128 deletions

View file

@ -19,6 +19,7 @@ Full featured and highly configurable SFTP server software
- Log files are accurate and they are saved in the easily parsable JSON format
- Automatically terminating idle connections
- Atomic uploads are supported
- Optional SCP support
## Platforms
@ -55,7 +56,7 @@ Version info, such as git commit and build date, can be embedded setting the fol
For example you can build using the following command:
```bash
go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date --utc +%FT%TZ`" -o sftpgo
go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo
```
and you will get a version that includes git commit and build date like this one:
@ -129,6 +130,7 @@ The `sftpgo` configuration file contains the following sections:
- `target_path`, added for `rename` action only
- `keys`, struct array. It contains the daemon's private keys. If empty or missing the daemon will search or try to generate `id_rsa` in the configuration directory.
- `private_key`, path to the private key file. It can be a path relative to the config dir or an absolute one.
- `enable_scp`, boolean. Default disabled. Set to `true` to enable SCP support. SCP is an experimental feature, we have our own SCP implementation since we can't rely on `scp` system command to proper handle permissions, quota and user's home dir restrictions. The SCP protocol is quite simple but there is no official docs about it, so we need more testing and feedbacks before enabling it by default. We may not handle some borderline cases or have sneaky bugs. Please do accurate tests yourself before enabling SCP and let us known if something does not work as expected for your use cases.
- **"data_provider"**, the configuration for the data provider
- `driver`, string. Supported drivers are `sqlite`, `mysql`, `postgresql`, `bolt`
- `name`, string. Database name. For driver `sqlite` this can be the database name relative to the config dir or the absolute path to the SQLite database.
@ -164,7 +166,8 @@ Here is a full example showing the default config in json format:
"command": "",
"http_notification_url": ""
},
"keys": []
"keys": [],
"enable_scp": false
},
"data_provider": {
"driver": "sqlite",
@ -287,22 +290,24 @@ The logs can be divided into the following categories:
- `time` string. Date/time with millisecond precision
- `level` string
- `message` string
- **"transfer logs"**, SFTP transfer logs:
- `sender` string. `SFTPUpload` or `SFTPDownload`
- **"transfer logs"**, SFTP/SCP transfer logs:
- `sender` string. `Upload` or `Download`
- `time` string. Date/time with millisecond precision
- `level` string
- `elapsed_ms`, int64. Elapsed time, as milliseconds, for the upload/download
- `size_bytes`, int64. Size, as bytes, of the download/upload
- `username`, string
- `file_path` string
- `connection_id` string. Unique SFTP connection identifier
- **"command logs"**, SFTP command logs:
- `sender` string. `SFTPRename`, `SFTPRmdir`, `SFTPMkdir`, `SFTPSymlink`, `SFTPRemove`
- `connection_id` string. Unique connection identifier
- `protocol` string. `SFTP` or `SCP`
- **"command logs"**, SFTP/SCP command logs:
- `sender` string. `Rename`, `Rmdir`, `Mkdir`, `Symlink`, `Remove`
- `level` string
- `username`, string
- `file_path` string
- `target_path` string
- `connection_id` string. Unique SFTP connection identifier
- `connection_id` string. Unique connection identifier
- `protocol` string. `SFTP` or `SCP`
- **"http logs"**, REST API logs:
- `sender` string. `httpd`
- `level` string
@ -329,6 +334,7 @@ The logs can be divided into the following categories:
- [lib/pq](https://github.com/lib/pq)
- [viper](https://github.com/spf13/viper)
- [cobra](https://github.com/spf13/cobra)
- [xid](https://github.com/rs/xid)
Some code was initially taken from [Pterodactyl sftp server](https://github.com/pterodactyl/sftp-server)

View file

@ -15,7 +15,7 @@ import (
const (
logSender = "api"
activeConnectionsPath = "/api/v1/sftp_connection"
activeConnectionsPath = "/api/v1/connection"
quotaScanPath = "/api/v1/quota_scan"
userPath = "/api/v1/user"
versionPath = "/api/v1/version"

View file

@ -33,7 +33,7 @@ const (
testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
logSender = "APITesting"
userPath = "/api/v1/user"
activeConnectionsPath = "/api/v1/sftp_connection"
activeConnectionsPath = "/api/v1/connection"
quotaScanPath = "/api/v1/quota_scan"
versionPath = "/api/v1/version"
)
@ -405,19 +405,19 @@ func TestGetVersion(t *testing.T) {
}
}
func TestGetSFTPConnections(t *testing.T) {
_, _, err := api.GetSFTPConnections(http.StatusOK)
func TestGetConnections(t *testing.T) {
_, _, err := api.GetConnections(http.StatusOK)
if err != nil {
t.Errorf("unable to get sftp connections: %v", err)
}
_, _, err = api.GetSFTPConnections(http.StatusInternalServerError)
_, _, err = api.GetConnections(http.StatusInternalServerError)
if err == nil {
t.Errorf("get sftp connections request must succeed, we requested to check a wrong status code")
}
}
func TestCloseActiveSFTPConnection(t *testing.T) {
_, err := api.CloseSFTPConnection("non_existent_id", http.StatusNotFound)
func TestCloseActiveConnection(t *testing.T) {
_, err := api.CloseConnection("non_existent_id", http.StatusNotFound)
if err != nil {
t.Errorf("unexpected error closing non existent sftp connection: %v", err)
}
@ -686,7 +686,7 @@ func TestGetVersionMock(t *testing.T) {
checkResponseCode(t, http.StatusOK, rr.Code)
}
func TestGetSFTPConnectionsMock(t *testing.T) {
func TestGetConnectionsMock(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil)
rr := executeRequest(req)
checkResponseCode(t, http.StatusOK, rr.Code)

View file

@ -207,8 +207,8 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, err
return body, checkResponse(resp.StatusCode, expectedStatusCode)
}
// GetSFTPConnections returns status and stats for active SFTP connections
func GetSFTPConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
// GetConnections returns status and stats for active SFTP/SCP connections
func GetConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
var connections []sftpd.ConnectionStatus
var body []byte
resp, err := getHTTPClient().Get(buildURLRelativeToBase(activeConnectionsPath))
@ -225,8 +225,8 @@ func GetSFTPConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byt
return connections, body, err
}
// CloseSFTPConnection closes an active SFTP connection identified by connectionID
func CloseSFTPConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
// CloseConnection closes an active connection identified by connectionID
func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
var body []byte
req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil)
if err != nil {

View file

@ -161,7 +161,7 @@ func TestApiCallsWithBadURL(t *testing.T) {
if err == nil {
t.Errorf("request with invalid URL must fail")
}
_, err = CloseSFTPConnection("non_existent_id", http.StatusNotFound)
_, err = CloseConnection("non_existent_id", http.StatusNotFound)
if err == nil {
t.Errorf("request with invalid URL must fail")
}
@ -200,11 +200,11 @@ func TestApiCallToNotListeningServer(t *testing.T) {
if err == nil {
t.Errorf("request to an inactive URL must fail")
}
_, _, err = GetSFTPConnections(http.StatusOK)
_, _, err = GetConnections(http.StatusOK)
if err == nil {
t.Errorf("request to an inactive URL must fail")
}
_, err = CloseSFTPConnection("non_existent_id", http.StatusNotFound)
_, err = CloseConnection("non_existent_id", http.StatusNotFound)
if err == nil {
t.Errorf("request to an inactive URL must fail")
}
@ -215,13 +215,13 @@ func TestApiCallToNotListeningServer(t *testing.T) {
SetBaseURL(oldBaseURL)
}
func TestCloseSFTPConnectionHandler(t *testing.T) {
func TestCloseConnectionHandler(t *testing.T) {
req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil)
rctx := chi.NewRouteContext()
rctx.URLParams.Add("connectionID", "")
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
rr := httptest.NewRecorder()
handleCloseSFTPConnection(rr, req)
handleCloseConnection(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("Expected response code 400. Got %d", rr.Code)
}

View file

@ -40,7 +40,7 @@ func initializeRouter() {
})
router.Delete(activeConnectionsPath+"/{connectionID}", func(w http.ResponseWriter, r *http.Request) {
handleCloseSFTPConnection(w, r)
handleCloseConnection(w, r)
})
router.Get(quotaScanPath, func(w http.ResponseWriter, r *http.Request) {
@ -72,7 +72,7 @@ func initializeRouter() {
})
}
func handleCloseSFTPConnection(w http.ResponseWriter, r *http.Request) {
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
connectionID := chi.URLParam(r, "connectionID")
if connectionID == "" {
sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)

View file

@ -22,12 +22,12 @@ paths:
type: array
items:
$ref : '#/components/schemas/VersionInfo'
/sftp_connection:
/connection:
get:
tags:
- connections
summary: Get the active sftp users and info about their uploads/downloads
operationId: get_sftp_connections
summary: Get the active users and info about their uploads/downloads
operationId: get_connections
responses:
200:
description: successful operation
@ -37,12 +37,12 @@ paths:
type: array
items:
$ref : '#/components/schemas/ConnectionStatus'
/sftp_connection/{connectionID}:
/connection/{connectionID}:
delete:
tags:
- connections
summary: Terminate an active SFTP connection
operationId: close_sftp_connection
summary: Terminate an active connection
operationId: close_connection
parameters:
- name: connectionID
in: path
@ -183,7 +183,7 @@ paths:
get:
tags:
- users
summary: Returns an array with one or more SFTP users
summary: Returns an array with one or more users
description: For security reasons password and public key are empty in the response
operationId: get_users
parameters:
@ -261,7 +261,7 @@ paths:
post:
tags:
- users
summary: Adds a new SFTP user
summary: Adds a new SFTP/SCP user
operationId: add_user
requestBody:
required: true
@ -562,15 +562,15 @@ components:
max_sessions:
type: integer
format: int32
description: limit the sessions that an sftp user can open. 0 means unlimited
description: limit the sessions that an user can open. 0 means unlimited
quota_size:
type: integer
format: int64
description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed
description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed
quota_files:
type: integer
format: int32
description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed
description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed
permissions:
type: array
items:
@ -594,7 +594,7 @@ components:
type: integer
format: int32
description: Maximum download bandwidth as KB/s, 0 means unlimited
SFTPTransfer:
Transfer:
type: object
properties:
operation_type:
@ -604,7 +604,7 @@ components:
- download
path:
type: string
description: SFTP file path for the upload/download
description: SFTP/SCP file path for the upload/download
start_time:
type: integer
format: int64
@ -625,13 +625,13 @@ components:
description: connected username
connection_id:
type: string
description: unique sftp connection identifier
description: unique connection identifier
client_version:
type: string
description: SFTP client version
description: SFTP/SCP client version
remote_address:
type: string
description: Remote address for the connected SFTP client
description: Remote address for the connected SFTP/SCP client
connection_time:
type: integer
format: int64
@ -640,10 +640,15 @@ components:
type: integer
format: int64
description: last client activity as unix timestamp in milliseconds
protocol:
type: string
enum:
- SFTP
- SCP
active_transfers:
type: array
items:
$ref : '#/components/schemas/SFTPTransfer'
$ref : '#/components/schemas/Transfer'
QuotaScan:
type: object
properties:

View file

@ -53,7 +53,8 @@ func init() {
Command: "",
HTTPNotificationURL: "",
},
Keys: []sftpd.Key{},
Keys: []sftpd.Key{},
IsSCPEnabled: false,
},
ProviderConf: dataprovider.Config{
Driver: "sqlite",

View file

@ -90,8 +90,8 @@ func ErrorToConsole(format string, v ...interface{}) {
consoleLogger.Error().Msg(fmt.Sprintf(format, v...))
}
// TransferLog logs an SFTP upload or download
func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string) {
// TransferLog logs an SFTP/SCP upload or download
func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string, protocol string) {
logger.Info().
Str("sender", operation).
Int64("elapsed_ms", elapsed).
@ -99,16 +99,18 @@ func TransferLog(operation string, path string, elapsed int64, size int64, user
Str("username", user).
Str("file_path", path).
Str("connection_id", connectionID).
Str("protocol", protocol).
Msg("")
}
// CommandLog logs an SFTP command
func CommandLog(command string, path string, target string, user string, connectionID string) {
// CommandLog logs an SFTP/SCP command
func CommandLog(command string, path string, target string, user string, connectionID string, protocol string) {
logger.Info().
Str("sender", command).
Str("username", user).
Str("file_path", path).
Str("target_path", target).
Str("connection_id", connectionID).
Str("protocol", protocol).
Msg("")
}

View file

@ -154,12 +154,12 @@ Output:
]
```
### Get SFTP connections
### Get active connections
Command:
```
python sftpgo_api_cli.py get-sftp-connections
python sftpgo_api_cli.py get-connections
```
Output:
@ -173,9 +173,11 @@ Output:
"remote_address": "127.0.0.1:41622",
"connection_time": 1564696137971,
"last_activity": 1564696159605,
"protocol": "SFTP",
"active_transfers": [
{
"operation_type": "upload",
"path": "/test_upload.gz",
"start_time": 1564696149783,
"size": 1146880,
"last_activity": 1564696159605
@ -185,12 +187,12 @@ Output:
]
```
### Close SFTP connection
### Close connection
Command:
```
python sftpgo_api_cli.py close-sftp-connection 76a11b22260ee4249328df28bef34dc64c70f7c097db52159fc24049eeb0e32c
python sftpgo_api_cli.py close-connection 76a11b22260ee4249328df28bef34dc64c70f7c097db52159fc24049eeb0e32c
```
Output:

View file

@ -22,7 +22,7 @@ class SFTPGoApiRequests:
def __init__(self, debug, baseUrl, authType, authUser, authPassword, secure, no_color):
self.userPath = urlparse.urljoin(baseUrl, '/api/v1/user')
self.quotaScanPath = urlparse.urljoin(baseUrl, '/api/v1/quota_scan')
self.activeConnectionsPath = urlparse.urljoin(baseUrl, '/api/v1/sftp_connection')
self.activeConnectionsPath = urlparse.urljoin(baseUrl, '/api/v1/connection')
self.versionPath = urlparse.urljoin(baseUrl, '/api/v1/version')
self.debug = debug
if authType == 'basic':
@ -101,12 +101,12 @@ class SFTPGoApiRequests:
r = requests.delete(urlparse.urljoin(self.userPath, "user/" + str(user_id)), auth=self.auth, verify=self.verify)
self.printResponse(r)
def getSFTPConnections(self):
def getConnections(self):
r = requests.get(self.activeConnectionsPath, auth=self.auth, verify=self.verify)
self.printResponse(r)
def closeSFTPConnection(self, connectionID):
r = requests.delete(urlparse.urljoin(self.userPath, "sftp_connection/" + str(connectionID)), auth=self.auth)
def closeConnection(self, connectionID):
r = requests.delete(urlparse.urljoin(self.activeConnectionsPath, "connection/" + str(connectionID)), auth=self.auth)
self.printResponse(r)
def getQuotaScans(self):
@ -187,11 +187,11 @@ if __name__ == '__main__':
parserGetUserByID = subparsers.add_parser('get-user-by-id', help='Find user by ID')
parserGetUserByID.add_argument('id', type=int)
parserGetSFTPConnections = subparsers.add_parser('get-sftp-connections',
help='Get the active sftp users and info about their uploads/downloads')
parserGetConnections = subparsers.add_parser('get-connections',
help='Get the active users and info about their uploads/downloads')
parserCloseSFTPConnection = subparsers.add_parser('close-sftp-connection', help='Terminate an active SFTP connection')
parserCloseSFTPConnection.add_argument('connectionID', type=str)
parserCloseConnection = subparsers.add_parser('close-connection', help='Terminate an active SFTP/SCP connection')
parserCloseConnection.add_argument('connectionID', type=str)
parserGetQuotaScans = subparsers.add_parser('get-quota-scans', help='Get the active quota scans')
@ -219,10 +219,10 @@ if __name__ == '__main__':
api.getUsers(args.limit, args.offset, args.order, args.username)
elif args.command == 'get-user-by-id':
api.getUserByID(args.id)
elif args.command == 'get-sftp-connections':
api.getSFTPConnections()
elif args.command == 'close-sftp-connection':
api.closeSFTPConnection(args.connectionID)
elif args.command == 'get-connections':
api.getConnections()
elif args.command == 'close-connection':
api.closeConnection(args.connectionID)
elif args.command == 'get-quota-scans':
api.getQuotaScans()
elif args.command == 'start-quota-scan':

View file

@ -35,6 +35,7 @@ type Connection struct {
StartTime time.Time
// last activity for this connection
lastActivity time.Time
protocol string
lock *sync.Mutex
sshConn *ssh.ServerConn
}
@ -78,6 +79,7 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
transferType: transferDownload,
lastActivity: time.Now(),
isNewFile: false,
protocol: c.protocol,
}
addTransfer(&transfer)
return &transfer, nil
@ -254,7 +256,7 @@ func (c Connection) handleSFTPRename(sourcePath string, targetPath string) error
logger.Error(logSender, "failed to rename file, source: %v target: %v: %v", sourcePath, targetPath, err)
return sftp.ErrSshFxFailure
}
logger.CommandLog(sftpdRenameLogSender, sourcePath, targetPath, c.User.Username, c.ID)
logger.CommandLog(renameLogSender, sourcePath, targetPath, c.User.Username, c.ID, c.protocol)
executeAction(operationRename, c.User.Username, sourcePath, targetPath)
return nil
}
@ -274,7 +276,7 @@ func (c Connection) handleSFTPRmdir(path string) error {
return sftp.ErrSshFxFailure
}
logger.CommandLog(sftpdRmdirLogSender, path, "", c.User.Username, c.ID)
logger.CommandLog(rmdirLogSender, path, "", c.User.Username, c.ID, c.protocol)
dataprovider.UpdateUserQuota(dataProvider, c.User, -numFiles, -size, false)
for _, p := range fileList {
executeAction(operationDelete, c.User.Username, p, "")
@ -291,7 +293,7 @@ func (c Connection) handleSFTPSymlink(sourcePath string, targetPath string) erro
return sftp.ErrSshFxFailure
}
logger.CommandLog(sftpdSymlinkLogSender, sourcePath, targetPath, c.User.Username, c.ID)
logger.CommandLog(symlinkLogSender, sourcePath, targetPath, c.User.Username, c.ID, c.protocol)
return nil
}
@ -304,7 +306,7 @@ func (c Connection) handleSFTPMkdir(path string) error {
logger.Error(logSender, "error making missing dir for path %v: %v", path, err)
return sftp.ErrSshFxFailure
}
logger.CommandLog(sftpdMkdirLogSender, path, "", c.User.Username, c.ID)
logger.CommandLog(mkdirLogSender, path, "", c.User.Username, c.ID, c.protocol)
return nil
}
@ -326,7 +328,7 @@ func (c Connection) handleSFTPRemove(path string) error {
return sftp.ErrSshFxFailure
}
logger.CommandLog(sftpdRemoveLogSender, path, "", c.User.Username, c.ID)
logger.CommandLog(removeLogSender, path, "", c.User.Username, c.ID, c.protocol)
if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
dataprovider.UpdateUserQuota(dataProvider, c.User, -1, -size, false)
}
@ -372,6 +374,7 @@ func (c Connection) handleSFTPUploadToNewFile(requestPath, filePath string) (io.
transferType: transferUpload,
lastActivity: time.Now(),
isNewFile: true,
protocol: c.protocol,
}
addTransfer(&transfer)
return &transfer, nil
@ -426,6 +429,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
transferType: transferUpload,
lastActivity: time.Now(),
isNewFile: false,
protocol: c.protocol,
}
addTransfer(&transfer)
return &transfer, nil

View file

@ -1,14 +1,56 @@
package sftpd
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"runtime"
"testing"
"time"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/pkg/sftp"
)
type MockChannel struct {
Buffer *bytes.Buffer
StdErrBuffer *bytes.Buffer
ReadError error
WriteError error
}
func (c *MockChannel) Read(data []byte) (int, error) {
if c.ReadError != nil {
return 0, c.ReadError
}
return c.Buffer.Read(data)
}
func (c *MockChannel) Write(data []byte) (int, error) {
if c.WriteError != nil {
return 0, c.WriteError
}
return c.Buffer.Write(data)
}
func (c *MockChannel) Close() error {
return nil
}
func (c *MockChannel) CloseWrite() error {
return nil
}
func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return true, nil
}
func (c *MockChannel) Stderr() io.ReadWriter {
return c.StdErrBuffer
}
func TestWrongActions(t *testing.T) {
actionsCopy := actions
badCommand := "/bad/command"
@ -114,3 +156,519 @@ func TestWithInvalidHome(t *testing.T) {
t.Errorf("tested path is not a home subdir")
}
}
func TestSFTPCmdTargetPath(t *testing.T) {
u := dataprovider.User{}
u.HomeDir = "home_rel_path"
u.Username = "test"
u.Permissions = []string{"*"}
connection := Connection{
User: u,
}
_, err := connection.getSFTPCmdTargetPath("invalid_path")
if err != sftp.ErrSshFxOpUnsupported {
t.Errorf("getSFTPCmdTargetPath must fal with the expected error: %v", err)
}
}
func TestSFTPGetUsedQuota(t *testing.T) {
u := dataprovider.User{}
u.HomeDir = "home_rel_path"
u.Username = "test_invalid_user"
u.QuotaSize = 4096
u.QuotaFiles = 1
u.Permissions = []string{"*"}
connection := Connection{
User: u,
}
res := connection.hasSpace(false)
if res != false {
t.Errorf("has space must return false if the user is invalid")
}
}
func TestSCPFileMode(t *testing.T) {
mode := getFileModeAsString(0, true)
if mode != "0755" {
t.Errorf("invalid file mode: %v expected: 0755", mode)
}
mode = getFileModeAsString(0700, true)
if mode != "0700" {
t.Errorf("invalid file mode: %v expected: 0700", mode)
}
mode = getFileModeAsString(0750, true)
if mode != "0750" {
t.Errorf("invalid file mode: %v expected: 0750", mode)
}
mode = getFileModeAsString(0777, true)
if mode != "0777" {
t.Errorf("invalid file mode: %v expected: 0777", mode)
}
mode = getFileModeAsString(0640, false)
if mode != "0640" {
t.Errorf("invalid file mode: %v expected: 0640", mode)
}
mode = getFileModeAsString(0600, false)
if mode != "0600" {
t.Errorf("invalid file mode: %v expected: 0600", mode)
}
mode = getFileModeAsString(0, false)
if mode != "0644" {
t.Errorf("invalid file mode: %v expected: 0644", mode)
}
fileMode := uint32(0777)
fileMode = fileMode | uint32(os.ModeSetgid)
fileMode = fileMode | uint32(os.ModeSetuid)
fileMode = fileMode | uint32(os.ModeSticky)
mode = getFileModeAsString(os.FileMode(fileMode), false)
if mode != "7777" {
t.Errorf("invalid file mode: %v expected: 7777", mode)
}
fileMode = uint32(0644)
fileMode = fileMode | uint32(os.ModeSetgid)
mode = getFileModeAsString(os.FileMode(fileMode), false)
if mode != "4644" {
t.Errorf("invalid file mode: %v expected: 4644", mode)
}
fileMode = uint32(0600)
fileMode = fileMode | uint32(os.ModeSetuid)
mode = getFileModeAsString(os.FileMode(fileMode), false)
if mode != "2600" {
t.Errorf("invalid file mode: %v expected: 2600", mode)
}
fileMode = uint32(0044)
fileMode = fileMode | uint32(os.ModeSticky)
mode = getFileModeAsString(os.FileMode(fileMode), false)
if mode != "1044" {
t.Errorf("invalid file mode: %v expected: 1044", mode)
}
}
func TestSCPGetNonExistingDirContent(t *testing.T) {
_, err := getDirContents("non_existing")
if err == nil {
t.Errorf("get non existing dir contents must fail")
}
}
func TestSCPParseUploadMessage(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
}
_, _, err := scpCommand.parseUploadMessage("invalid")
if err == nil {
t.Errorf("parsing invalid upload message must fail")
}
_, _, err = scpCommand.parseUploadMessage("D0755 0")
if err == nil {
t.Errorf("parsing incomplete upload message must fail")
}
_, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
if err == nil {
t.Errorf("parsing upload message with invalid size must fail")
}
_, _, err = scpCommand.parseUploadMessage("D0755 0 ")
if err == nil {
t.Errorf("parsing upload message with invalid name must fail")
}
}
func TestSCPProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
}
_, err := scpCommand.readProtocolMessage()
if err == nil || err != readErr {
t.Errorf("read protocol message must fail, we are sending a fake error")
}
err = scpCommand.sendConfirmationMessage()
if err != writeErr {
t.Errorf("write confirmation message must fail, we are sending a fake error")
}
err = scpCommand.sendProtocolMessage("E\n")
if err != writeErr {
t.Errorf("write confirmation message must fail, we are sending a fake error")
}
_, err = scpCommand.getNextUploadProtocolMessage()
if err == nil || err != readErr {
t.Errorf("read next upload protocol message must fail, we are sending a fake read error")
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
scpCommand.channel = &mockSSHChannel
_, err = scpCommand.getNextUploadProtocolMessage()
if err == nil || err != writeErr {
t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
}
respBuffer := []byte{0x02}
protocolErrorMsg := "protocol error msg"
respBuffer = append(respBuffer, protocolErrorMsg...)
respBuffer = append(respBuffer, 0x0A)
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(respBuffer),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
err = scpCommand.readConfirmationMessage()
if err == nil || err.Error() != protocolErrorMsg {
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
}
}
func TestSCPTestDownloadProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-f", "-p", "/tmp"},
channel: &mockSSHChannel,
}
path := "testDir"
os.Mkdir(path, 0777)
stat, _ := os.Stat(path)
err := scpCommand.sendDownloadProtocolMessages(path, stat)
if err != writeErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: nil,
}
err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != readErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand.args = []string{"-f", "/tmp"}
scpCommand.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != writeErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != readErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
}
os.Remove(path)
}
func TestSCPCommandHandleErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-f", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handle()
if err == nil || err != readErr {
t.Errorf("scp download must fail, we are sending a fake error")
}
scpCommand.args = []string{"-i", "/tmp"}
err = scpCommand.handle()
if err == nil {
t.Errorf("invalid scp command must fail")
}
}
func TestRecursiveDownloadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannel,
}
path := "testDir"
os.Mkdir(path, 0777)
stat, _ := os.Stat(path)
err := scpCommand.handleRecursiveDownload("invalid_dir", stat)
if err != writeErr {
t.Errorf("recursive upload download must fail with the expected error: %v", err)
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
if err == nil {
t.Errorf("recursive upload download must fail for a non existing dir")
}
os.Remove(path)
}
func TestRecursiveUploadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handleRecursiveUpload()
if err == nil {
t.Errorf("recursive upload must fail, we send a fake error message")
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
err = scpCommand.handleRecursiveUpload()
if err == nil {
t.Errorf("recursive upload must fail, we send a fake error message")
}
}
func TestSCPCreateDirs(t *testing.T) {
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
u := dataprovider.User{}
u.HomeDir = "home_rel_path"
u.Username = "test"
u.Permissions = []string{"*"}
connection := Connection{
User: u,
}
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handleCreateDir("invalid_dir")
if err == nil {
t.Errorf("create invalid dir must fail")
}
}
func TestSCPDownloadFileData(t *testing.T) {
testfile := "testfile"
buf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
stdErrBuf := make([]byte, 65535)
connection := Connection{}
mockSSHChannelReadErr := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: nil,
}
mockSSHChannelWriteErr := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannelReadErr,
}
ioutil.WriteFile(testfile, []byte("test"), 0666)
stat, _ := os.Stat(testfile)
err := scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
scpCommand.channel = &mockSSHChannelWriteErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
scpCommand.channel = &mockSSHChannelReadErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
os.Remove(testfile)
}
func TestSCPUploadFiledata(t *testing.T) {
testfile := "testfile"
connection := Connection{
User: dataprovider.User{
Username: "testuser",
},
protocol: protocolSCP,
}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: writeErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
file, _ := os.Create(testfile)
transfer := Transfer{
file: file,
path: file.Name(),
start: time.Now(),
bytesSent: 0,
bytesReceived: 0,
user: scpCommand.connection.User,
connectionID: "",
transferType: transferDownload,
lastActivity: time.Now(),
isNewFile: true,
protocol: connection.protocol,
}
addTransfer(&transfer)
err := scpCommand.getUploadFileData(2, &transfer)
if err == nil {
t.Errorf("upload must fail, we send a fake write error message")
}
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
file, _ = os.Create(testfile)
transfer.file = file
addTransfer(&transfer)
err = scpCommand.getUploadFileData(2, &transfer)
if err == nil {
t.Errorf("upload must fail, we send a fake read error message")
}
respBuffer := []byte("12")
respBuffer = append(respBuffer, 0x02)
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(respBuffer),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
file, _ = os.Create(testfile)
transfer.file = file
addTransfer(&transfer)
err = scpCommand.getUploadFileData(2, &transfer)
if err == nil {
t.Errorf("upload must fail, we have not enough data to read")
}
// the file is already closed so we have an error on trasfer closing
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
addTransfer(&transfer)
err = scpCommand.getUploadFileData(0, &transfer)
if err == nil {
t.Errorf("upload must fail, the file is closed")
}
os.Remove(testfile)
}

758
sftpd/scp.go Normal file
View file

@ -0,0 +1,758 @@
package sftpd
import (
"fmt"
"io"
"math"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
"golang.org/x/crypto/ssh"
)
var (
okMsg = []byte{0x00}
warnMsg = []byte{0x01} // must be followed by an optional message and a newline
errMsg = []byte{0x02} // must be followed by an optional message and a newline
newLine = []byte{0x0A}
)
type execMsg struct {
Command string
}
type exitStatusMsg struct {
Status uint32
}
type scpCommand struct {
connection Connection
args []string
channel ssh.Channel
}
func (c *scpCommand) handle() error {
var err error
addConnection(c.connection.ID, c.connection)
defer removeConnection(c.connection.ID)
destPath := c.getDestPath()
commandType := c.getCommandType()
logger.Debug(logSenderSCP, "handle scp command, args: %v user: %v command type: %v, dest path: %v",
c.args, c.connection.User.Username, commandType, destPath)
if commandType == "-t" {
// -t means "to", so upload
err = c.handleRecursiveUpload()
if err != nil {
return err
}
} else if commandType == "-f" {
// -f means "from" so download
err = c.readConfirmationMessage()
if err != nil {
return err
}
err = c.handleDownload(destPath)
if err != nil {
return err
}
} else {
err = fmt.Errorf("scp command not supported, args: %v", c.args)
}
c.sendExitStatus(err)
return err
}
func (c *scpCommand) handleRecursiveUpload() error {
var err error
numDirs := 0
destPath := c.getDestPath()
for {
err = c.sendConfirmationMessage()
if err != nil {
return err
}
command, err := c.getNextUploadProtocolMessage()
if err != nil {
return err
}
if strings.HasPrefix(command, "E") {
numDirs--
logger.Debug(logSenderSCP, "received end dir command, num dirs: %v", numDirs)
if numDirs == 0 {
// upload is now complete send confirmation message
err = c.sendConfirmationMessage()
if err != nil {
return err
}
} else {
// the destination dir is now the parent directory
destPath = filepath.Join(destPath, "..")
}
} else {
sizeToRead, name, err := c.parseUploadMessage(command)
if err != nil {
return err
}
objPath := path.Join(destPath, name)
if strings.HasPrefix(command, "D") {
numDirs++
err = c.handleCreateDir(objPath)
if err != nil {
return err
}
destPath = objPath
logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath)
} else if strings.HasPrefix(command, "C") {
// if the upload is not recursive and the destination path does not end with "/"
// then this is the wanted filename ...
if !c.isRecursive() {
if !strings.HasSuffix(destPath, "/") {
objPath = destPath
// ... but if the requested path is an existing directory then put the uploaded file inside that directory
if p, err := c.connection.buildPath(objPath); err == nil {
if stat, err := os.Stat(p); err == nil {
if stat.IsDir() {
objPath = path.Join(destPath, name)
}
}
}
}
}
err = c.handleUpload(objPath, sizeToRead)
if err != nil {
return err
}
}
}
if err != nil || numDirs == 0 {
break
}
}
return err
}
func (c *scpCommand) handleCreateDir(dirPath string) error {
updateConnectionActivity(c.connection.ID)
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs) {
err := fmt.Errorf("Permission denied")
logger.Warn(logSenderSCP, "error creating dir: %v, permission denied", dirPath)
c.sendErrorMessage(err.Error())
return err
}
p, err := c.connection.buildPath(dirPath)
if err != nil {
logger.Warn(logSenderSCP, "error creating dir: %v, invalid file path, err: %v", dirPath, err)
c.sendErrorMessage(err.Error())
return err
}
err = c.createDir(p)
if err != nil {
return err
}
logger.CommandLog(mkdirLogSender, dirPath, "", c.connection.User.Username, c.connection.ID, c.connection.protocol)
return nil
}
// we need to close the transfer if we have an error
func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) error {
err := c.sendConfirmationMessage()
if err != nil {
transfer.Close()
return err
}
if sizeToRead > 0 {
remaining := sizeToRead
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
for {
n, err := c.channel.Read(buf)
if err != nil {
c.sendErrorMessage(err.Error())
transfer.Close()
return err
}
transfer.WriteAt(buf[:n], sizeToRead-remaining)
remaining -= int64(n)
if remaining <= 0 {
break
}
if remaining < int64(len(buf)) {
buf = make([]byte, remaining)
}
}
}
err = c.readConfirmationMessage()
if err != nil {
transfer.Close()
return err
}
err = transfer.Close()
if err != nil {
c.sendErrorMessage(err.Error())
return err
}
return c.sendConfirmationMessage()
}
func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64) error {
logger.Debug(logSenderSCP, "upload to new file: %v", filePath)
if !c.connection.hasSpace(true) {
err := fmt.Errorf("denying file write due to space limit")
logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", filePath, err)
c.sendErrorMessage(err.Error())
return err
}
if _, err := os.Stat(filepath.Dir(requestPath)); os.IsNotExist(err) {
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs) {
err := fmt.Errorf("Permission denied")
logger.Warn(logSenderSCP, "error uploading file: %v, permission denied", requestPath)
c.sendErrorMessage(err.Error())
return err
}
}
file, err := os.Create(filePath)
if err != nil {
logger.Error(logSenderSCP, "error creating file %v: %v", requestPath, err)
c.sendErrorMessage(err.Error())
return err
}
utils.SetPathPermissions(filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
transfer := Transfer{
file: file,
path: requestPath,
start: time.Now(),
bytesSent: 0,
bytesReceived: 0,
user: c.connection.User,
connectionID: c.connection.ID,
transferType: transferUpload,
lastActivity: time.Now(),
isNewFile: true,
protocol: c.connection.protocol,
}
addTransfer(&transfer)
return c.getUploadFileData(sizeToRead, &transfer)
}
func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error {
var err error
updateConnectionActivity(c.connection.ID)
if !c.connection.User.HasPerm(dataprovider.PermUpload) {
err := fmt.Errorf("Permission denied")
logger.Warn(logSenderSCP, "error uploading file: %v, permission denied", uploadFilePath)
c.sendErrorMessage(err.Error())
return err
}
p, err := c.connection.buildPath(uploadFilePath)
if err != nil {
logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", uploadFilePath, err)
c.sendErrorMessage(err.Error())
return err
}
filePath := p
if uploadMode == uploadModeAtomic {
filePath = getUploadTempFilePath(p)
}
stat, statErr := os.Stat(p)
if os.IsNotExist(statErr) {
return c.handleUploadFile(p, filePath, sizeToRead)
}
if statErr != nil {
logger.Error(logSenderSCP, "error performing file stat %v: %v", p, statErr)
c.sendErrorMessage(err.Error())
return err
}
if stat.IsDir() {
logger.Warn(logSenderSCP, "attempted to open a directory for writing to: %v", p)
err = fmt.Errorf("Attempted to open a directory for writing: %v", p)
c.sendErrorMessage(err.Error())
return err
}
if uploadMode == uploadModeAtomic {
err = os.Rename(p, filePath)
if err != nil {
logger.Error(logSenderSCP, "error renaming existing file for atomic upload, source: %v, dest: %v, err: %v",
p, filePath, err)
c.sendErrorMessage(err.Error())
return err
}
}
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false)
return c.handleUploadFile(p, filePath, sizeToRead)
}
func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {
var err error
if c.sendFileTime() {
modTime := stat.ModTime().UnixNano() / 1000000000
tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
err = c.sendProtocolMessage(tCommand)
if err != nil {
return err
}
err = c.readConfirmationMessage()
if err != nil {
return err
}
}
fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), filepath.Base(dirPath))
err = c.sendProtocolMessage(fileMode)
if err != nil {
return err
}
err = c.readConfirmationMessage()
return err
}
// we send first all the files in the roor directory and then the directories
// for each directory we recursively call this method again
func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) error {
var err error
if c.isRecursive() {
logger.Debug(logSenderSCP, "recursive download, dir path: %v", dirPath)
err = c.sendDownloadProtocolMessages(dirPath, stat)
if err != nil {
return err
}
files, err := getDirContents(dirPath)
if err != nil {
c.sendErrorMessage(err.Error())
return err
}
var dirs []string
for _, file := range files {
filePath := c.connection.User.GetRelativePath(filepath.Join(dirPath, file.Name()))
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink == os.ModeSymlink {
err = c.handleDownload(filePath)
if err != nil {
break
}
} else if file.IsDir() {
dirs = append(dirs, filePath)
}
}
if err != nil {
c.sendErrorMessage(err.Error())
return err
}
for _, dir := range dirs {
err = c.handleDownload(dir)
if err != nil {
break
}
}
if err != nil {
c.sendErrorMessage(err.Error())
return err
}
err = c.sendProtocolMessage("E\n")
if err != nil {
return err
}
err = c.readConfirmationMessage()
if err != nil {
return err
}
return err
}
err = fmt.Errorf("Unable to send directory for non recursive copy")
c.sendErrorMessage(err.Error())
return err
}
func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, transfer *Transfer) error {
var err error
if c.sendFileTime() {
modTime := stat.ModTime().UnixNano() / 1000000000
tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
err = c.sendProtocolMessage(tCommand)
if err != nil {
return err
}
err = c.readConfirmationMessage()
if err != nil {
return err
}
}
fileSize := stat.Size()
readed := int64(0)
fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath))
err = c.sendProtocolMessage(fileMode)
if err != nil {
return err
}
err = c.readConfirmationMessage()
if err != nil {
return err
}
buf := make([]byte, 32768)
for {
n, err := transfer.ReadAt(buf, readed)
if err == nil || err == io.EOF {
if n > 0 {
_, err = c.channel.Write(buf[:n])
}
}
readed += int64(n)
if err != nil {
break
}
}
if err != nil && err != io.EOF {
c.sendErrorMessage(err.Error())
return err
}
err = c.sendConfirmationMessage()
if err != nil {
return err
}
err = c.readConfirmationMessage()
return err
}
func (c *scpCommand) handleDownload(filePath string) error {
var err error
updateConnectionActivity(c.connection.ID)
if !c.connection.User.HasPerm(dataprovider.PermDownload) {
err := fmt.Errorf("Permission denied")
logger.Warn(logSenderSCP, "error downloading file: %v, permission denied", filePath)
c.sendErrorMessage(err.Error())
return err
}
p, err := c.connection.buildPath(filePath)
if err != nil {
err := fmt.Errorf("Invalid file path")
logger.Warn(logSenderSCP, "error downloading file: %v, invalid file path", filePath)
c.sendErrorMessage(err.Error())
return err
}
var stat os.FileInfo
if stat, err = os.Stat(p); os.IsNotExist(err) {
logger.Warn(logSenderSCP, "error downloading file: %v, err: %v", p, err)
c.sendErrorMessage(err.Error())
return err
}
if stat.IsDir() {
err = c.handleRecursiveDownload(p, stat)
return err
}
file, err := os.Open(p)
if err != nil {
logger.Error(logSenderSCP, "could not open file \"%v\" for reading: %v", p, err)
c.sendErrorMessage(err.Error())
return err
}
transfer := Transfer{
file: file,
path: p,
start: time.Now(),
bytesSent: 0,
bytesReceived: 0,
user: c.connection.User,
connectionID: c.connection.ID,
transferType: transferDownload,
lastActivity: time.Now(),
isNewFile: false,
protocol: c.connection.protocol,
}
addTransfer(&transfer)
err = c.sendDownloadFileData(p, stat, &transfer)
// we need to call Close anyway and return close error if any and
// if we have no previous error
if err == nil {
err = transfer.Close()
} else {
transfer.Close()
}
return err
}
// returns the SCP destination path.
// We ensure that the path is absolute and in SFTP (UNIX) format
func (c *scpCommand) getDestPath() string {
destPath := filepath.ToSlash(c.args[len(c.args)-1])
if !filepath.IsAbs(destPath) {
destPath = "/" + destPath
}
return destPath
}
func (c *scpCommand) getCommandType() string {
return c.args[len(c.args)-2]
}
func (c *scpCommand) sendFileTime() bool {
return utils.IsStringInSlice("-p", c.args)
}
func (c *scpCommand) isRecursive() bool {
return utils.IsStringInSlice("-r", c.args)
}
// read the SCP confirmation message and the optional text message
// the channel will be closed on errors
func (c *scpCommand) readConfirmationMessage() error {
var msg strings.Builder
buf := make([]byte, 1)
n, err := c.channel.Read(buf)
if err != nil {
c.channel.Close()
return err
}
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
isError := buf[0] == errMsg[0]
for {
n, err = c.channel.Read(buf)
readed := buf[:n]
if err != nil || (n == 1 && readed[0] == newLine[0]) {
break
}
if n > 0 {
msg.WriteString(string(readed))
}
}
logger.Info(logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
err = fmt.Errorf("%v", msg.String())
c.channel.Close()
}
return err
}
// protool messages are newline terminated
func (c *scpCommand) readProtocolMessage() (string, error) {
var command strings.Builder
var err error
buf := make([]byte, 1)
for {
var n int
n, err = c.channel.Read(buf)
if err != nil {
break
}
if n > 0 {
readed := buf[:n]
if n == 1 && readed[0] == newLine[0] {
break
}
command.WriteString(string(readed))
}
}
if err != nil {
c.channel.Close()
}
return command.String(), err
}
// send an error message and close the channel
func (c *scpCommand) sendErrorMessage(error string) {
c.channel.Write(errMsg)
c.channel.Write([]byte(error))
c.channel.Write(newLine)
c.channel.Close()
}
// send scp confirmation message and close the channel if an error happen
func (c *scpCommand) sendConfirmationMessage() error {
_, err := c.channel.Write(okMsg)
if err != nil {
c.channel.Close()
}
return err
}
// sends a protocol message and close the channel on error
func (c *scpCommand) sendProtocolMessage(message string) error {
_, err := c.channel.Write([]byte(message))
if err != nil {
logger.Warn(logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
c.channel.Close()
}
return err
}
// sends the SCP command exit status
func (c *scpCommand) sendExitStatus(err error) {
status := uint32(0)
if err != nil {
status = 1
}
ex := exitStatusMsg{
Status: status,
}
logger.Debug(logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
c.args, c.connection.User.Username, err)
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
c.channel.Close()
}
// get the next upload protocol message ignoring T command if any
// we use our own user setting for permissions
func (c *scpCommand) getNextUploadProtocolMessage() (string, error) {
var command string
var err error
for {
command, err = c.readProtocolMessage()
if err != nil {
return command, err
}
if strings.HasPrefix(command, "T") {
err = c.sendConfirmationMessage()
if err != nil {
return command, err
}
} else {
break
}
}
return command, err
}
func (c *scpCommand) createDir(dirPath string) error {
var err error
if err = os.Mkdir(dirPath, 0777); err != nil {
logger.Error(logSenderSCP, "error creating dir: %v", dirPath)
c.sendErrorMessage(err.Error())
return err
}
utils.SetPathPermissions(dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
return err
}
// parse protocol messages such as:
// D0755 0 testdir
// or:
// C0644 6 testfile
// and returns file size and file/directory name
func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
var size int64
var name string
var err error
if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") {
err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v",
command, c.args, c.connection.User.Username)
logger.Warn(logSenderSCP, "error: %v", err)
c.sendErrorMessage(err.Error())
return size, name, err
}
parts := strings.Split(command, " ")
if len(parts) == 3 {
size, err = strconv.ParseInt(parts[1], 10, 64)
if err != nil {
logger.Warn(logSenderSCP, "error getting size from upload message: %v", err)
c.sendErrorMessage(fmt.Sprintf("Error getting size: %v", err))
return size, name, err
}
name = parts[2]
if len(name) == 0 {
err = fmt.Errorf("error getting name from upload message, cannot be empty")
logger.Warn(logSenderSCP, "error: %v", err)
c.sendErrorMessage(err.Error())
return size, name, err
}
} else {
err = fmt.Errorf("Error splitting upload message: %v", command)
logger.Warn(logSenderSCP, "error: %v", err)
c.sendErrorMessage(err.Error())
return size, name, err
}
return size, name, err
}
func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
var defaultMode string
if isDir {
defaultMode = "0755"
} else {
defaultMode = "0644"
}
if fileMode == 0 {
return defaultMode
}
modeString := []byte(fileMode.String())
nullPerm := []byte("-")
u := 0
g := 0
o := 0
s := 0
lastChar := len(modeString) - 1
if fileMode&os.ModeSticky != 0 {
s++
}
if fileMode&os.ModeSetuid != 0 {
s += 2
}
if fileMode&os.ModeSetgid != 0 {
s += 4
}
if modeString[lastChar-8] != nullPerm[0] {
u += 4
}
if modeString[lastChar-7] != nullPerm[0] {
u += 2
}
if modeString[lastChar-6] != nullPerm[0] {
u++
}
if modeString[lastChar-5] != nullPerm[0] {
g += 4
}
if modeString[lastChar-4] != nullPerm[0] {
g += 2
}
if modeString[lastChar-3] != nullPerm[0] {
g++
}
if modeString[lastChar-2] != nullPerm[0] {
o += 4
}
if modeString[lastChar-1] != nullPerm[0] {
o += 2
}
if modeString[lastChar] != nullPerm[0] {
o++
}
return fmt.Sprintf("%v%v%v%v", s, u, g, o)
}
func getDirContents(path string) ([]os.FileInfo, error) {
var files []os.FileInfo
f, err := os.Open(path)
if err != nil {
return files, err
}
files, err = f.Readdir(-1)
f.Close()
return files, err
}

View file

@ -15,6 +15,7 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
@ -52,6 +53,15 @@ type Configuration struct {
Actions Actions `json:"actions" mapstructure:"actions"`
// Keys are a list of host keys
Keys []Key `json:"keys" mapstructure:"keys"`
// IsSCPEnabled determines if experimental SCP support is enabled.
// We have our own SCP implementation since we can't rely on scp system
// command to properly handle permissions, quota and user's home dir restrictions.
// The SCP protocol is quite simple but there is no official docs about it,
// so we need more testing and feedbacks before enabling it by default.
// We may not handle some borderline cases or have sneaky bugs.
// Please do accurate tests yourself before enabling SCP and let us known
// if something does not work as expected for your use cases
IsSCPEnabled bool `json:"enable_scp" mapstructure:"enable_scp"`
}
// Key contains information about host keys
@ -152,6 +162,28 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
logger.Debug(logSender, "accepted inbound connection, ip: %v", conn.RemoteAddr().String())
var user dataprovider.User
err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
if err != nil {
logger.Warn(logSender, "Unable to deserialize user info, cannot serve connection: %v", err)
return
}
connectionID := hex.EncodeToString(sconn.SessionID())
connection := Connection{
ID: connectionID,
User: user,
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: conn.RemoteAddr(),
StartTime: time.Now(),
lastActivity: time.Now(),
lock: new(sync.Mutex),
sshConn: sconn,
}
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
@ -179,55 +211,54 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
ok = true
connection.protocol = protocolSFTP
go c.handleSftpConnection(channel, connection)
}
case "exec":
if c.IsSCPEnabled {
var msg execMsg
if err := ssh.Unmarshal(req.Payload, &msg); err == nil {
name, scpArgs, err := parseCommandPayload(msg.Command)
logger.Debug(logSender, "new exec command: %v args: %v user: %v, error: %v", name, scpArgs,
connection.User.Username, err)
if err == nil && name == "scp" && len(scpArgs) >= 2 {
ok = true
connection.protocol = protocolSCP
scpCommand := scpCommand{
connection: connection,
args: scpArgs,
channel: channel,
}
go scpCommand.handle()
}
}
}
}
req.Reply(ok, nil)
}
}(requests)
var user dataprovider.User
err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
if err != nil {
logger.Warn(logSender, "Unable to deserialize user info, cannot serve connection: %v", err)
return
}
connectionID := hex.EncodeToString(sconn.SessionID())
// Create a new handler for the currently logged in user's server.
handler := c.createHandler(sconn, user, connectionID)
// Create the server instance for the channel using the handler we created above.
server := sftp.NewRequestServer(channel, handler)
if err := server.Serve(); err == io.EOF {
logger.Debug(logSender, "connection closed, id: %v", connectionID)
server.Close()
} else if err != nil {
logger.Error(logSender, "sftp connection closed with error id %v: %v", connectionID, err)
}
removeConnection(connectionID)
}
}
func (c Configuration) createHandler(conn *ssh.ServerConn, user dataprovider.User, connectionID string) sftp.Handlers {
func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) {
addConnection(connection.ID, connection)
// Create a new handler for the currently logged in user's server.
handler := c.createHandler(connection)
connection := Connection{
ID: connectionID,
User: user,
ClientVersion: string(conn.ClientVersion()),
RemoteAddr: conn.RemoteAddr(),
StartTime: time.Now(),
lastActivity: time.Now(),
lock: new(sync.Mutex),
sshConn: conn,
// Create the server instance for the channel using the handler we created above.
server := sftp.NewRequestServer(channel, handler)
if err := server.Serve(); err == io.EOF {
logger.Debug(logSender, "connection closed, id: %v", connection.ID)
server.Close()
} else if err != nil {
logger.Error(logSender, "sftp connection closed with error id %v: %v", connection.ID, err)
}
addConnection(connectionID, connection)
removeConnection(connection.ID)
}
func (c Configuration) createHandler(connection Connection) sftp.Handlers {
return sftp.Handlers{
FileGet: connection,
@ -331,3 +362,11 @@ func (c Configuration) generatePrivateKey(file string) error {
return nil
}
func parseCommandPayload(command string) (string, []string, error) {
parts := strings.Split(command, " ")
if len(parts) < 2 {
return parts[0], []string{}, nil
}
return parts[0], parts[1:], nil
}

View file

@ -19,18 +19,21 @@ import (
)
const (
logSender = "sftpd"
sftpUploadLogSender = "SFTPUpload"
sftpdDownloadLogSender = "SFTPDownload"
sftpdRenameLogSender = "SFTPRename"
sftpdRmdirLogSender = "SFTPRmdir"
sftpdMkdirLogSender = "SFTPMkdir"
sftpdSymlinkLogSender = "SFTPSymlink"
sftpdRemoveLogSender = "SFTPRemove"
operationDownload = "download"
operationUpload = "upload"
operationDelete = "delete"
operationRename = "rename"
logSender = "sftpd"
logSenderSCP = "scp"
uploadLogSender = "Upload"
downloadLogSender = "Download"
renameLogSender = "Rename"
rmdirLogSender = "Rmdir"
mkdirLogSender = "Mkdir"
symlinkLogSender = "Symlink"
removeLogSender = "Remove"
operationDownload = "download"
operationUpload = "upload"
operationDelete = "delete"
operationRename = "rename"
protocolSFTP = "SFTP"
protocolSCP = "SCP"
)
var (
@ -86,6 +89,8 @@ type ConnectionStatus struct {
ConnectionTime int64 `json:"connection_time"`
// Last activity as unix timestamp in milliseconds
LastActivity int64 `json:"last_activity"`
// Protocol for this connection: SFTP or SCP
Protocol string `json:"protocol"`
// active uploads/downloads
Transfers []connectionTransfer `json:"active_transfers"`
}
@ -190,6 +195,7 @@ func GetConnectionsStats() []ConnectionStatus {
RemoteAddress: c.RemoteAddr.String(),
ConnectionTime: utils.GetTimeAsMsSinceEpoch(c.StartTime),
LastActivity: utils.GetTimeAsMsSinceEpoch(c.lastActivity),
Protocol: c.protocol,
Transfers: []connectionTransfer{},
}
for _, t := range activeTransfers {
@ -250,9 +256,7 @@ func CheckIdleConnections() {
if idleTime > idleTimeout {
logger.Debug(logSender, "close idle connection id: %v idle time: %v", c.ID, idleTime)
err := c.sshConn.Close()
if err != nil {
logger.Warn(logSender, "error closing idle connection: %v", err)
}
logger.Debug(logSender, "idle connection closed id: %v, err: %v", c.ID, err)
}
}
logger.Debug(logSender, "check idle connections ended")

View file

@ -8,6 +8,8 @@ import (
"net"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"testing"
@ -77,8 +79,11 @@ iixITGvaNZh/tjAAAACW5pY29sYUBwMQE=
)
var (
allPerms = []string{dataprovider.PermAny}
homeBasePath string
allPerms = []string{dataprovider.PermAny}
homeBasePath string
scpPath string
pubKeyPath string
privateKeyPath string
)
func TestMain(m *testing.M) {
@ -97,6 +102,8 @@ func TestMain(m *testing.M) {
httpdConf := config.GetHTTPDConfig()
router := api.GetHTTPRouter()
sftpdConf.BindPort = 2022
// we need to test SCP support
sftpdConf.IsSCPEnabled = true
// we run the test cases with UploadMode atomic. The non atomic code path
// simply does not execute some code so if it works in atomic mode will
// work in non atomic mode too
@ -109,10 +116,27 @@ func TestMain(m *testing.M) {
sftpdConf.Actions.Command = "/usr/bin/true"
sftpdConf.Actions.HTTPNotificationURL = "http://127.0.0.1:8080/"
}
pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub")
privateKeyPath = filepath.Join(homeBasePath, "ssh_key")
err = ioutil.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600)
if err != nil {
logger.WarnToConsole("unable to save public key to file: %v", err)
}
err = ioutil.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600)
if err != nil {
logger.WarnToConsole("unable to save private key to file: %v", err)
}
sftpd.SetDataProvider(dataProvider)
api.SetDataProvider(dataProvider)
scpPath, err = exec.LookPath("scp")
if err != nil {
logger.Warn(logSender, "unable to get scp command. SCP tests will be skipped, err: %v", err)
logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
scpPath = ""
}
go func() {
logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
if err := sftpdConf.Initialize(configDir); err != nil {
@ -1399,6 +1423,503 @@ func TestSSHConnection(t *testing.T) {
}
}
// Start SCP tests
func TestSCPBasicHandling(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
u.QuotaSize = 6553600
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(131074)
expectedQuotaSize := user.UsedQuotaSize + testFileSize
expectedQuotaFiles := user.UsedQuotaFiles + 1
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat")
// test to download a missing file
err = scpDownload(localPath, remoteDownPath, false, false)
if err == nil {
t.Errorf("downloading a missing file via scp must fail")
}
err = scpUpload(testFilePath, remoteUpPath, false)
if err != nil {
t.Errorf("error uploading file via scp: %v", err)
}
err = scpDownload(localPath, remoteDownPath, false, false)
if err != nil {
t.Errorf("error downloading file via scp: %v", err)
}
fi, err := os.Stat(localPath)
if err != nil {
t.Errorf("stat for the downloaded file must succeed")
} else {
if fi.Size() != testFileSize {
t.Errorf("size of the file downloaded via SCP does not match the expected one")
}
}
os.Remove(localPath)
user, _, err = api.GetUserByID(user.ID, http.StatusOK)
if err != nil {
t.Errorf("error getting user: %v", err)
}
if expectedQuotaFiles != user.UsedQuotaFiles {
t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles)
}
if expectedQuotaSize != user.UsedQuotaSize {
t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize)
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPUploadFileOverwrite(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(32760)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, filepath.Join("/", testFileName))
err = scpUpload(testFilePath, remoteUpPath, true)
if err != nil {
t.Errorf("error uploading file via scp: %v", err)
}
// test a new upload that must overwrite the existing file
err = scpUpload(testFilePath, remoteUpPath, true)
if err != nil {
t.Errorf("error uploading existing file via scp: %v", err)
}
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat")
err = scpDownload(localPath, remoteDownPath, false, false)
if err != nil {
t.Errorf("error downloading file via scp: %v", err)
}
fi, err := os.Stat(localPath)
if err != nil {
t.Errorf("stat for the downloaded file must succeed")
} else {
if fi.Size() != testFileSize {
t.Errorf("size of the file downloaded via SCP does not match the expected one")
}
}
os.Remove(localPath)
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPRecursive(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testBaseDirName := "test_dir"
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
testBaseDirDownName := "test_dir_down"
testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName)
testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName)
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName)
testFileSize := int64(131074)
createTestFile(testFilePath, testFileSize)
createTestFile(testFilePath1, testFileSize)
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName))
// test to download a missing dir
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
if err == nil {
t.Errorf("downloading a missing dir via scp must fail")
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
err = scpUpload(testBaseDirPath, remoteUpPath, true)
if err != nil {
t.Errorf("error uploading dir via scp: %v", err)
}
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
if err != nil {
t.Errorf("error downloading dir via scp: %v", err)
}
// test download without passing -r
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false)
if err == nil {
t.Errorf("recursive download without -r must fail")
}
fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName))
if err != nil {
t.Errorf("error downloading file using scp recursive: %v", err)
} else {
if fi.Size() != testFileSize {
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
}
}
fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName))
if err != nil {
t.Errorf("error downloading file using scp recursive: %v", err)
} else {
if fi.Size() != testFileSize {
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
}
}
// upload to a non existent dir
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir")
err = scpUpload(testBaseDirPath, remoteUpPath, true)
if err == nil {
t.Errorf("uploading via scp to a non existent dir must fail")
}
os.RemoveAll(testBaseDirPath)
os.RemoveAll(testBaseDirDownPath)
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPPermCreateDirs(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermUpload}
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(32760)
testBaseDirName := "test_dir"
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
err = createTestFile(testFilePath1, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/")
err = scpUpload(testFilePath, remoteUpPath, true)
if err == nil {
t.Errorf("scp upload must fail, the user cannot create new dirs")
}
err = scpUpload(testBaseDirPath, remoteUpPath, true)
if err == nil {
t.Errorf("scp upload must fail, the user cannot create new dirs")
}
err = os.Remove(testFilePath)
if err != nil {
t.Errorf("error removing test file")
}
os.RemoveAll(testBaseDirPath)
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPPermUpload(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs}
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65536)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp")
err = scpUpload(testFilePath, remoteUpPath, true)
if err == nil {
t.Errorf("scp upload must fail, the user cannot upload")
}
err = os.Remove(testFilePath)
if err != nil {
t.Errorf("error removing test file")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPPermDownload(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
u := getTestUser(usePubKey)
u.Permissions = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs}
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65537)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "tmp")
err = scpUpload(testFilePath, remoteUpPath, true)
if err != nil {
t.Errorf("error uploading existing file via scp: %v", err)
}
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/tmp", testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat")
err = scpDownload(localPath, remoteDownPath, false, false)
if err == nil {
t.Errorf("scp download must fail, the user cannot download")
}
err = os.Remove(testFilePath)
if err != nil {
t.Errorf("error removing test file")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPQuotaSize(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
testFileSize := int64(65535)
u := getTestUser(usePubKey)
u.QuotaFiles = 1
u.QuotaSize = testFileSize - 1
user, _, err := api.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
err = scpUpload(testFilePath, remoteUpPath, true)
if err != nil {
t.Errorf("error uploading existing file via scp: %v", err)
}
err = scpUpload(testFilePath, remoteUpPath+".quota", true)
if err == nil {
t.Errorf("user is over quota scp upload must fail")
}
err = os.Remove(testFilePath)
if err != nil {
t.Errorf("error removing test file")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPEscapeHomeDir(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
os.MkdirAll(user.GetHomeDir(), 0777)
testDir := "testDir"
linkPath := filepath.Join(homeBasePath, defaultUsername, testDir)
err = os.Symlink(homeBasePath, linkPath)
if err != nil {
t.Errorf("error making local symlink: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir))
err = scpUpload(testFilePath, remoteUpPath, false)
if err == nil {
t.Errorf("uploading to a dir with a symlink outside home dir must fail")
}
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat")
err = scpDownload(localPath, remoteDownPath, false, false)
if err == nil {
t.Errorf("scp download must fail, the requested file has a symlink outside user home")
}
remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir))
err = scpDownload(homeBasePath, remoteDownPath, false, true)
if err == nil {
t.Errorf("scp download must fail, the requested dir is a symlink outside user home")
}
err = os.Remove(testFilePath)
if err != nil {
t.Errorf("error removing test file")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPUploadPaths(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
testDirName := "testDir"
testDirPath := filepath.Join(user.GetHomeDir(), testDirName)
os.MkdirAll(testDirPath, 0777)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName)
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName))
localPath := filepath.Join(homeBasePath, "scp_download.dat")
err = scpUpload(testFilePath, remoteUpPath, false)
if err != nil {
t.Errorf("scp upload error: %v", err)
}
err = scpDownload(localPath, remoteDownPath, false, false)
if err != nil {
t.Errorf("scp download error: %v", err)
}
// upload a file to a missing dir
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName))
err = scpUpload(testFilePath, remoteUpPath, false)
if err == nil {
t.Errorf("scp upload to a missing dir must fail")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
func TestSCPOverwriteDirWithFile(t *testing.T) {
if len(scpPath) == 0 {
t.Skip("scp command not found, unable to execute this test")
}
usePubKey := true
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
testDirPath := filepath.Join(user.GetHomeDir(), testFileName)
os.MkdirAll(testDirPath, 0777)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
err = scpUpload(testFilePath, remoteUpPath, false)
if err == nil {
t.Errorf("copying a file over an existing dir must fail")
}
err = os.RemoveAll(user.GetHomeDir())
if err != nil {
t.Errorf("error removing uploaded files")
}
_, err = api.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
}
// End SCP tests
func waitTCPListening(address string) {
for {
conn, err := net.Dial("tcp", address)
@ -1487,6 +2008,10 @@ func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error)
}
func createTestFile(path string, size int64) error {
baseDir := filepath.Dir(path)
if _, err := os.Stat(baseDir); os.IsNotExist(err) {
os.MkdirAll(baseDir, 0777)
}
content := make([]byte, size)
_, err := rand.Read(content)
if err != nil {
@ -1572,6 +2097,49 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe
return c
}
func scpUpload(localPath, remotePath string, preserveTime bool) error {
var args []string
if preserveTime {
args = append(args, "-p")
}
fi, err := os.Stat(localPath)
if err == nil {
if fi.IsDir() {
args = append(args, "-r")
}
}
args = append(args, "-P")
args = append(args, "2022")
args = append(args, "-o")
args = append(args, "StrictHostKeyChecking=no")
args = append(args, "-i")
args = append(args, privateKeyPath)
args = append(args, localPath)
args = append(args, remotePath)
cmd := exec.Command(scpPath, args...)
return cmd.Run()
}
func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
var args []string
if preserveTime {
args = append(args, "-p")
}
if recursive {
args = append(args, "-r")
}
args = append(args, "-P")
args = append(args, "2022")
args = append(args, "-o")
args = append(args, "StrictHostKeyChecking=no")
args = append(args, "-i")
args = append(args, privateKeyPath)
args = append(args, remotePath)
args = append(args, localPath)
cmd := exec.Command(scpPath, args...)
return cmd.Run()
}
func waitForActiveTransfer() {
stats := sftpd.GetConnectionsStats()
for len(stats) < 1 {

View file

@ -31,6 +31,7 @@ type Transfer struct {
transferType int
lastActivity time.Time
isNewFile bool
protocol string
}
// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
@ -64,10 +65,10 @@ func (t *Transfer) Close() error {
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
if t.transferType == transferDownload {
logger.TransferLog(sftpdDownloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID)
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
executeAction(operationDownload, t.user.Username, t.path, "")
} else {
logger.TransferLog(sftpUploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID)
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
executeAction(operationUpload, t.user.Username, t.path, "")
}
removeTransfer(t)

View file

@ -12,7 +12,8 @@
"command": "",
"http_notification_url": ""
},
"keys": []
"keys": [],
"enable_scp": false
},
"data_provider": {
"driver": "sqlite",