diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 9403efaa5..5dc1bc0b7 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -679,6 +679,7 @@ func main() { privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff) + privateAPI.GET("/embeddings/files", embeddingHandler.GetFilesEmbedding) privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll) offerHandler := &api.OfferHandler{Controller: offerController} diff --git a/server/ente/embedding.go b/server/ente/embedding.go index b59332ec6..2a92adf79 100644 --- a/server/ente/embedding.go +++ b/server/ente/embedding.go @@ -1,11 +1,12 @@ package ente type Embedding struct { - FileID int64 `json:"fileID"` - Model string `json:"model"` - EncryptedEmbedding string `json:"encryptedEmbedding"` - DecryptionHeader string `json:"decryptionHeader"` - UpdatedAt int64 `json:"updatedAt"` + FileID int64 `json:"fileID"` + Model string `json:"model"` + EncryptedEmbedding string `json:"encryptedEmbedding"` + DecryptionHeader string `json:"decryptionHeader"` + UpdatedAt int64 `json:"updatedAt"` + Client *string `json:"client,omitempty"` } type InsertOrUpdateEmbeddingRequest struct { @@ -22,11 +23,23 @@ type GetEmbeddingDiffRequest struct { Limit int16 `form:"limit" binding:"required"` } +type GetFilesEmbeddingRequest struct { + Model Model `form:"model" binding:"required"` + FileIDs []int64 `form:"fileIDs" binding:"required"` +} + +type GetFilesEmbeddingResponse struct { + Embeddings []Embedding `json:"embeddings"` + NoDataFileIDs []int64 `json:"noDataFileIDs"` + ErrFileIDs []int64 `json:"errFileIDs"` +} + type Model string const ( - OnnxClip Model = "onnx-clip" - GgmlClip Model = "ggml-clip" + OnnxClip Model = "onnx-clip" + GgmlClip Model = "ggml-clip" + OnnxYolo5MobileNet Model = "onnx-yolo5-mobile" ) type EmbeddingObject struct { diff --git a/server/pkg/api/embedding.go b/server/pkg/api/embedding.go index 983bed52c..4b072b0b7 100644 --- a/server/pkg/api/embedding.go +++ b/server/pkg/api/embedding.go @@ -50,6 +50,22 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) { }) } +// GetFilesEmbedding returns the embeddings for the files +func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) { + var request ente.GetFilesEmbeddingRequest + if err := c.ShouldBindQuery(&request); err != nil { + handler.Error(c, + stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err))) + return + } + resp, err := h.Controller.GetFilesEmbedding(c, request) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, resp) +} + // DeleteAll handler for deleting all embeddings for the user func (h *EmbeddingHandler) DeleteAll(c *gin.Context) { err := h.Controller.DeleteAll(c) diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index ce086aadb..fd124e12a 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -1,11 +1,12 @@ package embedding import ( + "bytes" "encoding/json" "errors" "fmt" + "github.com/ente-io/museum/pkg/utils/array" "strconv" - "strings" "sync" "github.com/aws/aws-sdk-go/aws" @@ -118,6 +119,61 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) return embeddings, nil } +func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) { + userID := auth.GetUserID(ctx.Request.Header) + if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil { + return nil, stacktrace.Propagate(err, "") + } + + userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + + dbFileIds := make([]int64, 0) + for _, embedding := range userFileEmbeddings { + dbFileIds = append(dbFileIds, embedding.FileID) + } + missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds) + errFileIds := make([]int64, 0) + + // Collect object keys for userFileEmbeddings with missing data + var objectKeys []string + for i := range userFileEmbeddings { + objectKey := c.getObjectKey(userID, userFileEmbeddings[i].FileID, userFileEmbeddings[i].Model) + objectKeys = append(objectKeys, objectKey) + } + + // Fetch missing userFileEmbeddings in parallel + embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + fetchedEmbeddings := make([]ente.Embedding, 0) + + // Populate missing data in userFileEmbeddings from fetched objects + for _, obj := range embeddingObjects { + if obj.err != nil { + errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID) + } else { + fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{ + FileID: obj.dbEmbeddingRow.FileID, + Model: obj.dbEmbeddingRow.Model, + EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding, + DecryptionHeader: obj.embeddingObject.DecryptionHeader, + UpdatedAt: obj.dbEmbeddingRow.UpdatedAt, + Client: obj.dbEmbeddingRow.Client, + }) + } + } + + return &ente.GetFilesEmbeddingResponse{ + Embeddings: fetchedEmbeddings, + NoDataFileIDs: missingFileIds, + ErrFileIDs: errFileIds, + }, nil +} + func (c *Controller) DeleteAll(ctx *gin.Context) error { userID := auth.GetUserID(ctx.Request.Header) @@ -208,7 +264,7 @@ func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) error { up := s3manager.UploadInput{ Bucket: c.S3Config.GetHotBucket(), Key: &key, - Body: strings.NewReader(string(embeddingObj)), + Body: bytes.NewReader(embeddingObj), } result, err := uploader.Upload(&up) if err != nil { @@ -253,6 +309,44 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em return embeddingObjects, nil } +type embeddingObjectResult struct { + embeddingObject ente.EmbeddingObject + dbEmbeddingRow ente.Embedding + err error +} + +func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { + var wg sync.WaitGroup + embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) + downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) + + for i, dbEmbeddingRow := range dbEmbeddingRows { + wg.Add(1) + globalFetchSemaphore <- struct{}{} // Acquire from global semaphore + go func(i int, dbEmbeddingRow ente.Embedding) { + defer wg.Done() + defer func() { <-globalFetchSemaphore }() // Release back to global semaphore + objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) + obj, err := c.getEmbeddingObject(objectKey, downloader) + if err != nil { + log.Error("error fetching embedding object: "+objectKey, err) + embeddingObjects[i] = embeddingObjectResult{ + err: err, + dbEmbeddingRow: dbEmbeddingRow, + } + + } else { + embeddingObjects[i] = embeddingObjectResult{ + embeddingObject: obj, + dbEmbeddingRow: dbEmbeddingRow, + } + } + }(i, dbEmbeddingRow) + } + wg.Wait() + return embeddingObjects, nil +} + func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} @@ -271,3 +365,22 @@ func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager. } return obj, nil } + +func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error { + if req.Model == "" { + return ente.NewBadRequestWithMessage("model is required") + } + if len(req.FileIDs) == 0 { + return ente.NewBadRequestWithMessage("fileIDs are required") + } + if len(req.FileIDs) > 100 { + return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100") + } + if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{ + ActorUserId: userID, + FileIDs: req.FileIDs, + }); err != nil { + return stacktrace.Propagate(err, "User does not own some file(s)") + } + return nil +} diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index e44753b24..90e8a8264 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -55,6 +55,16 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode return convertRowsToEmbeddings(rows) } +func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at + FROM embeddings + WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, fileIDs) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + return convertRowsToEmbeddings(rows) +} + func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error { _, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID) if err != nil {