Ver código fonte

[mob] Add merges to predictComplete method

laurenspriem 1 ano atrás
pai
commit
ba58ac1358

+ 83 - 8
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -249,6 +249,7 @@ class FaceClusteringService {
     Map<String, Uint8List> input, {
     Map<int, int>? fileIDToCreationTime,
     double distanceThreshold = kRecommendedDistanceThreshold,
+    double mergeThreshold = 0.30,
   }) async {
     if (input.isEmpty) {
       _logger.warning(
@@ -270,6 +271,7 @@ class FaceClusteringService {
           "input": input,
           "fileIDToCreationTime": fileIDToCreationTime,
           "distanceThreshold": distanceThreshold,
+          "mergeThreshold": mergeThreshold,
         },
         taskName: "createImageEmbedding",
       ) as Map<String, int>;
@@ -578,12 +580,11 @@ class FaceClusteringService {
     );
   }
 
-  
-
   static Map<String, int> runCompleteClustering(Map args) {
     final input = args['input'] as Map<String, Uint8List>;
     final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
     final distanceThreshold = args['distanceThreshold'] as double;
+    final mergeThreshold = args['mergeThreshold'] as double;
 
     log(
       "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
@@ -637,10 +638,11 @@ class FaceClusteringService {
     final Map<String, int> newFaceIdToCluster = {};
     final stopwatchClustering = Stopwatch()..start();
     for (int i = 0; i < totalFaces; i++) {
+      if (faceInfos[i].clusterId != null) continue;
       int closestIdx = -1;
       double closestDistance = double.infinity;
-      if (i % 250 == 0) {
-        log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
+      if (i + 1 % 250 == 0) {
+        log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces");
       }
       for (int j = 0; j < totalFaces; j++) {
         if (i == j) continue;
@@ -656,18 +658,91 @@ class FaceClusteringService {
         if (faceInfos[closestIdx].clusterId == null) {
           clusterID++;
           faceInfos[closestIdx].clusterId = clusterID;
-          newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
         }
         faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
-        newFaceIdToCluster[faceInfos[i].faceID] =
-            faceInfos[closestIdx].clusterId!;
       } else {
         clusterID++;
         faceInfos[i].clusterId = clusterID;
-        newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
       }
     }
 
+    // Now calculate the mean of the embeddings for each cluster
+    final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
+    for (final faceInfo in faceInfos) {
+      if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
+        clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
+      } else {
+        clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
+      }
+    }
+    final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
+    for (final clusterId in clusterIdToFaceInfos.keys) {
+      final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
+          .map((faceInfo) => faceInfo.vEmbedding!)
+          .toList();
+      final count = clusterIdToFaceInfos[clusterId]!.length;
+      final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
+      clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
+    }
+
+    // Now merge the clusters that are close to each other, based on mean embedding
+    final List<(int, int)> mergedClustersList = [];
+    final List<int> clusterIds =
+        clusterIdToMeanEmbeddingAndWeight.keys.toList();
+    log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
+    while (true) {
+      if (clusterIds.length < 2) break;
+      double distance = double.infinity;
+      (int, int) clusterIDsToMerge = (-1, -1);
+      for (int i = 0; i < clusterIds.length; i++) {
+        for (int j = 0; j < clusterIds.length; j++) {
+          if (i == j) continue;
+          final double newDistance = 1.0 -
+              clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
+                    clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
+                  );
+          if (newDistance < distance) {
+            distance = newDistance;
+            clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
+          }
+        }
+      }
+      if (distance < mergeThreshold) {
+        mergedClustersList.add(clusterIDsToMerge);
+        final clusterID1 = clusterIDsToMerge.$1;
+        final clusterID2 = clusterIDsToMerge.$2;
+        final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
+        final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
+        final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
+        final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
+        final weight1 = count1 / (count1 + count2);
+        final weight2 = count2 / (count1 + count2);
+        clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
+          mean1 * weight1 + mean2 * weight2,
+          count1 + count2,
+        );
+        clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
+        clusterIds.remove(clusterID2);
+      } else {
+        break;
+      }
+    }
+    log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
+
+    // Now assign the new clusterId to the faces
+    for (final faceInfo in faceInfos) {
+      for (final mergedClusters in mergedClustersList) {
+        if (faceInfo.clusterId == mergedClusters.$2) {
+          faceInfo.clusterId = mergedClusters.$1;
+        }
+      }
+    }
+
+    // Finally, assign the new clusterId to the faces
+    for (final faceInfo in faceInfos) {
+      newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
+    }
+
     stopwatchClustering.stop();
     log(
       ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',

+ 25 - 1
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import "package:photos/models/file/file.dart";
 import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
 import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
+// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
+// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import "package:photos/services/search_service.dart";
@@ -232,14 +234,36 @@ class ClusterFeedbackService {
         maxClusterID++;
       }
     } else {
+      // final clusteringInput = embeddings
+      //     .map((key, value) {
+      //       return MapEntry(
+      //         key,
+      //         FaceInfoForClustering(
+      //           faceID: key,
+      //           embeddingBytes: value,
+      //           faceScore: kMinHighQualityFaceScore + 0.01,
+      //           blurValue: kLapacianDefault,
+      //         ),
+      //       );
+      //     })
+      //     .values
+      //     .toSet();
+      // final faceIdToCluster =
+      //     await FaceClusteringService.instance.predictLinear(
+      //   clusteringInput,
+      //   fileIDToCreationTime: fileIDToCreationTime,
+      //   distanceThreshold: 0.23,
+      //   useDynamicThreshold: false,
+      // );
       final faceIdToCluster =
           await FaceClusteringService.instance.predictComplete(
         embeddings,
         fileIDToCreationTime: fileIDToCreationTime,
         distanceThreshold: 0.30,
+        mergeThreshold: 0.30,
       );
 
-      if (faceIdToCluster.isEmpty) {
+      if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
         _logger.info('No clusters found');
         return {};
       } else {