Explorar el Código

Merge branch 'mobile_face' of https://github.com/ente-io/auth into mobile_face

Neeraj Gupta hace 1 año
padre
commit
cc682a0a09

+ 2 - 2
mobile/lib/db/files_db.dart

@@ -1316,8 +1316,8 @@ class FilesDB {
   }
 
   Future<Map<int, int>> getFileIDToCreationTime() async {
-    final db = await instance.database;
-    final rows = await db.rawQuery(
+    final db = await instance.sqliteAsyncDB;
+    final rows = await db.getAll(
       '''
       SELECT $columnUploadedFileID, $columnCreationTime
       FROM $filesTable

+ 1 - 0
mobile/lib/events/files_updated_event.dart

@@ -27,4 +27,5 @@ enum EventType {
   unhide,
   coverChanged,
   peopleChanged,
+  peopleClusterChanged,
 }

+ 20 - 1
mobile/lib/events/people_changed_event.dart

@@ -1,3 +1,22 @@
 import "package:photos/events/event.dart";
+import "package:photos/models/file/file.dart";
 
-class PeopleChangedEvent extends Event {}
+class PeopleChangedEvent extends Event {
+  final List<EnteFile>? relevantFiles;
+  final PeopleEventType type;
+  final String source;
+
+  PeopleChangedEvent({
+    this.relevantFiles, 
+    this.type = PeopleEventType.defaultType,
+    this.source = "",
+  });
+
+  @override
+  String get reason => '$runtimeType{type: ${type.name}, "via": $source}';
+}
+
+enum PeopleEventType {
+  defaultType,
+  removedFilesFromCluster,
+}

+ 198 - 41
mobile/lib/face/db.dart

@@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart';
 import "package:photos/face/db_model_mappers.dart";
 import "package:photos/face/model/face.dart";
 import "package:photos/models/file/file.dart";
+import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
 import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
 import 'package:sqflite/sqflite.dart';
 import 'package:sqlite_async/sqlite_async.dart' as sqlite_async;
@@ -160,27 +161,27 @@ class FaceMLDataDB {
     final db = await instance.database;
     // find out clusterIds that are assigned to other persons using the clusters table
     final List<Map<String, dynamic>> maps = await db.rawQuery(
-      'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
+      'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
       [personID],
     );
     final Set<int> ignoredClusterIDs =
-        maps.map((e) => e[cluserIDColumn] as int).toSet();
+        maps.map((e) => e[clusterIDColumn] as int).toSet();
     final List<Map<String, dynamic>> rejectMaps = await db.rawQuery(
-      'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
+      'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
       [personID],
     );
     final Set<int> rejectClusterIDs =
-        rejectMaps.map((e) => e[cluserIDColumn] as int).toSet();
+        rejectMaps.map((e) => e[clusterIDColumn] as int).toSet();
     return ignoredClusterIDs.union(rejectClusterIDs);
   }
 
   Future<Set<int>> getPersonClusterIDs(String personID) async {
     final db = await instance.database;
     final List<Map<String, dynamic>> maps = await db.rawQuery(
-      'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?',
+      'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?',
       [personID],
     );
-    return maps.map((e) => e[cluserIDColumn] as int).toSet();
+    return maps.map((e) => e[clusterIDColumn] as int).toSet();
   }
 
   Future<void> clearTable() async {
@@ -249,16 +250,16 @@ class FaceMLDataDB {
       }
       final cluterRows = await db.query(
         clusterPersonTable,
-        columns: [cluserIDColumn],
+        columns: [clusterIDColumn],
         where: '$personIdColumn = ?',
         whereArgs: [personID],
       );
       final clusterIDs =
-          cluterRows.map((e) => e[cluserIDColumn] as int).toList();
+          cluterRows.map((e) => e[clusterIDColumn] as int).toList();
       final List<Map<String, dynamic>> faceMaps = await db.rawQuery(
         'SELECT * FROM $facesTable where '
         '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where  $fcClusterID IN (${clusterIDs.join(",")}))'
-        'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinHighQualityFaceScore ORDER BY $faceScore DESC',
+        'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC',
       );
       if (faceMaps.isNotEmpty) {
         if (avatarFileId != null) {
@@ -308,8 +309,6 @@ class FaceMLDataDB {
         faceBlur,
         imageHeight,
         imageWidth,
-        faceArea,
-        faceVisibilityScore,
         mlVersionColumn,
       ],
       where: '$fileIDColumn = ?',
@@ -334,16 +333,60 @@ class FaceMLDataDB {
   }
 
   Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async {
-    final db = await instance.database;
-    final List<Map<String, dynamic>> maps = await db.query(
-      faceClustersTable,
-      columns: [fcFaceId],
-      where: '$fcClusterID = ?',
-      whereArgs: [clusterID],
+    final db = await instance.sqliteAsyncDB;
+    final List<Map<String, dynamic>> maps = await db.getAll(
+      'SELECT $fcFaceId FROM $faceClustersTable '
+      'WHERE $faceClustersTable.$fcClusterID = ?',
+      [clusterID],
     );
     return maps.map((e) => e[fcFaceId] as String).toSet();
   }
 
+  Future<Iterable<String>> getFaceIDsForPerson(String personID) async {
+    final db = await instance.sqliteAsyncDB;
+    final faceIdsResult = await db.getAll(
+      'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable '
+      'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
+      'WHERE $clusterPersonTable.$personIdColumn = ?',
+      [personID],
+    );
+    return faceIdsResult.map((e) => e[fcFaceId] as String).toSet();
+  }
+
+  Future<Iterable<double>> getBlurValuesForCluster(int clusterID) async {
+    final db = await instance.sqliteAsyncDB;
+    const String query = '''
+        SELECT $facesTable.$faceBlur 
+        FROM $facesTable 
+        JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$fcFaceId 
+        WHERE $faceClustersTable.$fcClusterID = ?
+      ''';
+    // const String query2 = '''
+    //     SELECT $faceBlur
+    //     FROM $facesTable
+    //     WHERE $faceIDColumn IN (SELECT $fcFaceId FROM $faceClustersTable WHERE $fcClusterID = ?)
+    //   ''';
+    final List<Map<String, dynamic>> maps = await db.getAll(
+      query,
+      [clusterID],
+    );
+    return maps.map((e) => e[faceBlur] as double).toSet();
+  }
+
+  Future<Map<String, double>> getFaceIDsToBlurValues(
+    int maxBlurValue,
+  ) async {
+    final db = await instance.sqliteAsyncDB;
+    final List<Map<String, dynamic>> maps = await db.getAll(
+      'SELECT $faceIDColumn, $faceBlur FROM $facesTable WHERE $faceBlur < $maxBlurValue AND $faceBlur > 1 ORDER BY $faceBlur ASC',
+    );
+    final Map<String, double> result = {};
+    for (final map in maps) {
+      result[map[faceIDColumn] as String] = map[faceBlur] as double;
+    }
+    return result;
+  }
+
   Future<Map<String, int?>> getFaceIdsToClusterIds(
     Iterable<String> faceIds,
   ) async {
@@ -376,14 +419,14 @@ class FaceMLDataDB {
   }
 
   Future<void> forceUpdateClusterIds(
-    Map<String, int> faceIDToPersonID,
+    Map<String, int> faceIDToClusterID,
   ) async {
     final db = await instance.database;
 
     // Start a batch
     final batch = db.batch();
 
-    for (final map in faceIDToPersonID.entries) {
+    for (final map in faceIDToClusterID.entries) {
       final faceID = map.key;
       final clusterID = map.value;
       batch.insert(
@@ -410,12 +453,64 @@ class FaceMLDataDB {
     );
   }
 
+  Future<Set<FaceInfoForClustering>> getFaceInfoForClustering({
+    double minScore = kMinimumQualityFaceScore,
+    int minClarity = kLaplacianHardThreshold,
+    int maxFaces = 20000,
+    int offset = 0,
+    int batchSize = 10000,
+  }) async {
+    final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start();
+    w.logAndReset(
+      'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize',
+    );
+    final db = await instance.sqliteAsyncDB;
+
+    final Set<FaceInfoForClustering> result = {};
+    while (true) {
+      // Query a batch of rows
+      final List<Map<String, dynamic>> maps = await db.getAll(
+        'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable'
+        ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
+        ' ORDER BY $faceIDColumn'
+        ' DESC LIMIT $batchSize OFFSET $offset',
+      );
+      // Break the loop if no more rows
+      if (maps.isEmpty) {
+        break;
+      }
+      final List<String> faceIds = [];
+      for (final map in maps) {
+        faceIds.add(map[faceIDColumn] as String);
+      }
+      final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds);
+      for (final map in maps) {
+        final faceID = map[faceIDColumn] as String;
+        final faceInfo = FaceInfoForClustering(
+          faceID: faceID,
+          clusterId: faceIdToClusterId[faceID],
+          embeddingBytes: map[faceEmbeddingBlob] as Uint8List,
+          faceScore: map[faceScore] as double,
+          blurValue: map[faceBlur] as double,
+          isSideways: (map[isSideways] as int) == 1,
+        );
+        result.add(faceInfo);
+      }
+      if (result.length >= maxFaces) {
+        break;
+      }
+      offset += batchSize;
+    }
+    w.stopWithLog('done reading face embeddings ${result.length}');
+    return result;
+  }
+
   /// Returns a map of faceID to record of clusterId and faceEmbeddingBlob
   ///
   /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity]
   Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
-    double minScore = kMinHighQualityFaceScore,
-    int minClarity = kLaplacianThreshold,
+    double minScore = kMinimumQualityFaceScore,
+    int minClarity = kLaplacianHardThreshold,
     int maxFaces = 20000,
     int offset = 0,
     int batchSize = 10000,
@@ -481,7 +576,7 @@ class FaceMLDataDB {
         facesTable,
         columns: [faceIDColumn, faceEmbeddingBlob],
         where:
-            '$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
+            '$faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
         limit: batchSize,
         offset: offset,
         orderBy: '$faceIDColumn DESC',
@@ -503,12 +598,50 @@ class FaceMLDataDB {
     return result;
   }
 
+  Future<Map<String, Uint8List>> getFaceEmbeddingMapForFaces(
+    Iterable<String> faceIDs,
+  ) async {
+    _logger.info('reading face embeddings for ${faceIDs.length} faces');
+    final db = await instance.sqliteAsyncDB;
+
+    // Define the batch size
+    const batchSize = 10000;
+    int offset = 0;
+
+    final Map<String, Uint8List> result = {};
+    while (true) {
+      // Query a batch of rows
+      final String query = '''
+        SELECT $faceIDColumn, $faceEmbeddingBlob 
+        FROM $facesTable 
+        WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) 
+        ORDER BY $faceIDColumn DESC 
+        LIMIT $batchSize OFFSET $offset         
+      ''';
+      final List<Map<String, dynamic>> maps = await db.getAll(query);
+      // Break the loop if no more rows
+      if (maps.isEmpty) {
+        break;
+      }
+      for (final map in maps) {
+        final faceID = map[faceIDColumn] as String;
+        result[faceID] = map[faceEmbeddingBlob] as Uint8List;
+      }
+      if (result.length > 10000) {
+        break;
+      }
+      offset += batchSize;
+    }
+    _logger.info('done reading face embeddings for ${faceIDs.length} faces');
+    return result;
+  }
+
   Future<int> getTotalFaceCount({
-    double minFaceScore = kMinHighQualityFaceScore,
+    double minFaceScore = kMinimumQualityFaceScore,
   }) async {
     final db = await instance.sqliteAsyncDB;
     final List<Map<String, dynamic>> maps = await db.getAll(
-      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold',
+      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold',
     );
     return maps.first['count'] as int;
   }
@@ -517,7 +650,7 @@ class FaceMLDataDB {
     final db = await instance.sqliteAsyncDB;
 
     final List<Map<String, dynamic>> totalFacesMaps = await db.getAll(
-      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold',
+      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold',
     );
     final int totalFaces = totalFacesMaps.first['count'] as int;
 
@@ -530,11 +663,11 @@ class FaceMLDataDB {
   }
 
   Future<int> getBlurryFaceCount([
-    int blurThreshold = kLaplacianThreshold,
+    int blurThreshold = kLaplacianHardThreshold,
   ]) async {
     final db = await instance.database;
     final List<Map<String, dynamic>> maps = await db.rawQuery(
-      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore',
+      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinimumQualityFaceScore',
     );
     return maps.first['count'] as int;
   }
@@ -555,7 +688,7 @@ class FaceMLDataDB {
       clusterPersonTable,
       {
         personIdColumn: personID,
-        cluserIDColumn: clusterID,
+        clusterIDColumn: clusterID,
       },
     );
   }
@@ -572,7 +705,7 @@ class FaceMLDataDB {
         clusterPersonTable,
         {
           personIdColumn: personID,
-          cluserIDColumn: clusterID,
+          clusterIDColumn: clusterID,
         },
         conflictAlgorithm: ConflictAlgorithm.replace,
       );
@@ -589,11 +722,31 @@ class FaceMLDataDB {
       notPersonFeedback,
       {
         personIdColumn: personID,
-        cluserIDColumn: clusterID,
+        clusterIDColumn: clusterID,
       },
     );
   }
 
+  Future<void> bulkCaptureNotPersonFeedback(
+    Map<int, String> clusterToPersonID,
+  ) async {
+    final db = await instance.database;
+    final batch = db.batch();
+    for (final entry in clusterToPersonID.entries) {
+      final clusterID = entry.key;
+      final personID = entry.value;
+      batch.insert(
+        notPersonFeedback,
+        {
+          personIdColumn: personID,
+          clusterIDColumn: clusterID,
+        },
+        conflictAlgorithm: ConflictAlgorithm.replace,
+      );
+    }
+    await batch.commit(noResult: true);
+  }
+
   Future<int> removeClusterToPerson({
     required String personID,
     required int clusterID,
@@ -601,7 +754,7 @@ class FaceMLDataDB {
     final db = await instance.database;
     return db.delete(
       clusterPersonTable,
-      where: '$personIdColumn = ? AND $cluserIDColumn = ?',
+      where: '$personIdColumn = ? AND $clusterIDColumn = ?',
       whereArgs: [personID, clusterID],
     );
   }
@@ -613,13 +766,13 @@ class FaceMLDataDB {
       final List<Map<String, dynamic>> maps = await db.rawQuery(
         'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable '
         'INNER JOIN $clusterPersonTable '
-        'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn '
+        'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
         'WHERE $clusterPersonTable.$personIdColumn = ?',
         [personID],
       );
       final Map<int, Set<int>> result = {};
       for (final map in maps) {
-        final clusterID = map[cluserIDColumn] as int;
+        final clusterID = map[clusterIDColumn] as int;
         final String faceID = map[fcFaceId] as String;
         final fileID = int.parse(faceID.split('_').first);
         result[fileID] = (result[fileID] ?? {})..add(clusterID);
@@ -664,7 +817,7 @@ class FaceMLDataDB {
       batch.insert(
         clusterSummaryTable,
         {
-          cluserIDColumn: cluserID,
+          clusterIDColumn: cluserID,
           avgColumn: avg,
           countColumn: count,
         },
@@ -676,12 +829,16 @@ class FaceMLDataDB {
   }
 
   /// Returns a map of clusterID to (avg embedding, count)
-  Future<Map<int, (Uint8List, int)>> clusterSummaryAll() async {
-    final db = await instance.database;
+  Future<Map<int, (Uint8List, int)>> getAllClusterSummary([
+    int? minClusterSize,
+  ]) async {
+    final db = await instance.sqliteAsyncDB;
     final Map<int, (Uint8List, int)> result = {};
-    final rows = await db.rawQuery('SELECT * from $clusterSummaryTable');
+    final rows = await db.getAll(
+      'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}',
+    );
     for (final r in rows) {
-      final id = r[cluserIDColumn] as int;
+      final id = r[clusterIDColumn] as int;
       final avg = r[avgColumn] as Uint8List;
       final count = r[countColumn] as int;
       result[id] = (avg, count);
@@ -692,11 +849,11 @@ class FaceMLDataDB {
   Future<Map<int, String>> getClusterIDToPersonID() async {
     final db = await instance.database;
     final List<Map<String, dynamic>> maps = await db.rawQuery(
-      'SELECT $personIdColumn, $cluserIDColumn FROM $clusterPersonTable',
+      'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable',
     );
     final Map<int, String> result = {};
     for (final map in maps) {
-      result[map[cluserIDColumn] as int] = map[personIdColumn] as String;
+      result[map[clusterIDColumn] as int] = map[personIdColumn] as String;
     }
     return result;
   }
@@ -741,7 +898,7 @@ class FaceMLDataDB {
     final db = await instance.database;
     final faceIdsResult = await db.rawQuery(
       'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable '
-      'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn '
+      'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
       'WHERE $clusterPersonTable.$personIdColumn = ?',
       [personID],
     );

+ 8 - 10
mobile/lib/face/db_fields.dart

@@ -8,8 +8,7 @@ const faceDetectionColumn = 'detection';
 const faceEmbeddingBlob = 'eBlob';
 const faceScore = 'score';
 const faceBlur = 'blur';
-const faceArea = 'area';
-const faceVisibilityScore = 'visibility';
+const isSideways = 'is_sideways';
 const imageWidth = 'width';
 const imageHeight = 'height';
 const faceClusterId = 'cluster_id';
@@ -22,10 +21,9 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable (
   $faceEmbeddingBlob BLOB NOT NULL,
   $faceScore  REAL NOT NULL,
   $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault,
+  $isSideways	INTEGER NOT NULL DEFAULT 0,
   $imageHeight	INTEGER NOT NULL DEFAULT 0,
   $imageWidth	INTEGER NOT NULL DEFAULT 0,
-  $faceArea	INTEGER NOT NULL DEFAULT 0,
-  $faceVisibilityScore	INTEGER NOT NULL DEFAULT -1,
   $mlVersionColumn	INTEGER NOT NULL DEFAULT -1,
   PRIMARY KEY($fileIDColumn, $faceIDColumn)
   );
@@ -62,13 +60,13 @@ const deletePersonTable = 'DROP TABLE IF EXISTS $personTable';
 // Clusters Table Fields & Schema Queries
 const clusterPersonTable = 'cluster_person';
 const personIdColumn = 'person_id';
-const cluserIDColumn = 'cluster_id';
+const clusterIDColumn = 'cluster_id';
 
 const createClusterPersonTable = '''
 CREATE TABLE IF NOT EXISTS $clusterPersonTable (
   $personIdColumn	TEXT NOT NULL,
-  $cluserIDColumn	INTEGER NOT NULL,
-  PRIMARY KEY($personIdColumn, $cluserIDColumn)
+  $clusterIDColumn	INTEGER NOT NULL,
+  PRIMARY KEY($personIdColumn, $clusterIDColumn)
 );
 ''';
 const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable';
@@ -80,10 +78,10 @@ const avgColumn = 'avg';
 const countColumn = 'count';
 const createClusterSummaryTable = '''
 CREATE TABLE IF NOT EXISTS $clusterSummaryTable (
-  $cluserIDColumn	INTEGER NOT NULL,
+  $clusterIDColumn	INTEGER NOT NULL,
   $avgColumn BLOB NOT NULL,
   $countColumn INTEGER NOT NULL,
-  PRIMARY KEY($cluserIDColumn)
+  PRIMARY KEY($clusterIDColumn)
 );
 ''';
 
@@ -97,7 +95,7 @@ const notPersonFeedback = 'not_person_feedback';
 const createNotPersonFeedbackTable = '''
 CREATE TABLE IF NOT EXISTS $notPersonFeedback (
   $personIdColumn	TEXT NOT NULL,
-  $cluserIDColumn	INTEGER NOT NULL
+  $clusterIDColumn	INTEGER NOT NULL
 );
 ''';
 const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback';

+ 1 - 2
mobile/lib/face/db_model_mappers.dart

@@ -34,9 +34,8 @@ Map<String, dynamic> mapRemoteToFaceDB(Face face) {
     ).writeToBuffer(),
     faceScore: face.score,
     faceBlur: face.blur,
+    isSideways: face.detection.faceIsSideways() ? 1 : 0,
     mlVersionColumn: faceMlVersion,
-    faceArea: face.area(),
-    faceVisibilityScore: face.visibility,
     imageWidth: face.fileInfo?.imageWidth ?? 0,
     imageHeight: face.fileInfo?.imageHeight ?? 0,
   };

+ 72 - 1
mobile/lib/face/model/detection.dart

@@ -1,6 +1,9 @@
+import "dart:math" show min, max;
+
 import "package:logging/logging.dart";
 import "package:photos/face/model/box.dart";
 import "package:photos/face/model/landmark.dart";
+import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
 
 /// Stores the face detection data, notably the bounding box and landmarks.
 ///
@@ -19,7 +22,7 @@ class Detection {
 
   bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty;
 
-  // emoty box
+  // empty box
   Detection.empty()
       : box = FaceBox(
           xMin: 0,
@@ -89,4 +92,72 @@ class Detection {
       return -1;
     }
   }
+
+  FaceDirection getFaceDirection() {
+    if (isEmpty) {
+      return FaceDirection.straight;
+    }
+    final leftEye = [landmarks[0].x, landmarks[0].y];
+    final rightEye = [landmarks[1].x, landmarks[1].y];
+    final nose = [landmarks[2].x, landmarks[2].y];
+    final leftMouth = [landmarks[3].x, landmarks[3].y];
+    final rightMouth = [landmarks[4].x, landmarks[4].y];
+
+    final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
+    final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
+    final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
+
+    final bool faceIsUpright =
+        (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
+            (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
+
+    final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) &&
+        (nose[0] < min(leftMouth[0], rightMouth[0]));
+    final bool noseStickingOutRight =
+        (nose[0] > max(leftEye[0], rightEye[0])) &&
+            (nose[0] > max(leftMouth[0], rightMouth[0]));
+
+    final bool noseCloseToLeftEye =
+        (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX;
+    final bool noseCloseToRightEye =
+        (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX;
+
+    // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) {
+    if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
+      return FaceDirection.left;
+      // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) {
+    } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
+      return FaceDirection.right;
+    }
+
+    return FaceDirection.straight;
+  }
+
+  bool faceIsSideways() {
+    if (isEmpty) {
+      return false;
+    }
+    final leftEye = [landmarks[0].x, landmarks[0].y];
+    final rightEye = [landmarks[1].x, landmarks[1].y];
+    final nose = [landmarks[2].x, landmarks[2].y];
+    final leftMouth = [landmarks[3].x, landmarks[3].y];
+    final rightMouth = [landmarks[4].x, landmarks[4].y];
+
+    final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
+    final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
+    final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
+
+    final bool faceIsUpright =
+        (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
+            (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
+
+    final bool noseStickingOutLeft =
+        (nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
+            (nose[0] < min(leftMouth[0], rightMouth[0]));
+    final bool noseStickingOutRight =
+        (nose[0] > max(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
+            (nose[0] > max(leftMouth[0], rightMouth[0]));
+
+    return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight);
+  }
 }

+ 2 - 2
mobile/lib/face/model/face.dart

@@ -20,9 +20,9 @@ class Face {
   final double blur;
   FileInfo? fileInfo;
 
-  bool get isBlurry => blur < kLaplacianThreshold;
+  bool get isBlurry => blur < kLaplacianHardThreshold;
 
-  bool get hasHighScore => score > kMinHighQualityFaceScore;
+  bool get hasHighScore => score > kMinimumQualityFaceScore;
 
   bool get isHighQuality => (!isBlurry) && hasHighScore;
 

+ 467 - 75
mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart

@@ -2,19 +2,26 @@ import "dart:async";
 import "dart:developer";
 import "dart:isolate";
 import "dart:math" show max;
-import "dart:typed_data";
+import "dart:typed_data" show Uint8List;
 
+import "package:computer/computer.dart";
+import "package:flutter/foundation.dart" show kDebugMode;
 import "package:logging/logging.dart";
 import "package:ml_linalg/dtype.dart";
 import "package:ml_linalg/vector.dart";
 import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
+import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
+import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:simple_cluster/simple_cluster.dart";
 import "package:synchronized/synchronized.dart";
 
 class FaceInfo {
   final String faceID;
+  final double? faceScore;
+  final double? blurValue;
+  final bool? badFace;
   final List<double>? embedding;
   final Vector? vEmbedding;
   int? clusterId;
@@ -23,6 +30,9 @@ class FaceInfo {
   int? fileCreationTime;
   FaceInfo({
     required this.faceID,
+    this.faceScore,
+    this.blurValue,
+    this.badFace,
     this.embedding,
     this.vEmbedding,
     this.clusterId,
@@ -32,8 +42,18 @@ class FaceInfo {
 
 enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
 
+class ClusteringResult {
+  final Map<String, int> newFaceIdToCluster;
+  final Map<int, (Uint8List, int)>? newClusterSummaries;
+  ClusteringResult({
+    required this.newFaceIdToCluster,
+    required this.newClusterSummaries,
+  });
+}
+
 class FaceClusteringService {
   final _logger = Logger("FaceLinearClustering");
+  final _computer = Computer.shared();
 
   Timer? _inactivityTimer;
   final Duration _inactivityDuration = const Duration(minutes: 3);
@@ -49,6 +69,7 @@ class FaceClusteringService {
   bool isRunning = false;
 
   static const kRecommendedDistanceThreshold = 0.24;
+  static const kConservativeDistanceThreshold = 0.06;
 
   // singleton pattern
   FaceClusteringService._privateConstructor();
@@ -100,31 +121,11 @@ class FaceClusteringService {
       try {
         switch (function) {
           case ClusterOperation.linearIncrementalClustering:
-            final input = args['input'] as Map<String, (int?, Uint8List)>;
-            final fileIDToCreationTime =
-                args['fileIDToCreationTime'] as Map<int, int>?;
-            final distanceThreshold = args['distanceThreshold'] as double;
-            final offset = args['offset'] as int?;
-            final result = FaceClusteringService._runLinearClustering(
-              input,
-              fileIDToCreationTime: fileIDToCreationTime,
-              distanceThreshold: distanceThreshold,
-              offset: offset,
-            );
+            final result = FaceClusteringService.runLinearClustering(args);
             sendPort.send(result);
             break;
           case ClusterOperation.dbscanClustering:
-            final input = args['input'] as Map<String, Uint8List>;
-            final fileIDToCreationTime =
-                args['fileIDToCreationTime'] as Map<int, int>?;
-            final eps = args['eps'] as double;
-            final minPts = args['minPts'] as int;
-            final result = FaceClusteringService._runDbscanClustering(
-              input,
-              fileIDToCreationTime: fileIDToCreationTime,
-              eps: eps,
-              minPts: minPts,
-            );
+            final result = FaceClusteringService._runDbscanClustering(args);
             sendPort.send(result);
             break;
         }
@@ -194,16 +195,19 @@ class FaceClusteringService {
     _inactivityTimer?.cancel();
   }
 
-  /// Runs the clustering algorithm on the given [input], in an isolate.
+  /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
   ///
   /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
   ///
   /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
-  Future<Map<String, int>?> predictLinear(
-    Map<String, (int?, Uint8List)> input, {
+  Future<ClusteringResult?> predictLinear(
+    Set<FaceInfoForClustering> input, {
     Map<int, int>? fileIDToCreationTime,
     double distanceThreshold = kRecommendedDistanceThreshold,
+    double conservativeDistanceThreshold = kConservativeDistanceThreshold,
+    bool useDynamicThreshold = true,
     int? offset,
+    required Map<int, (Uint8List, int)> oldClusterSummaries,
   }) async {
     if (input.isEmpty) {
       _logger.warning(
@@ -225,20 +229,23 @@ class FaceClusteringService {
       final stopwatchClustering = Stopwatch()..start();
       // final Map<String, int> faceIdToCluster =
       //     await _runLinearClusteringInComputer(input);
-      final Map<String, int> faceIdToCluster = await _runInIsolate(
+      final ClusteringResult? faceIdToCluster = await _runInIsolate(
         (
           ClusterOperation.linearIncrementalClustering,
           {
             'input': input,
             'fileIDToCreationTime': fileIDToCreationTime,
             'distanceThreshold': distanceThreshold,
+            'conservativeDistanceThreshold': conservativeDistanceThreshold,
+            'useDynamicThreshold': useDynamicThreshold,
             'offset': offset,
+            'oldClusterSummaries': oldClusterSummaries,
           }
         ),
       );
       // return _runLinearClusteringInComputer(input);
       _logger.info(
-        'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
+        'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
       );
 
       isRunning = false;
@@ -250,6 +257,142 @@ class FaceClusteringService {
     }
   }
 
+  /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
+  Future<ClusteringResult?> predictLinearComputer(
+    Map<String, Uint8List> input, {
+    Map<int, int>? fileIDToCreationTime,
+    double distanceThreshold = kRecommendedDistanceThreshold,
+  }) async {
+    if (input.isEmpty) {
+      _logger.warning(
+        "Linear Clustering dataset of embeddings is empty, returning empty list.",
+      );
+      return null;
+    }
+
+    // Clustering inside the isolate
+    _logger.info(
+      "Start Linear clustering on ${input.length} embeddings inside computer isolate",
+    );
+
+    try {
+      final clusteringInput = input
+          .map((key, value) {
+            return MapEntry(
+              key,
+              FaceInfoForClustering(
+                faceID: key,
+                embeddingBytes: value,
+                faceScore: kMinimumQualityFaceScore + 0.01,
+                blurValue: kLapacianDefault,
+              ),
+            );
+          })
+          .values
+          .toSet();
+      final startTime = DateTime.now();
+      final faceIdToCluster = await _computer.compute(
+        runLinearClustering,
+        param: {
+          "input": clusteringInput,
+          "fileIDToCreationTime": fileIDToCreationTime,
+          "distanceThreshold": distanceThreshold,
+          "conservativeDistanceThreshold": distanceThreshold,
+          "useDynamicThreshold": false,
+        },
+        taskName: "createImageEmbedding",
+      ) as ClusteringResult;
+      final endTime = DateTime.now();
+      _logger.info(
+        "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
+      );
+      return faceIdToCluster;
+    } catch (e, s) {
+      _logger.severe(e, s);
+      rethrow;
+    }
+  }
+
+  /// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer.
+  ///
+  /// WARNING: Only use on small datasets, as it is not optimized for large datasets.
+  Future<Map<String, int>> predictCompleteComputer(
+    Map<String, Uint8List> input, {
+    Map<int, int>? fileIDToCreationTime,
+    double distanceThreshold = kRecommendedDistanceThreshold,
+    double mergeThreshold = 0.30,
+  }) async {
+    if (input.isEmpty) {
+      _logger.warning(
+        "Complete Clustering dataset of embeddings is empty, returning empty list.",
+      );
+      return {};
+    }
+
+    // Clustering inside the isolate
+    _logger.info(
+      "Start Complete clustering on ${input.length} embeddings inside computer isolate",
+    );
+
+    try {
+      final startTime = DateTime.now();
+      final faceIdToCluster = await _computer.compute(
+        runCompleteClustering,
+        param: {
+          "input": input,
+          "fileIDToCreationTime": fileIDToCreationTime,
+          "distanceThreshold": distanceThreshold,
+          "mergeThreshold": mergeThreshold,
+        },
+        taskName: "createImageEmbedding",
+      ) as Map<String, int>;
+      final endTime = DateTime.now();
+      _logger.info(
+        "Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
+      );
+      return faceIdToCluster;
+    } catch (e, s) {
+      _logger.severe(e, s);
+      rethrow;
+    }
+  }
+
+  Future<Map<String, int>?> predictWithinClusterComputer(
+    Map<String, Uint8List> input, {
+    Map<int, int>? fileIDToCreationTime,
+    double distanceThreshold = kRecommendedDistanceThreshold,
+  }) async {
+    _logger.info(
+      '`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold',
+    );
+    try {
+      if (input.length < 100) {
+        final mergeThreshold = distanceThreshold + 0.06;
+        _logger.info(
+          'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold',
+        );
+        return predictCompleteComputer(
+          input,
+          fileIDToCreationTime: fileIDToCreationTime,
+          mergeThreshold: mergeThreshold,
+        );
+      } else {
+        _logger.info(
+          'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
+        );
+        final clusterResult = await predictLinearComputer(
+          input,
+          fileIDToCreationTime: fileIDToCreationTime,
+          distanceThreshold: distanceThreshold,
+        );
+        return clusterResult?.newFaceIdToCluster;
+      }
+    } catch (e, s) {
+      _logger.severe(e, s);
+      rethrow;
+    }
+  }
+
   Future<List<List<String>>> predictDbscan(
     Map<String, Uint8List> input, {
     Map<int, int>? fileIDToCreationTime,
@@ -299,29 +442,42 @@ class FaceClusteringService {
     return clusterFaceIDs;
   }
 
-  static Map<String, int> _runLinearClustering(
-    Map<String, (int?, Uint8List)> x, {
-    Map<int, int>? fileIDToCreationTime,
-    double distanceThreshold = kRecommendedDistanceThreshold,
-    int? offset,
-  }) {
+  static ClusteringResult? runLinearClustering(Map args) {
+    // final input = args['input'] as Map<String, (int?, Uint8List)>;
+    final input = args['input'] as Set<FaceInfoForClustering>;
+    final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
+    final distanceThreshold = args['distanceThreshold'] as double;
+    final conservativeDistanceThreshold =
+        args['conservativeDistanceThreshold'] as double;
+    final useDynamicThreshold = args['useDynamicThreshold'] as bool;
+    final offset = args['offset'] as int?;
+    final oldClusterSummaries =
+        args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
+
     log(
-      "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
+      "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
     );
 
     // Organize everything into a list of FaceInfo objects
     final List<FaceInfo> faceInfos = [];
-    for (final entry in x.entries) {
+    for (final face in input) {
       faceInfos.add(
         FaceInfo(
-          faceID: entry.key,
+          faceID: face.faceID,
+          faceScore: face.faceScore,
+          blurValue: face.blurValue,
+          badFace: face.faceScore < kMinimumQualityFaceScore ||
+              face.blurValue < kLaplacianSoftThreshold ||
+              (face.blurValue < kLaplacianVerySoftThreshold &&
+                  face.faceScore < kMediumQualityFaceScore) ||
+              face.isSideways,
           vEmbedding: Vector.fromList(
-            EVector.fromBuffer(entry.value.$2).values,
+            EVector.fromBuffer(face.embeddingBytes).values,
             dtype: DType.float32,
           ),
-          clusterId: entry.value.$1,
+          clusterId: face.clusterId,
           fileCreationTime:
-              fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
+              fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
         ),
       );
     }
@@ -351,19 +507,21 @@ class FaceClusteringService {
         facesWithClusterID.add(faceInfo);
       }
     }
+    final alreadyClusteredCount = facesWithClusterID.length;
     final sortedFaceInfos = <FaceInfo>[];
     sortedFaceInfos.addAll(facesWithClusterID);
     sortedFaceInfos.addAll(facesWithoutClusterID);
 
     log(
-      "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId",
+      "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId",
     );
 
     // Make sure the first face has a clusterId
     final int totalFaces = sortedFaceInfos.length;
+    int dynamicThresholdCount = 0;
 
     if (sortedFaceInfos.isEmpty) {
-      return {};
+      return null;
     }
 
     // Start actual clustering
@@ -377,7 +535,6 @@ class FaceClusteringService {
       sortedFaceInfos[0].clusterId = clusterID;
       clusterID++;
     }
-    final Map<String, int> newFaceIdToCluster = {};
     final stopwatchClustering = Stopwatch()..start();
     for (int i = 1; i < totalFaces; i++) {
       // Incremental clustering, so we can skip faces that already have a clusterId
@@ -388,6 +545,15 @@ class FaceClusteringService {
 
       int closestIdx = -1;
       double closestDistance = double.infinity;
+      late double thresholdValue;
+      if (useDynamicThreshold) {
+        thresholdValue = sortedFaceInfos[i].badFace!
+            ? conservativeDistanceThreshold
+            : distanceThreshold;
+        if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++;
+      } else {
+        thresholdValue = distanceThreshold;
+      }
       if (i % 250 == 0) {
         log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
       }
@@ -405,18 +571,16 @@ class FaceClusteringService {
           );
         }
         if (distance < closestDistance) {
+          if (sortedFaceInfos[j].badFace! &&
+              distance > conservativeDistanceThreshold) {
+            continue;
+          }
           closestDistance = distance;
           closestIdx = j;
-          // if (distance < distanceThreshold) {
-          //   if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
-          //       sortedFaceInfos[j].faceID.startsWith("15488756")) {
-          //     log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
-          //   }
-          // }
         }
       }
 
-      if (closestDistance < distanceThreshold) {
+      if (closestDistance < thresholdValue) {
         if (sortedFaceInfos[closestIdx].clusterId == null) {
           // Ideally this should never happen, but just in case log it
           log(
@@ -424,42 +588,99 @@ class FaceClusteringService {
           );
           clusterID++;
           sortedFaceInfos[closestIdx].clusterId = clusterID;
-          newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
         }
-        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
-        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
-        //   log(
-        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
-        //   );
-        // }
         sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
-        newFaceIdToCluster[sortedFaceInfos[i].faceID] =
-            sortedFaceInfos[closestIdx].clusterId!;
       } else {
-        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
-        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
-        //   log(
-        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
-        //   );
-        // }
         clusterID++;
         sortedFaceInfos[i].clusterId = clusterID;
-        newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
       }
     }
 
+    // Finally, assign the new clusterId to the faces
+    final Map<String, int> newFaceIdToCluster = {};
+    final newClusteredFaceInfos =
+        sortedFaceInfos.sublist(alreadyClusteredCount);
+    for (final faceInfo in newClusteredFaceInfos) {
+      newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
+    }
+
     stopwatchClustering.stop();
     log(
       ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
     );
+    if (useDynamicThreshold) {
+      log(
+        "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity",
+      );
+    }
+
+    // Now calculate the mean of the embeddings for each cluster and update the cluster summaries
+    Map<int, (Uint8List, int)>? newClusterSummaries;
+    if (oldClusterSummaries != null) {
+      newClusterSummaries = FaceClusteringService.updateClusterSummaries(
+        oldSummary: oldClusterSummaries,
+        newFaceInfos: newClusteredFaceInfos,
+      );
+    }
 
     // analyze the results
     FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
 
-    return newFaceIdToCluster;
+    return ClusteringResult(
+      newFaceIdToCluster: newFaceIdToCluster,
+      newClusterSummaries: newClusterSummaries,
+    );
+  }
+
+  static Map<int, (Uint8List, int)> updateClusterSummaries({
+    required Map<int, (Uint8List, int)> oldSummary,
+    required List<FaceInfo> newFaceInfos,
+  }) {
+    final calcSummariesStart = DateTime.now();
+    final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
+    for (final faceInfo in newFaceInfos) {
+      if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
+        newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
+      } else {
+        newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
+      }
+    }
+
+    final Map<int, (Uint8List, int)> newClusterSummaries = {};
+    for (final clusterId in newClusterIdToFaceInfos.keys) {
+      final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
+          .map((faceInfo) => faceInfo.vEmbedding!)
+          .toList();
+      final newCount = newEmbeddings.length;
+      if (oldSummary.containsKey(clusterId)) {
+        final oldMean = Vector.fromList(
+          EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
+          dtype: DType.float32,
+        );
+        final oldCount = oldSummary[clusterId]!.$2;
+        final oldEmbeddings = oldMean * oldCount;
+        newEmbeddings.add(oldEmbeddings);
+        final newMeanVector =
+            newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
+        newClusterSummaries[clusterId] = (
+          EVector(values: newMeanVector.toList()).writeToBuffer(),
+          oldCount + newCount
+        );
+      } else {
+        final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
+        newClusterSummaries[clusterId] =
+            (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
+      }
+    }
+    log(
+      "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
+    );
+
+    return newClusterSummaries;
   }
 
   static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
+    if (!kDebugMode) return;
     final stopwatch = Stopwatch()..start();
 
     final Map<String, int> faceIdToCluster = {};
@@ -517,14 +738,185 @@ class FaceClusteringService {
     );
   }
 
-  static List<List<String>> _runDbscanClustering(
-    Map<String, Uint8List> x, {
-    Map<int, int>? fileIDToCreationTime,
-    double eps = 0.3,
-    int minPts = 5,
-  }) {
+  static Map<String, int> runCompleteClustering(Map args) {
+    final input = args['input'] as Map<String, Uint8List>;
+    final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
+    final distanceThreshold = args['distanceThreshold'] as double;
+    final mergeThreshold = args['mergeThreshold'] as double;
+
+    log(
+      "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
+    );
+
+    // Organize everything into a list of FaceInfo objects
+    final List<FaceInfo> faceInfos = [];
+    for (final entry in input.entries) {
+      faceInfos.add(
+        FaceInfo(
+          faceID: entry.key,
+          vEmbedding: Vector.fromList(
+            EVector.fromBuffer(entry.value).values,
+            dtype: DType.float32,
+          ),
+          fileCreationTime:
+              fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
+        ),
+      );
+    }
+
+    // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
+    if (fileIDToCreationTime != null) {
+      faceInfos.sort((a, b) {
+        if (a.fileCreationTime == null && b.fileCreationTime == null) {
+          return 0;
+        } else if (a.fileCreationTime == null) {
+          return 1;
+        } else if (b.fileCreationTime == null) {
+          return -1;
+        } else {
+          return a.fileCreationTime!.compareTo(b.fileCreationTime!);
+        }
+      });
+    }
+
+    if (faceInfos.isEmpty) {
+      return {};
+    }
+    final int totalFaces = faceInfos.length;
+
+    // Start actual clustering
+    log(
+      "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering",
+    );
+
+    // set current epoch time as clusterID
+    int clusterID = DateTime.now().microsecondsSinceEpoch;
+
+    // Start actual clustering
+    final Map<String, int> newFaceIdToCluster = {};
+    final stopwatchClustering = Stopwatch()..start();
+    for (int i = 0; i < totalFaces; i++) {
+      if ((i + 1) % 250 == 0) {
+        log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces");
+      }
+      if (faceInfos[i].clusterId != null) continue;
+      int closestIdx = -1;
+      double closestDistance = double.infinity;
+      for (int j = 0; j < totalFaces; j++) {
+        if (i == j) continue;
+        final double distance =
+            1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
+        if (distance < closestDistance) {
+          closestDistance = distance;
+          closestIdx = j;
+        }
+      }
+
+      if (closestDistance < distanceThreshold) {
+        if (faceInfos[closestIdx].clusterId == null) {
+          clusterID++;
+          faceInfos[closestIdx].clusterId = clusterID;
+        }
+        faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
+      } else {
+        clusterID++;
+        faceInfos[i].clusterId = clusterID;
+      }
+    }
+
+    // Now calculate the mean of the embeddings for each cluster
+    final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
+    for (final faceInfo in faceInfos) {
+      if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
+        clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
+      } else {
+        clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
+      }
+    }
+    final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
+    for (final clusterId in clusterIdToFaceInfos.keys) {
+      final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
+          .map((faceInfo) => faceInfo.vEmbedding!)
+          .toList();
+      final count = clusterIdToFaceInfos[clusterId]!.length;
+      final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
+      clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
+    }
+
+    // Now merge the clusters that are close to each other, based on mean embedding
+    final List<(int, int)> mergedClustersList = [];
+    final List<int> clusterIds =
+        clusterIdToMeanEmbeddingAndWeight.keys.toList();
+    log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
+    while (true) {
+      if (clusterIds.length < 2) break;
+      double distance = double.infinity;
+      (int, int) clusterIDsToMerge = (-1, -1);
+      for (int i = 0; i < clusterIds.length; i++) {
+        for (int j = 0; j < clusterIds.length; j++) {
+          if (i == j) continue;
+          final double newDistance = 1.0 -
+              clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
+                    clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
+                  );
+          if (newDistance < distance) {
+            distance = newDistance;
+            clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
+          }
+        }
+      }
+      if (distance < mergeThreshold) {
+        mergedClustersList.add(clusterIDsToMerge);
+        final clusterID1 = clusterIDsToMerge.$1;
+        final clusterID2 = clusterIDsToMerge.$2;
+        final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
+        final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
+        final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
+        final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
+        final weight1 = count1 / (count1 + count2);
+        final weight2 = count2 / (count1 + count2);
+        clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
+          mean1 * weight1 + mean2 * weight2,
+          count1 + count2,
+        );
+        clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
+        clusterIds.remove(clusterID2);
+      } else {
+        break;
+      }
+    }
+    log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
+
+    // Now assign the new clusterId to the faces
+    for (final faceInfo in faceInfos) {
+      for (final mergedClusters in mergedClustersList) {
+        if (faceInfo.clusterId == mergedClusters.$2) {
+          faceInfo.clusterId = mergedClusters.$1;
+        }
+      }
+    }
+
+    // Finally, assign the new clusterId to the faces
+    for (final faceInfo in faceInfos) {
+      newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
+    }
+
+    stopwatchClustering.stop();
+    log(
+      ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
+    );
+
+    return newFaceIdToCluster;
+  }
+
+  static List<List<String>> _runDbscanClustering(Map args) {
+    final input = args['input'] as Map<String, Uint8List>;
+    final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
+    final eps = args['eps'] as double;
+    final minPts = args['minPts'] as int;
+
     log(
-      "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
+      "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
     );
 
     final DBSCAN dbscan = DBSCAN(
@@ -535,7 +927,7 @@ class FaceClusteringService {
 
     // Organize everything into a list of FaceInfo objects
     final List<FaceInfo> faceInfos = [];
-    for (final entry in x.entries) {
+    for (final entry in input.entries) {
       faceInfos.add(
         FaceInfo(
           faceID: entry.key,

+ 19 - 0
mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart

@@ -0,0 +1,19 @@
+import "dart:typed_data" show Uint8List;
+
+class FaceInfoForClustering {
+  final String faceID;
+  final int? clusterId;
+  final Uint8List embeddingBytes;
+  final double faceScore;
+  final double blurValue;
+  final bool isSideways;
+
+  FaceInfoForClustering({
+    required this.faceID,
+    this.clusterId,
+    required this.embeddingBytes,
+    required this.faceScore,
+    required this.blurValue,
+    this.isSideways = false,
+  });
+}

+ 50 - 1
mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart

@@ -1,7 +1,24 @@
-import 'dart:math' show sqrt, pow;
+import 'dart:math' show max, min, pow, sqrt;
 
 import "package:photos/face/model/dimension.dart";
 
+enum FaceDirection { left, right, straight }
+
+extension FaceDirectionExtension on FaceDirection {
+  String toDirectionString() {
+    switch (this) {
+      case FaceDirection.left:
+        return 'Left';
+      case FaceDirection.right:
+        return 'Right';
+      case FaceDirection.straight:
+        return 'Straight';
+      default:
+        throw Exception('Unknown FaceDirection');
+    }
+  }
+}
+
 abstract class Detection {
   final double score;
 
@@ -16,6 +33,7 @@ abstract class Detection {
   String toString();
 }
 
+@Deprecated('Old method only used in other deprecated methods')
 extension BBoxExtension on List<double> {
   void roundBoxToDouble() {
     final widthRounded = (this[2] - this[0]).roundToDouble();
@@ -425,6 +443,37 @@ class FaceDetectionAbsolute extends Detection {
 
   /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight].
   double get height => yMaxBox - yMinBox;
+
+  FaceDirection getFaceDirection() {
+    final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
+    final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
+    final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
+
+    final bool faceIsUpright =
+        (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
+            (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
+
+    final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) &&
+        (nose[0] < min(leftMouth[0], rightMouth[0]));
+    final bool noseStickingOutRight =
+        (nose[0] > max(leftEye[0], rightEye[0])) &&
+            (nose[0] > max(leftMouth[0], rightMouth[0]));
+
+    final bool noseCloseToLeftEye =
+        (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX;
+    final bool noseCloseToRightEye =
+        (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX;
+
+    // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) {
+    if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
+      return FaceDirection.left;
+      // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) {
+    } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
+      return FaceDirection.right;
+    }
+
+    return FaceDirection.straight;
+  }
 }
 
 List<FaceDetectionAbsolute> relativeToAbsoluteDetections({

+ 55 - 15
mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart

@@ -1,4 +1,5 @@
 import 'package:logging/logging.dart';
+import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
 import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
 
 class BlurDetectionService {
@@ -11,9 +12,11 @@ class BlurDetectionService {
 
   Future<(bool, double)> predictIsBlurGrayLaplacian(
     List<List<int>> grayImage, {
-    int threshold = kLaplacianThreshold,
+    int threshold = kLaplacianHardThreshold,
+    FaceDirection faceDirection = FaceDirection.straight,
   }) async {
-    final List<List<int>> laplacian = _applyLaplacian(grayImage);
+    final List<List<int>> laplacian =
+        _applyLaplacian(grayImage, faceDirection: faceDirection);
     final double variance = _calculateVariance(laplacian);
     _logger.info('Variance: $variance');
     return (variance < threshold, variance);
@@ -46,43 +49,80 @@ class BlurDetectionService {
     return variance;
   }
 
-  List<List<int>> _padImage(List<List<int>> image) {
+  List<List<int>> _padImage(
+    List<List<int>> image, {
+    int removeSideColumns = 56,
+    FaceDirection faceDirection = FaceDirection.straight,
+  }) {
+    // Exception is removeSideColumns is not even
+    if (removeSideColumns % 2 != 0) {
+      throw Exception('removeSideColumns must be even');
+    }
+
     final int numRows = image.length;
     final int numCols = image[0].length;
+    final int paddedNumCols = numCols + 2 - removeSideColumns;
+    final int paddedNumRows = numRows + 2;
 
     // Create a new matrix with extra padding
     final List<List<int>> paddedImage = List.generate(
-      numRows + 2,
-      (i) => List.generate(numCols + 2, (j) => 0, growable: false),
+      paddedNumRows,
+      (i) => List.generate(
+        paddedNumCols,
+        (j) => 0,
+        growable: false,
+      ),
       growable: false,
     );
 
-    // Copy original image into the center of the padded image
-    for (int i = 0; i < numRows; i++) {
-      for (int j = 0; j < numCols; j++) {
-        paddedImage[i + 1][j + 1] = image[i][j];
+    // Copy original image into the center of the padded image, taking into account the face direction
+    if (faceDirection == FaceDirection.straight) {
+      for (int i = 0; i < numRows; i++) {
+        for (int j = 0; j < (paddedNumCols - 2); j++) {
+          paddedImage[i + 1][j + 1] =
+              image[i][j + (removeSideColumns / 2).round()];
+        }
+      }
+      // If the face is facing left, we only take the right side of the face image
+    } else if (faceDirection == FaceDirection.left) {
+      for (int i = 0; i < numRows; i++) {
+        for (int j = 0; j < (paddedNumCols - 2); j++) {
+          paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns];
+        }
+      }
+      // If the face is facing right, we only take the left side of the face image
+    } else if (faceDirection == FaceDirection.right) {
+      for (int i = 0; i < numRows; i++) {
+        for (int j = 0; j < (paddedNumCols - 2); j++) {
+          paddedImage[i + 1][j + 1] = image[i][j];
+        }
       }
     }
 
     // Reflect padding
     // Top and bottom rows
-    for (int j = 1; j <= numCols; j++) {
+    for (int j = 1; j <= (paddedNumCols - 2); j++) {
       paddedImage[0][j] = paddedImage[2][j]; // Top row
       paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row
     }
     // Left and right columns
     for (int i = 0; i < numRows + 2; i++) {
       paddedImage[i][0] = paddedImage[i][2]; // Left column
-      paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column
+      paddedImage[i][paddedNumCols - 1] =
+          paddedImage[i][paddedNumCols - 3]; // Right column
     }
 
     return paddedImage;
   }
 
-  List<List<int>> _applyLaplacian(List<List<int>> image) {
-    final List<List<int>> paddedImage = _padImage(image);
-    final int numRows = image.length;
-    final int numCols = image[0].length;
+  List<List<int>> _applyLaplacian(
+    List<List<int>> image, {
+    FaceDirection faceDirection = FaceDirection.straight,
+  }) {
+    final List<List<int>> paddedImage =
+        _padImage(image, faceDirection: faceDirection);
+    final int numRows = paddedImage.length - 2;
+    final int numCols = paddedImage[0].length - 2;
     final List<List<int>> outputImage = List.generate(
       numRows,
       (i) => List.generate(numCols, (j) => 0, growable: false),

+ 7 - 3
mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart

@@ -1,13 +1,17 @@
 import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
 
 /// Blur detection threshold
-const kLaplacianThreshold = 15;
+const kLaplacianHardThreshold = 15;
+const kLaplacianSoftThreshold = 100;
+const kLaplacianVerySoftThreshold = 200;
 
 /// Default blur value
 const kLapacianDefault = 10000.0;
 
 /// The minimum score for a face to be considered a high quality face for clustering and person detection
-const kMinHighQualityFaceScore = 0.80;
+const kMinimumQualityFaceScore = 0.80;
+const kMediumQualityFaceScore = 0.85;
+const kHighQualityFaceScore = 0.90;
 
-/// The minimum score for a face to be detected, regardless of quality. Use [kMinHighQualityFaceScore] for high quality faces.
+/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
 const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;

+ 3 - 269
mobile/lib/services/machine_learning/face_ml/face_ml_result.dart

@@ -1,284 +1,18 @@
 import "dart:convert" show jsonEncode, jsonDecode;
 
-import "package:flutter/material.dart" show debugPrint, immutable;
+import "package:flutter/material.dart" show immutable;
 import "package:logging/logging.dart";
 import "package:photos/face/model/dimension.dart";
 import "package:photos/models/file/file.dart";
 import 'package:photos/models/ml/ml_typedefs.dart';
 import "package:photos/models/ml/ml_versions.dart";
 import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart';
-import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
 import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
 import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
 import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart';
 
 final _logger = Logger('ClusterResult_FaceMlResult');
 
-// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class?
-@Deprecated('We are now just storing the cluster results directly in DB')
-class ClusterResult {
-  final int personId;
-  String? userDefinedName;
-  bool get hasUserDefinedName => userDefinedName != null;
-
-  String _thumbnailFaceId;
-  bool thumbnailFaceIdIsUserDefined;
-
-  final List<int> _fileIds;
-  final List<String> _faceIds;
-
-  final Embedding medoid;
-  double medoidDistanceThreshold;
-
-  List<int> get uniqueFileIds => _fileIds.toSet().toList();
-  List<int> get fileIDsIncludingPotentialDuplicates => _fileIds;
-
-  List<String> get faceIDs => _faceIds;
-
-  String get thumbnailFaceId => _thumbnailFaceId;
-
-  int get thumbnailFileId => getFileIdFromFaceId(_thumbnailFaceId);
-
-  /// Sets the thumbnail faceId to the given faceId.
-  /// Throws an exception if the faceId is not in the list of faceIds.
-  set setThumbnailFaceId(String faceId) {
-    if (!_faceIds.contains(faceId)) {
-      throw Exception(
-        "The faceId $faceId is not in the list of faceIds: $faceId",
-      );
-    }
-    _thumbnailFaceId = faceId;
-    thumbnailFaceIdIsUserDefined = true;
-  }
-
-  /// Sets the [userDefinedName] to the given [customName]
-  set setUserDefinedName(String customName) {
-    userDefinedName = customName;
-  }
-
-  int get clusterSize => _fileIds.toSet().length;
-
-  ClusterResult({
-    required this.personId,
-    required String thumbnailFaceId,
-    required List<int> fileIds,
-    required List<String> faceIds,
-    required this.medoid,
-    required this.medoidDistanceThreshold,
-    this.userDefinedName,
-    this.thumbnailFaceIdIsUserDefined = false,
-  })  : _thumbnailFaceId = thumbnailFaceId,
-        _faceIds = faceIds,
-        _fileIds = fileIds;
-
-  void addFileIDsAndFaceIDs(List<int> fileIDs, List<String> faceIDs) {
-    assert(fileIDs.length == faceIDs.length);
-    _fileIds.addAll(fileIDs);
-    _faceIds.addAll(faceIDs);
-  }
-
-  // TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster
-  int removeFileId(int fileId) {
-    assert(_fileIds.length == _faceIds.length);
-    if (!_fileIds.contains(fileId)) {
-      throw Exception(
-        "The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.",
-      );
-    }
-
-    int removedCount = 0;
-    for (var i = 0; i < _fileIds.length; i++) {
-      if (_fileIds[i] == fileId) {
-        assert(getFileIdFromFaceId(_faceIds[i]) == fileId);
-        _fileIds.removeAt(i);
-        _faceIds.removeAt(i);
-        debugPrint(
-          "Removed fileId $fileId from cluster $personId at index ${i + removedCount}}",
-        );
-        i--; // Adjust index due to removal
-        removedCount++;
-      }
-    }
-
-    _ensureClusterSizeIsAboveMinimum();
-
-    return removedCount;
-  }
-
-  int addFileID(int fileID) {
-    assert(_fileIds.length == _faceIds.length);
-    if (_fileIds.contains(fileID)) {
-      return 0;
-    }
-
-    _fileIds.add(fileID);
-    _faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID));
-
-    return 1;
-  }
-
-  void ensureThumbnailFaceIdIsInCluster() {
-    if (!_faceIds.contains(_thumbnailFaceId)) {
-      _thumbnailFaceId = _faceIds[0];
-    }
-  }
-
-  void _ensureClusterSizeIsAboveMinimum() {
-    if (clusterSize < minimumClusterSize) {
-      throw Exception(
-        "Cluster size is below minimum cluster size of $minimumClusterSize",
-      );
-    }
-  }
-
-  Map<String, dynamic> _toJson() => {
-        'personId': personId,
-        'thumbnailFaceId': _thumbnailFaceId,
-        'fileIds': _fileIds,
-        'faceIds': _faceIds,
-        'medoid': medoid,
-        'medoidDistanceThreshold': medoidDistanceThreshold,
-        if (userDefinedName != null) 'userDefinedName': userDefinedName,
-        'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined,
-      };
-
-  String toJsonString() => jsonEncode(_toJson());
-
-  static ClusterResult _fromJson(Map<String, dynamic> json) {
-    return ClusterResult(
-      personId: json['personId'] ?? -1,
-      thumbnailFaceId: json['thumbnailFaceId'] ?? '',
-      fileIds:
-          (json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [],
-      faceIds:
-          (json['faceIds'] as List?)?.map((item) => item as String).toList() ??
-              [],
-      medoid:
-          (json['medoid'] as List?)?.map((item) => item as double).toList() ??
-              [],
-      medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0,
-      userDefinedName: json['userDefinedName'],
-      thumbnailFaceIdIsUserDefined:
-          json['thumbnailFaceIdIsUserDefined'] as bool,
-    );
-  }
-
-  static ClusterResult fromJsonString(String jsonString) {
-    return _fromJson(jsonDecode(jsonString));
-  }
-}
-
-class ClusterResultBuilder {
-  int personId = -1;
-  String? userDefinedName;
-  String thumbnailFaceId = '';
-  bool thumbnailFaceIdIsUserDefined = false;
-
-  List<int> fileIds = <int>[];
-  List<String> faceIds = <String>[];
-
-  List<Embedding> embeddings = <Embedding>[];
-  Embedding medoid = <double>[];
-  double medoidDistanceThreshold = 0;
-  bool medoidAndThresholdCalculated = false;
-  final int k = 5;
-
-  ClusterResultBuilder.createFromIndices({
-    required List<int> clusterIndices,
-    required List<int> labels,
-    required List<Embedding> allEmbeddings,
-    required List<int> allFileIds,
-    required List<String> allFaceIds,
-  }) {
-    final clusteredFileIds =
-        clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList();
-    final clusteredFaceIds =
-        clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList();
-    final clusteredEmbeddings =
-        clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList();
-    personId = labels[clusterIndices[0]];
-    fileIds = clusteredFileIds;
-    faceIds = clusteredFaceIds;
-    thumbnailFaceId = faceIds[0];
-    embeddings = clusteredEmbeddings;
-  }
-
-  void calculateAndSetMedoidAndThreshold() {
-    if (embeddings.isEmpty) {
-      throw Exception("Cannot calculate medoid and threshold for empty list");
-    }
-
-    // Calculate the medoid and threshold
-    final (tempMedoid, distanceThreshold) =
-        _calculateMedoidAndDistanceTreshold(embeddings);
-
-    // Update the medoid
-    medoid = List.from(tempMedoid);
-
-    // Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor
-    medoidDistanceThreshold = distanceThreshold;
-
-    medoidAndThresholdCalculated = true;
-  }
-
-  (List<double>, double) _calculateMedoidAndDistanceTreshold(
-    List<List<double>> embeddings,
-  ) {
-    double minDistance = double.infinity;
-    List<double>? medoid;
-
-    // Calculate the distance between all pairs
-    for (int i = 0; i < embeddings.length; ++i) {
-      double totalDistance = 0;
-      for (int j = 0; j < embeddings.length; ++j) {
-        if (i != j) {
-          totalDistance += cosineDistance(embeddings[i], embeddings[j]);
-
-          // Break early if we already exceed minDistance
-          if (totalDistance > minDistance) {
-            break;
-          }
-        }
-      }
-
-      // Find the minimum total distance
-      if (totalDistance < minDistance) {
-        minDistance = totalDistance;
-        medoid = embeddings[i];
-      }
-    }
-
-    // Now, calculate k-th nearest neighbor for the medoid
-    final List<double> distancesToMedoid = [];
-    for (List<double> embedding in embeddings) {
-      if (embedding != medoid) {
-        distancesToMedoid.add(cosineDistance(medoid!, embedding));
-      }
-    }
-    distancesToMedoid.sort();
-    // TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones.
-    final double kthDistance = distancesToMedoid[
-        distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1];
-
-    return (medoid!, kthDistance);
-  }
-
-  void changeThumbnailFaceId(String faceId) {
-    if (!faceIds.contains(faceId)) {
-      throw Exception(
-        "The faceId $faceId is not in the list of faceIds: $faceIds",
-      );
-    }
-    thumbnailFaceId = faceId;
-  }
-
-  void addFileIDsAndFaceIDs(List<int> addedFileIDs, List<String> addedFaceIDs) {
-    assert(addedFileIDs.length == addedFaceIDs.length);
-    fileIds.addAll(addedFileIDs);
-    faceIds.addAll(addedFaceIDs);
-  }
-}
-
 @immutable
 class FaceMlResult {
   final int fileId;
@@ -504,7 +238,7 @@ class FaceResult {
   final int fileId;
   final String faceId;
 
-  bool get isBlurry => blurValue < kLaplacianThreshold;
+  bool get isBlurry => blurValue < kLaplacianHardThreshold;
 
   const FaceResult({
     required this.detection,
@@ -545,7 +279,7 @@ class FaceResultBuilder {
   int fileId = -1;
   String faceId = '';
 
-  bool get isBlurry => blurValue < kLaplacianThreshold;
+  bool get isBlurry => blurValue < kLaplacianHardThreshold;
 
   FaceResultBuilder({
     required this.fileId,

+ 117 - 93
mobile/lib/services/machine_learning/face_ml/face_ml_service.dart

@@ -204,83 +204,13 @@ class FaceMlService {
       try {
         switch (function) {
           case FaceMlOperation.analyzeImage:
-            final int enteFileID = args["enteFileID"] as int;
-            final String imagePath = args["filePath"] as String;
-            final int faceDetectionAddress =
-                args["faceDetectionAddress"] as int;
-            final int faceEmbeddingAddress =
-                args["faceEmbeddingAddress"] as int;
-
-            final resultBuilder =
-                FaceMlResultBuilder.fromEnteFileID(enteFileID);
-
-            dev.log(
-              "Start analyzing image with uploadedFileID: $enteFileID inside the isolate",
-            );
-            final stopwatchTotal = Stopwatch()..start();
-            final stopwatch = Stopwatch()..start();
-
-            // Decode the image once to use for both face detection and alignment
-            final imageData = await File(imagePath).readAsBytes();
-            final image = await decodeImageFromData(imageData);
-            final ByteData imgByteData = await getByteDataFromImage(image);
-            dev.log('Reading and decoding image took '
-                '${stopwatch.elapsedMilliseconds} ms');
-            stopwatch.reset();
-
-            // Get the faces
-            final List<FaceDetectionRelative> faceDetectionResult =
-                await FaceMlService.detectFacesSync(
-              image,
-              imgByteData,
-              faceDetectionAddress,
-              resultBuilder: resultBuilder,
-            );
-
+            final time = DateTime.now();
+            final FaceMlResult result =
+                await FaceMlService.analyzeImageSync(args);
             dev.log(
-                "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in "
-                "${stopwatch.elapsedMilliseconds} ms");
-
-            // If no faces were detected, return a result with no faces. Otherwise, continue.
-            if (faceDetectionResult.isEmpty) {
-              dev.log(
-                  "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in "
-                  "${stopwatch.elapsedMilliseconds} ms");
-              sendPort.send(resultBuilder.buildNoFaceDetected().toJsonString());
-              break;
-            }
-
-            stopwatch.reset();
-            // Align the faces
-            final Float32List faceAlignmentResult =
-                await FaceMlService.alignFacesSync(
-              image,
-              imgByteData,
-              faceDetectionResult,
-              resultBuilder: resultBuilder,
+              "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms",
             );
-
-            dev.log("Completed `alignFacesSync` function, in "
-                "${stopwatch.elapsedMilliseconds} ms");
-
-            stopwatch.reset();
-            // Get the embeddings of the faces
-            final embeddings = await FaceMlService.embedFacesSync(
-              faceAlignmentResult,
-              faceEmbeddingAddress,
-              resultBuilder: resultBuilder,
-            );
-
-            dev.log("Completed `embedFacesSync` function, in "
-                "${stopwatch.elapsedMilliseconds} ms");
-
-            stopwatch.stop();
-            stopwatchTotal.stop();
-            dev.log("Finished Analyze image (${embeddings.length} faces) with "
-                "uploadedFileID $enteFileID, in "
-                "${stopwatchTotal.elapsedMilliseconds} ms");
-
-            sendPort.send(resultBuilder.build().toJsonString());
+            sendPort.send(result.toJsonString());
             break;
         }
       } catch (e, stackTrace) {
@@ -361,7 +291,7 @@ class FaceMlService {
   }
 
   Future<void> clusterAllImages({
-    double minFaceScore = kMinHighQualityFaceScore,
+    double minFaceScore = kMinimumQualityFaceScore,
     bool clusterInBuckets = true,
   }) async {
     _logger.info("`clusterAllImages()` called");
@@ -370,6 +300,10 @@ class FaceMlService {
       // Get a sense of the total number of faces in the database
       final int totalFaces = await FaceMLDataDB.instance
           .getTotalFaceCount(minFaceScore: minFaceScore);
+
+      // Get the current cluster statistics
+      final Map<int, (Uint8List, int)> oldClusterSummaries =
+          await FaceMLDataDB.instance.getAllClusterSummary();
       if (clusterInBuckets) {
         // read the creation times from Files DB, in a map from fileID to creation time
         final fileIDToCreationTime =
@@ -382,14 +316,14 @@ class FaceMlService {
         int bucket = 1;
 
         while (true) {
-          final faceIdToEmbeddingBucket =
-              await FaceMLDataDB.instance.getFaceEmbeddingMap(
+          final faceInfoForClustering =
+              await FaceMLDataDB.instance.getFaceInfoForClustering(
             minScore: minFaceScore,
             maxFaces: bucketSize,
             offset: offset,
             batchSize: batchSize,
           );
-          if (faceIdToEmbeddingBucket.isEmpty) {
+          if (faceInfoForClustering.isEmpty) {
             _logger.warning(
               'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
             );
@@ -402,20 +336,24 @@ class FaceMlService {
             break;
           }
 
-          final faceIdToCluster =
+          final clusteringResult =
               await FaceClusteringService.instance.predictLinear(
-            faceIdToEmbeddingBucket,
+            faceInfoForClustering,
             fileIDToCreationTime: fileIDToCreationTime,
             offset: offset,
+            oldClusterSummaries: oldClusterSummaries,
           );
-          if (faceIdToCluster == null) {
+          if (clusteringResult == null) {
             _logger.warning("faceIdToCluster is null");
             return;
           }
 
-          await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
+          await FaceMLDataDB.instance
+              .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
+          await FaceMLDataDB.instance
+              .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
           _logger.info(
-            'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
+            'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
           );
           if (offset + bucketSize >= totalFaces) {
             _logger.info('All faces clustered');
@@ -427,14 +365,14 @@ class FaceMlService {
       } else {
         // Read all the embeddings from the database, in a map from faceID to embedding
         final clusterStartTime = DateTime.now();
-        final faceIdToEmbedding =
-            await FaceMLDataDB.instance.getFaceEmbeddingMap(
+        final faceInfoForClustering =
+            await FaceMLDataDB.instance.getFaceInfoForClustering(
           minScore: minFaceScore,
           maxFaces: totalFaces,
         );
         final gotFaceEmbeddingsTime = DateTime.now();
         _logger.info(
-          'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
+          'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
         );
 
         // Read the creation times from Files DB, in a map from fileID to creation time
@@ -444,25 +382,29 @@ class FaceMlService {
             '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
 
         // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
-        final faceIdToCluster =
+        final clusteringResult =
             await FaceClusteringService.instance.predictLinear(
-          faceIdToEmbedding,
+          faceInfoForClustering,
           fileIDToCreationTime: fileIDToCreationTime,
+          oldClusterSummaries: oldClusterSummaries,
         );
-        if (faceIdToCluster == null) {
+        if (clusteringResult == null) {
           _logger.warning("faceIdToCluster is null");
           return;
         }
         final clusterDoneTime = DateTime.now();
         _logger.info(
-          'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
+          'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
         );
 
         // Store the updated clusterIDs in the database
         _logger.info(
-          'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB',
+          'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
         );
-        await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
+        await FaceMLDataDB.instance
+            .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
+        await FaceMLDataDB.instance
+            .clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
         _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
             '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
       }
@@ -875,6 +817,7 @@ class FaceMlService {
     }
   }
 
+  /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate.
   Future<FaceMlResult?> analyzeImageInSingleIsolate(EnteFile enteFile) async {
     _checkEnteFileForID(enteFile);
     await ensureInitialized();
@@ -931,6 +874,87 @@ class FaceMlService {
     return result;
   }
 
+  static Future<FaceMlResult> analyzeImageSync(Map args) async {
+    try {
+      final int enteFileID = args["enteFileID"] as int;
+      final String imagePath = args["filePath"] as String;
+      final int faceDetectionAddress = args["faceDetectionAddress"] as int;
+      final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int;
+
+      final resultBuilder = FaceMlResultBuilder.fromEnteFileID(enteFileID);
+
+      dev.log(
+        "Start analyzing image with uploadedFileID: $enteFileID inside the isolate",
+      );
+      final stopwatchTotal = Stopwatch()..start();
+      final stopwatch = Stopwatch()..start();
+
+      // Decode the image once to use for both face detection and alignment
+      final imageData = await File(imagePath).readAsBytes();
+      final image = await decodeImageFromData(imageData);
+      final ByteData imgByteData = await getByteDataFromImage(image);
+      dev.log('Reading and decoding image took '
+          '${stopwatch.elapsedMilliseconds} ms');
+      stopwatch.reset();
+
+      // Get the faces
+      final List<FaceDetectionRelative> faceDetectionResult =
+          await FaceMlService.detectFacesSync(
+        image,
+        imgByteData,
+        faceDetectionAddress,
+        resultBuilder: resultBuilder,
+      );
+
+      dev.log(
+          "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in "
+          "${stopwatch.elapsedMilliseconds} ms");
+
+      // If no faces were detected, return a result with no faces. Otherwise, continue.
+      if (faceDetectionResult.isEmpty) {
+        dev.log(
+            "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in "
+            "${stopwatch.elapsedMilliseconds} ms");
+        return resultBuilder.buildNoFaceDetected();
+      }
+
+      stopwatch.reset();
+      // Align the faces
+      final Float32List faceAlignmentResult =
+          await FaceMlService.alignFacesSync(
+        image,
+        imgByteData,
+        faceDetectionResult,
+        resultBuilder: resultBuilder,
+      );
+
+      dev.log("Completed `alignFacesSync` function, in "
+          "${stopwatch.elapsedMilliseconds} ms");
+
+      stopwatch.reset();
+      // Get the embeddings of the faces
+      final embeddings = await FaceMlService.embedFacesSync(
+        faceAlignmentResult,
+        faceEmbeddingAddress,
+        resultBuilder: resultBuilder,
+      );
+
+      dev.log("Completed `embedFacesSync` function, in "
+          "${stopwatch.elapsedMilliseconds} ms");
+
+      stopwatch.stop();
+      stopwatchTotal.stop();
+      dev.log("Finished Analyze image (${embeddings.length} faces) with "
+          "uploadedFileID $enteFileID, in "
+          "${stopwatchTotal.elapsedMilliseconds} ms");
+
+      return resultBuilder.build();
+    } catch (e, s) {
+      dev.log("Could not analyze image: \n e: $e \n s: $s");
+      rethrow;
+    }
+  }
+
   Future<String?> _getImagePathForML(
     EnteFile enteFile, {
     FileDataForML typeOfData = FileDataForML.fileData,

+ 157 - 19
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -5,6 +5,8 @@ import "package:flutter/foundation.dart";
 import "package:logging/logging.dart";
 import "package:photos/core/event_bus.dart";
 import "package:photos/db/files_db.dart";
+// import "package:photos/events/files_updated_event.dart";
+// import "package:photos/events/local_photos_updated_event.dart";
 import "package:photos/events/people_changed_event.dart";
 import "package:photos/extensions/stop_watch.dart";
 import "package:photos/face/db.dart";
@@ -115,17 +117,103 @@ class ClusterFeedbackService {
     List<EnteFile> files,
     PersonEntity p,
   ) async {
-    await FaceMLDataDB.instance.removeFilesFromPerson(files, p.remoteID);
-    Bus.instance.fire(PeopleChangedEvent());
+    try {
+      // Get the relevant faces to be removed
+      final faceIDs = await FaceMLDataDB.instance
+          .getFaceIDsForPerson(p.remoteID)
+          .then((iterable) => iterable.toList());
+      faceIDs.retainWhere((faceID) {
+        final fileID = getFileIdFromFaceId(faceID);
+        return files.any((file) => file.uploadedFileID == fileID);
+      });
+      final embeddings =
+          await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
+
+      final fileIDToCreationTime =
+          await FilesDB.instance.getFileIDToCreationTime();
+
+      // Re-cluster within the deleted faces
+      final newFaceIdToClusterID =
+          await FaceClusteringService.instance.predictWithinClusterComputer(
+        embeddings,
+        fileIDToCreationTime: fileIDToCreationTime,
+        distanceThreshold: 0.20,
+      );
+      if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
+        return;
+      }
+
+      // Update the deleted faces
+      await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
+
+      // Make sure the deleted faces don't get suggested in the future
+      final notClusterIdToPersonId = <int, String>{};
+      for (final clusterId in newFaceIdToClusterID.values.toSet()) {
+        notClusterIdToPersonId[clusterId] = p.remoteID;
+      }
+      await FaceMLDataDB.instance
+          .bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
+
+      Bus.instance.fire(PeopleChangedEvent());
+      return;
+    } catch (e, s) {
+      _logger.severe("Error in removeFilesFromPerson", e, s);
+      rethrow;
+    }
   }
 
   Future<void> removeFilesFromCluster(
     List<EnteFile> files,
     int clusterID,
   ) async {
-    await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
-    Bus.instance.fire(PeopleChangedEvent());
-    return;
+    try {
+      // Get the relevant faces to be removed
+      final faceIDs = await FaceMLDataDB.instance
+          .getFaceIDsForCluster(clusterID)
+          .then((iterable) => iterable.toList());
+      faceIDs.retainWhere((faceID) {
+        final fileID = getFileIdFromFaceId(faceID);
+        return files.any((file) => file.uploadedFileID == fileID);
+      });
+      final embeddings =
+          await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
+
+      final fileIDToCreationTime =
+          await FilesDB.instance.getFileIDToCreationTime();
+
+      // Re-cluster within the deleted faces
+      final newFaceIdToClusterID =
+          await FaceClusteringService.instance.predictWithinClusterComputer(
+        embeddings,
+        fileIDToCreationTime: fileIDToCreationTime,
+        distanceThreshold: 0.20,
+      );
+      if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
+        return;
+      }
+
+      // Update the deleted faces
+      await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
+
+      Bus.instance.fire(
+        PeopleChangedEvent(
+          relevantFiles: files,
+          type: PeopleEventType.removedFilesFromCluster,
+          source: "$clusterID",
+        ),
+      );
+      // Bus.instance.fire(
+      //   LocalPhotosUpdatedEvent(
+      //     files,
+      //     type: EventType.peopleClusterChanged,
+      //     source: "$clusterID",
+      //   ),
+      // );
+      return;
+    } catch (e, s) {
+      _logger.severe("Error in removeFilesFromCluster", e, s);
+      rethrow;
+    }
   }
 
   Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
@@ -194,7 +282,7 @@ class ClusterFeedbackService {
   // TODO: iterate over this method to find sweet spot
   Future<Map<int, List<String>>> breakUpCluster(
     int clusterID, {
-    useDbscan = false,
+    bool useDbscan = false,
   }) async {
     _logger.info(
       'breakUpCluster called for cluster $clusterID with dbscan $useDbscan',
@@ -203,10 +291,8 @@ class ClusterFeedbackService {
 
     final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID);
     final originalFaceIDsSet = faceIDs.toSet();
-    final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
 
-    final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
-    embeddings.removeWhere((key, value) => !faceIDs.contains(key));
+    final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs);
 
     final fileIDToCreationTime =
         await FilesDB.instance.getFileIDToCreationTime();
@@ -232,18 +318,14 @@ class ClusterFeedbackService {
         maxClusterID++;
       }
     } else {
-      final clusteringInput = embeddings.map((key, value) {
-        return MapEntry(key, (null, value));
-      });
-
       final faceIdToCluster =
-          await FaceClusteringService.instance.predictLinear(
-        clusteringInput,
+          await FaceClusteringService.instance.predictWithinClusterComputer(
+        embeddings,
         fileIDToCreationTime: fileIDToCreationTime,
-        distanceThreshold: 0.23,
+        distanceThreshold: 0.22,
       );
 
-      if (faceIdToCluster == null) {
+      if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
         _logger.info('No clusters found');
         return {};
       } else {
@@ -295,6 +377,62 @@ class ClusterFeedbackService {
     return clusterIdToFaceIds;
   }
 
+  /// WARNING: this method is purely for debugging purposes, never use in production
+  Future<void> createFakeClustersByBlurValue() async {
+    try {
+      // Delete old clusters
+      await FaceMLDataDB.instance.resetClusterIDs();
+      await FaceMLDataDB.instance.dropClustersAndPersonTable();
+      final List<PersonEntity> persons =
+          await PersonService.instance.getPersons();
+      for (final PersonEntity p in persons) {
+        await PersonService.instance.deletePerson(p.remoteID);
+      }
+
+      // Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200
+      final int startClusterID = DateTime.now().microsecondsSinceEpoch;
+      final faceIDsToBlurValues =
+          await FaceMLDataDB.instance.getFaceIDsToBlurValues(200);
+      final faceIdToCluster = <String, int>{};
+      for (final entry in faceIDsToBlurValues.entries) {
+        final faceID = entry.key;
+        final blurValue = entry.value;
+        final newClusterID = startClusterID + blurValue ~/ 10;
+        faceIdToCluster[faceID] = newClusterID;
+      }
+      await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
+
+      Bus.instance.fire(PeopleChangedEvent());
+    } catch (e, s) {
+      _logger.severe("Error in createFakeClustersByBlurValue", e, s);
+      rethrow;
+    }
+  }
+
+  Future<void> debugLogClusterBlurValues(
+    int clusterID, {
+    int? clusterSize,
+  }) async {
+    final List<double> blurValues = await FaceMLDataDB.instance
+        .getBlurValuesForCluster(clusterID)
+        .then((value) => value.toList());
+
+    // Round the blur values to integers
+    final blurValuesIntegers =
+        blurValues.map((value) => value.round()).toList();
+
+    // Sort the blur values in ascending order
+    blurValuesIntegers.sort();
+
+    // Log the sorted blur values
+
+    _logger.info(
+      "Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers",
+    );
+
+    return;
+  }
+
   /// Returns a map of person's clusterID to map of closest clusterID to with disstance
   Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
     PersonEntity p, {
@@ -523,7 +661,7 @@ class ClusterFeedbackService {
     );
 
     final Map<int, (Uint8List, int)> clusterToSummary =
-        await faceMlDb.clusterSummaryAll();
+        await faceMlDb.getAllClusterSummary();
     final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
 
     final Map<int, List<double>> clusterAvg = {};
@@ -714,7 +852,7 @@ class ClusterFeedbackService {
 
     // Get the cluster averages for the person's clusters and the suggestions' clusters
     final Map<int, (Uint8List, int)> clusterToSummary =
-        await faceMlDb.clusterSummaryAll();
+        await faceMlDb.getAllClusterSummary();
 
     // Calculate the avg embedding of the person
     final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);

+ 1 - 1
mobile/lib/services/search_service.dart

@@ -824,7 +824,7 @@ class SearchService {
             "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
           );
         }
-        if (files.length < 3 && sortedClusterIds.length > 3) {
+        if (files.length < 20 && sortedClusterIds.length > 3) {
           continue;
         }
         facesResult.add(

+ 29 - 0
mobile/lib/ui/settings/debug/face_debug_section_widget.dart

@@ -8,6 +8,7 @@ import "package:photos/events/people_changed_event.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
 import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart';
+import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import 'package:photos/theme/ente_theme.dart';
 import 'package:photos/ui/components/captioned_text_widget.dart';
@@ -284,6 +285,34 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
           },
         ),
         sectionOptionSpacing,
+        MenuItemWidget(
+          captionedTextWidget: const CaptionedTextWidget(
+            title: "Rank blurs",
+          ),
+          pressedColor: getEnteColorScheme(context).fillFaint,
+          trailingIcon: Icons.chevron_right_outlined,
+          trailingIconIsMuted: true,
+          onTap: () async {
+            await showChoiceDialog(
+              context,
+              title: "Are you sure?",
+              body:
+                  "This will delete all clusters and put blurry faces in separate clusters per ten points.",
+              firstButtonLabel: "Yes, confirm",
+              firstButtonOnTap: () async {
+                try {
+                  await ClusterFeedbackService.instance
+                      .createFakeClustersByBlurValue();
+                  showShortToast(context, "Done");
+                } catch (e, s) {
+                  _logger.warning('Failed to rank faces on blur values ', e, s);
+                  await showGenericErrorDialog(context: context, error: e);
+                }
+              },
+            );
+          },
+        ),
+        sectionOptionSpacing,
         MenuItemWidget(
           captionedTextWidget: const CaptionedTextWidget(
             title: "Drop embeddings & feedback",

+ 23 - 4
mobile/lib/ui/viewer/file_details/face_widget.dart

@@ -9,6 +9,7 @@ import "package:photos/face/db.dart";
 import "package:photos/face/model/face.dart";
 import "package:photos/face/model/person.dart";
 import 'package:photos/models/file/file.dart';
+import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
 import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import "package:photos/services/search_service.dart";
 import "package:photos/theme/ente_theme.dart";
@@ -47,7 +48,7 @@ class _FaceWidgetState extends State<FaceWidget> {
 
   @override
   Widget build(BuildContext context) {
-    if (Platform.isIOS || Platform.isAndroid) {
+    if (Platform.isIOS) {
       return FutureBuilder<Uint8List?>(
         future: getFaceCrop(),
         builder: (context, snapshot) {
@@ -164,19 +165,19 @@ class _FaceWidgetState extends State<FaceWidget> {
                     ),
                   if (kDebugMode)
                     Text(
-                      'B: ${widget.face.blur.toStringAsFixed(3)}',
+                      'B: ${widget.face.blur.toStringAsFixed(0)}',
                       style: Theme.of(context).textTheme.bodySmall,
                       maxLines: 1,
                     ),
                   if (kDebugMode)
                     Text(
-                      'V: ${widget.face.visibility}',
+                      'D: ${widget.face.detection.getFaceDirection().toDirectionString()}',
                       style: Theme.of(context).textTheme.bodySmall,
                       maxLines: 1,
                     ),
                   if (kDebugMode)
                     Text(
-                      'A: ${widget.face.area()}',
+                      'Sideways: ${widget.face.detection.faceIsSideways().toString()}',
                       style: Theme.of(context).textTheme.bodySmall,
                       maxLines: 1,
                     ),
@@ -303,6 +304,24 @@ class _FaceWidgetState extends State<FaceWidget> {
                     style: Theme.of(context).textTheme.bodySmall,
                     maxLines: 1,
                   ),
+                if (kDebugMode)
+                  Text(
+                    'B: ${widget.face.blur.toStringAsFixed(0)}',
+                    style: Theme.of(context).textTheme.bodySmall,
+                    maxLines: 1,
+                  ),
+                if (kDebugMode)
+                  Text(
+                    'D: ${widget.face.detection.getFaceDirection().toDirectionString()}',
+                    style: Theme.of(context).textTheme.bodySmall,
+                    maxLines: 1,
+                  ),
+                if (kDebugMode)
+                  Text(
+                    'Sideways: ${widget.face.detection.faceIsSideways().toString()}',
+                    style: Theme.of(context).textTheme.bodySmall,
+                    maxLines: 1,
+                  ),
               ],
             ),
           );

+ 23 - 1
mobile/lib/ui/viewer/people/cluster_page.dart

@@ -1,10 +1,12 @@
 import "dart:async";
 
+import "package:flutter/foundation.dart";
 import 'package:flutter/material.dart';
 import "package:flutter_animate/flutter_animate.dart";
 import 'package:photos/core/event_bus.dart';
 import 'package:photos/events/files_updated_event.dart';
 import 'package:photos/events/local_photos_updated_event.dart';
+import "package:photos/events/people_changed_event.dart";
 import "package:photos/face/model/person.dart";
 import "package:photos/generated/l10n.dart";
 import 'package:photos/models/file/file.dart';
@@ -51,6 +53,7 @@ class _ClusterPageState extends State<ClusterPage> {
   final _selectedFiles = SelectedFiles();
   late final List<EnteFile> files;
   late final StreamSubscription<LocalPhotosUpdatedEvent> _filesUpdatedEvent;
+  late final StreamSubscription<PeopleChangedEvent> _peopleChangedEvent;
 
   @override
   void initState() {
@@ -69,11 +72,27 @@ class _ClusterPageState extends State<ClusterPage> {
         setState(() {});
       }
     });
+    _peopleChangedEvent = Bus.instance.on<PeopleChangedEvent>().listen((event) {
+      if (event.type == PeopleEventType.removedFilesFromCluster &&
+          (event.source == widget.clusterID.toString())) {
+        for (var updatedFile in event.relevantFiles!) {
+          files.remove(updatedFile);
+        }
+        setState(() {});
+      }
+    });
+    kDebugMode
+        ? ClusterFeedbackService.instance.debugLogClusterBlurValues(
+            widget.clusterID,
+            clusterSize: files.length,
+          )
+        : null;
   }
 
   @override
   void dispose() {
     _filesUpdatedEvent.cancel();
+    _peopleChangedEvent.cancel();
     super.dispose();
   }
 
@@ -96,10 +115,12 @@ class _ClusterPageState extends State<ClusterPage> {
         );
       },
       reloadEvent: Bus.instance.on<LocalPhotosUpdatedEvent>(),
+      forceReloadEvents: [Bus.instance.on<PeopleChangedEvent>()],
       removalEventTypes: const {
         EventType.deletedFromRemote,
         EventType.deletedFromEverywhere,
         EventType.hide,
+        EventType.peopleClusterChanged,
       },
       tagPrefix: widget.tagPrefix + widget.tagPrefix,
       selectedFiles: _selectedFiles,
@@ -111,9 +132,10 @@ class _ClusterPageState extends State<ClusterPage> {
         preferredSize: const Size.fromHeight(50.0),
         child: ClusterAppBar(
           SearchResultPage.appBarType,
-          "${widget.searchResult.length} memories${widget.appendTitle}",
+          "${files.length} memories${widget.appendTitle}",
           _selectedFiles,
           widget.clusterID,
+          key: ValueKey(files.length),
         ),
       ),
       body: Column(

+ 4 - 7
mobile/lib/utils/image_ml_util.dart

@@ -1099,19 +1099,16 @@ Future<(Float32List, List<AlignmentResult>, List<bool>, List<double>, Size)>
     imageHeight: image.height,
   );
 
-  final List<List<List<double>>> faceLandmarks =
-      absoluteFaces.map((face) => face.allKeypoints).toList();
-
   final alignedImagesFloat32List =
-      Float32List(3 * width * height * faceLandmarks.length);
+      Float32List(3 * width * height * absoluteFaces.length);
   final alignmentResults = <AlignmentResult>[];
   final isBlurs = <bool>[];
   final blurValues = <double>[];
 
   int alignedImageIndex = 0;
-  for (final faceLandmark in faceLandmarks) {
+  for (final face in absoluteFaces) {
     final (alignmentResult, correctlyEstimated) =
-        SimilarityTransform.instance.estimate(faceLandmark);
+        SimilarityTransform.instance.estimate(face.allKeypoints);
     if (!correctlyEstimated) {
       alignedImageIndex += 3 * width * height;
       alignmentResults.add(AlignmentResult.empty());
@@ -1137,7 +1134,7 @@ Future<(Float32List, List<AlignmentResult>, List<bool>, List<double>, Size)>
     final grayscalems = blurDetectionStopwatch.elapsedMilliseconds;
     log('creating grayscale matrix took $grayscalems ms');
     final (isBlur, blurValue) = await BlurDetectionService.instance
-        .predictIsBlurGrayLaplacian(faceGrayMatrix);
+        .predictIsBlurGrayLaplacian(faceGrayMatrix, faceDirection: face.getFaceDirection());
     final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems;
     log('blur detection took $blurms ms');
     log(