Add support for configuring diff bucket for embeddings
This commit is contained in:
parent
bce3f40a16
commit
3e7b16288f
4 changed files with 58 additions and 9 deletions
|
@ -678,7 +678,7 @@ func main() {
|
||||||
pushHandler := &api.PushHandler{PushController: pushController}
|
pushHandler := &api.PushHandler{PushController: pushController}
|
||||||
privateAPI.POST("/push/token", pushHandler.AddToken)
|
privateAPI.POST("/push/token", pushHandler.AddToken)
|
||||||
|
|
||||||
embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName}
|
embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName)
|
||||||
embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController}
|
embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController}
|
||||||
|
|
||||||
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
|
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
|
||||||
|
|
|
@ -125,6 +125,16 @@ s3:
|
||||||
endpoint:
|
endpoint:
|
||||||
region:
|
region:
|
||||||
bucket:
|
bucket:
|
||||||
|
wasabi-eu-central-2-embeddings:
|
||||||
|
key:
|
||||||
|
secret:
|
||||||
|
endpoint:
|
||||||
|
region:
|
||||||
|
bucket:
|
||||||
|
# Embeddings bucket is used for storing embeddings and other derived data from a file.
|
||||||
|
# By default, it is the same as the hot storage bucket.
|
||||||
|
# embeddings-bucket: wasabi-eu-central-2-embeddings
|
||||||
|
|
||||||
# If true, enable some workarounds to allow us to use a local minio instance
|
# If true, enable some workarounds to allow us to use a local minio instance
|
||||||
# for object storage.
|
# for object storage.
|
||||||
#
|
#
|
||||||
|
|
|
@ -46,6 +46,24 @@ type Controller struct {
|
||||||
CollectionRepo *repo.CollectionRepository
|
CollectionRepo *repo.CollectionRepository
|
||||||
HostName string
|
HostName string
|
||||||
cleanupCronRunning bool
|
cleanupCronRunning bool
|
||||||
|
embeddingS3Client *s3.S3
|
||||||
|
embeddingBucket *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
||||||
|
return &Controller{
|
||||||
|
Repo: repo,
|
||||||
|
AccessCtrl: accessCtrl,
|
||||||
|
ObjectCleanupController: objectCleanupController,
|
||||||
|
S3Config: s3Config,
|
||||||
|
QueueRepo: queueRepo,
|
||||||
|
TaskLockingRepo: taskLockingRepo,
|
||||||
|
FileRepo: fileRepo,
|
||||||
|
CollectionRepo: collectionRepo,
|
||||||
|
HostName: hostName,
|
||||||
|
embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
|
||||||
|
embeddingBucket: s3Config.GetEmbeddingsBucket(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
|
func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
|
||||||
|
@ -245,7 +263,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
|
||||||
}
|
}
|
||||||
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
|
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
|
||||||
|
|
||||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
|
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetEmbeddingsDataCenter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctxLogger.WithError(err).Error("Failed to delete all objects")
|
ctxLogger.WithError(err).Error("Failed to delete all objects")
|
||||||
return
|
return
|
||||||
|
@ -277,9 +295,9 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
||||||
// uploadObject uploads the embedding object to the object store and returns the object size
|
// uploadObject uploads the embedding object to the object store and returns the object size
|
||||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
||||||
embeddingObj, _ := json.Marshal(obj)
|
embeddingObj, _ := json.Marshal(obj)
|
||||||
uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
|
uploader := s3manager.NewUploaderWithClient(c.embeddingS3Client)
|
||||||
up := s3manager.UploadInput{
|
up := s3manager.UploadInput{
|
||||||
Bucket: c.S3Config.GetHotBucket(),
|
Bucket: c.embeddingBucket,
|
||||||
Key: &key,
|
Key: &key,
|
||||||
Body: bytes.NewReader(embeddingObj),
|
Body: bytes.NewReader(embeddingObj),
|
||||||
}
|
}
|
||||||
|
@ -301,7 +319,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
var errs []error
|
var errs []error
|
||||||
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
||||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
|
||||||
|
|
||||||
for i, objectKey := range objectKeys {
|
for i, objectKey := range objectKeys {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -338,7 +356,7 @@ type embeddingObjectResult struct {
|
||||||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
||||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
|
||||||
|
|
||||||
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -416,7 +434,7 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
||||||
var obj ente.EmbeddingObject
|
var obj ente.EmbeddingObject
|
||||||
buff := &aws.WriteAtBuffer{}
|
buff := &aws.WriteAtBuffer{}
|
||||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||||
Bucket: c.S3Config.GetHotBucket(),
|
Bucket: c.embeddingBucket,
|
||||||
Key: &objectKey,
|
Key: &objectKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -28,6 +28,8 @@ type S3Config struct {
|
||||||
hotDC string
|
hotDC string
|
||||||
// Secondary (hot) data center
|
// Secondary (hot) data center
|
||||||
secondaryHotDC string
|
secondaryHotDC string
|
||||||
|
// Bucket for storing ml embeddings & preview files
|
||||||
|
embeddingsDC string
|
||||||
// A map from data centers to S3 configurations
|
// A map from data centers to S3 configurations
|
||||||
s3Configs map[string]*aws.Config
|
s3Configs map[string]*aws.Config
|
||||||
// A map from data centers to pre-created S3 clients
|
// A map from data centers to pre-created S3 clients
|
||||||
|
@ -71,6 +73,7 @@ var (
|
||||||
dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2"
|
dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2"
|
||||||
dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3"
|
dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3"
|
||||||
dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3"
|
dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3"
|
||||||
|
dcWasabiEuropeCentral string = "wasabi-eu-central-2-embeddings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Number of days that the wasabi bucket is configured to retain objects.
|
// Number of days that the wasabi bucket is configured to retain objects.
|
||||||
|
@ -86,9 +89,9 @@ func NewS3Config() *S3Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (config *S3Config) initialize() {
|
func (config *S3Config) initialize() {
|
||||||
dcs := [5]string{
|
dcs := [6]string{
|
||||||
dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated,
|
dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated,
|
||||||
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3}
|
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentral}
|
||||||
|
|
||||||
config.hotDC = dcB2EuropeCentral
|
config.hotDC = dcB2EuropeCentral
|
||||||
config.secondaryHotDC = dcWasabiEuropeCentral_v3
|
config.secondaryHotDC = dcWasabiEuropeCentral_v3
|
||||||
|
@ -99,6 +102,12 @@ func (config *S3Config) initialize() {
|
||||||
config.secondaryHotDC = hs2
|
config.secondaryHotDC = hs2
|
||||||
log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2)
|
log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2)
|
||||||
}
|
}
|
||||||
|
config.embeddingsDC = config.hotDC
|
||||||
|
embeddingsDC := viper.GetString("s3.embeddings-bucket")
|
||||||
|
if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) {
|
||||||
|
config.embeddingsDC = embeddingsDC
|
||||||
|
log.Infof("Embeddings bucket: %s", embeddingsDC)
|
||||||
|
}
|
||||||
|
|
||||||
config.buckets = make(map[string]string)
|
config.buckets = make(map[string]string)
|
||||||
config.s3Configs = make(map[string]*aws.Config)
|
config.s3Configs = make(map[string]*aws.Config)
|
||||||
|
@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 {
|
||||||
return &s3Client
|
return &s3Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (config *S3Config) GetEmbeddingsDataCenter() string {
|
||||||
|
return config.embeddingsDC
|
||||||
|
}
|
||||||
|
func (config *S3Config) GetEmbeddingsBucket() *string {
|
||||||
|
return config.GetBucket(config.embeddingsDC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *S3Config) GetEmbeddingsS3Client() *s3.S3 {
|
||||||
|
s3Client := config.GetS3Client(config.embeddingsDC)
|
||||||
|
return &s3Client
|
||||||
|
}
|
||||||
|
|
||||||
// Return the name of the hot Backblaze data center
|
// Return the name of the hot Backblaze data center
|
||||||
func (config *S3Config) GetHotBackblazeDC() string {
|
func (config *S3Config) GetHotBackblazeDC() string {
|
||||||
return dcB2EuropeCentral
|
return dcB2EuropeCentral
|
||||||
|
|
Loading…
Reference in a new issue