Selaa lähdekoodia

[mob][photos] Correct suggestion logic again

laurenspriem 1 vuosi sitten
vanhempi
commit
e0fbb2620b

+ 7 - 28
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -455,7 +455,7 @@ class ClusterFeedbackService {
     double maxMedianDistance = 0.62,
     double goodMedianDistance = 0.55,
     double maxMeanDistance = 0.65,
-    double goodMeanDistance = 0.54,
+    double goodMeanDistance = 0.50,
   }) async {
     final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
     // Get all the cluster data
@@ -472,10 +472,10 @@ class ClusterFeedbackService {
         .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
         .reduce((value, element) => min(value, element));
     final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
+    late final Map<int, Vector> clusterAvgBigClusters;
     for (final minimumSize in checkSizes.toSet()) {
       if (smallestPersonClusterSize >= minimumSize) {
-        final Map<int, Vector> clusterAvgBigClusters =
-            await _getUpdateClusterAvg(
+        clusterAvgBigClusters = await _getUpdateClusterAvg(
           allClusterIdsToCountMap,
           ignoredClusters,
           minClusterSize: minimumSize,
@@ -502,29 +502,8 @@ class ClusterFeedbackService {
     }
     w?.reset();
 
-    // Get and update the cluster summary to get the avg (centroid) and count
-    final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
-      allClusterIdsToCountMap,
-      ignoredClusters,
-    );
-    w?.log(
-      'computed avg for ${clusterAvg.length} clusters,',
-    );
-
-    // Find the other cluster candidates based on the mean
-    final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
-      clusterAvg,
-      personClusters,
-      ignoredClusters,
-      goodMeanDistance,
-    );
-    if (suggestionsMean.isNotEmpty) {
-      return suggestionsMean
-          .map((e) => (e.$1, e.$2, true))
-          .toList(growable: false);
-    }
-
     // Find the other cluster candidates based on the median
+    final clusterAvg = clusterAvgBigClusters;
     final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
       clusterAvg,
       personClusters,
@@ -595,7 +574,7 @@ class ClusterFeedbackService {
       final List<double> distances = [];
       for (final otherEmbedding in sampledOtherEmbeddings) {
         for (final embedding in sampledEmbeddings) {
-          distances.add(cosineDistanceSIMD(embedding,otherEmbedding));
+          distances.add(cosineDistanceSIMD(embedding, otherEmbedding));
         }
       }
       distances.sort();
@@ -800,7 +779,7 @@ class ClusterFeedbackService {
           continue;
         }
         final Vector avg = clusterAvg[personCluster]!;
-        final distance = cosineDistanceSIMD(avg,otherAvg);
+        final distance = cosineDistanceSIMD(avg, otherAvg);
         if (distance < maxClusterDistance) {
           if (minDistance == null || distance < minDistance) {
             minDistance = distance;
@@ -951,7 +930,7 @@ class ClusterFeedbackService {
       final fileIdToDistanceMap = {};
       for (final entry in faceIdToVectorMap.entries) {
         fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
-            cosineDistanceSIMD(personAvg,entry.value);
+            cosineDistanceSIMD(personAvg, entry.value);
       }
       w?.log('calculated distances for cluster $clusterID');
       suggestion.filesInCluster.sort((b, a) {