Compare commits

...

5 commits

Author SHA1 Message Date
Nicola Murino
b4c06c46e1
EventManager: always close the connection filesystem
Some checks are pending
Code scanning - action / CodeQL-Build (push) Waiting to run
CI / Test and deploy (push) Waiting to run
CI / Test build flags (push) Waiting to run
CI / Test with PgSQL/MySQL/Cockroach (push) Waiting to run
CI / Build Linux packages (push) Waiting to run
CI / golangci-lint (push) Waiting to run
Docker / Build (push) Waiting to run
closing the user filesystem is not enough here

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-20 21:18:01 +01:00
Nicola Murino
ee6049bdc3
test cases: fix some random failures
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-20 18:29:43 +01:00
Nicola Murino
6bc2f8d16e
upgrade nfpm to 2.41.1
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-20 18:29:19 +01:00
Nicola Murino
f89d72f685
OIDC cookie: use a cryptographically secure random string
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-20 18:28:43 +01:00
Nicola Murino
d0d8a1999f
sftpd: remove allocator
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2024-11-20 18:28:15 +01:00
10 changed files with 50 additions and 29 deletions

View file

@ -1374,6 +1374,9 @@ func getHTTPRuleActionBody(c *dataprovider.EventActionHTTPConfig, replacer *stri
go func() { go func() {
defer w.Close() defer w.Close()
defer user.CloseFs() //nolint:errcheck defer user.CloseFs() //nolint:errcheck
if conn != nil {
defer conn.CloseFS() //nolint:errcheck
}
for _, part := range c.Parts { for _, part := range c.Parts {
h := make(textproto.MIMEHeader) h := make(textproto.MIMEHeader)
@ -1591,6 +1594,8 @@ func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *Event
return fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
res, err := getMailAttachments(conn, fileAttachments, replacer) res, err := getMailAttachments(conn, fileAttachments, replacer)
if err != nil { if err != nil {
return err return err
@ -1652,6 +1657,8 @@ func executeDeleteFsActionForUser(deletes []string, replacer *strings.Replacer,
return fmt.Errorf("delete error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("delete error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
for _, item := range replacePathsPlaceholders(deletes, replacer) { for _, item := range replacePathsPlaceholders(deletes, replacer) {
info, err := conn.DoStat(item, 0, false) info, err := conn.DoStat(item, 0, false)
if err != nil { if err != nil {
@ -1720,6 +1727,8 @@ func executeMkDirsFsActionForUser(dirs []string, replacer *strings.Replacer, use
return fmt.Errorf("mkdir error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("mkdir error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
for _, item := range replacePathsPlaceholders(dirs, replacer) { for _, item := range replacePathsPlaceholders(dirs, replacer) {
if err = conn.CheckParentDirs(path.Dir(item)); err != nil { if err = conn.CheckParentDirs(path.Dir(item)); err != nil {
return fmt.Errorf("unable to check parent dirs for %q, user %q: %w", item, user.Username, err) return fmt.Errorf("unable to check parent dirs for %q, user %q: %w", item, user.Username, err)
@ -1779,6 +1788,8 @@ func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer
return fmt.Errorf("rename error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("rename error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
for _, item := range renames { for _, item := range renames {
source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) source := util.CleanPath(replaceWithReplacer(item.Key, replacer))
target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) target := util.CleanPath(replaceWithReplacer(item.Value, replacer))
@ -1808,6 +1819,8 @@ func executeCopyFsActionForUser(keyVals []dataprovider.KeyValue, replacer *strin
return fmt.Errorf("copy error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("copy error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
for _, item := range keyVals { for _, item := range keyVals {
source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) source := util.CleanPath(replaceWithReplacer(item.Key, replacer))
target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) target := util.CleanPath(replaceWithReplacer(item.Value, replacer))
@ -1839,6 +1852,8 @@ func executeExistFsActionForUser(exist []string, replacer *strings.Replacer,
return fmt.Errorf("existence check error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("existence check error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
for _, item := range replacePathsPlaceholders(exist, replacer) { for _, item := range replacePathsPlaceholders(exist, replacer) {
if _, err = conn.DoStat(item, 0, false); err != nil { if _, err = conn.DoStat(item, 0, false); err != nil {
return fmt.Errorf("error checking existence for path %q, user %q: %w", item, user.Username, err) return fmt.Errorf("error checking existence for path %q, user %q: %w", item, user.Username, err)
@ -1997,6 +2012,8 @@ func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replac
return fmt.Errorf("compress error, unable to check root fs for user %q: %w", user.Username, err) return fmt.Errorf("compress error, unable to check root fs for user %q: %w", user.Username, err)
} }
conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user)
defer conn.CloseFS() //nolint:errcheck
name := util.CleanPath(replaceWithReplacer(c.Name, replacer)) name := util.CleanPath(replaceWithReplacer(c.Name, replacer))
conn.CheckParentDirs(path.Dir(name)) //nolint:errcheck conn.CheckParentDirs(path.Dir(name)) //nolint:errcheck
paths := make([]string, 0, len(c.Paths)) paths := make([]string, 0, len(c.Paths))

View file

@ -575,6 +575,7 @@ func validateBrowsableShare(share dataprovider.Share, connection *Connection) er
basePath := share.Paths[0] basePath := share.Paths[0]
info, err := connection.Stat(basePath, 0) info, err := connection.Stat(basePath, 0)
if err != nil { if err != nil {
connection.CloseFS() //nolint:errcheck
return util.NewI18nError( return util.NewI18nError(
fmt.Errorf("unable to check the share directory: %w", err), fmt.Errorf("unable to check the share directory: %w", err),
util.I18nErrorShareInvalidPath, util.I18nErrorShareInvalidPath,

View file

@ -13556,7 +13556,9 @@ func TestMaxTransfers(t *testing.T) {
err = os.RemoveAll(user.GetHomeDir()) err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(""), 0) assert.Len(t, common.Connections.GetStats(""), 0)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) assert.Eventually(t, func() bool {
return common.Connections.GetTotalTransfers() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
common.Config.MaxPerHostConnections = oldValue common.Config.MaxPerHostConnections = oldValue
} }

View file

@ -15,8 +15,6 @@
package httpd package httpd
import ( import (
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"sync" "sync"
@ -53,10 +51,8 @@ type oauth2PendingAuth struct {
} }
func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth { func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth {
state := sha256.Sum256(util.GenerateRandomBytes(32))
return oauth2PendingAuth{ return oauth2PendingAuth{
State: hex.EncodeToString(state[:]), State: util.GenerateOpaqueString(),
Provider: provider, Provider: provider,
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,

View file

@ -16,8 +16,6 @@ package httpd
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -204,12 +202,9 @@ type oidcPendingAuth struct {
} }
func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth { func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth {
state := sha256.Sum256(util.GenerateRandomBytes(32))
nonce := util.GenerateUniqueID()
return oidcPendingAuth{ return oidcPendingAuth{
State: hex.EncodeToString(state[:]), State: util.GenerateOpaqueString(),
Nonce: nonce, Nonce: util.GenerateOpaqueString(),
Audience: audience, Audience: audience,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()),
} }
@ -684,7 +679,7 @@ func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request)
RefreshToken: oauth2Token.RefreshToken, RefreshToken: oauth2Token.RefreshToken,
IDToken: rawIDToken, IDToken: rawIDToken,
Nonce: idToken.Nonce, Nonce: idToken.Nonce,
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
} }
if !oauth2Token.Expiry.IsZero() { if !oauth2Token.Expiry.IsZero() {
token.ExpiresAt = util.GetTimeAsMsSinceEpoch(oauth2Token.Expiry) token.ExpiresAt = util.GetTimeAsMsSinceEpoch(oauth2Token.Expiry)

View file

@ -152,8 +152,8 @@ func TestOIDCLoginLogout(t *testing.T) {
assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth) assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth)
expiredAuthReq := oidcPendingAuth{ expiredAuthReq := oidcPendingAuth{
State: xid.New().String(), State: util.GenerateOpaqueString(),
Nonce: xid.New().String(), Nonce: util.GenerateOpaqueString(),
Audience: tokenAudienceWebClient, Audience: tokenAudienceWebClient,
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
} }
@ -564,7 +564,7 @@ func TestOIDCRefreshToken(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, webUsersPath, nil) r, err := http.NewRequest(http.MethodGet, webUsersPath, nil)
assert.NoError(t, err) assert.NoError(t, err)
token := oidcToken{ token := oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
TokenType: "Bearer", TokenType: "Bearer",
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)),
@ -668,7 +668,7 @@ func TestOIDCRefreshToken(t *testing.T) {
func TestOIDCRefreshUser(t *testing.T) { func TestOIDCRefreshUser(t *testing.T) {
token := oidcToken{ token := oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
TokenType: "Bearer", TokenType: "Bearer",
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)),
@ -782,7 +782,7 @@ func TestValidateOIDCToken(t *testing.T) {
}, },
} }
token := oidcToken{ token := oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
} }
@ -798,8 +798,8 @@ func TestValidateOIDCToken(t *testing.T) {
server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil)
token = oidcToken{ token = oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: util.GenerateUniqueID(),
} }
oidcMgr.addToken(token) oidcMgr.addToken(token)
rr = httptest.NewRecorder() rr = httptest.NewRecorder()
@ -813,7 +813,7 @@ func TestValidateOIDCToken(t *testing.T) {
assert.Len(t, oidcMgr.tokens, 0) assert.Len(t, oidcMgr.tokens, 0)
token = oidcToken{ token = oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
Role: "admin", Role: "admin",
} }
@ -1107,7 +1107,7 @@ func TestMemoryOIDCManager(t *testing.T) {
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
Nonce: xid.New().String(), Nonce: xid.New().String(),
SessionID: xid.New().String(), SessionID: xid.New().String(),
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
Username: xid.New().String(), Username: xid.New().String(),
Role: "admin", Role: "admin",
Permissions: []string{dataprovider.PermAdminAny}, Permissions: []string{dataprovider.PermAdminAny},
@ -1157,7 +1157,7 @@ func TestMemoryOIDCManager(t *testing.T) {
token.UsedAt = usedAt token.UsedAt = usedAt
oidcMgr.tokens[token.Cookie] = token oidcMgr.tokens[token.Cookie] = token
newToken := oidcToken{ newToken := oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
} }
oidcMgr.addToken(newToken) oidcMgr.addToken(newToken)
oidcMgr.cleanup() oidcMgr.cleanup()
@ -1663,7 +1663,7 @@ func TestDbOIDCManager(t *testing.T) {
} }
token := oidcToken{ token := oidcToken{
Cookie: xid.New().String(), Cookie: util.GenerateOpaqueString(),
AccessToken: xid.New().String(), AccessToken: xid.New().String(),
TokenType: "Bearer", TokenType: "Bearer",
RefreshToken: xid.New().String(), RefreshToken: xid.New().String(),

View file

@ -694,7 +694,7 @@ func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Co
defer common.Connections.Remove(connection.GetID()) defer common.Connections.Remove(connection.GetID())
// Create the server instance for the channel using the handler we created above. // Create the server instance for the channel using the handler we created above.
server := sftp.NewRequestServer(channel, c.createHandlers(connection), sftp.WithRSAllocator(), server := sftp.NewRequestServer(channel, c.createHandlers(connection),
sftp.WithStartDirectory(connection.User.Filters.StartDirectory)) sftp.WithStartDirectory(connection.User.Filters.StartDirectory))
defer server.Close() defer server.Close()

View file

@ -4457,7 +4457,9 @@ func TestMaxTransfers(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir()) err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) assert.Eventually(t, func() bool {
return common.Connections.GetTotalTransfers() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
common.Config.MaxPerHostConnections = oldValue common.Config.MaxPerHostConnections = oldValue
} }

View file

@ -22,8 +22,10 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
@ -550,7 +552,7 @@ func createDirPathIfMissing(file string, perm os.FileMode) error {
return nil return nil
} }
// GenerateRandomBytes generates the secret to use for JWT auth // GenerateRandomBytes generates random bytes with the specified length
func GenerateRandomBytes(length int) []byte { func GenerateRandomBytes(length int) []byte {
b := make([]byte, length) b := make([]byte, length)
_, err := io.ReadFull(rand.Reader, b) _, err := io.ReadFull(rand.Reader, b)
@ -560,6 +562,12 @@ func GenerateRandomBytes(length int) []byte {
return b return b
} }
// GenerateOpaqueString generates a cryptographically secure opaque string
func GenerateOpaqueString() string {
randomBytes := sha256.Sum256(GenerateRandomBytes(32))
return hex.EncodeToString(randomBytes[:])
}
// GenerateUniqueID returns an unique ID // GenerateUniqueID returns an unique ID
func GenerateUniqueID() string { func GenerateUniqueID() string {
u, err := uuid.NewRandom() u, err := uuid.NewRandom()

View file

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
NFPM_VERSION=2.41.0 NFPM_VERSION=2.41.1
NFPM_ARCH=${NFPM_ARCH:-amd64} NFPM_ARCH=${NFPM_ARCH:-amd64}
if [ -z ${SFTPGO_VERSION} ] if [ -z ${SFTPGO_VERSION} ]
then then