diff --git a/api/api_test.go b/api/api_test.go index 7bb57d80..5de07c07 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -96,50 +96,6 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } -func waitTCPListening(address string) { - for { - conn, err := net.Dial("tcp", address) - if err != nil { - fmt.Printf("tcp server %v not listening: %v\n", address, err) - time.Sleep(100 * time.Millisecond) - continue - } - fmt.Printf("tcp server %v now listening\n", address) - defer conn.Close() - break - } -} - -func getTestUser() dataprovider.User { - return dataprovider.User{ - Username: defaultUsername, - Password: defaultPassword, - HomeDir: filepath.Join(homeBasePath, defaultUsername), - Permissions: defaultPerms, - } -} - -func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { - json, err := json.Marshal(user) - if err != nil { - t.Errorf("error get user as json: %v", err) - return []byte("{}") - } - return json -} - -func executeRequest(req *http.Request) *httptest.ResponseRecorder { - rr := httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(rr, req) - return rr -} - -func checkResponseCode(t *testing.T, expected, actual int) { - if expected != actual { - t.Errorf("Expected response code %d. Got %d", expected, actual) - } -} - func TestBasicUserHandling(t *testing.T) { user, err := api.AddUser(getTestUser(), http.StatusOK) if err != nil { @@ -669,3 +625,47 @@ func TestMethodNotAllowedMock(t *testing.T) { rr := executeRequest(req) checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code) } + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + fmt.Printf("tcp server %v not listening: %v\n", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + fmt.Printf("tcp server %v now listening\n", address) + defer conn.Close() + break + } +} + +func getTestUser() dataprovider.User { + return dataprovider.User{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Permissions: defaultPerms, + } +} + +func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { + json, err := json.Marshal(user) + if err != nil { + t.Errorf("error get user as json: %v", err) + return []byte("{}") + } + return json +} + +func executeRequest(req *http.Request) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + return rr +} + +func checkResponseCode(t *testing.T, expected, actual int) { + if expected != actual { + t.Errorf("Expected response code %d. Got %d", expected, actual) + } +} diff --git a/sftpd/sftpd.go b/sftpd/sftpd.go index 0fdbe897..ae56d3f1 100644 --- a/sftpd/sftpd.go +++ b/sftpd/sftpd.go @@ -185,12 +185,13 @@ func startIdleTimer(maxIdleTime time.Duration) { go func() { for t := range idleConnectionTicker.C { logger.Debug(logSender, "idle connections check ticker %v", t) - checkIdleConnections() + CheckIdleConnections() } }() } -func checkIdleConnections() { +// CheckIdleConnections disconnects idle clients +func CheckIdleConnections() { mutex.RLock() defer mutex.RUnlock() for _, c := range openConnections { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index eb0f3df9..53260f26 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -135,123 +135,6 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } -func waitTCPListening(address string) { - for { - conn, err := net.Dial("tcp", address) - if err != nil { - fmt.Printf("tcp server %v not listening: %v\n", address, err) - time.Sleep(100 * time.Millisecond) - continue - } - fmt.Printf("tcp server %v now listening\n", address) - defer conn.Close() - break - } -} - -func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) { - var sftpClient *sftp.Client - config := &ssh.ClientConfig{ - User: defaultUsername, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return nil - }, - } - if usePubKey { - key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) - if err != nil { - return nil, err - } - config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)} - } else { - config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} - } - conn, err := ssh.Dial("tcp", sftpServerAddr, config) - if err != nil { - return sftpClient, err - } - sftpClient, err = sftp.NewClient(conn) - return sftpClient, err -} - -func createTestFile(path string, size int64) error { - content := make([]byte, size) - _, err := rand.Read(content) - if err != nil { - return err - } - return ioutil.WriteFile(path, content, 0666) -} - -func getTestUser(usePubKey bool) dataprovider.User { - user := dataprovider.User{ - Username: defaultUsername, - Password: defaultPassword, - HomeDir: filepath.Join(homeBasePath, defaultUsername), - Permissions: allPerms, - } - if usePubKey { - user.PublicKey = testPubKey - user.Password = "" - } - return user -} - -func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { - srcFile, err := os.Open(localSourcePath) - if err != nil { - return err - } - defer srcFile.Close() - destFile, err := client.Create(remoteDestPath) - if err != nil { - return err - } - defer destFile.Close() - _, err = io.Copy(destFile, srcFile) - if expectedSize > 0 { - fi, err := client.Lstat(remoteDestPath) - if err != nil { - return err - } - if fi.Size() != expectedSize { - return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) - } - } - return err -} - -func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error { - downloadDest, err := os.Create(localDestPath) - if err != nil { - return err - } - defer downloadDest.Close() - sftpSrcFile, err := client.Open(remoteSourcePath) - if err != nil { - return err - } - defer sftpSrcFile.Close() - _, err = io.Copy(downloadDest, sftpSrcFile) - if err != nil { - return err - } - err = downloadDest.Sync() - if err != nil { - return err - } - if expectedSize > 0 { - fi, err := downloadDest.Stat() - if err != nil { - return err - } - if fi.Size() != expectedSize { - return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) - } - } - return err -} - func TestBasicSFTPHandling(t *testing.T) { usePubKey := false user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK) @@ -647,7 +530,6 @@ func TestQuotaScan(t *testing.T) { defer client.Close() testFileName := "test_file.dat" testFilePath := filepath.Join(homeBasePath, testFileName) - testFileSize := int64(65535) err = createTestFile(testFilePath, testFileSize) if err != nil { t.Errorf("unable to create test file: %v", err) @@ -697,6 +579,118 @@ func TestQuotaScan(t *testing.T) { } } +func TestMultipleQuotaScans(t *testing.T) { + if !sftpd.AddQuotaScan(defaultUsername) { + t.Errorf("add quota failed") + } + if sftpd.AddQuotaScan(defaultUsername) { + t.Errorf("add quota must fail if another scan is already active") + } + sftpd.RemoveQuotaScan(defaultPassword) +} + +func TestQuotaSize(t *testing.T) { + usePubKey := false + testFileSize := int64(65535) + u := getTestUser(usePubKey) + u.QuotaFiles = 1 + u.QuotaSize = testFileSize - 1 + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + err = sftpUploadFile(testFilePath, testFileName+".1", testFileSize, client) + if err == nil { + t.Errorf("user is over quota file upload must fail") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + +func TestBandwidthAndConnections(t *testing.T) { + usePubKey := false + testFileSize := int64(131072) + u := getTestUser(usePubKey) + u.UploadBandwidth = 30 + u.DownloadBandwidth = 25 + wantedUploadElapsed := 1000 * (testFileSize / 1000) / u.UploadBandwidth + wantedDownloadElapsed := 1000 * (testFileSize / 1000) / u.DownloadBandwidth + // 100 ms tolerance + wantedUploadElapsed -= 100 + wantedDownloadElapsed -= 100 + user, err := api.AddUser(u, http.StatusOK) + if err != nil { + t.Errorf("unable to add user: %v", err) + } + client, err := getSftpClient(user, usePubKey) + if err != nil { + t.Errorf("unable to create sftp client: %v", err) + } else { + defer client.Close() + testFileName := "test_file.dat" + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + if err != nil { + t.Errorf("unable to create test file: %v", err) + } + startTime := time.Now() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if err != nil { + t.Errorf("file upload error: %v", err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + if elapsed < (wantedUploadElapsed) { + t.Errorf("upload bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedUploadElapsed) + } + startTime = time.Now() + localDownloadPath := filepath.Join(homeBasePath, "test_download.dat") + c := sftpDownloadNonBlocking(testFileName, localDownloadPath, testFileSize, client) + waitForActiveTransfer() + err = <-c + if err != nil { + t.Errorf("file download error: %v", err) + } + elapsed = time.Since(startTime).Nanoseconds() / 1000000 + if elapsed < (wantedDownloadElapsed) { + t.Errorf("download bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedDownloadElapsed) + } + // test disconnection + c = sftpUploadNonBlocking(testFilePath, testFileName, testFileSize, client) + waitForActiveTransfer() + sftpd.CheckIdleConnections() + stats := sftpd.GetConnectionsStats() + for _, stat := range stats { + sftpd.CloseActiveConnection(stat.ConnectionID) + } + err = <-c + if err == nil { + t.Errorf("connection closed upload must fail") + } + } + err = api.RemoveUser(user, http.StatusOK) + if err != nil { + t.Errorf("unable to remove user: %v", err) + } +} + func TestPermList(t *testing.T) { usePubKey := true u := getTestUser(usePubKey) @@ -935,3 +929,155 @@ func TestPermSymlink(t *testing.T) { t.Errorf("unable to remove user: %v", err) } } + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + fmt.Printf("tcp server %v not listening: %v\n", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + fmt.Printf("tcp server %v now listening\n", address) + defer conn.Close() + break + } +} + +func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: defaultUsername, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + if usePubKey { + key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + return nil, err + } + config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + return sftpClient, err +} + +func createTestFile(path string, size int64) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + return ioutil.WriteFile(path, content, 0666) +} + +func getTestUser(usePubKey bool) dataprovider.User { + user := dataprovider.User{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Permissions: allPerms, + } + if usePubKey { + user.PublicKey = testPubKey + user.Password = "" + } + return user +} + +func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + destFile, err := client.Create(remoteDestPath) + if err != nil { + return err + } + defer destFile.Close() + _, err = io.Copy(destFile, srcFile) + if expectedSize > 0 { + fi, err := client.Lstat(remoteDestPath) + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error { + downloadDest, err := os.Create(localDestPath) + if err != nil { + return err + } + defer downloadDest.Close() + sftpSrcFile, err := client.Open(remoteSourcePath) + if err != nil { + return err + } + defer sftpSrcFile.Close() + _, err = io.Copy(downloadDest, sftpSrcFile) + if err != nil { + return err + } + err = downloadDest.Sync() + if err != nil { + return err + } + if expectedSize > 0 { + fi, err := downloadDest.Stat() + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error { + c := make(chan error) + go func() { + c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client) + }() + return c +} + +func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error { + c := make(chan error) + go func() { + c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client) + }() + return c +} + +func waitForActiveTransfer() { + stats := sftpd.GetConnectionsStats() + for len(stats) < 1 { + stats = sftpd.GetConnectionsStats() + } + activeTransferFound := false + for !activeTransferFound { + stats = sftpd.GetConnectionsStats() + if len(stats) == 0 { + break + } + for _, stat := range stats { + if len(stat.Transfers) > 0 { + activeTransferFound = true + } + } + } +}