diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index fb79fd36e..a217fb5f8 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -53,9 +53,16 @@ type Controller struct { HostName string cleanupCronRunning bool derivedStorageDataCenter string + downloadManagerCache map[string]*s3manager.Downloader } func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { + embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetDerivedStorageDataCenter()} + cache := make(map[string]*s3manager.Downloader, len(embeddingDcs)) + for i := range embeddingDcs { + s3Client := s3Config.GetS3Client(embeddingDcs[i]) + cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client) + } return &Controller{ Repo: repo, AccessCtrl: accessCtrl, @@ -67,6 +74,7 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup CollectionRepo: collectionRepo, HostName: hostName, derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(), + downloadManagerCache: cache, } } @@ -136,7 +144,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) // Fetch missing embeddings in parallel if len(objectKeys) > 0 { - embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys) + embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -182,7 +190,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd errFileIds := make([]int64, 0) // Fetch missing userFileEmbeddings in parallel - embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData) + embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -245,13 +253,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300) var globalFileFetchSemaphore = make(chan struct{}, 400) -func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) { var wg sync.WaitGroup var errs []error embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) - s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter) - downloader := s3manager.NewDownloaderWithClient(&s3Client) - for i, objectKey := range objectKeys { wg.Add(1) globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore @@ -259,7 +264,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em defer wg.Done() defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { errs = append(errs, err) log.Error("error fetching embedding object: "+objectKey, err) @@ -284,11 +289,9 @@ type embeddingObjectResult struct { err error } -func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { +func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) { var wg sync.WaitGroup embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) - s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter) - downloader := s3manager.NewDownloaderWithClient(&s3Client) for i, dbEmbeddingRow := range dbEmbeddingRows { wg.Add(1) @@ -297,7 +300,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows defer wg.Done() defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { log.Error("error fetching embedding object: "+objectKey, err) embeddingObjects[i] = embeddingObjectResult{ @@ -317,7 +320,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows return embeddingObjects, nil } -func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { opt := _defaultFetchConfig ctxLogger := log.WithField("objectKey", objectKey) totalAttempts := opt.RetryCount + 1 @@ -329,7 +332,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d cancel() return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") default: - obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter) + obj, err := c.downloadObject(fetchCtx, objectKey, dc) cancel() // Ensure cancel is called to release resources if err == nil { if i > 0 { @@ -367,10 +370,11 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") } -func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) { +func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} bucket := c.S3Config.GetBucket(dc) + downloader := c.downloadManagerCache[dc] _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: bucket, Key: &objectKey, @@ -390,8 +394,7 @@ func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() { return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "") } - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) - obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC()) + obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC()) if err != nil { return nil, stacktrace.Propagate(err, "failed to download from hot bucket") }