Browse Source

add more test cases

Nicola Murino 6 years ago
parent
commit
9f61415832
3 changed files with 311 additions and 164 deletions
  1. 44 44
      api/api_test.go
  2. 3 2
      sftpd/sftpd.go
  3. 264 118
      sftpd/sftpd_test.go

+ 44 - 44
api/api_test.go

@@ -96,50 +96,6 @@ func TestMain(m *testing.M) {
 	os.Exit(exitCode)
 }
 
-func waitTCPListening(address string) {
-	for {
-		conn, err := net.Dial("tcp", address)
-		if err != nil {
-			fmt.Printf("tcp server %v not listening: %v\n", address, err)
-			time.Sleep(100 * time.Millisecond)
-			continue
-		}
-		fmt.Printf("tcp server %v now listening\n", address)
-		defer conn.Close()
-		break
-	}
-}
-
-func getTestUser() dataprovider.User {
-	return dataprovider.User{
-		Username:    defaultUsername,
-		Password:    defaultPassword,
-		HomeDir:     filepath.Join(homeBasePath, defaultUsername),
-		Permissions: defaultPerms,
-	}
-}
-
-func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
-	json, err := json.Marshal(user)
-	if err != nil {
-		t.Errorf("error get user as json: %v", err)
-		return []byte("{}")
-	}
-	return json
-}
-
-func executeRequest(req *http.Request) *httptest.ResponseRecorder {
-	rr := httptest.NewRecorder()
-	testServer.Config.Handler.ServeHTTP(rr, req)
-	return rr
-}
-
-func checkResponseCode(t *testing.T, expected, actual int) {
-	if expected != actual {
-		t.Errorf("Expected response code %d. Got %d", expected, actual)
-	}
-}
-
 func TestBasicUserHandling(t *testing.T) {
 	user, err := api.AddUser(getTestUser(), http.StatusOK)
 	if err != nil {
@@ -669,3 +625,47 @@ func TestMethodNotAllowedMock(t *testing.T) {
 	rr := executeRequest(req)
 	checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code)
 }
+
+func waitTCPListening(address string) {
+	for {
+		conn, err := net.Dial("tcp", address)
+		if err != nil {
+			fmt.Printf("tcp server %v not listening: %v\n", address, err)
+			time.Sleep(100 * time.Millisecond)
+			continue
+		}
+		fmt.Printf("tcp server %v now listening\n", address)
+		defer conn.Close()
+		break
+	}
+}
+
+func getTestUser() dataprovider.User {
+	return dataprovider.User{
+		Username:    defaultUsername,
+		Password:    defaultPassword,
+		HomeDir:     filepath.Join(homeBasePath, defaultUsername),
+		Permissions: defaultPerms,
+	}
+}
+
+func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
+	json, err := json.Marshal(user)
+	if err != nil {
+		t.Errorf("error get user as json: %v", err)
+		return []byte("{}")
+	}
+	return json
+}
+
+func executeRequest(req *http.Request) *httptest.ResponseRecorder {
+	rr := httptest.NewRecorder()
+	testServer.Config.Handler.ServeHTTP(rr, req)
+	return rr
+}
+
+func checkResponseCode(t *testing.T, expected, actual int) {
+	if expected != actual {
+		t.Errorf("Expected response code %d. Got %d", expected, actual)
+	}
+}

+ 3 - 2
sftpd/sftpd.go

@@ -185,12 +185,13 @@ func startIdleTimer(maxIdleTime time.Duration) {
 	go func() {
 		for t := range idleConnectionTicker.C {
 			logger.Debug(logSender, "idle connections check ticker %v", t)
-			checkIdleConnections()
+			CheckIdleConnections()
 		}
 	}()
 }
 
