Przeglądaj źródła

[mob][photos] Assert that embeddings are always normalized

laurenspriem 1 rok temu
rodzic
commit
bd495c3860

+ 12 - 2
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -482,6 +482,14 @@ class FaceClusteringService {
       );
       );
     }
     }
 
 
+    // Assert that the embeddings are normalized
+    for (final faceInfo in faceInfos) {
+      if (faceInfo.vEmbedding != null) {
+        final norm = faceInfo.vEmbedding!.norm();
+        assert((norm - 1.0).abs() < 1e-5);
+      }
+    }
+
     // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
     // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
     if (fileIDToCreationTime != null) {
     if (fileIDToCreationTime != null) {
       faceInfos.sort((a, b) {
       faceInfos.sort((a, b) {
@@ -670,8 +678,10 @@ class FaceClusteringService {
       } else {
       } else {
         final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
         final newMeanVector = newEmbeddings.reduce((a, b) => a + b);
         final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
         final newMeanVectorNormalized = newMeanVector / newMeanVector.norm();
-        newClusterSummaries[clusterId] =
-            (EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), newCount);
+        newClusterSummaries[clusterId] = (
+          EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(),
+          newCount
+        );
       }
       }
     }
     }
     log(
     log(

+ 5 - 0
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -808,6 +808,11 @@ class ClusterFeedbackService {
     final alreadyUpdatedClustersCnt = serializationEmbeddings.$4;
     final alreadyUpdatedClustersCnt = serializationEmbeddings.$4;
     final smallerClustersCnt = serializationEmbeddings.$5;
     final smallerClustersCnt = serializationEmbeddings.$5;
 
 
+    // Assert that all existing clusterAvg are normalized
+    for (final avg in clusterAvg.values) {
+      assert((avg.norm() - 1.0).abs() < 1e-5);
+    }
+
     w?.log(
     w?.log(
       'serialization of embeddings',
       'serialization of embeddings',
     );
     );