Преглед на файлове

hooks: preserve MFA related configs

if a user is updated using pre-login or external auth hook we need to
preserve the MFA related configs in the same way we do if the user is
updated using the REST API
Nicola Murino преди 3 години
родител
ревизия
1472a0f415
променени са 2 файла, в които са добавени 200 реда и са изтрити 0 реда
  1. 11 0
      dataprovider/dataprovider.go
  2. 189 0
      sftpd/sftpd_test.go

+ 11 - 0
dataprovider/dataprovider.go

@@ -2665,6 +2665,8 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
 	userLastQuotaUpdate := u.LastQuotaUpdate
 	userLastLogin := u.LastLogin
 	userCreatedAt := u.CreatedAt
+	totpConfig := u.Filters.TOTPConfig
+	recoveryCodes := u.Filters.RecoveryCodes
 	err = json.Unmarshal(out, &u)
 	if err != nil {
 		return u, fmt.Errorf("invalid pre-login hook response %#v, error: %v", string(out), err)
@@ -2679,6 +2681,9 @@ func executePreLoginHook(username, loginMethod, ip, protocol string) (User, erro
 		err = provider.addUser(&u)
 	} else {
 		u.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
+		// preserve TOTP config and recovery codes
+		u.Filters.TOTPConfig = totpConfig
+		u.Filters.RecoveryCodes = recoveryCodes
 		err = provider.updateUser(&u)
 		if err == nil {
 			webDAVUsersCache.swap(&u)
@@ -2881,6 +2886,9 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
 		user.LastLogin = u.LastLogin
 		user.CreatedAt = u.CreatedAt
 		user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
+		// preserve TOTP config and recovery codes
+		user.Filters.TOTPConfig = u.Filters.TOTPConfig
+		user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
 		err = provider.updateUser(&user)
 		if err == nil {
 			webDAVUsersCache.swap(&user)
@@ -2942,6 +2950,9 @@ func doPluginAuth(username, password string, pubKey []byte, ip, protocol string,
 		user.UsedQuotaFiles = u.UsedQuotaFiles
 		user.LastQuotaUpdate = u.LastQuotaUpdate
 		user.LastLogin = u.LastLogin
+		// preserve TOTP config and recovery codes
+		user.Filters.TOTPConfig = u.Filters.TOTPConfig
+		user.Filters.RecoveryCodes = u.Filters.RecoveryCodes
 		err = provider.updateUser(&user)
 		if err == nil {
 			webDAVUsersCache.swap(&user)

+ 189 - 0
sftpd/sftpd_test.go

@@ -2489,6 +2489,100 @@ func TestPreLoginUserCreation(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestPreLoginHookPreserveMFAConfig(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("this test is not available on Windows")
+	}
+	usePubKey := false
+	u := getTestUser(usePubKey)
+	err := dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf := config.GetProviderConf()
+	err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm)
+	assert.NoError(t, err)
+	providerConf.PreLoginHook = preLoginPath
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+
+	conn, client, err := getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+	// add multi-factor authentication
+	user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 0)
+	assert.False(t, user.Filters.TOTPConfig.Enabled)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+	user.Password = defaultPassword
+	user.Filters.TOTPConfig = sdk.TOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolSSH},
+	}
+	for i := 0; i < 12; i++ {
+		user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
+			Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
+		})
+	}
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	err = os.WriteFile(extAuthPath, getExitCodeScriptContent(0), os.ModePerm)
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	err = dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf = config.GetProviderConf()
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+	err = os.Remove(preLoginPath)
+	assert.NoError(t, err)
+}
+
 func TestPreDownloadHook(t *testing.T) {
 	if runtime.GOOS == osWindows {
 		t.Skip("this test is not available on Windows")
@@ -3336,6 +3430,101 @@ func TestLoginExternalAuthErrors(t *testing.T) {
 	assert.NoError(t, err)
 }
 
+func TestExternalAuthPreserveMFAConfig(t *testing.T) {
+	if runtime.GOOS == osWindows {
+		t.Skip("this test is not available on Windows")
+	}
+	usePubKey := false
+	u := getTestUser(usePubKey)
+	err := dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf := config.GetProviderConf()
+	err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm)
+	assert.NoError(t, err)
+	providerConf.ExternalAuthHook = extAuthPath
+	providerConf.ExternalAuthScope = 0
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+
+	conn, client, err := getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+	// add multi-factor authentication
+	user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 0)
+	assert.False(t, user.Filters.TOTPConfig.Enabled)
+	configName, _, secret, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username)
+	assert.NoError(t, err)
+	user.Password = defaultPassword
+	user.Filters.TOTPConfig = sdk.TOTPConfig{
+		Enabled:    true,
+		ConfigName: configName,
+		Secret:     kms.NewPlainSecret(secret),
+		Protocols:  []string{common.ProtocolSSH},
+	}
+	for i := 0; i < 12; i++ {
+		user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, sdk.RecoveryCode{
+			Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))),
+		})
+	}
+	err = dataprovider.UpdateUser(&user, "", "")
+	assert.NoError(t, err)
+	// login again and check that the MFA configs are preserved
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm)
+	assert.NoError(t, err)
+
+	conn, client, err = getSftpClient(u, usePubKey)
+	if assert.NoError(t, err) {
+		defer conn.Close()
+		defer client.Close()
+		assert.NoError(t, checkBasicSFTP(client))
+	}
+
+	user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK)
+	assert.NoError(t, err)
+	assert.Len(t, user.Filters.RecoveryCodes, 12)
+	assert.True(t, user.Filters.TOTPConfig.Enabled)
+	assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName)
+	assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols)
+	assert.Equal(t, kms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus())
+
+	_, err = httpdtest.RemoveUser(user, http.StatusOK)
+	assert.NoError(t, err)
+	err = os.RemoveAll(user.GetHomeDir())
+	assert.NoError(t, err)
+
+	err = dataprovider.Close()
+	assert.NoError(t, err)
+	err = config.LoadConfig(configDir, "")
+	assert.NoError(t, err)
+	providerConf = config.GetProviderConf()
+	err = dataprovider.Initialize(providerConf, configDir, true)
+	assert.NoError(t, err)
+	err = os.Remove(extAuthPath)
+	assert.NoError(t, err)
+}
+
 func TestQuotaDisabledError(t *testing.T) {
 	err := dataprovider.Close()
 	assert.NoError(t, err)