浏览代码

Ensure unique embeddings for a given file and model (#1638)

Vishnu Mohandas 1 年之前
父节点
当前提交
0da0df0738

+ 2 - 2
lib/db/embeddings_db.dart

@@ -32,14 +32,14 @@ class EmbeddingsDB {
 
   Future<void> put(Embedding embedding) {
     return _isar.writeTxn(() async {
-      await _isar.embeddings.put(embedding);
+      await _isar.embeddings.putByIndex(Embedding.index, embedding);
       Bus.instance.fire(EmbeddingUpdatedEvent());
     });
   }
 
   Future<void> putMany(List<Embedding> embeddings) {
     return _isar.writeTxn(() async {
-      await _isar.embeddings.putAll(embeddings);
+      await _isar.embeddings.putAllByIndex(Embedding.index, embeddings);
       Bus.instance.fire(EmbeddingUpdatedEvent());
     });
   }

+ 4 - 1
lib/models/embedding.dart

@@ -6,9 +6,12 @@ part 'embedding.g.dart';
 
 @collection
 class Embedding {
-  Id id = Isar.autoIncrement; // you can also use id = null to auto increment
+  static const index = 'unique_file_model_embedding';
+
+  Id id = Isar.autoIncrement;
   final int fileID;
   @enumerated
+  @Index(name: index, composite: [CompositeIndex('fileID')], unique: true, replace: true)
   final Model model;
   final List<double> embedding;
   int? updationTime;

+ 304 - 1
lib/models/embedding.g.dart

@@ -44,7 +44,26 @@ const EmbeddingSchema = CollectionSchema(
   deserialize: _embeddingDeserialize,
   deserializeProp: _embeddingDeserializeProp,
   idName: r'id',
-  indexes: {},
+  indexes: {
+    r'unique_file_model_embedding': IndexSchema(
+      id: 6248303800853228628,
+      name: r'unique_file_model_embedding',
+      unique: true,
+      replace: true,
+      properties: [
+        IndexPropertySchema(
+          name: r'model',
+          type: IndexType.value,
+          caseSensitive: false,
+        ),
+        IndexPropertySchema(
+          name: r'fileID',
+          type: IndexType.value,
+          caseSensitive: false,
+        )
+      ],
+    )
+  },
   links: {},
   embeddedSchemas: {},
   getId: _embeddingGetId,
@@ -134,6 +153,95 @@ void _embeddingAttach(IsarCollection<dynamic> col, Id id, Embedding object) {
   object.id = id;
 }
 
+extension EmbeddingByIndex on IsarCollection<Embedding> {
+  Future<Embedding?> getByModelFileID(Model model, int fileID) {
+    return getByIndex(r'unique_file_model_embedding', [model, fileID]);
+  }
+
+  Embedding? getByModelFileIDSync(Model model, int fileID) {
+    return getByIndexSync(r'unique_file_model_embedding', [model, fileID]);
+  }
+
+  Future<bool> deleteByModelFileID(Model model, int fileID) {
+    return deleteByIndex(r'unique_file_model_embedding', [model, fileID]);
+  }
+
+  bool deleteByModelFileIDSync(Model model, int fileID) {
+    return deleteByIndexSync(r'unique_file_model_embedding', [model, fileID]);
+  }
+
+  Future<List<Embedding?>> getAllByModelFileID(
+      List<Model> modelValues, List<int> fileIDValues) {
+    final len = modelValues.length;
+    assert(fileIDValues.length == len,
+        'All index values must have the same length');
+    final values = <List<dynamic>>[];
+    for (var i = 0; i < len; i++) {
+      values.add([modelValues[i], fileIDValues[i]]);
+    }
+
+    return getAllByIndex(r'unique_file_model_embedding', values);
+  }
+
+  List<Embedding?> getAllByModelFileIDSync(
+      List<Model> modelValues, List<int> fileIDValues) {
+    final len = modelValues.length;
+    assert(fileIDValues.length == len,
+        'All index values must have the same length');
+    final values = <List<dynamic>>[];
+    for (var i = 0; i < len; i++) {
+      values.add([modelValues[i], fileIDValues[i]]);
+    }
+
+    return getAllByIndexSync(r'unique_file_model_embedding', values);
+  }
+
+  Future<int> deleteAllByModelFileID(
+      List<Model> modelValues, List<int> fileIDValues) {
+    final len = modelValues.length;
+    assert(fileIDValues.length == len,
+        'All index values must have the same length');
+    final values = <List<dynamic>>[];
+    for (var i = 0; i < len; i++) {
+      values.add([modelValues[i], fileIDValues[i]]);
+    }
+
+    return deleteAllByIndex(r'unique_file_model_embedding', values);
+  }
+
+  int deleteAllByModelFileIDSync(
+      List<Model> modelValues, List<int> fileIDValues) {
+    final len = modelValues.length;
+    assert(fileIDValues.length == len,
+        'All index values must have the same length');
+    final values = <List<dynamic>>[];
+    for (var i = 0; i < len; i++) {
+      values.add([modelValues[i], fileIDValues[i]]);
+    }
+
+    return deleteAllByIndexSync(r'unique_file_model_embedding', values);
+  }
+
+  Future<Id> putByModelFileID(Embedding object) {
+    return putByIndex(r'unique_file_model_embedding', object);
+  }
+
+  Id putByModelFileIDSync(Embedding object, {bool saveLinks = true}) {
+    return putByIndexSync(r'unique_file_model_embedding', object,
+        saveLinks: saveLinks);
+  }
+
+  Future<List<Id>> putAllByModelFileID(List<Embedding> objects) {
+    return putAllByIndex(r'unique_file_model_embedding', objects);
+  }
+
+  List<Id> putAllByModelFileIDSync(List<Embedding> objects,
+      {bool saveLinks = true}) {
+    return putAllByIndexSync(r'unique_file_model_embedding', objects,
+        saveLinks: saveLinks);
+  }
+}
+
 extension EmbeddingQueryWhereSort
     on QueryBuilder<Embedding, Embedding, QWhere> {
   QueryBuilder<Embedding, Embedding, QAfterWhere> anyId() {
@@ -141,6 +249,14 @@ extension EmbeddingQueryWhereSort
       return query.addWhereClause(const IdWhereClause.any());
     });
   }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhere> anyModelFileID() {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(
+        const IndexWhereClause.any(indexName: r'unique_file_model_embedding'),
+      );
+    });
+  }
 }
 
 extension EmbeddingQueryWhere
@@ -209,6 +325,193 @@ extension EmbeddingQueryWhere
       ));
     });
   }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelEqualToAnyFileID(
+      Model model) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.equalTo(
+        indexName: r'unique_file_model_embedding',
+        value: [model],
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelNotEqualToAnyFileID(Model model) {
+    return QueryBuilder.apply(this, (query) {
+      if (query.whereSort == Sort.asc) {
+        return query
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [],
+              upper: [model],
+              includeUpper: false,
+            ))
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model],
+              includeLower: false,
+              upper: [],
+            ));
+      } else {
+        return query
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model],
+              includeLower: false,
+              upper: [],
+            ))
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [],
+              upper: [model],
+              includeUpper: false,
+            ));
+      }
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelGreaterThanAnyFileID(
+    Model model, {
+    bool include = false,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [model],
+        includeLower: include,
+        upper: [],
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelLessThanAnyFileID(
+    Model model, {
+    bool include = false,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [],
+        upper: [model],
+        includeUpper: include,
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelBetweenAnyFileID(
+    Model lowerModel,
+    Model upperModel, {
+    bool includeLower = true,
+    bool includeUpper = true,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [lowerModel],
+        includeLower: includeLower,
+        upper: [upperModel],
+        includeUpper: includeUpper,
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause> modelFileIDEqualTo(
+      Model model, int fileID) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.equalTo(
+        indexName: r'unique_file_model_embedding',
+        value: [model, fileID],
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelEqualToFileIDNotEqualTo(Model model, int fileID) {
+    return QueryBuilder.apply(this, (query) {
+      if (query.whereSort == Sort.asc) {
+        return query
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model],
+              upper: [model, fileID],
+              includeUpper: false,
+            ))
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model, fileID],
+              includeLower: false,
+              upper: [model],
+            ));
+      } else {
+        return query
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model, fileID],
+              includeLower: false,
+              upper: [model],
+            ))
+            .addWhereClause(IndexWhereClause.between(
+              indexName: r'unique_file_model_embedding',
+              lower: [model],
+              upper: [model, fileID],
+              includeUpper: false,
+            ));
+      }
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelEqualToFileIDGreaterThan(
+    Model model,
+    int fileID, {
+    bool include = false,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [model, fileID],
+        includeLower: include,
+        upper: [model],
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelEqualToFileIDLessThan(
+    Model model,
+    int fileID, {
+    bool include = false,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [model],
+        upper: [model, fileID],
+        includeUpper: include,
+      ));
+    });
+  }
+
+  QueryBuilder<Embedding, Embedding, QAfterWhereClause>
+      modelEqualToFileIDBetween(
+    Model model,
+    int lowerFileID,
+    int upperFileID, {
+    bool includeLower = true,
+    bool includeUpper = true,
+  }) {
+    return QueryBuilder.apply(this, (query) {
+      return query.addWhereClause(IndexWhereClause.between(
+        indexName: r'unique_file_model_embedding',
+        lower: [model, lowerFileID],
+        includeLower: includeLower,
+        upper: [model, upperFileID],
+        includeUpper: includeUpper,
+      ));
+    });
+  }
 }
 
 extension EmbeddingQueryFilter

