[mob][photos] Use SIMD in sorting suggestions too

This commit is contained in:
laurenspriem 2024-04-24 16:19:10 +05:30
parent e829f7b62f
commit 3806ee3232

View file

@ -6,15 +6,12 @@ import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart";
// import "package:photos/events/local_photos_updated_event.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -555,17 +552,22 @@ class ClusterFeedbackService {
// Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) {
final Iterable<Uint8List> embedings =
final Iterable<Uint8List> embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embedings);
personEmbeddingsProto.addAll(embeddings);
}
final List<Uint8List> sampledEmbeddingsProto =
_randomSampleWithoutReplacement(
personEmbeddingsProto,
sampleSize,
);
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values)
final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
// Find the actual closest clusters for the person using median
@ -581,16 +583,20 @@ class ClusterFeedbackService {
otherEmbeddingsProto,
sampleSize,
);
final List<List<double>> sampledOtherEmbeddings =
sampledOtherEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values)
.toList(growable: false);
final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
.map(
(embedding) => Vector.fromList(
EVector.fromBuffer(embedding).values,
dtype: DType.float32,
),
)
.toList(growable: false);
// Calculate distances and find the median
final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) {
distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
distances.add(1 - embedding.dot(otherEmbedding));
}
}
distances.sort();
@ -671,8 +677,9 @@ class ClusterFeedbackService {
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id);
clusterAvg[id] = Vector.fromList(
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,);
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,
);
alreadyUpdatedClustersCnt++;
}
if (allClusterIdsToCountMap[id]! < minClusterSize) {
@ -738,10 +745,12 @@ class ClusterFeedbackService {
for (final clusterID in clusterEmbeddings.keys) {
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
EVector.fromBuffer(e).values,
dtype: DType.float32,
),);
final Iterable<Vector> vectors = embeddings.map(
(e) => Vector.fromList(
EVector.fromBuffer(e).values,
dtype: DType.float32,
),
);
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] =
@ -908,16 +917,16 @@ class ClusterFeedbackService {
final personEmbeddingsCount = personClusters
.map((e) => personClusterToSummary[e]!.$2)
.reduce((a, b) => a + b);
final List<double> personAvg = List.filled(192, 0);
Vector personAvg = Vector.filled(192, 0);
for (final personClusterID in personClusters) {
final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
final personClusterAvg = Vector.fromList(
EVector.fromBuffer(personClusterBlob).values,
dtype: DType.float32,
);
final clusterWeight =
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
for (int i = 0; i < personClusterAvg.length; i++) {
personAvg[i] += personClusterAvg[i] *
clusterWeight; // Weighted sum of the cluster averages
}
personAvg += personClusterAvg * clusterWeight;
}
w?.log('calculated person avg');
@ -933,16 +942,22 @@ class ClusterFeedbackService {
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
faceIDs,
);
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
(key, value) => MapEntry(
key,
Vector.fromList(
EVector.fromBuffer(value).values,
dtype: DType.float32,
),
),
);
w?.log(
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
);
final fileIdToDistanceMap = {};
for (final entry in faceIdToEmbeddingMap.entries) {
for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
cosineDistForNormVectors(
personAvg,
EVector.fromBuffer(entry.value).values,
);
1 - personAvg.dot(entry.value);
}
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) {