hooks: preserve MFA related configs
if a user is updated using pre-login or external auth hook we need to preserve the MFA related configs in the same way we do if the user is updated using the REST API
This commit is contained in:
parent
0bb141960f
commit
1472a0f415
2 changed files with 200 additions and 0 deletions
|
@ -2665,6 +2665,8 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
|
|||
userLastQuotaUpdate := u.LastQuotaUpdate
|
||||
userLastLogin := u.LastLogin
|
||||
userCreatedAt := u.CreatedAt
|
||||
totpConfig := u.Filters.TOTPConfig
|
||||
recoveryCodes := u.Filters.RecoveryCodes
|
||||
err = json.Unmarshal(out, &u)
|
||||
if err != nil {
|
||||
return u, fmt.Errorf("invalid pre-login hook response %#v, error: %v", string(out), err)
|
||||
|
@ -2679,6 +2681,9 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
|
|||
err = provider.addUser(&u)
|
||||
} else {
|
||||
u.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
// preserve TOTP config and recovery codes
|
||||
u.Filters.TOTPConfig = totpConfig
|
||||
u.Filters.RecoveryCodes = recoveryCodes
|
||||
err = provider.updateUser(&u)
|
||||
if err == nil {
|
||||
webDAVUsersCache.swap(&u)
|
||||
|
@ -2881,6 +2886,9 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
|
|||
user.LastLogin = u.LastLogin
|
||||
user.CreatedAt = u.CreatedAt
|
||||
user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
// preserve TOTP config and recovery codes
|
||||
user.Filters.TOTPConfig = u.Filters.TOTPConfig
|
||||
user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
|
||||
err = provider.updateUser(&user)
|
||||
if err == nil {
|
||||
webDAVUsersCache.swap(&user)
|
||||
|
@ -2942,6 +2950,9 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
|
|||
user.UsedQuotaFiles = u.UsedQuotaFiles
|
||||
user.LastQuotaUpdate = u.LastQuotaUpdate
|
||||
user.LastLogin = u.LastLogin
|
||||
// preserve TOTP config and recovery codes
|
||||
user.Filters.TOTPConfig = u.Filters.TOTPConfig
|
||||
user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
|
||||
err = provider.updateUser(&user)
|
||||
if err == nil {
|
||||
webDAVUsersCache.swap(&user)
|
||||
|
|
|
@ -2489,6 +2489,100 @@ func TestPreLoginUserCreation(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPreLoginHookPreserveMFAConfig(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
}
|
||||
usePubKey := false
|
||||
u := getTestUser(usePubKey)
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
providerConf.PreLoginHook = preLoginPath
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err := getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
// add multi-factor authentication
|
||||
user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 0)
|
||||
assert.False(t, user.Filters.TOTPConfig.Enabled)
|
||||
configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
|
||||
assert.NoError(t, err)
|
||||
user.Password = defaultPassword
|
||||
user.Filters.TOTPConfig = sdk.TOTPConfig{
|
||||
Enabled: true,
|
||||
ConfigName: configName,
|
||||
Secret: kms.NewPlainSecret(secret),
|
||||
Protocols: []string{common.ProtocolSSH},
|
||||
}
|
||||
for i := 0; i < 12; i++ {
|
||||
user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
|
||||
Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
|
||||
})
|
||||
}
|
||||
err = dataprovider.UpdateUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err = getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
|
||||
user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 12)
|
||||
assert.True(t, user.Filters.TOTPConfig.Enabled)
|
||||
assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
|
||||
assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
|
||||
assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
|
||||
|
||||
err = os.WriteFile(extAuthPath, getExitCodeScriptContent(0), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err = getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
|
||||
user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 12)
|
||||
assert.True(t, user.Filters.TOTPConfig.Enabled)
|
||||
assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
|
||||
assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
|
||||
assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf = config.GetProviderConf()
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(preLoginPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPreDownloadHook(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
|
@ -3336,6 +3430,101 @@ func TestLoginExternalAuthErrors(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestExternalAuthPreserveMFAConfig(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
}
|
||||
usePubKey := false
|
||||
u := getTestUser(usePubKey)
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
providerConf.ExternalAuthHook = extAuthPath
|
||||
providerConf.ExternalAuthScope = 0
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err := getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
// add multi-factor authentication
|
||||
user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 0)
|
||||
assert.False(t, user.Filters.TOTPConfig.Enabled)
|
||||
configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
|
||||
assert.NoError(t, err)
|
||||
user.Password = defaultPassword
|
||||
user.Filters.TOTPConfig = sdk.TOTPConfig{
|
||||
Enabled: true,
|
||||
ConfigName: configName,
|
||||
Secret: kms.NewPlainSecret(secret),
|
||||
Protocols: []string{common.ProtocolSSH},
|
||||
}
|
||||
for i := 0; i < 12; i++ {
|
||||
user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
|
||||
Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
|
||||
})
|
||||
}
|
||||
err = dataprovider.UpdateUser(&user, "", "")
|
||||
assert.NoError(t, err)
|
||||
// login again and check that the MFA configs are preserved
|
||||
conn, client, err = getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
|
||||
user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 12)
|
||||
assert.True(t, user.Filters.TOTPConfig.Enabled)
|
||||
assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
|
||||
assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
|
||||
assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
|
||||
|
||||
err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
|
||||
conn, client, err = getSftpClient(u, usePubKey)
|
||||
if assert.NoError(t, err) {
|
||||
defer conn.Close()
|
||||
defer client.Close()
|
||||
assert.NoError(t, checkBasicSFTP(client))
|
||||
}
|
||||
|
||||
user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, user.Filters.RecoveryCodes, 12)
|
||||
assert.True(t, user.Filters.TOTPConfig.Enabled)
|
||||
assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
|
||||
assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
|
||||
assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
|
||||
|
||||
_, err = httpdtest.RemoveUser(user, http.StatusOK)
|
||||
assert.NoError(t, err)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf = config.GetProviderConf()
|
||||
err = dataprovider.Initialize(providerConf, configDir, true)
|
||||
assert.NoError(t, err)
|
||||
err = os.Remove(extAuthPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestQuotaDisabledError(t *testing.T) {
|
||||
err := dataprovider.Close()
|
||||
assert.NoError(t, err)
|
||||
|
|
Loading…
Reference in a new issue