diff --git a/server/ente/embedding.go b/server/ente/embedding.go index 2990a779a..fabde44a5 100644 --- a/server/ente/embedding.go +++ b/server/ente/embedding.go @@ -7,6 +7,7 @@ type Embedding struct { DecryptionHeader string `json:"decryptionHeader"` UpdatedAt int64 `json:"updatedAt"` Version *int `json:"version,omitempty"` + Size *int64 } type InsertOrUpdateEmbeddingRequest struct { @@ -30,9 +31,10 @@ type GetFilesEmbeddingRequest struct { } type GetFilesEmbeddingResponse struct { - Embeddings []Embedding `json:"embeddings"` - NoDataFileIDs []int64 `json:"noDataFileIDs"` - ErrFileIDs []int64 `json:"errFileIDs"` + Embeddings []Embedding `json:"embeddings"` + PendingIndexFileIDs []int64 `json:"pendingIndexFileIDs"` + ErrFileIDs []int64 `json:"errFileIDs"` + NoEmbeddingFileIDs []int64 `json:"noEmbeddingFileIDs"` } type Model string diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index 342411ea3..a30043e7f 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -26,6 +26,11 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // maxEmbeddingDataSize is the min size of an embedding object in bytes + minEmbeddingDataSize = 2048 +) + type Controller struct { Repo *embedding.Repository AccessCtrl access.Controller @@ -135,15 +140,23 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd return nil, stacktrace.Propagate(err, "") } + embeddingsWithData := make([]ente.Embedding, 0) + noEmbeddingFileIds := make([]int64, 0) dbFileIds := make([]int64, 0) - for _, embedding := range userFileEmbeddings { - dbFileIds = append(dbFileIds, embedding.FileID) + // fileIDs that were indexed but they don't contain any embedding information + for i := range userFileEmbeddings { + dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID) + if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize { + noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID) + } else { + embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i]) + } } - missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds) + pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds) errFileIds := make([]int64, 0) // Fetch missing userFileEmbeddings in parallel - embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings) + embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -166,9 +179,10 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd } return &ente.GetFilesEmbeddingResponse{ - Embeddings: fetchedEmbeddings, - NoDataFileIDs: missingFileIds, - ErrFileIDs: errFileIds, + Embeddings: fetchedEmbeddings, + PendingIndexFileIDs: pendingIndexFileIds, + ErrFileIDs: errFileIds, + NoEmbeddingFileIDs: noEmbeddingFileIds, }, nil } diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index f21e3b4f1..86915fde5 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -45,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en // GetDiff returns the embeddings that have been updated since the given time func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version + rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size FROM embeddings WHERE owner_id = $1 AND model = $2 AND updated_at > $3 ORDER BY updated_at ASC @@ -57,7 +57,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode } func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) { - rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version + rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size FROM embeddings WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs)) if err != nil { @@ -94,7 +94,7 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { embedding := ente.Embedding{} var encryptedEmbedding, decryptionHeader sql.NullString var version sql.NullInt32 - err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version) + err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size) if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 { embedding.EncryptedEmbedding = encryptedEmbedding.String }