[server] Avoid pulling files with no-embedding info
This commit is contained in:
parent
c29db9bcfb
commit
87a2f0e0df
3 changed files with 29 additions and 13 deletions
|
@ -7,6 +7,7 @@ type Embedding struct {
|
|||
DecryptionHeader string `json:"decryptionHeader"`
|
||||
UpdatedAt int64 `json:"updatedAt"`
|
||||
Version *int `json:"version,omitempty"`
|
||||
Size *int64
|
||||
}
|
||||
|
||||
type InsertOrUpdateEmbeddingRequest struct {
|
||||
|
@ -30,9 +31,10 @@ type GetFilesEmbeddingRequest struct {
|
|||
}
|
||||
|
||||
type GetFilesEmbeddingResponse struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
NoDataFileIDs []int64 `json:"noDataFileIDs"`
|
||||
ErrFileIDs []int64 `json:"errFileIDs"`
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
PendingIndexFileIDs []int64 `json:"pendingIndexFileIDs"`
|
||||
ErrFileIDs []int64 `json:"errFileIDs"`
|
||||
NoEmbeddingFileIDs []int64 `json:"noEmbeddingFileIDs"`
|
||||
}
|
||||
|
||||
type Model string
|
||||
|
|
|
@ -26,6 +26,11 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxEmbeddingDataSize is the min size of an embedding object in bytes
|
||||
minEmbeddingDataSize = 2048
|
||||
)
|
||||
|
||||
type Controller struct {
|
||||
Repo *embedding.Repository
|
||||
AccessCtrl access.Controller
|
||||
|
@ -135,15 +140,23 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
|||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
||||
embeddingsWithData := make([]ente.Embedding, 0)
|
||||
noEmbeddingFileIds := make([]int64, 0)
|
||||
dbFileIds := make([]int64, 0)
|
||||
for _, embedding := range userFileEmbeddings {
|
||||
dbFileIds = append(dbFileIds, embedding.FileID)
|
||||
// fileIDs that were indexed but they don't contain any embedding information
|
||||
for i, _ := range userFileEmbeddings {
|
||||
dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID)
|
||||
if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize {
|
||||
noEmbeddingFileIds = append(noEmbeddingFileIds, userFileEmbeddings[i].FileID)
|
||||
} else {
|
||||
embeddingsWithData = append(embeddingsWithData, userFileEmbeddings[i])
|
||||
}
|
||||
}
|
||||
missingFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
|
||||
pendingIndexFileIds := array.FindMissingElementsInSecondList(req.FileIDs, dbFileIds)
|
||||
errFileIds := make([]int64, 0)
|
||||
|
||||
// Fetch missing userFileEmbeddings in parallel
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, userFileEmbeddings)
|
||||
embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData)
|
||||
if err != nil {
|
||||
return nil, stacktrace.Propagate(err, "")
|
||||
}
|
||||
|
@ -166,9 +179,10 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd
|
|||
}
|
||||
|
||||
return &ente.GetFilesEmbeddingResponse{
|
||||
Embeddings: fetchedEmbeddings,
|
||||
NoDataFileIDs: missingFileIds,
|
||||
ErrFileIDs: errFileIds,
|
||||
Embeddings: fetchedEmbeddings,
|
||||
PendingIndexFileIDs: pendingIndexFileIds,
|
||||
ErrFileIDs: errFileIds,
|
||||
NoEmbeddingFileIDs: noEmbeddingFileIds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry en
|
|||
|
||||
// GetDiff returns the embeddings that have been updated since the given time
|
||||
func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Model, sinceTime int64, limit int16) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND updated_at > $3
|
||||
ORDER BY updated_at ASC
|
||||
|
@ -57,7 +57,7 @@ func (r *Repository) GetDiff(ctx context.Context, ownerID int64, model ente.Mode
|
|||
}
|
||||
|
||||
func (r *Repository) GetFilesEmbedding(ctx context.Context, ownerID int64, model ente.Model, fileIDs []int64) ([]ente.Embedding, error) {
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version
|
||||
rows, err := r.DB.QueryContext(ctx, `SELECT file_id, model, encrypted_embedding, decryption_header, updated_at, version, size
|
||||
FROM embeddings
|
||||
WHERE owner_id = $1 AND model = $2 AND file_id = ANY($3)`, ownerID, model, pq.Array(fileIDs))
|
||||
if err != nil {
|
||||
|
@ -94,7 +94,7 @@ func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) {
|
|||
embedding := ente.Embedding{}
|
||||
var encryptedEmbedding, decryptionHeader sql.NullString
|
||||
var version sql.NullInt32
|
||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version)
|
||||
err := rows.Scan(&embedding.FileID, &embedding.Model, &encryptedEmbedding, &decryptionHeader, &embedding.UpdatedAt, &version, &embedding.Size)
|
||||
if encryptedEmbedding.Valid && len(encryptedEmbedding.String) > 0 {
|
||||
embedding.EncryptedEmbedding = encryptedEmbedding.String
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue