|
@@ -1,7 +1,9 @@
|
|
|
package sftpd_test
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"crypto/rand"
|
|
|
+ "crypto/sha256"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
@@ -112,10 +114,10 @@ func TestMain(m *testing.M) {
|
|
|
sftpdConf.LoginBannerFile = loginBannerFileName
|
|
|
// we need to test SCP support
|
|
|
sftpdConf.IsSCPEnabled = true
|
|
|
- // we run the test cases with UploadMode atomic. The non atomic code path
|
|
|
+ // we run the test cases with UploadMode atomic and resume support. The non atomic code path
|
|
|
// simply does not execute some code so if it works in atomic mode will
|
|
|
// work in non atomic mode too
|
|
|
- sftpdConf.UploadMode = 1
|
|
|
+ sftpdConf.UploadMode = 2
|
|
|
if runtime.GOOS == "windows" {
|
|
|
homeBasePath = "C:\\"
|
|
|
} else {
|
|
@@ -187,6 +189,7 @@ func TestBasicSFTPHandling(t *testing.T) {
|
|
|
if err != nil {
|
|
|
t.Errorf("unable to add user: %v", err)
|
|
|
}
|
|
|
+ os.RemoveAll(user.GetHomeDir())
|
|
|
client, err := getSftpClient(user, usePubKey)
|
|
|
if err != nil {
|
|
|
t.Errorf("unable to create sftp client: %v", err)
|
|
@@ -246,6 +249,67 @@ func TestBasicSFTPHandling(t *testing.T) {
|
|
|
os.RemoveAll(user.GetHomeDir())
|
|
|
}
|
|
|
|
|
|
+func TestUploadResume(t *testing.T) {
|
|
|
+ usePubKey := false
|
|
|
+ u := getTestUser(usePubKey)
|
|
|
+ user, _, err := httpd.AddUser(u, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to add user: %v", err)
|
|
|
+ }
|
|
|
+ os.RemoveAll(user.GetHomeDir())
|
|
|
+ client, err := getSftpClient(user, usePubKey)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create sftp client: %v", err)
|
|
|
+ } else {
|
|
|
+ defer client.Close()
|
|
|
+ testFileName := "test_file.dat"
|
|
|
+ testFilePath := filepath.Join(homeBasePath, testFileName)
|
|
|
+ testFileSize := int64(65535)
|
|
|
+ appendDataSize := int64(65535)
|
|
|
+ err = createTestFile(testFilePath, testFileSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to create test file: %v", err)
|
|
|
+ }
|
|
|
+ err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("file upload error: %v", err)
|
|
|
+ }
|
|
|
+ err = appendToTestFile(testFilePath, appendDataSize)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to append to test file: %v", err)
|
|
|
+ }
|
|
|
+ err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("file upload resume error: %v", err)
|
|
|
+ }
|
|
|
+ localDownloadPath := filepath.Join(homeBasePath, "test_download.dat")
|
|
|
+ err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize+appendDataSize, client)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("file download error: %v", err)
|
|
|
+ }
|
|
|
+ initialHash, err := computeFileHash(localDownloadPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error computing file hash: %v", err)
|
|
|
+ }
|
|
|
+ donwloadedFileHash, err := computeFileHash(localDownloadPath)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("error computing downloaded file hash: %v", err)
|
|
|
+ }
|
|
|
+ if donwloadedFileHash != initialHash {
|
|
|
+ t.Errorf("resume failed: file hash does not match")
|
|
|
+ }
|
|
|
+ err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, true, client)
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("file upload resume with invalid offset must fail")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _, err = httpd.RemoveUser(user, http.StatusOK)
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unable to remove user: %v", err)
|
|
|
+ }
|
|
|
+ os.RemoveAll(user.GetHomeDir())
|
|
|
+}
|
|
|
+
|
|
|
func TestDirCommands(t *testing.T) {
|
|
|
usePubKey := false
|
|
|
user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
|
|
@@ -2301,6 +2365,26 @@ func createTestFile(path string, size int64) error {
|
|
|
return ioutil.WriteFile(path, content, 0666)
|
|
|
}
|
|
|
|
|
|
+func appendToTestFile(path string, size int64) error {
|
|
|
+ content := make([]byte, size)
|
|
|
+ _, err := rand.Read(content)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0666)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ written, err := io.Copy(f, bytes.NewReader(content))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if int64(written) != size {
|
|
|
+ return fmt.Errorf("write error, written: %v/%v", written, size)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
|
|
|
srcFile, err := os.Open(localSourcePath)
|
|
|
if err != nil {
|
|
@@ -2331,6 +2415,53 @@ func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+func sftpUploadResumeFile(localSourcePath string, remoteDestPath string, expectedSize int64, invalidOffset bool,
|
|
|
+ client *sftp.Client) error {
|
|
|
+ srcFile, err := os.Open(localSourcePath)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ defer srcFile.Close()
|
|
|
+ fi, err := client.Lstat(remoteDestPath)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if !invalidOffset {
|
|
|
+ _, err = srcFile.Seek(fi.Size(), 0)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ destFile, err := client.OpenFile(remoteDestPath, os.O_WRONLY|os.O_APPEND)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if !invalidOffset {
|
|
|
+ _, err = destFile.Seek(fi.Size(), 0)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _, err = io.Copy(destFile, srcFile)
|
|
|
+ if err != nil {
|
|
|
+ destFile.Close()
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ // we need to close the file to trigger the close method on server
|
|
|
+ // we cannot defer closing or Lstat will fail for upload atomic mode
|
|
|
+ destFile.Close()
|
|
|
+ if expectedSize > 0 {
|
|
|
+ fi, err := client.Lstat(remoteDestPath)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if fi.Size() != expectedSize {
|
|
|
+ return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return err
|
|
|
+}
|
|
|
+
|
|
|
func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
|
|
|
downloadDest, err := os.Create(localDestPath)
|
|
|
if err != nil {
|
|
@@ -2432,6 +2563,21 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem
|
|
|
return exec.Command(scpPath, args...)
|
|
|
}
|
|
|
|
|
|
+func computeFileHash(path string) (string, error) {
|
|
|
+ hash := ""
|
|
|
+ f, err := os.Open(path)
|
|
|
+ if err != nil {
|
|
|
+ return hash, err
|
|
|
+ }
|
|
|
+ defer f.Close()
|
|
|
+ h := sha256.New()
|
|
|
+ if _, err := io.Copy(h, f); err != nil {
|
|
|
+ return hash, err
|
|
|
+ }
|
|
|
+ hash = fmt.Sprintf("%x", h.Sum(nil))
|
|
|
+ return hash, err
|
|
|
+}
|
|
|
+
|
|
|
func waitForNoActiveTransfer() {
|
|
|
for len(sftpd.GetConnectionsStats()) > 0 {
|
|
|
time.Sleep(100 * time.Millisecond)
|