From be446651289d2d3e38b045a7be5da8e1525cbac8 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Mon, 13 May 2024 16:33:36 +0530 Subject: [PATCH] [server] Refactor embedding fetch --- server/pkg/controller/embedding/controller.go | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index bf317ccfe..b14f5d893 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -309,7 +309,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em defer wg.Done() defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil) if err != nil { errs = append(errs, err) log.Error("error fetching embedding object: "+objectKey, err) @@ -346,9 +346,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows defer wg.Done() defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout) - defer cancel() - obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil) if err != nil { log.Error("error fetching embedding object: "+objectKey, err) embeddingObjects[i] = embeddingObjectResult{ @@ -368,11 +366,45 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows return embeddingObjects, nil } -func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3) +type getOptions struct { + RetryCount int + FetchTimeOut gTime.Duration } -func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) { + if opt == nil { + opt = &getOptions{ + RetryCount: 3, + FetchTimeOut: embeddingFetchTimeout, + } + } + ctxLogger := log.WithField("objectKey", objectKey) + totalAttempts := opt.RetryCount + 1 + for i := 0; i < totalAttempts; i++ { + // Create a new context with a timeout for each fetch + fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut) + select { + case <-ctx.Done(): + cancel() + return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") + default: + obj, err := c.downloadObject(fetchCtx, objectKey, downloader) + cancel() // Ensure cancel is called to release resources + if err == nil { + return obj, nil + } + // Check if the error is due to context timeout or cancellation + if fetchCtx.Err() != nil { + ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err()) + } else { + ctxLogger.Error("Failed to fetch object: ", err) + } + } + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") +} + +func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ @@ -380,16 +412,11 @@ func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKe Key: &objectKey, }) if err != nil { - log.Error(err) - if retryCount > 0 { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1) - } - return obj, stacktrace.Propagate(err, "") + return obj, stacktrace.Propagate(err, "downloadFailed") } err = json.Unmarshal(buff.Bytes(), &obj) if err != nil { - log.Error(err) - return obj, stacktrace.Propagate(err, "") + return obj, stacktrace.Propagate(err, "unmarshal failed") } return obj, nil }