فهرست منبع

Update dc while copying derived file

Neeraj Gupta 1 سال پیش
والد
کامیت
b404b77da3
2فایلهای تغییر یافته به همراه43 افزوده شده و 10 حذف شده
  1. 26 10
      server/pkg/controller/embedding/controller.go
  2. 17 0
      server/pkg/repo/embedding/repository.go

+ 26 - 10
server/pkg/controller/embedding/controller.go

@@ -5,9 +5,11 @@ import (
 	"context"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"github.com/aws/aws-sdk-go/aws/awserr"
 	"github.com/ente-io/museum/pkg/utils/array"
 	"strconv"
+	"strings"
 	"sync"
 	gTime "time"
 
@@ -228,6 +230,15 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
 	return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
 }
 
+// Get userId, model and fileID from the object key
+func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) {
+	split := strings.Split(objectKey, "/")
+	userID, _ = strconv.ParseInt(split[0], 10, 64)
+	fileID, _ = strconv.ParseInt(split[2], 10, 64)
+	model = strings.Split(split[3], ".")[0]
+	return userID, model, fileID
+}
+
 // uploadObject uploads the embedding object to the object store and returns the object size
 func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
 	embeddingObj, _ := json.Marshal(obj)
@@ -352,7 +363,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 							return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
 						} else {
 							// If derived and hot bucket are different, try to copy from hot bucket
-							copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
+							copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC(), c.derivedStorageDataCenter)
 							if err == nil {
 								ctxLogger.Info("Got the object from hot bucket object")
 								return *copyEmbeddingObject, nil
@@ -390,21 +401,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc st
 }
 
 // download the embedding object from hot bucket and upload to embeddings bucket
-func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
-	if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
-		return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
+func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) {
+	if srcDC == destDC {
+		return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "")
 	}
-	obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC())
+	obj, err := c.downloadObject(ctx, objectKey, srcDC)
 	if err != nil {
-		return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
+		return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC))
 	}
 	go func() {
-		_, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
-		if err != nil {
-			log.WithField("object", objectKey).Error("Failed to copy  to embeddings bucket: ", err)
+		userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey)
+		size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
+		if uploadErr != nil {
+			log.WithField("object", objectKey).Error("Failed to copy  to embeddings bucket: ", uploadErr)
+		}
+		updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC)
+		if updateDcErr != nil {
+			log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr)
+			return
 		}
 	}()
-
 	return &obj, nil
 }
 

+ 17 - 0
server/pkg/repo/embedding/repository.go

@@ -3,6 +3,7 @@ package embedding
 import (
 	"context"
 	"database/sql"
+	"errors"
 	"fmt"
 	"github.com/lib/pq"
 
@@ -126,6 +127,22 @@ func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc stri
 	return nil
 }
 
+// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding
+func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error {
+	res, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET size = $1, datacenters = array_append(COALESCE(datacenters, ARRAY[]::s3region[]), $2::s3region) WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID)
+	if err != nil {
+		return stacktrace.Propagate(err, "")
+	}
+	rowsAffected, err := res.RowsAffected()
+	if err != nil {
+		return stacktrace.Propagate(err, "")
+	}
+	if rowsAffected == 0 {
+		return stacktrace.Propagate(errors.New("no  row got updated"), "")
+	}
+	return nil
+}
+
 func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
 	defer func() {
 		if err := rows.Close(); err != nil {