|
@@ -45,7 +45,7 @@ type Controller struct {
|
|
|
HostName string
|
|
|
cleanupCronRunning bool
|
|
|
derivedStorageS3Client *s3.S3
|
|
|
- derivedStorageBucket *string
|
|
|
+ derivedStorageDataCenter string
|
|
|
areDerivedAndHotBucketSame bool
|
|
|
}
|
|
|
|
|
@@ -61,8 +61,8 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup
|
|
|
CollectionRepo: collectionRepo,
|
|
|
HostName: hostName,
|
|
|
derivedStorageS3Client: s3Config.GetDerivedStorageS3Client(),
|
|
|
- derivedStorageBucket: s3Config.GetDerivedStorageBucket(),
|
|
|
- areDerivedAndHotBucketSame: s3Config.GetDerivedStorageBucket() == s3Config.GetHotBucket(),
|
|
|
+ derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
|
|
|
+ areDerivedAndHotBucketSame: s3Config.GetDerivedStorageDataCenter() == s3Config.GetDerivedStorageDataCenter(),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -96,12 +96,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
|
|
|
DecryptionHeader: req.DecryptionHeader,
|
|
|
Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
|
|
|
}
|
|
|
- size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
|
|
|
+ size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
|
|
|
if uploadErr != nil {
|
|
|
log.Error(uploadErr)
|
|
|
return nil, stacktrace.Propagate(uploadErr, "")
|
|
|
}
|
|
|
- embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.S3Config.GetDerivedStorageDataCenter())
|
|
|
+ embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
|
|
|
embedding.Version = &version
|
|
|
if err != nil {
|
|
|
return nil, stacktrace.Propagate(err, "")
|
|
@@ -217,11 +217,13 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
|
|
}
|
|
|
|
|
|
// uploadObject uploads the embedding object to the object store and returns the object size
|
|
|
-func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
|
|
+func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
|
|
|
embeddingObj, _ := json.Marshal(obj)
|
|
|
- uploader := s3manager.NewUploaderWithClient(c.derivedStorageS3Client)
|
|
|
+ s3Client := c.S3Config.GetS3Client(dc)
|
|
|
+ s3Bucket := c.S3Config.GetBucket(dc)
|
|
|
+ uploader := s3manager.NewUploaderWithClient(&s3Client)
|
|
|
up := s3manager.UploadInput{
|
|
|
- Bucket: c.derivedStorageBucket,
|
|
|
+ Bucket: s3Bucket,
|
|
|
Key: &key,
|
|
|
Body: bytes.NewReader(embeddingObj),
|
|
|
}
|
|
@@ -331,7 +333,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
|
|
cancel()
|
|
|
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
|
|
default:
|
|
|
- obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageBucket)
|
|
|
+ obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
|
|
|
cancel() // Ensure cancel is called to release resources
|
|
|
if err == nil {
|
|
|
if i > 0 {
|
|
@@ -368,9 +370,10 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
|
|
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
|
|
}
|
|
|
|
|
|
-func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, bucket *string) (ente.EmbeddingObject, error) {
|
|
|
+func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
|
|
|
var obj ente.EmbeddingObject
|
|
|
buff := &aws.WriteAtBuffer{}
|
|
|
+ bucket := c.S3Config.GetBucket(dc)
|
|
|
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
|
|
Bucket: bucket,
|
|
|
Key: &objectKey,
|
|
@@ -387,16 +390,16 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
|
|
|
|
|
// download the embedding object from hot bucket and upload to embeddings bucket
|
|
|
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
|
|
- if c.derivedStorageBucket == c.S3Config.GetHotBucket() {
|
|
|
- return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
|
|
|
+ if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
|
|
|
+ return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
|
|
}
|
|
|
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
|
|
- obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
|
|
|
+ obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
|
|
|
if err != nil {
|
|
|
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
|
|
}
|
|
|
go func() {
|
|
|
- _, err = c.uploadObject(obj, objectKey)
|
|
|
+ _, err = c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
|
|
|
if err != nil {
|
|
|
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
|
|
|
}
|