Add support for configuring diff bucket for embeddings

This commit is contained in:
Neeraj Gupta 2024-05-14 14:03:35 +05:30
parent bce3f40a16
commit 3e7b16288f
4 changed files with 58 additions and 9 deletions

View file

@ -678,7 +678,7 @@ func main() {
pushHandler := &api.PushHandler{PushController: pushController} pushHandler := &api.PushHandler{PushController: pushController}
privateAPI.POST("/push/token", pushHandler.AddToken) privateAPI.POST("/push/token", pushHandler.AddToken)
embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName} embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName)
embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController} embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController}
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)

View file

@ -125,6 +125,16 @@ s3:
endpoint: endpoint:
region: region:
bucket: bucket:
wasabi-eu-central-2-embeddings:
key:
secret:
endpoint:
region:
bucket:
# Embeddings bucket is used for storing embeddings and other derived data from a file.
# By default, it is the same as the hot storage bucket.
# embeddings-bucket: wasabi-eu-central-2-embeddings
# If true, enable some workarounds to allow us to use a local minio instance # If true, enable some workarounds to allow us to use a local minio instance
# for object storage. # for object storage.
# #

View file

@ -46,6 +46,24 @@ type Controller struct {
CollectionRepo *repo.CollectionRepository CollectionRepo *repo.CollectionRepository
HostName string HostName string
cleanupCronRunning bool cleanupCronRunning bool
embeddingS3Client *s3.S3
embeddingBucket *string
}
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
return &Controller{
Repo: repo,
AccessCtrl: accessCtrl,
ObjectCleanupController: objectCleanupController,
S3Config: s3Config,
QueueRepo: queueRepo,
TaskLockingRepo: taskLockingRepo,
FileRepo: fileRepo,
CollectionRepo: collectionRepo,
HostName: hostName,
embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
embeddingBucket: s3Config.GetEmbeddingsBucket(),
}
} }
func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) { func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
@ -245,7 +263,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
} }
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetEmbeddingsDataCenter())
if err != nil { if err != nil {
ctxLogger.WithError(err).Error("Failed to delete all objects") ctxLogger.WithError(err).Error("Failed to delete all objects")
return return
@ -277,9 +295,9 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
// uploadObject uploads the embedding object to the object store and returns the object size // uploadObject uploads the embedding object to the object store and returns the object size
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) { func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
embeddingObj, _ := json.Marshal(obj) embeddingObj, _ := json.Marshal(obj)
uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client()) uploader := s3manager.NewUploaderWithClient(c.embeddingS3Client)
up := s3manager.UploadInput{ up := s3manager.UploadInput{
Bucket: c.S3Config.GetHotBucket(), Bucket: c.embeddingBucket,
Key: &key, Key: &key,
Body: bytes.NewReader(embeddingObj), Body: bytes.NewReader(embeddingObj),
} }
@ -301,7 +319,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
var wg sync.WaitGroup var wg sync.WaitGroup
var errs []error var errs []error
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
for i, objectKey := range objectKeys { for i, objectKey := range objectKeys {
wg.Add(1) wg.Add(1)
@ -338,7 +356,7 @@ type embeddingObjectResult struct {
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
var wg sync.WaitGroup var wg sync.WaitGroup
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
for i, dbEmbeddingRow := range dbEmbeddingRows { for i, dbEmbeddingRow := range dbEmbeddingRows {
wg.Add(1) wg.Add(1)
@ -416,7 +434,7 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
var obj ente.EmbeddingObject var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{} buff := &aws.WriteAtBuffer{}
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
Bucket: c.S3Config.GetHotBucket(), Bucket: c.embeddingBucket,
Key: &objectKey, Key: &objectKey,
}) })
if err != nil { if err != nil {

View file

@ -28,6 +28,8 @@ type S3Config struct {
hotDC string hotDC string
// Secondary (hot) data center // Secondary (hot) data center
secondaryHotDC string secondaryHotDC string
// Bucket for storing ml embeddings & preview files
embeddingsDC string
// A map from data centers to S3 configurations // A map from data centers to S3 configurations
s3Configs map[string]*aws.Config s3Configs map[string]*aws.Config
// A map from data centers to pre-created S3 clients // A map from data centers to pre-created S3 clients
@ -71,6 +73,7 @@ var (
dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2" dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2"
dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3" dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3"
dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3" dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3"
dcWasabiEuropeCentral string = "wasabi-eu-central-2-embeddings"
) )
// Number of days that the wasabi bucket is configured to retain objects. // Number of days that the wasabi bucket is configured to retain objects.
@ -86,9 +89,9 @@ func NewS3Config() *S3Config {
} }
func (config *S3Config) initialize() { func (config *S3Config) initialize() {
dcs := [5]string{ dcs := [6]string{
dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated, dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated,
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3} dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentral}
config.hotDC = dcB2EuropeCentral config.hotDC = dcB2EuropeCentral
config.secondaryHotDC = dcWasabiEuropeCentral_v3 config.secondaryHotDC = dcWasabiEuropeCentral_v3
@ -99,6 +102,12 @@ func (config *S3Config) initialize() {
config.secondaryHotDC = hs2 config.secondaryHotDC = hs2
log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2) log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2)
} }
config.embeddingsDC = config.hotDC
embeddingsDC := viper.GetString("s3.embeddings-bucket")
if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) {
config.embeddingsDC = embeddingsDC
log.Infof("Embeddings bucket: %s", embeddingsDC)
}
config.buckets = make(map[string]string) config.buckets = make(map[string]string)
config.s3Configs = make(map[string]*aws.Config) config.s3Configs = make(map[string]*aws.Config)
@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 {
return &s3Client return &s3Client
} }
func (config *S3Config) GetEmbeddingsDataCenter() string {
return config.embeddingsDC
}
func (config *S3Config) GetEmbeddingsBucket() *string {
return config.GetBucket(config.embeddingsDC)
}
func (config *S3Config) GetEmbeddingsS3Client() *s3.S3 {
s3Client := config.GetS3Client(config.embeddingsDC)
return &s3Client
}
// Return the name of the hot Backblaze data center // Return the name of the hot Backblaze data center
func (config *S3Config) GetHotBackblazeDC() string { func (config *S3Config) GetHotBackblazeDC() string {
return dcB2EuropeCentral return dcB2EuropeCentral