|
@@ -2,19 +2,26 @@ import "dart:async";
|
|
|
import "dart:developer";
|
|
|
import "dart:isolate";
|
|
|
import "dart:math" show max;
|
|
|
-import "dart:typed_data";
|
|
|
+import "dart:typed_data" show Uint8List;
|
|
|
|
|
|
+import "package:computer/computer.dart";
|
|
|
+import "package:flutter/foundation.dart" show kDebugMode;
|
|
|
import "package:logging/logging.dart";
|
|
|
import "package:ml_linalg/dtype.dart";
|
|
|
import "package:ml_linalg/vector.dart";
|
|
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|
|
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
|
|
+import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
|
|
+import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
|
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
|
|
import "package:simple_cluster/simple_cluster.dart";
|
|
|
import "package:synchronized/synchronized.dart";
|
|
|
|
|
|
class FaceInfo {
|
|
|
final String faceID;
|
|
|
+ final double? faceScore;
|
|
|
+ final double? blurValue;
|
|
|
+ final bool? badFace;
|
|
|
final List<double>? embedding;
|
|
|
final Vector? vEmbedding;
|
|
|
int? clusterId;
|
|
@@ -23,6 +30,9 @@ class FaceInfo {
|
|
|
int? fileCreationTime;
|
|
|
FaceInfo({
|
|
|
required this.faceID,
|
|
|
+ this.faceScore,
|
|
|
+ this.blurValue,
|
|
|
+ this.badFace,
|
|
|
this.embedding,
|
|
|
this.vEmbedding,
|
|
|
this.clusterId,
|
|
@@ -32,8 +42,18 @@ class FaceInfo {
|
|
|
|
|
|
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
|
|
|
|
|
|
+class ClusteringResult {
|
|
|
+ final Map<String, int> newFaceIdToCluster;
|
|
|
+ final Map<int, (Uint8List, int)>? newClusterSummaries;
|
|
|
+ ClusteringResult({
|
|
|
+ required this.newFaceIdToCluster,
|
|
|
+ required this.newClusterSummaries,
|
|
|
+ });
|
|
|
+}
|
|
|
+
|
|
|
class FaceClusteringService {
|
|
|
final _logger = Logger("FaceLinearClustering");
|
|
|
+ final _computer = Computer.shared();
|
|
|
|
|
|
Timer? _inactivityTimer;
|
|
|
final Duration _inactivityDuration = const Duration(minutes: 3);
|
|
@@ -49,6 +69,7 @@ class FaceClusteringService {
|
|
|
bool isRunning = false;
|
|
|
|
|
|
static const kRecommendedDistanceThreshold = 0.24;
|
|
|
+ static const kConservativeDistanceThreshold = 0.06;
|
|
|
|
|
|
// singleton pattern
|
|
|
FaceClusteringService._privateConstructor();
|
|
@@ -100,31 +121,11 @@ class FaceClusteringService {
|
|
|
try {
|
|
|
switch (function) {
|
|
|
case ClusterOperation.linearIncrementalClustering:
|
|
|
- final input = args['input'] as Map<String, (int?, Uint8List)>;
|
|
|
- final fileIDToCreationTime =
|
|
|
- args['fileIDToCreationTime'] as Map<int, int>?;
|
|
|
- final distanceThreshold = args['distanceThreshold'] as double;
|
|
|
- final offset = args['offset'] as int?;
|
|
|
- final result = FaceClusteringService._runLinearClustering(
|
|
|
- input,
|
|
|
- fileIDToCreationTime: fileIDToCreationTime,
|
|
|
- distanceThreshold: distanceThreshold,
|
|
|
- offset: offset,
|
|
|
- );
|
|
|
+ final result = FaceClusteringService.runLinearClustering(args);
|
|
|
sendPort.send(result);
|
|
|
break;
|
|
|
case ClusterOperation.dbscanClustering:
|
|
|
- final input = args['input'] as Map<String, Uint8List>;
|
|
|
- final fileIDToCreationTime =
|
|
|
- args['fileIDToCreationTime'] as Map<int, int>?;
|
|
|
- final eps = args['eps'] as double;
|
|
|
- final minPts = args['minPts'] as int;
|
|
|
- final result = FaceClusteringService._runDbscanClustering(
|
|
|
- input,
|
|
|
- fileIDToCreationTime: fileIDToCreationTime,
|
|
|
- eps: eps,
|
|
|
- minPts: minPts,
|
|
|
- );
|
|
|
+ final result = FaceClusteringService._runDbscanClustering(args);
|
|
|
sendPort.send(result);
|
|
|
break;
|
|
|
}
|
|
@@ -194,16 +195,19 @@ class FaceClusteringService {
|
|
|
_inactivityTimer?.cancel();
|
|
|
}
|
|
|
|
|
|
- /// Runs the clustering algorithm on the given [input], in an isolate.
|
|
|
+ /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
|
|
|
///
|
|
|
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
|
|
|
///
|
|
|
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
|
|
|
- Future<Map<String, int>?> predictLinear(
|
|
|
- Map<String, (int?, Uint8List)> input, {
|
|
|
+ Future<ClusteringResult?> predictLinear(
|
|
|
+ Set<FaceInfoForClustering> input, {
|
|
|
Map<int, int>? fileIDToCreationTime,
|
|
|
double distanceThreshold = kRecommendedDistanceThreshold,
|
|
|
+ double conservativeDistanceThreshold = kConservativeDistanceThreshold,
|
|
|
+ bool useDynamicThreshold = true,
|
|
|
int? offset,
|
|
|
+ required Map<int, (Uint8List, int)> oldClusterSummaries,
|
|
|
}) async {
|
|
|
if (input.isEmpty) {
|
|
|
_logger.warning(
|
|
@@ -225,20 +229,23 @@ class FaceClusteringService {
|
|
|
final stopwatchClustering = Stopwatch()..start();
|
|
|
// final Map<String, int> faceIdToCluster =
|
|
|
// await _runLinearClusteringInComputer(input);
|
|
|
- final Map<String, int> faceIdToCluster = await _runInIsolate(
|
|
|
+ final ClusteringResult? faceIdToCluster = await _runInIsolate(
|
|
|
(
|
|
|
ClusterOperation.linearIncrementalClustering,
|
|
|
{
|
|
|
'input': input,
|
|
|
'fileIDToCreationTime': fileIDToCreationTime,
|
|
|
'distanceThreshold': distanceThreshold,
|
|
|
+ 'conservativeDistanceThreshold': conservativeDistanceThreshold,
|
|
|
+ 'useDynamicThreshold': useDynamicThreshold,
|
|
|
'offset': offset,
|
|
|
+ 'oldClusterSummaries': oldClusterSummaries,
|
|
|
}
|
|
|
),
|
|
|
);
|
|
|
// return _runLinearClusteringInComputer(input);
|
|
|
_logger.info(
|
|
|
- 'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
|
|
|
+ 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
|
|
|
);
|
|
|
|
|
|
isRunning = false;
|
|
@@ -250,6 +257,142 @@ class FaceClusteringService {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
|
|
|
+ Future<ClusteringResult?> predictLinearComputer(
|
|
|
+ Map<String, Uint8List> input, {
|
|
|
+ Map<int, int>? fileIDToCreationTime,
|
|
|
+ double distanceThreshold = kRecommendedDistanceThreshold,
|
|
|
+ }) async {
|
|
|
+ if (input.isEmpty) {
|
|
|
+ _logger.warning(
|
|
|
+ "Linear Clustering dataset of embeddings is empty, returning empty list.",
|
|
|
+ );
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Clustering inside the isolate
|
|
|
+ _logger.info(
|
|
|
+ "Start Linear clustering on ${input.length} embeddings inside computer isolate",
|
|
|
+ );
|
|
|
+
|
|
|
+ try {
|
|
|
+ final clusteringInput = input
|
|
|
+ .map((key, value) {
|
|
|
+ return MapEntry(
|
|
|
+ key,
|
|
|
+ FaceInfoForClustering(
|
|
|
+ faceID: key,
|
|
|
+ embeddingBytes: value,
|
|
|
+ faceScore: kMinimumQualityFaceScore + 0.01,
|
|
|
+ blurValue: kLapacianDefault,
|
|
|
+ ),
|
|
|
+ );
|
|
|
+ })
|
|
|
+ .values
|
|
|
+ .toSet();
|
|
|
+ final startTime = DateTime.now();
|
|
|
+ final faceIdToCluster = await _computer.compute(
|
|
|
+ runLinearClustering,
|
|
|
+ param: {
|
|
|
+ "input": clusteringInput,
|
|
|
+ "fileIDToCreationTime": fileIDToCreationTime,
|
|
|
+ "distanceThreshold": distanceThreshold,
|
|
|
+ "conservativeDistanceThreshold": distanceThreshold,
|
|
|
+ "useDynamicThreshold": false,
|
|
|
+ },
|
|
|
+ taskName: "createImageEmbedding",
|
|
|
+ ) as ClusteringResult;
|
|
|
+ final endTime = DateTime.now();
|
|
|
+ _logger.info(
|
|
|
+ "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
|
|
|
+ );
|
|
|
+ return faceIdToCluster;
|
|
|
+ } catch (e, s) {
|
|
|
+ _logger.severe(e, s);
|
|
|
+ rethrow;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer.
|
|
|
+ ///
|
|
|
+ /// WARNING: Only use on small datasets, as it is not optimized for large datasets.
|
|
|
+ Future<Map<String, int>> predictCompleteComputer(
|
|
|
+ Map<String, Uint8List> input, {
|
|
|
+ Map<int, int>? fileIDToCreationTime,
|
|
|
+ double distanceThreshold = kRecommendedDistanceThreshold,
|
|
|
+ double mergeThreshold = 0.30,
|
|
|
+ }) async {
|
|
|
+ if (input.isEmpty) {
|
|
|
+ _logger.warning(
|
|
|
+ "Complete Clustering dataset of embeddings is empty, returning empty list.",
|
|
|
+ );
|
|
|
+ return {};
|
|
|
+ }
|
|
|
+
|
|
|
+ // Clustering inside the isolate
|
|
|
+ _logger.info(
|
|
|
+ "Start Complete clustering on ${input.length} embeddings inside computer isolate",
|
|
|
+ );
|
|
|
+
|
|
|
+ try {
|
|
|
+ final startTime = DateTime.now();
|
|
|
+ final faceIdToCluster = await _computer.compute(
|
|
|
+ runCompleteClustering,
|
|
|
+ param: {
|
|
|
+ "input": input,
|
|
|
+ "fileIDToCreationTime": fileIDToCreationTime,
|
|
|
+ "distanceThreshold": distanceThreshold,
|
|
|
+ "mergeThreshold": mergeThreshold,
|
|
|
+ },
|
|
|
+ taskName: "createImageEmbedding",
|
|
|
+ ) as Map<String, int>;
|
|
|
+ final endTime = DateTime.now();
|
|
|
+ _logger.info(
|
|
|
+ "Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
|
|
|
+ );
|
|
|
+ return faceIdToCluster;
|
|
|
+ } catch (e, s) {
|
|
|
+ _logger.severe(e, s);
|
|
|
+ rethrow;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Future<Map<String, int>?> predictWithinClusterComputer(
|
|
|
+ Map<String, Uint8List> input, {
|
|
|
+ Map<int, int>? fileIDToCreationTime,
|
|
|
+ double distanceThreshold = kRecommendedDistanceThreshold,
|
|
|
+ }) async {
|
|
|
+ _logger.info(
|
|
|
+ '`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold',
|
|
|
+ );
|
|
|
+ try {
|
|
|
+ if (input.length < 100) {
|
|
|
+ final mergeThreshold = distanceThreshold + 0.06;
|
|
|
+ _logger.info(
|
|
|
+ 'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold',
|
|
|
+ );
|
|
|
+ return predictCompleteComputer(
|
|
|
+ input,
|
|
|
+ fileIDToCreationTime: fileIDToCreationTime,
|
|
|
+ mergeThreshold: mergeThreshold,
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ _logger.info(
|
|
|
+ 'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
|
|
|
+ );
|
|
|
+ final clusterResult = await predictLinearComputer(
|
|
|
+ input,
|
|
|
+ fileIDToCreationTime: fileIDToCreationTime,
|
|
|
+ distanceThreshold: distanceThreshold,
|
|
|
+ );
|
|
|
+ return clusterResult?.newFaceIdToCluster;
|
|
|
+ }
|
|
|
+ } catch (e, s) {
|
|
|
+ _logger.severe(e, s);
|
|
|
+ rethrow;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
Future<List<List<String>>> predictDbscan(
|
|
|
Map<String, Uint8List> input, {
|
|
|
Map<int, int>? fileIDToCreationTime,
|
|
@@ -299,29 +442,42 @@ class FaceClusteringService {
|
|
|
return clusterFaceIDs;
|
|
|
}
|
|
|
|
|
|
- static Map<String, int> _runLinearClustering(
|
|
|
- Map<String, (int?, Uint8List)> x, {
|
|
|
- Map<int, int>? fileIDToCreationTime,
|
|
|
- double distanceThreshold = kRecommendedDistanceThreshold,
|
|
|
- int? offset,
|
|
|
- }) {
|
|
|
+ static ClusteringResult? runLinearClustering(Map args) {
|
|
|
+ // final input = args['input'] as Map<String, (int?, Uint8List)>;
|
|
|
+ final input = args['input'] as Set<FaceInfoForClustering>;
|
|
|
+ final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
|
|
+ final distanceThreshold = args['distanceThreshold'] as double;
|
|
|
+ final conservativeDistanceThreshold =
|
|
|
+ args['conservativeDistanceThreshold'] as double;
|
|
|
+ final useDynamicThreshold = args['useDynamicThreshold'] as bool;
|
|
|
+ final offset = args['offset'] as int?;
|
|
|
+ final oldClusterSummaries =
|
|
|
+ args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
|
|
|
+
|
|
|
log(
|
|
|
- "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
|
|
|
+ "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
|
|
|
);
|
|
|
|
|
|
// Organize everything into a list of FaceInfo objects
|
|
|
final List<FaceInfo> faceInfos = [];
|
|
|
- for (final entry in x.entries) {
|
|
|
+ for (final face in input) {
|
|
|
faceInfos.add(
|
|
|
FaceInfo(
|
|
|
- faceID: entry.key,
|
|
|
+ faceID: face.faceID,
|
|
|
+ faceScore: face.faceScore,
|
|
|
+ blurValue: face.blurValue,
|
|
|
+ badFace: face.faceScore < kMinimumQualityFaceScore ||
|
|
|
+ face.blurValue < kLaplacianSoftThreshold ||
|
|
|
+ (face.blurValue < kLaplacianVerySoftThreshold &&
|
|
|
+ face.faceScore < kMediumQualityFaceScore) ||
|
|
|
+ face.isSideways,
|
|
|
vEmbedding: Vector.fromList(
|
|
|
- EVector.fromBuffer(entry.value.$2).values,
|
|
|
+ EVector.fromBuffer(face.embeddingBytes).values,
|
|
|
dtype: DType.float32,
|
|
|
),
|
|
|
- clusterId: entry.value.$1,
|
|
|
+ clusterId: face.clusterId,
|
|
|
fileCreationTime:
|
|
|
- fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
|
|
+ fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
|
|
|
),
|
|
|
);
|
|
|
}
|
|
@@ -351,19 +507,21 @@ class FaceClusteringService {
|
|
|
facesWithClusterID.add(faceInfo);
|
|
|
}
|
|
|
}
|
|
|
+ final alreadyClusteredCount = facesWithClusterID.length;
|
|
|
final sortedFaceInfos = <FaceInfo>[];
|
|
|
sortedFaceInfos.addAll(facesWithClusterID);
|
|
|
sortedFaceInfos.addAll(facesWithoutClusterID);
|
|
|
|
|
|
log(
|
|
|
- "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId",
|
|
|
+ "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId",
|
|
|
);
|
|
|
|
|
|
// Make sure the first face has a clusterId
|
|
|
final int totalFaces = sortedFaceInfos.length;
|
|
|
+ int dynamicThresholdCount = 0;
|
|
|
|
|
|
if (sortedFaceInfos.isEmpty) {
|
|
|
- return {};
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
// Start actual clustering
|
|
@@ -377,7 +535,6 @@ class FaceClusteringService {
|
|
|
sortedFaceInfos[0].clusterId = clusterID;
|
|
|
clusterID++;
|
|
|
}
|
|
|
- final Map<String, int> newFaceIdToCluster = {};
|
|
|
final stopwatchClustering = Stopwatch()..start();
|
|
|
for (int i = 1; i < totalFaces; i++) {
|
|
|
// Incremental clustering, so we can skip faces that already have a clusterId
|
|
@@ -388,6 +545,15 @@ class FaceClusteringService {
|
|
|
|
|
|
int closestIdx = -1;
|
|
|
double closestDistance = double.infinity;
|
|
|
+ late double thresholdValue;
|
|
|
+ if (useDynamicThreshold) {
|
|
|
+ thresholdValue = sortedFaceInfos[i].badFace!
|
|
|
+ ? conservativeDistanceThreshold
|
|
|
+ : distanceThreshold;
|
|
|
+ if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++;
|
|
|
+ } else {
|
|
|
+ thresholdValue = distanceThreshold;
|
|
|
+ }
|
|
|
if (i % 250 == 0) {
|
|
|
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
|
|
|
}
|
|
@@ -405,18 +571,16 @@ class FaceClusteringService {
|
|
|
);
|
|
|
}
|
|
|
if (distance < closestDistance) {
|
|
|
+ if (sortedFaceInfos[j].badFace! &&
|
|
|
+ distance > conservativeDistanceThreshold) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
closestDistance = distance;
|
|
|
closestIdx = j;
|
|
|
- // if (distance < distanceThreshold) {
|
|
|
- // if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
|
|
|
- // sortedFaceInfos[j].faceID.startsWith("15488756")) {
|
|
|
- // log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
|
|
|
- // }
|
|
|
- // }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if (closestDistance < distanceThreshold) {
|
|
|
+ if (closestDistance < thresholdValue) {
|
|
|
if (sortedFaceInfos[closestIdx].clusterId == null) {
|
|
|
// Ideally this should never happen, but just in case log it
|
|
|
log(
|
|
@@ -424,42 +588,99 @@ class FaceClusteringService {
|
|
|
);
|
|
|
clusterID++;
|
|
|
sortedFaceInfos[closestIdx].clusterId = clusterID;
|
|
|
- newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
|
|
|
}
|
|
|
- // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
|
|
|
- // sortedFaceInfos[i].faceID.startsWith("15488756")) {
|
|
|
- // log(
|
|
|
- // "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
|
|
|
- // );
|
|
|
- // }
|
|
|
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
|
|
|
- newFaceIdToCluster[sortedFaceInfos[i].faceID] =
|
|
|
- sortedFaceInfos[closestIdx].clusterId!;
|
|
|
} else {
|
|
|
- // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
|
|
|
- // sortedFaceInfos[i].faceID.startsWith("15488756")) {
|
|
|
- // log(
|
|
|
- // "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
|
|
|
- // );
|
|
|
- // }
|
|
|
clusterID++;
|
|
|
sortedFaceInfos[i].clusterId = clusterID;
|
|
|
- newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // Finally, assign the new clusterId to the faces
|
|
|
+ final Map<String, int> newFaceIdToCluster = {};
|
|
|
+ final newClusteredFaceInfos =
|
|
|
+ sortedFaceInfos.sublist(alreadyClusteredCount);
|
|
|
+ for (final faceInfo in newClusteredFaceInfos) {
|
|
|
+ newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
|
|
+ }
|
|
|
+
|
|
|
stopwatchClustering.stop();
|
|
|
log(
|
|
|
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
|
|
);
|
|
|
+ if (useDynamicThreshold) {
|
|
|
+ log(
|
|
|
+ "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity",
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now calculate the mean of the embeddings for each cluster and update the cluster summaries
|
|
|
+ Map<int, (Uint8List, int)>? newClusterSummaries;
|
|
|
+ if (oldClusterSummaries != null) {
|
|
|
+ newClusterSummaries = FaceClusteringService.updateClusterSummaries(
|
|
|
+ oldSummary: oldClusterSummaries,
|
|
|
+ newFaceInfos: newClusteredFaceInfos,
|
|
|
+ );
|
|
|
+ }
|
|
|
|
|
|
// analyze the results
|
|
|
FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
|
|
|
|
|
|
- return newFaceIdToCluster;
|
|
|
+ return ClusteringResult(
|
|
|
+ newFaceIdToCluster: newFaceIdToCluster,
|
|
|
+ newClusterSummaries: newClusterSummaries,
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ static Map<int, (Uint8List, int)> updateClusterSummaries({
|
|
|
+ required Map<int, (Uint8List, int)> oldSummary,
|
|
|
+ required List<FaceInfo> newFaceInfos,
|
|
|
+ }) {
|
|
|
+ final calcSummariesStart = DateTime.now();
|
|
|
+ final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
|
|
|
+ for (final faceInfo in newFaceInfos) {
|
|
|
+ if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
|
|
|
+ newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
|
|
|
+ } else {
|
|
|
+ newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ final Map<int, (Uint8List, int)> newClusterSummaries = {};
|
|
|
+ for (final clusterId in newClusterIdToFaceInfos.keys) {
|
|
|
+ final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
|
|
|
+ .map((faceInfo) => faceInfo.vEmbedding!)
|
|
|
+ .toList();
|
|
|
+ final newCount = newEmbeddings.length;
|
|
|
+ if (oldSummary.containsKey(clusterId)) {
|
|
|
+ final oldMean = Vector.fromList(
|
|
|
+ EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ );
|
|
|
+ final oldCount = oldSummary[clusterId]!.$2;
|
|
|
+ final oldEmbeddings = oldMean * oldCount;
|
|
|
+ newEmbeddings.add(oldEmbeddings);
|
|
|
+ final newMeanVector =
|
|
|
+ newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
|
|
|
+ newClusterSummaries[clusterId] = (
|
|
|
+ EVector(values: newMeanVector.toList()).writeToBuffer(),
|
|
|
+ oldCount + newCount
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
|
|
|
+ newClusterSummaries[clusterId] =
|
|
|
+ (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log(
|
|
|
+ "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
|
|
|
+ );
|
|
|
+
|
|
|
+ return newClusterSummaries;
|
|
|
}
|
|
|
|
|
|
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
|
|
|
+ if (!kDebugMode) return;
|
|
|
final stopwatch = Stopwatch()..start();
|
|
|
|
|
|
final Map<String, int> faceIdToCluster = {};
|
|
@@ -517,14 +738,185 @@ class FaceClusteringService {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- static List<List<String>> _runDbscanClustering(
|
|
|
- Map<String, Uint8List> x, {
|
|
|
- Map<int, int>? fileIDToCreationTime,
|
|
|
- double eps = 0.3,
|
|
|
- int minPts = 5,
|
|
|
- }) {
|
|
|
+ static Map<String, int> runCompleteClustering(Map args) {
|
|
|
+ final input = args['input'] as Map<String, Uint8List>;
|
|
|
+ final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
|
|
+ final distanceThreshold = args['distanceThreshold'] as double;
|
|
|
+ final mergeThreshold = args['mergeThreshold'] as double;
|
|
|
+
|
|
|
+ log(
|
|
|
+ "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
|
|
+ );
|
|
|
+
|
|
|
+ // Organize everything into a list of FaceInfo objects
|
|
|
+ final List<FaceInfo> faceInfos = [];
|
|
|
+ for (final entry in input.entries) {
|
|
|
+ faceInfos.add(
|
|
|
+ FaceInfo(
|
|
|
+ faceID: entry.key,
|
|
|
+ vEmbedding: Vector.fromList(
|
|
|
+ EVector.fromBuffer(entry.value).values,
|
|
|
+ dtype: DType.float32,
|
|
|
+ ),
|
|
|
+ fileCreationTime:
|
|
|
+ fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
|
|
+ ),
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
|
|
+ if (fileIDToCreationTime != null) {
|
|
|
+ faceInfos.sort((a, b) {
|
|
|
+ if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
|
|
+ return 0;
|
|
|
+ } else if (a.fileCreationTime == null) {
|
|
|
+ return 1;
|
|
|
+ } else if (b.fileCreationTime == null) {
|
|
|
+ return -1;
|
|
|
+ } else {
|
|
|
+ return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ if (faceInfos.isEmpty) {
|
|
|
+ return {};
|
|
|
+ }
|
|
|
+ final int totalFaces = faceInfos.length;
|
|
|
+
|
|
|
+ // Start actual clustering
|
|
|
+ log(
|
|
|
+ "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering",
|
|
|
+ );
|
|
|
+
|
|
|
+ // set current epoch time as clusterID
|
|
|
+ int clusterID = DateTime.now().microsecondsSinceEpoch;
|
|
|
+
|
|
|
+ // Start actual clustering
|
|
|
+ final Map<String, int> newFaceIdToCluster = {};
|
|
|
+ final stopwatchClustering = Stopwatch()..start();
|
|
|
+ for (int i = 0; i < totalFaces; i++) {
|
|
|
+ if ((i + 1) % 250 == 0) {
|
|
|
+ log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces");
|
|
|
+ }
|
|
|
+ if (faceInfos[i].clusterId != null) continue;
|
|
|
+ int closestIdx = -1;
|
|
|
+ double closestDistance = double.infinity;
|
|
|
+ for (int j = 0; j < totalFaces; j++) {
|
|
|
+ if (i == j) continue;
|
|
|
+ final double distance =
|
|
|
+ 1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
|
|
+ if (distance < closestDistance) {
|
|
|
+ closestDistance = distance;
|
|
|
+ closestIdx = j;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (closestDistance < distanceThreshold) {
|
|
|
+ if (faceInfos[closestIdx].clusterId == null) {
|
|
|
+ clusterID++;
|
|
|
+ faceInfos[closestIdx].clusterId = clusterID;
|
|
|
+ }
|
|
|
+ faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
|
|
|
+ } else {
|
|
|
+ clusterID++;
|
|
|
+ faceInfos[i].clusterId = clusterID;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now calculate the mean of the embeddings for each cluster
|
|
|
+ final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
|
|
|
+ for (final faceInfo in faceInfos) {
|
|
|
+ if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
|
|
|
+ clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
|
|
|
+ } else {
|
|
|
+ clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
|
|
|
+ for (final clusterId in clusterIdToFaceInfos.keys) {
|
|
|
+ final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
|
|
|
+ .map((faceInfo) => faceInfo.vEmbedding!)
|
|
|
+ .toList();
|
|
|
+ final count = clusterIdToFaceInfos[clusterId]!.length;
|
|
|
+ final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
|
|
|
+ clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Now merge the clusters that are close to each other, based on mean embedding
|
|
|
+ final List<(int, int)> mergedClustersList = [];
|
|
|
+ final List<int> clusterIds =
|
|
|
+ clusterIdToMeanEmbeddingAndWeight.keys.toList();
|
|
|
+ log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
|
|
|
+ while (true) {
|
|
|
+ if (clusterIds.length < 2) break;
|
|
|
+ double distance = double.infinity;
|
|
|
+ (int, int) clusterIDsToMerge = (-1, -1);
|
|
|
+ for (int i = 0; i < clusterIds.length; i++) {
|
|
|
+ for (int j = 0; j < clusterIds.length; j++) {
|
|
|
+ if (i == j) continue;
|
|
|
+ final double newDistance = 1.0 -
|
|
|
+ clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
|
|
|
+ clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
|
|
+ );
|
|
|
+ if (newDistance < distance) {
|
|
|
+ distance = newDistance;
|
|
|
+ clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (distance < mergeThreshold) {
|
|
|
+ mergedClustersList.add(clusterIDsToMerge);
|
|
|
+ final clusterID1 = clusterIDsToMerge.$1;
|
|
|
+ final clusterID2 = clusterIDsToMerge.$2;
|
|
|
+ final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
|
|
|
+ final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
|
|
|
+ final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
|
|
|
+ final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
|
|
|
+ final weight1 = count1 / (count1 + count2);
|
|
|
+ final weight2 = count2 / (count1 + count2);
|
|
|
+ clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
|
|
|
+ mean1 * weight1 + mean2 * weight2,
|
|
|
+ count1 + count2,
|
|
|
+ );
|
|
|
+ clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
|
|
|
+ clusterIds.remove(clusterID2);
|
|
|
+ } else {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
|
|
|
+
|
|
|
+ // Now assign the new clusterId to the faces
|
|
|
+ for (final faceInfo in faceInfos) {
|
|
|
+ for (final mergedClusters in mergedClustersList) {
|
|
|
+ if (faceInfo.clusterId == mergedClusters.$2) {
|
|
|
+ faceInfo.clusterId = mergedClusters.$1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Finally, assign the new clusterId to the faces
|
|
|
+ for (final faceInfo in faceInfos) {
|
|
|
+ newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
|
|
+ }
|
|
|
+
|
|
|
+ stopwatchClustering.stop();
|
|
|
+ log(
|
|
|
+ ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
|
|
+ );
|
|
|
+
|
|
|
+ return newFaceIdToCluster;
|
|
|
+ }
|
|
|
+
|
|
|
+ static List<List<String>> _runDbscanClustering(Map args) {
|
|
|
+ final input = args['input'] as Map<String, Uint8List>;
|
|
|
+ final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
|
|
+ final eps = args['eps'] as double;
|
|
|
+ final minPts = args['minPts'] as int;
|
|
|
+
|
|
|
log(
|
|
|
- "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
|
|
|
+ "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
|
|
|
);
|
|
|
|
|
|
final DBSCAN dbscan = DBSCAN(
|
|
@@ -535,7 +927,7 @@ class FaceClusteringService {
|
|
|
|
|
|
// Organize everything into a list of FaceInfo objects
|
|
|
final List<FaceInfo> faceInfos = [];
|
|
|
- for (final entry in x.entries) {
|
|
|
+ for (final entry in input.entries) {
|
|
|
faceInfos.add(
|
|
|
FaceInfo(
|
|
|
faceID: entry.key,
|