Browse Source

[server] Refactor embedding fetch

Neeraj Gupta 1 year ago
parent
commit
be44665128
1 changed files with 41 additions and 14 deletions
  1. 41 14
      server/pkg/controller/embedding/controller.go

+ 41 - 14
server/pkg/controller/embedding/controller.go

@@ -309,7 +309,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
 			defer wg.Done()
 			defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
 
-			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
 			if err != nil {
 				errs = append(errs, err)
 				log.Error("error fetching embedding object: "+objectKey, err)
@@ -346,9 +346,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 			defer wg.Done()
 			defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
 			objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
-			ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
-			defer cancel()
-			obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
+			obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
 			if err != nil {
 				log.Error("error fetching embedding object: "+objectKey, err)
 				embeddingObjects[i] = embeddingObjectResult{
@@ -368,11 +366,45 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
 	return embeddingObjects, nil
 }
 
-func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
-	return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
+type getOptions struct {
+	RetryCount   int
+	FetchTimeOut gTime.Duration
 }
 
-func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
+func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
+	if opt == nil {
+		opt = &getOptions{
+			RetryCount:   3,
+			FetchTimeOut: embeddingFetchTimeout,
+		}
+	}
+	ctxLogger := log.WithField("objectKey", objectKey)
+	totalAttempts := opt.RetryCount + 1
+	for i := 0; i < totalAttempts; i++ {
+		// Create a new context with a timeout for each fetch
+		fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
+		select {
+		case <-ctx.Done():
+			cancel()
+			return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
+		default:
+			obj, err := c.downloadObject(fetchCtx, objectKey, downloader)
+			cancel() // Ensure cancel is called to release resources
+			if err == nil {
+				return obj, nil
+			}
+			// Check if the error is due to context timeout or cancellation
+			if fetchCtx.Err() != nil {
+				ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
+			} else {
+				ctxLogger.Error("Failed to fetch object: ", err)
+			}
+		}
+	}
+	return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
+}
+
+func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
 	var obj ente.EmbeddingObject
 	buff := &aws.WriteAtBuffer{}
 	_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
@@ -380,16 +412,11 @@ func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKe
 		Key:    &objectKey,
 	})
 	if err != nil {
-		log.Error(err)
-		if retryCount > 0 {
-			return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
-		}
-		return obj, stacktrace.Propagate(err, "")
+		return obj, stacktrace.Propagate(err, "downloadFailed")
 	}
 	err = json.Unmarshal(buff.Bytes(), &obj)
 	if err != nil {
-		log.Error(err)
-		return obj, stacktrace.Propagate(err, "")
+		return obj, stacktrace.Propagate(err, "unmarshal failed")
 	}
 	return obj, nil
 }