Prechádzať zdrojové kódy

[mob][photos] Normalize weighted embeddings for cluster summary

laurenspriem 1 rok pred
rodič
commit
b4736fb1d6

+ 5 - 3
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -662,14 +662,16 @@ class FaceClusteringService {
         newEmbeddings.add(oldEmbeddings);
         newEmbeddings.add(oldEmbeddings);
         final newMeanVector =
         final newMeanVector =
             newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
             newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
+        final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
         newClusterSummaries[clusterId] = (
         newClusterSummaries[clusterId] = (
-          EVector(values: newMeanVector.toList()).writeToBuffer(),
+          EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
           oldCount + newCount
           oldCount + newCount
         );
         );
       } else {
       } else {
-        final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
+        final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
+        final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
         newClusterSummaries[clusterId] =
         newClusterSummaries[clusterId] =
-            (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
+            (EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
       }
       }
     }
     }
     log(
     log(

+ 3 - 2
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -868,7 +868,8 @@ class ClusterFeedbackService {
         ),
         ),
       );
       );
       final avg = vectors.reduce((a, b) => a + b) / vectors.length;
       final avg = vectors.reduce((a, b) => a + b) / vectors.length;
-      final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
+      final avgNormalized = avg / avg.norm();
+      final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer();
       updatesForClusterSummary[clusterID] =
       updatesForClusterSummary[clusterID] =
           (avgEmbeddingBuffer, embeddings.length);
           (avgEmbeddingBuffer, embeddings.length);
       // store the intermediate updates
       // store the intermediate updates
@@ -882,7 +883,7 @@ class ClusterFeedbackService {
           );
           );
         }
         }
       }
       }
-      clusterAvg[clusterID] = avg;
+      clusterAvg[clusterID] = avgNormalized;
     }
     }
     if (updatesForClusterSummary.isNotEmpty) {
     if (updatesForClusterSummary.isNotEmpty) {
       await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
       await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);