[server] Refactor embedding fetch
This commit is contained in:
parent
282611610d
commit
be44665128
1 changed files with 41 additions and 14 deletions
|
@ -309,7 +309,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
defer wg.Done()
|
||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
|
@ -346,9 +346,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
defer wg.Done()
|
||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
|
||||
defer cancel()
|
||||
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
||||
if err != nil {
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
|
@ -368,11 +366,45 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
|
||||
type getOptions struct {
|
||||
RetryCount int
|
||||
FetchTimeOut gTime.Duration
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
|
||||
if opt == nil {
|
||||
opt = &getOptions{
|
||||
RetryCount: 3,
|
||||
FetchTimeOut: embeddingFetchTimeout,
|
||||
}
|
||||
}
|
||||
ctxLogger := log.WithField("objectKey", objectKey)
|
||||
totalAttempts := opt.RetryCount + 1
|
||||
for i := 0; i < totalAttempts; i++ {
|
||||
// Create a new context with a timeout for each fetch
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
||||
default:
|
||||
obj, err := c.downloadObject(fetchCtx, objectKey, downloader)
|
||||
cancel() // Ensure cancel is called to release resources
|
||||
if err == nil {
|
||||
return obj, nil
|
||||
}
|
||||
// Check if the error is due to context timeout or cancellation
|
||||
if fetchCtx.Err() != nil {
|
||||
ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
|
||||
} else {
|
||||
ctxLogger.Error("Failed to fetch object: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
||||
}
|
||||
|
||||
func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
var obj ente.EmbeddingObject
|
||||
buff := &aws.WriteAtBuffer{}
|
||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||
|
@ -380,16 +412,11 @@ func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKe
|
|||
Key: &objectKey,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if retryCount > 0 {
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
|
||||
}
|
||||
return obj, stacktrace.Propagate(err, "")
|
||||
return obj, stacktrace.Propagate(err, "downloadFailed")
|
||||
}
|
||||
err = json.Unmarshal(buff.Bytes(), &obj)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return obj, stacktrace.Propagate(err, "")
|
||||
return obj, stacktrace.Propagate(err, "unmarshal failed")
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue