Ensure unique embeddings for a given file and model (#1638)

This commit is contained in:
Vishnu Mohandas 2024-01-06 20:34:02 +05:30 committed by GitHub
commit 0da0df0738
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 327 additions and 10 deletions

View file

@ -32,14 +32,14 @@ class EmbeddingsDB {
Future<void> put(Embedding embedding) {
return _isar.writeTxn(() async {
await _isar.embeddings.put(embedding);
await _isar.embeddings.putByIndex(Embedding.index, embedding);
Bus.instance.fire(EmbeddingUpdatedEvent());
});
}
Future<void> putMany(List<Embedding> embeddings) {
return _isar.writeTxn(() async {
await _isar.embeddings.putAll(embeddings);
await _isar.embeddings.putAllByIndex(Embedding.index, embeddings);
Bus.instance.fire(EmbeddingUpdatedEvent());
});
}

View file

@ -6,9 +6,12 @@ part 'embedding.g.dart';
@collection
class Embedding {
Id id = Isar.autoIncrement; // you can also use id = null to auto increment
static const index = 'unique_file_model_embedding';
Id id = Isar.autoIncrement;
final int fileID;
@enumerated
@Index(name: index, composite: [CompositeIndex('fileID')], unique: true, replace: true)
final Model model;
final List<double> embedding;
int? updationTime;

View file

@ -44,7 +44,26 @@ const EmbeddingSchema = CollectionSchema(
deserialize: _embeddingDeserialize,
deserializeProp: _embeddingDeserializeProp,
idName: r'id',
indexes: {},
indexes: {
r'unique_file_model_embedding': IndexSchema(
id: 6248303800853228628,
name: r'unique_file_model_embedding',
unique: true,
replace: true,
properties: [
IndexPropertySchema(
name: r'model',
type: IndexType.value,
caseSensitive: false,
),
IndexPropertySchema(
name: r'fileID',
type: IndexType.value,
caseSensitive: false,
)
],
)
},
links: {},
embeddedSchemas: {},
getId: _embeddingGetId,
@ -134,6 +153,95 @@ void _embeddingAttach(IsarCollection<dynamic> col, Id id, Embedding object) {
object.id = id;
}
extension EmbeddingByIndex on IsarCollection<Embedding> {
Future<Embedding?> getByModelFileID(Model model, int fileID) {
return getByIndex(r'unique_file_model_embedding', [model, fileID]);
}
Embedding? getByModelFileIDSync(Model model, int fileID) {
return getByIndexSync(r'unique_file_model_embedding', [model, fileID]);
}
Future<bool> deleteByModelFileID(Model model, int fileID) {
return deleteByIndex(r'unique_file_model_embedding', [model, fileID]);
}
bool deleteByModelFileIDSync(Model model, int fileID) {
return deleteByIndexSync(r'unique_file_model_embedding', [model, fileID]);
}
Future<List<Embedding?>> getAllByModelFileID(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return getAllByIndex(r'unique_file_model_embedding', values);
}
List<Embedding?> getAllByModelFileIDSync(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return getAllByIndexSync(r'unique_file_model_embedding', values);
}
Future<int> deleteAllByModelFileID(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return deleteAllByIndex(r'unique_file_model_embedding', values);
}
int deleteAllByModelFileIDSync(
List<Model> modelValues, List<int> fileIDValues) {
final len = modelValues.length;
assert(fileIDValues.length == len,
'All index values must have the same length');
final values = <List<dynamic>>[];
for (var i = 0; i < len; i++) {
values.add([modelValues[i], fileIDValues[i]]);
}
return deleteAllByIndexSync(r'unique_file_model_embedding', values);
}
Future<Id> putByModelFileID(Embedding object) {
return putByIndex(r'unique_file_model_embedding', object);
}
Id putByModelFileIDSync(Embedding object, {bool saveLinks = true}) {
return putByIndexSync(r'unique_file_model_embedding', object,
saveLinks: saveLinks);
}
Future<List<Id>> putAllByModelFileID(List<Embedding> objects) {
return putAllByIndex(r'unique_file_model_embedding', objects);
}
List<Id> putAllByModelFileIDSync(List<Embedding> objects,
{bool saveLinks = true}) {
return putAllByIndexSync(r'unique_file_model_embedding', objects,
saveLinks: saveLinks);
}
}
extension EmbeddingQueryWhereSort
on QueryBuilder<Embedding, Embedding, QWhere> {
QueryBuilder<Embedding, Embedding, QAfterWhere> anyId() {
@ -141,6 +249,14 @@ extension EmbeddingQueryWhereSort
return query.addWhereClause(const IdWhereClause.any());
});
}
QueryBuilder<Embedding, Embedding, QAfterWhere> anyModelFileID() {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(
const IndexWhereClause.any(indexName: r'unique_file_model_embedding'),
);
});
}
}
extension EmbeddingQueryWhere
@ -209,6 +325,193 @@ extension EmbeddingQueryWhere
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelEqualToAnyFileID(
Model model) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.equalTo(
indexName: r'unique_file_model_embedding',
value: [model],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelNotEqualToAnyFileID(Model model) {
return QueryBuilder.apply(this, (query) {
if (query.whereSort == Sort.asc) {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: false,
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: false,
upper: [],
));
} else {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: false,
upper: [],
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: false,
));
}
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelGreaterThanAnyFileID(
Model model, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
includeLower: include,
upper: [],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelLessThanAnyFileID(
Model model, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [],
upper: [model],
includeUpper: include,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelBetweenAnyFileID(
Model lowerModel,
Model upperModel, {
bool includeLower = true,
bool includeUpper = true,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [lowerModel],
includeLower: includeLower,
upper: [upperModel],
includeUpper: includeUpper,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelFileIDEqualTo(
Model model, int fileID) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.equalTo(
indexName: r'unique_file_model_embedding',
value: [model, fileID],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDNotEqualTo(Model model, int fileID) {
return QueryBuilder.apply(this, (query) {
if (query.whereSort == Sort.asc) {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: false,
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: false,
upper: [model],
));
} else {
return query
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: false,
upper: [model],
))
.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: false,
));
}
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDGreaterThan(
Model model,
int fileID, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, fileID],
includeLower: include,
upper: [model],
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDLessThan(
Model model,
int fileID, {
bool include = false,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model],
upper: [model, fileID],
includeUpper: include,
));
});
}
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
modelEqualToFileIDBetween(
Model model,
int lowerFileID,
int upperFileID, {
bool includeLower = true,
bool includeUpper = true,
}) {
return QueryBuilder.apply(this, (query) {
return query.addWhereClause(IndexWhereClause.between(
indexName: r'unique_file_model_embedding',
lower: [model, lowerFileID],
includeLower: includeLower,
upper: [model, upperFileID],
includeUpper: includeUpper,
));
});
}
}
extension EmbeddingQueryFilter

