Neeraj Gupta vor 1 Jahr
Ursprung
Commit
e0738db6ae
1 geänderte Dateien mit 17 neuen und 14 gelöschten Zeilen
  1. 17 14
      server/pkg/controller/embedding/controller.go

+ 17 - 14
server/pkg/controller/embedding/controller.go

@@ -45,7 +45,7 @@ type Controller struct {
 	HostName                   string
 	cleanupCronRunning         bool
 	derivedStorageS3Client     *s3.S3
-	derivedStorageBucket       *string
+	derivedStorageDataCenter   string
 	areDerivedAndHotBucketSame bool
 }
 
@@ -61,8 +61,8 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup
 		CollectionRepo:             collectionRepo,
 		HostName:                   hostName,
 		derivedStorageS3Client:     s3Config.GetDerivedStorageS3Client(),
-		derivedStorageBucket:       s3Config.GetDerivedStorageBucket(),
-		areDerivedAndHotBucketSame: s3Config.GetDerivedStorageBucket() == s3Config.GetHotBucket(),
+		derivedStorageDataCenter:   s3Config.GetDerivedStorageDataCenter(),
+		areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(),
 	}
 }
 
@@ -96,12 +96,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
 		DecryptionHeader:   req.DecryptionHeader,
 		Client:             network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
 	}
-	size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
+	size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
 	if uploadErr != nil {
 		log.Error(uploadErr)
 		return nil, stacktrace.Propagate(uploadErr, "")
 	}
-	embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.S3Config.GetDerivedStorageDataCenter())
+	embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
 	embedding.Version = &version
 	if err != nil {
 		return nil, stacktrace.Propagate(err, "")
@@ -217,11 +217,13 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
 }
 
 // uploadObject uploads the embedding object to the object store and returns the object size
-func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
+func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
 	embeddingObj, _ := json.Marshal(obj)
-	uploader := s3manager.NewUploaderWithClient(c.derivedStorageS3Client)
+	s3Client := c.S3Config.GetS3Client(dc)
+	s3Bucket := c.S3Config.GetBucket(dc)
+	uploader := s3manager.NewUploaderWithClient(&s3Client)
 	up := s3manager.UploadInput{
-		Bucket: c.derivedStorageBucket,
+		Bucket: s3Bucket,
 		Key:    &key,
 		Body:   bytes.NewReader(embeddingObj),
 	}
@@ -331,7 +333,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 			cancel()
 			return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
 		default:
-			obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageBucket)
+			obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
 			cancel() // Ensure cancel is called to release resources
 			if err == nil {
 				if i > 0 {
@@ -368,9 +370,10 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 	return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
 }
 
-func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, bucket *string) (ente.EmbeddingObject, error) {
+func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
 	var obj ente.EmbeddingObject
 	buff := &aws.WriteAtBuffer{}
+	bucket := c.S3Config.GetBucket(dc)
 	_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
 		Bucket: bucket,
 		Key:    &objectKey,
@@ -387,16 +390,16 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
 
 // download the embedding object from hot bucket and upload to embeddings bucket
 func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
-	if c.derivedStorageBucket == c.S3Config.GetHotBucket() {
-		return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
+	if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
+		return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
 	}
 	downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
-	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
+	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
 	if err != nil {
 		return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
 	}
 	go func() {
-		_, err = c.uploadObject(obj, objectKey)
+		_, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
 		if err != nil {
 			log.WithField("object", objectKey).Error("Failed to copy  to embeddings bucket: ", err)
 		}