Pārlūkot izejas kodu

[mob][photos] Use vectors everywhere in cluster suggestion

laurenspriem 1 gadu atpakaļ
vecāks
revīzija
e829f7b62f

+ 29 - 36
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -3,6 +3,7 @@ import "dart:math" show Random, min;
 
 import "package:flutter/foundation.dart";
 import "package:logging/logging.dart";
+import "package:ml_linalg/linalg.dart";
 import "package:photos/core/event_bus.dart";
 import "package:photos/db/files_db.dart";
 // import "package:photos/events/files_updated_event.dart";
@@ -245,13 +246,13 @@ class ClusterFeedbackService {
     final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
     final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
     dev.log(
-      'existing clusters for ${p.data.name} are $personClusters',
+      '${p.data.name} has ${personClusters.length} existing clusters',
       name: "ClusterFeedbackService",
     );
 
     // Get and update the cluster summary to get the avg (centroid) and count
     final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
-    final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
+    final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
       allClusterIdsToCountMap,
       ignoredClusters,
     );
@@ -466,19 +467,19 @@ class ClusterFeedbackService {
     final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
     final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
     dev.log(
-      'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
+      '${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
       name: "getSuggestionsUsingMedian",
     );
 
     // First only do a simple check on the big clusters, if the person does not have small clusters yet
-    final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
+    final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
     final smallestPersonClusterSize = personClusters
         .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
         .reduce((value, element) => min(value, element));
     final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
     for (final minimumSize in checkSizes.toSet()) {
       if (smallestPersonClusterSize >= minimumSize) {
-        final Map<int, List<double>> clusterAvgBigClusters =
+        final Map<int, Vector> clusterAvgBigClusters =
             await _getUpdateClusterAvg(
           allClusterIdsToCountMap,
           ignoredClusters,
@@ -487,6 +488,7 @@ class ClusterFeedbackService {
         dev.log(
           'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
         );
+        w?.log('Calculate avg for min size $minimumSize');
         final List<(int, double)> suggestionsMeanBigClusters =
             _calcSuggestionsMean(
           clusterAvgBigClusters,
@@ -494,6 +496,7 @@ class ClusterFeedbackService {
           ignoredClusters,
           goodMeanDistance,
         );
+        w?.log('Calculate suggestions using mean for min size $minimumSize');
         if (suggestionsMeanBigClusters.isNotEmpty) {
           return suggestionsMeanBigClusters
               .map((e) => (e.$1, e.$2, true))
@@ -501,9 +504,10 @@ class ClusterFeedbackService {
         }
       }
     }
+    w?.reset();
 
     // Get and update the cluster summary to get the avg (centroid) and count
-    final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
+    final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
       allClusterIdsToCountMap,
       ignoredClusters,
     );
@@ -547,7 +551,7 @@ class ClusterFeedbackService {
       "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
     );
 
-    watch.logAndReset("Starting median test");
+    w?.logAndReset("Starting median test");
     // Take the embeddings from the person's clusters in one big list and sample from it
     final List<Uint8List> personEmbeddingsProto = [];
     for (final clusterID in personClusters) {
@@ -600,7 +604,7 @@ class ClusterFeedbackService {
         }
       }
     }
-    watch.log("Finished median test");
+    w?.log("Finished median test");
     if (suggestionsMedian.isEmpty) {
       _logger.info("No suggestions found using median");
       return [];
@@ -632,7 +636,7 @@ class ClusterFeedbackService {
     return finalSuggestionsMedian;
   }
 
-  Future<Map<int, List<double>>> _getUpdateClusterAvg(
+  Future<Map<int, Vector>> _getUpdateClusterAvg(
     Map<int, int> allClusterIdsToCountMap,
     Set<int> ignoredClusters, {
     int minClusterSize = 1,
@@ -649,7 +653,7 @@ class ClusterFeedbackService {
         await faceMlDb.getAllClusterSummary(minClusterSize);
     final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
 
-    final Map<int, List<double>> clusterAvg = {};
+    final Map<int, Vector> clusterAvg = {};
 
     dev.log(
       'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
@@ -666,7 +670,9 @@ class ClusterFeedbackService {
       }
       if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
         allClusterIds.remove(id);
-        clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
+        clusterAvg[id] = Vector.fromList(
+            EVector.fromBuffer(clusterToSummary[id]!.$1).values,
+            dtype: DType.float32,);
         alreadyUpdatedClustersCnt++;
       }
       if (allClusterIdsToCountMap[id]! < minClusterSize) {
@@ -731,19 +737,15 @@ class ClusterFeedbackService {
     );
 
     for (final clusterID in clusterEmbeddings.keys) {
-      late List<double> avg;
-      final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
-      final List<double> sum = List.filled(192, 0);
-      for (final embedding in embedings) {
-        final data = EVector.fromBuffer(embedding).values;
-        for (int i = 0; i < sum.length; i++) {
-          sum[i] += data[i];
-        }
-      }
-      avg = sum.map((e) => e / embedings.length).toList();
-      final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
+      final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
+      final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
+            EVector.fromBuffer(e).values,
+            dtype: DType.float32,
+          ),);
+      final avg = vectors.reduce((a, b) => a + b) / vectors.length;
+      final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
       updatesForClusterSummary[clusterID] =
-          (avgEmbeedingBuffer, embedings.length);
+          (avgEmbeddingBuffer, embeddings.length);
       // store the intermediate updates
       indexedInCurrentRun++;
       if (updatesForClusterSummary.length > 100) {
@@ -770,7 +772,7 @@ class ClusterFeedbackService {
 
   /// Returns a map of person's clusterID to map of closest clusterID to with disstance
   List<(int, double)> _calcSuggestionsMean(
-    Map<int, List<double>> clusterAvg,
+    Map<int, Vector> clusterAvg,
     Set<int> personClusters,
     Set<int> ignoredClusters,
     double maxClusterDistance, {
@@ -779,23 +781,14 @@ class ClusterFeedbackService {
     final Map<int, List<(int, double)>> suggestions = {};
     int suggestionCount = 0;
     final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
-    final clusterAvgVectors = clusterAvg.map(
-      (key, value) => MapEntry(
-        key,
-        Vector.fromList(
-          value,
-          dtype: DType.float32,
-        ),
-      ),
-    );
     w?.log('converted avg to vectors for ${clusterAvg.length} averages');
-    for (final otherClusterID in clusterAvgVectors.keys) {
+    for (final otherClusterID in clusterAvg.keys) {
       // ignore the cluster that belong to the person or is ignored
       if (personClusters.contains(otherClusterID) ||
           ignoredClusters.contains(otherClusterID)) {
         continue;
       }
-      final otherAvg = clusterAvgVectors[otherClusterID]!;
+      final Vector otherAvg = clusterAvg[otherClusterID]!;
       int? nearestPersonCluster;
       double? minDistance;
       for (final personCluster in personClusters) {
@@ -803,7 +796,7 @@ class ClusterFeedbackService {
           _logger.info('no avg for cluster $personCluster');
           continue;
         }
-        final avg = clusterAvgVectors[personCluster]!;
+        final Vector avg = clusterAvg[personCluster]!;
         final distance = 1 - avg.dot(otherAvg);
         if (distance < maxClusterDistance) {
           if (minDistance == null || distance < minDistance) {