Add fallback logic to read embedding from hot bucket

This commit is contained in:
Neeraj Gupta 2024-05-14 17:00:16 +05:30
parent 87b087f295
commit 835a773f13

View file

@ -36,33 +36,35 @@ const (
)
type Controller struct {
Repo *embedding.Repository
AccessCtrl access.Controller
ObjectCleanupController *controller.ObjectCleanupController
S3Config *s3config.S3Config
QueueRepo *repo.QueueRepository
TaskLockingRepo *repo.TaskLockRepository
FileRepo *repo.FileRepository
CollectionRepo *repo.CollectionRepository
HostName string
cleanupCronRunning bool
embeddingS3Client *s3.S3
embeddingBucket *string
Repo *embedding.Repository
AccessCtrl access.Controller
ObjectCleanupController *controller.ObjectCleanupController
S3Config *s3config.S3Config
QueueRepo *repo.QueueRepository
TaskLockingRepo *repo.TaskLockRepository
FileRepo *repo.FileRepository
CollectionRepo *repo.CollectionRepository
HostName string
cleanupCronRunning bool
embeddingS3Client *s3.S3
embeddingBucket *string
areEmbeddingAndHotBucketSame bool
}
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
return &Controller{
Repo: repo,
AccessCtrl: accessCtrl,
ObjectCleanupController: objectCleanupController,
S3Config: s3Config,
QueueRepo: queueRepo,
TaskLockingRepo: taskLockingRepo,
FileRepo: fileRepo,
CollectionRepo: collectionRepo,
HostName: hostName,
embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
embeddingBucket: s3Config.GetEmbeddingsBucket(),
Repo: repo,
AccessCtrl: accessCtrl,
ObjectCleanupController: objectCleanupController,
S3Config: s3Config,
QueueRepo: queueRepo,
TaskLockingRepo: taskLockingRepo,
FileRepo: fileRepo,
CollectionRepo: collectionRepo,
HostName: hostName,
embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
embeddingBucket: s3Config.GetEmbeddingsBucket(),
areEmbeddingAndHotBucketSame: s3Config.GetEmbeddingsBucket() == s3Config.GetHotBucket(),
}
}
@ -269,7 +271,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
return
}
// if Embeddings DC is different from hot DC, delete from hot DC as well
if c.S3Config.GetEmbeddingsDataCenter() != c.S3Config.GetHotDataCenter() {
if !c.areEmbeddingAndHotBucketSame {
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
if err != nil {
ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
@ -425,10 +427,21 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
} else {
// check if the error is due to object not found
if s3Err, ok := errors.Unwrap(err).(awserr.Error); ok {
if s3Err, ok := err.(awserr.RequestFailure); ok {
if s3Err.Code() == s3.ErrCodeNoSuchKey {
ctxLogger.Warn("Object not found: ", s3Err)
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
if c.areEmbeddingAndHotBucketSame {
ctxLogger.Error("Object not found: ", s3Err)
} else {
// If embedding and hot bucket are different, try to copy from hot bucket
copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
if err == nil {
ctxLogger.Info("Got the object from hot bucket object")
return *copyEmbeddingObject, nil
} else {
ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
}
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
}
}
}
ctxLogger.Error("Failed to fetch object: ", err)
@ -455,6 +468,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
return obj, nil
}
// download the embedding object from hot bucket and upload to embeddings bucket
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
if c.embeddingBucket == c.S3Config.GetHotBucket() {
return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
}
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
if err != nil {
return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
}
go func() {
_, err = c.uploadObject(obj, objectKey)
if err != nil {
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
}
}()
return &obj, nil
}
func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
if req.Model == "" {
return ente.NewBadRequestWithMessage("model is required")