Refactor
This commit is contained in:
parent
3485b31475
commit
20e9a6a1fc
1 changed files with 15 additions and 17 deletions
|
@ -33,6 +33,14 @@ const (
|
|||
embeddingFetchTimeout = 15 * gTime.Second
|
||||
)
|
||||
|
||||
// _fetchConfig is the configuration for the fetching objects from S3
|
||||
type _fetchConfig struct {
|
||||
RetryCount int
|
||||
FetchTimeOut gTime.Duration
|
||||
}
|
||||
|
||||
var _defaultFetchConfig = _fetchConfig{RetryCount: 3, FetchTimeOut: 15 * gTime.Second}
|
||||
|
||||
type Controller struct {
|
||||
Repo *embedding.Repository
|
||||
AccessCtrl access.Controller
|
||||
|
@ -251,7 +259,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
defer wg.Done()
|
||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
|
@ -289,7 +297,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
defer wg.Done()
|
||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader, nil)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||
if err != nil {
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
|
@ -309,18 +317,8 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
type getOptions struct {
|
||||
RetryCount int
|
||||
FetchTimeOut gTime.Duration
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader, opt *getOptions) (ente.EmbeddingObject, error) {
|
||||
if opt == nil {
|
||||
opt = &getOptions{
|
||||
RetryCount: 3,
|
||||
FetchTimeOut: embeddingFetchTimeout,
|
||||
}
|
||||
}
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
opt := _defaultFetchConfig
|
||||
ctxLogger := log.WithField("objectKey", objectKey)
|
||||
totalAttempts := opt.RetryCount + 1
|
||||
for i := 0; i < totalAttempts; i++ {
|
||||
|
@ -346,7 +344,7 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
|||
// check if the error is due to object not found
|
||||
if s3Err, ok := err.(awserr.RequestFailure); ok {
|
||||
if s3Err.Code() == s3.ErrCodeNoSuchKey {
|
||||
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
|
||||
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
||||
ctxLogger.Error("Object not found: ", s3Err)
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
||||
} else {
|
||||
|
@ -389,11 +387,11 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
|||
|
||||
// download the embedding object from hot bucket and upload to embeddings bucket
|
||||
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
||||
if c.derivedStorageDataCenter == c.S3Config.GetHotDataCenter() {
|
||||
if c.derivedStorageDataCenter == c.S3Config.GetHotBackblazeDC() {
|
||||
return nil, stacktrace.Propagate(errors.New("derived DC bucket and hot DC are same"), "")
|
||||
}
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotDataCenter())
|
||||
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBackblazeDC())
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue