Browse Source

hooks: preserve MFA related configs

if a user is updated using pre-login or external auth hook we need to
preserve the MFA related configs in the same way we do if the user is
updated using the REST API
Nicola Murino 3 years ago
parent
commit
1472a0f415
2 changed files with 200 additions and 0 deletions
  1. 11 0
      dataprovider/dataprovider.go
  2. 189 0
      sftpd/sftpd_test.go

+ 11 - 0
dataprovider/dataprovider.go

@@ -2665,6 +2665,8 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
 	userLastQuotaUpdate := u.LastQuotaUpdate
 	userLastLogin := u.LastLogin
 	userCreatedAt := u.CreatedAt
+	totpConfig := u.Filters.TOTPConfig
+	recoveryCodes := u.Filters.RecoveryCodes
 	err = json.Unmarshal(out, &u)
 	if err != nil {
 		return u, fmt.Errorf("invalid pre-login hook response %#v, error: %v", string(out), err)
@@ -2679,6 +2681,9 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
 		err = provider.addUser(&u)
 	} else {
 		u.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
+		// preserve TOTP config and recovery codes
+		u.Filters.TOTPConfig = totpConfig
+		u.Filters.RecoveryCodes = recoveryCodes
 		err = provider.updateUser(&u)
 		if err == nil {
 			webDAVUsersCache.swap(&u)
@@ -2881,6 +2886,9 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
 		user.LastLogin = u.LastLogin
 		user.CreatedAt = u.CreatedAt
 		user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
+		// preserve TOTP config and recovery codes
+		user.Filters.TOTPConfig = u.Filters.TOTPConfig
+		user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
 		err = provider.updateUser(&user)
 		if err == nil {
 			webDAVUsersCache.swap(&user)
@@ -2942,6 +2950,9 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
 		user.UsedQuotaFiles = u.UsedQuotaFiles
 		user.LastQuotaUpdate = u.LastQuotaUpdate
 		user.LastLogin = u.LastLogin
+		// preserve TOTP config and recovery codes
+		user.Filters.TOTPConfig = u.Filters.TOTPConfig
+		user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
 		err = provider.updateUser(&user)
 		if err == nil {
 			webDAVUsersCache.swap(&user)

+ 189 - 0
sftpd/sftpd_test.go

@@ -2489,6 +2489,100 @@ func TestPreLoginUserCreation(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestPreLoginHookPreserveMFAConfig(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("this test is not available on Windows")
+	}
+	usePubKey := false
+	u := getTestUser(usePubKey)
+	err := dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf := config.GetProviderConf()
+	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
+	assert.NoError(t, err)
+	providerConf.PreLoginHook = preLoginPath
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+
+	conn, client, err := getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+	// add multi-factor authentication
+	user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 0)
+	assert.False(t, user.Filters.TOTPConfig.Enabled)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+	user.Password = defaultPassword
+	user.Filters.TOTPConfig = sdk.TOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolSSH},
+	}
+	for i := 0; i < 12; i++ {
+		user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
+			Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
+		})
+	}
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	err = os.WriteFile(extAuthPath, getExitCodeScriptContent(0), os.ModePerm)
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	err = dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf = config.GetProviderConf()
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+	err = os.Remove(preLoginPath)
+	assert.NoError(t, err)
+}
+
 func TestPreDownloadHook(t *testing.T) {
 	if runtime.GOOS == osWindows {
 		t.Skip("this test is not available on Windows")
@@ -3336,6 +3430,101 @@ func TestLoginExternalAuthErrors(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestExternalAuthPreserveMFAConfig(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("this test is not available on Windows")
+	}
+	usePubKey := false
+	u := getTestUser(usePubKey)
+	err := dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf := config.GetProviderConf()
+	err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm)
+	assert.NoError(t, err)
+	providerConf.ExternalAuthHook = extAuthPath
+	providerConf.ExternalAuthScope = 0
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+
+	conn, client, err := getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+	// add multi-factor authentication
+	user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 0)
+	assert.False(t, user.Filters.TOTPConfig.Enabled)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+	user.Password = defaultPassword
+	user.Filters.TOTPConfig = sdk.TOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolSSH},
+	}
+	for i := 0; i < 12; i++ {
+		user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
+			Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
+		})
+	}
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+	// login again and check that the MFA configs are preserved
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm)
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	err = dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf = config.GetProviderConf()
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+	err = os.Remove(extAuthPath)
+	assert.NoError(t, err)
+}
+
 func TestQuotaDisabledError(t *testing.T) {
 	err := dataprovider.Close()
 	assert.NoError(t, err)