[mob][photos] Precompute cluster summaries incrementally during clustering

This commit is contained in:
laurenspriem 2024-04-20 14:38:46 +05:30
parent cde17441d6
commit edf99385dc
2 changed files with 104 additions and 17 deletions

View file

@ -42,6 +42,15 @@ class FaceInfo {
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
class ClusteringResult {
final Map<String, int> newFaceIdToCluster;
final Map<int, (Uint8List, int)>? newClusterSummaries;
ClusteringResult({
required this.newFaceIdToCluster,
required this.newClusterSummaries,
});
}
class FaceClusteringService {
final _logger = Logger("FaceLinearClustering");
final _computer = Computer.shared();
@ -191,13 +200,14 @@ class FaceClusteringService {
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predictLinear(
Future<ClusteringResult?> predictLinear(
Set<FaceInfoForClustering> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
double conservativeDistanceThreshold = kConservativeDistanceThreshold,
bool useDynamicThreshold = true,
int? offset,
required Map<int, (Uint8List, int)> oldClusterSummaries,
}) async {
if (input.isEmpty) {
_logger.warning(
@ -219,7 +229,7 @@ class FaceClusteringService {
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final Map<String, int> faceIdToCluster = await _runInIsolate(
final ClusteringResult? faceIdToCluster = await _runInIsolate(
(
ClusterOperation.linearIncrementalClustering,
{
@ -229,6 +239,7 @@ class FaceClusteringService {
'conservativeDistanceThreshold': conservativeDistanceThreshold,
'useDynamicThreshold': useDynamicThreshold,
'offset': offset,
'oldClusterSummaries': oldClusterSummaries,
}
),
);
@ -247,7 +258,7 @@ class FaceClusteringService {
}
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
Future<Map<String, int>?> predictLinearComputer(
Future<ClusteringResult?> predictLinearComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
@ -256,7 +267,7 @@ class FaceClusteringService {
_logger.warning(
"Linear Clustering dataset of embeddings is empty, returning empty list.",
);
return {};
return null;
}
// Clustering inside the isolate
@ -290,7 +301,7 @@ class FaceClusteringService {
"useDynamicThreshold": false,
},
taskName: "createImageEmbedding",
) as Map<String, int>;
) as ClusteringResult;
final endTime = DateTime.now();
_logger.info(
"Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
@ -369,11 +380,12 @@ class FaceClusteringService {
_logger.info(
'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
);
return predictLinearComputer(
final clusterResult = await predictLinearComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: distanceThreshold,
);
return clusterResult?.newFaceIdToCluster;
}
} catch (e, s) {
_logger.severe(e, s);
@ -430,7 +442,7 @@ class FaceClusteringService {
return clusterFaceIDs;
}
static Map<String, int> runLinearClustering(Map args) {
static ClusteringResult? runLinearClustering(Map args) {
// final input = args['input'] as Map<String, (int?, Uint8List)>;
final input = args['input'] as Set<FaceInfoForClustering>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
@ -439,6 +451,8 @@ class FaceClusteringService {
args['conservativeDistanceThreshold'] as double;
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
final offset = args['offset'] as int?;
final oldClusterSummaries =
args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
@ -507,7 +521,7 @@ class FaceClusteringService {
int dynamicThresholdCount = 0;
if (sortedFaceInfos.isEmpty) {
return {};
return null;
}
// Start actual clustering
@ -584,7 +598,9 @@ class FaceClusteringService {
// Finally, assign the new clusterId to the faces
final Map<String, int> newFaceIdToCluster = {};
for (final faceInfo in sortedFaceInfos.sublist(alreadyClusteredCount)) {
final newClusteredFaceInfos =
sortedFaceInfos.sublist(alreadyClusteredCount);
for (final faceInfo in newClusteredFaceInfos) {
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
@ -598,10 +614,69 @@ class FaceClusteringService {
);
}
// Now calculate the mean of the embeddings for each cluster and update the cluster summaries
Map<int, (Uint8List, int)>? newClusterSummaries;
if (oldClusterSummaries != null) {
newClusterSummaries = FaceClusteringService.updateClusterSummaries(
oldSummary: oldClusterSummaries,
newFaceInfos: newClusteredFaceInfos,
);
}
// analyze the results
FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
return newFaceIdToCluster;
return ClusteringResult(
newFaceIdToCluster: newFaceIdToCluster,
newClusterSummaries: newClusterSummaries,
);
}
static Map<int, (Uint8List, int)> updateClusterSummaries({
required Map<int, (Uint8List, int)> oldSummary,
required List<FaceInfo> newFaceInfos,
}) {
final calcSummariesStart = DateTime.now();
final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
for (final faceInfo in newFaceInfos) {
if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
} else {
newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
}
}
final Map<int, (Uint8List, int)> newClusterSummaries = {};
for (final clusterId in newClusterIdToFaceInfos.keys) {
final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final newCount = newEmbeddings.length;
if (oldSummary.containsKey(clusterId)) {
final oldMean = Vector.fromList(
EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
dtype: DType.float32,
);
final oldCount = oldSummary[clusterId]!.$2;
final oldEmbeddings = oldMean * oldCount;
newEmbeddings.add(oldEmbeddings);
final newMeanVector =
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
newClusterSummaries[clusterId] = (
EVector(values: newMeanVector.toList()).writeToBuffer(),
oldCount + newCount
);
} else {
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
newClusterSummaries[clusterId] =
(EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
}
}
log(
"[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
);
return newClusterSummaries;
}
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {

View file

@ -300,6 +300,10 @@ class FaceMlService {
// Get a sense of the total number of faces in the database
final int totalFaces = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: minFaceScore);
// Get the current cluster statistics
final Map<int, (Uint8List, int)> oldClusterSummaries =
await FaceMLDataDB.instance.getAllClusterSummary();
if (clusterInBuckets) {
// read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime =
@ -332,18 +336,22 @@ class FaceMlService {
break;
}
final faceIdToCluster =
final clusteringResult =
await FaceClusteringService.instance.predictLinear(
faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime,
offset: offset,
oldClusterSummaries: oldClusterSummaries,
);
if (faceIdToCluster == null) {
if (clusteringResult == null) {
_logger.warning("faceIdToCluster is null");
return;
}
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info(
'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
);
@ -374,12 +382,13 @@ class FaceMlService {
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster =
final clusteringResult =
await FaceClusteringService.instance.predictLinear(
faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
);
if (faceIdToCluster == null) {
if (clusteringResult == null) {
_logger.warning("faceIdToCluster is null");
return;
}
@ -390,9 +399,12 @@ class FaceMlService {
// Store the updated clusterIDs in the database
_logger.info(
'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
);
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
'${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
}