View file

@ -61,6 +61,11 @@ class EmbeddingStore {
unawaited(_pushEmbedding(file, embedding));
}
Future<void> clearEmbeddings(Model model) async {
await EmbeddingsDB.instance.deleteAllForModel(model);
await _preferences.remove(kEmbeddingsSyncTimeKey);
}
Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
final encryptionKey = getFileKey(file);
final embeddingJSON = jsonEncode(embedding.embedding);

View file

@ -51,7 +51,7 @@ class SemanticSearchService {
get hasInitialized => _hasInitialized;
Future<void> init() async {
Future<void> init({bool shouldSyncImmediately = false}) async {
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
@ -87,6 +87,9 @@ class SemanticSearchService {
Bus.instance.on<FileUploadedEvent>().listen((event) async {
_addToQueue(event.file);
});
if (shouldSyncImmediately) {
unawaited(sync());
}
}
Future<void> release() async {
@ -146,7 +149,7 @@ class SemanticSearchService {
}
Future<void> clearIndexes() async {
await EmbeddingsDB.instance.deleteAllForModel(kCurrentModel);
await EmbeddingStore.instance.clearEmbeddings(kCurrentModel);
_logger.info("Indexes cleared for $kCurrentModel");
}

View file

@ -1,11 +1,11 @@
import "dart:async";
import "package:flutter/foundation.dart";
import "package:flutter/material.dart";
import "package:intl/intl.dart";
import "package:photos/core/event_bus.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/generated/l10n.dart";
import "package:photos/services/feature_flag_service.dart";
import "package:photos/services/semantic_search/semantic_search_service.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/common/loading_widget.dart";
@ -92,7 +92,10 @@ class _MachineLearningSettingsPageState
!LocalSettings.instance.hasEnabledMagicSearch(),
);
if (LocalSettings.instance.hasEnabledMagicSearch()) {
unawaited(SemanticSearchService.instance.init());
unawaited(
SemanticSearchService.instance
.init(shouldSyncImmediately: true),
);
} else {
await SemanticSearchService.instance.clearQueue();
}
@ -129,7 +132,7 @@ class _MachineLearningSettingsPageState
const SizedBox(
height: 12,
),
kDebugMode
FeatureFlagService.instance.isInternalUserOrDebugBuild()
? MenuItemWidget(
leadingIcon: Icons.delete_sweep_outlined,
captionedTextWidget: CaptionedTextWidget(

View file

@ -12,7 +12,7 @@ description: ente photos application
# Read more about iOS versioning at
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
version: 0.8.31+551
version: 0.8.32+552
environment:
sdk: ">=3.0.0 <4.0.0"