mirror of
https://github.com/drakkan/sftpgo.git
synced 2024-11-21 23:20:24 +00:00
add SCP support
SCP is an experimental feature, we have our own SCP implementation since we can't rely on scp system command to proper handle permissions, quota and user's home dir restrictions. The SCP protocol is quite simple but there is no official docs about it, so we need more testing and feedbacks before enabling it by default. We may not handle some borderline cases or have sneaky bugs. This commit contains some breaking changes to the REST API. SFTPGo API should be stable now and I hope no more breaking changes before the first stable release.
This commit is contained in:
parent
2c05791624
commit
e50c521c33
19 changed files with 2077 additions and 128 deletions
22
README.md
22
README.md
|
@ -19,6 +19,7 @@ Full featured and highly configurable SFTP server software
|
|||
- Log files are accurate and they are saved in the easily parsable JSON format
|
||||
- Automatically terminating idle connections
|
||||
- Atomic uploads are supported
|
||||
- Optional SCP support
|
||||
|
||||
## Platforms
|
||||
|
||||
|
@ -55,7 +56,7 @@ Version info, such as git commit and build date, can be embedded setting the fol
|
|||
For example you can build using the following command:
|
||||
|
||||
```bash
|
||||
go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date --utc +%FT%TZ`" -o sftpgo
|
||||
go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo
|
||||
```
|
||||
|
||||
and you will get a version that includes git commit and build date like this one:
|
||||
|
@ -129,6 +130,7 @@ The `sftpgo` configuration file contains the following sections:
|
|||
- `target_path`, added for `rename` action only
|
||||
- `keys`, struct array. It contains the daemon's private keys. If empty or missing the daemon will search or try to generate `id_rsa` in the configuration directory.
|
||||
- `private_key`, path to the private key file. It can be a path relative to the config dir or an absolute one.
|
||||
- `enable_scp`, boolean. Default disabled. Set to `true` to enable SCP support. SCP is an experimental feature, we have our own SCP implementation since we can't rely on `scp` system command to proper handle permissions, quota and user's home dir restrictions. The SCP protocol is quite simple but there is no official docs about it, so we need more testing and feedbacks before enabling it by default. We may not handle some borderline cases or have sneaky bugs. Please do accurate tests yourself before enabling SCP and let us known if something does not work as expected for your use cases.
|
||||
- **"data_provider"**, the configuration for the data provider
|
||||
- `driver`, string. Supported drivers are `sqlite`, `mysql`, `postgresql`, `bolt`
|
||||
- `name`, string. Database name. For driver `sqlite` this can be the database name relative to the config dir or the absolute path to the SQLite database.
|
||||
|
@ -164,7 +166,8 @@ Here is a full example showing the default config in json format:
|
|||
"command": "",
|
||||
"http_notification_url": ""
|
||||
},
|
||||
"keys": []
|
||||
"keys": [],
|
||||
"enable_scp": false
|
||||
},
|
||||
"data_provider": {
|
||||
"driver": "sqlite",
|
||||
|
@ -287,22 +290,24 @@ The logs can be divided into the following categories:
|
|||
- `time` string. Date/time with millisecond precision
|
||||
- `level` string
|
||||
- `message` string
|
||||
- **"transfer logs"**, SFTP transfer logs:
|
||||
- `sender` string. `SFTPUpload` or `SFTPDownload`
|
||||
- **"transfer logs"**, SFTP/SCP transfer logs:
|
||||
- `sender` string. `Upload` or `Download`
|
||||
- `time` string. Date/time with millisecond precision
|
||||
- `level` string
|
||||
- `elapsed_ms`, int64. Elapsed time, as milliseconds, for the upload/download
|
||||
- `size_bytes`, int64. Size, as bytes, of the download/upload
|
||||
- `username`, string
|
||||
- `file_path` string
|
||||
- `connection_id` string. Unique SFTP connection identifier
|
||||
- **"command logs"**, SFTP command logs:
|
||||
- `sender` string. `SFTPRename`, `SFTPRmdir`, `SFTPMkdir`, `SFTPSymlink`, `SFTPRemove`
|
||||
- `connection_id` string. Unique connection identifier
|
||||
- `protocol` string. `SFTP` or `SCP`
|
||||
- **"command logs"**, SFTP/SCP command logs:
|
||||
- `sender` string. `Rename`, `Rmdir`, `Mkdir`, `Symlink`, `Remove`
|
||||
- `level` string
|
||||
- `username`, string
|
||||
- `file_path` string
|
||||
- `target_path` string
|
||||
- `connection_id` string. Unique SFTP connection identifier
|
||||
- `connection_id` string. Unique connection identifier
|
||||
- `protocol` string. `SFTP` or `SCP`
|
||||
- **"http logs"**, REST API logs:
|
||||
- `sender` string. `httpd`
|
||||
- `level` string
|
||||
|
@ -329,6 +334,7 @@ The logs can be divided into the following categories:
|
|||
- [lib/pq](https://github.com/lib/pq)
|
||||
- [viper](https://github.com/spf13/viper)
|
||||
- [cobra](https://github.com/spf13/cobra)
|
||||
- [xid](https://github.com/rs/xid)
|
||||
|
||||
Some code was initially taken from [Pterodactyl sftp server](https://github.com/pterodactyl/sftp-server)
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
const (
|
||||
logSender = "api"
|
||||
activeConnectionsPath = "/api/v1/sftp_connection"
|
||||
activeConnectionsPath = "/api/v1/connection"
|
||||
quotaScanPath = "/api/v1/quota_scan"
|
||||
userPath = "/api/v1/user"
|
||||
versionPath = "/api/v1/version"
|
||||
|
|
|
@ -33,7 +33,7 @@ const (
|
|||
testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"
|
||||
logSender = "APITesting"
|
||||
userPath = "/api/v1/user"
|
||||
activeConnectionsPath = "/api/v1/sftp_connection"
|
||||
activeConnectionsPath = "/api/v1/connection"
|
||||
quotaScanPath = "/api/v1/quota_scan"
|
||||
versionPath = "/api/v1/version"
|
||||
)
|
||||
|
@ -405,19 +405,19 @@ func TestGetVersion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetSFTPConnections(t *testing.T) {
|
||||
_, _, err := api.GetSFTPConnections(http.StatusOK)
|
||||
func TestGetConnections(t *testing.T) {
|
||||
_, _, err := api.GetConnections(http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to get sftp connections: %v", err)
|
||||
}
|
||||
_, _, err = api.GetSFTPConnections(http.StatusInternalServerError)
|
||||
_, _, err = api.GetConnections(http.StatusInternalServerError)
|
||||
if err == nil {
|
||||
t.Errorf("get sftp connections request must succeed, we requested to check a wrong status code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseActiveSFTPConnection(t *testing.T) {
|
||||
_, err := api.CloseSFTPConnection("non_existent_id", http.StatusNotFound)
|
||||
func TestCloseActiveConnection(t *testing.T) {
|
||||
_, err := api.CloseConnection("non_existent_id", http.StatusNotFound)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error closing non existent sftp connection: %v", err)
|
||||
}
|
||||
|
@ -686,7 +686,7 @@ func TestGetVersionMock(t *testing.T) {
|
|||
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
func TestGetSFTPConnectionsMock(t *testing.T) {
|
||||
func TestGetConnectionsMock(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil)
|
||||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusOK, rr.Code)
|
||||
|
|
|
@ -207,8 +207,8 @@ func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, err
|
|||
return body, checkResponse(resp.StatusCode, expectedStatusCode)
|
||||
}
|
||||
|
||||
// GetSFTPConnections returns status and stats for active SFTP connections
|
||||
func GetSFTPConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
|
||||
// GetConnections returns status and stats for active SFTP/SCP connections
|
||||
func GetConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) {
|
||||
var connections []sftpd.ConnectionStatus
|
||||
var body []byte
|
||||
resp, err := getHTTPClient().Get(buildURLRelativeToBase(activeConnectionsPath))
|
||||
|
@ -225,8 +225,8 @@ func GetSFTPConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byt
|
|||
return connections, body, err
|
||||
}
|
||||
|
||||
// CloseSFTPConnection closes an active SFTP connection identified by connectionID
|
||||
func CloseSFTPConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
|
||||
// CloseConnection closes an active connection identified by connectionID
|
||||
func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) {
|
||||
var body []byte
|
||||
req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -161,7 +161,7 @@ func TestApiCallsWithBadURL(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Errorf("request with invalid URL must fail")
|
||||
}
|
||||
_, err = CloseSFTPConnection("non_existent_id", http.StatusNotFound)
|
||||
_, err = CloseConnection("non_existent_id", http.StatusNotFound)
|
||||
if err == nil {
|
||||
t.Errorf("request with invalid URL must fail")
|
||||
}
|
||||
|
@ -200,11 +200,11 @@ func TestApiCallToNotListeningServer(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Errorf("request to an inactive URL must fail")
|
||||
}
|
||||
_, _, err = GetSFTPConnections(http.StatusOK)
|
||||
_, _, err = GetConnections(http.StatusOK)
|
||||
if err == nil {
|
||||
t.Errorf("request to an inactive URL must fail")
|
||||
}
|
||||
_, err = CloseSFTPConnection("non_existent_id", http.StatusNotFound)
|
||||
_, err = CloseConnection("non_existent_id", http.StatusNotFound)
|
||||
if err == nil {
|
||||
t.Errorf("request to an inactive URL must fail")
|
||||
}
|
||||
|
@ -215,13 +215,13 @@ func TestApiCallToNotListeningServer(t *testing.T) {
|
|||
SetBaseURL(oldBaseURL)
|
||||
}
|
||||
|
||||
func TestCloseSFTPConnectionHandler(t *testing.T) {
|
||||
func TestCloseConnectionHandler(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil)
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("connectionID", "")
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
rr := httptest.NewRecorder()
|
||||
handleCloseSFTPConnection(rr, req)
|
||||
handleCloseConnection(rr, req)
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected response code 400. Got %d", rr.Code)
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func initializeRouter() {
|
|||
})
|
||||
|
||||
router.Delete(activeConnectionsPath+"/{connectionID}", func(w http.ResponseWriter, r *http.Request) {
|
||||
handleCloseSFTPConnection(w, r)
|
||||
handleCloseConnection(w, r)
|
||||
})
|
||||
|
||||
router.Get(quotaScanPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -72,7 +72,7 @@ func initializeRouter() {
|
|||
})
|
||||
}
|
||||
|
||||
func handleCloseSFTPConnection(w http.ResponseWriter, r *http.Request) {
|
||||
func handleCloseConnection(w http.ResponseWriter, r *http.Request) {
|
||||
connectionID := chi.URLParam(r, "connectionID")
|
||||
if connectionID == "" {
|
||||
sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest)
|
||||
|
|
|
@ -22,12 +22,12 @@ paths:
|
|||
type: array
|
||||
items:
|
||||
$ref : '#/components/schemas/VersionInfo'
|
||||
/sftp_connection:
|
||||
/connection:
|
||||
get:
|
||||
tags:
|
||||
- connections
|
||||
summary: Get the active sftp users and info about their uploads/downloads
|
||||
operationId: get_sftp_connections
|
||||
summary: Get the active users and info about their uploads/downloads
|
||||
operationId: get_connections
|
||||
responses:
|
||||
200:
|
||||
description: successful operation
|
||||
|
@ -37,12 +37,12 @@ paths:
|
|||
type: array
|
||||
items:
|
||||
$ref : '#/components/schemas/ConnectionStatus'
|
||||
/sftp_connection/{connectionID}:
|
||||
/connection/{connectionID}:
|
||||
delete:
|
||||
tags:
|
||||
- connections
|
||||
summary: Terminate an active SFTP connection
|
||||
operationId: close_sftp_connection
|
||||
summary: Terminate an active connection
|
||||
operationId: close_connection
|
||||
parameters:
|
||||
- name: connectionID
|
||||
in: path
|
||||
|
@ -183,7 +183,7 @@ paths:
|
|||
get:
|
||||
tags:
|
||||
- users
|
||||
summary: Returns an array with one or more SFTP users
|
||||
summary: Returns an array with one or more users
|
||||
description: For security reasons password and public key are empty in the response
|
||||
operationId: get_users
|
||||
parameters:
|
||||
|
@ -261,7 +261,7 @@ paths:
|
|||
post:
|
||||
tags:
|
||||
- users
|
||||
summary: Adds a new SFTP user
|
||||
summary: Adds a new SFTP/SCP user
|
||||
operationId: add_user
|
||||
requestBody:
|
||||
required: true
|
||||
|
@ -562,15 +562,15 @@ components:
|
|||
max_sessions:
|
||||
type: integer
|
||||
format: int32
|
||||
description: limit the sessions that an sftp user can open. 0 means unlimited
|
||||
description: limit the sessions that an user can open. 0 means unlimited
|
||||
quota_size:
|
||||
type: integer
|
||||
format: int64
|
||||
description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed
|
||||
description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed
|
||||
quota_files:
|
||||
type: integer
|
||||
format: int32
|
||||
description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via sftp otherwise a quota scan is needed
|
||||
description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed
|
||||
permissions:
|
||||
type: array
|
||||
items:
|
||||
|
@ -594,7 +594,7 @@ components:
|
|||
type: integer
|
||||
format: int32
|
||||
description: Maximum download bandwidth as KB/s, 0 means unlimited
|
||||
SFTPTransfer:
|
||||
Transfer:
|
||||
type: object
|
||||
properties:
|
||||
operation_type:
|
||||
|
@ -604,7 +604,7 @@ components:
|
|||
- download
|
||||
path:
|
||||
type: string
|
||||
description: SFTP file path for the upload/download
|
||||
description: SFTP/SCP file path for the upload/download
|
||||
start_time:
|
||||
type: integer
|
||||
format: int64
|
||||
|
@ -625,13 +625,13 @@ components:
|
|||
description: connected username
|
||||
connection_id:
|
||||
type: string
|
||||
description: unique sftp connection identifier
|
||||
description: unique connection identifier
|
||||
client_version:
|
||||
type: string
|
||||
description: SFTP client version
|
||||
description: SFTP/SCP client version
|
||||
remote_address:
|
||||
type: string
|
||||
description: Remote address for the connected SFTP client
|
||||
description: Remote address for the connected SFTP/SCP client
|
||||
connection_time:
|
||||
type: integer
|
||||
format: int64
|
||||
|
@ -640,10 +640,15 @@ components:
|
|||
type: integer
|
||||
format: int64
|
||||
description: last client activity as unix timestamp in milliseconds
|
||||
protocol:
|
||||
type: string
|
||||
enum:
|
||||
- SFTP
|
||||
- SCP
|
||||
active_transfers:
|
||||
type: array
|
||||
items:
|
||||
$ref : '#/components/schemas/SFTPTransfer'
|
||||
$ref : '#/components/schemas/Transfer'
|
||||
QuotaScan:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -53,7 +53,8 @@ func init() {
|
|||
Command: "",
|
||||
HTTPNotificationURL: "",
|
||||
},
|
||||
Keys: []sftpd.Key{},
|
||||
Keys: []sftpd.Key{},
|
||||
IsSCPEnabled: false,
|
||||
},
|
||||
ProviderConf: dataprovider.Config{
|
||||
Driver: "sqlite",
|
||||
|
|
|
@ -90,8 +90,8 @@ func ErrorToConsole(format string, v ...interface{}) {
|
|||
consoleLogger.Error().Msg(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// TransferLog logs an SFTP upload or download
|
||||
func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string) {
|
||||
// TransferLog logs an SFTP/SCP upload or download
|
||||
func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string, protocol string) {
|
||||
logger.Info().
|
||||
Str("sender", operation).
|
||||
Int64("elapsed_ms", elapsed).
|
||||
|
@ -99,16 +99,18 @@ func TransferLog(operation string, path string, elapsed int64, size int64, user
|
|||
Str("username", user).
|
||||
Str("file_path", path).
|
||||
Str("connection_id", connectionID).
|
||||
Str("protocol", protocol).
|
||||
Msg("")
|
||||
}
|
||||
|
||||
// CommandLog logs an SFTP command
|
||||
func CommandLog(command string, path string, target string, user string, connectionID string) {
|
||||
// CommandLog logs an SFTP/SCP command
|
||||
func CommandLog(command string, path string, target string, user string, connectionID string, protocol string) {
|
||||
logger.Info().
|
||||
Str("sender", command).
|
||||
Str("username", user).
|
||||
Str("file_path", path).
|
||||
Str("target_path", target).
|
||||
Str("connection_id", connectionID).
|
||||
Str("protocol", protocol).
|
||||
Msg("")
|
||||
}
|
||||
|
|
|
@ -154,12 +154,12 @@ Output:
|
|||
]
|
||||
```
|
||||
|
||||
### Get SFTP connections
|
||||
### Get active connections
|
||||
|
||||
Command:
|
||||
|
||||
```
|
||||
python sftpgo_api_cli.py get-sftp-connections
|
||||
python sftpgo_api_cli.py get-connections
|
||||
```
|
||||
|
||||
Output:
|
||||
|
@ -173,9 +173,11 @@ Output:
|
|||
"remote_address": "127.0.0.1:41622",
|
||||
"connection_time": 1564696137971,
|
||||
"last_activity": 1564696159605,
|
||||
"protocol": "SFTP",
|
||||
"active_transfers": [
|
||||
{
|
||||
"operation_type": "upload",
|
||||
"path": "/test_upload.gz",
|
||||
"start_time": 1564696149783,
|
||||
"size": 1146880,
|
||||
"last_activity": 1564696159605
|
||||
|
@ -185,12 +187,12 @@ Output:
|
|||
]
|
||||
```
|
||||
|
||||
### Close SFTP connection
|
||||
### Close connection
|
||||
|
||||
Command:
|
||||
|
||||
```
|
||||
python sftpgo_api_cli.py close-sftp-connection 76a11b22260ee4249328df28bef34dc64c70f7c097db52159fc24049eeb0e32c
|
||||
python sftpgo_api_cli.py close-connection 76a11b22260ee4249328df28bef34dc64c70f7c097db52159fc24049eeb0e32c
|
||||
```
|
||||
|
||||
Output:
|
||||
|
|
|
@ -22,7 +22,7 @@ class SFTPGoApiRequests:
|
|||
def __init__(self, debug, baseUrl, authType, authUser, authPassword, secure, no_color):
|
||||
self.userPath = urlparse.urljoin(baseUrl, '/api/v1/user')
|
||||
self.quotaScanPath = urlparse.urljoin(baseUrl, '/api/v1/quota_scan')
|
||||
self.activeConnectionsPath = urlparse.urljoin(baseUrl, '/api/v1/sftp_connection')
|
||||
self.activeConnectionsPath = urlparse.urljoin(baseUrl, '/api/v1/connection')
|
||||
self.versionPath = urlparse.urljoin(baseUrl, '/api/v1/version')
|
||||
self.debug = debug
|
||||
if authType == 'basic':
|
||||
|
@ -101,12 +101,12 @@ class SFTPGoApiRequests:
|
|||
r = requests.delete(urlparse.urljoin(self.userPath, "user/" + str(user_id)), auth=self.auth, verify=self.verify)
|
||||
self.printResponse(r)
|
||||
|
||||
def getSFTPConnections(self):
|
||||
def getConnections(self):
|
||||
r = requests.get(self.activeConnectionsPath, auth=self.auth, verify=self.verify)
|
||||
self.printResponse(r)
|
||||
|
||||
def closeSFTPConnection(self, connectionID):
|
||||
r = requests.delete(urlparse.urljoin(self.userPath, "sftp_connection/" + str(connectionID)), auth=self.auth)
|
||||
def closeConnection(self, connectionID):
|
||||
r = requests.delete(urlparse.urljoin(self.activeConnectionsPath, "connection/" + str(connectionID)), auth=self.auth)
|
||||
self.printResponse(r)
|
||||
|
||||
def getQuotaScans(self):
|
||||
|
@ -187,11 +187,11 @@ if __name__ == '__main__':
|
|||
parserGetUserByID = subparsers.add_parser('get-user-by-id', help='Find user by ID')
|
||||
parserGetUserByID.add_argument('id', type=int)
|
||||
|
||||
parserGetSFTPConnections = subparsers.add_parser('get-sftp-connections',
|
||||
help='Get the active sftp users and info about their uploads/downloads')
|
||||
parserGetConnections = subparsers.add_parser('get-connections',
|
||||
help='Get the active users and info about their uploads/downloads')
|
||||
|
||||
parserCloseSFTPConnection = subparsers.add_parser('close-sftp-connection', help='Terminate an active SFTP connection')
|
||||
parserCloseSFTPConnection.add_argument('connectionID', type=str)
|
||||
parserCloseConnection = subparsers.add_parser('close-connection', help='Terminate an active SFTP/SCP connection')
|
||||
parserCloseConnection.add_argument('connectionID', type=str)
|
||||
|
||||
parserGetQuotaScans = subparsers.add_parser('get-quota-scans', help='Get the active quota scans')
|
||||
|
||||
|
@ -219,10 +219,10 @@ if __name__ == '__main__':
|
|||
api.getUsers(args.limit, args.offset, args.order, args.username)
|
||||
elif args.command == 'get-user-by-id':
|
||||
api.getUserByID(args.id)
|
||||
elif args.command == 'get-sftp-connections':
|
||||
api.getSFTPConnections()
|
||||
elif args.command == 'close-sftp-connection':
|
||||
api.closeSFTPConnection(args.connectionID)
|
||||
elif args.command == 'get-connections':
|
||||
api.getConnections()
|
||||
elif args.command == 'close-connection':
|
||||
api.closeConnection(args.connectionID)
|
||||
elif args.command == 'get-quota-scans':
|
||||
api.getQuotaScans()
|
||||
elif args.command == 'start-quota-scan':
|
||||
|
|
|
@ -35,6 +35,7 @@ type Connection struct {
|
|||
StartTime time.Time
|
||||
// last activity for this connection
|
||||
lastActivity time.Time
|
||||
protocol string
|
||||
lock *sync.Mutex
|
||||
sshConn *ssh.ServerConn
|
||||
}
|
||||
|
@ -78,6 +79,7 @@ func (c Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
|||
transferType: transferDownload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
@ -254,7 +256,7 @@ func (c Connection) handleSFTPRename(sourcePath string, targetPath string) error
|
|||
logger.Error(logSender, "failed to rename file, source: %v target: %v: %v", sourcePath, targetPath, err)
|
||||
return sftp.ErrSshFxFailure
|
||||
}
|
||||
logger.CommandLog(sftpdRenameLogSender, sourcePath, targetPath, c.User.Username, c.ID)
|
||||
logger.CommandLog(renameLogSender, sourcePath, targetPath, c.User.Username, c.ID, c.protocol)
|
||||
executeAction(operationRename, c.User.Username, sourcePath, targetPath)
|
||||
return nil
|
||||
}
|
||||
|
@ -274,7 +276,7 @@ func (c Connection) handleSFTPRmdir(path string) error {
|
|||
return sftp.ErrSshFxFailure
|
||||
}
|
||||
|
||||
logger.CommandLog(sftpdRmdirLogSender, path, "", c.User.Username, c.ID)
|
||||
logger.CommandLog(rmdirLogSender, path, "", c.User.Username, c.ID, c.protocol)
|
||||
dataprovider.UpdateUserQuota(dataProvider, c.User, -numFiles, -size, false)
|
||||
for _, p := range fileList {
|
||||
executeAction(operationDelete, c.User.Username, p, "")
|
||||
|
@ -291,7 +293,7 @@ func (c Connection) handleSFTPSymlink(sourcePath string, targetPath string) erro
|
|||
return sftp.ErrSshFxFailure
|
||||
}
|
||||
|
||||
logger.CommandLog(sftpdSymlinkLogSender, sourcePath, targetPath, c.User.Username, c.ID)
|
||||
logger.CommandLog(symlinkLogSender, sourcePath, targetPath, c.User.Username, c.ID, c.protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -304,7 +306,7 @@ func (c Connection) handleSFTPMkdir(path string) error {
|
|||
logger.Error(logSender, "error making missing dir for path %v: %v", path, err)
|
||||
return sftp.ErrSshFxFailure
|
||||
}
|
||||
logger.CommandLog(sftpdMkdirLogSender, path, "", c.User.Username, c.ID)
|
||||
logger.CommandLog(mkdirLogSender, path, "", c.User.Username, c.ID, c.protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -326,7 +328,7 @@ func (c Connection) handleSFTPRemove(path string) error {
|
|||
return sftp.ErrSshFxFailure
|
||||
}
|
||||
|
||||
logger.CommandLog(sftpdRemoveLogSender, path, "", c.User.Username, c.ID)
|
||||
logger.CommandLog(removeLogSender, path, "", c.User.Username, c.ID, c.protocol)
|
||||
if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
|
||||
dataprovider.UpdateUserQuota(dataProvider, c.User, -1, -size, false)
|
||||
}
|
||||
|
@ -372,6 +374,7 @@ func (c Connection) handleSFTPUploadToNewFile(requestPath, filePath string) (io.
|
|||
transferType: transferUpload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: true,
|
||||
protocol: c.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
@ -426,6 +429,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
|
|||
transferType: transferUpload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
return &transfer, nil
|
||||
|
|
|
@ -1,14 +1,56 @@
|
|||
package sftpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/dataprovider"
|
||||
"github.com/pkg/sftp"
|
||||
)
|
||||
|
||||
type MockChannel struct {
|
||||
Buffer *bytes.Buffer
|
||||
StdErrBuffer *bytes.Buffer
|
||||
ReadError error
|
||||
WriteError error
|
||||
}
|
||||
|
||||
func (c *MockChannel) Read(data []byte) (int, error) {
|
||||
if c.ReadError != nil {
|
||||
return 0, c.ReadError
|
||||
}
|
||||
return c.Buffer.Read(data)
|
||||
}
|
||||
|
||||
func (c *MockChannel) Write(data []byte) (int, error) {
|
||||
if c.WriteError != nil {
|
||||
return 0, c.WriteError
|
||||
}
|
||||
return c.Buffer.Write(data)
|
||||
}
|
||||
|
||||
func (c *MockChannel) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *MockChannel) Stderr() io.ReadWriter {
|
||||
return c.StdErrBuffer
|
||||
}
|
||||
|
||||
func TestWrongActions(t *testing.T) {
|
||||
actionsCopy := actions
|
||||
badCommand := "/bad/command"
|
||||
|
@ -114,3 +156,519 @@ func TestWithInvalidHome(t *testing.T) {
|
|||
t.Errorf("tested path is not a home subdir")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSFTPCmdTargetPath(t *testing.T) {
|
||||
u := dataprovider.User{}
|
||||
u.HomeDir = "home_rel_path"
|
||||
u.Username = "test"
|
||||
u.Permissions = []string{"*"}
|
||||
connection := Connection{
|
||||
User: u,
|
||||
}
|
||||
_, err := connection.getSFTPCmdTargetPath("invalid_path")
|
||||
if err != sftp.ErrSshFxOpUnsupported {
|
||||
t.Errorf("getSFTPCmdTargetPath must fal with the expected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSFTPGetUsedQuota(t *testing.T) {
|
||||
u := dataprovider.User{}
|
||||
u.HomeDir = "home_rel_path"
|
||||
u.Username = "test_invalid_user"
|
||||
u.QuotaSize = 4096
|
||||
u.QuotaFiles = 1
|
||||
u.Permissions = []string{"*"}
|
||||
connection := Connection{
|
||||
User: u,
|
||||
}
|
||||
res := connection.hasSpace(false)
|
||||
if res != false {
|
||||
t.Errorf("has space must return false if the user is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPFileMode(t *testing.T) {
|
||||
mode := getFileModeAsString(0, true)
|
||||
if mode != "0755" {
|
||||
t.Errorf("invalid file mode: %v expected: 0755", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0700, true)
|
||||
if mode != "0700" {
|
||||
t.Errorf("invalid file mode: %v expected: 0700", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0750, true)
|
||||
if mode != "0750" {
|
||||
t.Errorf("invalid file mode: %v expected: 0750", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0777, true)
|
||||
if mode != "0777" {
|
||||
t.Errorf("invalid file mode: %v expected: 0777", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0640, false)
|
||||
if mode != "0640" {
|
||||
t.Errorf("invalid file mode: %v expected: 0640", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0600, false)
|
||||
if mode != "0600" {
|
||||
t.Errorf("invalid file mode: %v expected: 0600", mode)
|
||||
}
|
||||
mode = getFileModeAsString(0, false)
|
||||
if mode != "0644" {
|
||||
t.Errorf("invalid file mode: %v expected: 0644", mode)
|
||||
}
|
||||
fileMode := uint32(0777)
|
||||
fileMode = fileMode | uint32(os.ModeSetgid)
|
||||
fileMode = fileMode | uint32(os.ModeSetuid)
|
||||
fileMode = fileMode | uint32(os.ModeSticky)
|
||||
mode = getFileModeAsString(os.FileMode(fileMode), false)
|
||||
if mode != "7777" {
|
||||
t.Errorf("invalid file mode: %v expected: 7777", mode)
|
||||
}
|
||||
fileMode = uint32(0644)
|
||||
fileMode = fileMode | uint32(os.ModeSetgid)
|
||||
mode = getFileModeAsString(os.FileMode(fileMode), false)
|
||||
if mode != "4644" {
|
||||
t.Errorf("invalid file mode: %v expected: 4644", mode)
|
||||
}
|
||||
fileMode = uint32(0600)
|
||||
fileMode = fileMode | uint32(os.ModeSetuid)
|
||||
mode = getFileModeAsString(os.FileMode(fileMode), false)
|
||||
if mode != "2600" {
|
||||
t.Errorf("invalid file mode: %v expected: 2600", mode)
|
||||
}
|
||||
fileMode = uint32(0044)
|
||||
fileMode = fileMode | uint32(os.ModeSticky)
|
||||
mode = getFileModeAsString(os.FileMode(fileMode), false)
|
||||
if mode != "1044" {
|
||||
t.Errorf("invalid file mode: %v expected: 1044", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPGetNonExistingDirContent(t *testing.T) {
|
||||
_, err := getDirContents("non_existing")
|
||||
if err == nil {
|
||||
t.Errorf("get non existing dir contents must fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPParseUploadMessage(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-t", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
_, _, err := scpCommand.parseUploadMessage("invalid")
|
||||
if err == nil {
|
||||
t.Errorf("parsing invalid upload message must fail")
|
||||
}
|
||||
_, _, err = scpCommand.parseUploadMessage("D0755 0")
|
||||
if err == nil {
|
||||
t.Errorf("parsing incomplete upload message must fail")
|
||||
}
|
||||
_, _, err = scpCommand.parseUploadMessage("D0755 invalidsize testdir")
|
||||
if err == nil {
|
||||
t.Errorf("parsing upload message with invalid size must fail")
|
||||
}
|
||||
_, _, err = scpCommand.parseUploadMessage("D0755 0 ")
|
||||
if err == nil {
|
||||
t.Errorf("parsing upload message with invalid name must fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPProtocolMessages(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-t", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
_, err := scpCommand.readProtocolMessage()
|
||||
if err == nil || err != readErr {
|
||||
t.Errorf("read protocol message must fail, we are sending a fake error")
|
||||
}
|
||||
err = scpCommand.sendConfirmationMessage()
|
||||
if err != writeErr {
|
||||
t.Errorf("write confirmation message must fail, we are sending a fake error")
|
||||
}
|
||||
err = scpCommand.sendProtocolMessage("E\n")
|
||||
if err != writeErr {
|
||||
t.Errorf("write confirmation message must fail, we are sending a fake error")
|
||||
}
|
||||
_, err = scpCommand.getNextUploadProtocolMessage()
|
||||
if err == nil || err != readErr {
|
||||
t.Errorf("read next upload protocol message must fail, we are sending a fake read error")
|
||||
}
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
_, err = scpCommand.getNextUploadProtocolMessage()
|
||||
if err == nil || err != writeErr {
|
||||
t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
|
||||
}
|
||||
respBuffer := []byte{0x02}
|
||||
protocolErrorMsg := "protocol error msg"
|
||||
respBuffer = append(respBuffer, protocolErrorMsg...)
|
||||
respBuffer = append(respBuffer, 0x0A)
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(respBuffer),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
err = scpCommand.readConfirmationMessage()
|
||||
if err == nil || err.Error() != protocolErrorMsg {
|
||||
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-f", "-p", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
path := "testDir"
|
||||
os.Mkdir(path, 0777)
|
||||
stat, _ := os.Stat(path)
|
||||
err := scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||
if err != writeErr {
|
||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||
}
|
||||
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: nil,
|
||||
}
|
||||
|
||||
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||
if err != readErr {
|
||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||
}
|
||||
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand.args = []string{"-f", "/tmp"}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||
if err != writeErr {
|
||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||
}
|
||||
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||
if err != readErr {
|
||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||
}
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
func TestSCPCommandHandleErrors(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-f", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
err := scpCommand.handle()
|
||||
if err == nil || err != readErr {
|
||||
t.Errorf("scp download must fail, we are sending a fake error")
|
||||
}
|
||||
scpCommand.args = []string{"-i", "/tmp"}
|
||||
err = scpCommand.handle()
|
||||
if err == nil {
|
||||
t.Errorf("invalid scp command must fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursiveDownloadErrors(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-r", "-f", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
path := "testDir"
|
||||
os.Mkdir(path, 0777)
|
||||
stat, _ := os.Stat(path)
|
||||
err := scpCommand.handleRecursiveDownload("invalid_dir", stat)
|
||||
if err != writeErr {
|
||||
t.Errorf("recursive upload download must fail with the expected error: %v", err)
|
||||
}
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
|
||||
if err == nil {
|
||||
t.Errorf("recursive upload download must fail for a non existing dir")
|
||||
}
|
||||
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
func TestRecursiveUploadErrors(t *testing.T) {
|
||||
connection := Connection{}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-r", "-t", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
err := scpCommand.handleRecursiveUpload()
|
||||
if err == nil {
|
||||
t.Errorf("recursive upload must fail, we send a fake error message")
|
||||
}
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
err = scpCommand.handleRecursiveUpload()
|
||||
if err == nil {
|
||||
t.Errorf("recursive upload must fail, we send a fake error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPCreateDirs(t *testing.T) {
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
u := dataprovider.User{}
|
||||
u.HomeDir = "home_rel_path"
|
||||
u.Username = "test"
|
||||
u.Permissions = []string{"*"}
|
||||
connection := Connection{
|
||||
User: u,
|
||||
}
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-r", "-t", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
err := scpCommand.handleCreateDir("invalid_dir")
|
||||
if err == nil {
|
||||
t.Errorf("create invalid dir must fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPDownloadFileData(t *testing.T) {
|
||||
testfile := "testfile"
|
||||
buf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
connection := Connection{}
|
||||
mockSSHChannelReadErr := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: nil,
|
||||
}
|
||||
mockSSHChannelWriteErr := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-r", "-f", "/tmp"},
|
||||
channel: &mockSSHChannelReadErr,
|
||||
}
|
||||
ioutil.WriteFile(testfile, []byte("test"), 0666)
|
||||
stat, _ := os.Stat(testfile)
|
||||
err := scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||
if err != readErr {
|
||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannelWriteErr
|
||||
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||
if err != writeErr {
|
||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||
}
|
||||
scpCommand.args = []string{"-r", "-p", "-f", "/tmp"}
|
||||
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||
if err != writeErr {
|
||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannelReadErr
|
||||
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||
if err != readErr {
|
||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||
}
|
||||
os.Remove(testfile)
|
||||
}
|
||||
|
||||
func TestSCPUploadFiledata(t *testing.T) {
|
||||
testfile := "testfile"
|
||||
connection := Connection{
|
||||
User: dataprovider.User{
|
||||
Username: "testuser",
|
||||
},
|
||||
protocol: protocolSCP,
|
||||
}
|
||||
buf := make([]byte, 65535)
|
||||
stdErrBuf := make([]byte, 65535)
|
||||
readErr := fmt.Errorf("test read error")
|
||||
writeErr := fmt.Errorf("test write error")
|
||||
mockSSHChannel := MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: writeErr,
|
||||
}
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: []string{"-r", "-t", "/tmp"},
|
||||
channel: &mockSSHChannel,
|
||||
}
|
||||
file, _ := os.Create(testfile)
|
||||
transfer := Transfer{
|
||||
file: file,
|
||||
path: file.Name(),
|
||||
start: time.Now(),
|
||||
bytesSent: 0,
|
||||
bytesReceived: 0,
|
||||
user: scpCommand.connection.User,
|
||||
connectionID: "",
|
||||
transferType: transferDownload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: true,
|
||||
protocol: connection.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
err := scpCommand.getUploadFileData(2, &transfer)
|
||||
if err == nil {
|
||||
t.Errorf("upload must fail, we send a fake write error message")
|
||||
}
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: readErr,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
file, _ = os.Create(testfile)
|
||||
transfer.file = file
|
||||
addTransfer(&transfer)
|
||||
err = scpCommand.getUploadFileData(2, &transfer)
|
||||
if err == nil {
|
||||
t.Errorf("upload must fail, we send a fake read error message")
|
||||
}
|
||||
|
||||
respBuffer := []byte("12")
|
||||
respBuffer = append(respBuffer, 0x02)
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(respBuffer),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: nil,
|
||||
}
|
||||
scpCommand.channel = &mockSSHChannel
|
||||
file, _ = os.Create(testfile)
|
||||
transfer.file = file
|
||||
addTransfer(&transfer)
|
||||
err = scpCommand.getUploadFileData(2, &transfer)
|
||||
if err == nil {
|
||||
t.Errorf("upload must fail, we have not enough data to read")
|
||||
}
|
||||
|
||||
// the file is already closed so we have an error on trasfer closing
|
||||
mockSSHChannel = MockChannel{
|
||||
Buffer: bytes.NewBuffer(buf),
|
||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||
ReadError: nil,
|
||||
WriteError: nil,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
err = scpCommand.getUploadFileData(0, &transfer)
|
||||
if err == nil {
|
||||
t.Errorf("upload must fail, the file is closed")
|
||||
}
|
||||
os.Remove(testfile)
|
||||
}
|
||||
|
|
758
sftpd/scp.go
Normal file
758
sftpd/scp.go
Normal file
|
@ -0,0 +1,758 @@
|
|||
package sftpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/dataprovider"
|
||||
"github.com/drakkan/sftpgo/logger"
|
||||
"github.com/drakkan/sftpgo/utils"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
okMsg = []byte{0x00}
|
||||
warnMsg = []byte{0x01} // must be followed by an optional message and a newline
|
||||
errMsg = []byte{0x02} // must be followed by an optional message and a newline
|
||||
newLine = []byte{0x0A}
|
||||
)
|
||||
|
||||
type execMsg struct {
|
||||
Command string
|
||||
}
|
||||
|
||||
type exitStatusMsg struct {
|
||||
Status uint32
|
||||
}
|
||||
|
||||
type scpCommand struct {
|
||||
connection Connection
|
||||
args []string
|
||||
channel ssh.Channel
|
||||
}
|
||||
|
||||
func (c *scpCommand) handle() error {
|
||||
var err error
|
||||
addConnection(c.connection.ID, c.connection)
|
||||
defer removeConnection(c.connection.ID)
|
||||
destPath := c.getDestPath()
|
||||
commandType := c.getCommandType()
|
||||
logger.Debug(logSenderSCP, "handle scp command, args: %v user: %v command type: %v, dest path: %v",
|
||||
c.args, c.connection.User.Username, commandType, destPath)
|
||||
if commandType == "-t" {
|
||||
// -t means "to", so upload
|
||||
err = c.handleRecursiveUpload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else if commandType == "-f" {
|
||||
// -f means "from" so download
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.handleDownload(destPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("scp command not supported, args: %v", c.args)
|
||||
}
|
||||
c.sendExitStatus(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *scpCommand) handleRecursiveUpload() error {
|
||||
var err error
|
||||
numDirs := 0
|
||||
destPath := c.getDestPath()
|
||||
for {
|
||||
err = c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
command, err := c.getNextUploadProtocolMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.HasPrefix(command, "E") {
|
||||
numDirs--
|
||||
logger.Debug(logSenderSCP, "received end dir command, num dirs: %v", numDirs)
|
||||
if numDirs == 0 {
|
||||
// upload is now complete send confirmation message
|
||||
err = c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// the destination dir is now the parent directory
|
||||
destPath = filepath.Join(destPath, "..")
|
||||
}
|
||||
} else {
|
||||
sizeToRead, name, err := c.parseUploadMessage(command)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objPath := path.Join(destPath, name)
|
||||
if strings.HasPrefix(command, "D") {
|
||||
numDirs++
|
||||
err = c.handleCreateDir(objPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destPath = objPath
|
||||
logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath)
|
||||
} else if strings.HasPrefix(command, "C") {
|
||||
// if the upload is not recursive and the destination path does not end with "/"
|
||||
// then this is the wanted filename ...
|
||||
if !c.isRecursive() {
|
||||
if !strings.HasSuffix(destPath, "/") {
|
||||
objPath = destPath
|
||||
// ... but if the requested path is an existing directory then put the uploaded file inside that directory
|
||||
if p, err := c.connection.buildPath(objPath); err == nil {
|
||||
if stat, err := os.Stat(p); err == nil {
|
||||
if stat.IsDir() {
|
||||
objPath = path.Join(destPath, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
err = c.handleUpload(objPath, sizeToRead)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil || numDirs == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *scpCommand) handleCreateDir(dirPath string) error {
|
||||
updateConnectionActivity(c.connection.ID)
|
||||
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs) {
|
||||
err := fmt.Errorf("Permission denied")
|
||||
logger.Warn(logSenderSCP, "error creating dir: %v, permission denied", dirPath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
p, err := c.connection.buildPath(dirPath)
|
||||
if err != nil {
|
||||
logger.Warn(logSenderSCP, "error creating dir: %v, invalid file path, err: %v", dirPath, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.createDir(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.CommandLog(mkdirLogSender, dirPath, "", c.connection.User.Username, c.connection.ID, c.connection.protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
// we need to close the transfer if we have an error
|
||||
func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) error {
|
||||
err := c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if sizeToRead > 0 {
|
||||
remaining := sizeToRead
|
||||
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
|
||||
for {
|
||||
n, err := c.channel.Read(buf)
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
transfer.WriteAt(buf[:n], sizeToRead-remaining)
|
||||
remaining -= int64(n)
|
||||
if remaining <= 0 {
|
||||
break
|
||||
}
|
||||
if remaining < int64(len(buf)) {
|
||||
buf = make([]byte, remaining)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
transfer.Close()
|
||||
return err
|
||||
}
|
||||
err = transfer.Close()
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
return c.sendConfirmationMessage()
|
||||
}
|
||||
|
||||
func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64) error {
|
||||
logger.Debug(logSenderSCP, "upload to new file: %v", filePath)
|
||||
if !c.connection.hasSpace(true) {
|
||||
err := fmt.Errorf("denying file write due to space limit")
|
||||
logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", filePath, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Dir(requestPath)); os.IsNotExist(err) {
|
||||
if !c.connection.User.HasPerm(dataprovider.PermCreateDirs) {
|
||||
err := fmt.Errorf("Permission denied")
|
||||
logger.Warn(logSenderSCP, "error uploading file: %v, permission denied", requestPath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
logger.Error(logSenderSCP, "error creating file %v: %v", requestPath, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
utils.SetPathPermissions(filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
|
||||
|
||||
transfer := Transfer{
|
||||
file: file,
|
||||
path: requestPath,
|
||||
start: time.Now(),
|
||||
bytesSent: 0,
|
||||
bytesReceived: 0,
|
||||
user: c.connection.User,
|
||||
connectionID: c.connection.ID,
|
||||
transferType: transferUpload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: true,
|
||||
protocol: c.connection.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
|
||||
return c.getUploadFileData(sizeToRead, &transfer)
|
||||
}
|
||||
|
||||
func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error {
|
||||
var err error
|
||||
|
||||
updateConnectionActivity(c.connection.ID)
|
||||
if !c.connection.User.HasPerm(dataprovider.PermUpload) {
|
||||
err := fmt.Errorf("Permission denied")
|
||||
logger.Warn(logSenderSCP, "error uploading file: %v, permission denied", uploadFilePath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
p, err := c.connection.buildPath(uploadFilePath)
|
||||
if err != nil {
|
||||
logger.Warn(logSenderSCP, "error uploading file: %v, err: %v", uploadFilePath, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
filePath := p
|
||||
if uploadMode == uploadModeAtomic {
|
||||
filePath = getUploadTempFilePath(p)
|
||||
}
|
||||
stat, statErr := os.Stat(p)
|
||||
if os.IsNotExist(statErr) {
|
||||
return c.handleUploadFile(p, filePath, sizeToRead)
|
||||
}
|
||||
|
||||
if statErr != nil {
|
||||
logger.Error(logSenderSCP, "error performing file stat %v: %v", p, statErr)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
logger.Warn(logSenderSCP, "attempted to open a directory for writing to: %v", p)
|
||||
err = fmt.Errorf("Attempted to open a directory for writing: %v", p)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if uploadMode == uploadModeAtomic {
|
||||
err = os.Rename(p, filePath)
|
||||
if err != nil {
|
||||
logger.Error(logSenderSCP, "error renaming existing file for atomic upload, source: %v, dest: %v, err: %v",
|
||||
p, filePath, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false)
|
||||
|
||||
return c.handleUploadFile(p, filePath, sizeToRead)
|
||||
}
|
||||
|
||||
func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {
|
||||
var err error
|
||||
if c.sendFileTime() {
|
||||
modTime := stat.ModTime().UnixNano() / 1000000000
|
||||
tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
|
||||
err = c.sendProtocolMessage(tCommand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), filepath.Base(dirPath))
|
||||
err = c.sendProtocolMessage(fileMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
return err
|
||||
}
|
||||
|
||||
// we send first all the files in the roor directory and then the directories
|
||||
// for each directory we recursively call this method again
|
||||
func (c *scpCommand) handleRecursiveDownload(dirPath string, stat os.FileInfo) error {
|
||||
var err error
|
||||
if c.isRecursive() {
|
||||
logger.Debug(logSenderSCP, "recursive download, dir path: %v", dirPath)
|
||||
err = c.sendDownloadProtocolMessages(dirPath, stat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
files, err := getDirContents(dirPath)
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
var dirs []string
|
||||
for _, file := range files {
|
||||
filePath := c.connection.User.GetRelativePath(filepath.Join(dirPath, file.Name()))
|
||||
if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink == os.ModeSymlink {
|
||||
err = c.handleDownload(filePath)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
} else if file.IsDir() {
|
||||
dirs = append(dirs, filePath)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
for _, dir := range dirs {
|
||||
err = c.handleDownload(dir)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
err = c.sendProtocolMessage("E\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = fmt.Errorf("Unable to send directory for non recursive copy")
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, transfer *Transfer) error {
|
||||
var err error
|
||||
if c.sendFileTime() {
|
||||
modTime := stat.ModTime().UnixNano() / 1000000000
|
||||
tCommand := fmt.Sprintf("T%v 0 %v 0\n", modTime, modTime)
|
||||
err = c.sendProtocolMessage(tCommand)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fileSize := stat.Size()
|
||||
readed := int64(0)
|
||||
fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath))
|
||||
err = c.sendProtocolMessage(fileMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 32768)
|
||||
for {
|
||||
n, err := transfer.ReadAt(buf, readed)
|
||||
if err == nil || err == io.EOF {
|
||||
if n > 0 {
|
||||
_, err = c.channel.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
readed += int64(n)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
err = c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.readConfirmationMessage()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *scpCommand) handleDownload(filePath string) error {
|
||||
var err error
|
||||
|
||||
updateConnectionActivity(c.connection.ID)
|
||||
|
||||
if !c.connection.User.HasPerm(dataprovider.PermDownload) {
|
||||
err := fmt.Errorf("Permission denied")
|
||||
logger.Warn(logSenderSCP, "error downloading file: %v, permission denied", filePath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
p, err := c.connection.buildPath(filePath)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Invalid file path")
|
||||
logger.Warn(logSenderSCP, "error downloading file: %v, invalid file path", filePath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
var stat os.FileInfo
|
||||
if stat, err = os.Stat(p); os.IsNotExist(err) {
|
||||
logger.Warn(logSenderSCP, "error downloading file: %v, err: %v", p, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if stat.IsDir() {
|
||||
err = c.handleRecursiveDownload(p, stat)
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.Open(p)
|
||||
if err != nil {
|
||||
logger.Error(logSenderSCP, "could not open file \"%v\" for reading: %v", p, err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
transfer := Transfer{
|
||||
file: file,
|
||||
path: p,
|
||||
start: time.Now(),
|
||||
bytesSent: 0,
|
||||
bytesReceived: 0,
|
||||
user: c.connection.User,
|
||||
connectionID: c.connection.ID,
|
||||
transferType: transferDownload,
|
||||
lastActivity: time.Now(),
|
||||
isNewFile: false,
|
||||
protocol: c.connection.protocol,
|
||||
}
|
||||
addTransfer(&transfer)
|
||||
|
||||
err = c.sendDownloadFileData(p, stat, &transfer)
|
||||
// we need to call Close anyway and return close error if any and
|
||||
// if we have no previous error
|
||||
if err == nil {
|
||||
err = transfer.Close()
|
||||
} else {
|
||||
transfer.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// returns the SCP destination path.
|
||||
// We ensure that the path is absolute and in SFTP (UNIX) format
|
||||
func (c *scpCommand) getDestPath() string {
|
||||
destPath := filepath.ToSlash(c.args[len(c.args)-1])
|
||||
if !filepath.IsAbs(destPath) {
|
||||
destPath = "/" + destPath
|
||||
}
|
||||
return destPath
|
||||
}
|
||||
|
||||
func (c *scpCommand) getCommandType() string {
|
||||
return c.args[len(c.args)-2]
|
||||
}
|
||||
|
||||
func (c *scpCommand) sendFileTime() bool {
|
||||
return utils.IsStringInSlice("-p", c.args)
|
||||
}
|
||||
|
||||
func (c *scpCommand) isRecursive() bool {
|
||||
return utils.IsStringInSlice("-r", c.args)
|
||||
}
|
||||
|
||||
// read the SCP confirmation message and the optional text message
|
||||
// the channel will be closed on errors
|
||||
func (c *scpCommand) readConfirmationMessage() error {
|
||||
var msg strings.Builder
|
||||
buf := make([]byte, 1)
|
||||
n, err := c.channel.Read(buf)
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
return err
|
||||
}
|
||||
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
|
||||
isError := buf[0] == errMsg[0]
|
||||
for {
|
||||
n, err = c.channel.Read(buf)
|
||||
readed := buf[:n]
|
||||
if err != nil || (n == 1 && readed[0] == newLine[0]) {
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
msg.WriteString(string(readed))
|
||||
}
|
||||
}
|
||||
logger.Info(logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
|
||||
err = fmt.Errorf("%v", msg.String())
|
||||
c.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// protool messages are newline terminated
|
||||
func (c *scpCommand) readProtocolMessage() (string, error) {
|
||||
var command strings.Builder
|
||||
var err error
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
var n int
|
||||
n, err = c.channel.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if n > 0 {
|
||||
readed := buf[:n]
|
||||
if n == 1 && readed[0] == newLine[0] {
|
||||
break
|
||||
}
|
||||
command.WriteString(string(readed))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
}
|
||||
return command.String(), err
|
||||
}
|
||||
|
||||
// send an error message and close the channel
|
||||
func (c *scpCommand) sendErrorMessage(error string) {
|
||||
c.channel.Write(errMsg)
|
||||
c.channel.Write([]byte(error))
|
||||
c.channel.Write(newLine)
|
||||
c.channel.Close()
|
||||
}
|
||||
|
||||
// send scp confirmation message and close the channel if an error happen
|
||||
func (c *scpCommand) sendConfirmationMessage() error {
|
||||
_, err := c.channel.Write(okMsg)
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// sends a protocol message and close the channel on error
|
||||
func (c *scpCommand) sendProtocolMessage(message string) error {
|
||||
_, err := c.channel.Write([]byte(message))
|
||||
if err != nil {
|
||||
logger.Warn(logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
|
||||
c.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// sends the SCP command exit status
|
||||
func (c *scpCommand) sendExitStatus(err error) {
|
||||
status := uint32(0)
|
||||
if err != nil {
|
||||
status = 1
|
||||
}
|
||||
ex := exitStatusMsg{
|
||||
Status: status,
|
||||
}
|
||||
logger.Debug(logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
|
||||
c.args, c.connection.User.Username, err)
|
||||
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
|
||||
c.channel.Close()
|
||||
}
|
||||
|
||||
// get the next upload protocol message ignoring T command if any
|
||||
// we use our own user setting for permissions
|
||||
func (c *scpCommand) getNextUploadProtocolMessage() (string, error) {
|
||||
var command string
|
||||
var err error
|
||||
for {
|
||||
command, err = c.readProtocolMessage()
|
||||
if err != nil {
|
||||
return command, err
|
||||
}
|
||||
if strings.HasPrefix(command, "T") {
|
||||
err = c.sendConfirmationMessage()
|
||||
if err != nil {
|
||||
return command, err
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return command, err
|
||||
}
|
||||
|
||||
func (c *scpCommand) createDir(dirPath string) error {
|
||||
var err error
|
||||
if err = os.Mkdir(dirPath, 0777); err != nil {
|
||||
logger.Error(logSenderSCP, "error creating dir: %v", dirPath)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return err
|
||||
}
|
||||
utils.SetPathPermissions(dirPath, c.connection.User.GetUID(), c.connection.User.GetGID())
|
||||
return err
|
||||
}
|
||||
|
||||
// parse protocol messages such as:
|
||||
// D0755 0 testdir
|
||||
// or:
|
||||
// C0644 6 testfile
|
||||
// and returns file size and file/directory name
|
||||
func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) {
|
||||
var size int64
|
||||
var name string
|
||||
var err error
|
||||
if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") {
|
||||
err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v",
|
||||
command, c.args, c.connection.User.Username)
|
||||
logger.Warn(logSenderSCP, "error: %v", err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return size, name, err
|
||||
}
|
||||
parts := strings.Split(command, " ")
|
||||
if len(parts) == 3 {
|
||||
size, err = strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
logger.Warn(logSenderSCP, "error getting size from upload message: %v", err)
|
||||
c.sendErrorMessage(fmt.Sprintf("Error getting size: %v", err))
|
||||
return size, name, err
|
||||
}
|
||||
name = parts[2]
|
||||
if len(name) == 0 {
|
||||
err = fmt.Errorf("error getting name from upload message, cannot be empty")
|
||||
logger.Warn(logSenderSCP, "error: %v", err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return size, name, err
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("Error splitting upload message: %v", command)
|
||||
logger.Warn(logSenderSCP, "error: %v", err)
|
||||
c.sendErrorMessage(err.Error())
|
||||
return size, name, err
|
||||
}
|
||||
return size, name, err
|
||||
}
|
||||
|
||||
func getFileModeAsString(fileMode os.FileMode, isDir bool) string {
|
||||
var defaultMode string
|
||||
if isDir {
|
||||
defaultMode = "0755"
|
||||
} else {
|
||||
defaultMode = "0644"
|
||||
}
|
||||
if fileMode == 0 {
|
||||
return defaultMode
|
||||
}
|
||||
modeString := []byte(fileMode.String())
|
||||
nullPerm := []byte("-")
|
||||
u := 0
|
||||
g := 0
|
||||
o := 0
|
||||
s := 0
|
||||
lastChar := len(modeString) - 1
|
||||
if fileMode&os.ModeSticky != 0 {
|
||||
s++
|
||||
}
|
||||
if fileMode&os.ModeSetuid != 0 {
|
||||
s += 2
|
||||
}
|
||||
if fileMode&os.ModeSetgid != 0 {
|
||||
s += 4
|
||||
}
|
||||
if modeString[lastChar-8] != nullPerm[0] {
|
||||
u += 4
|
||||
}
|
||||
if modeString[lastChar-7] != nullPerm[0] {
|
||||
u += 2
|
||||
}
|
||||
if modeString[lastChar-6] != nullPerm[0] {
|
||||
u++
|
||||
}
|
||||
if modeString[lastChar-5] != nullPerm[0] {
|
||||
g += 4
|
||||
}
|
||||
if modeString[lastChar-4] != nullPerm[0] {
|
||||
g += 2
|
||||
}
|
||||
if modeString[lastChar-3] != nullPerm[0] {
|
||||
g++
|
||||
}
|
||||
if modeString[lastChar-2] != nullPerm[0] {
|
||||
o += 4
|
||||
}
|
||||
if modeString[lastChar-1] != nullPerm[0] {
|
||||
o += 2
|
||||
}
|
||||
if modeString[lastChar] != nullPerm[0] {
|
||||
o++
|
||||
}
|
||||
return fmt.Sprintf("%v%v%v%v", s, u, g, o)
|
||||
}
|
||||
|
||||
func getDirContents(path string) ([]os.FileInfo, error) {
|
||||
var files []os.FileInfo
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
files, err = f.Readdir(-1)
|
||||
f.Close()
|
||||
return files, err
|
||||
}
|
115
sftpd/server.go
115
sftpd/server.go
|
@ -15,6 +15,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -52,6 +53,15 @@ type Configuration struct {
|
|||
Actions Actions `json:"actions" mapstructure:"actions"`
|
||||
// Keys are a list of host keys
|
||||
Keys []Key `json:"keys" mapstructure:"keys"`
|
||||
// IsSCPEnabled determines if experimental SCP support is enabled.
|
||||
// We have our own SCP implementation since we can't rely on scp system
|
||||
// command to properly handle permissions, quota and user's home dir restrictions.
|
||||
// The SCP protocol is quite simple but there is no official docs about it,
|
||||
// so we need more testing and feedbacks before enabling it by default.
|
||||
// We may not handle some borderline cases or have sneaky bugs.
|
||||
// Please do accurate tests yourself before enabling SCP and let us known
|
||||
// if something does not work as expected for your use cases
|
||||
IsSCPEnabled bool `json:"enable_scp" mapstructure:"enable_scp"`
|
||||
}
|
||||
|
||||
// Key contains information about host keys
|
||||
|
@ -152,6 +162,28 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||
|
||||
logger.Debug(logSender, "accepted inbound connection, ip: %v", conn.RemoteAddr().String())
|
||||
|
||||
var user dataprovider.User
|
||||
|
||||
err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "Unable to deserialize user info, cannot serve connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
connectionID := hex.EncodeToString(sconn.SessionID())
|
||||
|
||||
connection := Connection{
|
||||
ID: connectionID,
|
||||
User: user,
|
||||
ClientVersion: string(sconn.ClientVersion()),
|
||||
RemoteAddr: conn.RemoteAddr(),
|
||||
StartTime: time.Now(),
|
||||
lastActivity: time.Now(),
|
||||
lock: new(sync.Mutex),
|
||||
sshConn: sconn,
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for newChannel := range chans {
|
||||
|
@ -179,55 +211,54 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||
case "subsystem":
|
||||
if string(req.Payload[4:]) == "sftp" {
|
||||
ok = true
|
||||
connection.protocol = protocolSFTP
|
||||
go c.handleSftpConnection(channel, connection)
|
||||
}
|
||||
case "exec":
|
||||
if c.IsSCPEnabled {
|
||||
var msg execMsg
|
||||
if err := ssh.Unmarshal(req.Payload, &msg); err == nil {
|
||||
name, scpArgs, err := parseCommandPayload(msg.Command)
|
||||
logger.Debug(logSender, "new exec command: %v args: %v user: %v, error: %v", name, scpArgs,
|
||||
connection.User.Username, err)
|
||||
if err == nil && name == "scp" && len(scpArgs) >= 2 {
|
||||
ok = true
|
||||
connection.protocol = protocolSCP
|
||||
scpCommand := scpCommand{
|
||||
connection: connection,
|
||||
args: scpArgs,
|
||||
channel: channel,
|
||||
}
|
||||
go scpCommand.handle()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
}(requests)
|
||||
|
||||
var user dataprovider.User
|
||||
|
||||
err = json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "Unable to deserialize user info, cannot serve connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
connectionID := hex.EncodeToString(sconn.SessionID())
|
||||
|
||||
// Create a new handler for the currently logged in user's server.
|
||||
handler := c.createHandler(sconn, user, connectionID)
|
||||
|
||||
// Create the server instance for the channel using the handler we created above.
|
||||
server := sftp.NewRequestServer(channel, handler)
|
||||
|
||||
if err := server.Serve(); err == io.EOF {
|
||||
logger.Debug(logSender, "connection closed, id: %v", connectionID)
|
||||
server.Close()
|
||||
} else if err != nil {
|
||||
logger.Error(logSender, "sftp connection closed with error id %v: %v", connectionID, err)
|
||||
}
|
||||
|
||||
removeConnection(connectionID)
|
||||
}
|
||||
}
|
||||
|
||||
func (c Configuration) createHandler(conn *ssh.ServerConn, user dataprovider.User, connectionID string) sftp.Handlers {
|
||||
func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) {
|
||||
addConnection(connection.ID, connection)
|
||||
// Create a new handler for the currently logged in user's server.
|
||||
handler := c.createHandler(connection)
|
||||
|
||||
connection := Connection{
|
||||
ID: connectionID,
|
||||
User: user,
|
||||
ClientVersion: string(conn.ClientVersion()),
|
||||
RemoteAddr: conn.RemoteAddr(),
|
||||
StartTime: time.Now(),
|
||||
lastActivity: time.Now(),
|
||||
lock: new(sync.Mutex),
|
||||
sshConn: conn,
|
||||
// Create the server instance for the channel using the handler we created above.
|
||||
server := sftp.NewRequestServer(channel, handler)
|
||||
|
||||
if err := server.Serve(); err == io.EOF {
|
||||
logger.Debug(logSender, "connection closed, id: %v", connection.ID)
|
||||
server.Close()
|
||||
} else if err != nil {
|
||||
logger.Error(logSender, "sftp connection closed with error id %v: %v", connection.ID, err)
|
||||
}
|
||||
|
||||
addConnection(connectionID, connection)
|
||||
removeConnection(connection.ID)
|
||||
}
|
||||
|
||||
func (c Configuration) createHandler(connection Connection) sftp.Handlers {
|
||||
|
||||
return sftp.Handlers{
|
||||
FileGet: connection,
|
||||
|
@ -331,3 +362,11 @@ func (c Configuration) generatePrivateKey(file string) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseCommandPayload(command string) (string, []string, error) {
|
||||
parts := strings.Split(command, " ")
|
||||
if len(parts) < 2 {
|
||||
return parts[0], []string{}, nil
|
||||
}
|
||||
return parts[0], parts[1:], nil
|
||||
}
|
||||
|
|
|
@ -19,18 +19,21 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
logSender = "sftpd"
|
||||
sftpUploadLogSender = "SFTPUpload"
|
||||
sftpdDownloadLogSender = "SFTPDownload"
|
||||
sftpdRenameLogSender = "SFTPRename"
|
||||
sftpdRmdirLogSender = "SFTPRmdir"
|
||||
sftpdMkdirLogSender = "SFTPMkdir"
|
||||
sftpdSymlinkLogSender = "SFTPSymlink"
|
||||
sftpdRemoveLogSender = "SFTPRemove"
|
||||
operationDownload = "download"
|
||||
operationUpload = "upload"
|
||||
operationDelete = "delete"
|
||||
operationRename = "rename"
|
||||
logSender = "sftpd"
|
||||
logSenderSCP = "scp"
|
||||
uploadLogSender = "Upload"
|
||||
downloadLogSender = "Download"
|
||||
renameLogSender = "Rename"
|
||||
rmdirLogSender = "Rmdir"
|
||||
mkdirLogSender = "Mkdir"
|
||||
symlinkLogSender = "Symlink"
|
||||
removeLogSender = "Remove"
|
||||
operationDownload = "download"
|
||||
operationUpload = "upload"
|
||||
operationDelete = "delete"
|
||||
operationRename = "rename"
|
||||
protocolSFTP = "SFTP"
|
||||
protocolSCP = "SCP"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -86,6 +89,8 @@ type ConnectionStatus struct {
|
|||
ConnectionTime int64 `json:"connection_time"`
|
||||
// Last activity as unix timestamp in milliseconds
|
||||
LastActivity int64 `json:"last_activity"`
|
||||
// Protocol for this connection: SFTP or SCP
|
||||
Protocol string `json:"protocol"`
|
||||
// active uploads/downloads
|
||||
Transfers []connectionTransfer `json:"active_transfers"`
|
||||
}
|
||||
|
@ -190,6 +195,7 @@ func GetConnectionsStats() []ConnectionStatus {
|
|||
RemoteAddress: c.RemoteAddr.String(),
|
||||
ConnectionTime: utils.GetTimeAsMsSinceEpoch(c.StartTime),
|
||||
LastActivity: utils.GetTimeAsMsSinceEpoch(c.lastActivity),
|
||||
Protocol: c.protocol,
|
||||
Transfers: []connectionTransfer{},
|
||||
}
|
||||
for _, t := range activeTransfers {
|
||||
|
@ -250,9 +256,7 @@ func CheckIdleConnections() {
|
|||
if idleTime > idleTimeout {
|
||||
logger.Debug(logSender, "close idle connection id: %v idle time: %v", c.ID, idleTime)
|
||||
err := c.sshConn.Close()
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "error closing idle connection: %v", err)
|
||||
}
|
||||
logger.Debug(logSender, "idle connection closed id: %v, err: %v", c.ID, err)
|
||||
}
|
||||
}
|
||||
logger.Debug(logSender, "check idle connections ended")
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
@ -77,8 +79,11 @@ iixITGvaNZh/tjAAAACW5pY29sYUBwMQE=
|
|||
)
|
||||
|
||||
var (
|
||||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
scpPath string
|
||||
pubKeyPath string
|
||||
privateKeyPath string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -97,6 +102,8 @@ func TestMain(m *testing.M) {
|
|||
httpdConf := config.GetHTTPDConfig()
|
||||
router := api.GetHTTPRouter()
|
||||
sftpdConf.BindPort = 2022
|
||||
// we need to test SCP support
|
||||
sftpdConf.IsSCPEnabled = true
|
||||
// we run the test cases with UploadMode atomic. The non atomic code path
|
||||
// simply does not execute some code so if it works in atomic mode will
|
||||
// work in non atomic mode too
|
||||
|
@ -109,10 +116,27 @@ func TestMain(m *testing.M) {
|
|||
sftpdConf.Actions.Command = "/usr/bin/true"
|
||||
sftpdConf.Actions.HTTPNotificationURL = "http://127.0.0.1:8080/"
|
||||
}
|
||||
pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub")
|
||||
privateKeyPath = filepath.Join(homeBasePath, "ssh_key")
|
||||
err = ioutil.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600)
|
||||
if err != nil {
|
||||
logger.WarnToConsole("unable to save public key to file: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600)
|
||||
if err != nil {
|
||||
logger.WarnToConsole("unable to save private key to file: %v", err)
|
||||
}
|
||||
|
||||
sftpd.SetDataProvider(dataProvider)
|
||||
api.SetDataProvider(dataProvider)
|
||||
|
||||
scpPath, err = exec.LookPath("scp")
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
scpPath = ""
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
|
||||
if err := sftpdConf.Initialize(configDir); err != nil {
|
||||
|
@ -1399,6 +1423,503 @@ func TestSSHConnection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Start SCP tests
|
||||
func TestSCPBasicHandling(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.QuotaSize = 6553600
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(131074)
|
||||
expectedQuotaSize := user.UsedQuotaSize + testFileSize
|
||||
expectedQuotaFiles := user.UsedQuotaFiles + 1
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
// test to download a missing file
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("downloading a missing file via scp must fail")
|
||||
}
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading file via scp: %v", err)
|
||||
}
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file via scp: %v", err)
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
t.Errorf("stat for the downloaded file must succeed")
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
||||
}
|
||||
}
|
||||
os.Remove(localPath)
|
||||
user, _, err = api.GetUserByID(user.ID, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("error getting user: %v", err)
|
||||
}
|
||||
if expectedQuotaFiles != user.UsedQuotaFiles {
|
||||
t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles)
|
||||
}
|
||||
if expectedQuotaSize != user.UsedQuotaSize {
|
||||
t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize)
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPUploadFileOverwrite(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(32760)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, filepath.Join("/", testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading file via scp: %v", err)
|
||||
}
|
||||
// test a new upload that must overwrite the existing file
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file via scp: %v", err)
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
t.Errorf("stat for the downloaded file must succeed")
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
||||
}
|
||||
}
|
||||
os.Remove(localPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPRecursive(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testBaseDirName := "test_dir"
|
||||
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
||||
testBaseDirDownName := "test_dir_down"
|
||||
testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName)
|
||||
testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
||||
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName)
|
||||
testFileSize := int64(131074)
|
||||
createTestFile(testFilePath, testFileSize)
|
||||
createTestFile(testFilePath1, testFileSize)
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName))
|
||||
// test to download a missing dir
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
||||
if err == nil {
|
||||
t.Errorf("downloading a missing dir via scp must fail")
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading dir via scp: %v", err)
|
||||
}
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading dir via scp: %v", err)
|
||||
}
|
||||
// test download without passing -r
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false)
|
||||
if err == nil {
|
||||
t.Errorf("recursive download without -r must fail")
|
||||
}
|
||||
fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName))
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file using scp recursive: %v", err)
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
||||
}
|
||||
}
|
||||
fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName))
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file using scp recursive: %v", err)
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
||||
}
|
||||
}
|
||||
// upload to a non existent dir
|
||||
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir")
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("uploading via scp to a non existent dir must fail")
|
||||
}
|
||||
os.RemoveAll(testBaseDirPath)
|
||||
os.RemoveAll(testBaseDirDownPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermCreateDirs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermUpload}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(32760)
|
||||
testBaseDirName := "test_dir"
|
||||
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
||||
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
err = createTestFile(testFilePath1, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot create new dirs")
|
||||
}
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot create new dirs")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
os.RemoveAll(testBaseDirPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermUpload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65536)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot upload")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermDownload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65537)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "tmp")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/tmp", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the user cannot download")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPQuotaSize(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
testFileSize := int64(65535)
|
||||
u := getTestUser(usePubKey)
|
||||
u.QuotaFiles = 1
|
||||
u.QuotaSize = testFileSize - 1
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
err = scpUpload(testFilePath, remoteUpPath+".quota", true)
|
||||
if err == nil {
|
||||
t.Errorf("user is over quota scp upload must fail")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPEscapeHomeDir(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
os.MkdirAll(user.GetHomeDir(), 0777)
|
||||
testDir := "testDir"
|
||||
linkPath := filepath.Join(homeBasePath, defaultUsername, testDir)
|
||||
err = os.Symlink(homeBasePath, linkPath)
|
||||
if err != nil {
|
||||
t.Errorf("error making local symlink: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir))
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("uploading to a dir with a symlink outside home dir must fail")
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the requested file has a symlink outside user home")
|
||||
}
|
||||
remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir))
|
||||
err = scpDownload(homeBasePath, remoteDownPath, false, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the requested dir is a symlink outside user home")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPUploadPaths(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
testDirName := "testDir"
|
||||
testDirPath := filepath.Join(user.GetHomeDir(), testDirName)
|
||||
os.MkdirAll(testDirPath, 0777)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName)
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err != nil {
|
||||
t.Errorf("scp upload error: %v", err)
|
||||
}
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("scp download error: %v", err)
|
||||
}
|
||||
// upload a file to a missing dir
|
||||
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload to a missing dir must fail")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPOverwriteDirWithFile(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
testDirPath := filepath.Join(user.GetHomeDir(), testFileName)
|
||||
os.MkdirAll(testDirPath, 0777)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("copying a file over an existing dir must fail")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// End SCP tests
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
|
@ -1487,6 +2008,10 @@ func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error)
|
|||
}
|
||||
|
||||
func createTestFile(path string, size int64) error {
|
||||
baseDir := filepath.Dir(path)
|
||||
if _, err := os.Stat(baseDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(baseDir, 0777)
|
||||
}
|
||||
content := make([]byte, size)
|
||||
_, err := rand.Read(content)
|
||||
if err != nil {
|
||||
|
@ -1572,6 +2097,49 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe
|
|||
return c
|
||||
}
|
||||
|
||||
func scpUpload(localPath, remotePath string, preserveTime bool) error {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err == nil {
|
||||
if fi.IsDir() {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, localPath)
|
||||
args = append(args, remotePath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
if recursive {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, remotePath)
|
||||
args = append(args, localPath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func waitForActiveTransfer() {
|
||||
stats := sftpd.GetConnectionsStats()
|
||||
for len(stats) < 1 {
|
||||
|
|
|
@ -31,6 +31,7 @@ type Transfer struct {
|
|||
transferType int
|
||||
lastActivity time.Time
|
||||
isNewFile bool
|
||||
protocol string
|
||||
}
|
||||
|
||||
// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
|
||||
|
@ -64,10 +65,10 @@ func (t *Transfer) Close() error {
|
|||
}
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
if t.transferType == transferDownload {
|
||||
logger.TransferLog(sftpdDownloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID)
|
||||
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationDownload, t.user.Username, t.path, "")
|
||||
} else {
|
||||
logger.TransferLog(sftpUploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID)
|
||||
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
|
||||
executeAction(operationUpload, t.user.Username, t.path, "")
|
||||
}
|
||||
removeTransfer(t)
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
"command": "",
|
||||
"http_notification_url": ""
|
||||
},
|
||||
"keys": []
|
||||
"keys": [],
|
||||
"enable_scp": false
|
||||
},
|
||||
"data_provider": {
|
||||
"driver": "sqlite",
|
||||
|
|
Loading…
Reference in a new issue