Browse Source

ensure no client is connected before running max connections test cases

Nicola Murino 4 năm trước cách đây
mục cha
commit
b67cd0d3df
6 tập tin đã thay đổi với 56 bổ sung2 xóa
  1. 5 0
      common/common.go
  2. 5 0
      common/common_test.go
  3. 2 0
      ftpd/cryptfs_test.go
  4. 14 0
      ftpd/ftpd_test.go
  5. 8 0
      sftpd/sftpd_test.go
  6. 22 2
      webdavd/webdavd_test.go

+ 5 - 0
common/common.go

@@ -711,6 +711,11 @@ func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
 	conns.clients.remove(ipAddr)
 }
 
+// GetClientConnections returns the total number of client connections
+func (conns *ActiveConnections) GetClientConnections() int32 {
+	return conns.clients.getTotal()
+}
+
 // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
 func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
 	if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {

+ 5 - 0
common/common_test.go

@@ -276,9 +276,13 @@ func TestMaxConnectionPerHost(t *testing.T) {
 
 	Connections.AddClientConnection(ipAddr)
 	assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
+	assert.Equal(t, int32(3), Connections.GetClientConnections())
 
 	Connections.RemoveClientConnection(ipAddr)
 	Connections.RemoveClientConnection(ipAddr)
+	Connections.RemoveClientConnection(ipAddr)
+
+	assert.Equal(t, int32(0), Connections.GetClientConnections())
 
 	Config.MaxPerHostConnections = oldValue
 }
@@ -357,6 +361,7 @@ func TestIdleConnections(t *testing.T) {
 		defer Connections.RUnlock()
 		return len(Connections.sshConnections) == 0
 	}, 1*time.Second, 200*time.Millisecond)
+	assert.Equal(t, int32(0), Connections.GetClientConnections())
 	stopIdleTimeoutTicker()
 	assert.True(t, customConn1.isClosed)
 	assert.True(t, customConn2.isClosed)

+ 2 - 0
ftpd/cryptfs_test.go

@@ -117,6 +117,8 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) {
 	err = os.RemoveAll(user.GetHomeDir())
 	assert.NoError(t, err)
 	assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
+	assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
+		50*time.Millisecond)
 }
 
 func TestZeroBytesTransfersCryptFs(t *testing.T) {

+ 14 - 0
ftpd/ftpd_test.go

@@ -560,6 +560,8 @@ func TestBasicFTPHandling(t *testing.T) {
 	err = os.RemoveAll(localUser.GetHomeDir())
 	assert.NoError(t, err)
 	assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
+	assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
+		50*time.Millisecond)
 }
 
 func TestLoginInvalidCredentials(t *testing.T) {
@@ -756,10 +758,15 @@ func TestPostConnectHook(t *testing.T) {
 	common.Config.PostConnectHook = ""
 }
 
+//nolint:dupl
 func TestMaxConnections(t *testing.T) {
 	oldValue := common.Config.MaxTotalConnections
 	common.Config.MaxTotalConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	user := getTestUser()
 	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
@@ -781,10 +788,15 @@ func TestMaxConnections(t *testing.T) {
 	common.Config.MaxTotalConnections = oldValue
 }
 
+//nolint:dupl
 func TestMaxPerHostConnections(t *testing.T) {
 	oldValue := common.Config.MaxPerHostConnections
 	common.Config.MaxPerHostConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	user := getTestUser()
 	err := dataprovider.AddUser(&user)
 	assert.NoError(t, err)
@@ -2689,6 +2701,8 @@ func TestNestedVirtualFolders(t *testing.T) {
 	err = os.RemoveAll(localUser.GetHomeDir())
 	assert.NoError(t, err)
 	assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
+	assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
+		50*time.Millisecond)
 }
 
 func checkBasicFTP(client *ftp.ServerConn) error {

+ 8 - 0
sftpd/sftpd_test.go

@@ -2906,6 +2906,10 @@ func TestMaxConnections(t *testing.T) {
 	oldValue := common.Config.MaxTotalConnections
 	common.Config.MaxTotalConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	usePubKey := true
 	user := getTestUser(usePubKey)
 	err := dataprovider.AddUser(&user)
@@ -2937,6 +2941,10 @@ func TestMaxPerHostConnections(t *testing.T) {
 	oldValue := common.Config.MaxPerHostConnections
 	common.Config.MaxPerHostConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	usePubKey := true
 	user := getTestUser(usePubKey)
 	err := dataprovider.AddUser(&user)

+ 22 - 2
webdavd/webdavd_test.go

@@ -919,6 +919,10 @@ func TestMaxConnections(t *testing.T) {
 	oldValue := common.Config.MaxTotalConnections
 	common.Config.MaxTotalConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
 	assert.NoError(t, err)
 	client := getWebDavClient(user, true, nil)
@@ -944,6 +948,10 @@ func TestMaxPerHostConnections(t *testing.T) {
 	oldValue := common.Config.MaxPerHostConnections
 	common.Config.MaxPerHostConnections = 1
 
+	assert.Eventually(t, func() bool {
+		return common.Connections.GetClientConnections() == 0
+	}, 1000*time.Millisecond, 50*time.Millisecond)
+
 	user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
 	assert.NoError(t, err)
 	client := getWebDavClient(user, true, nil)
@@ -1188,7 +1196,7 @@ func TestQuotaLimits(t *testing.T) {
 		if !assert.NoError(t, err, "username: %v", user.Username) {
 			info, err := os.Stat(testFilePath)
 			if assert.NoError(t, err) {
-				fmt.Printf("local file size %v", info.Size())
+				fmt.Printf("local file size: %v\n", info.Size())
 			}
 			printLatestLogs(20)
 		}
@@ -2580,7 +2588,19 @@ func createTestFile(path string, size int64) error {
 	if err != nil {
 		return err
 	}
-	return os.WriteFile(path, content, os.ModePerm)
+
+	f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
+	if err != nil {
+		return err
+	}
+	_, err = f.Write(content)
+	if err == nil {
+		err = f.Sync()
+	}
+	if err1 := f.Close(); err1 != nil && err == nil {
+		err = err1
+	}
+	return err
 }
 
 func printLatestLogs(maxNumberOfLines int) {