|
@@ -36,33 +36,35 @@ const (
|
|
|
)
|
|
|
|
|
|
type Controller struct {
|
|
|
- Repo *embedding.Repository
|
|
|
- AccessCtrl access.Controller
|
|
|
- ObjectCleanupController *controller.ObjectCleanupController
|
|
|
- S3Config *s3config.S3Config
|
|
|
- QueueRepo *repo.QueueRepository
|
|
|
- TaskLockingRepo *repo.TaskLockRepository
|
|
|
- FileRepo *repo.FileRepository
|
|
|
- CollectionRepo *repo.CollectionRepository
|
|
|
- HostName string
|
|
|
- cleanupCronRunning bool
|
|
|
- embeddingS3Client *s3.S3
|
|
|
- embeddingBucket *string
|
|
|
+ Repo *embedding.Repository
|
|
|
+ AccessCtrl access.Controller
|
|
|
+ ObjectCleanupController *controller.ObjectCleanupController
|
|
|
+ S3Config *s3config.S3Config
|
|
|
+ QueueRepo *repo.QueueRepository
|
|
|
+ TaskLockingRepo *repo.TaskLockRepository
|
|
|
+ FileRepo *repo.FileRepository
|
|
|
+ CollectionRepo *repo.CollectionRepository
|
|
|
+ HostName string
|
|
|
+ cleanupCronRunning bool
|
|
|
+ embeddingS3Client *s3.S3
|
|
|
+ embeddingBucket *string
|
|
|
+ areEmbeddingAndHotBucketSame bool
|
|
|
}
|
|
|
|
|
|
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
|
|
return &Controller{
|
|
|
- Repo: repo,
|
|
|
- AccessCtrl: accessCtrl,
|
|
|
- ObjectCleanupController: objectCleanupController,
|
|
|
- S3Config: s3Config,
|
|
|
- QueueRepo: queueRepo,
|
|
|
- TaskLockingRepo: taskLockingRepo,
|
|
|
- FileRepo: fileRepo,
|
|
|
- CollectionRepo: collectionRepo,
|
|
|
- HostName: hostName,
|
|
|
- embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
|
|
|
- embeddingBucket: s3Config.GetEmbeddingsBucket(),
|
|
|
+ Repo: repo,
|
|
|
+ AccessCtrl: accessCtrl,
|
|
|
+ ObjectCleanupController: objectCleanupController,
|
|
|
+ S3Config: s3Config,
|
|
|
+ QueueRepo: queueRepo,
|
|
|
+ TaskLockingRepo: taskLockingRepo,
|
|
|
+ FileRepo: fileRepo,
|
|
|
+ CollectionRepo: collectionRepo,
|
|
|
+ HostName: hostName,
|
|
|
+ embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
|
|
|
+ embeddingBucket: s3Config.GetEmbeddingsBucket(),
|
|
|
+ areEmbeddingAndHotBucketSame: s3Config.GetEmbeddingsBucket() == s3Config.GetHotBucket(),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -269,7 +271,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
|
|
|
return
|
|
|
}
|
|
|
// if Embeddings DC is different from hot DC, delete from hot DC as well
|
|
|
- if c.S3Config.GetEmbeddingsDataCenter() != c.S3Config.GetHotDataCenter() {
|
|
|
+ if !c.areEmbeddingAndHotBucketSame {
|
|
|
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
|
|
|
if err != nil {
|
|
|
ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
|
|
@@ -425,10 +427,21 @@ func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, d
|
|
|
ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
|
|
|
} else {
|
|
|
// check if the error is due to object not found
|
|
|
- if s3Err, ok := errors.Unwrap(err).(awserr.Error); ok {
|
|
|
+ if s3Err, ok := err.(awserr.RequestFailure); ok {
|
|
|
if s3Err.Code() == s3.ErrCodeNoSuchKey {
|
|
|
- ctxLogger.Warn("Object not found: ", s3Err)
|
|
|
- return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
|
|
+ if c.areEmbeddingAndHotBucketSame {
|
|
|
+ ctxLogger.Error("Object not found: ", s3Err)
|
|
|
+ } else {
|
|
|
+ // If embedding and hot bucket are different, try to copy from hot bucket
|
|
|
+ copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey)
|
|
|
+ if err == nil {
|
|
|
+ ctxLogger.Info("Got the object from hot bucket object")
|
|
|
+ return *copyEmbeddingObject, nil
|
|
|
+ } else {
|
|
|
+ ctxLogger.WithError(err).Error("Failed to copy from hot bucket object")
|
|
|
+ }
|
|
|
+ return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
ctxLogger.Error("Failed to fetch object: ", err)
|
|
@@ -455,6 +468,26 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
|
|
return obj, nil
|
|
|
}
|
|
|
|
|
|
+// download the embedding object from hot bucket and upload to embeddings bucket
|
|
|
+func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string) (*ente.EmbeddingObject, error) {
|
|
|
+ if c.embeddingBucket == c.S3Config.GetHotBucket() {
|
|
|
+ return nil, stacktrace.Propagate(errors.New("embedding bucket and hot bucket are same"), "")
|
|
|
+ }
|
|
|
+ downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
|
|
+ obj, err := c.downloadObject(ctx, objectKey, downloader, c.S3Config.GetHotBucket())
|
|
|
+ if err != nil {
|
|
|
+ return nil, stacktrace.Propagate(err, "failed to download from hot bucket")
|
|
|
+ }
|
|
|
+ go func() {
|
|
|
+ _, err = c.uploadObject(obj, objectKey)
|
|
|
+ if err != nil {
|
|
|
+ log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ return &obj, nil
|
|
|
+}
|
|
|
+
|
|
|
func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
|
|
|
if req.Model == "" {
|
|
|
return ente.NewBadRequestWithMessage("model is required")
|