Add support for configuring diff bucket for embeddings
This commit is contained in:
parent
bce3f40a16
commit
3e7b16288f
4 changed files with 58 additions and 9 deletions
|
@ -678,7 +678,7 @@ func main() {
|
|||
pushHandler := &api.PushHandler{PushController: pushController}
|
||||
privateAPI.POST("/push/token", pushHandler.AddToken)
|
||||
|
||||
embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName}
|
||||
embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName)
|
||||
embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController}
|
||||
|
||||
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
|
||||
|
|
|
@ -125,6 +125,16 @@ s3:
|
|||
endpoint:
|
||||
region:
|
||||
bucket:
|
||||
wasabi-eu-central-2-embeddings:
|
||||
key:
|
||||
secret:
|
||||
endpoint:
|
||||
region:
|
||||
bucket:
|
||||
# Embeddings bucket is used for storing embeddings and other derived data from a file.
|
||||
# By default, it is the same as the hot storage bucket.
|
||||
# embeddings-bucket: wasabi-eu-central-2-embeddings
|
||||
|
||||
# If true, enable some workarounds to allow us to use a local minio instance
|
||||
# for object storage.
|
||||
#
|
||||
|
|
|
@ -46,6 +46,24 @@ type Controller struct {
|
|||
CollectionRepo *repo.CollectionRepository
|
||||
HostName string
|
||||
cleanupCronRunning bool
|
||||
embeddingS3Client *s3.S3
|
||||
embeddingBucket *string
|
||||
}
|
||||
|
||||
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
||||
return &Controller{
|
||||
Repo: repo,
|
||||
AccessCtrl: accessCtrl,
|
||||
ObjectCleanupController: objectCleanupController,
|
||||
S3Config: s3Config,
|
||||
QueueRepo: queueRepo,
|
||||
TaskLockingRepo: taskLockingRepo,
|
||||
FileRepo: fileRepo,
|
||||
CollectionRepo: collectionRepo,
|
||||
HostName: hostName,
|
||||
embeddingS3Client: s3Config.GetEmbeddingsS3Client(),
|
||||
embeddingBucket: s3Config.GetEmbeddingsBucket(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
|
||||
|
@ -245,7 +263,7 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
|
|||
}
|
||||
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
|
||||
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetEmbeddingsDataCenter())
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to delete all objects")
|
||||
return
|
||||
|
@ -277,9 +295,9 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
|||
// uploadObject uploads the embedding object to the object store and returns the object size
|
||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
||||
embeddingObj, _ := json.Marshal(obj)
|
||||
uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
|
||||
uploader := s3manager.NewUploaderWithClient(c.embeddingS3Client)
|
||||
up := s3manager.UploadInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Bucket: c.embeddingBucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(embeddingObj),
|
||||
}
|
||||
|
@ -301,7 +319,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
var wg sync.WaitGroup
|
||||
var errs []error
|
||||
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
|
||||
|
||||
for i, objectKey := range objectKeys {
|
||||
wg.Add(1)
|
||||
|
@ -338,7 +356,7 @@ type embeddingObjectResult struct {
|
|||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
||||
var wg sync.WaitGroup
|
||||
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
downloader := s3manager.NewDownloaderWithClient(c.embeddingS3Client)
|
||||
|
||||
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
||||
wg.Add(1)
|
||||
|
@ -416,7 +434,7 @@ func (c *Controller) downloadObject(ctx context.Context, objectKey string, downl
|
|||
var obj ente.EmbeddingObject
|
||||
buff := &aws.WriteAtBuffer{}
|
||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Bucket: c.embeddingBucket,
|
||||
Key: &objectKey,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
@ -28,6 +28,8 @@ type S3Config struct {
|
|||
hotDC string
|
||||
// Secondary (hot) data center
|
||||
secondaryHotDC string
|
||||
// Bucket for storing ml embeddings & preview files
|
||||
embeddingsDC string
|
||||
// A map from data centers to S3 configurations
|
||||
s3Configs map[string]*aws.Config
|
||||
// A map from data centers to pre-created S3 clients
|
||||
|
@ -71,6 +73,7 @@ var (
|
|||
dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2"
|
||||
dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3"
|
||||
dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3"
|
||||
dcWasabiEuropeCentral string = "wasabi-eu-central-2-embeddings"
|
||||
)
|
||||
|
||||
// Number of days that the wasabi bucket is configured to retain objects.
|
||||
|
@ -86,9 +89,9 @@ func NewS3Config() *S3Config {
|
|||
}
|
||||
|
||||
func (config *S3Config) initialize() {
|
||||
dcs := [5]string{
|
||||
dcs := [6]string{
|
||||
dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated,
|
||||
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3}
|
||||
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentral}
|
||||
|
||||
config.hotDC = dcB2EuropeCentral
|
||||
config.secondaryHotDC = dcWasabiEuropeCentral_v3
|
||||
|
@ -99,6 +102,12 @@ func (config *S3Config) initialize() {
|
|||
config.secondaryHotDC = hs2
|
||||
log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2)
|
||||
}
|
||||
config.embeddingsDC = config.hotDC
|
||||
embeddingsDC := viper.GetString("s3.embeddings-bucket")
|
||||
if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) {
|
||||
config.embeddingsDC = embeddingsDC
|
||||
log.Infof("Embeddings bucket: %s", embeddingsDC)
|
||||
}
|
||||
|
||||
config.buckets = make(map[string]string)
|
||||
config.s3Configs = make(map[string]*aws.Config)
|
||||
|
@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 {
|
|||
return &s3Client
|
||||
}
|
||||
|
||||
func (config *S3Config) GetEmbeddingsDataCenter() string {
|
||||
return config.embeddingsDC
|
||||
}
|
||||
func (config *S3Config) GetEmbeddingsBucket() *string {
|
||||
return config.GetBucket(config.embeddingsDC)
|
||||
}
|
||||
|
||||
func (config *S3Config) GetEmbeddingsS3Client() *s3.S3 {
|
||||
s3Client := config.GetS3Client(config.embeddingsDC)
|
||||
return &s3Client
|
||||
}
|
||||
|
||||
// Return the name of the hot Backblaze data center
|
||||
func (config *S3Config) GetHotBackblazeDC() string {
|
||||
return dcB2EuropeCentral
|
||||
|
|
Loading…
Reference in a new issue