[server] Add timeout while fetching embedding
This commit is contained in:
parent
5caa9c5f61
commit
3a70dcd930
1 changed files with 11 additions and 7 deletions
|
@ -2,12 +2,14 @@ package embedding
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"strconv"
|
||||
"sync"
|
||||
gTime "time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
|
@ -306,7 +308,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
defer wg.Done()
|
||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||
|
||||
obj, err := c.getEmbeddingObject(objectKey, downloader)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
|
@ -343,7 +345,9 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
defer wg.Done()
|
||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||
obj, err := c.getEmbeddingObject(objectKey, downloader)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gTime.Second*10) // 10 seconds timeout
|
||||
defer cancel()
|
||||
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
|
||||
if err != nil {
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
|
@ -363,21 +367,21 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
return c.getEmbeddingObjectWithRetries(objectKey, downloader, 3)
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObjectWithRetries(objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
|
||||
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
|
||||
var obj ente.EmbeddingObject
|
||||
buff := &aws.WriteAtBuffer{}
|
||||
_, err := downloader.Download(buff, &s3.GetObjectInput{
|
||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Key: &objectKey,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if retryCount > 0 {
|
||||
return c.getEmbeddingObjectWithRetries(objectKey, downloader, retryCount-1)
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
|
||||
}
|
||||
return obj, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue