فهرست منبع

[mob] Speed up cluster avg calculation

Neeraj Gupta 1 سال پیش
والد
کامیت
8e6617eed5
2فایلهای تغییر یافته به همراه64 افزوده شده و 4 حذف شده
  1. 28 0
      mobile/lib/face/db.dart
  2. 36 4
      mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

+ 28 - 0
mobile/lib/face/db.dart

@@ -181,6 +181,32 @@ class FaceMLDataDB {
     return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
   }
 
+  Future<Map<int, Iterable<Uint8List>>> getFaceEmbeddingsForClusters(
+    Iterable<int> clusterIDs, {
+    int? limit,
+  }) async {
+    final db = await instance.database;
+    final Map<int, List<Uint8List>> result = {};
+
+    final selectQuery = '''
+    SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob
+    FROM $faceClustersTable fc
+    INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn
+    WHERE fc.$fcClusterID IN (${clusterIDs.join(',')})
+    ${limit != null ? 'LIMIT $limit' : ''}
+  ''';
+
+    final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery);
+
+    for (final map in maps) {
+      final clusterID = map[fcClusterID] as int;
+      final faceEmbedding = map[faceEmbeddingBlob] as Uint8List;
+      result.putIfAbsent(clusterID, () => <Uint8List>[]).add(faceEmbedding);
+    }
+
+    return result;
+  }
+
   Future<Face?> getCoverFaceForPerson({
     required int recentFileID,
     String? personID,
@@ -668,9 +694,11 @@ class FaceMLDataDB {
     await db.execute(deletePersonTable);
     await db.execute(dropClusterPersonTable);
     await db.execute(dropNotPersonFeedbackTable);
+    await db.execute(dropClusterSummaryTable);
     await db.execute(createPersonTable);
     await db.execute(createClusterPersonTable);
     await db.execute(createNotPersonFeedbackTable);
+    await db.execute(createClusterSummaryTable);
   }
 
   Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {

+ 36 - 4
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -410,15 +410,44 @@ class ClusterFeedbackService {
           allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
     );
     int indexedInCurrentRun = 0;
+    final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null;
+    w?.start();
 
+    w?.log(
+      'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
+    );
+    final int maxEmbeddingToRead = 10000;
+    int currentPendingRead = 0;
+    List<int> clusterIdsToRead = [];
     for (final clusterID in sortedClusterIDs) {
       if (maxClusterInCurrentRun-- <= 0) {
         break;
       }
-      indexedInCurrentRun++;
+      if (currentPendingRead == 0) {
+        currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0;
+        clusterIdsToRead.add(clusterID);
+      } else {
+        if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) <
+            maxEmbeddingToRead) {
+          clusterIdsToRead.add(clusterID);
+          currentPendingRead += allClusterIdsToCountMap[clusterID]!;
+        } else {
+          break;
+        }
+      }
+    }
+
+    final Map<int, Iterable<Uint8List>> clusterEmbeddings = await FaceMLDataDB
+        .instance
+        .getFaceEmbeddingsForClusters(clusterIdsToRead);
+
+    w?.logAndReset(
+      'read  $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters',
+    );
+
+    for (final clusterID in clusterEmbeddings.keys) {
       late List<double> avg;
-      final Iterable<Uint8List> embedings =
-          await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
+      final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
       final List<double> sum = List.filled(192, 0);
       for (final embedding in embedings) {
         final data = EVector.fromBuffer(embedding).values;
@@ -431,12 +460,14 @@ class ClusterFeedbackService {
       updatesForClusterSummary[clusterID] =
           (avgEmbeedingBuffer, embedings.length);
       // store the intermediate updates
+      indexedInCurrentRun++;
       if (updatesForClusterSummary.length > 100) {
         await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
         updatesForClusterSummary.clear();
         if (kDebugMode) {
           _logger.info(
-              'getUpdateClusterAvg $indexedInCurrentRun clusters in current one');
+            'getUpdateClusterAvg $indexedInCurrentRun clusters in current one',
+          );
         }
       }
       clusterAvg[clusterID] = avg;
@@ -444,6 +475,7 @@ class ClusterFeedbackService {
     if (updatesForClusterSummary.isNotEmpty) {
       await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
     }
+    w?.logAndReset('done computing avg ');
     _logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
 
     return clusterAvg;