Make embedding bucket configurable (#1726)
## Description ## Tests - [x] New ml data is doing to new bucket - [x] For existing embedding, fall back logic is working as expected, we are returning the object immediately and copying the object to new bucket in an async manner. - [x] Verified that the dc values were getting updated correctly on copy or insert. - [x] Verified that on deletion, we are deleting files from all dcs where the derived file is present.
This commit is contained in:
commit
401cf92695
8 changed files with 464 additions and 129 deletions
|
@ -678,7 +678,7 @@ func main() {
|
|||
pushHandler := &api.PushHandler{PushController: pushController}
|
||||
privateAPI.POST("/push/token", pushHandler.AddToken)
|
||||
|
||||
embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName}
|
||||
embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName)
|
||||
embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController}
|
||||
|
||||
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
|
||||
|
|
|
@ -125,6 +125,16 @@ s3:
|
|||
endpoint:
|
||||
region:
|
||||
bucket:
|
||||
wasabi-eu-central-2-derived:
|
||||
key:
|
||||
secret:
|
||||
endpoint:
|
||||
region:
|
||||
bucket:
|
||||
# Derived storage bucket is used for storing derived data like embeddings, preview etc.
|
||||
# By default, it is the same as the hot storage bucket.
|
||||
# derived-storage: wasabi-eu-central-2-derived
|
||||
|
||||
# If true, enable some workarounds to allow us to use a local minio instance
|
||||
# for object storage.
|
||||
#
|
||||
|
|
18
server/migrations/86_add_dc_embedding.down.sql
Normal file
18
server/migrations/86_add_dc_embedding.down.sql
Normal file
|
@ -0,0 +1,18 @@
|
|||
-- Add types for the new dcs that are introduced for the derived data
|
||||
ALTER TABLE embeddings DROP COLUMN IF EXISTS datacenters;
|
||||
|
||||
DO
|
||||
$$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_embeddings_updated_at') THEN
|
||||
CREATE TRIGGER update_embeddings_updated_at
|
||||
BEFORE UPDATE
|
||||
ON embeddings
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE
|
||||
trigger_updated_at_microseconds_column();
|
||||
ELSE
|
||||
RAISE NOTICE 'Trigger update_embeddings_updated_at already exists.';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
4
server/migrations/86_add_dc_embedding.up.sql
Normal file
4
server/migrations/86_add_dc_embedding.up.sql
Normal file
|
@ -0,0 +1,4 @@
|
|||
-- Add types for the new dcs that are introduced for the derived data
|
||||
ALTER TYPE s3region ADD VALUE 'wasabi-eu-central-2-derived';
|
||||
DROP TRIGGER IF EXISTS update_embeddings_updated_at ON embeddings;
|
||||
ALTER TABLE embeddings ADD COLUMN IF NOT EXISTS datacenters s3region[] default '{b2-eu-cen}';
|
|
@ -6,8 +6,10 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
gTime "time"
|
||||
|
||||
|
@ -22,7 +24,6 @@ import (
|
|||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/network"
|
||||
"github.com/ente-io/museum/pkg/utils/s3config"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -31,20 +32,54 @@ import (
|
|||
const (
|
||||
// maxEmbeddingDataSize is the min size of an embedding object in bytes
|
||||
minEmbeddingDataSize = 2048
|
||||
embeddingFetchTimeout = 15 * gTime.Second
|
||||
embeddingFetchTimeout = 10 * gTime.Second
|
||||
)
|
||||
|
||||
// _fetchConfig is the configuration for the fetching objects from S3
|
||||
type _fetchConfig struct {
|
||||
RetryCount int
|
||||
InitialTimeout gTime.Duration
|
||||
MaxTimeout gTime.Duration
|
||||
}
|
||||
|
||||
var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second}
|
||||
var _b2FetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 15 * gTime.Second, MaxTimeout: 30 * gTime.Second}
|
||||
|
||||
type Controller struct {
|
||||
Repo *embedding.Repository
|
||||
AccessCtrl access.Controller
|
||||
ObjectCleanupController *controller.ObjectCleanupController
|
||||
S3Config *s3config.S3Config
|
||||
QueueRepo *repo.QueueRepository
|
||||
TaskLockingRepo *repo.TaskLockRepository
|
||||
FileRepo *repo.FileRepository
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
HostName string
|
||||
cleanupCronRunning bool
|
||||
Repo *embedding.Repository
|
||||
AccessCtrl access.Controller
|
||||
ObjectCleanupController *controller.ObjectCleanupController
|
||||
S3Config *s3config.S3Config
|
||||
QueueRepo *repo.QueueRepository
|
||||
TaskLockingRepo *repo.TaskLockRepository
|
||||
FileRepo *repo.FileRepository
|
||||
CollectionRepo *repo.CollectionRepository
|
||||
HostName string
|
||||
cleanupCronRunning bool
|
||||
derivedStorageDataCenter string
|
||||
downloadManagerCache map[string]*s3manager.Downloader
|
||||
}
|
||||
|
||||
func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller {
|
||||
embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter()}
|
||||
cache := make(map[string]*s3manager.Downloader, len(embeddingDcs))
|
||||
for i := range embeddingDcs {
|
||||
s3Client := s3Config.GetS3Client(embeddingDcs[i])
|
||||
cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client)
|
||||
}
|
||||
return &Controller{
|
||||
Repo: repo,
|
||||
AccessCtrl: accessCtrl,
|
||||
ObjectCleanupController: objectCleanupController,
|
||||
S3Config: s3Config,
|
||||
QueueRepo: queueRepo,
|
||||
TaskLockingRepo: taskLockingRepo,
|
||||
FileRepo: fileRepo,
|
||||
CollectionRepo: collectionRepo,
|
||||
HostName: hostName,
|
||||
derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(),
|
||||
downloadManagerCache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) {
|
||||
|
@ -77,12 +112,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
|
|||
DecryptionHeader: req.DecryptionHeader,
|
||||
Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
|
||||
}
|
||||
size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
|
||||
size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter)
|
||||
if uploadErr != nil {
|
||||
log.Error(uploadErr)
|
||||
return nil, stacktrace.Propagate(uploadErr, "")
|
||||
}
|
||||
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
|
||||
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter)
|
||||
embedding.Version = &version
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
|
@ -113,7 +148,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest)
|
|||
|
||||
// Fetch missing embeddings in parallel
|
||||
if len(objectKeys) > 0 {
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys)
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
@ -146,7 +181,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
|||
embeddingsWithData := make([]ente.Embedding, 0)
|
||||
noEmbeddingFileIds := make([]int64, 0)
|
||||
dbFileIds := make([]int64, 0)
|
||||
// fileIDs that were indexed but they don't contain any embedding information
|
||||
// fileIDs that were indexed, but they don't contain any embedding information
|
||||
for i := range userFileEmbeddings {
|
||||
dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
|
||||
if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
|
||||
|
@ -159,7 +194,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
|||
errFileIds := make([]int64, 0)
|
||||
|
||||
// Fetch missing userFileEmbeddings in parallel
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
@ -189,82 +224,6 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteAll(ctx *gin.Context) error {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
|
||||
err := c.Repo.DeleteAll(ctx, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store
|
||||
func (c *Controller) CleanupDeletedEmbeddings() {
|
||||
log.Info("Cleaning up deleted embeddings")
|
||||
if c.cleanupCronRunning {
|
||||
log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running")
|
||||
return
|
||||
}
|
||||
c.cleanupCronRunning = true
|
||||
defer func() {
|
||||
c.cleanupCronRunning = false
|
||||
}()
|
||||
items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to fetch items from queue")
|
||||
return
|
||||
}
|
||||
for _, i := range items {
|
||||
c.deleteEmbedding(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
|
||||
lockName := fmt.Sprintf("Embedding:%s", qItem.Item)
|
||||
lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName)
|
||||
ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id)
|
||||
if err != nil || !lockStatus {
|
||||
ctxLogger.Warn("unable to acquire lock")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err = c.TaskLockingRepo.ReleaseLock(lockName)
|
||||
if err != nil {
|
||||
ctxLogger.Errorf("Error while releasing lock %s", err)
|
||||
}
|
||||
}()
|
||||
ctxLogger.Info("Deleting all embeddings")
|
||||
|
||||
fileID, _ := strconv.ParseInt(qItem.Item, 10, 64)
|
||||
ownerID, err := c.FileRepo.GetOwnerID(fileID)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to fetch ownerID")
|
||||
return
|
||||
}
|
||||
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
|
||||
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to delete all objects")
|
||||
return
|
||||
}
|
||||
|
||||
err = c.Repo.Delete(fileID)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to remove from db")
|
||||
return
|
||||
}
|
||||
|
||||
err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to remove item from the queue")
|
||||
return
|
||||
}
|
||||
|
||||
ctxLogger.Info("Successfully deleted all embeddings")
|
||||
}
|
||||
|
||||
func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string {
|
||||
return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json"
|
||||
}
|
||||
|
@ -273,12 +232,23 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
|||
return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
|
||||
}
|
||||
|
||||
// Get userId, model and fileID from the object key
|
||||
func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) {
|
||||
split := strings.Split(objectKey, "/")
|
||||
userID, _ = strconv.ParseInt(split[0], 10, 64)
|
||||
fileID, _ = strconv.ParseInt(split[2], 10, 64)
|
||||
model = strings.Split(split[3], ".")[0]
|
||||
return userID, model, fileID
|
||||
}
|
||||
|
||||
// uploadObject uploads the embedding object to the object store and returns the object size
|
||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) {
|
||||
embeddingObj, _ := json.Marshal(obj)
|
||||
uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
|
||||
s3Client := c.S3Config.GetS3Client(dc)
|
||||
s3Bucket := c.S3Config.GetBucket(dc)
|
||||
uploader := s3manager.NewUploaderWithClient(&s3Client)
|
||||
up := s3manager.UploadInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Bucket: s3Bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(embeddingObj),
|
||||
}
|
||||
|
@ -296,12 +266,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300)
|
|||
|
||||
var globalFileFetchSemaphore = make(chan struct{}, 400)
|
||||
|
||||
func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) {
|
||||
func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) {
|
||||
var wg sync.WaitGroup
|
||||
var errs []error
|
||||
embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys))
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
|
||||
for i, objectKey := range objectKeys {
|
||||
wg.Add(1)
|
||||
globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore
|
||||
|
@ -309,7 +277,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
defer wg.Done()
|
||||
defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore
|
||||
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
|
@ -334,10 +302,9 @@ type embeddingObjectResult struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
||||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) {
|
||||
var wg sync.WaitGroup
|
||||
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
|
||||
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
||||
wg.Add(1)
|
||||
|
@ -346,9 +313,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
defer wg.Done()
|
||||
defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore
|
||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout)
|
||||
defer cancel()
|
||||
obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0)
|
||||
obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc)
|
||||
if err != nil {
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
|
@ -368,32 +333,125 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows
|
|||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3)
|
||||
func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
|
||||
opt := _defaultFetchConfig
|
||||
if dc == c.S3Config.GetHotBackblazeDC() {
|
||||
opt = _b2FetchConfig
|
||||
}
|
||||
ctxLogger := log.WithField("objectKey", objectKey).WithField("dc", dc)
|
||||
totalAttempts := opt.RetryCount + 1
|
||||
timeout := opt.InitialTimeout
|
||||
for i := 0; i < totalAttempts; i++ {
|
||||
if i > 0 {
|
||||
timeout = timeout * 2
|
||||
if timeout > opt.MaxTimeout {
|
||||
timeout = opt.MaxTimeout
|
||||
}
|
||||
}
|
||||
fetchCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "")
|
||||
default:
|
||||
obj, err := c.downloadObject(fetchCtx, objectKey, dc)
|
||||
cancel() // Ensure cancel is called to release resources
|
||||
if err == nil {
|
||||
if i > 0 {
|
||||
ctxLogger.Infof("Fetched object after %d attempts", i)
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
// Check if the error is due to context timeout or cancellation
|
||||
if err == nil && fetchCtx.Err() != nil {
|
||||
ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err())
|
||||
} else {
|
||||
// check if the error is due to object not found
|
||||
if s3Err, ok := err.(awserr.RequestFailure); ok {
|
||||
if s3Err.Code() == s3.ErrCodeNoSuchKey {
|
||||
var srcDc, destDc string
|
||||
destDc = c.S3Config.GetDerivedStorageDataCenter()
|
||||
// todo:(neeraj) Refactor this later to get available the DC from the DB instead of
|
||||
// querying the DB. This will help in case of multiple DCs and avoid querying the DB
|
||||
// for each object.
|
||||
// For initial migration, as we know that original DC was b2, and if the embedding is not found
|
||||
// in the new derived DC, we can try to fetch it from the B2 DC.
|
||||
if c.derivedStorageDataCenter != c.S3Config.GetHotBackblazeDC() {
|
||||
// embeddings ideally should ideally be in the default hot bucket b2
|
||||
srcDc = c.S3Config.GetHotBackblazeDC()
|
||||
} else {
|
||||
_, modelName, fileID := c.getEmbeddingObjectDetails(objectKey)
|
||||
activeDcs, err := c.Repo.GetOtherDCsForFileAndModel(context.Background(), fileID, modelName, c.derivedStorageDataCenter)
|
||||
if err != nil {
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(err, "failed to get other dc")
|
||||
}
|
||||
if len(activeDcs) > 0 {
|
||||
srcDc = activeDcs[0]
|
||||
} else {
|
||||
ctxLogger.Error("Object not found in any dc ", s3Err)
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
||||
}
|
||||
}
|
||||
copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, srcDc, destDc)
|
||||
if err == nil {
|
||||
ctxLogger.Infof("Got object from dc %s", srcDc)
|
||||
return *copyEmbeddingObject, nil
|
||||
} else {
|
||||
ctxLogger.WithError(err).Errorf("Failed to get object from fallback dc %s", srcDc)
|
||||
}
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "")
|
||||
}
|
||||
}
|
||||
ctxLogger.Error("Failed to fetch object: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "")
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) {
|
||||
func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) {
|
||||
var obj ente.EmbeddingObject
|
||||
buff := &aws.WriteAtBuffer{}
|
||||
bucket := c.S3Config.GetBucket(dc)
|
||||
downloader := c.downloadManagerCache[dc]
|
||||
_, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Bucket: bucket,
|
||||
Key: &objectKey,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if retryCount > 0 {
|
||||
return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1)
|
||||
}
|
||||
return obj, stacktrace.Propagate(err, "")
|
||||
return obj, err
|
||||
}
|
||||
err = json.Unmarshal(buff.Bytes(), &obj)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return obj, stacktrace.Propagate(err, "")
|
||||
return obj, stacktrace.Propagate(err, "unmarshal failed")
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
// download the embedding object from hot bucket and upload to embeddings bucket
|
||||
func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) {
|
||||
if srcDC == destDC {
|
||||
return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "")
|
||||
}
|
||||
obj, err := c.downloadObject(ctx, objectKey, srcDC)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC))
|
||||
}
|
||||
go func() {
|
||||
userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey)
|
||||
size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter)
|
||||
if uploadErr != nil {
|
||||
log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr)
|
||||
}
|
||||
updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC)
|
||||
if updateDcErr != nil {
|
||||
log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr)
|
||||
return
|
||||
}
|
||||
}()
|
||||
return &obj, nil
|
||||
}
|
||||
|
||||
func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
|
||||
if req.Model == "" {
|
||||
return ente.NewBadRequestWithMessage("model is required")
|
||||
|
|
126
server/pkg/controller/embedding/delete.go
Normal file
126
server/pkg/controller/embedding/delete.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/repo"
|
||||
"github.com/ente-io/museum/pkg/utils/auth"
|
||||
"github.com/ente-io/museum/pkg/utils/time"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/gin-gonic/gin"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (c *Controller) DeleteAll(ctx *gin.Context) error {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
|
||||
err := c.Repo.DeleteAll(ctx, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store
|
||||
func (c *Controller) CleanupDeletedEmbeddings() {
|
||||
log.Info("Cleaning up deleted embeddings")
|
||||
if c.cleanupCronRunning {
|
||||
log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running")
|
||||
return
|
||||
}
|
||||
c.cleanupCronRunning = true
|
||||
defer func() {
|
||||
c.cleanupCronRunning = false
|
||||
}()
|
||||
items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to fetch items from queue")
|
||||
return
|
||||
}
|
||||
for _, i := range items {
|
||||
c.deleteEmbedding(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
|
||||
lockName := fmt.Sprintf("Embedding:%s", qItem.Item)
|
||||
lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName)
|
||||
ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id)
|
||||
if err != nil || !lockStatus {
|
||||
ctxLogger.Warn("unable to acquire lock")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err = c.TaskLockingRepo.ReleaseLock(lockName)
|
||||
if err != nil {
|
||||
ctxLogger.Errorf("Error while releasing lock %s", err)
|
||||
}
|
||||
}()
|
||||
ctxLogger.Info("Deleting all embeddings")
|
||||
|
||||
fileID, _ := strconv.ParseInt(qItem.Item, 10, 64)
|
||||
ownerID, err := c.FileRepo.GetOwnerID(fileID)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to fetch ownerID")
|
||||
return
|
||||
}
|
||||
prefix := c.getEmbeddingObjectPrefix(ownerID, fileID)
|
||||
datacenters, err := c.Repo.GetDatacenters(context.Background(), fileID)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to fetch datacenters")
|
||||
return
|
||||
}
|
||||
// Ensure that the object are deleted from active derived storage dc. Ideally, this section should never be executed
|
||||
// unless there's a bug in storing the DC or the service restarts before removing the rows from the table
|
||||
// todo:(neeraj): remove this section after a few weeks of deployment
|
||||
if len(datacenters) == 0 {
|
||||
ctxLogger.Warn("No datacenters found for file, ensuring deletion from derived storage and hot DC")
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetDerivedStorageDataCenter())
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to delete all objects")
|
||||
return
|
||||
}
|
||||
// if Derived DC is different from hot DC, delete from hot DC as well
|
||||
if c.derivedStorageDataCenter != c.S3Config.GetHotDataCenter() {
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ctxLogger.Infof("Deleting from all datacenters %v", datacenters)
|
||||
}
|
||||
|
||||
for i := range datacenters {
|
||||
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, datacenters[i])
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Errorf("Failed to delete all objects from %s", datacenters[i])
|
||||
return
|
||||
} else {
|
||||
removeErr := c.Repo.RemoveDatacenter(context.Background(), fileID, datacenters[i])
|
||||
if removeErr != nil {
|
||||
ctxLogger.WithError(removeErr).Error("Failed to remove datacenter from db")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
noDcs, noDcErr := c.Repo.GetDatacenters(context.Background(), fileID)
|
||||
if len(noDcs) > 0 || noDcErr != nil {
|
||||
ctxLogger.Errorf("Failed to delete from all datacenters %s", noDcs)
|
||||
return
|
||||
}
|
||||
err = c.Repo.Delete(fileID)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to remove from db")
|
||||
return
|
||||
}
|
||||
err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item)
|
||||
if err != nil {
|
||||
ctxLogger.WithError(err).Error("Failed to remove item from the queue")
|
||||
return
|
||||
}
|
||||
ctxLogger.Info("Successfully deleted all embeddings")
|
||||
}
|
|
@ -3,11 +3,11 @@ package embedding
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/stacktrace"
|
||||
"github.com/lib/pq"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -18,15 +18,26 @@ type Repository struct {
|
|||
}
|
||||
|
||||
// Create inserts a new embedding
|
||||
|
||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) {
|
||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) {
|
||||
var updatedAt int64
|
||||
err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings
|
||||
(file_id, owner_id, model, size, version)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model
|
||||
DO UPDATE SET updated_at = now_utc_micro_seconds(), size = $4, version = $5
|
||||
RETURNING updated_at`, entry.FileID, ownerID, entry.Model, size, version).Scan(&updatedAt)
|
||||
err := r.DB.QueryRowContext(ctx, `
|
||||
INSERT INTO embeddings
|
||||
(file_id, owner_id, model, size, version, datacenters)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, ARRAY[$6]::s3region[])
|
||||
ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model
|
||||
DO UPDATE
|
||||
SET
|
||||
updated_at = now_utc_micro_seconds(),
|
||||
size = $4,
|
||||
version = $5,
|
||||
datacenters = CASE
|
||||
WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters
|
||||
ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region)
|
||||
END
|
||||
RETURNING updated_at`,
|
||||
entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt)
|
||||
|
||||
if err != nil {
|
||||
// check if error is due to model enum invalid value
|
||||
if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) {
|
||||
|
@ -82,6 +93,89 @@ func (r *Repository) Delete(fileID int64) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetDatacenters returns unique list of datacenters where derived embeddings are stored
|
||||
func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1`, fileID)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
uniqueDatacenters := make(map[string]struct{})
|
||||
for rows.Next() {
|
||||
var datacenters []string
|
||||
err = rows.Scan(pq.Array(&datacenters))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
for _, dc := range datacenters {
|
||||
uniqueDatacenters[dc] = struct{}{}
|
||||
}
|
||||
}
|
||||
datacenters := make([]string, 0, len(uniqueDatacenters))
|
||||
for dc := range uniqueDatacenters {
|
||||
datacenters = append(datacenters, dc)
|
||||
}
|
||||
return datacenters, nil
|
||||
}
|
||||
|
||||
// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC
|
||||
func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
uniqueDatacenters := make(map[string]bool)
|
||||
for rows.Next() {
|
||||
var datacenters []string
|
||||
err = rows.Scan(pq.Array(&datacenters))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
for _, dc := range datacenters {
|
||||
// add to uniqueDatacenters if it is not the ignoredDC
|
||||
if dc != ignoredDC {
|
||||
uniqueDatacenters[dc] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
datacenters := make([]string, 0, len(uniqueDatacenters))
|
||||
for dc := range uniqueDatacenters {
|
||||
datacenters = append(datacenters, dc)
|
||||
}
|
||||
return datacenters, nil
|
||||
}
|
||||
|
||||
// RemoveDatacenter removes the given datacenter from the list of datacenters
|
||||
func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error {
|
||||
_, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding
|
||||
func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error {
|
||||
res, err := r.DB.ExecContext(ctx, `
|
||||
UPDATE embeddings
|
||||
SET size = $1,
|
||||
datacenters = CASE
|
||||
WHEN $2::s3region = ANY(datacenters) THEN datacenters
|
||||
ELSE array_append(datacenters, $2::s3region)
|
||||
END
|
||||
WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID)
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return stacktrace.Propagate(err, "")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return stacktrace.Propagate(errors.New("no row got updated"), "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
|
|
|
@ -28,6 +28,8 @@ type S3Config struct {
|
|||
hotDC string
|
||||
// Secondary (hot) data center
|
||||
secondaryHotDC string
|
||||
//Derived data data center for derived files like ml embeddings & preview files
|
||||
derivedStorageDC string
|
||||
// A map from data centers to S3 configurations
|
||||
s3Configs map[string]*aws.Config
|
||||
// A map from data centers to pre-created S3 clients
|
||||
|
@ -71,6 +73,7 @@ var (
|
|||
dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2"
|
||||
dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3"
|
||||
dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3"
|
||||
dcWasabiEuropeCentralDerived string = "wasabi-eu-central-2-derived"
|
||||
)
|
||||
|
||||
// Number of days that the wasabi bucket is configured to retain objects.
|
||||
|
@ -86,9 +89,9 @@ func NewS3Config() *S3Config {
|
|||
}
|
||||
|
||||
func (config *S3Config) initialize() {
|
||||
dcs := [5]string{
|
||||
dcs := [6]string{
|
||||
dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated,
|
||||
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3}
|
||||
dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentralDerived}
|
||||
|
||||
config.hotDC = dcB2EuropeCentral
|
||||
config.secondaryHotDC = dcWasabiEuropeCentral_v3
|
||||
|
@ -99,6 +102,12 @@ func (config *S3Config) initialize() {
|
|||
config.secondaryHotDC = hs2
|
||||
log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2)
|
||||
}
|
||||
config.derivedStorageDC = config.hotDC
|
||||
embeddingsDC := viper.GetString("s3.derived-storage")
|
||||
if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) {
|
||||
config.derivedStorageDC = embeddingsDC
|
||||
log.Infof("Embeddings bucket: %s", embeddingsDC)
|
||||
}
|
||||
|
||||
config.buckets = make(map[string]string)
|
||||
config.s3Configs = make(map[string]*aws.Config)
|
||||
|
@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 {
|
|||
return &s3Client
|
||||
}
|
||||
|
||||
func (config *S3Config) GetDerivedStorageDataCenter() string {
|
||||
return config.derivedStorageDC
|
||||
}
|
||||
func (config *S3Config) GetDerivedStorageBucket() *string {
|
||||
return config.GetBucket(config.derivedStorageDC)
|
||||
}
|
||||
|
||||
func (config *S3Config) GetDerivedStorageS3Client() *s3.S3 {
|
||||
s3Client := config.GetS3Client(config.derivedStorageDC)
|
||||
return &s3Client
|
||||
}
|
||||
|
||||
// Return the name of the hot Backblaze data center
|
||||
func (config *S3Config) GetHotBackblazeDC() string {
|
||||
return dcB2EuropeCentral
|
||||
|
@ -181,6 +202,10 @@ func (config *S3Config) GetHotWasabiDC() string {
|
|||
return dcWasabiEuropeCentral_v3
|
||||
}
|
||||
|
||||
func (config *S3Config) GetWasabiDerivedDC() string {
|
||||
return dcWasabiEuropeCentralDerived
|
||||
}
|
||||
|
||||
// Return the name of the cold Scaleway data center
|
||||
func (config *S3Config) GetColdScalewayDC() string {
|
||||
return dcSCWEuropeFrance_v3
|
||||
|
|
Loading…
Add table
Reference in a new issue