Sfoglia il codice sorgente

[mob] Re-cluster when removing file from person/cluster

laurenspriem 1 anno fa
parent
commit
e20f13f02b

+ 69 - 0
mobile/lib/face/db.dart

@@ -344,6 +344,17 @@ class FaceMLDataDB {
     return maps.map((e) => e[fcFaceId] as String).toSet();
   }
 
+  Future<Iterable<String>> getFaceIDsForPerson(String personID) async {
+    final db = await instance.sqliteAsyncDB;
+    final faceIdsResult = await db.getAll(
+      'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable '
+      'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
+      'WHERE $clusterPersonTable.$personIdColumn = ?',
+      [personID],
+    );
+    return faceIdsResult.map((e) => e[fcFaceId] as String).toSet();
+  }
+
   Future<Iterable<double>> getBlurValuesForCluster(int clusterID) async {
     final db = await instance.sqliteAsyncDB;
     const String query = '''
@@ -588,6 +599,44 @@ class FaceMLDataDB {
     return result;
   }
 
+  Future<Map<String, Uint8List>> getFaceEmbeddingMapForFaces(
+    Iterable<String> faceIDs,
+  ) async {
+    _logger.info('reading face embeddings for ${faceIDs.length} faces');
+    final db = await instance.sqliteAsyncDB;
+
+    // Define the batch size
+    const batchSize = 10000;
+    int offset = 0;
+
+    final Map<String, Uint8List> result = {};
+    while (true) {
+      // Query a batch of rows
+      final String query = '''
+        SELECT $faceIDColumn, $faceEmbeddingBlob 
+        FROM $facesTable 
+        WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) 
+        ORDER BY $faceIDColumn DESC 
+        LIMIT $batchSize OFFSET $offset         
+      ''';
+      final List<Map<String, dynamic>> maps = await db.getAll(query);
+      // Break the loop if no more rows
+      if (maps.isEmpty) {
+        break;
+      }
+      for (final map in maps) {
+        final faceID = map[faceIDColumn] as String;
+        result[faceID] = map[faceEmbeddingBlob] as Uint8List;
+      }
+      if (result.length > 10000) {
+        break;
+      }
+      offset += batchSize;
+    }
+    _logger.info('done reading face embeddings for ${faceIDs.length} faces');
+    return result;
+  }
+
   Future<int> getTotalFaceCount({
     double minFaceScore = kMinHighQualityFaceScore,
   }) async {
@@ -679,6 +728,26 @@ class FaceMLDataDB {
     );
   }
 
+  Future<void> bulkCaptureNotPersonFeedback(
+    Map<int, String> clusterToPersonID,
+  ) async {
+    final db = await instance.database;
+    final batch = db.batch();
+    for (final entry in clusterToPersonID.entries) {
+      final clusterID = entry.key;
+      final personID = entry.value;
+      batch.insert(
+        notPersonFeedback,
+        {
+          personIdColumn: personID,
+          clusterIDColumn: clusterID,
+        },
+        conflictAlgorithm: ConflictAlgorithm.replace,
+      );
+    }
+    await batch.commit(noResult: true);
+  }
+
   Future<int> removeClusterToPerson({
     required String personID,
     required int clusterID,

+ 79 - 8
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -117,17 +117,90 @@ class ClusterFeedbackService {
     List<EnteFile> files,
     PersonEntity p,
   ) async {
-    await FaceMLDataDB.instance.removeFilesFromPerson(files, p.remoteID);
-    Bus.instance.fire(PeopleChangedEvent());
+    try {
+      // Get the relevant faces to be removed
+      final faceIDs = await FaceMLDataDB.instance
+          .getFaceIDsForPerson(p.remoteID)
+          .then((iterable) => iterable.toList());
+      faceIDs.retainWhere((faceID) {
+        final fileID = getFileIdFromFaceId(faceID);
+        return files.any((file) => file.uploadedFileID == fileID);
+      });
+      final embeddings =
+          await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
+
+      final fileIDToCreationTime =
+          await FilesDB.instance.getFileIDToCreationTime();
+
+      // Re-cluster within the deleted faces
+      final newFaceIdToClusterID =
+          await FaceClusteringService.instance.predictWithinClusterComputer(
+        embeddings,
+        fileIDToCreationTime: fileIDToCreationTime,
+        distanceThreshold: 0.20,
+      );
+      if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
+        return;
+      }
+
+      // Update the deleted faces
+      await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
+
+      // Make sure the deleted faces don't get suggested in the future
+      final notClusterIdToPersonId = <int, String>{};
+      for (final clusterId in newFaceIdToClusterID.values.toSet()) {
+        notClusterIdToPersonId[clusterId] = p.remoteID;
+      }
+      await FaceMLDataDB.instance
+          .bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
+
+      Bus.instance.fire(PeopleChangedEvent());
+      return;
+    } catch (e, s) {
+      _logger.severe("Error in removeFilesFromPerson", e, s);
+      rethrow;
+    }
   }
 
   Future<void> removeFilesFromCluster(
     List<EnteFile> files,
     int clusterID,
   ) async {
-    await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
-    Bus.instance.fire(PeopleChangedEvent());
-    return;
+    try {
+      // Get the relevant faces to be removed
+      final faceIDs = await FaceMLDataDB.instance
+          .getFaceIDsForCluster(clusterID)
+          .then((iterable) => iterable.toList());
+      faceIDs.retainWhere((faceID) {
+        final fileID = getFileIdFromFaceId(faceID);
+        return files.any((file) => file.uploadedFileID == fileID);
+      });
+      final embeddings =
+          await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
+
+      final fileIDToCreationTime =
+          await FilesDB.instance.getFileIDToCreationTime();
+
+      // Re-cluster within the deleted faces
+      final newFaceIdToClusterID =
+          await FaceClusteringService.instance.predictWithinClusterComputer(
+        embeddings,
+        fileIDToCreationTime: fileIDToCreationTime,
+        distanceThreshold: 0.20,
+      );
+      if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
+        return;
+      }
+
+      // Update the deleted faces
+      await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
+
+      Bus.instance.fire(PeopleChangedEvent());
+      return;
+    } catch (e, s) {
+      _logger.severe("Error in removeFilesFromCluster", e, s);
+      rethrow;
+    }
   }
 
   Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
@@ -205,10 +278,8 @@ class ClusterFeedbackService {
 
     final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID);
     final originalFaceIDsSet = faceIDs.toSet();
-    final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
 
-    final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
-    embeddings.removeWhere((key, value) => !faceIDs.contains(key));
+    final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs);
 
     final fileIDToCreationTime =
         await FilesDB.instance.getFileIDToCreationTime();