Browse Source

[mob][photos] Use cosineDistanceSIMD

laurenspriem 1 year ago
parent
commit
462d1d4854

+ 13 - 0
mobile/lib/services/machine_learning/face_ml/face_clustering/cosine_distance.dart

@@ -1,5 +1,18 @@
 import 'dart:math' show sqrt;
 
+import "package:ml_linalg/vector.dart";
+
+/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
+/// 
+/// WARNING: This assumes both vectors are already normalized!
+double cosineDistanceSIMD(Vector vector1, Vector vector2) {
+  if (vector1.length != vector2.length) {
+    throw ArgumentError('Vectors must be the same length');
+  }
+
+  return 1 - vector1.dot(vector2);
+}
+
 /// Calculates the cosine distance between two embeddings/vectors.
 ///
 /// Throws an ArgumentError if the vectors are of different lengths or

+ 12 - 10
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -560,10 +560,10 @@ class FaceClusteringService {
       for (int j = i - 1; j >= 0; j--) {
         late double distance;
         if (sortedFaceInfos[i].vEmbedding != null) {
-          distance = 1.0 -
-              sortedFaceInfos[i]
-                  .vEmbedding!
-                  .dot(sortedFaceInfos[j].vEmbedding!);
+          distance = cosineDistanceSIMD(
+            sortedFaceInfos[i].vEmbedding!,
+            sortedFaceInfos[j].vEmbedding!,
+          );
         } else {
           distance = cosineDistForNormVectors(
             sortedFaceInfos[i].embedding!,
@@ -804,8 +804,10 @@ class FaceClusteringService {
       double closestDistance = double.infinity;
       for (int j = 0; j < totalFaces; j++) {
         if (i == j) continue;
-        final double distance =
-            1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
+        final double distance = cosineDistanceSIMD(
+          faceInfos[i].vEmbedding!,
+          faceInfos[j].vEmbedding!,
+        );
         if (distance < closestDistance) {
           closestDistance = distance;
           closestIdx = j;
@@ -855,10 +857,10 @@ class FaceClusteringService {
       for (int i = 0; i < clusterIds.length; i++) {
         for (int j = 0; j < clusterIds.length; j++) {
           if (i == j) continue;
-          final double newDistance = 1.0 -
-              clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
-                    clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
-                  );
+          final double newDistance = cosineDistanceSIMD(
+            clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
+            clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
+          );
           if (newDistance < distance) {
             distance = newDistance;
             clusterIDsToMerge = (clusterIds[i], clusterIds[j]);

+ 4 - 3
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -12,6 +12,7 @@ import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
 import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import "package:photos/models/file/file.dart";
+import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
 import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
 import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@@ -594,7 +595,7 @@ class ClusterFeedbackService {
       final List<double> distances = [];
       for (final otherEmbedding in sampledOtherEmbeddings) {
         for (final embedding in sampledEmbeddings) {
-          distances.add(1 - embedding.dot(otherEmbedding));
+          distances.add(cosineDistanceSIMD(embedding,otherEmbedding));
         }
       }
       distances.sort();
@@ -799,7 +800,7 @@ class ClusterFeedbackService {
           continue;
         }
         final Vector avg = clusterAvg[personCluster]!;
-        final distance = 1 - avg.dot(otherAvg);
+        final distance = cosineDistanceSIMD(avg,otherAvg);
         if (distance < maxClusterDistance) {
           if (minDistance == null || distance < minDistance) {
             minDistance = distance;
@@ -950,7 +951,7 @@ class ClusterFeedbackService {
       final fileIdToDistanceMap = {};
       for (final entry in faceIdToVectorMap.entries) {
         fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
-            1 - personAvg.dot(entry.value);
+            cosineDistanceSIMD(personAvg,entry.value);
       }
       w?.log('calculated distances for cluster $clusterID');
       suggestion.filesInCluster.sort((b, a) {

+ 8 - 8
mobile/lib/ui/viewer/people/cluster_app_bar.dart

@@ -207,14 +207,14 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
         if (embedding.key == otherEmbedding.key) {
           continue;
         }
-        final distance64 = 1.0 -
-            Vector.fromList(embedding.value, dtype: DType.float64).dot(
-              Vector.fromList(otherEmbedding.value, dtype: DType.float64),
-            );
-        final distance32 = 1.0 -
-            Vector.fromList(embedding.value, dtype: DType.float32).dot(
-              Vector.fromList(otherEmbedding.value, dtype: DType.float32),
-            );
+        final distance64 = cosineDistanceSIMD(
+          Vector.fromList(embedding.value, dtype: DType.float64),
+          Vector.fromList(otherEmbedding.value, dtype: DType.float64),
+        );
+        final distance32 = cosineDistanceSIMD(
+          Vector.fromList(embedding.value, dtype: DType.float32),
+          Vector.fromList(otherEmbedding.value, dtype: DType.float32),
+        );
         final distance = cosineDistForNormVectors(
           embedding.value,
           otherEmbedding.value,