[mob][photos] Precompute cluster summaries incrementally during clustering
This commit is contained in:
parent
cde17441d6
commit
edf99385dc
2 changed files with 104 additions and 17 deletions
|
@ -42,6 +42,15 @@ class FaceInfo {
|
|||
|
||||
enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
|
||||
|
||||
class ClusteringResult {
|
||||
final Map<String, int> newFaceIdToCluster;
|
||||
final Map<int, (Uint8List, int)>? newClusterSummaries;
|
||||
ClusteringResult({
|
||||
required this.newFaceIdToCluster,
|
||||
required this.newClusterSummaries,
|
||||
});
|
||||
}
|
||||
|
||||
class FaceClusteringService {
|
||||
final _logger = Logger("FaceLinearClustering");
|
||||
final _computer = Computer.shared();
|
||||
|
@ -191,13 +200,14 @@ class FaceClusteringService {
|
|||
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
|
||||
///
|
||||
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
|
||||
Future<Map<String, int>?> predictLinear(
|
||||
Future<ClusteringResult?> predictLinear(
|
||||
Set<FaceInfoForClustering> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
double conservativeDistanceThreshold = kConservativeDistanceThreshold,
|
||||
bool useDynamicThreshold = true,
|
||||
int? offset,
|
||||
required Map<int, (Uint8List, int)> oldClusterSummaries,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
_logger.warning(
|
||||
|
@ -219,7 +229,7 @@ class FaceClusteringService {
|
|||
final stopwatchClustering = Stopwatch()..start();
|
||||
// final Map<String, int> faceIdToCluster =
|
||||
// await _runLinearClusteringInComputer(input);
|
||||
final Map<String, int> faceIdToCluster = await _runInIsolate(
|
||||
final ClusteringResult? faceIdToCluster = await _runInIsolate(
|
||||
(
|
||||
ClusterOperation.linearIncrementalClustering,
|
||||
{
|
||||
|
@ -229,6 +239,7 @@ class FaceClusteringService {
|
|||
'conservativeDistanceThreshold': conservativeDistanceThreshold,
|
||||
'useDynamicThreshold': useDynamicThreshold,
|
||||
'offset': offset,
|
||||
'oldClusterSummaries': oldClusterSummaries,
|
||||
}
|
||||
),
|
||||
);
|
||||
|
@ -247,7 +258,7 @@ class FaceClusteringService {
|
|||
}
|
||||
|
||||
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
|
||||
Future<Map<String, int>?> predictLinearComputer(
|
||||
Future<ClusteringResult?> predictLinearComputer(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
|
@ -256,7 +267,7 @@ class FaceClusteringService {
|
|||
_logger.warning(
|
||||
"Linear Clustering dataset of embeddings is empty, returning empty list.",
|
||||
);
|
||||
return {};
|
||||
return null;
|
||||
}
|
||||
|
||||
// Clustering inside the isolate
|
||||
|
@ -290,7 +301,7 @@ class FaceClusteringService {
|
|||
"useDynamicThreshold": false,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
) as Map<String, int>;
|
||||
) as ClusteringResult;
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
|
||||
|
@ -369,11 +380,12 @@ class FaceClusteringService {
|
|||
_logger.info(
|
||||
'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
|
||||
);
|
||||
return predictLinearComputer(
|
||||
final clusterResult = await predictLinearComputer(
|
||||
input,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
distanceThreshold: distanceThreshold,
|
||||
);
|
||||
return clusterResult?.newFaceIdToCluster;
|
||||
}
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
|
@ -430,7 +442,7 @@ class FaceClusteringService {
|
|||
return clusterFaceIDs;
|
||||
}
|
||||
|
||||
static Map<String, int> runLinearClustering(Map args) {
|
||||
static ClusteringResult? runLinearClustering(Map args) {
|
||||
// final input = args['input'] as Map<String, (int?, Uint8List)>;
|
||||
final input = args['input'] as Set<FaceInfoForClustering>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
|
@ -439,6 +451,8 @@ class FaceClusteringService {
|
|||
args['conservativeDistanceThreshold'] as double;
|
||||
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
|
||||
final offset = args['offset'] as int?;
|
||||
final oldClusterSummaries =
|
||||
args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
|
||||
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
|
||||
|
@ -507,7 +521,7 @@ class FaceClusteringService {
|
|||
int dynamicThresholdCount = 0;
|
||||
|
||||
if (sortedFaceInfos.isEmpty) {
|
||||
return {};
|
||||
return null;
|
||||
}
|
||||
|
||||
// Start actual clustering
|
||||
|
@ -584,7 +598,9 @@ class FaceClusteringService {
|
|||
|
||||
// Finally, assign the new clusterId to the faces
|
||||
final Map<String, int> newFaceIdToCluster = {};
|
||||
for (final faceInfo in sortedFaceInfos.sublist(alreadyClusteredCount)) {
|
||||
final newClusteredFaceInfos =
|
||||
sortedFaceInfos.sublist(alreadyClusteredCount);
|
||||
for (final faceInfo in newClusteredFaceInfos) {
|
||||
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
||||
}
|
||||
|
||||
|
@ -598,10 +614,69 @@ class FaceClusteringService {
|
|||
);
|
||||
}
|
||||
|
||||
// Now calculate the mean of the embeddings for each cluster and update the cluster summaries
|
||||
Map<int, (Uint8List, int)>? newClusterSummaries;
|
||||
if (oldClusterSummaries != null) {
|
||||
newClusterSummaries = FaceClusteringService.updateClusterSummaries(
|
||||
oldSummary: oldClusterSummaries,
|
||||
newFaceInfos: newClusteredFaceInfos,
|
||||
);
|
||||
}
|
||||
|
||||
// analyze the results
|
||||
FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
|
||||
|
||||
return newFaceIdToCluster;
|
||||
return ClusteringResult(
|
||||
newFaceIdToCluster: newFaceIdToCluster,
|
||||
newClusterSummaries: newClusterSummaries,
|
||||
);
|
||||
}
|
||||
|
||||
static Map<int, (Uint8List, int)> updateClusterSummaries({
|
||||
required Map<int, (Uint8List, int)> oldSummary,
|
||||
required List<FaceInfo> newFaceInfos,
|
||||
}) {
|
||||
final calcSummariesStart = DateTime.now();
|
||||
final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
|
||||
for (final faceInfo in newFaceInfos) {
|
||||
if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
|
||||
newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
|
||||
} else {
|
||||
newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
|
||||
}
|
||||
}
|
||||
|
||||
final Map<int, (Uint8List, int)> newClusterSummaries = {};
|
||||
for (final clusterId in newClusterIdToFaceInfos.keys) {
|
||||
final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
|
||||
.map((faceInfo) => faceInfo.vEmbedding!)
|
||||
.toList();
|
||||
final newCount = newEmbeddings.length;
|
||||
if (oldSummary.containsKey(clusterId)) {
|
||||
final oldMean = Vector.fromList(
|
||||
EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
final oldCount = oldSummary[clusterId]!.$2;
|
||||
final oldEmbeddings = oldMean * oldCount;
|
||||
newEmbeddings.add(oldEmbeddings);
|
||||
final newMeanVector =
|
||||
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
|
||||
newClusterSummaries[clusterId] = (
|
||||
EVector(values: newMeanVector.toList()).writeToBuffer(),
|
||||
oldCount + newCount
|
||||
);
|
||||
} else {
|
||||
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
|
||||
newClusterSummaries[clusterId] =
|
||||
(EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
|
||||
}
|
||||
}
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
|
||||
);
|
||||
|
||||
return newClusterSummaries;
|
||||
}
|
||||
|
||||
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
|
||||
|
|
|
@ -300,6 +300,10 @@ class FaceMlService {
|
|||
// Get a sense of the total number of faces in the database
|
||||
final int totalFaces = await FaceMLDataDB.instance
|
||||
.getTotalFaceCount(minFaceScore: minFaceScore);
|
||||
|
||||
// Get the current cluster statistics
|
||||
final Map<int, (Uint8List, int)> oldClusterSummaries =
|
||||
await FaceMLDataDB.instance.getAllClusterSummary();
|
||||
if (clusterInBuckets) {
|
||||
// read the creation times from Files DB, in a map from fileID to creation time
|
||||
final fileIDToCreationTime =
|
||||
|
@ -332,18 +336,22 @@ class FaceMlService {
|
|||
break;
|
||||
}
|
||||
|
||||
final faceIdToCluster =
|
||||
final clusteringResult =
|
||||
await FaceClusteringService.instance.predictLinear(
|
||||
faceInfoForClustering,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
offset: offset,
|
||||
oldClusterSummaries: oldClusterSummaries,
|
||||
);
|
||||
if (faceIdToCluster == null) {
|
||||
if (clusteringResult == null) {
|
||||
_logger.warning("faceIdToCluster is null");
|
||||
return;
|
||||
}
|
||||
|
||||
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
|
||||
_logger.info(
|
||||
'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
|
||||
);
|
||||
|
@ -374,12 +382,13 @@ class FaceMlService {
|
|||
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
|
||||
|
||||
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
|
||||
final faceIdToCluster =
|
||||
final clusteringResult =
|
||||
await FaceClusteringService.instance.predictLinear(
|
||||
faceInfoForClustering,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
oldClusterSummaries: oldClusterSummaries,
|
||||
);
|
||||
if (faceIdToCluster == null) {
|
||||
if (clusteringResult == null) {
|
||||
_logger.warning("faceIdToCluster is null");
|
||||
return;
|
||||
}
|
||||
|
@ -390,9 +399,12 @@ class FaceMlService {
|
|||
|
||||
// Store the updated clusterIDs in the database
|
||||
_logger.info(
|
||||
'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
|
||||
'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
|
||||
);
|
||||
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
|
||||
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
|
||||
'${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue