Ver Fonte

Add fallback logic to read embedding from hot bucket

Neeraj Gupta há 1 ano atrás
pai
commit
835a773f13
1 ficheiros alterados com 60 adições e 27 exclusões
  1. 60 27
      server/pkg/controller/embedding/controller.go

+ 60 - 27
server/pkg/controller/embedding/controller.go

@@ -36,33 +36,35 @@ const (
 )
 
 type Controller struct {
-	Repo                    *embedding.Repository
-	AccessCtrl              access.Controller
-	ObjectCleanupController *controller.ObjectCleanupController
-	S3Config                *s3config.S3Config
-	QueueRepo               *repo.QueueRepository
-	TaskLockingRepo         *repo.TaskLockRepository
-	FileRepo                *repo.FileRepository
-	CollectionRepo          *repo.CollectionRepository
-	HostName                string
-	cleanupCronRunning      bool
-	embeddingS3Client       *s3.S3
-	embeddingBucket         *string
+	Repo                         *embedding.Repository
+	AccessCtrl                   access.Controller
+	ObjectCleanupController      *controller.ObjectCleanupController
+	S3Config                     *s3config.S3Config
+	QueueRepo                    *repo.QueueRepository
+	TaskLockingRepo              *repo.TaskLockRepository
+	FileRepo                     *repo.FileRepository
+	CollectionRepo               *repo.CollectionRepository
+	HostName                     string
+	cleanupCronRunning           bool
+	embeddingS3Client            *s3.S3
+	embeddingBucket              *string
+	areEmbeddingAndHotBucketSame bool
 }
 
 func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
 	return &Controller{
-		Repo:                    repo,
-		AccessCtrl:              accessCtrl,
-		ObjectCleanupController: objectCleanupController,
-		S3Config:                s3Config,
-		QueueRepo:               queueRepo,
-		TaskLockingRepo:         taskLockingRepo,
-		FileRepo:                fileRepo,
-		CollectionRepo:          collectionRepo,
-		HostName:                hostName,
-		embeddingS3Client:       s3Config.GetEmbeddingsS3Client(),
-		embeddingBucket:         s3Config.GetEmbeddingsBucket(),
+		Repo:                         repo,
+		AccessCtrl:                   accessCtrl,
+		ObjectCleanupController:      objectCleanupController,
+		S3Config:                     s3Config,
+		QueueRepo:                    queueRepo,
+		TaskLockingRepo:              taskLockingRepo,
+		FileRepo:                     fileRepo,
+		CollectionRepo:               collectionRepo,
+		HostName:                     hostName,
+		embeddingS3Client:            s3Config.GetEmbeddingsS3Client(),
+		embeddingBucket:              s3Config.GetEmbeddingsBucket(),
+		areEmbeddingAndHotBucketSame: s3Config.GetEmbeddingsBucket() == s3Config.GetHotBucket(),
 	}
 }
 
@@ -269,7 +271,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
 		return
 	}
 	// if Embeddings DC is different from hot DC, delete from hot DC as well
-	if c.S3Config.GetEmbeddingsDataCenter() != c.S3Config.GetHotDataCenter() {
+	if !c.areEmbeddingAndHotBucketSame {
 		err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
 		if err != nil {
 			ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
@@ -425,10 +427,21 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
 				ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
 			} else {
 				// check if the error is due to object not found
-				if s3Err, ok := errors.Unwrap(err).(awserr.Error); ok {
+				if s3Err, ok := err.(awserr.RequestFailure); ok {
 					if s3Err.Code() == s3.ErrCodeNoSuchKey {
-						ctxLogger.Warn("Object not found: ", s3Err)
-						return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
+						if c.areEmbeddingAndHotBucketSame {
+							ctxLogger.Error("Object not found: ", s3Err)
+						} else {
+							// If embedding and hot bucket are different, try to copy from hot bucket
+							copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
+							if err == nil {
+								ctxLogger.Info("Got the object from hot bucket object")
+								return *copyEmbeddingObject, nil
+							} else {
+								ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
+							}
+							return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
+						}
 					}
 				}
 				ctxLogger.Error("Failed to fetch object: ", err)
@@ -455,6 +468,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
 	return obj, nil
 }
 
+// download the embedding object from hot bucket and upload to embeddings bucket
+func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
+	if c.embeddingBucket == c.S3Config.GetHotBucket() {
+		return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
+	}
+	downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
+	obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
+	if err != nil {
+		return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
+	}
+	go func() {
+		_, err = c.uploadObject(obj, objectKey)
+		if err != nil {
+			log.WithField("object", objectKey).Error("Failed to copy  to embeddings bucket: ", err)
+		}
+	}()
+
+	return &obj, nil
+}
+
 func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
 	if req.Model == "" {
 		return ente.NewBadRequestWithMessage("model is required")