diff --git a/internal/vfs/sftpfs.go b/internal/vfs/sftpfs.go index 869c315e..445327bc 100644 --- a/internal/vfs/sftpfs.go +++ b/internal/vfs/sftpfs.go @@ -69,6 +69,17 @@ type SFTPFsConfig struct { forbiddenSelfUsernames []string `json:"-"` } +func (c *SFTPFsConfig) getKeySigner() (ssh.Signer, error) { + privPayload := c.PrivateKey.GetPayload() + if privPayload == "" { + return nil, nil + } + if key := c.KeyPassphrase.GetPayload(); key != "" { + return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key)) + } + return ssh.ParsePrivateKey([]byte(privPayload)) +} + // HideConfidentialData hides confidential data func (c *SFTPFsConfig) HideConfidentialData() { if c.Password != nil { @@ -185,25 +196,20 @@ func (c *SFTPFsConfig) validate() error { func (c *SFTPFsConfig) validatePrivateKey() error { if c.PrivateKey.IsPlain() { - var signer ssh.Signer - var err error - if c.KeyPassphrase.IsPlain() { - signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(c.PrivateKey.GetPayload()), - []byte(c.KeyPassphrase.GetPayload())) - } else { - signer, err = ssh.ParsePrivateKey([]byte(c.PrivateKey.GetPayload())) - } + signer, err := c.getKeySigner() if err != nil { return util.NewI18nError(fmt.Errorf("invalid private key: %w", err), util.I18nErrorPrivKeyInvalid) } - if key, ok := signer.PublicKey().(ssh.CryptoPublicKey); ok { - cryptoKey := key.CryptoPublicKey() - if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { - if size := rsaKey.N.BitLen(); size < 2048 { - return util.NewI18nError( - fmt.Errorf("rsa key with size %d not accepted, minimum 2048", size), - util.I18nErrorKeySizeInvalid, - ) + if signer != nil { + if key, ok := signer.PublicKey().(ssh.CryptoPublicKey); ok { + cryptoKey := key.CryptoPublicKey() + if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { + if size := rsaKey.N.BitLen(); size < 2048 { + return util.NewI18nError( + fmt.Errorf("rsa key with size %d not accepted, minimum 2048", size), + util.I18nErrorKeySizeInvalid, + ) + } } } } @@ -331,15 +337,19 @@ func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUserna return nil, err } } + conn, err := sftpConnsCache.Get(&config, connectionID) + if err != nil { + return nil, err + } config.forbiddenSelfUsernames = forbiddenSelfUsernames sftpFs := &SFTPFs{ connectionID: connectionID, mountPath: getMountPath(mountPath), localTempDir: localTempDir, config: &config, - conn: sftpConnsCache.Get(&config, connectionID), + conn: conn, } - err := sftpFs.createConnection() + err = sftpFs.createConnection() if err != nil { sftpFs.Close() //nolint:errcheck } @@ -910,6 +920,7 @@ type sftpConnection struct { isConnected bool sessions map[string]bool lastActivity time.Time + signer ssh.Signer } func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection { @@ -919,6 +930,7 @@ func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection { isConnected: false, sessions: map[string]bool{}, lastActivity: time.Now().UTC(), + signer: nil, } c.sessions[sessionID] = true return c @@ -931,17 +943,6 @@ func (c *sftpConnection) OpenConnection() error { return c.openConnNoLock() } -func (c *sftpConnection) getKeySigner() (ssh.Signer, error) { - privPayload := c.config.PrivateKey.GetPayload() - if privPayload == "" { - return nil, nil - } - if key := c.config.KeyPassphrase.GetPayload(); key != "" { - return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key)) - } - return ssh.ParsePrivateKey([]byte(privPayload)) -} - func (c *sftpConnection) openConnNoLock() error { if c.isConnected { logger.Debug(c.logSender, "", "reusing connection") @@ -979,12 +980,8 @@ func (c *sftpConnection) openConnNoLock() error { Timeout: 15 * time.Second, ClientVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)), } - signer, err := c.getKeySigner() - if err != nil { - return fmt.Errorf("sftpfs: unable to parse the private key: %w", err) - } - if signer != nil { - clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer)) + if c.signer != nil { + clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(c.signer)) } if pwd := c.config.Password.GetPayload(); pwd != "" { clientConfig.Auth = append(clientConfig.Auth, ssh.Password(pwd)) @@ -1156,7 +1153,7 @@ func newSFTPConnectionCache() *sftpConnectionsCache { return c } -func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftpConnection { +func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) (*sftpConnection, error) { partition := 0 key := config.getUniqueID(partition) @@ -1172,7 +1169,7 @@ func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftp "reusing connection for session ID %q, key: %d, active sessions %d, active connections: %d", sessionID, key, activeSessions+1, len(c.items)) val.AddSession(sessionID) - return val + return val, nil } partition++ oldKey = key @@ -1182,11 +1179,16 @@ func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) *sftp partition, activeSessions, oldKey, key) } else { conn := newSFTPConnection(config, sessionID) + signer, err := config.getKeySigner() + if err != nil { + return nil, fmt.Errorf("sftpfs: unable to parse the private key: %w", err) + } + conn.signer = signer c.items[key] = conn logger.Debug(logSenderSFTPCache, "", "adding new connection for session ID %q, partition: %d, key: %d, active connections: %d", sessionID, partition, key, len(c.items)) - return conn + return conn, nil } } }