[mob] Add merges to predictComplete method
This commit is contained in:
parent
7a5e1263e0
commit
ba58ac1358
2 changed files with 108 additions and 9 deletions
|
@ -249,6 +249,7 @@ class FaceClusteringService {
|
|||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
double mergeThreshold = 0.30,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
_logger.warning(
|
||||
|
@ -270,6 +271,7 @@ class FaceClusteringService {
|
|||
"input": input,
|
||||
"fileIDToCreationTime": fileIDToCreationTime,
|
||||
"distanceThreshold": distanceThreshold,
|
||||
"mergeThreshold": mergeThreshold,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
) as Map<String, int>;
|
||||
|
@ -578,12 +580,11 @@ class FaceClusteringService {
|
|||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static Map<String, int> runCompleteClustering(Map args) {
|
||||
final input = args['input'] as Map<String, Uint8List>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
final mergeThreshold = args['mergeThreshold'] as double;
|
||||
|
||||
log(
|
||||
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
||||
|
@ -637,10 +638,11 @@ class FaceClusteringService {
|
|||
final Map<String, int> newFaceIdToCluster = {};
|
||||
final stopwatchClustering = Stopwatch()..start();
|
||||
for (int i = 0; i < totalFaces; i++) {
|
||||
if (faceInfos[i].clusterId != null) continue;
|
||||
int closestIdx = -1;
|
||||
double closestDistance = double.infinity;
|
||||
if (i % 250 == 0) {
|
||||
log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
|
||||
if (i + 1 % 250 == 0) {
|
||||
log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces");
|
||||
}
|
||||
for (int j = 0; j < totalFaces; j++) {
|
||||
if (i == j) continue;
|
||||
|
@ -656,18 +658,91 @@ class FaceClusteringService {
|
|||
if (faceInfos[closestIdx].clusterId == null) {
|
||||
clusterID++;
|
||||
faceInfos[closestIdx].clusterId = clusterID;
|
||||
newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
|
||||
}
|
||||
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
|
||||
newFaceIdToCluster[faceInfos[i].faceID] =
|
||||
faceInfos[closestIdx].clusterId!;
|
||||
} else {
|
||||
clusterID++;
|
||||
faceInfos[i].clusterId = clusterID;
|
||||
newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
|
||||
}
|
||||
}
|
||||
|
||||
// Now calculate the mean of the embeddings for each cluster
|
||||
final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
|
||||
for (final faceInfo in faceInfos) {
|
||||
if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
|
||||
clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
|
||||
} else {
|
||||
clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
|
||||
}
|
||||
}
|
||||
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
|
||||
for (final clusterId in clusterIdToFaceInfos.keys) {
|
||||
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
|
||||
.map((faceInfo) => faceInfo.vEmbedding!)
|
||||
.toList();
|
||||
final count = clusterIdToFaceInfos[clusterId]!.length;
|
||||
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
|
||||
}
|
||||
|
||||
// Now merge the clusters that are close to each other, based on mean embedding
|
||||
final List<(int, int)> mergedClustersList = [];
|
||||
final List<int> clusterIds =
|
||||
clusterIdToMeanEmbeddingAndWeight.keys.toList();
|
||||
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
|
||||
while (true) {
|
||||
if (clusterIds.length < 2) break;
|
||||
double distance = double.infinity;
|
||||
(int, int) clusterIDsToMerge = (-1, -1);
|
||||
for (int i = 0; i < clusterIds.length; i++) {
|
||||
for (int j = 0; j < clusterIds.length; j++) {
|
||||
if (i == j) continue;
|
||||
final double newDistance = 1.0 -
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||
);
|
||||
if (newDistance < distance) {
|
||||
distance = newDistance;
|
||||
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (distance < mergeThreshold) {
|
||||
mergedClustersList.add(clusterIDsToMerge);
|
||||
final clusterID1 = clusterIDsToMerge.$1;
|
||||
final clusterID2 = clusterIDsToMerge.$2;
|
||||
final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
|
||||
final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
|
||||
final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
|
||||
final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
|
||||
final weight1 = count1 / (count1 + count2);
|
||||
final weight2 = count2 / (count1 + count2);
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
|
||||
mean1 * weight1 + mean2 * weight2,
|
||||
count1 + count2,
|
||||
);
|
||||
clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
|
||||
clusterIds.remove(clusterID2);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
|
||||
|
||||
// Now assign the new clusterId to the faces
|
||||
for (final faceInfo in faceInfos) {
|
||||
for (final mergedClusters in mergedClustersList) {
|
||||
if (faceInfo.clusterId == mergedClusters.$2) {
|
||||
faceInfo.clusterId = mergedClusters.$1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, assign the new clusterId to the faces
|
||||
for (final faceInfo in faceInfos) {
|
||||
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
||||
}
|
||||
|
||||
stopwatchClustering.stop();
|
||||
log(
|
||||
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
||||
|
|
|
@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|||
import "package:photos/models/file/file.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||
// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
|
@ -232,14 +234,36 @@ class ClusterFeedbackService {
|
|||
maxClusterID++;
|
||||
}
|
||||
} else {
|
||||
// final clusteringInput = embeddings
|
||||
// .map((key, value) {
|
||||
// return MapEntry(
|
||||
// key,
|
||||
// FaceInfoForClustering(
|
||||
// faceID: key,
|
||||
// embeddingBytes: value,
|
||||
// faceScore: kMinHighQualityFaceScore + 0.01,
|
||||
// blurValue: kLapacianDefault,
|
||||
// ),
|
||||
// );
|
||||
// })
|
||||
// .values
|
||||
// .toSet();
|
||||
// final faceIdToCluster =
|
||||
// await FaceClusteringService.instance.predictLinear(
|
||||
// clusteringInput,
|
||||
// fileIDToCreationTime: fileIDToCreationTime,
|
||||
// distanceThreshold: 0.23,
|
||||
// useDynamicThreshold: false,
|
||||
// );
|
||||
final faceIdToCluster =
|
||||
await FaceClusteringService.instance.predictComplete(
|
||||
embeddings,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
distanceThreshold: 0.30,
|
||||
mergeThreshold: 0.30,
|
||||
);
|
||||
|
||||
if (faceIdToCluster.isEmpty) {
|
||||
if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
|
||||
_logger.info('No clusters found');
|
||||
return {};
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue