diff --git a/lib/db/embeddings_db.dart b/lib/db/embeddings_db.dart index d1a71609e..eeb5b15c6 100644 --- a/lib/db/embeddings_db.dart +++ b/lib/db/embeddings_db.dart @@ -30,17 +30,16 @@ class EmbeddingsDB { return _isar.embeddings.filter().modelEqualTo(model).findAll(); } - Future put(Embedding embedding) { + Future put(Embedding embedding) { return _isar.writeTxn(() async { - final id = await _isar.embeddings.put(embedding); + await _isar.embeddings.putByIndex(Embedding.index, embedding); Bus.instance.fire(EmbeddingUpdatedEvent()); - return id; }); } Future putMany(List embeddings) { return _isar.writeTxn(() async { - await _isar.embeddings.putAll(embeddings); + await _isar.embeddings.putAllByIndex(Embedding.index, embeddings); Bus.instance.fire(EmbeddingUpdatedEvent()); }); } diff --git a/lib/models/embedding.dart b/lib/models/embedding.dart index ddf6c33f8..1f78687b9 100644 --- a/lib/models/embedding.dart +++ b/lib/models/embedding.dart @@ -6,9 +6,12 @@ part 'embedding.g.dart'; @collection class Embedding { - Id id = Isar.autoIncrement; // you can also use id = null to auto increment + static const index = 'unique_file_model_embedding'; + + Id id = Isar.autoIncrement; final int fileID; @enumerated + @Index(name: index, composite: [CompositeIndex('fileID')], unique: true, replace: true) final Model model; final List embedding; int? updationTime; diff --git a/lib/models/embedding.g.dart b/lib/models/embedding.g.dart index 3f8fcfa07..ca041a0d0 100644 --- a/lib/models/embedding.g.dart +++ b/lib/models/embedding.g.dart @@ -44,7 +44,26 @@ const EmbeddingSchema = CollectionSchema( deserialize: _embeddingDeserialize, deserializeProp: _embeddingDeserializeProp, idName: r'id', - indexes: {}, + indexes: { + r'unique_file_model_embedding': IndexSchema( + id: 6248303800853228628, + name: r'unique_file_model_embedding', + unique: true, + replace: true, + properties: [ + IndexPropertySchema( + name: r'model', + type: IndexType.value, + caseSensitive: false, + ), + IndexPropertySchema( + name: r'fileID', + type: IndexType.value, + caseSensitive: false, + ) + ], + ) + }, links: {}, embeddedSchemas: {}, getId: _embeddingGetId, @@ -134,6 +153,95 @@ void _embeddingAttach(IsarCollection col, Id id, Embedding object) { object.id = id; } +extension EmbeddingByIndex on IsarCollection { + Future getByModelFileID(Model model, int fileID) { + return getByIndex(r'unique_file_model_embedding', [model, fileID]); + } + + Embedding? getByModelFileIDSync(Model model, int fileID) { + return getByIndexSync(r'unique_file_model_embedding', [model, fileID]); + } + + Future deleteByModelFileID(Model model, int fileID) { + return deleteByIndex(r'unique_file_model_embedding', [model, fileID]); + } + + bool deleteByModelFileIDSync(Model model, int fileID) { + return deleteByIndexSync(r'unique_file_model_embedding', [model, fileID]); + } + + Future> getAllByModelFileID( + List modelValues, List fileIDValues) { + final len = modelValues.length; + assert(fileIDValues.length == len, + 'All index values must have the same length'); + final values = >[]; + for (var i = 0; i < len; i++) { + values.add([modelValues[i], fileIDValues[i]]); + } + + return getAllByIndex(r'unique_file_model_embedding', values); + } + + List getAllByModelFileIDSync( + List modelValues, List fileIDValues) { + final len = modelValues.length; + assert(fileIDValues.length == len, + 'All index values must have the same length'); + final values = >[]; + for (var i = 0; i < len; i++) { + values.add([modelValues[i], fileIDValues[i]]); + } + + return getAllByIndexSync(r'unique_file_model_embedding', values); + } + + Future deleteAllByModelFileID( + List modelValues, List fileIDValues) { + final len = modelValues.length; + assert(fileIDValues.length == len, + 'All index values must have the same length'); + final values = >[]; + for (var i = 0; i < len; i++) { + values.add([modelValues[i], fileIDValues[i]]); + } + + return deleteAllByIndex(r'unique_file_model_embedding', values); + } + + int deleteAllByModelFileIDSync( + List modelValues, List fileIDValues) { + final len = modelValues.length; + assert(fileIDValues.length == len, + 'All index values must have the same length'); + final values = >[]; + for (var i = 0; i < len; i++) { + values.add([modelValues[i], fileIDValues[i]]); + } + + return deleteAllByIndexSync(r'unique_file_model_embedding', values); + } + + Future putByModelFileID(Embedding object) { + return putByIndex(r'unique_file_model_embedding', object); + } + + Id putByModelFileIDSync(Embedding object, {bool saveLinks = true}) { + return putByIndexSync(r'unique_file_model_embedding', object, + saveLinks: saveLinks); + } + + Future> putAllByModelFileID(List objects) { + return putAllByIndex(r'unique_file_model_embedding', objects); + } + + List putAllByModelFileIDSync(List objects, + {bool saveLinks = true}) { + return putAllByIndexSync(r'unique_file_model_embedding', objects, + saveLinks: saveLinks); + } +} + extension EmbeddingQueryWhereSort on QueryBuilder { QueryBuilder anyId() { @@ -141,6 +249,14 @@ extension EmbeddingQueryWhereSort return query.addWhereClause(const IdWhereClause.any()); }); } + + QueryBuilder anyModelFileID() { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause( + const IndexWhereClause.any(indexName: r'unique_file_model_embedding'), + ); + }); + } } extension EmbeddingQueryWhere @@ -209,6 +325,193 @@ extension EmbeddingQueryWhere )); }); } + + QueryBuilder modelEqualToAnyFileID( + Model model) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.equalTo( + indexName: r'unique_file_model_embedding', + value: [model], + )); + }); + } + + QueryBuilder + modelNotEqualToAnyFileID(Model model) { + return QueryBuilder.apply(this, (query) { + if (query.whereSort == Sort.asc) { + return query + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [], + upper: [model], + includeUpper: false, + )) + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + includeLower: false, + upper: [], + )); + } else { + return query + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + includeLower: false, + upper: [], + )) + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [], + upper: [model], + includeUpper: false, + )); + } + }); + } + + QueryBuilder + modelGreaterThanAnyFileID( + Model model, { + bool include = false, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + includeLower: include, + upper: [], + )); + }); + } + + QueryBuilder modelLessThanAnyFileID( + Model model, { + bool include = false, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [], + upper: [model], + includeUpper: include, + )); + }); + } + + QueryBuilder modelBetweenAnyFileID( + Model lowerModel, + Model upperModel, { + bool includeLower = true, + bool includeUpper = true, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [lowerModel], + includeLower: includeLower, + upper: [upperModel], + includeUpper: includeUpper, + )); + }); + } + + QueryBuilder modelFileIDEqualTo( + Model model, int fileID) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.equalTo( + indexName: r'unique_file_model_embedding', + value: [model, fileID], + )); + }); + } + + QueryBuilder + modelEqualToFileIDNotEqualTo(Model model, int fileID) { + return QueryBuilder.apply(this, (query) { + if (query.whereSort == Sort.asc) { + return query + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + upper: [model, fileID], + includeUpper: false, + )) + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model, fileID], + includeLower: false, + upper: [model], + )); + } else { + return query + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model, fileID], + includeLower: false, + upper: [model], + )) + .addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + upper: [model, fileID], + includeUpper: false, + )); + } + }); + } + + QueryBuilder + modelEqualToFileIDGreaterThan( + Model model, + int fileID, { + bool include = false, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model, fileID], + includeLower: include, + upper: [model], + )); + }); + } + + QueryBuilder + modelEqualToFileIDLessThan( + Model model, + int fileID, { + bool include = false, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model], + upper: [model, fileID], + includeUpper: include, + )); + }); + } + + QueryBuilder + modelEqualToFileIDBetween( + Model model, + int lowerFileID, + int upperFileID, { + bool includeLower = true, + bool includeUpper = true, + }) { + return QueryBuilder.apply(this, (query) { + return query.addWhereClause(IndexWhereClause.between( + indexName: r'unique_file_model_embedding', + lower: [model, lowerFileID], + includeLower: includeLower, + upper: [model, upperFileID], + includeUpper: includeUpper, + )); + }); + } } extension EmbeddingQueryFilter diff --git a/lib/services/semantic_search/embedding_store.dart b/lib/services/semantic_search/embedding_store.dart index 67f4522d7..143cc59d6 100644 --- a/lib/services/semantic_search/embedding_store.dart +++ b/lib/services/semantic_search/embedding_store.dart @@ -57,7 +57,7 @@ class EmbeddingStore { } Future storeEmbedding(EnteFile file, Embedding embedding) async { - embedding.id = await EmbeddingsDB.instance.put(embedding); + await EmbeddingsDB.instance.put(embedding); unawaited(_pushEmbedding(file, embedding)); }