[mob][photos] Use SIMD in sorting suggestions too
This commit is contained in:
parent
e829f7b62f
commit
3806ee3232
1 changed files with 44 additions and 29 deletions
|
@ -6,15 +6,12 @@ import "package:logging/logging.dart";
|
|||
import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
// import "package:photos/events/files_updated_event.dart";
|
||||
// import "package:photos/events/local_photos_updated_event.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/extensions/stop_watch.dart";
|
||||
import "package:photos/face/db.dart";
|
||||
import "package:photos/face/model/person.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
|
@ -555,17 +552,22 @@ class ClusterFeedbackService {
|
|||
// Take the embeddings from the person's clusters in one big list and sample from it
|
||||
final List<Uint8List> personEmbeddingsProto = [];
|
||||
for (final clusterID in personClusters) {
|
||||
final Iterable<Uint8List> embedings =
|
||||
final Iterable<Uint8List> embeddings =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
||||
personEmbeddingsProto.addAll(embedings);
|
||||
personEmbeddingsProto.addAll(embeddings);
|
||||
}
|
||||
final List<Uint8List> sampledEmbeddingsProto =
|
||||
_randomSampleWithoutReplacement(
|
||||
personEmbeddingsProto,
|
||||
sampleSize,
|
||||
);
|
||||
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
|
||||
.map((embedding) => EVector.fromBuffer(embedding).values)
|
||||
final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
|
||||
.map(
|
||||
(embedding) => Vector.fromList(
|
||||
EVector.fromBuffer(embedding).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
)
|
||||
.toList(growable: false);
|
||||
|
||||
// Find the actual closest clusters for the person using median
|
||||
|
@ -581,16 +583,20 @@ class ClusterFeedbackService {
|
|||
otherEmbeddingsProto,
|
||||
sampleSize,
|
||||
);
|
||||
final List<List<double>> sampledOtherEmbeddings =
|
||||
sampledOtherEmbeddingsProto
|
||||
.map((embedding) => EVector.fromBuffer(embedding).values)
|
||||
.toList(growable: false);
|
||||
final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
|
||||
.map(
|
||||
(embedding) => Vector.fromList(
|
||||
EVector.fromBuffer(embedding).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
)
|
||||
.toList(growable: false);
|
||||
|
||||
// Calculate distances and find the median
|
||||
final List<double> distances = [];
|
||||
for (final otherEmbedding in sampledOtherEmbeddings) {
|
||||
for (final embedding in sampledEmbeddings) {
|
||||
distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
|
||||
distances.add(1 - embedding.dot(otherEmbedding));
|
||||
}
|
||||
}
|
||||
distances.sort();
|
||||
|
@ -671,8 +677,9 @@ class ClusterFeedbackService {
|
|||
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||
allClusterIds.remove(id);
|
||||
clusterAvg[id] = Vector.fromList(
|
||||
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
||||
dtype: DType.float32,);
|
||||
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
alreadyUpdatedClustersCnt++;
|
||||
}
|
||||
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||
|
@ -738,10 +745,12 @@ class ClusterFeedbackService {
|
|||
|
||||
for (final clusterID in clusterEmbeddings.keys) {
|
||||
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
||||
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
|
||||
EVector.fromBuffer(e).values,
|
||||
dtype: DType.float32,
|
||||
),);
|
||||
final Iterable<Vector> vectors = embeddings.map(
|
||||
(e) => Vector.fromList(
|
||||
EVector.fromBuffer(e).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
);
|
||||
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
||||
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
|
@ -908,16 +917,16 @@ class ClusterFeedbackService {
|
|||
final personEmbeddingsCount = personClusters
|
||||
.map((e) => personClusterToSummary[e]!.$2)
|
||||
.reduce((a, b) => a + b);
|
||||
final List<double> personAvg = List.filled(192, 0);
|
||||
Vector personAvg = Vector.filled(192, 0);
|
||||
for (final personClusterID in personClusters) {
|
||||
final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
|
||||
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
|
||||
final personClusterAvg = Vector.fromList(
|
||||
EVector.fromBuffer(personClusterBlob).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
final clusterWeight =
|
||||
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
|
||||
for (int i = 0; i < personClusterAvg.length; i++) {
|
||||
personAvg[i] += personClusterAvg[i] *
|
||||
clusterWeight; // Weighted sum of the cluster averages
|
||||
}
|
||||
personAvg += personClusterAvg * clusterWeight;
|
||||
}
|
||||
w?.log('calculated person avg');
|
||||
|
||||
|
@ -933,16 +942,22 @@ class ClusterFeedbackService {
|
|||
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
|
||||
faceIDs,
|
||||
);
|
||||
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
|
||||
(key, value) => MapEntry(
|
||||
key,
|
||||
Vector.fromList(
|
||||
EVector.fromBuffer(value).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
),
|
||||
);
|
||||
w?.log(
|
||||
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
|
||||
);
|
||||
final fileIdToDistanceMap = {};
|
||||
for (final entry in faceIdToEmbeddingMap.entries) {
|
||||
for (final entry in faceIdToVectorMap.entries) {
|
||||
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
||||
cosineDistForNormVectors(
|
||||
personAvg,
|
||||
EVector.fromBuffer(entry.value).values,
|
||||
);
|
||||
1 - personAvg.dot(entry.value);
|
||||
}
|
||||
w?.log('calculated distances for cluster $clusterID');
|
||||
suggestion.filesInCluster.sort((b, a) {
|
||||
|
|
Loading…
Add table
Reference in a new issue