|
@@ -309,7 +309,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|
|
defer wg.Done()
|
|
|
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
|
|
|
|
|
- obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
|
|
+ obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
|
|
if err != nil {
|
|
|
errs = append(errs, err)
|
|
|
log.Error("error fetching embedding object: "+objectKey, err)
|
|
@@ -346,9 +346,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|
|
defer wg.Done()
|
|
|
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
|
|
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
|
|
|
- defer cancel()
|
|
|
- obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
|
|
|
+ obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
|
|
if err != nil {
|
|
|
log.Error("error fetching embedding object: "+objectKey, err)
|
|
|
embeddingObjects[i] = embeddingObjectResult{
|
|
@@ -368,11 +366,45 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|
|
return embeddingObjects, nil
|
|
|
}
|
|
|
|
|
|
-func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
|
|
- return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
|
|
|
+type getOptions struct {
|
|
|
+ RetryCount int
|
|
|
+ FetchTimeOut gTime.Duration
|
|
|
}
|
|
|
|
|
|
-func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
|
|
|
+func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
|
|
|
+ if opt == nil {
|
|
|
+ opt = &getOptions{
|
|
|
+ RetryCount: 3,
|
|
|
+ FetchTimeOut: embeddingFetchTimeout,
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ctxLogger := log.WithField("objectKey", objectKey)
|
|
|
+ totalAttempts := opt.RetryCount + 1
|
|
|
+ for i := 0; i < totalAttempts; i++ {
|
|
|
+ // Create a new context with a timeout for each fetch
|
|
|
+ fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ cancel()
|
|
|
+ return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
|
|
+ default:
|
|
|
+ obj, err := c.downloadObject(fetchCtx, objectKey, downloader)
|
|
|
+ cancel() // Ensure cancel is called to release resources
|
|
|
+ if err == nil {
|
|
|
+ return obj, nil
|
|
|
+ }
|
|
|
+ // Check if the error is due to context timeout or cancellation
|
|
|
+ if fetchCtx.Err() != nil {
|
|
|
+ ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
|
|
|
+ } else {
|
|
|
+ ctxLogger.Error("Failed to fetch object: ", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
|
|
+}
|
|
|
+
|
|
|
+func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
|
|
var obj ente.EmbeddingObject
|
|
|
buff := &aws.WriteAtBuffer{}
|
|
|
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
|
@@ -380,16 +412,11 @@ func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKe
|
|
|
Key: &objectKey,
|
|
|
})
|
|
|
if err != nil {
|
|
|
- log.Error(err)
|
|
|
- if retryCount > 0 {
|
|
|
- return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
|
|
|
- }
|
|
|
- return obj, stacktrace.Propagate(err, "")
|
|
|
+ return obj, stacktrace.Propagate(err, "downloadFailed")
|
|
|
}
|
|
|
err = json.Unmarshal(buff.Bytes(), &obj)
|
|
|
if err != nil {
|
|
|
- log.Error(err)
|
|
|
- return obj, stacktrace.Propagate(err, "")
|
|
|
+ return obj, stacktrace.Propagate(err, "unmarshal failed")
|
|
|
}
|
|
|
return obj, nil
|
|
|
}
|