[mob] Add completeClustering functionality
This commit is contained in:
parent
e3fd836901
commit
45d18b187c
1 changed files with 139 additions and 0 deletions
|
@ -4,6 +4,7 @@ import "dart:isolate";
|
|||
import "dart:math" show max;
|
||||
import "dart:typed_data";
|
||||
|
||||
import "package:computer/computer.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/dtype.dart";
|
||||
import "package:ml_linalg/vector.dart";
|
||||
|
@ -42,6 +43,7 @@ enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
|
|||
|
||||
class FaceClusteringService {
|
||||
final _logger = Logger("FaceLinearClustering");
|
||||
final _computer = Computer.shared();
|
||||
|
||||
Timer? _inactivityTimer;
|
||||
final Duration _inactivityDuration = const Duration(minutes: 3);
|
||||
|
@ -243,6 +245,45 @@ class FaceClusteringService {
|
|||
}
|
||||
}
|
||||
|
||||
Future<Map<String, int>> predictComplete(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
_logger.warning(
|
||||
"Complete Clustering dataset of embeddings is empty, returning empty list.",
|
||||
);
|
||||
return {};
|
||||
}
|
||||
|
||||
// Clustering inside the isolate
|
||||
_logger.info(
|
||||
"Start Complete clustering on ${input.length} embeddings inside computer isolate",
|
||||
);
|
||||
|
||||
try {
|
||||
final startTime = DateTime.now();
|
||||
final faceIdToCluster = await _computer.compute(
|
||||
runCompleteClustering,
|
||||
param: {
|
||||
"input": input,
|
||||
"fileIDToCreationTime": fileIDToCreationTime,
|
||||
"distanceThreshold": distanceThreshold,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
) as Map<String, int>;
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
|
||||
);
|
||||
return faceIdToCluster;
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<List<List<String>>> predictDbscan(
|
||||
Map<String, Uint8List> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
|
@ -537,6 +578,104 @@ class FaceClusteringService {
|
|||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static Map<String, int> runCompleteClustering(Map args) {
|
||||
final input = args['input'] as Map<String, Uint8List>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
|
||||
log(
|
||||
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
||||
);
|
||||
|
||||
// Organize everything into a list of FaceInfo objects
|
||||
final List<FaceInfo> faceInfos = [];
|
||||
for (final entry in input.entries) {
|
||||
faceInfos.add(
|
||||
FaceInfo(
|
||||
faceID: entry.key,
|
||||
vEmbedding: Vector.fromList(
|
||||
EVector.fromBuffer(entry.value).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
|
||||
if (fileIDToCreationTime != null) {
|
||||
faceInfos.sort((a, b) {
|
||||
if (a.fileCreationTime == null && b.fileCreationTime == null) {
|
||||
return 0;
|
||||
} else if (a.fileCreationTime == null) {
|
||||
return 1;
|
||||
} else if (b.fileCreationTime == null) {
|
||||
return -1;
|
||||
} else {
|
||||
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (faceInfos.isEmpty) {
|
||||
return {};
|
||||
}
|
||||
final int totalFaces = faceInfos.length;
|
||||
|
||||
// Start actual clustering
|
||||
log(
|
||||
"[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering",
|
||||
);
|
||||
|
||||
// set current epoch time as clusterID
|
||||
int clusterID = DateTime.now().microsecondsSinceEpoch;
|
||||
|
||||
// Start actual clustering
|
||||
final Map<String, int> newFaceIdToCluster = {};
|
||||
final stopwatchClustering = Stopwatch()..start();
|
||||
for (int i = 0; i < totalFaces; i++) {
|
||||
int closestIdx = -1;
|
||||
double closestDistance = double.infinity;
|
||||
if (i % 250 == 0) {
|
||||
log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
|
||||
}
|
||||
for (int j = 0; j < totalFaces; j++) {
|
||||
if (i == j) continue;
|
||||
final double distance =
|
||||
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
}
|
||||
}
|
||||
|
||||
if (closestDistance < distanceThreshold) {
|
||||
if (faceInfos[closestIdx].clusterId == null) {
|
||||
clusterID++;
|
||||
faceInfos[closestIdx].clusterId = clusterID;
|
||||
newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
|
||||
}
|
||||
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
|
||||
newFaceIdToCluster[faceInfos[i].faceID] =
|
||||
faceInfos[closestIdx].clusterId!;
|
||||
} else {
|
||||
clusterID++;
|
||||
faceInfos[i].clusterId = clusterID;
|
||||
newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
|
||||
}
|
||||
}
|
||||
|
||||
stopwatchClustering.stop();
|
||||
log(
|
||||
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
||||
);
|
||||
|
||||
return newFaceIdToCluster;
|
||||
}
|
||||
|
||||
static List<List<String>> _runDbscanClustering(Map args) {
|
||||
final input = args['input'] as Map<String, Uint8List>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
|
|
Loading…
Add table
Reference in a new issue