|
@@ -6,15 +6,12 @@ import "package:logging/logging.dart";
|
|
|
import "package:ml_linalg/linalg.dart";
|
|
|
import "package:photos/core/event_bus.dart";
|
|
|
import "package:photos/db/files_db.dart";
|
|
|
-// import "package:photos/events/files_updated_event.dart";
|
|
|
-// import "package:photos/events/local_photos_updated_event.dart";
|
|
|
import "package:photos/events/people_changed_event.dart";
|
|
|
import "package:photos/extensions/stop_watch.dart";
|
|
|
import "package:photos/face/db.dart";
|
|
|
import "package:photos/face/model/person.dart";
|
|
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|
|
import "package:photos/models/file/file.dart";
|
|
|
-import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
|
|
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
|
|
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
|
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
|
@@ -555,17 +552,22 @@ class ClusterFeedbackService {
|
|
|
// Take the embeddings from the person's clusters in one big list and sample from it
|
|
|
final List<Uint8List> personEmbeddingsProto = [];
|
|
|
for (final clusterID in personClusters) {
|
|
|
- final Iterable<Uint8List> embedings =
|
|
|
+ final Iterable<Uint8List> embeddings =
|
|
|
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
|
|
- personEmbeddingsProto.addAll(embedings);
|
|
|
+ personEmbeddingsProto.addAll(embeddings);
|
|
|
}
|
|
|
final List<Uint8List> sampledEmbeddingsProto =
|
|
|
_randomSampleWithoutReplacement(
|
|
|
personEmbeddingsProto,
|
|
|
sampleSize,
|
|
|
);
|
|
|
- final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
|
|
|
- .map((embedding) => EVector.fromBuffer(embedding).values)
|
|
|
+ final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
|
|
|
+ .map(
|
|
|
+ (embedding) => Vector.fromList(
|
|
|
+ EVector.fromBuffer(embedding).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ ),
|
|
|
+ )
|
|
|
.toList(growable: false);
|
|
|
|
|
|
// Find the actual closest clusters for the person using median
|
|
@@ -581,16 +583,20 @@ class ClusterFeedbackService {
|
|
|
otherEmbeddingsProto,
|
|
|
sampleSize,
|
|
|
);
|
|
|
- final List<List<double>> sampledOtherEmbeddings =
|
|
|
- sampledOtherEmbeddingsProto
|
|
|
- .map((embedding) => EVector.fromBuffer(embedding).values)
|
|
|
- .toList(growable: false);
|
|
|
+ final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
|
|
|
+ .map(
|
|
|
+ (embedding) => Vector.fromList(
|
|
|
+ EVector.fromBuffer(embedding).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ .toList(growable: false);
|
|
|
|
|
|
// Calculate distances and find the median
|
|
|
final List<double> distances = [];
|
|
|
for (final otherEmbedding in sampledOtherEmbeddings) {
|
|
|
for (final embedding in sampledEmbeddings) {
|
|
|
- distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
|
|
|
+ distances.add(1 - embedding.dot(otherEmbedding));
|
|
|
}
|
|
|
}
|
|
|
distances.sort();
|
|
@@ -671,8 +677,9 @@ class ClusterFeedbackService {
|
|
|
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
|
|
allClusterIds.remove(id);
|
|
|
clusterAvg[id] = Vector.fromList(
|
|
|
- EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
|
|
- dtype: DType.float32,);
|
|
|
+ EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ );
|
|
|
alreadyUpdatedClustersCnt++;
|
|
|
}
|
|
|
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
|
@@ -738,10 +745,12 @@ class ClusterFeedbackService {
|
|
|
|
|
|
for (final clusterID in clusterEmbeddings.keys) {
|
|
|
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
|
|
- final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
|
|
|
- EVector.fromBuffer(e).values,
|
|
|
- dtype: DType.float32,
|
|
|
- ),);
|
|
|
+ final Iterable<Vector> vectors = embeddings.map(
|
|
|
+ (e) => Vector.fromList(
|
|
|
+ EVector.fromBuffer(e).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ ),
|
|
|
+ );
|
|
|
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
|
|
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
|
|
updatesForClusterSummary[clusterID] =
|
|
@@ -908,16 +917,16 @@ class ClusterFeedbackService {
|
|
|
final personEmbeddingsCount = personClusters
|
|
|
.map((e) => personClusterToSummary[e]!.$2)
|
|
|
.reduce((a, b) => a + b);
|
|
|
- final List<double> personAvg = List.filled(192, 0);
|
|
|
+ Vector personAvg = Vector.filled(192, 0);
|
|
|
for (final personClusterID in personClusters) {
|
|
|
final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
|
|
|
- final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
|
|
|
+ final personClusterAvg = Vector.fromList(
|
|
|
+ EVector.fromBuffer(personClusterBlob).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ );
|
|
|
final clusterWeight =
|
|
|
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
|
|
|
- for (int i = 0; i < personClusterAvg.length; i++) {
|
|
|
- personAvg[i] += personClusterAvg[i] *
|
|
|
- clusterWeight; // Weighted sum of the cluster averages
|
|
|
- }
|
|
|
+ personAvg += personClusterAvg * clusterWeight;
|
|
|
}
|
|
|
w?.log('calculated person avg');
|
|
|
|
|
@@ -933,16 +942,22 @@ class ClusterFeedbackService {
|
|
|
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
|
|
|
faceIDs,
|
|
|
);
|
|
|
+ final faceIdToVectorMap = faceIdToEmbeddingMap.map(
|
|
|
+ (key, value) => MapEntry(
|
|
|
+ key,
|
|
|
+ Vector.fromList(
|
|
|
+ EVector.fromBuffer(value).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ );
|
|
|
w?.log(
|
|
|
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
|
|
|
);
|
|
|
final fileIdToDistanceMap = {};
|
|
|
- for (final entry in faceIdToEmbeddingMap.entries) {
|
|
|
+ for (final entry in faceIdToVectorMap.entries) {
|
|
|
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
|
|
- cosineDistForNormVectors(
|
|
|
- personAvg,
|
|
|
- EVector.fromBuffer(entry.value).values,
|
|
|
- );
|
|
|
+ 1 - personAvg.dot(entry.value);
|
|
|
}
|
|
|
w?.log('calculated distances for cluster $clusterID');
|
|
|
suggestion.filesInCluster.sort((b, a) {
|