From 0661876e99f4691084e2175985a602d41ff771ce Mon Sep 17 00:00:00 2001 From: Maharanjan Date: Sat, 25 Sep 2021 19:06:13 +0200 Subject: [PATCH] Added email field for user account --- dataprovider/admin.go | 2 +- dataprovider/bolt.go | 10 +++++++--- dataprovider/dataprovider.go | 3 +++ dataprovider/mysql.go | 37 +++++++++++++++++++++++++++++++++++- dataprovider/pgsql.go | 36 ++++++++++++++++++++++++++++++++++- dataprovider/sqlcommon.go | 14 +++++++++----- dataprovider/sqlite.go | 36 ++++++++++++++++++++++++++++++++++- dataprovider/sqlqueries.go | 15 ++++++++------- dataprovider/user.go | 1 + httpd/httpd_test.go | 27 ++++++++++++++++++++++++-- httpd/schema/openapi.yaml | 3 +++ httpd/webadmin.go | 3 +++ httpdtest/httpdtest.go | 31 ++++++++++++++++++++---------- sdk/user.go | 2 ++ templates/webadmin/user.html | 8 ++++++++ 15 files changed, 197 insertions(+), 31 deletions(-) diff --git a/dataprovider/admin.go b/dataprovider/admin.go index 7e34f564..166643d4 100644 --- a/dataprovider/admin.go +++ b/dataprovider/admin.go @@ -104,7 +104,7 @@ type Admin struct { // Username Username string `json:"username"` Password string `json:"password,omitempty"` - Email string `json:"email"` + Email string `json:"email,omitempty"` Permissions []string `json:"permissions"` Filters AdminFilters `json:"filters,omitempty"` Description string `json:"description,omitempty"` diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index 0f3c095f..eec999b1 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -20,7 +20,7 @@ import ( ) const ( - boltDatabaseVersion = 12 + boltDatabaseVersion = 13 ) var ( @@ -1155,9 +1155,11 @@ func (p *BoltProvider) migrateDatabase() error { logger.ErrorToConsole("%v", err) return err case version == 10: - return updateBoltDatabaseVersion(p.dbHandle, 12) + return updateBoltDatabaseVersion(p.dbHandle, 13) case version == 11: - return updateBoltDatabaseVersion(p.dbHandle, 12) + return updateBoltDatabaseVersion(p.dbHandle, 13) + case version == 12: + return updateBoltDatabaseVersion(p.dbHandle, 13) default: if version > boltDatabaseVersion { providerLog(logger.LevelWarn, "database version %v is newer than the supported one: %v", version, @@ -1179,6 +1181,8 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { + case 13: + return updateBoltDatabaseVersion(p.dbHandle, 10) case 12: return updateBoltDatabaseVersion(p.dbHandle, 10) case 11: diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index d3e208a9..bdf0d962 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -1595,6 +1595,9 @@ func validateBaseParams(user *User) error { if user.Username == "" { return util.NewValidationError("username is mandatory") } + if user.Email != "" && !emailRegex.MatchString(user.Email) { + return util.NewValidationError(fmt.Sprintf("email %#v is not valid", user.Email)) + } if !config.SkipNaturalKeysValidation && !usernameRegex.MatchString(user.Username) { return util.NewValidationError(fmt.Sprintf("username %#v is not valid, the following characters are allowed: a-zA-Z0-9-_.~", user.Username)) diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 24ceb99e..0f9970d8 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -63,6 +63,9 @@ const ( "ALTER TABLE `{{admins}}` DROP COLUMN `last_login`;" + "ALTER TABLE `{{users}}` DROP COLUMN `created_at`;" + "ALTER TABLE `{{users}}` DROP COLUMN `updated_at`;" + + mysqlV13SQL = "ALTER TABLE `{{users}}` ADD COLUMN `email` varchar(255) NULL;" + mysqlV13DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `email`;" ) // MySQLProvider auth provider for MySQL/MariaDB database @@ -307,6 +310,8 @@ func (p *MySQLProvider) migrateDatabase() error { return updateMySQLDatabaseFromV10(p.dbHandle) case version == 11: return updateMySQLDatabaseFromV11(p.dbHandle) + case version == 12: + return updateMySQLDatabaseFromV12(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelWarn, "database version %v is newer than the supported one: %v", version, @@ -329,6 +334,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 13: + return downgradeMySQLDatabaseFromV13(p.dbHandle) case 12: return downgradeMySQLDatabaseFromV12(p.dbHandle) case 11: @@ -346,7 +353,21 @@ func updateMySQLDatabaseFromV10(dbHandle *sql.DB) error { } func updateMySQLDatabaseFromV11(dbHandle *sql.DB) error { - return updateMySQLDatabaseFrom11To12(dbHandle) + if err := updateMySQLDatabaseFrom11To12(dbHandle); err != nil { + return err + } + return updateMySQLDatabaseFromV12(dbHandle) +} + +func updateMySQLDatabaseFromV12(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom12To13(dbHandle) +} + +func downgradeMySQLDatabaseFromV13(dbHandle *sql.DB) error { + if err := downgradeMySQLDatabaseFrom13To12(dbHandle); err != nil { + return err + } + return downgradeMySQLDatabaseFromV12(dbHandle) } func downgradeMySQLDatabaseFromV12(dbHandle *sql.DB) error { @@ -360,6 +381,20 @@ func downgradeMySQLDatabaseFromV11(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFrom11To10(dbHandle) } +func updateMySQLDatabaseFrom12To13(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 12 -> 13") + providerLog(logger.LevelInfo, "updating database version: 12 -> 13") + sql := strings.ReplaceAll(mysqlV13SQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 13) +} + +func downgradeMySQLDatabaseFrom13To12(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 13 -> 12") + providerLog(logger.LevelInfo, "downgrading database version: 13 -> 12") + sql := strings.ReplaceAll(mysqlV13DownSQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 12) +} + func updateMySQLDatabaseFrom11To12(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 11 -> 12") providerLog(logger.LevelInfo, "updating database version: 11 -> 12") diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index 44e98536..13b0ede5 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -76,6 +76,8 @@ ALTER TABLE "{{admins}}" DROP COLUMN "created_at" CASCADE; ALTER TABLE "{{admins}}" DROP COLUMN "updated_at" CASCADE; ALTER TABLE "{{admins}}" DROP COLUMN "last_login" CASCADE; ` + pgsqlV13SQL = `ALTER TABLE "{{users}}" ADD COLUMN "email" varchar(255) NULL;` + pgsqlV13DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "email" CASCADE;` ) // PGSQLProvider auth provider for PostgreSQL database @@ -326,6 +328,8 @@ func (p *PGSQLProvider) migrateDatabase() error { return updatePGSQLDatabaseFromV10(p.dbHandle) case version == 11: return updatePGSQLDatabaseFromV11(p.dbHandle) + case version == 12: + return updatePGSQLDatabaseFromV12(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelWarn, "database version %v is newer than the supported one: %v", version, @@ -348,6 +352,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 13: + return downgradePGSQLDatabaseFromV13(p.dbHandle) case 12: return downgradePGSQLDatabaseFromV12(p.dbHandle) case 11: @@ -365,7 +371,21 @@ func updatePGSQLDatabaseFromV10(dbHandle *sql.DB) error { } func updatePGSQLDatabaseFromV11(dbHandle *sql.DB) error { - return updatePGSQLDatabaseFrom11To12(dbHandle) + if err := updatePGSQLDatabaseFrom11To12(dbHandle); err != nil { + return err + } + return updatePGSQLDatabaseFromV12(dbHandle) +} + +func updatePGSQLDatabaseFromV12(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom12To13(dbHandle) +} + +func downgradePGSQLDatabaseFromV13(dbHandle *sql.DB) error { + if err := downgradePGSQLDatabaseFrom13To12(dbHandle); err != nil { + return err + } + return downgradePGSQLDatabaseFromV12(dbHandle) } func downgradePGSQLDatabaseFromV12(dbHandle *sql.DB) error { @@ -379,6 +399,20 @@ func downgradePGSQLDatabaseFromV11(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFrom11To10(dbHandle) } +func updatePGSQLDatabaseFrom12To13(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 12 -> 13") + providerLog(logger.LevelInfo, "updating database version: 12 -> 13") + sql := strings.ReplaceAll(pgsqlV13SQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 13) +} + +func downgradePGSQLDatabaseFrom13To12(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 13 -> 12") + providerLog(logger.LevelInfo, "downgrading database version: 13 -> 12") + sql := strings.ReplaceAll(pgsqlV13DownSQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 12) +} + func updatePGSQLDatabaseFrom11To12(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 11 -> 12") providerLog(logger.LevelInfo, "updating database version: 11 -> 12") diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 542ff0ed..977b4a57 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -19,7 +19,7 @@ import ( ) const ( - sqlDatabaseVersion = 12 + sqlDatabaseVersion = 13 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -577,7 +577,7 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { } _, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters), - string(fsConfig), user.AdditionalInfo, user.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now())) if err != nil { return err @@ -620,7 +620,8 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { } _, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, - string(filters), string(fsConfig), user.AdditionalInfo, user.Description, util.GetTimeAsMsSinceEpoch(time.Now()), user.ID) + string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), + user.ID) if err != nil { return err } @@ -817,12 +818,12 @@ func getUserFromDbRow(row sqlScanner) (User, error) { var publicKey sql.NullString var filters sql.NullString var fsConfig sql.NullString - var additionalInfo, description sql.NullString + var additionalInfo, description, email sql.NullString err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, - &additionalInfo, &description, &user.CreatedAt, &user.UpdatedAt) + &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt) if err != nil { if err == sql.ErrNoRows { return user, util.NewRecordNotFoundError(err.Error()) @@ -871,6 +872,9 @@ func getUserFromDbRow(row sqlScanner) (User, error) { if description.Valid { user.Description = description.String } + if email.Valid { + user.Email = email.String + } user.SetEmptySecretsIfNil() return user, nil } diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 23e89ccf..71affacd 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -67,6 +67,8 @@ ALTER TABLE "{{admins}}" DROP COLUMN "created_at"; ALTER TABLE "{{admins}}" DROP COLUMN "updated_at"; ALTER TABLE "{{admins}}" DROP COLUMN "last_login"; ` + sqliteV13SQL = `ALTER TABLE "{{users}}" ADD COLUMN "email" varchar(255) NULL;` + sqliteV13DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "email";` ) // SQLiteProvider auth provider for SQLite database @@ -304,6 +306,8 @@ func (p *SQLiteProvider) migrateDatabase() error { return updateSQLiteDatabaseFromV10(p.dbHandle) case version == 11: return updateSQLiteDatabaseFromV11(p.dbHandle) + case version == 12: + return updateSQLiteDatabaseFromV12(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelWarn, "database version %v is newer than the supported one: %v", version, @@ -326,6 +330,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { } switch dbVersion.Version { + case 13: + return downgradeSQLiteDatabaseFromV13(p.dbHandle) case 12: return downgradeSQLiteDatabaseFromV12(p.dbHandle) case 11: @@ -343,7 +349,21 @@ func updateSQLiteDatabaseFromV10(dbHandle *sql.DB) error { } func updateSQLiteDatabaseFromV11(dbHandle *sql.DB) error { - return updateSQLiteDatabaseFrom11To12(dbHandle) + if err := updateSQLiteDatabaseFrom11To12(dbHandle); err != nil { + return err + } + return updateSQLiteDatabaseFromV12(dbHandle) +} + +func updateSQLiteDatabaseFromV12(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom12To13(dbHandle) +} + +func downgradeSQLiteDatabaseFromV13(dbHandle *sql.DB) error { + if err := downgradeSQLiteDatabaseFrom13To12(dbHandle); err != nil { + return err + } + return downgradeSQLiteDatabaseFromV12(dbHandle) } func downgradeSQLiteDatabaseFromV12(dbHandle *sql.DB) error { @@ -357,6 +377,20 @@ func downgradeSQLiteDatabaseFromV11(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFrom11To10(dbHandle) } +func updateSQLiteDatabaseFrom12To13(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 12 -> 13") + providerLog(logger.LevelInfo, "updating database version: 12 -> 13") + sql := strings.ReplaceAll(sqliteV13SQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 13) +} + +func downgradeSQLiteDatabaseFrom13To12(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 13 -> 12") + providerLog(logger.LevelInfo, "downgrading database version: 13 -> 12") + sql := strings.ReplaceAll(sqliteV13DownSQL, "{{users}}", sqlTableUsers) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 12) +} + func updateSQLiteDatabaseFrom11To12(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 11 -> 12") providerLog(logger.LevelInfo, "updating database version: 11 -> 12") diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go index 8310f98b..21ee6cb3 100644 --- a/dataprovider/sqlqueries.go +++ b/dataprovider/sqlqueries.go @@ -11,7 +11,7 @@ import ( const ( selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," + "used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," + - "additional_info,description,created_at,updated_at" + "additional_info,description,email,created_at,updated_at" selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem" selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login" selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id" @@ -19,7 +19,7 @@ const ( func getSQLPlaceholders() []string { var placeholders []string - for i := 1; i <= 20; i++ { + for i := 1; i <= 30; i++ { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { placeholders = append(placeholders, fmt.Sprintf("$%v", i)) } else { @@ -185,20 +185,21 @@ func getQuotaQuery() string { func getAddUserQuery() string { return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters, - filesystem,additional_info,description,created_at,updated_at) - VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v)`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], + filesystem,additional_info,description,email,created_at,updated_at) + VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v,%v)`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], - sqlPlaceholders[14], sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19]) + sqlPlaceholders[14], sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], + sqlPlaceholders[20]) } func getUpdateUserQuery() string { return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v, quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v,status=%v,expiration_date=%v,filters=%v,filesystem=%v, - additional_info=%v,description=%v,updated_at=%v WHERE id = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], + additional_info=%v,description=%v,email=%v,updated_at=%v WHERE id = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], sqlPlaceholders[15], - sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18]) + sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19]) } func getDeleteUserQuery() string { diff --git a/dataprovider/user.go b/dataprovider/user.go index 819653a7..39c51ff3 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -1075,6 +1075,7 @@ func (u *User) getACopy() User { BaseUser: sdk.BaseUser{ ID: u.ID, Username: u.Username, + Email: u.Email, Password: u.Password, PublicKeys: pubKeys, HomeDir: u.HomeDir, diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 4621704c..fe69bb85 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -427,7 +427,9 @@ func TestInitialization(t *testing.T) { } func TestBasicUserHandling(t *testing.T) { - user, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + u := getTestUser() + u.Email = "user@user.com" + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err, string(resp)) user.MaxSessions = 10 user.QuotaSize = 4096 @@ -437,6 +439,7 @@ func TestBasicUserHandling(t *testing.T) { user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) user.AdditionalInfo = "some free text" user.Filters.TLSUsername = sdk.TLSUsernameCN + user.Email = "user@example.net" user.Filters.WebClient = append(user.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled, sdk.WebClientWriteDisabled) originalUser := user @@ -446,6 +449,12 @@ func TestBasicUserHandling(t *testing.T) { user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) + + user.Email = "invalid@email" + _, body, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(body), "Validation error: email") + _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) } @@ -1219,6 +1228,14 @@ func TestAddUserNoPerms(t *testing.T) { assert.NoError(t, err) } +func TestAddUserInvalidEmail(t *testing.T) { + u := getTestUser() + u.Email = "invalid_email" + _, body, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(body), "Validation error: email") +} + func TestAddUserInvalidPerms(t *testing.T) { u := getTestUser() u.Permissions["/"] = []string{"invalidPerm"} @@ -3338,7 +3355,7 @@ func TestSkipNaturalKeysValidation(t *testing.T) { assert.NoError(t, err) u := getTestUser() - u.Username = "user@example.com" + u.Username = "user@user.me" user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) user.AdditionalInfo = "info" @@ -10856,6 +10873,7 @@ func TestWebUserAddMock(t *testing.T) { user.UID = 1000 user.AdditionalInfo = "info" user.Description = "user dsc" + user.Email = "test@test.com" mappedDir := filepath.Join(os.TempDir(), "mapped") folderName := filepath.Base(mappedDir) f := vfs.BaseVirtualFolder{ @@ -10872,6 +10890,7 @@ func TestWebUserAddMock(t *testing.T) { form := make(url.Values) form.Set(csrfFormToken, csrfToken) form.Set("username", user.Username) + form.Set("email", user.Email) form.Set("home_dir", user.HomeDir) form.Set("password", user.Password) form.Set("status", strconv.Itoa(user.Status)) @@ -11068,6 +11087,7 @@ func TestWebUserAddMock(t *testing.T) { assert.False(t, newUser.Filters.Hooks.CheckPasswordDisabled) assert.True(t, newUser.Filters.DisableFsChecks) assert.False(t, newUser.Filters.AllowAPIKeyAuth) + assert.Equal(t, user.Email, newUser.Email) assert.True(t, util.IsStringInSlice(testPubKey, newUser.PublicKeys)) if val, ok := newUser.Permissions["/subdir"]; ok { assert.True(t, util.IsStringInSlice(dataprovider.PermListItems, val)) @@ -11172,8 +11192,10 @@ func TestWebUserUpdateMock(t *testing.T) { user.GID = 1000 user.Filters.AllowAPIKeyAuth = true user.AdditionalInfo = "new additional info" + user.Email = "user@example.com" form := make(url.Values) form.Set("username", user.Username) + form.Set("email", user.Email) form.Set("password", "") form.Set("public_keys", testPubKey) form.Set("home_dir", user.HomeDir) @@ -11257,6 +11279,7 @@ func TestWebUserUpdateMock(t *testing.T) { var updateUser dataprovider.User err = render.DecodeJSON(rr.Body, &updateUser) assert.NoError(t, err) + assert.Equal(t, user.Email, updateUser.Email) assert.Equal(t, user.HomeDir, updateUser.HomeDir) assert.Equal(t, user.MaxSessions, updateUser.MaxSessions) assert.Equal(t, user.QuotaFiles, updateUser.QuotaFiles) diff --git a/httpd/schema/openapi.yaml b/httpd/schema/openapi.yaml index 0df46cb2..dddfe044 100644 --- a/httpd/schema/openapi.yaml +++ b/httpd/schema/openapi.yaml @@ -3660,6 +3660,9 @@ components: username: type: string description: username is unique + email: + type: string + format: email description: type: string description: 'optional description, for example the user full name' diff --git a/httpd/webadmin.go b/httpd/webadmin.go index a25acf30..fc0c62de 100644 --- a/httpd/webadmin.go +++ b/httpd/webadmin.go @@ -1084,6 +1084,7 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { user = dataprovider.User{ BaseUser: sdk.BaseUser{ Username: r.Form.Get("username"), + Email: r.Form.Get("email"), Password: r.Form.Get("password"), PublicKeys: r.Form["public_keys"], HomeDir: r.Form.Get("home_dir"), @@ -1475,6 +1476,8 @@ func handleWebTemplateUserGet(w http.ResponseWriter, r *http.Request) { user, err := dataprovider.UserExists(username) if err == nil { user.SetEmptySecrets() + user.Email = "" + user.Description = "" renderUserPage(w, r, &user, userPageModeTemplate, "") } else if _, ok := err.(*util.RecordNotFoundError); ok { renderNotFoundPage(w, r, err) diff --git a/httpdtest/httpdtest.go b/httpdtest/httpdtest.go index eeb4c98d..663e78b2 100644 --- a/httpdtest/httpdtest.go +++ b/httpdtest/httpdtest.go @@ -1161,6 +1161,26 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) } } + + if expected.Email != actual.Email { + return errors.New("email mismatch") + } + if err := compareUserPermissions(expected, actual); err != nil { + return err + } + if err := compareUserFilters(expected, actual); err != nil { + return err + } + if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { + return err + } + if err := compareUserVirtualFolders(expected, actual); err != nil { + return err + } + return compareEqualsUserFields(expected, actual) +} + +func compareUserPermissions(expected *dataprovider.User, actual *dataprovider.User) error { if len(expected.Permissions) != len(actual.Permissions) { return errors.New("permissions mismatch") } @@ -1175,16 +1195,7 @@ func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { return errors.New("permissions directories mismatch") } } - if err := compareUserFilters(expected, actual); err != nil { - return err - } - if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { - return err - } - if err := compareUserVirtualFolders(expected, actual); err != nil { - return err - } - return compareEqualsUserFields(expected, actual) + return nil } func compareUserVirtualFolders(expected *dataprovider.User, actual *dataprovider.User) error { diff --git a/sdk/user.go b/sdk/user.go index fa568554..8db50e0b 100644 --- a/sdk/user.go +++ b/sdk/user.go @@ -178,6 +178,8 @@ type BaseUser struct { Status int `json:"status"` // Username Username string `json:"username"` + // Email + Email string `json:"email,omitempty"` // Account expiration date as unix timestamp in milliseconds. An expired account cannot login. // 0 means no expiration ExpirationDate int64 `json:"expiration_date"` diff --git a/templates/webadmin/user.html b/templates/webadmin/user.html index ae38c47c..5edf3693 100644 --- a/templates/webadmin/user.html +++ b/templates/webadmin/user.html @@ -86,6 +86,14 @@ {{end}} +
+ +
+ +
+
+