Ver Fonte

Refactor

Neeraj Gupta há 1 ano atrás
pai
commit
a522631c2b
1 ficheiros alterados com 19 adições e 16 exclusões
  1. 19 16
      server/pkg/controller/embedding/controller.go

+ 19 - 16
server/pkg/controller/embedding/controller.go

@@ -53,9 +53,16 @@ type Controller struct {
 	HostName                 string
 	cleanupCronRunning       bool
 	derivedStorageDataCenter string
+	downloadManagerCache     map[string]*s3manager.Downloader
 }
 
 func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
+	embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetDerivedStorageDataCenter()}
+	cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
+	for i := range embeddingDcs {
+		s3Client := s3Config.GetS3Client(embeddingDcs[i])
+		cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
+	}
 	return &Controller{
 		Repo:                     repo,
 		AccessCtrl:               accessCtrl,
@@ -67,6 +74,7 @@ func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanup
 		CollectionRepo:           collectionRepo,
 		HostName:                 hostName,
 		derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
+		downloadManagerCache:     cache,
 	}
 }
 
@@ -136,7 +144,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest)
 
 	// Fetch missing embeddings in parallel
 	if len(objectKeys) > 0 {
-		embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
+		embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter)
 		if err != nil {
 			return nil, stacktrace.Propagate(err, "")
 		}
@@ -182,7 +190,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
 	errFileIds := make([]int64, 0)
 
 	// Fetch missing userFileEmbeddings in parallel
-	embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
+	embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter)
 	if err != nil {
 		return nil, stacktrace.Propagate(err, "")
 	}
@@ -245,13 +253,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300)
 
 var globalFileFetchSemaphore = make(chan struct{}, 400)
 
-func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
+func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) {
 	var wg sync.WaitGroup
 	var errs []error
 	embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
-	s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
-	downloader := s3manager.NewDownloaderWithClient(&s3Client)
-
 	for i, objectKey := range objectKeys {
 		wg.Add(1)
 		globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
@@ -259,7 +264,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
 			defer wg.Done()
 			defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
 
-			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
 			if err != nil {
 				errs = append(errs, err)
 				log.Error("error fetching embedding object: "+objectKey, err)
@@ -284,11 +289,9 @@ type embeddingObjectResult struct {
 	err             error
 }
 
-func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
+func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) {
 	var wg sync.WaitGroup
 	embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
-	s3Client := c.S3Config.GetS3Client(c.derivedStorageDataCenter)
-	downloader := s3manager.NewDownloaderWithClient(&s3Client)
 
 	for i, dbEmbeddingRow := range dbEmbeddingRows {
 		wg.Add(1)
@@ -297,7 +300,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 			defer wg.Done()
 			defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
 			objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
-			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
 			if err != nil {
 				log.Error("error fetching embedding object: "+objectKey, err)
 				embeddingObjects[i] = embeddingObjectResult{
@@ -317,7 +320,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 	return embeddingObjects, nil
 }
 
-func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
+func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
 	opt := _defaultFetchConfig
 	ctxLogger := log.WithField("objectKey", objectKey)
 	totalAttempts := opt.RetryCount + 1
@@ -329,7 +332,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 			cancel()
 			return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
 		default:
-			obj, err := c.downloadObject(fetchCtx, objectKey, downloader, c.derivedStorageDataCenter)
+			obj, err := c.downloadObject(fetchCtx, objectKey, dc)
 			cancel() // Ensure cancel is called to release resources
 			if err == nil {
 				if i > 0 {
@@ -367,10 +370,11 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 	return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
 }
 
-func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, dc string) (ente.EmbeddingObject, error) {
+func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
 	var obj ente.EmbeddingObject
 	buff := &aws.WriteAtBuffer{}
 	bucket := c.S3Config.GetBucket(dc)
+	downloader := c.downloadManagerCache[dc]
 	_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
 		Bucket: bucket,
 		Key:    &objectKey,
@@ -390,8 +394,7 @@ func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string)
 	if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
 		return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
 	}
-	downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
-	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
+	obj, err := c.downloadObject(ctx, objectKey, c.S3Config.GetHotBackblazeDC())
 	if err != nil {
 		return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
 	}