Support for clustering in buckets

This commit is contained in:
laurenspriem 2024-03-22 11:49:23 +05:30
parent 85f76497b4
commit b1b3bcc534
2 changed files with 96 additions and 44 deletions

View file

@ -343,15 +343,13 @@ class FaceMLDataDB {
Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
double minScore = kMinHighQualityFaceScore,
int minClarity = kLaplacianThreshold,
int maxRows = 20000,
int maxFaces = 20000,
int offset = 0,
int batchSize = 10000,
}) async {
_logger.info('reading as float');
final db = await instance.database;
// Define the batch size
const batchSize = 10000;
int offset = 0;
final Map<String, (int?, Uint8List)> result = {};
while (true) {
// Query a batch of rows
@ -373,7 +371,7 @@ class FaceMLDataDB {
result[faceID] =
(map[faceClusterId] as int?, map[faceEmbeddingBlob] as Uint8List);
}
if (result.length >= 20000) {
if (result.length >= maxFaces) {
break;
}
offset += batchSize;
@ -419,10 +417,12 @@ class FaceMLDataDB {
return result;
}
Future<int> getTotalFaceCount() async {
Future<int> getTotalFaceCount({
double minFaceScore = kMinHighQualityFaceScore,
}) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold',
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold',
);
return maps.first['count'] as int;
}

View file

@ -365,48 +365,98 @@ class FaceMlService {
Future<void> clusterAllImages({
double minFaceScore = kMinHighQualityFaceScore,
bool clusterInBuckets = false,
}) async {
_logger.info("`clusterAllImages()` called");
try {
// Read all the embeddings from the database, in a map from faceID to embedding
final clusterStartTime = DateTime.now();
final faceIdToEmbedding = await FaceMLDataDB.instance.getFaceEmbeddingMap(
minScore: minFaceScore,
);
final gotFaceEmbeddingsTime = DateTime.now();
_logger.info(
'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
);
if (clusterInBuckets) {
// Get a sense of the total number of faces in the database
final int totalFaces = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: minFaceScore);
// Read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
_logger.info('read creation times from FilesDB in '
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
// read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster = await FaceLinearClustering.instance.predict(
faceIdToEmbedding,
fileIDToCreationTime: fileIDToCreationTime,
);
if (faceIdToCluster == null) {
_logger.warning("faceIdToCluster is null");
return;
const int bucketSize = 10000;
const int offsetIncrement = 7500;
const int batchSize = 5000;
int offset = 0;
while (true) {
final faceIdToEmbeddingBucket =
await FaceMLDataDB.instance.getFaceEmbeddingMap(
minScore: minFaceScore,
maxFaces: bucketSize,
offset: offset,
batchSize: batchSize,
);
if (faceIdToEmbeddingBucket.isEmpty) {
break;
}
if (offset > totalFaces) {
_logger.warning(
'offset > totalFaces, this should ideally not happen. offset: $offset, totalFaces: $totalFaces',
);
break;
}
final faceIdToCluster = await FaceLinearClustering.instance.predict(
faceIdToEmbeddingBucket,
fileIDToCreationTime: fileIDToCreationTime,
);
if (faceIdToCluster == null) {
_logger.warning("faceIdToCluster is null");
return;
}
await FaceMLDataDB.instance
.updatePersonIDForFaceIDIFNotSet(faceIdToCluster);
offset += offsetIncrement;
}
} else {
// Read all the embeddings from the database, in a map from faceID to embedding
final clusterStartTime = DateTime.now();
final faceIdToEmbedding =
await FaceMLDataDB.instance.getFaceEmbeddingMap(
minScore: minFaceScore,
);
final gotFaceEmbeddingsTime = DateTime.now();
_logger.info(
'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
);
// Read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
_logger.info('read creation times from FilesDB in '
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster = await FaceLinearClustering.instance.predict(
faceIdToEmbedding,
fileIDToCreationTime: fileIDToCreationTime,
);
if (faceIdToCluster == null) {
_logger.warning("faceIdToCluster is null");
return;
}
final clusterDoneTime = DateTime.now();
_logger.info(
'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
);
// Store the updated clusterIDs in the database
_logger.info(
'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
);
await FaceMLDataDB.instance
.updatePersonIDForFaceIDIFNotSet(faceIdToCluster);
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
'${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
}
final clusterDoneTime = DateTime.now();
_logger.info(
'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
);
// Store the updated clusterIDs in the database
_logger.info(
'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
);
await FaceMLDataDB.instance
.updatePersonIDForFaceIDIFNotSet(faceIdToCluster);
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
'${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
} catch (e, s) {
_logger.severe("`clusterAllImages` failed", e, s);
}
@ -522,7 +572,9 @@ class FaceMlService {
_logger.info(
"indexAllImages() analyzed $fileAnalyzedCount images, cooldown for 1 minute",
);
await Future.delayed(const Duration(minutes: 1));
await Future.delayed(const Duration(minutes: 1), () {
_logger.info("indexAllImages() cooldown finished");
});
}
}