Rely on Isar's index to remove duplicates

This commit is contained in:
vishnukvmd 2024-01-06 20:22:03 +05:30
parent 41ae10b454
commit 99575825c0
4 changed files with 312 additions and 7 deletions

View file

@ -30,17 +30,16 @@ class EmbeddingsDB {
return _isar.embeddings.filter().modelEqualTo(model).findAll(); return _isar.embeddings.filter().modelEqualTo(model).findAll();
} }
Future<int> put(Embedding embedding) { Future<void> put(Embedding embedding) {
return _isar.writeTxn(() async { return _isar.writeTxn(() async {
final id = await _isar.embeddings.put(embedding); await _isar.embeddings.putByIndex(Embedding.index, embedding);
Bus.instance.fire(EmbeddingUpdatedEvent()); Bus.instance.fire(EmbeddingUpdatedEvent());
return id;
}); });
} }
Future<void> putMany(List<Embedding> embeddings) { Future<void> putMany(List<Embedding> embeddings) {
return _isar.writeTxn(() async { return _isar.writeTxn(() async {
await _isar.embeddings.putAll(embeddings); await _isar.embeddings.putAllByIndex(Embedding.index, embeddings);
Bus.instance.fire(EmbeddingUpdatedEvent()); Bus.instance.fire(EmbeddingUpdatedEvent());
}); });
} }

View file

