[server] API to fetch ML embeddings for given fileIDs (#1144)
## Description - Also started storing the embedding size and version in the table. This will come handy while comparing overall size when different compression and serialization formats are used.. - Also, it can be used to smartly auto-download the embeddings or inform the user about approximate size when they decide to enable indexing or search on a particular client. ## Tests ✅ Verified that embedding fetch and store is working fine. ✅ Verified that embeddings/diff API is working fine.
This commit is contained in:
commit
449284a6a1
8 changed files with 212 additions and 18 deletions
|
@ -681,6 +681,7 @@ func main() {
|
|||
|
||||
privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate)
|
||||
privateAPI.GET("/embeddings/diff", embeddingHandler.GetDiff)
|
||||
privateAPI.POST("/embeddings/files", embeddingHandler.GetFilesEmbedding)
|
||||
privateAPI.DELETE("/embeddings", embeddingHandler.DeleteAll)
|
||||
|
||||
offerHandler := &api.OfferHandler{Controller: offerController}
|
||||
|
|
|
@ -6,6 +6,7 @@ type Embedding struct {
|
|||
EncryptedEmbedding string `json:"encryptedEmbedding"`
|
||||
DecryptionHeader string `json:"decryptionHeader"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Version *int `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
type InsertOrUpdateEmbeddingRequest struct {
|
||||
|
@ -13,6 +14,7 @@ type InsertOrUpdateEmbeddingRequest struct {
|
|||
Model string `json:"model" binding:"required"`
|
||||
EncryptedEmbedding string `json:"encryptedEmbedding" binding:"required"`
|
||||
DecryptionHeader string `json:"decryptionHeader" binding:"required"`
|
||||
Version *int `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
type GetEmbeddingDiffRequest struct {
|
||||
|
@ -22,11 +24,25 @@ type GetEmbeddingDiffRequest struct {
|
|||
Limit int16 `form:"limit" binding:"required"`
|
||||
}
|
||||
|
||||
type GetFilesEmbeddingRequest struct {
|
||||
Model Model `form:"model" binding:"required"`
|
||||
FileIDs []int64 `form:"fileIDs" binding:"required"`
|
||||
}
|
||||
|
||||
type GetFilesEmbeddingResponse struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
NoDataFileIDs []int64 `json:"noDataFileIDs"`
|
||||
ErrFileIDs []int64 `json:"errFileIDs"`
|
||||
}
|
||||
|
||||
type Model string
|
||||
|
||||
const (
|
||||
OnnxClip Model = "onnx-clip"
|
||||
GgmlClip Model = "ggml-clip"
|
||||
|
||||
// FileMlClipFace is a model for face embeddings, it is used in request validation.
|
||||
FileMlClipFace Model = "file-ml-clip-face"
|
||||
)
|
||||
|
||||
type EmbeddingObject struct {
|
||||
|
|
3
server/migrations/81_embeddings_type_and_size.down.sql
Normal file
3
server/migrations/81_embeddings_type_and_size.down.sql
Normal file
|
@ -0,0 +1,3 @@
|
|||
ALTER TABLE embeddings
|
||||
DROP COLUMN IF EXISTS size,
|
||||
DROP COLUMN IF EXISTS version;
|
4
server/migrations/81_embeddings_type_and_size.up.sql
Normal file
4
server/migrations/81_embeddings_type_and_size.up.sql
Normal file
|
@ -0,0 +1,4 @@
|
|||
ALTER TYPE model ADD VALUE IF NOT EXISTS 'file-ml-clip-face';
|
||||
ALTER TABLE embeddings
|
||||
ADD COLUMN size int DEFAULT NULL,
|
||||
ADD COLUMN version int DEFAULT 1;
|
|
@ -50,6 +50,22 @@ func (h *EmbeddingHandler) GetDiff(c *gin.Context) {
|
|||
})
|
||||
}
|
||||
|
||||
// GetFilesEmbedding returns the embeddings for the files
|
||||
func (h *EmbeddingHandler) GetFilesEmbedding(c *gin.Context) {
|
||||
var request ente.GetFilesEmbeddingRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
handler.Error(c,
|
||||
stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("Request binding failed %s", err)))
|
||||
return
|
||||
}
|
||||
resp, err := h.Controller.GetFilesEmbedding(c, request)
|
||||
if err != nil {
|
||||
handler.Error(c, stacktrace.Propagate(err, ""))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// DeleteAll handler for deleting all embeddings for the user
|
||||
func (h *EmbeddingHandler) DeleteAll(c *gin.Context) {
|
||||
err := h.Controller.DeleteAll(c)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/ente-io/museum/pkg/utils/array"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
|
@ -57,19 +58,24 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb
|
|||
if count < 1 {
|
||||
return nil, stacktrace.Propagate(ente.ErrNotFound, "")
|
||||
}
|
||||
version := 1
|
||||
if req.Version != nil {
|
||||
version = *req.Version
|
||||
}
|
||||
|
||||
obj := ente.EmbeddingObject{
|
||||
Version: 1,
|
||||
Version: version,
|
||||
EncryptedEmbedding: req.EncryptedEmbedding,
|
||||
DecryptionHeader: req.DecryptionHeader,
|
||||
Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"),
|
||||
}
|
||||
err = c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model))
|
||||
if uploadErr != nil {
|
||||
log.Error(uploadErr)
|
||||
return nil, stacktrace.Propagate(uploadErr, "")
|
||||
}
|
||||
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req)
|
||||
embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version)
|
||||
embedding.Version = &version
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
@ -118,6 +124,54 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest)
|
|||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbeddingRequest) (*ente.GetFilesEmbeddingResponse, error) {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
if err := c._validateGetFileEmbeddingsRequest(ctx, userID, req); err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
userFileEmbeddings, err := c.Repo.GetFilesEmbedding(ctx, userID, req.Model, req.FileIDs)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
dbFileIds := make([]int64, 0)
|
||||
for _, embedding := range userFileEmbeddings {
|
||||
dbFileIds = append(dbFileIds, embedding.FileID)
|
||||
}
|
||||
missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
|
||||
errFileIds := make([]int64, 0)
|
||||
|
||||
// Fetch missing userFileEmbeddings in parallel
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
fetchedEmbeddings := make([]ente.Embedding, 0)
|
||||
|
||||
// Populate missing data in userFileEmbeddings from fetched objects
|
||||
for _, obj := range embeddingObjects {
|
||||
if obj.err != nil {
|
||||
errFileIds = append(errFileIds, obj.dbEmbeddingRow.FileID)
|
||||
} else {
|
||||
fetchedEmbeddings = append(fetchedEmbeddings, ente.Embedding{
|
||||
FileID: obj.dbEmbeddingRow.FileID,
|
||||
Model: obj.dbEmbeddingRow.Model,
|
||||
EncryptedEmbedding: obj.embeddingObject.EncryptedEmbedding,
|
||||
DecryptionHeader: obj.embeddingObject.DecryptionHeader,
|
||||
UpdatedAt: obj.dbEmbeddingRow.UpdatedAt,
|
||||
Version: obj.dbEmbeddingRow.Version,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &ente.GetFilesEmbeddingResponse{
|
||||
Embeddings: fetchedEmbeddings,
|
||||
NoDataFileIDs: missingFileIds,
|
||||
ErrFileIDs: errFileIds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeleteAll(ctx *gin.Context) error {
|
||||
userID := auth.GetUserID(ctx.Request.Header)
|
||||
|
||||
|
@ -202,21 +256,23 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string
|
|||
return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/"
|
||||
}
|
||||
|
||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) error {
|
||||
// uploadObject uploads the embedding object to the object store and returns the object size
|
||||
func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) {
|
||||
embeddingObj, _ := json.Marshal(obj)
|
||||
uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client())
|
||||
up := s3manager.UploadInput{
|
||||
Bucket: c.S3Config.GetHotBucket(),
|
||||
Key: &key,
|
||||
Body: strings.NewReader(string(embeddingObj)),
|
||||
Body: bytes.NewReader(embeddingObj),
|
||||
}
|
||||
result, err := uploader.Upload(&up)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return stacktrace.Propagate(err, "")
|
||||
return -1, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
log.Infof("Uploaded to bucket %s", result.Location)
|
||||
return nil
|
||||
return len(embeddingObj), nil
|
||||
}
|
||||
|
||||
var globalFetchSemaphore = make(chan struct{}, 300)
|
||||
|
@ -253,6 +309,44 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em
|
|||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
type embeddingObjectResult struct {
|
||||
embeddingObject ente.EmbeddingObject
|
||||
dbEmbeddingRow ente.Embedding
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) {
|
||||
var wg sync.WaitGroup
|
||||
embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows))
|
||||
downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client())
|
||||
|
||||
for i, dbEmbeddingRow := range dbEmbeddingRows {
|
||||
wg.Add(1)
|
||||
globalFetchSemaphore <- struct{}{} // Acquire from global semaphore
|
||||
go func(i int, dbEmbeddingRow ente.Embedding) {
|
||||
defer wg.Done()
|
||||
defer func() { <-globalFetchSemaphore }() // Release back to global semaphore
|
||||
objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model)
|
||||
obj, err := c.getEmbeddingObject(objectKey, downloader)
|
||||
if err != nil {
|
||||
log.Error("error fetching embedding object: "+objectKey, err)
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
err: err,
|
||||
dbEmbeddingRow: dbEmbeddingRow,
|
||||
}
|
||||
|
||||
} else {
|
||||
embeddingObjects[i] = embeddingObjectResult{
|
||||
embeddingObject: obj,
|
||||
dbEmbeddingRow: dbEmbeddingRow,
|
||||
}
|
||||
}
|
||||
}(i, dbEmbeddingRow)
|
||||
}
|
||||
wg.Wait()
|
||||
return embeddingObjects, nil
|
||||
}
|
||||
|
||||
func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) {
|
||||
var obj ente.EmbeddingObject
|
||||
buff := &aws.WriteAtBuffer{}
|
||||
|
@ -271,3 +365,22 @@ func (c *Controller) getEmbeddingObject(objectKey string, downloader *s3manager.
|
|||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error {
|
||||
if req.Model == "" {
|
||||
return ente.NewBadRequestWithMessage("model is required")
|
||||
}
|
||||
if len(req.FileIDs) == 0 {
|
||||
return ente.NewBadRequestWithMessage("fileIDs are required")
|
||||
}
|
||||
if len(req.FileIDs) > 100 {
|
||||
return ente.NewBadRequestWithMessage("fileIDs should be less than or equal to 100")
|
||||
}
|
||||
if err := c.AccessCtrl.VerifyFileOwnership(ctx, &access.VerifyFileOwnershipParams{
|
||||
ActorUserId: userID,
|
||||
FileIDs: req.FileIDs,
|
||||
}); err != nil {
|
||||
return stacktrace.Propagate(err, "User does not own some file(s)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/ente-io/museum/ente"
|
||||
"github.com/ente-io/stacktrace"
|
||||
|
@ -18,14 +19,14 @@ type Repository struct {
|
|||
|
||||
// Create inserts a new embedding
|
||||
|
||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest) (ente.Embedding, error) {
|
||||
func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) {
|
||||
var updatedAt int64
|
||||
err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings
|
||||
(file_id, owner_id, model)
|
||||
VALUES ($1, $2, $3)
|
||||
(file_id, owner_id, model, size, version)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model
|
||||
DO UPDATE SET updated_at = now_utc_micro_seconds()
|
||||
RETURNING updated_at`, entry.FileID, ownerID, entry.Model).Scan(&updatedAt)
|
||||
DO UPDATE SET updated_at = now_utc_micro_seconds(), size = $4, version = $5
|
||||
RETURNING updated_at`, entry.FileID, ownerID, entry.Model, size, version).Scan(&updatedAt)
|
||||
if err != nil {
|
||||
// check if error is due to model enum invalid value
|
||||
if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) {
|
||||
|
@ -44,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en
|
|||
|
||||
// GetDiff returns the embeddings that have been updated since the given time
|
||||
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
|
||||
ORDER BY updated_at ASC
|
||||
|
@ -55,6 +56,16 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode
|
|||
return convertRowsToEmbeddings(rows)
|
||||
}
|
||||
|
||||
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
return convertRowsToEmbeddings(rows)
|
||||
}
|
||||
|
||||
func (r *Repository) DeleteAll(ctx context.Context, ownerID int64) error {
|
||||
_, err := r.DB.ExecContext(ctx, "DELETE FROM embeddings WHERE owner_id = $1", ownerID)
|
||||
if err != nil {
|
||||
|
@ -82,13 +93,19 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
|||
for rows.Next() {
|
||||
embedding := ente.Embedding{}
|
||||
var encryptedEmbedding, decryptionHeader sql.NullString
|
||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt)
|
||||
var version sql.NullInt32
|
||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version)
|
||||
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
|
||||
embedding.EncryptedEmbedding = encryptedEmbedding.String
|
||||
}
|
||||
if decryptionHeader.Valid && len(decryptionHeader.String) > 0 {
|
||||
embedding.DecryptionHeader = decryptionHeader.String
|
||||
}
|
||||
v := 1
|
||||
if version.Valid {
|
||||
v = int(version.Int32)
|
||||
}
|
||||
embedding.Version = &v
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
|
|
@ -47,3 +47,27 @@ func Int64InList(a int64, list []int64) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FindMissingElementsInSecondList identifies elements in 'sourceList' that are not present in 'targetList'.
|
||||
// Returns:
|
||||
// - A slice of int64 representing the elements found in 'sourceList' but not in 'targetList'.
|
||||
// If all elements of 'sourceList' are present in 'targetList', an empty slice is returned.
|
||||
//
|
||||
// Example usage:
|
||||
// missingElements := FindMissingElementsInSecondList([]int64{1, 2, 3, 4}, []int64{2, 4, 6})
|
||||
// fmt.Println(missingElements) // Output: [1, 3]
|
||||
func FindMissingElementsInSecondList(sourceList []int64, targetList []int64) []int64 {
|
||||
targetSet := make(map[int64]struct{})
|
||||
for _, item := range targetList {
|
||||
targetSet[item] = struct{}{}
|
||||
}
|
||||
|
||||
var missingElements = make([]int64, 0)
|
||||
for _, item := range sourceList {
|
||||
if _, found := targetSet[item]; !found {
|
||||
missingElements = append(missingElements, item)
|
||||
}
|
||||
}
|
||||
|
||||
return missingElements
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue