[server] Add endpoint to get embedding for givenFilesIds

This commit is contained in:
Neeraj Gupta 2024-03-17 09:37:02 +05:30
parent 36982c5332
commit e927064476
5 changed files with 162 additions and 9 deletions

View file

@ -679,6 +679,7 @@ func main() {
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff)
privateAPI.GET("/embeddings/files", embeddingHandler.GetFilesEmbedding)
privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll)
offerHandler := &api.OfferHandler{Controller: offerController}

View file

@ -1,11 +1,12 @@
package ente
type Embedding struct {
FileID int64 `json:"fileID"`
Model string `json:"model"`
EncryptedEmbedding string `json:"encryptedEmbedding"`
DecryptionHeader string `json:"decryptionHeader"`
UpdatedAt int64 `json:"updatedAt"`
FileID int64 `json:"fileID"`
Model string `json:"model"`
EncryptedEmbedding string `json:"encryptedEmbedding"`
DecryptionHeader string `json:"decryptionHeader"`
UpdatedAt int64 `json:"updatedAt"`
Client *string `json:"client,omitempty"`
}
type InsertOrUpdateEmbeddingRequest struct {
@ -22,11 +23,23 @@ type GetEmbeddingDiffRequest struct {
Limit int16 `form:"limit" binding:"required"`
}
type GetFilesEmbeddingRequest struct {
Model Model `form:"model" binding:"required"`
FileIDs []int64 `form:"fileIDs" binding:"required"`
}
type GetFilesEmbeddingResponse struct {
Embeddings []Embedding `json:"embeddings"`
NoDataFileIDs []int64 `json:"noDataFileIDs"`
ErrFileIDs []int64 `json:"errFileIDs"`
}
type Model string
const (
OnnxClip Model = "onnx-clip"
GgmlClip Model = "ggml-clip"
OnnxClip Model = "onnx-clip"
GgmlClip Model = "ggml-clip"
OnnxYolo5MobileNet Model = "onnx-yolo5-mobile"
)
type EmbeddingObject struct {

View file

@ -50,6 +50,22 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) {
})
}
// GetFilesEmbedding returns the embeddings for the files
func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) {
var request ente.GetFilesEmbeddingRequest
if err := c.ShouldBindQuery(&request); err != nil {
handler.Error(c,
stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err)))
return
}
resp, err := h.Controller.GetFilesEmbedding(c, request)
if err != nil {
handler.Error(c, stacktrace.Propagate(err, ""))
return
}
c.JSON(http.StatusOK, resp)
}
// DeleteAll handler for deleting all embeddings for the user
func (h *EmbeddingHandler) DeleteAll(c *gin.Context) {
err := h.Controller.DeleteAll(c)

View file

@ -1,11 +1,12 @@
package embedding
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/ente-io/museum/pkg/utils/array"
"strconv"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws"
@ -118,6 +119,61 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest)
return embeddings, nil
}
func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
userID := auth.GetUserID(ctx.Request.Header)
if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
return nil, stacktrace.Propagate(err, "")
}
userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
dbFileIds := make([]int64, 0)
for _, embedding := range userFileEmbeddings {
dbFileIds = append(dbFileIds, embedding.FileID)
}
missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
errFileIds := make([]int64, 0)
// Collect object keys for userFileEmbeddings with missing data
var objectKeys []string
for i := range userFileEmbeddings {
objectKey := c.getObjectKey(userID, userFileEmbeddings[i].FileID, userFileEmbeddings[i].Model)
objectKeys = append(objectKeys, objectKey)
}
// Fetch missing userFileEmbeddings in parallel
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
fetchedEmbeddings := make([]ente.Embedding, 0)
// Populate missing data in userFileEmbeddings from fetched objects
for _, obj := range embeddingObjects {
if obj.err != nil {
errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
} else {
fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
FileID: obj.dbEmbeddingRow.FileID,
Model: obj.dbEmbeddingRow.Model,
EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
DecryptionHeader: obj.embeddingObject.DecryptionHeader,
UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
Client: obj.dbEmbeddingRow.Client,
})
}
}
return &ente.GetFilesEmbeddingResponse{
Embeddings: fetchedEmbeddings,
NoDataFileIDs: missingFileIds,
ErrFileIDs: errFileIds,
}, nil
}
func (c *Controller) DeleteAll(ctx *gin.Context) error {
userID := auth.GetUserID(ctx.Request.Header)
@ -208,7 +264,7 @@ func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) error {
up := s3manager.UploadInput{
Bucket: c.S3Config.GetHotBucket(),
Key: &key,
Body: strings.NewReader(string(embeddingObj)),
Body: bytes.NewReader(embeddingObj),
}
result, err := uploader.Upload(&up)
if err != nil {
@ -253,6 +309,44 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
return embeddingObjects, nil
}
type embeddingObjectResult struct {
embeddingObject ente.EmbeddingObject
dbEmbeddingRow ente.Embedding
err error
}
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
var wg sync.WaitGroup
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
for i, dbEmbeddingRow := range dbEmbeddingRows {
wg.Add(1)
globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
go func(i int, dbEmbeddingRow ente.Embedding) {
defer wg.Done()
defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
obj, err := c.getEmbeddingObject(objectKey, downloader)
if err != nil {
log.Error("error fetching embedding object: "+objectKey, err)
embeddingObjects[i] = embeddingObjectResult{
err: err,
dbEmbeddingRow: dbEmbeddingRow,
}
} else {
embeddingObjects[i] = embeddingObjectResult{
embeddingObject: obj,
dbEmbeddingRow: dbEmbeddingRow,
}
}
}(i, dbEmbeddingRow)
}
wg.Wait()
return embeddingObjects, nil
}
func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
var obj ente.EmbeddingObject
buff := &aws.WriteAtBuffer{}
@ -271,3 +365,22 @@ func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.
}
return obj, nil
}
func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
if req.Model == "" {
return ente.NewBadRequestWithMessage("model is required")
}
if len(req.FileIDs) == 0 {
return ente.NewBadRequestWithMessage("fileIDs are required")
}
if len(req.FileIDs) > 100 {
return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100")
}
if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
ActorUserId: userID,
FileIDs: req.FileIDs,
}); err != nil {
return stacktrace.Propagate(err, "User does not own some file(s)")
}
return nil
}

View file

@ -55,6 +55,16 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode
return convertRowsToEmbeddings(rows)
}
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at
FROM embeddings
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, fileIDs)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
return convertRowsToEmbeddings(rows)
}
func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error {
_, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID)
if err != nil {