Browse Source

[mob][photos] Use SIMD in sorting suggestions too

laurenspriem 1 year ago
parent
commit
3806ee3232

+ 44 - 29
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -6,15 +6,12 @@ import "package:logging/logging.dart";
 import "package:ml_linalg/linalg.dart";
 import "package:photos/core/event_bus.dart";
 import "package:photos/db/files_db.dart";
-// import "package:photos/events/files_updated_event.dart";
-// import "package:photos/events/local_photos_updated_event.dart";
 import "package:photos/events/people_changed_event.dart";
 import "package:photos/extensions/stop_watch.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
 import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import "package:photos/models/file/file.dart";
-import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
 import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
 import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@@ -555,17 +552,22 @@ class ClusterFeedbackService {
     // Take the embeddings from the person's clusters in one big list and sample from it
     final List<Uint8List> personEmbeddingsProto = [];
     for (final clusterID in personClusters) {
-      final Iterable<Uint8List> embedings =
+      final Iterable<Uint8List> embeddings =
           await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
-      personEmbeddingsProto.addAll(embedings);
+      personEmbeddingsProto.addAll(embeddings);
     }
     final List<Uint8List> sampledEmbeddingsProto =
         _randomSampleWithoutReplacement(
       personEmbeddingsProto,
       sampleSize,
     );
-    final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
-        .map((embedding) => EVector.fromBuffer(embedding).values)
+    final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
+        .map(
+          (embedding) => Vector.fromList(
+            EVector.fromBuffer(embedding).values,
+            dtype: DType.float32,
+          ),
+        )
         .toList(growable: false);
 
     // Find the actual closest clusters for the person using median
@@ -581,16 +583,20 @@ class ClusterFeedbackService {
         otherEmbeddingsProto,
         sampleSize,
       );
-      final List<List<double>> sampledOtherEmbeddings =
-          sampledOtherEmbeddingsProto
-              .map((embedding) => EVector.fromBuffer(embedding).values)
-              .toList(growable: false);
+      final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
+          .map(
+            (embedding) => Vector.fromList(
+              EVector.fromBuffer(embedding).values,
+              dtype: DType.float32,
+            ),
+          )
+          .toList(growable: false);
 
       // Calculate distances and find the median
       final List<double> distances = [];
       for (final otherEmbedding in sampledOtherEmbeddings) {
         for (final embedding in sampledEmbeddings) {
-          distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
+          distances.add(1 - embedding.dot(otherEmbedding));
         }
       }
       distances.sort();
@@ -671,8 +677,9 @@ class ClusterFeedbackService {
       if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
         allClusterIds.remove(id);
         clusterAvg[id] = Vector.fromList(
-            EVector.fromBuffer(clusterToSummary[id]!.$1).values,
-            dtype: DType.float32,);
+          EVector.fromBuffer(clusterToSummary[id]!.$1).values,
+          dtype: DType.float32,
+        );
         alreadyUpdatedClustersCnt++;
       }
       if (allClusterIdsToCountMap[id]! < minClusterSize) {
@@ -738,10 +745,12 @@ class ClusterFeedbackService {
 
     for (final clusterID in clusterEmbeddings.keys) {
       final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
-      final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
-            EVector.fromBuffer(e).values,
-            dtype: DType.float32,
-          ),);
+      final Iterable<Vector> vectors = embeddings.map(
+        (e) => Vector.fromList(
+          EVector.fromBuffer(e).values,
+          dtype: DType.float32,
+        ),
+      );
       final avg = vectors.reduce((a, b) => a + b) / vectors.length;
       final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
       updatesForClusterSummary[clusterID] =
@@ -908,16 +917,16 @@ class ClusterFeedbackService {
     final personEmbeddingsCount = personClusters
         .map((e) => personClusterToSummary[e]!.$2)
         .reduce((a, b) => a + b);
-    final List<double> personAvg = List.filled(192, 0);
+    Vector personAvg = Vector.filled(192, 0);
     for (final personClusterID in personClusters) {
       final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
-      final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
+      final personClusterAvg = Vector.fromList(
+        EVector.fromBuffer(personClusterBlob).values,
+        dtype: DType.float32,
+      );
       final clusterWeight =
           personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
-      for (int i = 0; i < personClusterAvg.length; i++) {
-        personAvg[i] += personClusterAvg[i] *
-            clusterWeight; // Weighted sum of the cluster averages
-      }
+      personAvg += personClusterAvg * clusterWeight;
     }
     w?.log('calculated person avg');
 
@@ -933,16 +942,22 @@ class ClusterFeedbackService {
       final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
         faceIDs,
       );
+      final faceIdToVectorMap = faceIdToEmbeddingMap.map(
+        (key, value) => MapEntry(
+          key,
+          Vector.fromList(
+            EVector.fromBuffer(value).values,
+            dtype: DType.float32,
+          ),
+        ),
+      );
       w?.log(
         'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
       );
       final fileIdToDistanceMap = {};
-      for (final entry in faceIdToEmbeddingMap.entries) {
+      for (final entry in faceIdToVectorMap.entries) {
         fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
-            cosineDistForNormVectors(
-          personAvg,
-          EVector.fromBuffer(entry.value).values,
-        );
+            1 - personAvg.dot(entry.value);
       }
       w?.log('calculated distances for cluster $clusterID');
       suggestion.filesInCluster.sort((b, a) {