diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 8be76120d..fc2300d93 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -194,7 +194,7 @@ func main() { commonBillController := commonbilling.NewController(storagBonusRepo, userRepo, usageRepo) appStoreController := controller.NewAppStoreController(defaultPlan, billingRepo, fileRepo, userRepo, commonBillController) - + remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository} playStoreController := controller.NewPlayStoreController(defaultPlan, billingRepo, fileRepo, userRepo, storagBonusRepo, commonBillController) stripeController := controller.NewStripeController(plans, stripeClients, @@ -610,6 +610,7 @@ func main() { UserAuthRepo: userAuthRepo, UserController: userController, FamilyController: familyController, + RemoteStoreController: remoteStoreController, FileRepo: fileRepo, StorageBonusRepo: storagBonusRepo, BillingRepo: billingRepo, @@ -631,6 +632,7 @@ func main() { adminAPI.PUT("/user/change-email", adminHandler.ChangeEmail) adminAPI.DELETE("/user/delete", adminHandler.DeleteUser) adminAPI.POST("/user/recover", adminHandler.RecoverAccount) + adminAPI.POST("/user/update-flag", adminHandler.UpdateFeatureFlag) adminAPI.GET("/email-hash", adminHandler.GetEmailHash) adminAPI.POST("/emails-from-hashes", adminHandler.GetEmailsFromHashes) adminAPI.PUT("/user/subscription", adminHandler.UpdateSubscription) @@ -658,7 +660,6 @@ func main() { privateAPI.DELETE("/authenticator/entity", authenticatorHandler.DeleteEntity) privateAPI.GET("/authenticator/entity/diff", authenticatorHandler.GetDiff) - remoteStoreController := &remoteStoreCtrl.Controller{Repo: remoteStoreRepository} dataCleanupController := &dataCleanupCtrl.DeleteUserCleanupController{ Repo: dataCleanupRepository, UserRepo: userRepo, @@ -672,6 +673,7 @@ func main() { privateAPI.POST("/remote-store/update", remoteStoreHandler.InsertOrUpdate) privateAPI.GET("/remote-store", remoteStoreHandler.GetKey) + privateAPI.GET("/remote-store/feature-flags", remoteStoreHandler.GetFeatureFlags) pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) diff --git a/server/ente/remotestore.go b/server/ente/remotestore.go index 02eb93232..8f518f2a1 100644 --- a/server/ente/remotestore.go +++ b/server/ente/remotestore.go @@ -13,3 +13,66 @@ type UpdateKeyValueRequest struct { Key string `json:"key" binding:"required"` Value string `json:"value" binding:"required"` } + +type AdminUpdateKeyValueRequest struct { + UserID int64 `json:"userID" binding:"required"` + Key string `json:"key" binding:"required"` + Value string `json:"value" binding:"required"` +} + +type FeatureFlagResponse struct { + EnableStripe bool `json:"enableStripe"` + // If true, the mobile client will stop using CF worker to download files + DisableCFWorker bool `json:"disableCFWorker"` + MapEnabled bool `json:"mapEnabled"` + FaceSearchEnabled bool `json:"faceSearchEnabled"` + PassKeyEnabled bool `json:"passKeyEnabled"` + RecoveryKeyVerified bool `json:"recoveryKeyVerified"` + InternalUser bool `json:"internalUser"` + BetaUser bool `json:"betaUser"` +} + +type FlagKey string + +const ( + RecoveryKeyVerified FlagKey = "recoveryKeyVerified" + MapEnabled FlagKey = "mapEnabled" + FaceSearchEnabled FlagKey = "faceSearchEnabled" + PassKeyEnabled FlagKey = "passKeyEnabled" + IsInternalUser FlagKey = "internalUser" + IsBetaUser FlagKey = "betaUser" +) + +func (k FlagKey) String() string { + return string(k) +} + +// UserEditable returns true if the key is user editable +func (k FlagKey) UserEditable() bool { + switch k { + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled: + return true + default: + return false + } +} + +func (k FlagKey) IsAdminEditable() bool { + switch k { + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled: + return false + case IsInternalUser, IsBetaUser, PassKeyEnabled: + return true + default: + return true + } +} + +func (k FlagKey) IsBoolType() bool { + switch k { + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled, IsInternalUser, IsBetaUser: + return true + default: + return false + } +} diff --git a/server/pkg/api/admin.go b/server/pkg/api/admin.go index b153e19bb..0b6ac18ef 100644 --- a/server/pkg/api/admin.go +++ b/server/pkg/api/admin.go @@ -3,6 +3,7 @@ package api import ( "errors" "fmt" + "github.com/ente-io/museum/pkg/controller/remotestore" "net/http" "strconv" "strings" @@ -43,6 +44,7 @@ type AdminHandler struct { BillingController *controller.BillingController UserController *user.UserController FamilyController *family.Controller + RemoteStoreController *remotestore.Controller ObjectCleanupController *controller.ObjectCleanupController MailingListsController *controller.MailingListsController DiscordController *discord.DiscordController @@ -260,6 +262,32 @@ func (h *AdminHandler) RemovePasskeys(c *gin.Context) { c.JSON(http.StatusOK, gin.H{}) } +func (h *AdminHandler) UpdateFeatureFlag(c *gin.Context) { + var request ente.AdminUpdateKeyValueRequest + if err := c.ShouldBindJSON(&request); err != nil { + handler.Error(c, stacktrace.Propagate(ente.ErrBadRequest, "Bad request")) + return + } + go h.DiscordController.NotifyAdminAction( + fmt.Sprintf("Admin (%d) updating flag:%s to val:%s for %d", auth.GetUserID(c.Request.Header), request.Key, request.Value, request.UserID)) + + logger := logrus.WithFields(logrus.Fields{ + "user_id": request.UserID, + "admin_id": auth.GetUserID(c.Request.Header), + "req_id": requestid.Get(c), + "req_ctx": "update_feature_flag", + }) + logger.Info("Start update") + err := h.RemoteStoreController.AdminInsertOrUpdate(c, request) + if err != nil { + logger.WithError(err).Error("Failed to update flag") + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + logger.Info("successfully updated flag") + c.JSON(http.StatusOK, gin.H{}) +} + func (h *AdminHandler) CloseFamily(c *gin.Context) { var request ente.AdminOpsForUserRequest diff --git a/server/pkg/api/remotestore.go b/server/pkg/api/remotestore.go index ea6e621a3..9f03554de 100644 --- a/server/pkg/api/remotestore.go +++ b/server/pkg/api/remotestore.go @@ -49,3 +49,13 @@ func (h *RemoteStoreHandler) GetKey(c *gin.Context) { } c.JSON(http.StatusOK, resp) } + +// GetFeatureFlags returns all the feature flags and value for given user +func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) { + resp, err := h.Controller.GetFeatureFlags(c) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags")) + return + } + c.JSON(http.StatusOK, resp) +} diff --git a/server/pkg/controller/remotestore/controller.go b/server/pkg/controller/remotestore/controller.go index d41bf7e5f..bf8e4acfc 100644 --- a/server/pkg/controller/remotestore/controller.go +++ b/server/pkg/controller/remotestore/controller.go @@ -3,6 +3,7 @@ package remotestore import ( "database/sql" "errors" + "fmt" "github.com/ente-io/museum/ente" "github.com/ente-io/museum/pkg/repo/remotestore" @@ -16,12 +17,22 @@ type Controller struct { Repo *remotestore.Repository } -// Insert of update the key's value +// InsertOrUpdate the key's value func (c *Controller) InsertOrUpdate(ctx *gin.Context, request ente.UpdateKeyValueRequest) error { + if err := _validateRequest(request.Key, request.Value, false); err != nil { + return err + } userID := auth.GetUserID(ctx.Request.Header) return c.Repo.InsertOrUpdate(ctx, userID, request.Key, request.Value) } +func (c *Controller) AdminInsertOrUpdate(ctx *gin.Context, request ente.AdminUpdateKeyValueRequest) error { + if err := _validateRequest(request.Key, request.Value, true); err != nil { + return err + } + return c.Repo.InsertOrUpdate(ctx, request.UserID, request.Key, request.Value) +} + func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetValueResponse, error) { userID := auth.GetUserID(ctx.Request.Header) value, err := c.Repo.GetValue(ctx, userID, req.Key) @@ -34,3 +45,50 @@ func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetV } return &ente.GetValueResponse{Value: value}, nil } + +func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagResponse, error) { + userID := auth.GetUserID(ctx.Request.Header) + values, err := c.Repo.GetAllValues(ctx, userID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + response := &ente.FeatureFlagResponse{ + EnableStripe: true, // enable stripe for all + DisableCFWorker: false, + } + for key, value := range values { + flag := ente.FlagKey(key) + if !flag.IsBoolType() { + continue + } + switch flag { + case ente.RecoveryKeyVerified: + response.RecoveryKeyVerified = value == "true" + case ente.MapEnabled: + response.MapEnabled = value == "true" + case ente.FaceSearchEnabled: + response.FaceSearchEnabled = value == "true" + case ente.PassKeyEnabled: + response.PassKeyEnabled = value == "true" + case ente.IsInternalUser: + response.InternalUser = value == "true" + case ente.IsBetaUser: + response.BetaUser = value == "true" + } + } + return response, nil +} + +func _validateRequest(key, value string, byAdmin bool) error { + flag := ente.FlagKey(key) + if !flag.UserEditable() && !byAdmin { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not user editable", key)), "key not user editable") + } + if byAdmin && !flag.IsAdminEditable() { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not admin editable", key)), "key not admin editable") + } + if flag.IsBoolType() && value != "true" && value != "false" { + return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", value)), "value not allowed") + } + return nil +} diff --git a/server/pkg/repo/remotestore/repository.go b/server/pkg/repo/remotestore/repository.go index dc54b0cfc..2548f4901 100644 --- a/server/pkg/repo/remotestore/repository.go +++ b/server/pkg/repo/remotestore/repository.go @@ -13,7 +13,6 @@ type Repository struct { DB *sql.DB } -// func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key string, value string) error { _, err := r.DB.ExecContext(ctx, `INSERT INTO remote_store(user_id, key_name, key_value) VALUES ($1,$2,$3) ON CONFLICT (user_id, key_name) DO UPDATE SET key_value = $3; @@ -40,3 +39,25 @@ func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (st } return keyValue, nil } + +// GetAllValues fetches and return all the key value pairs for given user_id +func (r *Repository) GetAllValues(ctx context.Context, userID int64) (map[string]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT key_name, key_value FROM remote_store + WHERE user_id = $1`, + userID, // $1 + ) + if err != nil { + return nil, stacktrace.Propagate(err, "reading value failed") + } + defer rows.Close() + values := make(map[string]string) + for rows.Next() { + var key, value string + err := rows.Scan(&key, &value) + if err != nil { + return nil, stacktrace.Propagate(err, "reading value failed") + } + values[key] = value + } + return values, nil +}