diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 84c34189d..8ccb43cc0 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -678,7 +678,7 @@ func main() { pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) - embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName} + embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName) embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController} privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) diff --git a/server/configurations/local.yaml b/server/configurations/local.yaml index 196c56f1f..a3cb71c9d 100644 --- a/server/configurations/local.yaml +++ b/server/configurations/local.yaml @@ -125,6 +125,16 @@ s3: endpoint: region: bucket: + wasabi-eu-central-2-embeddings: + key: + secret: + endpoint: + region: + bucket: + # Embeddings bucket is used for storing embeddings and other derived data from a file. + # By default, it is the same as the hot storage bucket. + # embeddings-bucket: wasabi-eu-central-2-embeddings + # If true, enable some workarounds to allow us to use a local minio instance # for object storage. # diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index e8d6c347a..0212fb0de 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -46,6 +46,24 @@ type Controller struct { CollectionRepo *repo.CollectionRepository HostName string cleanupCronRunning bool + embeddingS3Client *s3.S3 + embeddingBucket *string +} + +func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { + return &Controller{ + Repo: repo, + AccessCtrl: accessCtrl, + ObjectCleanupController: objectCleanupController, + S3Config: s3Config, + QueueRepo: queueRepo, + TaskLockingRepo: taskLockingRepo, + FileRepo: fileRepo, + CollectionRepo: collectionRepo, + HostName: hostName, + embeddingS3Client: s3Config.GetEmbeddingsS3Client(), + embeddingBucket: s3Config.GetEmbeddingsBucket(), + } } func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) { @@ -245,7 +263,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { } prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) - err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) + err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetEmbeddingsDataCenter()) if err != nil { ctxLogger.WithError(err).Error("Failed to delete all objects") return @@ -277,9 +295,9 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string // uploadObject uploads the embedding object to the object store and returns the object size func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) { embeddingObj, _ := json.Marshal(obj) - uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client()) + uploader := s3manager.NewUploaderWithClient(c.embeddingS3Client) up := s3manager.UploadInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: c.embeddingBucket, Key: &key, Body: bytes.NewReader(embeddingObj), } @@ -301,7 +319,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em var wg sync.WaitGroup var errs []error embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) + downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client) for i, objectKey := range objectKeys { wg.Add(1) @@ -338,7 +356,7 @@ type embeddingObjectResult struct { func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { var wg sync.WaitGroup embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) + downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client) for i, dbEmbeddingRow := range dbEmbeddingRows { wg.Add(1) @@ -416,7 +434,7 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: c.embeddingBucket, Key: &objectKey, }) if err != nil { diff --git a/server/pkg/utils/s3config/s3config.go b/server/pkg/utils/s3config/s3config.go index 9b273bd61..02a7fbd26 100644 --- a/server/pkg/utils/s3config/s3config.go +++ b/server/pkg/utils/s3config/s3config.go @@ -28,6 +28,8 @@ type S3Config struct { hotDC string // Secondary (hot) data center secondaryHotDC string + // Bucket for storing ml embeddings & preview files + embeddingsDC string // A map from data centers to S3 configurations s3Configs map[string]*aws.Config // A map from data centers to pre-created S3 clients @@ -71,6 +73,7 @@ var ( dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2" dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3" dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3" + dcWasabiEuropeCentral string = "wasabi-eu-central-2-embeddings" ) // Number of days that the wasabi bucket is configured to retain objects. @@ -86,9 +89,9 @@ func NewS3Config() *S3Config { } func (config *S3Config) initialize() { - dcs := [5]string{ + dcs := [6]string{ dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated, - dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3} + dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentral} config.hotDC = dcB2EuropeCentral config.secondaryHotDC = dcWasabiEuropeCentral_v3 @@ -99,6 +102,12 @@ func (config *S3Config) initialize() { config.secondaryHotDC = hs2 log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2) } + config.embeddingsDC = config.hotDC + embeddingsDC := viper.GetString("s3.embeddings-bucket") + if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) { + config.embeddingsDC = embeddingsDC + log.Infof("Embeddings bucket: %s", embeddingsDC) + } config.buckets = make(map[string]string) config.s3Configs = make(map[string]*aws.Config) @@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 { return &s3Client } +func (config *S3Config) GetEmbeddingsDataCenter() string { + return config.embeddingsDC +} +func (config *S3Config) GetEmbeddingsBucket() *string { + return config.GetBucket(config.embeddingsDC) +} + +func (config *S3Config) GetEmbeddingsS3Client() *s3.S3 { + s3Client := config.GetS3Client(config.embeddingsDC) + return &s3Client +} + // Return the name of the hot Backblaze data center func (config *S3Config) GetHotBackblazeDC() string { return dcB2EuropeCentral