diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 6f0f6e69a..3fff6b568 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -45,7 +45,7 @@ type Controller struct { HostName string cleanupCronRunning bool derivedStorageS3Client *s3.S3 - derivedStorageBucket *string + derivedStorageDataCenter string areDerivedAndHotBucketSame bool } @@ -61,8 +61,8 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup CollectionRepo: collectionRepo, HostName: hostName, derivedStorageS3Client: s3Config.GetDerivedStorageS3Client(), - derivedStorageBucket: s3Config.GetDerivedStorageBucket(), - areDerivedAndHotBucketSame: s3Config.GetDerivedStorageBucket() == s3Config.GetHotBucket(), + derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(), + areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(), } } @@ -96,12 +96,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb DecryptionHeader: req.DecryptionHeader, Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"), } - size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model)) + size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter) if uploadErr != nil { log.Error(uploadErr) return nil, stacktrace.Propagate(uploadErr, "") } - embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.S3Config.GetDerivedStorageDataCenter()) + embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter) embedding.Version = &version if err != nil { return nil, stacktrace.Propagate(err, "") @@ -217,11 +217,13 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string } // uploadObject uploads the embedding object to the object store and returns the object size -func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) { +func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { embeddingObj, _ := json.Marshal(obj) - uploader := s3manager.NewUploaderWithClient(c.derivedStorageS3Client) + s3Client := c.S3Config.GetS3Client(dc) + s3Bucket := c.S3Config.GetBucket(dc) + uploader := s3manager.NewUploaderWithClient(&s3Client) up := s3manager.UploadInput{ - Bucket: c.derivedStorageBucket, + Bucket: s3Bucket, Key: &key, Body: bytes.NewReader(embeddingObj), } @@ -331,7 +333,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d cancel() return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") default: - obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageBucket) + obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter) cancel() // Ensure cancel is called to release resources if err == nil { if i > 0 { @@ -368,9 +370,10 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") } -func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, bucket *string) (ente.EmbeddingObject, error) { +func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} + bucket := c.S3Config.GetBucket(dc) _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: bucket, Key: &objectKey, @@ -387,16 +390,16 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl // download the embedding object from hot bucket and upload to embeddings bucket func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) { - if c.derivedStorageBucket == c.S3Config.GetHotBucket() { - return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "") + if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() { + return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "") } downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) - obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket()) + obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter()) if err != nil { return nil, stacktrace.Propagate(err, "failed to download from hot bucket") } go func() { - _, err = c.uploadObject(obj, objectKey) + _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) if err != nil { log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err) }