[mob] Speed up cluster avg calculation
This commit is contained in:
parent
faa07a0704
commit
8e6617eed5
2 changed files with 64 additions and 4 deletions
|
@ -181,6 +181,32 @@ class FaceMLDataDB {
|
|||
return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
|
||||
}
|
||||
|
||||
Future<Map<int, Iterable<Uint8List>>> getFaceEmbeddingsForClusters(
|
||||
Iterable<int> clusterIDs, {
|
||||
int? limit,
|
||||
}) async {
|
||||
final db = await instance.database;
|
||||
final Map<int, List<Uint8List>> result = {};
|
||||
|
||||
final selectQuery = '''
|
||||
SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob
|
||||
FROM $faceClustersTable fc
|
||||
INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn
|
||||
WHERE fc.$fcClusterID IN (${clusterIDs.join(',')})
|
||||
${limit != null ? 'LIMIT $limit' : ''}
|
||||
''';
|
||||
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery);
|
||||
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceEmbedding = map[faceEmbeddingBlob] as Uint8List;
|
||||
result.putIfAbsent(clusterID, () => <Uint8List>[]).add(faceEmbedding);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Future<Face?> getCoverFaceForPerson({
|
||||
required int recentFileID,
|
||||
String? personID,
|
||||
|
@ -668,9 +694,11 @@ class FaceMLDataDB {
|
|||
await db.execute(deletePersonTable);
|
||||
await db.execute(dropClusterPersonTable);
|
||||
await db.execute(dropNotPersonFeedbackTable);
|
||||
await db.execute(dropClusterSummaryTable);
|
||||
await db.execute(createPersonTable);
|
||||
await db.execute(createClusterPersonTable);
|
||||
await db.execute(createNotPersonFeedbackTable);
|
||||
await db.execute(createClusterSummaryTable);
|
||||
}
|
||||
|
||||
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {
|
||||
|
|
|
@ -410,15 +410,44 @@ class ClusterFeedbackService {
|
|||
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
|
||||
);
|
||||
int indexedInCurrentRun = 0;
|
||||
final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null;
|
||||
w?.start();
|
||||
|
||||
w?.log(
|
||||
'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
|
||||
);
|
||||
final int maxEmbeddingToRead = 10000;
|
||||
int currentPendingRead = 0;
|
||||
List<int> clusterIdsToRead = [];
|
||||
for (final clusterID in sortedClusterIDs) {
|
||||
if (maxClusterInCurrentRun-- <= 0) {
|
||||
break;
|
||||
}
|
||||
indexedInCurrentRun++;
|
||||
if (currentPendingRead == 0) {
|
||||
currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0;
|
||||
clusterIdsToRead.add(clusterID);
|
||||
} else {
|
||||
if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) <
|
||||
maxEmbeddingToRead) {
|
||||
clusterIdsToRead.add(clusterID);
|
||||
currentPendingRead += allClusterIdsToCountMap[clusterID]!;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final Map<int, Iterable<Uint8List>> clusterEmbeddings = await FaceMLDataDB
|
||||
.instance
|
||||
.getFaceEmbeddingsForClusters(clusterIdsToRead);
|
||||
|
||||
w?.logAndReset(
|
||||
'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
|
||||
);
|
||||
|
||||
for (final clusterID in clusterEmbeddings.keys) {
|
||||
late List<double> avg;
|
||||
final Iterable<Uint8List> embedings =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
||||
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
||||
final List<double> sum = List.filled(192, 0);
|
||||
for (final embedding in embedings) {
|
||||
final data = EVector.fromBuffer(embedding).values;
|
||||
|
@ -431,12 +460,14 @@ class ClusterFeedbackService {
|
|||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeedingBuffer, embedings.length);
|
||||
// store the intermediate updates
|
||||
indexedInCurrentRun++;
|
||||
if (updatesForClusterSummary.length > 100) {
|
||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
updatesForClusterSummary.clear();
|
||||
if (kDebugMode) {
|
||||
_logger.info(
|
||||
'getUpdateClusterAvg $indexedInCurrentRun clusters in current one');
|
||||
'getUpdateClusterAvg $indexedInCurrentRun clusters in current one',
|
||||
);
|
||||
}
|
||||
}
|
||||
clusterAvg[clusterID] = avg;
|
||||
|
@ -444,6 +475,7 @@ class ClusterFeedbackService {
|
|||
if (updatesForClusterSummary.isNotEmpty) {
|
||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
}
|
||||
w?.logAndReset('done computing avg ');
|
||||
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
|
||||
|
||||
return clusterAvg;
|
||||
|
|
Loading…
Add table
Reference in a new issue