-func checkIdleConnections() {
+// CheckIdleConnections disconnects idle clients
+func CheckIdleConnections() {
 	mutex.RLock()
 	defer mutex.RUnlock()
 	for _, c := range openConnections {

+ 264 - 118
sftpd/sftpd_test.go

@@ -135,123 +135,6 @@ func TestMain(m *testing.M) {
 	os.Exit(exitCode)
 }
 
-func waitTCPListening(address string) {
-	for {
-		conn, err := net.Dial("tcp", address)
-		if err != nil {
-			fmt.Printf("tcp server %v not listening: %v\n", address, err)
-			time.Sleep(100 * time.Millisecond)
-			continue
-		}
-		fmt.Printf("tcp server %v now listening\n", address)
-		defer conn.Close()
-		break
-	}
-}
-
-func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
-	var sftpClient *sftp.Client
-	config := &ssh.ClientConfig{
-		User: defaultUsername,
-		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
-			return nil
-		},
-	}
-	if usePubKey {
-		key, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
-		if err != nil {
-			return nil, err
-		}
-		config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)}
-	} else {
-		config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
-	}
-	conn, err := ssh.Dial("tcp", sftpServerAddr, config)
-	if err != nil {
-		return sftpClient, err
-	}
-	sftpClient, err = sftp.NewClient(conn)
-	return sftpClient, err
-}
-
-func createTestFile(path string, size int64) error {
-	content := make([]byte, size)
-	_, err := rand.Read(content)
-	if err != nil {
-		return err
-	}
-	return ioutil.WriteFile(path, content, 0666)
-}
-
-func getTestUser(usePubKey bool) dataprovider.User {
-	user := dataprovider.User{
-		Username:    defaultUsername,
-		Password:    defaultPassword,
-		HomeDir:     filepath.Join(homeBasePath, defaultUsername),
-		Permissions: allPerms,
-	}
-	if usePubKey {
-		user.PublicKey = testPubKey
-		user.Password = ""
-	}
-	return user
-}
-
-func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
-	srcFile, err := os.Open(localSourcePath)
-	if err != nil {
-		return err
-	}
-	defer srcFile.Close()
-	destFile, err := client.Create(remoteDestPath)
-	if err != nil {
-		return err
-	}
-	defer destFile.Close()
-	_, err = io.Copy(destFile, srcFile)
-	if expectedSize > 0 {
-		fi, err := client.Lstat(remoteDestPath)
-		if err != nil {
-			return err
-		}
-		if fi.Size() != expectedSize {
-			return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
-		}
-	}
-	return err
-}
-
-func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
-	downloadDest, err := os.Create(localDestPath)
-	if err != nil {
-		return err
-	}
-	defer downloadDest.Close()
-	sftpSrcFile, err := client.Open(remoteSourcePath)
-	if err != nil {
-		return err
-	}
-	defer sftpSrcFile.Close()
-	_, err = io.Copy(downloadDest, sftpSrcFile)
-	if err != nil {
-		return err
-	}
-	err = downloadDest.Sync()
-	if err != nil {
-		return err
-	}
-	if expectedSize > 0 {
-		fi, err := downloadDest.Stat()
-		if err != nil {
-			return err
-		}
-		if fi.Size() != expectedSize {
-			return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
-		}
-	}
-	return err
-}
-
 func TestBasicSFTPHandling(t *testing.T) {
 	usePubKey := false
 	user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
@@ -647,7 +530,6 @@ func TestQuotaScan(t *testing.T) {
 		defer client.Close()
 		testFileName := "test_file.dat"
 		testFilePath := filepath.Join(homeBasePath, testFileName)
-		testFileSize := int64(65535)
 		err = createTestFile(testFilePath, testFileSize)
 		if err != nil {
 			t.Errorf("unable to create test file: %v", err)
@@ -697,6 +579,118 @@ func TestQuotaScan(t *testing.T) {
 	}
 }
 
+func TestMultipleQuotaScans(t *testing.T) {
+	if !sftpd.AddQuotaScan(defaultUsername) {
+		t.Errorf("add quota failed")
+	}
+	if sftpd.AddQuotaScan(defaultUsername) {
+		t.Errorf("add quota must fail if another scan is already active")
+	}
+	sftpd.RemoveQuotaScan(defaultPassword)
+}
+
+func TestQuotaSize(t *testing.T) {
+	usePubKey := false
+	testFileSize := int64(65535)
+	u := getTestUser(usePubKey)
+	u.QuotaFiles = 1
+	u.QuotaSize = testFileSize - 1
+	user, err := api.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	client, err := getSftpClient(user, usePubKey)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		testFileName := "test_file.dat"
+		testFilePath := filepath.Join(homeBasePath, testFileName)
+		err = createTestFile(testFilePath, testFileSize)
+		if err != nil {
+			t.Errorf("unable to create test file: %v", err)
+		}
+		err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
+		if err != nil {
+			t.Errorf("file upload error: %v", err)
+		}
+		err = sftpUploadFile(testFilePath, testFileName+".1", testFileSize, client)
+		if err == nil {
+			t.Errorf("user is over quota file upload must fail")
+		}
+	}
+	err = api.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+}
+
+func TestBandwidthAndConnections(t *testing.T) {
+	usePubKey := false
+	testFileSize := int64(131072)
+	u := getTestUser(usePubKey)
+	u.UploadBandwidth = 30
+	u.DownloadBandwidth = 25
+	wantedUploadElapsed := 1000 * (testFileSize / 1000) / u.UploadBandwidth
+	wantedDownloadElapsed := 1000 * (testFileSize / 1000) / u.DownloadBandwidth
+	// 100 ms tolerance
+	wantedUploadElapsed -= 100
+	wantedDownloadElapsed -= 100
+	user, err := api.AddUser(u, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to add user: %v", err)
+	}
+	client, err := getSftpClient(user, usePubKey)
+	if err != nil {
+		t.Errorf("unable to create sftp client: %v", err)
+	} else {
+		defer client.Close()
+		testFileName := "test_file.dat"
+		testFilePath := filepath.Join(homeBasePath, testFileName)
+		err = createTestFile(testFilePath, testFileSize)
+		if err != nil {
+			t.Errorf("unable to create test file: %v", err)
+		}
+		startTime := time.Now()
+		err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
+		if err != nil {
+			t.Errorf("file upload error: %v", err)
+		}
+		elapsed := time.Since(startTime).Nanoseconds() / 1000000
+		if elapsed < (wantedUploadElapsed) {
+			t.Errorf("upload bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedUploadElapsed)
+		}
+		startTime = time.Now()
+		localDownloadPath := filepath.Join(homeBasePath, "test_download.dat")
+		c := sftpDownloadNonBlocking(testFileName, localDownloadPath, testFileSize, client)
+		waitForActiveTransfer()
+		err = <-c
+		if err != nil {
+			t.Errorf("file download error: %v", err)
+		}
+		elapsed = time.Since(startTime).Nanoseconds() / 1000000
+		if elapsed < (wantedDownloadElapsed) {
+			t.Errorf("download bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedDownloadElapsed)
+		}
+		// test disconnection
+		c = sftpUploadNonBlocking(testFilePath, testFileName, testFileSize, client)
+		waitForActiveTransfer()
+		sftpd.CheckIdleConnections()
+		stats := sftpd.GetConnectionsStats()
+		for _, stat := range stats {
+			sftpd.CloseActiveConnection(stat.ConnectionID)
+		}
+		err = <-c
+		if err == nil {
+			t.Errorf("connection closed upload must fail")
+		}
+	}
+	err = api.RemoveUser(user, http.StatusOK)
+	if err != nil {
+		t.Errorf("unable to remove user: %v", err)
+	}
+}
+
 func TestPermList(t *testing.T) {
 	usePubKey := true
 	u := getTestUser(usePubKey)
@@ -935,3 +929,155 @@ func TestPermSymlink(t *testing.T) {
 		t.Errorf("unable to remove user: %v", err)
 	}
 }
+
+func waitTCPListening(address string) {
+	for {
+		conn, err := net.Dial("tcp", address)
+		if err != nil {
+			fmt.Printf("tcp server %v not listening: %v\n", address, err)
+			time.Sleep(100 * time.Millisecond)
+			continue
+		}
+		fmt.Printf("tcp server %v now listening\n", address)
+		defer conn.Close()
+		break
+	}
+}
+
+func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
+	var sftpClient *sftp.Client
+	config := &ssh.ClientConfig{
+		User: defaultUsername,
+		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+			return nil
+		},
+	}
+	if usePubKey {
+		key, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
+		if err != nil {
+			return nil, err
+		}
+		config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)}
+	} else {
+		config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
+	}
+	conn, err := ssh.Dial("tcp", sftpServerAddr, config)
+	if err != nil {
+		return sftpClient, err
+	}
+	sftpClient, err = sftp.NewClient(conn)
+	return sftpClient, err
+}
+
+func createTestFile(path string, size int64) error {
+	content := make([]byte, size)
+	_, err := rand.Read(content)
+	if err != nil {
+		return err
+	}
+	return ioutil.WriteFile(path, content, 0666)
+}
+
+func getTestUser(usePubKey bool) dataprovider.User {
+	user := dataprovider.User{
+		Username:    defaultUsername,
+		Password:    defaultPassword,
+		HomeDir:     filepath.Join(homeBasePath, defaultUsername),
+		Permissions: allPerms,
+	}
+	if usePubKey {
+		user.PublicKey = testPubKey
+		user.Password = ""
+	}
+	return user
+}
+
+func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
+	srcFile, err := os.Open(localSourcePath)
+	if err != nil {
+		return err
+	}
+	defer srcFile.Close()
+	destFile, err := client.Create(remoteDestPath)
+	if err != nil {
+		return err
+	}
+	defer destFile.Close()
+	_, err = io.Copy(destFile, srcFile)
+	if expectedSize > 0 {
+		fi, err := client.Lstat(remoteDestPath)
+		if err != nil {
+			return err
+		}
+		if fi.Size() != expectedSize {
+			return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
+		}
+	}
+	return err
+}
+
+func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
+	downloadDest, err := os.Create(localDestPath)
+	if err != nil {
+		return err
+	}
+	defer downloadDest.Close()
+	sftpSrcFile, err := client.Open(remoteSourcePath)
+	if err != nil {
+		return err
+	}
+	defer sftpSrcFile.Close()
+	_, err = io.Copy(downloadDest, sftpSrcFile)
+	if err != nil {
+		return err
+	}
+	err = downloadDest.Sync()
+	if err != nil {
+		return err
+	}
+	if expectedSize > 0 {
+		fi, err := downloadDest.Stat()
+		if err != nil {
+			return err
+		}
+		if fi.Size() != expectedSize {
+			return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
+		}
+	}
+	return err
+}
+
+func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
+	c := make(chan error)
+	go func() {
+		c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client)
+	}()
+	return c
+}
+
+func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
+	c := make(chan error)
+	go func() {
+		c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client)
+	}()
+	return c
+}
+
+func waitForActiveTransfer() {
+	stats := sftpd.GetConnectionsStats()
+	for len(stats) < 1 {
+		stats = sftpd.GetConnectionsStats()
+	}
+	activeTransferFound := false
+	for !activeTransferFound {
+		stats = sftpd.GetConnectionsStats()
+		if len(stats) == 0 {
+			break
+		}
+		for _, stat := range stats {
+			if len(stat.Transfers) > 0 {
+				activeTransferFound = true
+			}
+		}
+	}
+}