diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index a217fb5f8..df5301904 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -5,9 +5,11 @@ import ( "context" "encoding/json" "errors" + "fmt" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/ente-io/museum/pkg/utils/array" "strconv" + "strings" "sync" gTime "time" @@ -228,6 +230,15 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" } +// Get userId, model and fileID from the object key +func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) { + split := strings.Split(objectKey, "/") + userID, _ = strconv.ParseInt(split[0], 10, 64) + fileID, _ = strconv.ParseInt(split[2], 10, 64) + model = strings.Split(split[3], ".")[0] + return userID, model, fileID +} + // uploadObject uploads the embedding object to the object store and returns the object size func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { embeddingObj, _ := json.Marshal(obj) @@ -352,7 +363,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") } else { // If derived and hot bucket are different, try to copy from hot bucket - copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey) + copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC(), c.derivedStorageDataCenter) if err == nil { ctxLogger.Info("Got the object from hot bucket object") return *copyEmbeddingObject, nil @@ -390,21 +401,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc st } // download the embedding object from hot bucket and upload to embeddings bucket -func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) { - if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() { - return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "") +func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) { + if srcDC == destDC { + return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "") } - obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC()) + obj, err := c.downloadObject(ctx, objectKey, srcDC) if err != nil { - return nil, stacktrace.Propagate(err, "failed to download from hot bucket") + return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC)) } go func() { - _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) - if err != nil { - log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err) + userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) + size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) + if uploadErr != nil { + log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr) + } + updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC) + if updateDcErr != nil { + log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr) + return } }() - return &obj, nil } diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index d2e73145a..1288512b3 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -3,6 +3,7 @@ package embedding import ( "context" "database/sql" + "errors" "fmt" "github.com/lib/pq" @@ -126,6 +127,22 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri return nil } +// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding +func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error { + res, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET size = $1, datacenters = array_append(COALESCE(datacenters, ARRAY[]::s3region[]), $2::s3region) WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) + if err != nil { + return stacktrace.Propagate(err, "") + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return stacktrace.Propagate(err, "") + } + if rowsAffected == 0 { + return stacktrace.Propagate(errors.New("no row got updated"), "") + } + return nil +} + func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { defer func() { if err := rows.Close(); err != nil {