@ -6,9 +6,12 @@ part 'embedding.g.dart';
@collection @collection
class Embedding { class Embedding {
Id id = Isar.autoIncrement; // you can also use id = null to auto increment static const index = 'unique_file_model_embedding';
Id id = Isar.autoIncrement;
final int fileID; final int fileID;
@enumerated @enumerated
@Index(name: index, composite: [CompositeIndex('fileID')], unique: true, replace: true)
final Model model; final Model model;
final List<double> embedding; final List<double> embedding;
int? updationTime; int? updationTime;

View file

@ -44,7 +44,26 @@ const EmbeddingSchema = CollectionSchema(
deserialize: _embeddingDeserialize, deserialize: _embeddingDeserialize,
deserializeProp: _embeddingDeserializeProp, deserializeProp: _embeddingDeserializeProp,
idName: r'id', idName: r'id',
indexes: {}, indexes: {
r'unique_file_model_embedding': IndexSchema(
id: 6248303800853228628,
name: r'unique_file_model_embedding',
unique: true,
replace: true,
properties: [
IndexPropertySchema(
name: r'model',
type: IndexType.value,
caseSensitive: false,
),
IndexPropertySchema(
name: r'fileID',
type: IndexType.value,
caseSensitive: false,
)
],
)
},
links: {}, links: {},
embeddedSchemas: {}, embeddedSchemas: {},
getId: _embeddingGetId, getId: _embeddingGetId,
@ -134,6 +153,95 @@ void _embeddingAttach(IsarCollection<dynamic> col, Id id, Embedding object) {
object.id = id; object.id = id;
} }
extension EmbeddingByIndex on IsarCollection<Embedding> {
Future<Embedding?> getByModelFileID(Model model, int fileID) {
return getByIndex(r'unique_file_model_embedding', [model, fileID]);
}
Embedding? getByModelFileIDSync(Model model, int fileID) {
return getByIndexSync(r'unique_file_model_embedding', [model, fileID]);
}
Future<bool> deleteByModelFileID(Model model, int fileID) {
return deleteByIndex(r'unique_file_model_embedding', [model, fileID]);
}
bool deleteByModelFileIDSync(Model model, int fileID) {
return deleteByIndexSync(r'unique_file_model_embedding', [model, fileID]);
}
Future<List<Embedding?>> getAllByModelFileID(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return getAllByIndex(r'unique_file_model_embedding', values);
}
List<Embedding?> getAllByModelFileIDSync(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return getAllByIndexSync(r'unique_file_model_embedding', values);
}
Future<int> deleteAllByModelFileID(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return deleteAllByIndex(r'unique_file_model_embedding', values);
}
int deleteAllByModelFileIDSync(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return deleteAllByIndexSync(r'unique_file_model_embedding', values);
}
Future<Id> putByModelFileID(Embedding object) {
return putByIndex(r'unique_file_model_embedding', object);
}
Id putByModelFileIDSync(Embedding object, {bool saveLinks = true}) {
return putByIndexSync(r'unique_file_model_embedding', object,
saveLinks: saveLinks);
}
Future<List<Id>> putAllByModelFileID(List<Embedding> objects) {
return putAllByIndex(r'unique_file_model_embedding', objects);
}
List<Id> putAllByModelFileIDSync(List<Embedding> objects,
{bool saveLinks = true}) {
return putAllByIndexSync(r'unique_file_model_embedding', objects,
saveLinks: saveLinks);
}
}
extension EmbeddingQueryWhereSort extension EmbeddingQueryWhereSort
on QueryBuilder<Embedding, Embedding, QWhere> { on QueryBuilder<Embedding, Embedding, QWhere> {
QueryBuilder<Embedding, Embedding, QAfterWhere> anyId() { QueryBuilder<Embedding, Embedding, QAfterWhere> anyId() {
@ -141,6 +249,14 @@ extension EmbeddingQueryWhereSort
return query.addWhereClause(const IdWhereClause.any()); return query.addWhereClause(const IdWhereClause.any());
}); });
} }
QueryBuilder<Embedding, Embedding, QAfterWhere> anyModelFileID() {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(
const IndexWhereClause.any(indexName: r'unique_file_model_embedding'),
);
});
}
} }
extension EmbeddingQueryWhere extension EmbeddingQueryWhere
@ -209,6 +325,193 @@ extension EmbeddingQueryWhere
)); ));
}); });
} }
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelEqualToAnyFileID(
Model model) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.equalTo(
indexName: r'unique_file_model_embedding',
value: [model],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelNotEqualToAnyFileID(Model model) {
return QueryBuilder.apply(this, (query) {
if (query.whereSort == Sort.asc) {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: false,
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: false,
upper: [],
));
} else {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: false,
upper: [],
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: false,
));
}
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelGreaterThanAnyFileID(
Model model, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: include,
upper: [],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelLessThanAnyFileID(
Model model, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: include,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelBetweenAnyFileID(
Model lowerModel,
Model upperModel, {
bool includeLower = true,
bool includeUpper = true,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [lowerModel],
includeLower: includeLower,
upper: [upperModel],
includeUpper: includeUpper,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelFileIDEqualTo(
Model model, int fileID) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.equalTo(
indexName: r'unique_file_model_embedding',
value: [model, fileID],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDNotEqualTo(Model model, int fileID) {
return QueryBuilder.apply(this, (query) {
if (query.whereSort == Sort.asc) {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: false,
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: false,
upper: [model],
));
} else {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: false,
upper: [model],
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: false,
));
}
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDGreaterThan(
Model model,
int fileID, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: include,
upper: [model],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDLessThan(
Model model,
int fileID, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: include,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDBetween(
Model model,
int lowerFileID,
int upperFileID, {
bool includeLower = true,
bool includeUpper = true,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, lowerFileID],
includeLower: includeLower,
upper: [model, upperFileID],
includeUpper: includeUpper,
));
});
}
} }
extension EmbeddingQueryFilter extension EmbeddingQueryFilter

View file

@ -57,7 +57,7 @@ class EmbeddingStore {
} }
Future<void> storeEmbedding(EnteFile file, Embedding embedding) async { Future<void> storeEmbedding(EnteFile file, Embedding embedding) async {
embedding.id = await EmbeddingsDB.instance.put(embedding); await EmbeddingsDB.instance.put(embedding);
unawaited(_pushEmbedding(file, embedding)); unawaited(_pushEmbedding(file, embedding));
} }