|
@@ -8,6 +8,8 @@ import (
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"os"
|
|
|
+ "os/exec"
|
|
|
+ "path"
|
|
|
"path/filepath"
|
|
|
"runtime"
|
|
|
"testing"
|
|
@@ -77,8 +79,11 @@ iixITGvaNZh/tjAAAACW5pY29sYUBwMQE=
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
- allPerms = []string{dataprovider.PermAny}
|
|
|
- homeBasePath string
|
|
|
+ allPerms = []string{dataprovider.PermAny}
|
|
|
+ homeBasePath string
|
|
|
+ scpPath string
|
|
|
+ pubKeyPath string
|
|
|
+ privateKeyPath string
|
|
|
)
|
|
|
|
|
|
func TestMain(m *testing.M) {
|
|
@@ -97,6 +102,8 @@ func TestMain(m *testing.M) {
|
|
|
httpdConf := config.GetHTTPDConfig()
|
|
|
router := api.GetHTTPRouter()
|
|
|
sftpdConf.BindPort = 2022
|
|
|
+ // we need to test SCP support
|
|
|
+ sftpdConf.IsSCPEnabled = true
|
|
|
// we run the test cases with UploadMode atomic. The non atomic code path
|
|
|
// simply does not execute some code so if it works in atomic mode will
|
|
|
// work in non atomic mode too
|
|
@@ -109,10 +116,27 @@ func TestMain(m *testing.M) {
|
|
|
sftpdConf.Actions.Command = "/usr/bin/true"
|
|
|
sftpdConf.Actions.HTTPNotificationURL = "http://127.0.0.1:8080/"
|
|
|
}
|
|
|
+ pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub")
|
|
|
+ privateKeyPath = filepath.Join(homeBasePath, "ssh_key")
|
|
|
+ err = ioutil.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600)
|
|
|
+ if err != nil {
|
|
|
+ logger.WarnToConsole("unable to save public key to file: %v", err)
|
|
|
+ }
|
|
|
+ err = ioutil.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600)
|
|
|
+ if err != nil {
|
|
|
+ logger.WarnToConsole("unable to save private key to file: %v", err)
|
|
|
+ }
|
|
|
|
|
|
sftpd.SetDataProvider(dataProvider)
|
|
|
api.SetDataProvider(dataProvider)
|
|
|
|
|
|
+ scpPath, err = exec.LookPath("scp")
|
|
|
+ if err != nil {
|
|
|
+ logger.Warn(logSender, "unable to get scp command. SCP tests will be skipped, err: %v", err)
|
|
|
+ logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
|
|
|
+ scpPath = ""
|
|
|
+ }
|
|
|
+
|
|
|
go func() {
|
|
|
logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
|
|
|
if err := sftpdConf.Initialize(configDir); err != nil {
|
|
@@ -1399,6 +1423,503 @@ func TestSSHConnection(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// Start SCP tests
|
|
|
+func TestSCPBasicHandling(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ u.QuotaSize = 6553600
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(131074)
|
|
|
+ expectedQuotaSize := user.UsedQuotaSize + testFileSize
|
|
|
+ expectedQuotaFiles := user.UsedQuotaFiles + 1
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
|
|
+ localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
|
|
+ // test to download a missing file
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("downloading a missing file via scp must fail")
|
|
|
+ }
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading file via scp: %v", err)
|
|
|
+ }
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error downloading file via scp: %v", err)
|
|
|
+ }
|
|
|
+ fi, err := os.Stat(localPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("stat for the downloaded file must succeed")
|
|
|
+ } else {
|
|
|
+ if fi.Size() != testFileSize {
|
|
|
+ t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ os.Remove(localPath)
|
|
|
+ user, _, err = api.GetUserByID(user.ID, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error getting user: %v", err)
|
|
|
+ }
|
|
|
+ if expectedQuotaFiles != user.UsedQuotaFiles {
|
|
|
+ t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles)
|
|
|
+ }
|
|
|
+ if expectedQuotaSize != user.UsedQuotaSize {
|
|
|
+ t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize)
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPUploadFileOverwrite(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(32760)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, filepath.Join("/", testFileName))
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading file via scp: %v", err)
|
|
|
+ }
|
|
|
+ // test a new upload that must overwrite the existing file
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading existing file via scp: %v", err)
|
|
|
+ }
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
|
|
+ localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error downloading file via scp: %v", err)
|
|
|
+ }
|
|
|
+ fi, err := os.Stat(localPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("stat for the downloaded file must succeed")
|
|
|
+ } else {
|
|
|
+ if fi.Size() != testFileSize {
|
|
|
+ t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ os.Remove(localPath)
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPRecursive(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testBaseDirName := "test_dir"
|
|
|
+ testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
|
|
+ testBaseDirDownName := "test_dir_down"
|
|
|
+ testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName)
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
|
|
+ testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName)
|
|
|
+ testFileSize := int64(131074)
|
|
|
+ createTestFile(testFilePath, testFileSize)
|
|
|
+ createTestFile(testFilePath1, testFileSize)
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName))
|
|
|
+ // test to download a missing dir
|
|
|
+ err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("downloading a missing dir via scp must fail")
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
|
|
+ err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading dir via scp: %v", err)
|
|
|
+ }
|
|
|
+ err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error downloading dir via scp: %v", err)
|
|
|
+ }
|
|
|
+ // test download without passing -r
|
|
|
+ err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("recursive download without -r must fail")
|
|
|
+ }
|
|
|
+ fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName))
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error downloading file using scp recursive: %v", err)
|
|
|
+ } else {
|
|
|
+ if fi.Size() != testFileSize {
|
|
|
+ t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName))
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error downloading file using scp recursive: %v", err)
|
|
|
+ } else {
|
|
|
+ if fi.Size() != testFileSize {
|
|
|
+ t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // upload to a non existent dir
|
|
|
+ remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir")
|
|
|
+ err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("uploading via scp to a non existent dir must fail")
|
|
|
+ }
|
|
|
+ os.RemoveAll(testBaseDirPath)
|
|
|
+ os.RemoveAll(testBaseDirDownPath)
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPPermCreateDirs(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermUpload}
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(32760)
|
|
|
+ testBaseDirName := "test_dir"
|
|
|
+ testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
|
|
+ testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ err = createTestFile(testFilePath1, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/")
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp upload must fail, the user cannot create new dirs")
|
|
|
+ }
|
|
|
+ err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp upload must fail, the user cannot create new dirs")
|
|
|
+ }
|
|
|
+ err = os.Remove(testFilePath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing test file")
|
|
|
+ }
|
|
|
+ os.RemoveAll(testBaseDirPath)
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPPermUpload(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs}
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65536)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp")
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp upload must fail, the user cannot upload")
|
|
|
+ }
|
|
|
+ err = os.Remove(testFilePath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing test file")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPPermDownload(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ u.Permissions = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs}
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65537)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "tmp")
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading existing file via scp: %v", err)
|
|
|
+ }
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/tmp", testFileName))
|
|
|
+ localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp download must fail, the user cannot download")
|
|
|
+ }
|
|
|
+ err = os.Remove(testFilePath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing test file")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPQuotaSize(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ testFileSize := int64(65535)
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ u.QuotaFiles = 1
|
|
|
+ u.QuotaSize = testFileSize - 1
|
|
|
+ user, _, err := api.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, true)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error uploading existing file via scp: %v", err)
|
|
|
+ }
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath+".quota", true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("user is over quota scp upload must fail")
|
|
|
+ }
|
|
|
+ err = os.Remove(testFilePath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing test file")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPEscapeHomeDir(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ os.MkdirAll(user.GetHomeDir(), 0777)
|
|
|
+ testDir := "testDir"
|
|
|
+ linkPath := filepath.Join(homeBasePath, defaultUsername, testDir)
|
|
|
+ err = os.Symlink(homeBasePath, linkPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error making local symlink: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65535)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir))
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("uploading to a dir with a symlink outside home dir must fail")
|
|
|
+ }
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName))
|
|
|
+ localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp download must fail, the requested file has a symlink outside user home")
|
|
|
+ }
|
|
|
+ remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir))
|
|
|
+ err = scpDownload(homeBasePath, remoteDownPath, false, true)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp download must fail, the requested dir is a symlink outside user home")
|
|
|
+ }
|
|
|
+ err = os.Remove(testFilePath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing test file")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPUploadPaths(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65535)
|
|
|
+ testDirName := "testDir"
|
|
|
+ testDirPath := filepath.Join(user.GetHomeDir(), testDirName)
|
|
|
+ os.MkdirAll(testDirPath, 0777)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName)
|
|
|
+ remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName))
|
|
|
+ localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("scp upload error: %v", err)
|
|
|
+ }
|
|
|
+ err = scpDownload(localPath, remoteDownPath, false, false)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("scp download error: %v", err)
|
|
|
+ }
|
|
|
+ // upload a file to a missing dir
|
|
|
+ remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName))
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("scp upload to a missing dir must fail")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestSCPOverwriteDirWithFile(t *testing.T) {
|
|
|
+ if len(scpPath) == 0 {
|
|
|
+ t.Skip("scp command not found, unable to execute this test")
|
|
|
+ }
|
|
|
+ usePubKey := true
|
|
|
+ user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65535)
|
|
|
+ testDirPath := filepath.Join(user.GetHomeDir(), testFileName)
|
|
|
+ os.MkdirAll(testDirPath, 0777)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
|
|
+ err = scpUpload(testFilePath, remoteUpPath, false)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("copying a file over an existing dir must fail")
|
|
|
+ }
|
|
|
+ err = os.RemoveAll(user.GetHomeDir())
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error removing uploaded files")
|
|
|
+ }
|
|
|
+ _, err = api.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// End SCP tests
|
|
|
+
|
|
|
func waitTCPListening(address string) {
|
|
|
for {
|
|
|
conn, err := net.Dial("tcp", address)
|
|
@@ -1487,6 +2008,10 @@ func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error)
|
|
|
}
|
|
|
|
|
|
func createTestFile(path string, size int64) error {
|
|
|
+ baseDir := filepath.Dir(path)
|
|
|
+ if _, err := os.Stat(baseDir); os.IsNotExist(err) {
|
|
|
+ os.MkdirAll(baseDir, 0777)
|
|
|
+ }
|
|
|
content := make([]byte, size)
|
|
|
_, err := rand.Read(content)
|
|
|
if err != nil {
|
|
@@ -1572,6 +2097,49 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
+func scpUpload(localPath, remotePath string, preserveTime bool) error {
|
|
|
+ var args []string
|
|
|
+ if preserveTime {
|
|
|
+ args = append(args, "-p")
|
|
|
+ }
|
|
|
+ fi, err := os.Stat(localPath)
|
|
|
+ if err == nil {
|
|
|
+ if fi.IsDir() {
|
|
|
+ args = append(args, "-r")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ args = append(args, "-P")
|
|
|
+ args = append(args, "2022")
|
|
|
+ args = append(args, "-o")
|
|
|
+ args = append(args, "StrictHostKeyChecking=no")
|
|
|
+ args = append(args, "-i")
|
|
|
+ args = append(args, privateKeyPath)
|
|
|
+ args = append(args, localPath)
|
|
|
+ args = append(args, remotePath)
|
|
|
+ cmd := exec.Command(scpPath, args...)
|
|
|
+ return cmd.Run()
|
|
|
+}
|
|
|
+
|
|
|
+func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
|
|
|
+ var args []string
|
|
|
+ if preserveTime {
|
|
|
+ args = append(args, "-p")
|
|
|
+ }
|
|
|
+ if recursive {
|
|
|
+ args = append(args, "-r")
|
|
|
+ }
|
|
|
+ args = append(args, "-P")
|
|
|
+ args = append(args, "2022")
|
|
|
+ args = append(args, "-o")
|
|
|
+ args = append(args, "StrictHostKeyChecking=no")
|
|
|
+ args = append(args, "-i")
|
|
|
+ args = append(args, privateKeyPath)
|
|
|
+ args = append(args, remotePath)
|
|
|
+ args = append(args, localPath)
|
|
|
+ cmd := exec.Command(scpPath, args...)
|
|
|
+ return cmd.Run()
|
|
|
+}
|
|
|
+
|
|
|
func waitForActiveTransfer() {
|
|
|
stats := sftpd.GetConnectionsStats()
|
|
|
for len(stats) < 1 {
|