diff --git a/server/pkg/api/remotestore.go b/server/pkg/api/remotestore.go index ea6e621a3..9f03554de 100644 --- a/server/pkg/api/remotestore.go +++ b/server/pkg/api/remotestore.go @@ -49,3 +49,13 @@ func (h *RemoteStoreHandler) GetKey(c *gin.Context) { } c.JSON(http.StatusOK, resp) } + +// GetFeatureFlags returns all the feature flags and value for given user +func (h *RemoteStoreHandler) GetFeatureFlags(c *gin.Context) { + resp, err := h.Controller.GetFeatureFlags(c) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "failed to get feature flags")) + return + } + c.JSON(http.StatusOK, resp) +} diff --git a/server/pkg/controller/remotestore/controller.go b/server/pkg/controller/remotestore/controller.go index d774c742e..0a4e8cbb1 100644 --- a/server/pkg/controller/remotestore/controller.go +++ b/server/pkg/controller/remotestore/controller.go @@ -21,6 +21,15 @@ const ( PassKeyEnabled FlagKey = "passKeyEnabled" ) +func isBoolType(key FlagKey) bool { + switch key { + case RecoveryKeyVerified, MapEnabled, FaceSearchEnabled, PassKeyEnabled: + return true + default: + return false + } +} + var ( _allowKeys = map[FlagKey]*bool{ RecoveryKeyVerified: nil, @@ -57,11 +66,41 @@ func (c *Controller) Get(ctx *gin.Context, req ente.GetValueRequest) (*ente.GetV return &ente.GetValueResponse{Value: value}, nil } +func (c *Controller) GetFeatureFlags(ctx *gin.Context) (*ente.FeatureFlagResponse, error) { + userID := auth.GetUserID(ctx.Request.Header) + values, err := c.Repo.GetAllValues(ctx, userID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + response := &ente.FeatureFlagResponse{ + EnableStripe: true, // enable stripe for all + DisableCFWorker: false, + } + for key, value := range values { + flag := FlagKey(key) + if !isBoolType(flag) { + continue + } + switch flag { + case RecoveryKeyVerified: + response.RestoreKeyVerified = value == "true" + case MapEnabled: + response.MapEnabled = value == "true" + case FaceSearchEnabled: + response.FaceSearchEnabled = value == "true" + case PassKeyEnabled: + response.PassKeyEnabled = value == "true" + } + } + return response, nil +} + func _validateRequest(request ente.UpdateKeyValueRequest) error { - if _, ok := _allowKeys[FlagKey(request.Key)]; !ok { + flag := FlagKey(request.Key) + if _, ok := _allowKeys[flag]; !ok { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("key %s is not allowed", request.Key)), "key not allowed") } - if request.Value != "true" && request.Value != "false" { + if isBoolType(flag) && request.Value != "true" && request.Value != "false" { return stacktrace.Propagate(ente.NewBadRequestWithMessage(fmt.Sprintf("value %s is not allowed", request.Value)), "value not allowed") } return nil diff --git a/server/pkg/repo/remotestore/repository.go b/server/pkg/repo/remotestore/repository.go index dc54b0cfc..2548f4901 100644 --- a/server/pkg/repo/remotestore/repository.go +++ b/server/pkg/repo/remotestore/repository.go @@ -13,7 +13,6 @@ type Repository struct { DB *sql.DB } -// func (r *Repository) InsertOrUpdate(ctx context.Context, userID int64, key string, value string) error { _, err := r.DB.ExecContext(ctx, `INSERT INTO remote_store(user_id, key_name, key_value) VALUES ($1,$2,$3) ON CONFLICT (user_id, key_name) DO UPDATE SET key_value = $3; @@ -40,3 +39,25 @@ func (r *Repository) GetValue(ctx context.Context, userID int64, key string) (st } return keyValue, nil } + +// GetAllValues fetches and return all the key value pairs for given user_id +func (r *Repository) GetAllValues(ctx context.Context, userID int64) (map[string]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT key_name, key_value FROM remote_store + WHERE user_id = $1`, + userID, // $1 + ) + if err != nil { + return nil, stacktrace.Propagate(err, "reading value failed") + } + defer rows.Close() + values := make(map[string]string) + for rows.Next() { + var key, value string + err := rows.Scan(&key, &value) + if err != nil { + return nil, stacktrace.Propagate(err, "reading value failed") + } + values[key] = value + } + return values, nil +}