diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go index 9d68c095..23fac6b9 100644 --- a/internal/common/protocol_test.go +++ b/internal/common/protocol_test.go @@ -81,6 +81,45 @@ const ( osWindows = "windows" testFileName = "test_file_common_sftp.dat" testDir = "test_dir_common" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq +zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA +PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo +XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ +cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E +NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi +XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD +10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 +EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX +ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb +GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs +TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 ++9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj +CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m +GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ +gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy +0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ +2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S +1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ +CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI +YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI +jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF +quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA +wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR +nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU +u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv +sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr +Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa +Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe +A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD +mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY +0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 +NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt +RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm +iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= +-----END OPENSSH PRIVATE KEY-----` ) var ( @@ -7812,6 +7851,69 @@ func TestBuiltinKeyboardInteractiveAuthentication(t *testing.T) { assert.NoError(t, err) } +func TestMultiStepBuiltinKeyboardAuth(t *testing.T) { + u := getTestUser() + u.PublicKeys = []string{testPubKey} + u.Filters.DeniedLoginMethods = []string{ + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + assert.NoError(t, err) + // public key + password + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{defaultPassword}, nil + }), + } + conn, client, err := getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + // add multi-factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + // public key + passcode + authMethods = []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{passcode}, nil + }), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + func TestRenameSymlink(t *testing.T) { u := getTestUser() testDir := "/dir-no-create-links" diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index a5cfafcc..cef652a1 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -1371,7 +1371,9 @@ func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string, isS // CheckKeyboardInteractiveAuth checks the keyboard interactive authentication and returns // the authenticated user or an error -func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) { +func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (User, error) { var user User var err error username = config.convertName(username) @@ -1387,7 +1389,7 @@ func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.Keyboard if err != nil { return user, err } - return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol) + return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol, isPartialAuth) } // GetFTPPreAuthUser returns the SFTPGo user with the specified username @@ -3624,21 +3626,25 @@ func sendKeyboardAuthHTTPReq(url string, request *plugin.KeyboardAuthRequest) (* return &response, err } -func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { - answers, err := client("", "", []string{"Password: "}, []bool{false}) - if err != nil { +func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (int, error) { + if err := user.LoadAndApplyGroupSettings(); err != nil { return 0, err } - if len(answers) != 1 { - return 0, fmt.Errorf("unexpected number of answers: %d", len(answers)) - } - err = user.LoadAndApplyGroupSettings() - if err != nil { - return 0, err - } - _, err = checkUserAndPass(user, answers[0], ip, protocol) - if err != nil { - return 0, err + hasSecondFactor := user.Filters.TOTPConfig.Enabled && util.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) + if !isPartialAuth || !hasSecondFactor { + answers, err := client("", "", []string{"Password: "}, []bool{false}) + if err != nil { + return 0, err + } + if len(answers) != 1 { + return 0, fmt.Errorf("unexpected number of answers: %d", len(answers)) + } + _, err = checkUserAndPass(user, answers[0], ip, protocol) + if err != nil { + return 0, err + } } return checkKeyboardInteractiveSecondFactor(user, client, protocol) } @@ -3881,7 +3887,9 @@ func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.K return authResult, err } -func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (User, error) { +func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (User, error) { if err := user.LoadAndApplyGroupSettings(); err != nil { return *user, err } @@ -3900,10 +3908,10 @@ func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardI authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol) } } else { - authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol) + authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) } } else { - authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol) + authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) } if err != nil { return *user, err diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go index ca3bbf5b..20a574b6 100644 --- a/internal/sftpd/server.go +++ b/internal/sftpd/server.go @@ -588,7 +588,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve } } serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { - return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive) + return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive, false) } serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) @@ -1193,7 +1193,7 @@ func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error { } if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) { err.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { - return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt) + return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true) } } return err @@ -1288,7 +1288,7 @@ func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass } func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge, - method string, + method string, isPartialAuth bool, ) (*ssh.Permissions, error) { var err error var user dataprovider.User @@ -1296,7 +1296,7 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client, - ipAddr, common.ProtocolSSH); err == nil { + ipAddr, common.ProtocolSSH, isPartialAuth); err == nil { sshPerm, err = loginUser(&user, method, "", conn) } user.Username = conn.User() diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index a5813322..15b5ddc0 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -11310,7 +11310,7 @@ func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMetho } var err error var conn *ssh.Client - if len(addr) > 0 { + if addr != "" { conn, err = ssh.Dial("tcp", addr, config) } else { conn, err = ssh.Dial("tcp", sftpServerAddr, config)