[server] Refactor embedding fetch

This commit is contained in:
Neeraj Gupta 2024-05-13 16:33:36 +05:30
parent 282611610d
commit be44665128

View file

@ -309,7 +309,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
defer wg.Done()
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
if err != nil {
errs = append(errs, err)
log.Error("error fetching embedding object: "+objectKey, err)
@ -346,9 +346,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
defer wg.Done()
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
defer cancel()
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
if err != nil {
log.Error("error fetching embedding object: "+objectKey, err)
embeddingObjects[i] = embeddingObjectResult{
@ -368,11 +366,45 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
return embeddingObjects, nil
}
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
type getOptions struct {
RetryCount int
FetchTimeOut gTime.Duration
}
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
if opt == nil {
opt = &getOptions{
RetryCount: 3,
FetchTimeOut: embeddingFetchTimeout,
}
}
ctxLogger := log.WithField("objectKey", objectKey)
totalAttempts := opt.RetryCount + 1
for i := 0; i < totalAttempts; i++ {
// Create a new context with a timeout for each fetch
fetchCtx, cancel := context.WithTimeout(ctx, opt.FetchTimeOut)
select {
case <-ctx.Done():
cancel()
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
default:
obj, err := c.downloadObject(fetchCtx, objectKey, downloader)
cancel() // Ensure cancel is called to release resources
if err == nil {
return obj, nil
}
// Check if the error is due to context timeout or cancellation
if fetchCtx.Err() != nil {
ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
} else {
ctxLogger.Error("Failed to fetch object: ", err)
}
}
}
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
}
func (c *Controller) downloadObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
@ -380,16 +412,11 @@ func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKe
Key: &objectKey,
})
if err != nil {
log.Error(err)
if retryCount > 0 {
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
}
return obj, stacktrace.Propagate(err, "")
return obj, stacktrace.Propagate(err, "downloadFailed")
}
err = json.Unmarshal(buff.Bytes(), &obj)
if err != nil {
log.Error(err)
return obj, stacktrace.Propagate(err, "")
return obj, stacktrace.Propagate(err, "unmarshal failed")
}
return obj, nil
}