Refactor of clustering
This commit is contained in:
parent
212208ae01
commit
b5cff212bb
1 changed files with 72 additions and 99 deletions
|
@ -7,6 +7,7 @@ import "dart:typed_data";
|
|||
import "package:logging/logging.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
class FaceInfo {
|
||||
|
@ -15,10 +16,12 @@ class FaceInfo {
|
|||
int? clusterId;
|
||||
String? closestFaceId;
|
||||
int? closestDist;
|
||||
int? fileCreationTime;
|
||||
FaceInfo({
|
||||
required this.faceID,
|
||||
required this.embedding,
|
||||
this.clusterId,
|
||||
this.fileCreationTime,
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -31,7 +34,6 @@ class FaceLinearClustering {
|
|||
final Duration _inactivityDuration = const Duration(seconds: 30);
|
||||
int _activeTasks = 0;
|
||||
|
||||
|
||||
final _initLock = Lock();
|
||||
|
||||
late Isolate _isolate;
|
||||
|
@ -151,8 +153,8 @@ class FaceLinearClustering {
|
|||
_resetInactivityTimer();
|
||||
} else {
|
||||
_logger.info(
|
||||
'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.',
|
||||
);
|
||||
'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.',
|
||||
);
|
||||
dispose();
|
||||
}
|
||||
});
|
||||
|
@ -220,6 +222,8 @@ class FaceLinearClustering {
|
|||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
|
||||
);
|
||||
|
||||
// Organize everything into a list of FaceInfo objects
|
||||
final List<FaceInfo> faceInfos = [];
|
||||
for (final entry in x.entries) {
|
||||
faceInfos.add(
|
||||
|
@ -249,59 +253,61 @@ class FaceLinearClustering {
|
|||
}
|
||||
|
||||
// Sort the faceInfos such that the ones with null clusterId are at the end
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.clusterId == null && b.clusterId == null) {
|
||||
return 0;
|
||||
} else if (a.clusterId == null) {
|
||||
return 1;
|
||||
} else if (b.clusterId == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
});
|
||||
// Count the amount of null values at the end
|
||||
int nullCount = 0;
|
||||
for (final faceInfo in faceInfos.reversed) {
|
||||
final List<FaceInfo> facesWithClusterID = <FaceInfo>[];
|
||||
final List<FaceInfo> facesWithoutClusterID = <FaceInfo>[];
|
||||
for (final FaceInfo faceInfo in faceInfos) {
|
||||
if (faceInfo.clusterId == null) {
|
||||
nullCount++;
|
||||
facesWithoutClusterID.add(faceInfo);
|
||||
} else {
|
||||
break;
|
||||
facesWithClusterID.add(faceInfo);
|
||||
}
|
||||
}
|
||||
final sortedFaceInfos = <FaceInfo>[];
|
||||
sortedFaceInfos.addAll(facesWithClusterID);
|
||||
sortedFaceInfos.addAll(facesWithoutClusterID);
|
||||
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Clustering $nullCount new faces without clusterId, and ${faceInfos.length - nullCount} faces with clusterId",
|
||||
"[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId",
|
||||
);
|
||||
for (final clusteredFaceInfo
|
||||
in faceInfos.sublist(0, faceInfos.length - nullCount)) {
|
||||
assert(clusteredFaceInfo.clusterId != null);
|
||||
|
||||
// Make sure the first face has a clusterId
|
||||
final int totalFaces = sortedFaceInfos.length;
|
||||
int clusterID = 1;
|
||||
if (sortedFaceInfos.isNotEmpty) {
|
||||
if (sortedFaceInfos.first.clusterId == null) {
|
||||
sortedFaceInfos.first.clusterId = clusterID;
|
||||
} else {
|
||||
clusterID = sortedFaceInfos.first.clusterId!;
|
||||
}
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
||||
final int totalFaces = faceInfos.length;
|
||||
int clusterID = 1;
|
||||
if (faceInfos.isNotEmpty) {
|
||||
faceInfos.first.clusterId = clusterID;
|
||||
}
|
||||
// Start actual clustering
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces",
|
||||
);
|
||||
final Map<String, int> newFaceIdToCluster = {};
|
||||
final stopwatchClustering = Stopwatch()..start();
|
||||
for (int i = 1; i < totalFaces; i++) {
|
||||
// Incremental clustering, so we can skip faces that already have a clusterId
|
||||
if (faceInfos[i].clusterId != null) {
|
||||
clusterID = max(clusterID, faceInfos[i].clusterId!);
|
||||
if (sortedFaceInfos[i].clusterId != null) {
|
||||
clusterID = max(clusterID, sortedFaceInfos[i].clusterId!);
|
||||
if (i % 250 == 0) {
|
||||
log("[ClusterIsolate] ${DateTime.now()} First $i faces already had a clusterID");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
final currentEmbedding = faceInfos[i].embedding;
|
||||
final currentEmbedding = sortedFaceInfos[i].embedding;
|
||||
int closestIdx = -1;
|
||||
double closestDistance = double.infinity;
|
||||
if (i % 250 == 0) {
|
||||
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
|
||||
}
|
||||
for (int j = 0; j < i; j++) {
|
||||
for (int j = i - 1; j >= 0; j--) {
|
||||
final double distance = cosineDistForNormVectors(
|
||||
currentEmbedding,
|
||||
faceInfos[j].embedding,
|
||||
sortedFaceInfos[j].embedding,
|
||||
);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
|
@ -310,42 +316,43 @@ class FaceLinearClustering {
|
|||
}
|
||||
|
||||
if (closestDistance < recommendedDistanceThreshold) {
|
||||
if (faceInfos[closestIdx].clusterId == null) {
|
||||
if (sortedFaceInfos[closestIdx].clusterId == null) {
|
||||
// Ideally this should never happen, but just in case log it
|
||||
log(
|
||||
" [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID",
|
||||
" [ClusterIsolate] [WARNING] ${DateTime.now()} Found new cluster $clusterID",
|
||||
);
|
||||
clusterID++;
|
||||
faceInfos[closestIdx].clusterId = clusterID;
|
||||
sortedFaceInfos[closestIdx].clusterId = clusterID;
|
||||
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
|
||||
}
|
||||
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId;
|
||||
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
|
||||
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
|
||||
sortedFaceInfos[closestIdx].clusterId!;
|
||||
} else {
|
||||
clusterID++;
|
||||
faceInfos[i].clusterId = clusterID;
|
||||
sortedFaceInfos[i].clusterId = clusterID;
|
||||
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
|
||||
}
|
||||
}
|
||||
final Map<String, int> result = {};
|
||||
for (final faceInfo in faceInfos) {
|
||||
result[faceInfo.faceID] = faceInfo.clusterId!;
|
||||
}
|
||||
|
||||
stopwatchClustering.stop();
|
||||
log(
|
||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings (${faceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
||||
);
|
||||
// return result;
|
||||
|
||||
// NOTe: The main clustering logic is done, the following is just filtering and logging
|
||||
final input = x;
|
||||
final faceIdToCluster = result;
|
||||
stopwatchClustering.reset();
|
||||
stopwatchClustering.start();
|
||||
// analyze the results
|
||||
FaceLinearClustering._analyzeClusterResults(sortedFaceInfos);
|
||||
|
||||
final Set<String> newFaceIds = <String>{};
|
||||
input.forEach((key, value) {
|
||||
if (value.$1 == null) {
|
||||
newFaceIds.add(key);
|
||||
}
|
||||
});
|
||||
return newFaceIdToCluster;
|
||||
}
|
||||
|
||||
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
|
||||
final stopwatch = Stopwatch()..start();
|
||||
|
||||
final Map<String, int> faceIdToCluster = {};
|
||||
for (final faceInfo in sortedFaceInfos) {
|
||||
faceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
||||
}
|
||||
|
||||
// Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs
|
||||
final Map<int, int> clusterIdToSize = {};
|
||||
|
@ -356,12 +363,6 @@ class FaceLinearClustering {
|
|||
clusterIdToSize[value] = 1;
|
||||
}
|
||||
});
|
||||
final Map<String, int> faceIdToClusterFiltered = {};
|
||||
for (final entry in faceIdToCluster.entries) {
|
||||
if (clusterIdToSize[entry.value]! > 0 && newFaceIds.contains(entry.key)) {
|
||||
faceIdToClusterFiltered[entry.key] = entry.value;
|
||||
}
|
||||
}
|
||||
|
||||
// print top 10 cluster ids and their sizes based on the internal cluster id
|
||||
final clusterIds = faceIdToCluster.values.toSet();
|
||||
|
@ -369,7 +370,7 @@ class FaceLinearClustering {
|
|||
return faceIdToCluster.values.where((id) => id == clusterId).length;
|
||||
}).toList();
|
||||
clusterSizes.sort();
|
||||
// find clusters whose size is graeter than 1
|
||||
// find clusters whose size is greater than 1
|
||||
int oneClusterCount = 0;
|
||||
int moreThan5Count = 0;
|
||||
int moreThan10Count = 0;
|
||||
|
@ -377,57 +378,29 @@ class FaceLinearClustering {
|
|||
int moreThan50Count = 0;
|
||||
int moreThan100Count = 0;
|
||||
|
||||
// for (int i = 0; i < clusterSizes.length; i++) {
|
||||
// if (clusterSizes[i] > 100) {
|
||||
// moreThan100Count++;
|
||||
// } else if (clusterSizes[i] > 50) {
|
||||
// moreThan50Count++;
|
||||
// } else if (clusterSizes[i] > 20) {
|
||||
// moreThan20Count++;
|
||||
// } else if (clusterSizes[i] > 10) {
|
||||
// moreThan10Count++;
|
||||
// } else if (clusterSizes[i] > 5) {
|
||||
// moreThan5Count++;
|
||||
// } else if (clusterSizes[i] == 1) {
|
||||
// oneClusterCount++;
|
||||
// }
|
||||
// }
|
||||
for (int i = 0; i < clusterSizes.length; i++) {
|
||||
if (clusterSizes[i] > 100) {
|
||||
moreThan100Count++;
|
||||
}
|
||||
if (clusterSizes[i] > 50) {
|
||||
} else if (clusterSizes[i] > 50) {
|
||||
moreThan50Count++;
|
||||
}
|
||||
if (clusterSizes[i] > 20) {
|
||||
} else if (clusterSizes[i] > 20) {
|
||||
moreThan20Count++;
|
||||
}
|
||||
if (clusterSizes[i] > 10) {
|
||||
} else if (clusterSizes[i] > 10) {
|
||||
moreThan10Count++;
|
||||
}
|
||||
if (clusterSizes[i] > 5) {
|
||||
} else if (clusterSizes[i] > 5) {
|
||||
moreThan5Count++;
|
||||
}
|
||||
if (clusterSizes[i] == 1) {
|
||||
} else if (clusterSizes[i] == 1) {
|
||||
oneClusterCount++;
|
||||
}
|
||||
}
|
||||
|
||||
// print the metrics
|
||||
log(
|
||||
'[ClusterIsolate] Total clusters ${clusterIds.length}, '
|
||||
'oneClusterCount $oneClusterCount, '
|
||||
'moreThan5Count $moreThan5Count, '
|
||||
'moreThan10Count $moreThan10Count, '
|
||||
'moreThan20Count $moreThan20Count, '
|
||||
'moreThan50Count $moreThan50Count, '
|
||||
'moreThan100Count $moreThan100Count',
|
||||
"[ClusterIsolate] Total clusters ${clusterIds.length}: \n oneClusterCount $oneClusterCount \n moreThan5Count $moreThan5Count \n moreThan10Count $moreThan10Count \n moreThan20Count $moreThan20Count \n moreThan50Count $moreThan50Count \n moreThan100Count $moreThan100Count",
|
||||
);
|
||||
stopwatchClustering.stop();
|
||||
stopwatch.stop();
|
||||
log(
|
||||
"[ClusterIsolate] Clustering additional steps took ${stopwatchClustering.elapsedMilliseconds} ms",
|
||||
"[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms",
|
||||
);
|
||||
|
||||
// log('Top clusters count ${clusterSizes.reversed.take(10).toList()}');
|
||||
return faceIdToClusterFiltered;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue