浏览代码

Ml fixes (#1812)

## Description

- Fixed some issues in face indexing
- Cleaned up some functions in FaceMlService
- Hooked iOS onto MLController for battery check, for faces and clip

## Tests

Tested on my Pixel phone only
Laurens Priem 1 年之前
父节点
当前提交
201286f59a

+ 0 - 2
mobile/lib/main.dart

@@ -242,8 +242,6 @@ Future<void> _init(bool isBackground, {String via = ''}) async {
     // unawaited(ObjectDetectionService.instance.init());
     if (flagService.faceSearchEnabled) {
       unawaited(FaceMlService.instance.init());
-      FaceMlService.instance.listenIndexOnDiffSync();
-      FaceMlService.instance.listenOnPeopleChangedSync();
     } else {
       if (LocalSettings.instance.isFaceIndexingEnabled) {
         unawaited(LocalSettings.instance.toggleFaceIndexing());

+ 275 - 417
mobile/lib/services/machine_learning/face_ml/face_ml_service.dart

@@ -9,7 +9,6 @@ import "dart:ui" show Image;
 import "package:computer/computer.dart";
 import "package:dart_ui_isolate/dart_ui_isolate.dart";
 import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
-import "package:flutter_image_compress/flutter_image_compress.dart";
 import "package:logging/logging.dart";
 import "package:onnxruntime/onnxruntime.dart";
 import "package:package_info_plus/package_info_plus.dart";
@@ -74,7 +73,7 @@ class FaceMlService {
   late ReceivePort _receivePort = ReceivePort();
   late SendPort _mainSendPort;
 
-  bool isIsolateSpawned = false;
+  bool _isIsolateSpawned = false;
 
   // singleton pattern
   FaceMlService._privateConstructor();
@@ -91,12 +90,14 @@ class FaceMlService {
   bool isInitialized = false;
   late String client;
 
-  bool canRunMLController = false;
-  bool isImageIndexRunning = false;
-  bool isClusteringRunning = false;
-  bool shouldSyncPeople = false;
+  bool debugIndexingDisabled = false;
+  bool _mlControllerStatus = false;
+  bool _isIndexingOrClusteringRunning = false;
+  bool _shouldPauseIndexingAndClustering = false;
+  bool _shouldSyncPeople = false;
+  bool _isSyncing = false;
 
-  final int _fileDownloadLimit = 15;
+  final int _fileDownloadLimit = 10;
   final int _embeddingFetchLimit = 200;
 
   Future<void> init({bool initializeImageMlIsolate = false}) async {
@@ -133,31 +134,28 @@ class FaceMlService {
       _logger.info("client: $client");
 
       isInitialized = true;
-      canRunMLController = !Platform.isAndroid || kDebugMode;
+      _mlControllerStatus = !Platform.isAndroid;
 
       /// hooking FaceML into [MachineLearningController]
-      if (Platform.isAndroid && !kDebugMode) {
-        Bus.instance.on<MachineLearningControlEvent>().listen((event) {
-          if (LocalSettings.instance.isFaceIndexingEnabled == false) {
-            return;
-          }
-          canRunMLController = event.shouldRun;
-          if (canRunMLController) {
-            _logger.info(
-              "MLController allowed running ML, faces indexing starting",
-            );
-            unawaited(indexAndClusterAll());
-          } else {
-            _logger
-                .info("MLController stopped running ML, faces indexing paused");
-            pauseIndexing();
-          }
-        });
-      } else {
-        if (!kDebugMode) {
+      Bus.instance.on<MachineLearningControlEvent>().listen((event) {
+        if (LocalSettings.instance.isFaceIndexingEnabled == false) {
+          return;
+        }
+        _mlControllerStatus = event.shouldRun;
+        if (_mlControllerStatus) {
+          _logger.info(
+            "MLController allowed running ML, faces indexing starting (unless it's already fetching embeddings)",
+          );
           unawaited(indexAndClusterAll());
+        } else {
+          _logger
+              .info("MLController stopped running ML, faces indexing will be paused (unless it's fetching embeddings)");
+          pauseIndexingAndClustering();
         }
-      }
+      });
+
+      _listenIndexOnDiffSync();
+      _listenOnPeopleChangedSync();
     });
   }
 
@@ -165,24 +163,15 @@ class FaceMlService {
     OrtEnv.instance.init();
   }
 
-  void listenIndexOnDiffSync() {
+  void _listenIndexOnDiffSync() {
     Bus.instance.on<DiffSyncCompleteEvent>().listen((event) async {
-      if (LocalSettings.instance.isFaceIndexingEnabled == false || kDebugMode) {
-        return;
-      }
-      // [neeraj] intentional delay in starting indexing on diff sync, this gives time for the user
-      // to disable face-indexing in case it's causing crash. In the future, we
-      // should have a better way to handle this.
-      shouldSyncPeople = true;
-      Future.delayed(const Duration(seconds: 10), () {
-        unawaited(indexAndClusterAll());
-      });
+      unawaited(sync());
     });
   }
 
-  void listenOnPeopleChangedSync() {
+  void _listenOnPeopleChangedSync() {
     Bus.instance.on<PeopleChangedEvent>().listen((event) {
-      shouldSyncPeople = true;
+      _shouldSyncPeople = true;
     });
   }
 
@@ -218,9 +207,9 @@ class FaceMlService {
     });
   }
 
-  Future<void> initIsolate() async {
+  Future<void> _initIsolate() async {
     return _initLockIsolate.synchronized(() async {
-      if (isIsolateSpawned) return;
+      if (_isIsolateSpawned) return;
       _logger.info("initIsolate called");
 
       _receivePort = ReceivePort();
@@ -231,19 +220,19 @@ class FaceMlService {
           _receivePort.sendPort,
         );
         _mainSendPort = await _receivePort.first as SendPort;
-        isIsolateSpawned = true;
+        _isIsolateSpawned = true;
 
         _resetInactivityTimer();
       } catch (e) {
         _logger.severe('Could not spawn isolate', e);
-        isIsolateSpawned = false;
+        _isIsolateSpawned = false;
       }
     });
   }
 
-  Future<void> ensureSpawnedIsolate() async {
-    if (!isIsolateSpawned) {
-      await initIsolate();
+  Future<void> _ensureSpawnedIsolate() async {
+    if (!_isIsolateSpawned) {
+      await _initIsolate();
     }
   }
 
@@ -286,11 +275,11 @@ class FaceMlService {
   Future<dynamic> _runInIsolate(
     (FaceMlOperation, Map<String, dynamic>) message,
   ) async {
-    await ensureSpawnedIsolate();
+    await _ensureSpawnedIsolate();
     return _functionLock.synchronized(() async {
       _resetInactivityTimer();
 
-      if (isImageIndexRunning == false || canRunMLController == false) {
+      if (_shouldPauseIndexingAndClustering == false) {
         return null;
       }
 
@@ -332,35 +321,42 @@ class FaceMlService {
         _logger.info(
           'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.',
         );
-        disposeIsolate();
+        _disposeIsolate();
       }
     });
   }
 
-  void disposeIsolate() async {
-    if (!isIsolateSpawned) return;
+  void _disposeIsolate() async {
+    if (!_isIsolateSpawned) return;
     await release();
 
-    isIsolateSpawned = false;
+    _isIsolateSpawned = false;
     _isolate.kill();
     _receivePort.close();
     _inactivityTimer?.cancel();
   }
 
-  Future<void> indexAndClusterAll() async {
-    if (isClusteringRunning || isImageIndexRunning) {
-      _logger.info("indexing or clustering is already running, skipping");
+  Future<void> sync({bool forceSync = true}) async {
+    if (_isSyncing) {
       return;
     }
-    if (shouldSyncPeople) {
+    _isSyncing = true;
+    if (forceSync) {
       await PersonService.instance.reconcileClusters();
-      shouldSyncPeople = false;
+      _shouldSyncPeople = false;
     }
+    _isSyncing = false;
+  }
+
+  Future<void> indexAndClusterAll() async {
+    if (_cannotRunMLFunction()) return;
+
+    await sync(forceSync: _shouldSyncPeople);
     await indexAllImages();
     final indexingCompleteRatio = await _getIndexedDoneRatio();
     if (indexingCompleteRatio < 0.95) {
       _logger.info(
-        "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio",
+        "Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio",
       );
       return;
     } else {
@@ -368,174 +364,9 @@ class FaceMlService {
     }
   }
 
-  Future<void> clusterAllImages({
-    double minFaceScore = kMinimumQualityFaceScore,
-    bool clusterInBuckets = true,
-  }) async {
-    if (!canRunMLController) {
-      _logger
-          .info("MLController does not allow running ML, skipping clustering");
-      return;
-    }
-    if (isClusteringRunning) {
-      _logger.info("clusterAllImages is already running, skipping");
-      return;
-    }
-    // verify faces is enabled
-    if (LocalSettings.instance.isFaceIndexingEnabled == false) {
-      _logger.warning("clustering is disabled by user");
-      return;
-    }
-
-    final indexingCompleteRatio = await _getIndexedDoneRatio();
-    if (indexingCompleteRatio < 0.95) {
-      _logger.info(
-        "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio",
-      );
-      return;
-    }
-
-    _logger.info("`clusterAllImages()` called");
-    isClusteringRunning = true;
-    final clusterAllImagesTime = DateTime.now();
-
-    try {
-      // Get a sense of the total number of faces in the database
-      final int totalFaces = await FaceMLDataDB.instance
-          .getTotalFaceCount(minFaceScore: minFaceScore);
-      final fileIDToCreationTime =
-          await FilesDB.instance.getFileIDToCreationTime();
-      final startEmbeddingFetch = DateTime.now();
-      // read all embeddings
-      final result = await FaceMLDataDB.instance.getFaceInfoForClustering(
-        minScore: minFaceScore,
-        maxFaces: totalFaces,
-      );
-      final Set<int> missingFileIDs = {};
-      final allFaceInfoForClustering = <FaceInfoForClustering>[];
-      for (final faceInfo in result) {
-        if (!fileIDToCreationTime.containsKey(faceInfo.fileID)) {
-          missingFileIDs.add(faceInfo.fileID);
-        } else {
-          allFaceInfoForClustering.add(faceInfo);
-        }
-      }
-      // sort the embeddings based on file creation time, oldest first
-      allFaceInfoForClustering.sort((a, b) {
-        return fileIDToCreationTime[a.fileID]!
-            .compareTo(fileIDToCreationTime[b.fileID]!);
-      });
-      _logger.info(
-        'Getting and sorting embeddings took ${DateTime.now().difference(startEmbeddingFetch).inMilliseconds} ms for ${allFaceInfoForClustering.length} embeddings'
-        'and ${missingFileIDs.length} missing fileIDs',
-      );
-
-      // Get the current cluster statistics
-      final Map<int, (Uint8List, int)> oldClusterSummaries =
-          await FaceMLDataDB.instance.getAllClusterSummary();
-
-      if (clusterInBuckets) {
-        const int bucketSize = 20000;
-        const int offsetIncrement = 7500;
-        int offset = 0;
-        int bucket = 1;
-
-        while (true) {
-          if (!canRunMLController) {
-            _logger.info(
-              "MLController does not allow running ML, stopping before clustering bucket $bucket",
-            );
-            break;
-          }
-          if (offset > allFaceInfoForClustering.length - 1) {
-            _logger.warning(
-              'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
-            );
-            break;
-          }
-          if (offset > totalFaces) {
-            _logger.warning(
-              'offset > totalFaces, this should ideally not happen. offset: $offset, totalFaces: $totalFaces',
-            );
-            break;
-          }
-
-          final bucketStartTime = DateTime.now();
-          final faceInfoForClustering = allFaceInfoForClustering.sublist(
-            offset,
-            min(offset + bucketSize, allFaceInfoForClustering.length),
-          );
-
-          final clusteringResult =
-              await FaceClusteringService.instance.predictLinear(
-            faceInfoForClustering.toSet(),
-            fileIDToCreationTime: fileIDToCreationTime,
-            offset: offset,
-            oldClusterSummaries: oldClusterSummaries,
-          );
-          if (clusteringResult == null) {
-            _logger.warning("faceIdToCluster is null");
-            return;
-          }
-
-          await FaceMLDataDB.instance
-              .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
-          await FaceMLDataDB.instance
-              .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
-          for (final faceInfo in faceInfoForClustering) {
-            faceInfo.clusterId ??=
-                clusteringResult.newFaceIdToCluster[faceInfo.faceID];
-          }
-          for (final clusterUpdate
-              in clusteringResult.newClusterSummaries!.entries) {
-            oldClusterSummaries[clusterUpdate.key] = clusterUpdate.value;
-          }
-          _logger.info(
-            'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset, in ${DateTime.now().difference(bucketStartTime).inSeconds} seconds',
-          );
-          if (offset + bucketSize >= totalFaces) {
-            _logger.info('All faces clustered');
-            break;
-          }
-          offset += offsetIncrement;
-          bucket++;
-        }
-      } else {
-        final clusterStartTime = DateTime.now();
-        // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
-        final clusteringResult =
-            await FaceClusteringService.instance.predictLinear(
-          allFaceInfoForClustering.toSet(),
-          fileIDToCreationTime: fileIDToCreationTime,
-          oldClusterSummaries: oldClusterSummaries,
-        );
-        if (clusteringResult == null) {
-          _logger.warning("faceIdToCluster is null");
-          return;
-        }
-        final clusterDoneTime = DateTime.now();
-        _logger.info(
-          'done with clustering ${allFaceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
-        );
-
-        // Store the updated clusterIDs in the database
-        _logger.info(
-          'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
-        );
-        await FaceMLDataDB.instance
-            .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
-        await FaceMLDataDB.instance
-            .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
-        _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
-            '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
-      }
-      Bus.instance.fire(PeopleChangedEvent());
-      _logger.info('clusterAllImages() finished, in '
-          '${DateTime.now().difference(clusterAllImagesTime).inSeconds} seconds');
-    } catch (e, s) {
-      _logger.severe("`clusterAllImages` failed", e, s);
-    } finally {
-      isClusteringRunning = false;
+  void pauseIndexingAndClustering() {
+    if (_isIndexingOrClusteringRunning) {
+      _shouldPauseIndexingAndClustering = true;
     }
   }
 
@@ -543,17 +374,10 @@ class FaceMlService {
   ///
   /// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image.
   Future<void> indexAllImages({int retryFetchCount = 10}) async {
-    if (isImageIndexRunning) {
-      _logger.warning("indexAllImages is already running, skipping");
-      return;
-    }
-    // verify faces is enabled
-    if (LocalSettings.instance.isFaceIndexingEnabled == false) {
-      _logger.warning("indexing is disabled by user");
-      return;
-    }
+    if (_cannotRunMLFunction()) return;
+
     try {
-      isImageIndexRunning = true;
+      _isIndexingOrClusteringRunning = true;
       _logger.info('starting image indexing');
 
       final w = (kDebugMode ? EnteWatch('prepare indexing files') : null)
@@ -608,6 +432,7 @@ class FaceMlService {
       w?.log('preparing all files to index');
       final List<List<EnteFile>> chunks =
           sortedBylocalID.chunks(_embeddingFetchLimit);
+      int fetchedCount = 0;
       outerLoop:
       for (final chunk in chunks) {
         final futures = <Future<bool>>[];
@@ -619,17 +444,15 @@ class FaceMlService {
             for (final f in chunk) {
               fileIds.add(f.uploadedFileID!);
             }
-            final EnteWatch? w =
-                flagService.internalUser ? EnteWatch("face_em_fetch") : null;
-            w?.start();
-            w?.log('starting remote fetch for ${fileIds.length} files');
+            _logger.info('starting remote fetch for ${fileIds.length} files');
             final res =
                 await RemoteFileMLService.instance.getFilessEmbedding(fileIds);
-            w?.logAndReset('fetched ${res.mlData.length} embeddings');
+            _logger.info('fetched ${res.mlData.length} embeddings');
+            fetchedCount += res.mlData.length;
             final List<Face> faces = [];
             final remoteFileIdToVersion = <int, int>{};
             for (FileMl fileMl in res.mlData.values) {
-              if (shouldDiscardRemoteEmbedding(fileMl)) continue;
+              if (_shouldDiscardRemoteEmbedding(fileMl)) continue;
               if (fileMl.faceEmbedding.faces.isEmpty) {
                 faces.add(
                   Face.empty(
@@ -659,7 +482,7 @@ class FaceMlService {
             }
 
             await FaceMLDataDB.instance.bulkInsertFaces(faces);
-            w?.logAndReset('stored embeddings');
+            _logger.info('stored embeddings');
             for (final entry in remoteFileIdToVersion.entries) {
               alreadyIndexedFiles[entry.key] = entry.value;
             }
@@ -688,7 +511,7 @@ class FaceMlService {
         final smallerChunks = chunk.chunks(_fileDownloadLimit);
         for (final smallestChunk in smallerChunks) {
           for (final enteFile in smallestChunk) {
-            if (isImageIndexRunning == false) {
+            if (_shouldPauseIndexingAndClustering) {
               _logger.info("indexAllImages() was paused, stopping");
               break outerLoop;
             }
@@ -712,16 +535,168 @@ class FaceMlService {
 
       stopwatch.stop();
       _logger.info(
-        "`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images. MLController status: $canRunMLController)",
+        "`indexAllImages()` finished. Fetched $fetchedCount and analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images. MLController status: $_mlControllerStatus)",
       );
     } catch (e, s) {
       _logger.severe("indexAllImages failed", e, s);
     } finally {
-      isImageIndexRunning = false;
+      _isIndexingOrClusteringRunning = false;
+      _shouldPauseIndexingAndClustering = false;
+    }
+  }
+
+  Future<void> clusterAllImages({
+    double minFaceScore = kMinimumQualityFaceScore,
+    bool clusterInBuckets = true,
+  }) async {
+    if (_cannotRunMLFunction()) return;
+
+    _logger.info("`clusterAllImages()` called");
+    _isIndexingOrClusteringRunning = true;
+    final clusterAllImagesTime = DateTime.now();
+
+    try {
+      // Get a sense of the total number of faces in the database
+      final int totalFaces = await FaceMLDataDB.instance
+          .getTotalFaceCount(minFaceScore: minFaceScore);
+      final fileIDToCreationTime =
+          await FilesDB.instance.getFileIDToCreationTime();
+      final startEmbeddingFetch = DateTime.now();
+      // read all embeddings
+      final result = await FaceMLDataDB.instance.getFaceInfoForClustering(
+        minScore: minFaceScore,
+        maxFaces: totalFaces,
+      );
+      final Set<int> missingFileIDs = {};
+      final allFaceInfoForClustering = <FaceInfoForClustering>[];
+      for (final faceInfo in result) {
+        if (!fileIDToCreationTime.containsKey(faceInfo.fileID)) {
+          missingFileIDs.add(faceInfo.fileID);
+        } else {
+          allFaceInfoForClustering.add(faceInfo);
+        }
+      }
+      // sort the embeddings based on file creation time, oldest first
+      allFaceInfoForClustering.sort((a, b) {
+        return fileIDToCreationTime[a.fileID]!
+            .compareTo(fileIDToCreationTime[b.fileID]!);
+      });
+      _logger.info(
+        'Getting and sorting embeddings took ${DateTime.now().difference(startEmbeddingFetch).inMilliseconds} ms for ${allFaceInfoForClustering.length} embeddings'
+        'and ${missingFileIDs.length} missing fileIDs',
+      );
+
+      // Get the current cluster statistics
+      final Map<int, (Uint8List, int)> oldClusterSummaries =
+          await FaceMLDataDB.instance.getAllClusterSummary();
+
+      if (clusterInBuckets) {
+        const int bucketSize = 20000;
+        const int offsetIncrement = 7500;
+        int offset = 0;
+        int bucket = 1;
+
+        while (true) {
+          if (_shouldPauseIndexingAndClustering) {
+            _logger.info(
+              "MLController does not allow running ML, stopping before clustering bucket $bucket",
+            );
+            break;
+          }
+          if (offset > allFaceInfoForClustering.length - 1) {
+            _logger.warning(
+              'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
+            );
+            break;
+          }
+          if (offset > totalFaces) {
+            _logger.warning(
+              'offset > totalFaces, this should ideally not happen. offset: $offset, totalFaces: $totalFaces',
+            );
+            break;
+          }
+
+          final bucketStartTime = DateTime.now();
+          final faceInfoForClustering = allFaceInfoForClustering.sublist(
+            offset,
+            min(offset + bucketSize, allFaceInfoForClustering.length),
+          );
+
+          final clusteringResult =
+              await FaceClusteringService.instance.predictLinear(
+            faceInfoForClustering.toSet(),
+            fileIDToCreationTime: fileIDToCreationTime,
+            offset: offset,
+            oldClusterSummaries: oldClusterSummaries,
+          );
+          if (clusteringResult == null) {
+            _logger.warning("faceIdToCluster is null");
+            return;
+          }
+
+          await FaceMLDataDB.instance
+              .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
+          await FaceMLDataDB.instance
+              .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
+          for (final faceInfo in faceInfoForClustering) {
+            faceInfo.clusterId ??=
+                clusteringResult.newFaceIdToCluster[faceInfo.faceID];
+          }
+          for (final clusterUpdate
+              in clusteringResult.newClusterSummaries!.entries) {
+            oldClusterSummaries[clusterUpdate.key] = clusterUpdate.value;
+          }
+          _logger.info(
+            'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset, in ${DateTime.now().difference(bucketStartTime).inSeconds} seconds',
+          );
+          if (offset + bucketSize >= totalFaces) {
+            _logger.info('All faces clustered');
+            break;
+          }
+          offset += offsetIncrement;
+          bucket++;
+        }
+      } else {
+        final clusterStartTime = DateTime.now();
+        // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
+        final clusteringResult =
+            await FaceClusteringService.instance.predictLinear(
+          allFaceInfoForClustering.toSet(),
+          fileIDToCreationTime: fileIDToCreationTime,
+          oldClusterSummaries: oldClusterSummaries,
+        );
+        if (clusteringResult == null) {
+          _logger.warning("faceIdToCluster is null");
+          return;
+        }
+        final clusterDoneTime = DateTime.now();
+        _logger.info(
+          'done with clustering ${allFaceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
+        );
+
+        // Store the updated clusterIDs in the database
+        _logger.info(
+          'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
+        );
+        await FaceMLDataDB.instance
+            .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
+        await FaceMLDataDB.instance
+            .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
+        _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
+            '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
+      }
+      Bus.instance.fire(PeopleChangedEvent());
+      _logger.info('clusterAllImages() finished, in '
+          '${DateTime.now().difference(clusterAllImagesTime).inSeconds} seconds');
+    } catch (e, s) {
+      _logger.severe("`clusterAllImages` failed", e, s);
+    } finally {
+      _isIndexingOrClusteringRunning = false;
+      _shouldPauseIndexingAndClustering = false;
     }
   }
 
-  bool shouldDiscardRemoteEmbedding(FileMl fileMl) {
+  bool _shouldDiscardRemoteEmbedding(FileMl fileMl) {
     if (fileMl.faceEmbedding.version < faceMlVersion) {
       debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} "
           "because version is ${fileMl.faceEmbedding.version} and we need $faceMlVersion");
@@ -769,7 +744,7 @@ class FaceMlService {
     );
 
     try {
-      final FaceMlResult? result = await analyzeImageInSingleIsolate(
+      final FaceMlResult? result = await _analyzeImageInSingleIsolate(
         enteFile,
         // preferUsingThumbnailForEverything: false,
         // disposeImageIsolateAfterUse: false,
@@ -861,12 +836,8 @@ class FaceMlService {
     }
   }
 
-  void pauseIndexing() {
-    isImageIndexRunning = false;
-  }
-
   /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate.
-  Future<FaceMlResult?> analyzeImageInSingleIsolate(EnteFile enteFile) async {
+  Future<FaceMlResult?> _analyzeImageInSingleIsolate(EnteFile enteFile) async {
     _checkEnteFileForID(enteFile);
     await ensureInitialized();
 
@@ -1057,94 +1028,6 @@ class FaceMlService {
     return imagePath;
   }
 
-  @Deprecated('Deprecated in favor of `_getImagePathForML`')
-  Future<Uint8List?> _getDataForML(
-    EnteFile enteFile, {
-    FileDataForML typeOfData = FileDataForML.fileData,
-  }) async {
-    Uint8List? data;
-
-    switch (typeOfData) {
-      case FileDataForML.fileData:
-        final stopwatch = Stopwatch()..start();
-        final File? actualIoFile = await getFile(enteFile, isOrigin: true);
-        if (actualIoFile != null) {
-          data = await actualIoFile.readAsBytes();
-        }
-        stopwatch.stop();
-        _logger.info(
-          "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms",
-        );
-
-        break;
-
-      case FileDataForML.thumbnailData:
-        final stopwatch = Stopwatch()..start();
-        data = await getThumbnail(enteFile);
-        stopwatch.stop();
-        _logger.info(
-          "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms",
-        );
-        break;
-
-      case FileDataForML.compressedFileData:
-        final stopwatch = Stopwatch()..start();
-        final String tempPath = Configuration.instance.getTempDirectory() +
-            "${enteFile.uploadedFileID!}";
-        final File? actualIoFile = await getFile(enteFile);
-        if (actualIoFile != null) {
-          final compressResult = await FlutterImageCompress.compressAndGetFile(
-            actualIoFile.path,
-            tempPath + ".jpg",
-          );
-          if (compressResult != null) {
-            data = await compressResult.readAsBytes();
-          }
-        }
-        stopwatch.stop();
-        _logger.info(
-          "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms",
-        );
-        break;
-    }
-
-    return data;
-  }
-
-  /// Detects faces in the given image data.
-  ///
-  /// `imageData`: The image data to analyze.
-  ///
-  /// Returns a list of face detection results.
-  ///
-  /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong.
-  Future<List<FaceDetectionRelative>> _detectFacesIsolate(
-    String imagePath,
-    // Uint8List fileData,
-    {
-    FaceMlResultBuilder? resultBuilder,
-  }) async {
-    try {
-      // Get the bounding boxes of the faces
-      final (List<FaceDetectionRelative> faces, dataSize) =
-          await FaceDetectionService.instance.predictInComputer(imagePath);
-
-      // Add detected faces to the resultBuilder
-      if (resultBuilder != null) {
-        resultBuilder.addNewlyDetectedFaces(faces, dataSize);
-      }
-
-      return faces;
-    } on YOLOFaceInterpreterInitializationException {
-      throw CouldNotInitializeFaceDetector();
-    } on YOLOFaceInterpreterRunException {
-      throw CouldNotRunFaceDetector();
-    } catch (e) {
-      _logger.severe('Face detection failed: $e');
-      throw GeneralFaceMlException('Face detection failed: $e');
-    }
-  }
-
   /// Detects faces in the given image data.
   ///
   /// `imageData`: The image data to analyze.
@@ -1183,38 +1066,6 @@ class FaceMlService {
     }
   }
 
-  /// Aligns multiple faces from the given image data.
-  ///
-  /// `imageData`: The image data in [Uint8List] that contains the faces.
-  /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align.
-  ///
-  /// Returns a list of the aligned faces as image data.
-  ///
-  /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails.
-  Future<Float32List> _alignFaces(
-    String imagePath,
-    List<FaceDetectionRelative> faces, {
-    FaceMlResultBuilder? resultBuilder,
-  }) async {
-    try {
-      final (alignedFaces, alignmentResults, _, blurValues, _) =
-          await ImageMlIsolate.instance
-              .preprocessMobileFaceNetOnnx(imagePath, faces);
-
-      if (resultBuilder != null) {
-        resultBuilder.addAlignmentResults(
-          alignmentResults,
-          blurValues,
-        );
-      }
-
-      return alignedFaces;
-    } catch (e, s) {
-      _logger.severe('Face alignment failed: $e', e, s);
-      throw CouldNotWarpAffine();
-    }
-  }
-
   /// Aligns multiple faces from the given image data.
   ///
   /// `imageData`: The image data in [Uint8List] that contains the faces.
@@ -1256,45 +1107,6 @@ class FaceMlService {
     }
   }
 
-  /// Embeds multiple faces from the given input matrices.
-  ///
-  /// `facesMatrices`: The input matrices of the faces to embed.
-  ///
-  /// Returns a list of the face embeddings as lists of doubles.
-  ///
-  /// Throws [CouldNotInitializeFaceEmbeddor], [CouldNotRunFaceEmbeddor], [InputProblemFaceEmbeddor] or [GeneralFaceMlException] if the face embedding fails.
-  Future<List<List<double>>> _embedFaces(
-    Float32List facesList, {
-    FaceMlResultBuilder? resultBuilder,
-  }) async {
-    try {
-      // Get the embedding of the faces
-      final List<List<double>> embeddings =
-          await FaceEmbeddingService.instance.predictInComputer(facesList);
-
-      // Add the embeddings to the resultBuilder
-      if (resultBuilder != null) {
-        resultBuilder.addEmbeddingsToExistingFaces(embeddings);
-      }
-
-      return embeddings;
-    } on MobileFaceNetInterpreterInitializationException {
-      throw CouldNotInitializeFaceEmbeddor();
-    } on MobileFaceNetInterpreterRunException {
-      throw CouldNotRunFaceEmbeddor();
-    } on MobileFaceNetEmptyInput {
-      throw InputProblemFaceEmbeddor("Input is empty");
-    } on MobileFaceNetWrongInputSize {
-      throw InputProblemFaceEmbeddor("Input size is wrong");
-    } on MobileFaceNetWrongInputRange {
-      throw InputProblemFaceEmbeddor("Input range is wrong");
-      // ignore: avoid_catches_without_on_clauses
-    } catch (e) {
-      _logger.severe('Face embedding (batch) failed: $e');
-      throw GeneralFaceMlException('Face embedding (batch) failed: $e');
-    }
-  }
-
   static Future<List<List<double>>> embedFacesSync(
     Float32List facesList,
     int interpreterAddress, {
@@ -1334,10 +1146,9 @@ class FaceMlService {
       _logger.warning(
         '''Skipped analysis of image with enteFile, it might be the wrong format or has no uploadedFileID, or MLController doesn't allow it to run.
         enteFile: ${enteFile.toString()}
-        isImageIndexRunning: $isImageIndexRunning
-        canRunML: $canRunMLController
         ''',
       );
+      _logStatus();
       throw CouldNotRetrieveAnyFileData();
     }
   }
@@ -1361,7 +1172,8 @@ class FaceMlService {
   }
 
   bool _skipAnalysisEnteFile(EnteFile enteFile, Map<int, int> indexedFileIds) {
-    if (isImageIndexRunning == false || canRunMLController == false) {
+    if (_isIndexingOrClusteringRunning == false ||
+        _mlControllerStatus == false) {
       return true;
     }
     // Skip if the file is not uploaded or not owned by the user
@@ -1378,4 +1190,50 @@ class FaceMlService {
     return indexedFileIds.containsKey(id) &&
         indexedFileIds[id]! >= faceMlVersion;
   }
+
+  bool _cannotRunMLFunction({String function = ""}) {
+    if (_isIndexingOrClusteringRunning) {
+      _logger.info(
+        "Cannot run $function because indexing or clustering is already running",
+      );
+      _logStatus();
+      return true;
+    }
+    if (_mlControllerStatus == false) {
+      _logger.info(
+        "Cannot run $function because MLController does not allow it",
+      );
+      _logStatus();
+      return true;
+    }
+    if (debugIndexingDisabled) {
+      _logger.info(
+        "Cannot run $function because debugIndexingDisabled is true",
+      );
+      _logStatus();
+      return true;
+    }
+    if (_shouldPauseIndexingAndClustering) {
+      // This should ideally not be triggered, because one of the above should be triggered instead.
+      _logger.warning(
+        "Cannot run $function because indexing and clustering is being paused",
+      );
+      _logStatus();
+      return true;
+    }
+    return false;
+  }
+
+  void _logStatus() {
+    final String status = '''
+    isInternalUser: ${flagService.internalUser}
+    isFaceIndexingEnabled: ${LocalSettings.instance.isFaceIndexingEnabled}
+    canRunMLController: $_mlControllerStatus
+    isIndexingOrClusteringRunning: $_isIndexingOrClusteringRunning
+    shouldPauseIndexingAndClustering: $_shouldPauseIndexingAndClustering
+    debugIndexingDisabled: $debugIndexingDisabled
+    shouldSyncPeople: $_shouldSyncPeople
+    ''';
+    _logger.info(status);
+  }
 }

+ 28 - 10
mobile/lib/services/machine_learning/machine_learning_controller.dart

@@ -3,6 +3,8 @@ import "dart:io";
 
 import "package:battery_info/battery_info_plugin.dart";
 import "package:battery_info/model/android_battery_info.dart";
+import "package:battery_info/model/iso_battery_info.dart";
+import "package:flutter/foundation.dart" show kDebugMode;
 import "package:logging/logging.dart";
 import "package:photos/core/event_bus.dart";
 import "package:photos/events/machine_learning_control_event.dart";
@@ -17,7 +19,8 @@ class MachineLearningController {
 
   static const kMaximumTemperature = 42; // 42 degree celsius
   static const kMinimumBatteryLevel = 20; // 20%
-  static const kDefaultInteractionTimeout = Duration(seconds: 15);
+  static const kDefaultInteractionTimeout =
+      kDebugMode ? Duration(seconds: 3) : Duration(seconds: 5);
   static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"];
 
   bool _isDeviceHealthy = true;
@@ -31,13 +34,17 @@ class MachineLearningController {
       BatteryInfoPlugin()
           .androidBatteryInfoStream
           .listen((AndroidBatteryInfo? batteryInfo) {
-        _onBatteryStateUpdate(batteryInfo);
+        _onAndroidBatteryStateUpdate(batteryInfo);
       });
-    } else {
-      // Always run Machine Learning on iOS
-      _canRunML = true;
-      Bus.instance.fire(MachineLearningControlEvent(true));
     }
+    if (Platform.isIOS) {
+      BatteryInfoPlugin()
+          .iosBatteryInfoStream
+          .listen((IosBatteryInfo? batteryInfo) {
+        _oniOSBatteryStateUpdate(batteryInfo);
+      });
+    }
+    _fireControlEvent();
   }
 
   void onUserInteraction() {
@@ -53,7 +60,8 @@ class MachineLearningController {
   }
 
   void _fireControlEvent() {
-    final shouldRunML = _isDeviceHealthy && !_isUserInteracting;
+    final shouldRunML =
+        _isDeviceHealthy && (Platform.isAndroid ? !_isUserInteracting : true);
     if (shouldRunML != _canRunML) {
       _canRunML = shouldRunML;
       _logger.info(
@@ -76,18 +84,28 @@ class MachineLearningController {
     _startInteractionTimer();
   }
 
-  void _onBatteryStateUpdate(AndroidBatteryInfo? batteryInfo) {
+  void _onAndroidBatteryStateUpdate(AndroidBatteryInfo? batteryInfo) {
     _logger.info("Battery info: ${batteryInfo!.toJson()}");
-    _isDeviceHealthy = _computeIsDeviceHealthy(batteryInfo);
+    _isDeviceHealthy = _computeIsAndroidDeviceHealthy(batteryInfo);
     _fireControlEvent();
   }
 
-  bool _computeIsDeviceHealthy(AndroidBatteryInfo info) {
+  void _oniOSBatteryStateUpdate(IosBatteryInfo? batteryInfo) {
+    _logger.info("Battery info: ${batteryInfo!.toJson()}");
+    _isDeviceHealthy = _computeIsiOSDeviceHealthy(batteryInfo);
+    _fireControlEvent();
+  }
+
+  bool _computeIsAndroidDeviceHealthy(AndroidBatteryInfo info) {
     return _hasSufficientBattery(info.batteryLevel ?? kMinimumBatteryLevel) &&
         _isAcceptableTemperature(info.temperature ?? kMaximumTemperature) &&
         _isBatteryHealthy(info.health ?? "");
   }
 
+  bool _computeIsiOSDeviceHealthy(IosBatteryInfo info) {
+    return _hasSufficientBattery(info.batteryLevel ?? kMinimumBatteryLevel);
+  }
+
   bool _hasSufficientBattery(int batteryLevel) {
     return batteryLevel >= kMinimumBatteryLevel;
   }

+ 7 - 12
mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart

@@ -1,6 +1,5 @@
 import "dart:async";
 import "dart:collection";
-import "dart:io";
 import "dart:math" show min;
 
 import "package:computer/computer.dart";
@@ -103,17 +102,13 @@ class SemanticSearchService {
     if (shouldSyncImmediately) {
       unawaited(sync());
     }
-    if (Platform.isAndroid) {
-      Bus.instance.on<MachineLearningControlEvent>().listen((event) {
-        if (event.shouldRun) {
-          _startIndexing();
-        } else {
-          _pauseIndexing();
-        }
-      });
-    } else {
-      _startIndexing();
-    }
+    Bus.instance.on<MachineLearningControlEvent>().listen((event) {
+      if (event.shouldRun) {
+        _startIndexing();
+      } else {
+        _pauseIndexing();
+      }
+    });
   }
 
   Future<void> release() async {

+ 3 - 2
mobile/lib/services/search_service.dart

@@ -848,8 +848,9 @@ class SearchService {
         final String clusterName = "$clusterId";
 
         if (clusterIDToPersonID[clusterId] != null) {
-          throw Exception(
-            "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
+          // This should not happen, means a faceID is assigned to multiple persons.
+          _logger.severe(
+            "`getAllFace`: Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
           );
         }
         if (files.length < kMinimumClusterSizeSearchResult &&

+ 14 - 8
mobile/lib/ui/settings/debug/face_debug_section_widget.dart

@@ -79,7 +79,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
               final isEnabled =
                   await LocalSettings.instance.toggleFaceIndexing();
               if (!isEnabled) {
-                FaceMlService.instance.pauseIndexing();
+                FaceMlService.instance.pauseIndexingAndClustering();
               }
               if (mounted) {
                 setState(() {});
@@ -107,7 +107,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
                 setState(() {});
               }
             } catch (e, s) {
-              _logger.warning('indexing failed ', e, s);
+              _logger.warning('Remote fetch toggle failed ', e, s);
               await showGenericErrorDialog(context: context, error: e);
             }
           },
@@ -115,22 +115,25 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
         sectionOptionSpacing,
         MenuItemWidget(
           captionedTextWidget: CaptionedTextWidget(
-            title: FaceMlService.instance.canRunMLController
-                ? "canRunML enabled"
-                : "canRunML disabled",
+            title: FaceMlService.instance.debugIndexingDisabled
+                ? "Debug enable indexing again"
+                : "Debug disable indexing",
           ),
           pressedColor: getEnteColorScheme(context).fillFaint,
           trailingIcon: Icons.chevron_right_outlined,
           trailingIconIsMuted: true,
           onTap: () async {
             try {
-              FaceMlService.instance.canRunMLController =
-                  !FaceMlService.instance.canRunMLController;
+              FaceMlService.instance.debugIndexingDisabled =
+                  !FaceMlService.instance.debugIndexingDisabled;
+              if (FaceMlService.instance.debugIndexingDisabled) {
+                FaceMlService.instance.pauseIndexingAndClustering();
+              }
               if (mounted) {
                 setState(() {});
               }
             } catch (e, s) {
-              _logger.warning('canRunML toggle failed ', e, s);
+              _logger.warning('debugIndexingDisabled toggle failed ', e, s);
               await showGenericErrorDialog(context: context, error: e);
             }
           },
@@ -145,6 +148,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
           trailingIconIsMuted: true,
           onTap: () async {
             try {
+              FaceMlService.instance.debugIndexingDisabled = false;
               unawaited(FaceMlService.instance.indexAndClusterAll());
             } catch (e, s) {
               _logger.warning('indexAndClusterAll failed ', e, s);
@@ -162,6 +166,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
           trailingIconIsMuted: true,
           onTap: () async {
             try {
+              FaceMlService.instance.debugIndexingDisabled = false;
               unawaited(FaceMlService.instance.indexAllImages());
             } catch (e, s) {
               _logger.warning('indexing failed ', e, s);
@@ -189,6 +194,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
           onTap: () async {
             try {
               await PersonService.instance.storeRemoteFeedback();
+              FaceMlService.instance.debugIndexingDisabled = false;
               await FaceMlService.instance
                   .clusterAllImages(clusterInBuckets: true);
               Bus.instance.fire(PeopleChangedEvent());

+ 1 - 1
mobile/lib/ui/settings/machine_learning_settings_page.dart

@@ -208,7 +208,7 @@ class _MachineLearningSettingsPageState
               if (isEnabled) {
                 unawaited(FaceMlService.instance.ensureInitialized());
               } else {
-                FaceMlService.instance.pauseIndexing();
+                FaceMlService.instance.pauseIndexingAndClustering();
               }
               if (mounted) {
                 setState(() {});

+ 1 - 1
mobile/pubspec.yaml

@@ -12,7 +12,7 @@ description: ente photos application
 # Read more about iOS versioning at
 # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
 
-version: 0.8.98+618
+version: 0.8.101+624
 publish_to: none
 
 environment: