add more test cases
This commit is contained in:
parent
ff8fb80e3c
commit
9f61415832
3 changed files with 311 additions and 164 deletions
|
@ -96,50 +96,6 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("tcp server %v now listening\n", address)
|
||||
defer conn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func getTestUser() dataprovider.User {
|
||||
return dataprovider.User{
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
HomeDir: filepath.Join(homeBasePath, defaultUsername),
|
||||
Permissions: defaultPerms,
|
||||
}
|
||||
}
|
||||
|
||||
func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
|
||||
json, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
t.Errorf("error get user as json: %v", err)
|
||||
return []byte("{}")
|
||||
}
|
||||
return json
|
||||
}
|
||||
|
||||
func executeRequest(req *http.Request) *httptest.ResponseRecorder {
|
||||
rr := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
||||
func checkResponseCode(t *testing.T, expected, actual int) {
|
||||
if expected != actual {
|
||||
t.Errorf("Expected response code %d. Got %d", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicUserHandling(t *testing.T) {
|
||||
user, err := api.AddUser(getTestUser(), http.StatusOK)
|
||||
if err != nil {
|
||||
|
@ -669,3 +625,47 @@ func TestMethodNotAllowedMock(t *testing.T) {
|
|||
rr := executeRequest(req)
|
||||
checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code)
|
||||
}
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("tcp server %v now listening\n", address)
|
||||
defer conn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func getTestUser() dataprovider.User {
|
||||
return dataprovider.User{
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
HomeDir: filepath.Join(homeBasePath, defaultUsername),
|
||||
Permissions: defaultPerms,
|
||||
}
|
||||
}
|
||||
|
||||
func getUserAsJSON(t *testing.T, user dataprovider.User) []byte {
|
||||
json, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
t.Errorf("error get user as json: %v", err)
|
||||
return []byte("{}")
|
||||
}
|
||||
return json
|
||||
}
|
||||
|
||||
func executeRequest(req *http.Request) *httptest.ResponseRecorder {
|
||||
rr := httptest.NewRecorder()
|
||||
testServer.Config.Handler.ServeHTTP(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
||||
func checkResponseCode(t *testing.T, expected, actual int) {
|
||||
if expected != actual {
|
||||
t.Errorf("Expected response code %d. Got %d", expected, actual)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -185,12 +185,13 @@ func startIdleTimer(maxIdleTime time.Duration) {
|
|||
go func() {
|
||||
for t := range idleConnectionTicker.C {
|
||||
logger.Debug(logSender, "idle connections check ticker %v", t)
|
||||
checkIdleConnections()
|
||||
CheckIdleConnections()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkIdleConnections() {
|
||||
// CheckIdleConnections disconnects idle clients
|
||||
func CheckIdleConnections() {
|
||||
mutex.RLock()
|
||||
defer mutex.RUnlock()
|
||||
for _, c := range openConnections {
|
||||
|
|
|
@ -135,123 +135,6 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("tcp server %v now listening\n", address)
|
||||
defer conn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
|
||||
var sftpClient *sftp.Client
|
||||
config := &ssh.ClientConfig{
|
||||
User: defaultUsername,
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if usePubKey {
|
||||
key, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)}
|
||||
} else {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
|
||||
}
|
||||
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
|
||||
if err != nil {
|
||||
return sftpClient, err
|
||||
}
|
||||
sftpClient, err = sftp.NewClient(conn)
|
||||
return sftpClient, err
|
||||
}
|
||||
|
||||
func createTestFile(path string, size int64) error {
|
||||
content := make([]byte, size)
|
||||
_, err := rand.Read(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(path, content, 0666)
|
||||
}
|
||||
|
||||
func getTestUser(usePubKey bool) dataprovider.User {
|
||||
user := dataprovider.User{
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
HomeDir: filepath.Join(homeBasePath, defaultUsername),
|
||||
Permissions: allPerms,
|
||||
}
|
||||
if usePubKey {
|
||||
user.PublicKey = testPubKey
|
||||
user.Password = ""
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
|
||||
srcFile, err := os.Open(localSourcePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
destFile, err := client.Create(remoteDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
if expectedSize > 0 {
|
||||
fi, err := client.Lstat(remoteDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fi.Size() != expectedSize {
|
||||
return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
|
||||
downloadDest, err := os.Create(localDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer downloadDest.Close()
|
||||
sftpSrcFile, err := client.Open(remoteSourcePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sftpSrcFile.Close()
|
||||
_, err = io.Copy(downloadDest, sftpSrcFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = downloadDest.Sync()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if expectedSize > 0 {
|
||||
fi, err := downloadDest.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fi.Size() != expectedSize {
|
||||
return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func TestBasicSFTPHandling(t *testing.T) {
|
||||
usePubKey := false
|
||||
user, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
|
@ -647,7 +530,6 @@ func TestQuotaScan(t *testing.T) {
|
|||
defer client.Close()
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
|
@ -697,6 +579,118 @@ func TestQuotaScan(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMultipleQuotaScans(t *testing.T) {
|
||||
if !sftpd.AddQuotaScan(defaultUsername) {
|
||||
t.Errorf("add quota failed")
|
||||
}
|
||||
if sftpd.AddQuotaScan(defaultUsername) {
|
||||
t.Errorf("add quota must fail if another scan is already active")
|
||||
}
|
||||
sftpd.RemoveQuotaScan(defaultPassword)
|
||||
}
|
||||
|
||||
func TestQuotaSize(t *testing.T) {
|
||||
usePubKey := false
|
||||
testFileSize := int64(65535)
|
||||
u := getTestUser(usePubKey)
|
||||
u.QuotaFiles = 1
|
||||
u.QuotaSize = testFileSize - 1
|
||||
user, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
client, err := getSftpClient(user, usePubKey)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create sftp client: %v", err)
|
||||
} else {
|
||||
defer client.Close()
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
|
||||
if err != nil {
|
||||
t.Errorf("file upload error: %v", err)
|
||||
}
|
||||
err = sftpUploadFile(testFilePath, testFileName+".1", testFileSize, client)
|
||||
if err == nil {
|
||||
t.Errorf("user is over quota file upload must fail")
|
||||
}
|
||||
}
|
||||
err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandwidthAndConnections(t *testing.T) {
|
||||
usePubKey := false
|
||||
testFileSize := int64(131072)
|
||||
u := getTestUser(usePubKey)
|
||||
u.UploadBandwidth = 30
|
||||
u.DownloadBandwidth = 25
|
||||
wantedUploadElapsed := 1000 * (testFileSize / 1000) / u.UploadBandwidth
|
||||
wantedDownloadElapsed := 1000 * (testFileSize / 1000) / u.DownloadBandwidth
|
||||
// 100 ms tolerance
|
||||
wantedUploadElapsed -= 100
|
||||
wantedDownloadElapsed -= 100
|
||||
user, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
client, err := getSftpClient(user, usePubKey)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create sftp client: %v", err)
|
||||
} else {
|
||||
defer client.Close()
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
startTime := time.Now()
|
||||
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
|
||||
if err != nil {
|
||||
t.Errorf("file upload error: %v", err)
|
||||
}
|
||||
elapsed := time.Since(startTime).Nanoseconds() / 1000000
|
||||
if elapsed < (wantedUploadElapsed) {
|
||||
t.Errorf("upload bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedUploadElapsed)
|
||||
}
|
||||
startTime = time.Now()
|
||||
localDownloadPath := filepath.Join(homeBasePath, "test_download.dat")
|
||||
c := sftpDownloadNonBlocking(testFileName, localDownloadPath, testFileSize, client)
|
||||
waitForActiveTransfer()
|
||||
err = <-c
|
||||
if err != nil {
|
||||
t.Errorf("file download error: %v", err)
|
||||
}
|
||||
elapsed = time.Since(startTime).Nanoseconds() / 1000000
|
||||
if elapsed < (wantedDownloadElapsed) {
|
||||
t.Errorf("download bandwidth throttling not respected, elapsed: %v, wanted: %v", elapsed, wantedDownloadElapsed)
|
||||
}
|
||||
// test disconnection
|
||||
c = sftpUploadNonBlocking(testFilePath, testFileName, testFileSize, client)
|
||||
waitForActiveTransfer()
|
||||
sftpd.CheckIdleConnections()
|
||||
stats := sftpd.GetConnectionsStats()
|
||||
for _, stat := range stats {
|
||||
sftpd.CloseActiveConnection(stat.ConnectionID)
|
||||
}
|
||||
err = <-c
|
||||
if err == nil {
|
||||
t.Errorf("connection closed upload must fail")
|
||||
}
|
||||
}
|
||||
err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermList(t *testing.T) {
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
|
@ -935,3 +929,155 @@ func TestPermSymlink(t *testing.T) {
|
|||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
fmt.Printf("tcp server %v not listening: %v\n", address, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("tcp server %v now listening\n", address)
|
||||
defer conn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error) {
|
||||
var sftpClient *sftp.Client
|
||||
config := &ssh.ClientConfig{
|
||||
User: defaultUsername,
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if usePubKey {
|
||||
key, err := ssh.ParsePrivateKey([]byte(testPrivateKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)}
|
||||
} else {
|
||||
config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
|
||||
}
|
||||
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
|
||||
if err != nil {
|
||||
return sftpClient, err
|
||||
}
|
||||
sftpClient, err = sftp.NewClient(conn)
|
||||
return sftpClient, err
|
||||
}
|
||||
|
||||
func createTestFile(path string, size int64) error {
|
||||
content := make([]byte, size)
|
||||
_, err := rand.Read(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(path, content, 0666)
|
||||
}
|
||||
|
||||
func getTestUser(usePubKey bool) dataprovider.User {
|
||||
user := dataprovider.User{
|
||||
Username: defaultUsername,
|
||||
Password: defaultPassword,
|
||||
HomeDir: filepath.Join(homeBasePath, defaultUsername),
|
||||
Permissions: allPerms,
|
||||
}
|
||||
if usePubKey {
|
||||
user.PublicKey = testPubKey
|
||||
user.Password = ""
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
|
||||
srcFile, err := os.Open(localSourcePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
destFile, err := client.Create(remoteDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer destFile.Close()
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
if expectedSize > 0 {
|
||||
fi, err := client.Lstat(remoteDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fi.Size() != expectedSize {
|
||||
return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
|
||||
downloadDest, err := os.Create(localDestPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer downloadDest.Close()
|
||||
sftpSrcFile, err := client.Open(remoteSourcePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sftpSrcFile.Close()
|
||||
_, err = io.Copy(downloadDest, sftpSrcFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = downloadDest.Sync()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if expectedSize > 0 {
|
||||
fi, err := downloadDest.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fi.Size() != expectedSize {
|
||||
return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
||||
c := make(chan error)
|
||||
go func() {
|
||||
c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client)
|
||||
}()
|
||||
return c
|
||||
}
|
||||
|
||||
func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
||||
c := make(chan error)
|
||||
go func() {
|
||||
c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client)
|
||||
}()
|
||||
return c
|
||||
}
|
||||
|
||||
func waitForActiveTransfer() {
|
||||
stats := sftpd.GetConnectionsStats()
|
||||
for len(stats) < 1 {
|
||||
stats = sftpd.GetConnectionsStats()
|
||||
}
|
||||
activeTransferFound := false
|
||||
for !activeTransferFound {
|
||||
stats = sftpd.GetConnectionsStats()
|
||||
if len(stats) == 0 {
|
||||
break
|
||||
}
|
||||
for _, stat := range stats {
|
||||
if len(stat.Transfers) > 0 {
|
||||
activeTransferFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue