[server] Avoid pulling files with no-embedding info

This commit is contained in:
Neeraj Gupta 2024-05-10 16:09:49 +05:30
parent c29db9bcfb
commit 87a2f0e0df
3 changed files with 29 additions and 13 deletions

View file

@ -7,6 +7,7 @@ type Embedding struct {
DecryptionHeader string `json:"decryptionHeader"`
UpdatedAt int64 `json:"updatedAt"`
Version *int `json:"version,omitempty"`
Size *int64
}
type InsertOrUpdateEmbeddingRequest struct {
@ -30,9 +31,10 @@ type GetFilesEmbeddingRequest struct {
}
type GetFilesEmbeddingResponse struct {
Embeddings []Embedding `json:"embeddings"`
NoDataFileIDs []int64 `json:"noDataFileIDs"`
ErrFileIDs []int64 `json:"errFileIDs"`
Embeddings []Embedding `json:"embeddings"`
PendingIndexFileIDs []int64 `json:"pendingIndexFileIDs"`
ErrFileIDs []int64 `json:"errFileIDs"`
NoEmbeddingFileIDs []int64 `json:"noEmbeddingFileIDs"`
}
type Model string

View file

@ -26,6 +26,11 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// maxEmbeddingDataSize is the min size of an embedding object in bytes
minEmbeddingDataSize = 2048
)
type Controller struct {
Repo *embedding.Repository
AccessCtrl access.Controller
@ -135,15 +140,23 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
return nil, stacktrace.Propagate(err, "")
}
embeddingsWithData := make([]ente.Embedding, 0)
noEmbeddingFileIds := make([]int64, 0)
dbFileIds := make([]int64, 0)
for _, embedding := range userFileEmbeddings {
dbFileIds = append(dbFileIds, embedding.FileID)
// fileIDs that were indexed but they don't contain any embedding information
for i, _ := range userFileEmbeddings {
dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
} else {
embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
}
}
missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
errFileIds := make([]int64, 0)
// Fetch missing userFileEmbeddings in parallel
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings)
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
if err != nil {
return nil, stacktrace.Propagate(err, "")
}
@ -166,9 +179,10 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
}
return &ente.GetFilesEmbeddingResponse{
Embeddings: fetchedEmbeddings,
NoDataFileIDs: missingFileIds,
ErrFileIDs: errFileIds,
Embeddings: fetchedEmbeddings,
PendingIndexFileIDs: pendingIndexFileIds,
ErrFileIDs: errFileIds,
NoEmbeddingFileIDs: noEmbeddingFileIds,
}, nil
}

View file

@ -45,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en
// GetDiff returns the embeddings that have been updated since the given time
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
FROM embeddings
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
ORDER BY updated_at ASC
@ -57,7 +57,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode
}
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
FROM embeddings
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
if err != nil {
@ -94,7 +94,7 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
embedding := ente.Embedding{}
var encryptedEmbedding, decryptionHeader sql.NullString
var version sql.NullInt32
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version)
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size)
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
embedding.EncryptedEmbedding = encryptedEmbedding.String
}