|
@@ -12,6 +12,7 @@ import "package:photos/face/db.dart";
|
|
import "package:photos/face/model/person.dart";
|
|
import "package:photos/face/model/person.dart";
|
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|
import "package:photos/models/file/file.dart";
|
|
import "package:photos/models/file/file.dart";
|
|
|
|
+import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
|
@@ -594,7 +595,7 @@ class ClusterFeedbackService {
|
|
final List<double> distances = [];
|
|
final List<double> distances = [];
|
|
for (final otherEmbedding in sampledOtherEmbeddings) {
|
|
for (final otherEmbedding in sampledOtherEmbeddings) {
|
|
for (final embedding in sampledEmbeddings) {
|
|
for (final embedding in sampledEmbeddings) {
|
|
- distances.add(1 - embedding.dot(otherEmbedding));
|
|
|
|
|
|
+ distances.add(cosineDistanceSIMD(embedding,otherEmbedding));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
distances.sort();
|
|
distances.sort();
|
|
@@ -799,7 +800,7 @@ class ClusterFeedbackService {
|
|
continue;
|
|
continue;
|
|
}
|
|
}
|
|
final Vector avg = clusterAvg[personCluster]!;
|
|
final Vector avg = clusterAvg[personCluster]!;
|
|
- final distance = 1 - avg.dot(otherAvg);
|
|
|
|
|
|
+ final distance = cosineDistanceSIMD(avg,otherAvg);
|
|
if (distance < maxClusterDistance) {
|
|
if (distance < maxClusterDistance) {
|
|
if (minDistance == null || distance < minDistance) {
|
|
if (minDistance == null || distance < minDistance) {
|
|
minDistance = distance;
|
|
minDistance = distance;
|
|
@@ -950,7 +951,7 @@ class ClusterFeedbackService {
|
|
final fileIdToDistanceMap = {};
|
|
final fileIdToDistanceMap = {};
|
|
for (final entry in faceIdToVectorMap.entries) {
|
|
for (final entry in faceIdToVectorMap.entries) {
|
|
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
|
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
|
- 1 - personAvg.dot(entry.value);
|
|
|
|
|
|
+ cosineDistanceSIMD(personAvg,entry.value);
|
|
}
|
|
}
|
|
w?.log('calculated distances for cluster $clusterID');
|
|
w?.log('calculated distances for cluster $clusterID');
|
|
suggestion.filesInCluster.sort((b, a) {
|
|
suggestion.filesInCluster.sort((b, a) {
|