+ 5 - 0
lib/services/semantic_search/embedding_store.dart

@@ -61,6 +61,11 @@ class EmbeddingStore {
     unawaited(_pushEmbedding(file, embedding));
   }
 
+  Future<void> clearEmbeddings(Model model) async {
+    await EmbeddingsDB.instance.deleteAllForModel(model);
+    await _preferences.remove(kEmbeddingsSyncTimeKey);
+  }
+
   Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
     final encryptionKey = getFileKey(file);
     final embeddingJSON = jsonEncode(embedding.embedding);

+ 5 - 2
lib/services/semantic_search/semantic_search_service.dart

@@ -51,7 +51,7 @@ class SemanticSearchService {
 
   get hasInitialized => _hasInitialized;
 
-  Future<void> init() async {
+  Future<void> init({bool shouldSyncImmediately = false}) async {
     if (!LocalSettings.instance.hasEnabledMagicSearch()) {
       return;
     }
@@ -87,6 +87,9 @@ class SemanticSearchService {
     Bus.instance.on<FileUploadedEvent>().listen((event) async {
       _addToQueue(event.file);
     });
+    if (shouldSyncImmediately) {
+      unawaited(sync());
+    }
   }
 
   Future<void> release() async {
@@ -146,7 +149,7 @@ class SemanticSearchService {
   }
 
   Future<void> clearIndexes() async {
-    await EmbeddingsDB.instance.deleteAllForModel(kCurrentModel);
+    await EmbeddingStore.instance.clearEmbeddings(kCurrentModel);
     _logger.info("Indexes cleared for $kCurrentModel");
   }
 

+ 6 - 3
lib/ui/settings/machine_learning_settings_page.dart

@@ -1,11 +1,11 @@
 import "dart:async";
 
-import "package:flutter/foundation.dart";
 import "package:flutter/material.dart";
 import "package:intl/intl.dart";
 import "package:photos/core/event_bus.dart";
 import 'package:photos/events/embedding_updated_event.dart';
 import "package:photos/generated/l10n.dart";
+import "package:photos/services/feature_flag_service.dart";
 import "package:photos/services/semantic_search/semantic_search_service.dart";
 import "package:photos/theme/ente_theme.dart";
 import "package:photos/ui/common/loading_widget.dart";
@@ -92,7 +92,10 @@ class _MachineLearningSettingsPageState
                 !LocalSettings.instance.hasEnabledMagicSearch(),
               );
               if (LocalSettings.instance.hasEnabledMagicSearch()) {
-                unawaited(SemanticSearchService.instance.init());
+                unawaited(
+                  SemanticSearchService.instance
+                      .init(shouldSyncImmediately: true),
+                );
               } else {
                 await SemanticSearchService.instance.clearQueue();
               }
@@ -129,7 +132,7 @@ class _MachineLearningSettingsPageState
                   const SizedBox(
                     height: 12,
                   ),
-                  kDebugMode
+                  FeatureFlagService.instance.isInternalUserOrDebugBuild()
                       ? MenuItemWidget(
                           leadingIcon: Icons.delete_sweep_outlined,
                           captionedTextWidget: CaptionedTextWidget(

+ 1 - 1
pubspec.yaml

@@ -12,7 +12,7 @@ description: ente photos application
 # Read more about iOS versioning at
 # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
 
-version: 0.8.31+551
+version: 0.8.32+552
 
 environment:
   sdk: ">=3.0.0 <4.0.0"