Ensure unique embeddings for a given file and model (#1638)
This commit is contained in:
commit
0da0df0738
7 changed files with 327 additions and 10 deletions
|
@ -32,14 +32,14 @@ class EmbeddingsDB {
|
|||
|
||||
Future<void> put(Embedding embedding) {
|
||||
return _isar.writeTxn(() async {
|
||||
await _isar.embeddings.put(embedding);
|
||||
await _isar.embeddings.putByIndex(Embedding.index, embedding);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
});
|
||||
}
|
||||
|
||||
Future<void> putMany(List<Embedding> embeddings) {
|
||||
return _isar.writeTxn(() async {
|
||||
await _isar.embeddings.putAll(embeddings);
|
||||
await _isar.embeddings.putAllByIndex(Embedding.index, embeddings);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
});
|
||||
}
|
||||
|
|
|
@ -6,9 +6,12 @@ part 'embedding.g.dart';
|
|||
|
||||
@collection
|
||||
class Embedding {
|
||||
Id id = Isar.autoIncrement; // you can also use id = null to auto increment
|
||||
static const index = 'unique_file_model_embedding';
|
||||
|
||||
Id id = Isar.autoIncrement;
|
||||
final int fileID;
|
||||
@enumerated
|
||||
@Index(name: index, composite: [CompositeIndex('fileID')], unique: true, replace: true)
|
||||
final Model model;
|
||||
final List<double> embedding;
|
||||
int? updationTime;
|
||||
|
|
|
@ -44,7 +44,26 @@ const EmbeddingSchema = CollectionSchema(
|
|||
deserialize: _embeddingDeserialize,
|
||||
deserializeProp: _embeddingDeserializeProp,
|
||||
idName: r'id',
|
||||
indexes: {},
|
||||
indexes: {
|
||||
r'unique_file_model_embedding': IndexSchema(
|
||||
id: 6248303800853228628,
|
||||
name: r'unique_file_model_embedding',
|
||||
unique: true,
|
||||
replace: true,
|
||||
properties: [
|
||||
IndexPropertySchema(
|
||||
name: r'model',
|
||||
type: IndexType.value,
|
||||
caseSensitive: false,
|
||||
),
|
||||
IndexPropertySchema(
|
||||
name: r'fileID',
|
||||
type: IndexType.value,
|
||||
caseSensitive: false,
|
||||
)
|
||||
],
|
||||
)
|
||||
},
|
||||
links: {},
|
||||
embeddedSchemas: {},
|
||||
getId: _embeddingGetId,
|
||||
|
@ -134,6 +153,95 @@ void _embeddingAttach(IsarCollection<dynamic> col, Id id, Embedding object) {
|
|||
object.id = id;
|
||||
}
|
||||
|
||||
extension EmbeddingByIndex on IsarCollection<Embedding> {
|
||||
Future<Embedding?> getByModelFileID(Model model, int fileID) {
|
||||
return getByIndex(r'unique_file_model_embedding', [model, fileID]);
|
||||
}
|
||||
|
||||
Embedding? getByModelFileIDSync(Model model, int fileID) {
|
||||
return getByIndexSync(r'unique_file_model_embedding', [model, fileID]);
|
||||
}
|
||||
|
||||
Future<bool> deleteByModelFileID(Model model, int fileID) {
|
||||
return deleteByIndex(r'unique_file_model_embedding', [model, fileID]);
|
||||
}
|
||||
|
||||
bool deleteByModelFileIDSync(Model model, int fileID) {
|
||||
return deleteByIndexSync(r'unique_file_model_embedding', [model, fileID]);
|
||||
}
|
||||
|
||||
Future<List<Embedding?>> getAllByModelFileID(
|
||||
List<Model> modelValues, List<int> fileIDValues) {
|
||||
final len = modelValues.length;
|
||||
assert(fileIDValues.length == len,
|
||||
'All index values must have the same length');
|
||||
final values = <List<dynamic>>[];
|
||||
for (var i = 0; i < len; i++) {
|
||||
values.add([modelValues[i], fileIDValues[i]]);
|
||||
}
|
||||
|
||||
return getAllByIndex(r'unique_file_model_embedding', values);
|
||||
}
|
||||
|
||||
List<Embedding?> getAllByModelFileIDSync(
|
||||
List<Model> modelValues, List<int> fileIDValues) {
|
||||
final len = modelValues.length;
|
||||
assert(fileIDValues.length == len,
|
||||
'All index values must have the same length');
|
||||
final values = <List<dynamic>>[];
|
||||
for (var i = 0; i < len; i++) {
|
||||
values.add([modelValues[i], fileIDValues[i]]);
|
||||
}
|
||||
|
||||
return getAllByIndexSync(r'unique_file_model_embedding', values);
|
||||
}
|
||||
|
||||
Future<int> deleteAllByModelFileID(
|
||||
List<Model> modelValues, List<int> fileIDValues) {
|
||||
final len = modelValues.length;
|
||||
assert(fileIDValues.length == len,
|
||||
'All index values must have the same length');
|
||||
final values = <List<dynamic>>[];
|
||||
for (var i = 0; i < len; i++) {
|
||||
values.add([modelValues[i], fileIDValues[i]]);
|
||||
}
|
||||
|
||||
return deleteAllByIndex(r'unique_file_model_embedding', values);
|
||||
}
|
||||
|
||||
int deleteAllByModelFileIDSync(
|
||||
List<Model> modelValues, List<int> fileIDValues) {
|
||||
final len = modelValues.length;
|
||||
assert(fileIDValues.length == len,
|
||||
'All index values must have the same length');
|
||||
final values = <List<dynamic>>[];
|
||||
for (var i = 0; i < len; i++) {
|
||||
values.add([modelValues[i], fileIDValues[i]]);
|
||||
}
|
||||
|
||||
return deleteAllByIndexSync(r'unique_file_model_embedding', values);
|
||||
}
|
||||
|
||||
Future<Id> putByModelFileID(Embedding object) {
|
||||
return putByIndex(r'unique_file_model_embedding', object);
|
||||
}
|
||||
|
||||
Id putByModelFileIDSync(Embedding object, {bool saveLinks = true}) {
|
||||
return putByIndexSync(r'unique_file_model_embedding', object,
|
||||
saveLinks: saveLinks);
|
||||
}
|
||||
|
||||
Future<List<Id>> putAllByModelFileID(List<Embedding> objects) {
|
||||
return putAllByIndex(r'unique_file_model_embedding', objects);
|
||||
}
|
||||
|
||||
List<Id> putAllByModelFileIDSync(List<Embedding> objects,
|
||||
{bool saveLinks = true}) {
|
||||
return putAllByIndexSync(r'unique_file_model_embedding', objects,
|
||||
saveLinks: saveLinks);
|
||||
}
|
||||
}
|
||||
|
||||
extension EmbeddingQueryWhereSort
|
||||
on QueryBuilder<Embedding, Embedding, QWhere> {
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhere> anyId() {
|
||||
|
@ -141,6 +249,14 @@ extension EmbeddingQueryWhereSort
|
|||
return query.addWhereClause(const IdWhereClause.any());
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhere> anyModelFileID() {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(
|
||||
const IndexWhereClause.any(indexName: r'unique_file_model_embedding'),
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
extension EmbeddingQueryWhere
|
||||
|
@ -209,6 +325,193 @@ extension EmbeddingQueryWhere
|
|||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelEqualToAnyFileID(
|
||||
Model model) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.equalTo(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
value: [model],
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelNotEqualToAnyFileID(Model model) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
if (query.whereSort == Sort.asc) {
|
||||
return query
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [],
|
||||
upper: [model],
|
||||
includeUpper: false,
|
||||
))
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
includeLower: false,
|
||||
upper: [],
|
||||
));
|
||||
} else {
|
||||
return query
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
includeLower: false,
|
||||
upper: [],
|
||||
))
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [],
|
||||
upper: [model],
|
||||
includeUpper: false,
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelGreaterThanAnyFileID(
|
||||
Model model, {
|
||||
bool include = false,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
includeLower: include,
|
||||
upper: [],
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelLessThanAnyFileID(
|
||||
Model model, {
|
||||
bool include = false,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [],
|
||||
upper: [model],
|
||||
includeUpper: include,
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelBetweenAnyFileID(
|
||||
Model lowerModel,
|
||||
Model upperModel, {
|
||||
bool includeLower = true,
|
||||
bool includeUpper = true,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [lowerModel],
|
||||
includeLower: includeLower,
|
||||
upper: [upperModel],
|
||||
includeUpper: includeUpper,
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelFileIDEqualTo(
|
||||
Model model, int fileID) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.equalTo(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
value: [model, fileID],
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelEqualToFileIDNotEqualTo(Model model, int fileID) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
if (query.whereSort == Sort.asc) {
|
||||
return query
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
upper: [model, fileID],
|
||||
includeUpper: false,
|
||||
))
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model, fileID],
|
||||
includeLower: false,
|
||||
upper: [model],
|
||||
));
|
||||
} else {
|
||||
return query
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model, fileID],
|
||||
includeLower: false,
|
||||
upper: [model],
|
||||
))
|
||||
.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
upper: [model, fileID],
|
||||
includeUpper: false,
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelEqualToFileIDGreaterThan(
|
||||
Model model,
|
||||
int fileID, {
|
||||
bool include = false,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model, fileID],
|
||||
includeLower: include,
|
||||
upper: [model],
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelEqualToFileIDLessThan(
|
||||
Model model,
|
||||
int fileID, {
|
||||
bool include = false,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model],
|
||||
upper: [model, fileID],
|
||||
includeUpper: include,
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
QueryBuilder<Embedding, Embedding, QAfterWhereClause>
|
||||
modelEqualToFileIDBetween(
|
||||
Model model,
|
||||
int lowerFileID,
|
||||
int upperFileID, {
|
||||
bool includeLower = true,
|
||||
bool includeUpper = true,
|
||||
}) {
|
||||
return QueryBuilder.apply(this, (query) {
|
||||
return query.addWhereClause(IndexWhereClause.between(
|
||||
indexName: r'unique_file_model_embedding',
|
||||
lower: [model, lowerFileID],
|
||||
includeLower: includeLower,
|
||||
upper: [model, upperFileID],
|
||||
includeUpper: includeUpper,
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
extension EmbeddingQueryFilter
|
||||
|
|
|
@ -61,6 +61,11 @@ class EmbeddingStore {
|
|||
unawaited(_pushEmbedding(file, embedding));
|
||||
}
|
||||
|
||||
Future<void> clearEmbeddings(Model model) async {
|
||||
await EmbeddingsDB.instance.deleteAllForModel(model);
|
||||
await _preferences.remove(kEmbeddingsSyncTimeKey);
|
||||
}
|
||||
|
||||
Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
|
||||
final encryptionKey = getFileKey(file);
|
||||
final embeddingJSON = jsonEncode(embedding.embedding);
|
||||
|
|
|
@ -51,7 +51,7 @@ class SemanticSearchService {
|
|||
|
||||
get hasInitialized => _hasInitialized;
|
||||
|
||||
Future<void> init() async {
|
||||
Future<void> init({bool shouldSyncImmediately = false}) async {
|
||||
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
|
||||
return;
|
||||
}
|
||||
|
@ -87,6 +87,9 @@ class SemanticSearchService {
|
|||
Bus.instance.on<FileUploadedEvent>().listen((event) async {
|
||||
_addToQueue(event.file);
|
||||
});
|
||||
if (shouldSyncImmediately) {
|
||||
unawaited(sync());
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> release() async {
|
||||
|
@ -146,7 +149,7 @@ class SemanticSearchService {
|
|||
}
|
||||
|
||||
Future<void> clearIndexes() async {
|
||||
await EmbeddingsDB.instance.deleteAllForModel(kCurrentModel);
|
||||
await EmbeddingStore.instance.clearEmbeddings(kCurrentModel);
|
||||
_logger.info("Indexes cleared for $kCurrentModel");
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import "dart:async";
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:flutter/material.dart";
|
||||
import "package:intl/intl.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import 'package:photos/events/embedding_updated_event.dart';
|
||||
import "package:photos/generated/l10n.dart";
|
||||
import "package:photos/services/feature_flag_service.dart";
|
||||
import "package:photos/services/semantic_search/semantic_search_service.dart";
|
||||
import "package:photos/theme/ente_theme.dart";
|
||||
import "package:photos/ui/common/loading_widget.dart";
|
||||
|
@ -92,7 +92,10 @@ class _MachineLearningSettingsPageState
|
|||
!LocalSettings.instance.hasEnabledMagicSearch(),
|
||||
);
|
||||
if (LocalSettings.instance.hasEnabledMagicSearch()) {
|
||||
unawaited(SemanticSearchService.instance.init());
|
||||
unawaited(
|
||||
SemanticSearchService.instance
|
||||
.init(shouldSyncImmediately: true),
|
||||
);
|
||||
} else {
|
||||
await SemanticSearchService.instance.clearQueue();
|
||||
}
|
||||
|
@ -129,7 +132,7 @@ class _MachineLearningSettingsPageState
|
|||
const SizedBox(
|
||||
height: 12,
|
||||
),
|
||||
kDebugMode
|
||||
FeatureFlagService.instance.isInternalUserOrDebugBuild()
|
||||
? MenuItemWidget(
|
||||
leadingIcon: Icons.delete_sweep_outlined,
|
||||
captionedTextWidget: CaptionedTextWidget(
|
||||
|
|
|
@ -12,7 +12,7 @@ description: ente photos application
|
|||
# Read more about iOS versioning at
|
||||
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
|
||||
|
||||
version: 0.8.31+551
|
||||
version: 0.8.32+552
|
||||
|
||||
environment:
|
||||
sdk: ">=3.0.0 <4.0.0"
|
||||
|
|
Loading…
Add table
Reference in a new issue