Neeraj Gupta hace 1 año
padre
commit
20e9a6a1fc
Se han modificado 1 ficheros con 15 adiciones y 17 borrados
  1. 15 17
      server/pkg/controller/embedding/controller.go

+ 15 - 17
server/pkg/controller/embedding/controller.go

@@ -33,6 +33,14 @@ const (
 	embeddingFetchTimeout = 15 * gTime.Second
 )
 
+// _fetchConfig is the configuration for the fetching objects from S3
+type _fetchConfig struct {
+	RetryCount   int
+	FetchTimeOut gTime.Duration
+}
+
+var _defaultFetchConfig = _fetchConfig{RetryCount: 3, FetchTimeOut: 15 * gTime.Second}
+
 type Controller struct {
 	Repo                     *embedding.Repository
 	AccessCtrl               access.Controller
@@ -251,7 +259,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
 			defer wg.Done()
 			defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
 
-			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
 			if err != nil {
 				errs = append(errs, err)
 				log.Error("error fetching embedding object: "+objectKey, err)
@@ -289,7 +297,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 			defer wg.Done()
 			defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
 			objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
-			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
 			if err != nil {
 				log.Error("error fetching embedding object: "+objectKey, err)
 				embeddingObjects[i] = embeddingObjectResult{
@@ -309,18 +317,8 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 	return embeddingObjects, nil
 }
 
-type getOptions struct {
-	RetryCount   int
-	FetchTimeOut gTime.Duration
-}
-
-func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
-	if opt == nil {
-		opt = &getOptions{
-			RetryCount:   3,
-			FetchTimeOut: embeddingFetchTimeout,
-		}
-	}
+func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
+	opt := _defaultFetchConfig
 	ctxLogger := log.WithField("objectKey", objectKey)
 	totalAttempts := opt.RetryCount + 1
 	for i := 0; i < totalAttempts; i++ {
@@ -346,7 +344,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 				// check if the error is due to object not found
 				if s3Err, ok := err.(awserr.RequestFailure); ok {
 					if s3Err.Code() == s3.ErrCodeNoSuchKey {
-						if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
+						if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
 							ctxLogger.Error("Object not found: ", s3Err)
 							return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
 						} else {
@@ -389,11 +387,11 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
 
 // download the embedding object from hot bucket and upload to embeddings bucket
 func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
-	if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
+	if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
 		return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
 	}
 	downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
-	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
+	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
 	if err != nil {
 		return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
 	}