Ver Fonte

[mob][photos] Precompute cluster summaries incrementally during clustering

laurenspriem há 1 ano atrás
pai
commit
edf99385dc

+ 85 - 10
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -42,6 +42,15 @@ class FaceInfo {
 
 enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
 
+class ClusteringResult {
+  final Map<String, int> newFaceIdToCluster;
+  final Map<int, (Uint8List, int)>? newClusterSummaries;
+  ClusteringResult({
+    required this.newFaceIdToCluster,
+    required this.newClusterSummaries,
+  });
+}
+
 class FaceClusteringService {
   final _logger = Logger("FaceLinearClustering");
   final _computer = Computer.shared();
@@ -191,13 +200,14 @@ class FaceClusteringService {
   /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
   ///
   /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
-  Future<Map<String, int>?> predictLinear(
+  Future<ClusteringResult?> predictLinear(
     Set<FaceInfoForClustering> input, {
     Map<int, int>? fileIDToCreationTime,
     double distanceThreshold = kRecommendedDistanceThreshold,
     double conservativeDistanceThreshold = kConservativeDistanceThreshold,
     bool useDynamicThreshold = true,
     int? offset,
+    required Map<int, (Uint8List, int)> oldClusterSummaries,
   }) async {
     if (input.isEmpty) {
       _logger.warning(
@@ -219,7 +229,7 @@ class FaceClusteringService {
       final stopwatchClustering = Stopwatch()..start();
       // final Map<String, int> faceIdToCluster =
       //     await _runLinearClusteringInComputer(input);
-      final Map<String, int> faceIdToCluster = await _runInIsolate(
+      final ClusteringResult? faceIdToCluster = await _runInIsolate(
         (
           ClusterOperation.linearIncrementalClustering,
           {
@@ -229,6 +239,7 @@ class FaceClusteringService {
             'conservativeDistanceThreshold': conservativeDistanceThreshold,
             'useDynamicThreshold': useDynamicThreshold,
             'offset': offset,
+            'oldClusterSummaries': oldClusterSummaries,
           }
         ),
       );
@@ -247,7 +258,7 @@ class FaceClusteringService {
   }
 
   /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
-  Future<Map<String, int>?> predictLinearComputer(
+  Future<ClusteringResult?> predictLinearComputer(
     Map<String, Uint8List> input, {
     Map<int, int>? fileIDToCreationTime,
     double distanceThreshold = kRecommendedDistanceThreshold,
@@ -256,7 +267,7 @@ class FaceClusteringService {
       _logger.warning(
         "Linear Clustering dataset of embeddings is empty, returning empty list.",
       );
-      return {};
+      return null;
     }
 
     // Clustering inside the isolate
@@ -290,7 +301,7 @@ class FaceClusteringService {
           "useDynamicThreshold": false,
         },
         taskName: "createImageEmbedding",
-      ) as Map<String, int>;
+      ) as ClusteringResult;
       final endTime = DateTime.now();
       _logger.info(
         "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
@@ -369,11 +380,12 @@ class FaceClusteringService {
         _logger.info(
           'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
         );
-        return predictLinearComputer(
+        final clusterResult = await predictLinearComputer(
           input,
           fileIDToCreationTime: fileIDToCreationTime,
           distanceThreshold: distanceThreshold,
         );
+        return clusterResult?.newFaceIdToCluster;
       }
     } catch (e, s) {
       _logger.severe(e, s);
@@ -430,7 +442,7 @@ class FaceClusteringService {
     return clusterFaceIDs;
   }
 
-  static Map<String, int> runLinearClustering(Map args) {
+  static ClusteringResult? runLinearClustering(Map args) {
     // final input = args['input'] as Map<String, (int?, Uint8List)>;
     final input = args['input'] as Set<FaceInfoForClustering>;
     final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
@@ -439,6 +451,8 @@ class FaceClusteringService {
         args['conservativeDistanceThreshold'] as double;
     final useDynamicThreshold = args['useDynamicThreshold'] as bool;
     final offset = args['offset'] as int?;
+    final oldClusterSummaries =
+        args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
 
     log(
       "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
@@ -507,7 +521,7 @@ class FaceClusteringService {
     int dynamicThresholdCount = 0;
 
     if (sortedFaceInfos.isEmpty) {
-      return {};
+      return null;
     }
 
     // Start actual clustering
@@ -584,7 +598,9 @@ class FaceClusteringService {
 
     // Finally, assign the new clusterId to the faces
     final Map<String, int> newFaceIdToCluster = {};
-    for (final faceInfo in sortedFaceInfos.sublist(alreadyClusteredCount)) {
+    final newClusteredFaceInfos =
+        sortedFaceInfos.sublist(alreadyClusteredCount);
+    for (final faceInfo in newClusteredFaceInfos) {
       newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
     }
 
@@ -598,10 +614,69 @@ class FaceClusteringService {
       );
     }
 
+    // Now calculate the mean of the embeddings for each cluster and update the cluster summaries
+    Map<int, (Uint8List, int)>? newClusterSummaries;
+    if (oldClusterSummaries != null) {
+      newClusterSummaries = FaceClusteringService.updateClusterSummaries(
+        oldSummary: oldClusterSummaries,
+        newFaceInfos: newClusteredFaceInfos,
+      );
+    }
+
     // analyze the results
     FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
 
-    return newFaceIdToCluster;
+    return ClusteringResult(
+      newFaceIdToCluster: newFaceIdToCluster,
+      newClusterSummaries: newClusterSummaries,
+    );
+  }
+
+  static Map<int, (Uint8List, int)> updateClusterSummaries({
+    required Map<int, (Uint8List, int)> oldSummary,
+    required List<FaceInfo> newFaceInfos,
+  }) {
+    final calcSummariesStart = DateTime.now();
+    final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
+    for (final faceInfo in newFaceInfos) {
+      if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
+        newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
+      } else {
+        newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
+      }
+    }
+
+    final Map<int, (Uint8List, int)> newClusterSummaries = {};
+    for (final clusterId in newClusterIdToFaceInfos.keys) {
+      final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
+          .map((faceInfo) => faceInfo.vEmbedding!)
+          .toList();
+      final newCount = newEmbeddings.length;
+      if (oldSummary.containsKey(clusterId)) {
+        final oldMean = Vector.fromList(
+          EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
+          dtype: DType.float32,
+        );
+        final oldCount = oldSummary[clusterId]!.$2;
+        final oldEmbeddings = oldMean * oldCount;
+        newEmbeddings.add(oldEmbeddings);
+        final newMeanVector =
+            newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
+        newClusterSummaries[clusterId] = (
+          EVector(values: newMeanVector.toList()).writeToBuffer(),
+          oldCount + newCount
+        );
+      } else {
+        final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
+        newClusterSummaries[clusterId] =
+            (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
+      }
+    }
+    log(
+      "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
+    );
+
+    return newClusterSummaries;
   }
 
   static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {

+ 19 - 7
mobile/lib/services/machine_learning/face_ml/face_ml_service.dart

@@ -300,6 +300,10 @@ class FaceMlService {
       // Get a sense of the total number of faces in the database
       final int totalFaces = await FaceMLDataDB.instance
           .getTotalFaceCount(minFaceScore: minFaceScore);
+
+      // Get the current cluster statistics
+      final Map<int, (Uint8List, int)> oldClusterSummaries =
+          await FaceMLDataDB.instance.getAllClusterSummary();
       if (clusterInBuckets) {
         // read the creation times from Files DB, in a map from fileID to creation time
         final fileIDToCreationTime =
@@ -332,18 +336,22 @@ class FaceMlService {
             break;
           }
 
-          final faceIdToCluster =
+          final clusteringResult =
               await FaceClusteringService.instance.predictLinear(
             faceInfoForClustering,
             fileIDToCreationTime: fileIDToCreationTime,
             offset: offset,
+            oldClusterSummaries: oldClusterSummaries,
           );
-          if (faceIdToCluster == null) {
+          if (clusteringResult == null) {
             _logger.warning("faceIdToCluster is null");
             return;
           }
 
-          await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
+          await FaceMLDataDB.instance
+              .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
+          await FaceMLDataDB.instance
+              .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
           _logger.info(
             'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
           );
@@ -374,12 +382,13 @@ class FaceMlService {
             '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
 
         // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
-        final faceIdToCluster =
+        final clusteringResult =
             await FaceClusteringService.instance.predictLinear(
           faceInfoForClustering,
           fileIDToCreationTime: fileIDToCreationTime,
+          oldClusterSummaries: oldClusterSummaries,
         );
-        if (faceIdToCluster == null) {
+        if (clusteringResult == null) {
           _logger.warning("faceIdToCluster is null");
           return;
         }
@@ -390,9 +399,12 @@ class FaceMlService {
 
         // Store the updated clusterIDs in the database
         _logger.info(
-          'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
+          'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
         );
-        await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
+        await FaceMLDataDB.instance
+            .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
+        await FaceMLDataDB.instance
+            .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
         _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
             '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
       }