fix max connections check

Also make sure to close the ssh client connection in test cases
This commit is contained in:
Nicola Murino 2021-04-20 18:12:16 +02:00
parent 92638ce93d
commit f4369cdbef
No known key found for this signature in database
GPG key ID: 2F1FB59433D5A8CB
8 changed files with 503 additions and 240 deletions

View file

@ -524,6 +524,9 @@ func (c *SSHConnection) Close() error {
// ActiveConnections holds the currect active connections with the associated transfers
type ActiveConnections struct {
// networkConnections is the counter for the network connections, it contains
// both authenticated and estabilished connections and the ones waiting for authentication
networkConnections int32
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
@ -690,12 +693,30 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock()
}
// AddNetworkConnection increments the network connections counter
func (conns *ActiveConnections) AddNetworkConnection() {
atomic.AddInt32(&conns.networkConnections, 1)
}
// RemoveNetworkConnection decrements the network connections counter
func (conns *ActiveConnections) RemoveNetworkConnection() {
atomic.AddInt32(&conns.networkConnections, -1)
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
if Config.MaxTotalConnections == 0 {
return true
}
num := atomic.LoadInt32(&conns.networkConnections)
if num > int32(Config.MaxTotalConnections) {
logger.Debug(logSender, "", "active network connections %v/%v", num, Config.MaxTotalConnections)
return false
}
// on a single SFTP connection we could have multiple SFTP channels or commands
// so we check the estabilished connections too
conns.RLock()
defer conns.RUnlock()

View file

@ -241,6 +241,14 @@ func TestMaxConnections(t *testing.T) {
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.AddNetworkConnection()
Connections.AddNetworkConnection()
assert.False(t, Connections.IsNewConnectionAllowed())
Connections.RemoveNetworkConnection()
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.RemoveNetworkConnection()
Config.MaxTotalConnections = oldValue
}

View file

@ -165,8 +165,9 @@ func TestBaseConnection(t *testing.T) {
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
_, err = client.ReadDir(testDir)
@ -241,8 +242,9 @@ func TestSetStat(t *testing.T) {
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
f, err := client.Create(testFileName)
assert.NoError(t, err)
@ -302,8 +304,9 @@ func TestPermissionErrors(t *testing.T) {
u.Permissions[subDir] = nil
sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = client.MkdirAll(path.Join(subDir, subDir))
assert.NoError(t, err)
@ -315,8 +318,9 @@ func TestPermissionErrors(t *testing.T) {
assert.NoError(t, err)
}
}
client, err = getSftpClient(sftpUser)
conn, client, err = getSftpClient(sftpUser)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
_, err = client.ReadDir(subDir)
@ -357,8 +361,9 @@ func TestFileNotAllowedErrors(t *testing.T) {
}
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt")
err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm)
@ -412,8 +417,9 @@ func TestTruncateQuotaLimits(t *testing.T) {
sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
for _, user := range []dataprovider.User{localUser, sftpUser} {
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
f, err := client.OpenFile(testFileName, os.O_WRONLY)
if assert.NoError(t, err) {
@ -662,8 +668,9 @@ func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client)
assert.NoError(t, err)
@ -750,8 +757,9 @@ func TestQuotaRenameOverwrite(t *testing.T) {
u.QuotaFiles = 100
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileSize := int64(131072)
testFileSize1 := int64(65537)
@ -820,8 +828,9 @@ func TestVirtualFoldersQuotaValues(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileSize := int64(131072)
err = writeSFTPFile(testFileName, testFileSize, client)
@ -909,8 +918,9 @@ func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileName1 := "test_file1.dat"
testFileSize := int64(131072)
@ -1088,8 +1098,9 @@ func TestQuotaRenameBetweenVirtualFolder(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileName1 := "test_file1.dat"
testFileSize := int64(131072)
@ -1283,8 +1294,9 @@ func TestQuotaRenameFromVirtualFolder(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileName1 := "test_file1.dat"
testFileSize := int64(131072)
@ -1483,8 +1495,9 @@ func TestQuotaRenameToVirtualFolder(t *testing.T) {
dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermCreateSymlinks, dataprovider.PermCreateDirs}
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileName1 := "test_file1.dat"
testFileSize := int64(131072)
@ -1693,8 +1706,9 @@ func TestVirtualFoldersLink(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileSize := int64(131072)
testDir := "adir"
@ -1794,8 +1808,9 @@ func TestDirs(t *testing.T) {
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
info, err := client.ReadDir("/")
if assert.NoError(t, err) {
@ -1855,8 +1870,9 @@ func TestDirs(t *testing.T) {
func TestCryptFsStat(t *testing.T) {
user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFileSize := int64(4096)
err = writeSFTPFile(testFileName, testFileSize, client)
@ -1882,8 +1898,9 @@ func TestFsPermissionErrors(t *testing.T) {
}
user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testDir := "tDir"
err = client.Mkdir(testDir)
@ -1978,8 +1995,9 @@ func TestUserPasswordHashing(t *testing.T) {
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(currentUser.Password, "$argon2id$"))
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
@ -1998,8 +2016,9 @@ func TestUserPasswordHashing(t *testing.T) {
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(currentUser.Password, "$2a$"))
client, err = getSftpClient(user)
conn, client, err = getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
@ -2114,15 +2133,16 @@ func TestPasswordCaching(t *testing.T) {
assert.False(t, match)
user.Password = "wrong"
_, err = getSftpClient(user)
_, _, err = getSftpClient(user)
assert.Error(t, err)
found, match = dataprovider.CheckCachedPassword(user.Username, defaultPassword)
assert.False(t, found)
assert.False(t, match)
user.Password = ""
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
@ -2145,8 +2165,9 @@ func TestPasswordCaching(t *testing.T) {
assert.False(t, found)
assert.False(t, match)
client, err = getSftpClient(user)
conn, client, err = getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = checkBasicSFTP(client)
assert.NoError(t, err)
@ -2177,8 +2198,9 @@ func TestQuotaTrackDisabled(t *testing.T) {
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = writeSFTPFile(testFileName, 32, client)
assert.NoError(t, err)
@ -2218,8 +2240,9 @@ func TestGetQuotaError(t *testing.T) {
})
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = writeSFTPFile(testFileName, 32, client)
assert.NoError(t, err)
@ -2252,8 +2275,9 @@ func TestRenameDir(t *testing.T) {
u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload}
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = client.Mkdir(testDir)
assert.NoError(t, err)
@ -2276,8 +2300,9 @@ func TestRenameSymlink(t *testing.T) {
dataprovider.PermCreateDirs}
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = client.Mkdir(otherDir)
assert.NoError(t, err)
@ -2398,8 +2423,9 @@ func TestNonLocalCrossRename(t *testing.T) {
})
user, resp, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err, string(resp))
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
err = writeSFTPFile(testFileName, 4096, client)
@ -2498,8 +2524,9 @@ func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) {
})
user, resp, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err, string(resp))
client, err := getSftpClient(user)
conn, client, err := getSftpClient(user)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
err = writeSFTPFile(testFileName, 4096, client)
@ -2608,7 +2635,7 @@ func checkBasicSFTP(client *sftp.Client) error {
return err
}
func getSftpClient(user dataprovider.User) (*sftp.Client, error) {
func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) {
var sftpClient *sftp.Client
config := &ssh.ClientConfig{
User: user.Username,
@ -2624,10 +2651,13 @@ func getSftpClient(user dataprovider.User) (*sftp.Client, error) {
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
if err != nil {
return sftpClient, err
return conn, sftpClient, err
}
sftpClient, err = sftp.NewClient(conn)
return sftpClient, err
if err != nil {
conn.Close()
}
return conn, sftpClient, err
}
func getTestUser() dataprovider.User {

View file

@ -135,6 +135,7 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
// ClientConnected is called to send the very first welcome message
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
common.Connections.AddNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(cc.RemoteAddr().String())
if common.IsBanned(ipAddr) {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %#v is banned", ipAddr)
@ -166,6 +167,7 @@ func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) {
s.cleanTLSConnVerification(cc.ID())
connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID())
common.Connections.Remove(connID)
common.Connections.RemoveNetworkConnection()
}
// AuthUser authenticates the user and selects an handling driver

View file

@ -29,8 +29,9 @@ func TestBasicSFTPCryptoHandling(t *testing.T) {
u.QuotaSize = 6553600
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
@ -95,8 +96,9 @@ func TestOpenReadWriteCryptoFs(t *testing.T) {
u.QuotaSize = 6553600
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC)
if assert.NoError(t, err) {
@ -124,8 +126,9 @@ func TestEmptyFile(t *testing.T) {
u := getTestUserWithCryptFs(usePubKey)
user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC)
if assert.NoError(t, err) {
@ -166,8 +169,9 @@ func TestUploadResumeCryptFs(t *testing.T) {
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
@ -201,8 +205,9 @@ func TestQuotaFileReplaceCryptFs(t *testing.T) {
testFilePath := filepath.Join(homeBasePath, testFileName)
encryptedFileSize, err := getEncryptedFileSize(testFileSize)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) { //nolint:dupl
defer conn.Close()
defer client.Close()
expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize
expectedQuotaFiles := user.UsedQuotaFiles + 1
@ -238,8 +243,9 @@ func TestQuotaFileReplaceCryptFs(t *testing.T) {
user.QuotaSize = encryptedFileSize*2 - 1
user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "")
assert.NoError(t, err)
client, err = getSftpClient(user, usePubKey)
conn, client, err = getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
assert.Error(t, err, "quota size exceeded, file upload must fail")
@ -263,8 +269,9 @@ func TestQuotaScanCryptFs(t *testing.T) {
assert.NoError(t, err)
expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize
expectedQuotaFiles := user.UsedQuotaFiles + 1
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
testFilePath := filepath.Join(homeBasePath, testFileName)
err = createTestFile(testFilePath, testFileSize)
@ -302,8 +309,9 @@ func TestGetMimeTypeCryptFs(t *testing.T) {
usePubKey := true
user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC)
if assert.NoError(t, err) {
@ -336,8 +344,9 @@ func TestTruncate(t *testing.T) {
usePubKey := true
user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
conn, client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer conn.Close()
defer client.Close()
f, err := client.OpenFile(testFileName, os.O_WRONLY)
if assert.NoError(t, err) {

View file

@ -377,6 +377,10 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
}
}()
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
ipAddr := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
if !canAcceptConnection(ipAddr) {
conn.Close()

File diff suppressed because it is too large Load diff

View file

@ -147,6 +147,9 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
}
}()
common.Connections.AddNetworkConnection()
defer common.Connections.RemoveNetworkConnection()
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)