ensure no client is connected before running max connections test cases

This commit is contained in:
Nicola Murino 2021-05-11 08:04:57 +02:00
parent c8f7fc9bc9
commit b67cd0d3df
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
6 changed files with 56 additions and 2 deletions

View file

@ -711,6 +711,11 @@ func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) {
conns.clients.remove(ipAddr)
}
// GetClientConnections returns the total number of client connections
func (conns *ActiveConnections) GetClientConnections() int32 {
return conns.clients.getTotal()
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {

View file

@ -276,9 +276,13 @@ func TestMaxConnectionPerHost(t *testing.T) {
Connections.AddClientConnection(ipAddr)
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.Equal(t, int32(3), Connections.GetClientConnections())
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
assert.Equal(t, int32(0), Connections.GetClientConnections())
Config.MaxPerHostConnections = oldValue
}
@ -357,6 +361,7 @@ func TestIdleConnections(t *testing.T) {
defer Connections.RUnlock()
return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections())
stopIdleTimeoutTicker()
assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed)

View file

@ -117,6 +117,8 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) {
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
}
func TestZeroBytesTransfersCryptFs(t *testing.T) {

View file

@ -560,6 +560,8 @@ func TestBasicFTPHandling(t *testing.T) {
err = os.RemoveAll(localUser.GetHomeDir())
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
}
func TestLoginInvalidCredentials(t *testing.T) {
@ -756,10 +758,15 @@ func TestPostConnectHook(t *testing.T) {
common.Config.PostConnectHook = ""
}
//nolint:dupl
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user := getTestUser()
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
@ -781,10 +788,15 @@ func TestMaxConnections(t *testing.T) {
common.Config.MaxTotalConnections = oldValue
}
//nolint:dupl
func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user := getTestUser()
err := dataprovider.AddUser(&user)
assert.NoError(t, err)
@ -2689,6 +2701,8 @@ func TestNestedVirtualFolders(t *testing.T) {
err = os.RemoveAll(localUser.GetHomeDir())
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
}
func checkBasicFTP(client *ftp.ServerConn) error {

View file

@ -2906,6 +2906,10 @@ func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
usePubKey := true
user := getTestUser(usePubKey)
err := dataprovider.AddUser(&user)
@ -2937,6 +2941,10 @@ func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
usePubKey := true
user := getTestUser(usePubKey)
err := dataprovider.AddUser(&user)

View file

@ -919,6 +919,10 @@ func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(user, true, nil)
@ -944,6 +948,10 @@ func TestMaxPerHostConnections(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 1
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(user, true, nil)
@ -1188,7 +1196,7 @@ func TestQuotaLimits(t *testing.T) {
if !assert.NoError(t, err, "username: %v", user.Username) {
info, err := os.Stat(testFilePath)
if assert.NoError(t, err) {
fmt.Printf("local file size %v", info.Size())
fmt.Printf("local file size: %v\n", info.Size())
}
printLatestLogs(20)
}
@ -2580,7 +2588,19 @@ func createTestFile(path string, size int64) error {
if err != nil {
return err
}
return os.WriteFile(path, content, os.ModePerm)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil {
return err
}
_, err = f.Write(content)
if err == nil {
err = f.Sync()
}
if err1 := f.Close(); err1 != nil && err == nil {
err = err1
}
return err
}
func printLatestLogs(maxNumberOfLines int) {