diff --git a/mobile/lib/db/files_db.dart b/mobile/lib/db/files_db.dart index dc821d793..0926276e4 100644 --- a/mobile/lib/db/files_db.dart +++ b/mobile/lib/db/files_db.dart @@ -1316,8 +1316,8 @@ class FilesDB { } Future> getFileIDToCreationTime() async { - final db = await instance.database; - final rows = await db.rawQuery( + final db = await instance.sqliteAsyncDB; + final rows = await db.getAll( ''' SELECT $columnUploadedFileID, $columnCreationTime FROM $filesTable diff --git a/mobile/lib/events/files_updated_event.dart b/mobile/lib/events/files_updated_event.dart index 7d7779d49..2fc67d646 100644 --- a/mobile/lib/events/files_updated_event.dart +++ b/mobile/lib/events/files_updated_event.dart @@ -27,4 +27,5 @@ enum EventType { unhide, coverChanged, peopleChanged, + peopleClusterChanged, } diff --git a/mobile/lib/events/people_changed_event.dart b/mobile/lib/events/people_changed_event.dart index e2d135866..51f4eaeef 100644 --- a/mobile/lib/events/people_changed_event.dart +++ b/mobile/lib/events/people_changed_event.dart @@ -1,3 +1,22 @@ import "package:photos/events/event.dart"; +import "package:photos/models/file/file.dart"; -class PeopleChangedEvent extends Event {} +class PeopleChangedEvent extends Event { + final List? relevantFiles; + final PeopleEventType type; + final String source; + + PeopleChangedEvent({ + this.relevantFiles, + this.type = PeopleEventType.defaultType, + this.source = "", + }); + + @override + String get reason => '$runtimeType{type: ${type.name}, "via": $source}'; +} + +enum PeopleEventType { + defaultType, + removedFilesFromCluster, +} \ No newline at end of file diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index f192607ec..594517547 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart'; import "package:photos/face/db_model_mappers.dart"; import "package:photos/face/model/face.dart"; import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:sqflite/sqflite.dart'; import 'package:sqlite_async/sqlite_async.dart' as sqlite_async; @@ -160,27 +161,27 @@ class FaceMLDataDB { final db = await instance.database; // find out clusterIds that are assigned to other persons using the clusters table final List> maps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', [personID], ); final Set ignoredClusterIDs = - maps.map((e) => e[cluserIDColumn] as int).toSet(); + maps.map((e) => e[clusterIDColumn] as int).toSet(); final List> rejectMaps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', + 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', [personID], ); final Set rejectClusterIDs = - rejectMaps.map((e) => e[cluserIDColumn] as int).toSet(); + rejectMaps.map((e) => e[clusterIDColumn] as int).toSet(); return ignoredClusterIDs.union(rejectClusterIDs); } Future> getPersonClusterIDs(String personID) async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', [personID], ); - return maps.map((e) => e[cluserIDColumn] as int).toSet(); + return maps.map((e) => e[clusterIDColumn] as int).toSet(); } Future clearTable() async { @@ -249,16 +250,16 @@ class FaceMLDataDB { } final cluterRows = await db.query( clusterPersonTable, - columns: [cluserIDColumn], + columns: [clusterIDColumn], where: '$personIdColumn = ?', whereArgs: [personID], ); final clusterIDs = - cluterRows.map((e) => e[cluserIDColumn] as int).toList(); + cluterRows.map((e) => e[clusterIDColumn] as int).toList(); final List> faceMaps = await db.rawQuery( 'SELECT * FROM $facesTable where ' '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' - 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinHighQualityFaceScore ORDER BY $faceScore DESC', + 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC', ); if (faceMaps.isNotEmpty) { if (avatarFileId != null) { @@ -308,8 +309,6 @@ class FaceMLDataDB { faceBlur, imageHeight, imageWidth, - faceArea, - faceVisibilityScore, mlVersionColumn, ], where: '$fileIDColumn = ?', @@ -334,16 +333,60 @@ class FaceMLDataDB { } Future> getFaceIDsForCluster(int clusterID) async { - final db = await instance.database; - final List> maps = await db.query( - faceClustersTable, - columns: [fcFaceId], - where: '$fcClusterID = ?', - whereArgs: [clusterID], + final db = await instance.sqliteAsyncDB; + final List> maps = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$fcClusterID = ?', + [clusterID], ); return maps.map((e) => e[fcFaceId] as String).toSet(); } + Future> getFaceIDsForPerson(String personID) async { + final db = await instance.sqliteAsyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + return faceIdsResult.map((e) => e[fcFaceId] as String).toSet(); + } + + Future> getBlurValuesForCluster(int clusterID) async { + final db = await instance.sqliteAsyncDB; + const String query = ''' + SELECT $facesTable.$faceBlur + FROM $facesTable + JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$fcFaceId + WHERE $faceClustersTable.$fcClusterID = ? + '''; + // const String query2 = ''' + // SELECT $faceBlur + // FROM $facesTable + // WHERE $faceIDColumn IN (SELECT $fcFaceId FROM $faceClustersTable WHERE $fcClusterID = ?) + // '''; + final List> maps = await db.getAll( + query, + [clusterID], + ); + return maps.map((e) => e[faceBlur] as double).toSet(); + } + + Future> getFaceIDsToBlurValues( + int maxBlurValue, + ) async { + final db = await instance.sqliteAsyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceBlur FROM $facesTable WHERE $faceBlur < $maxBlurValue AND $faceBlur > 1 ORDER BY $faceBlur ASC', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[faceBlur] as double; + } + return result; + } + Future> getFaceIdsToClusterIds( Iterable faceIds, ) async { @@ -376,14 +419,14 @@ class FaceMLDataDB { } Future forceUpdateClusterIds( - Map faceIDToPersonID, + Map faceIDToClusterID, ) async { final db = await instance.database; // Start a batch final batch = db.batch(); - for (final map in faceIDToPersonID.entries) { + for (final map in faceIDToClusterID.entries) { final faceID = map.key; final clusterID = map.value; batch.insert( @@ -410,12 +453,64 @@ class FaceMLDataDB { ); } + Future> getFaceInfoForClustering({ + double minScore = kMinimumQualityFaceScore, + int minClarity = kLaplacianHardThreshold, + int maxFaces = 20000, + int offset = 0, + int batchSize = 10000, + }) async { + final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start(); + w.logAndReset( + 'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize', + ); + final db = await instance.sqliteAsyncDB; + + final Set result = {}; + while (true) { + // Query a batch of rows + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable' + ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity' + ' ORDER BY $faceIDColumn' + ' DESC LIMIT $batchSize OFFSET $offset', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + final List faceIds = []; + for (final map in maps) { + faceIds.add(map[faceIDColumn] as String); + } + final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds); + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + final faceInfo = FaceInfoForClustering( + faceID: faceID, + clusterId: faceIdToClusterId[faceID], + embeddingBytes: map[faceEmbeddingBlob] as Uint8List, + faceScore: map[faceScore] as double, + blurValue: map[faceBlur] as double, + isSideways: (map[isSideways] as int) == 1, + ); + result.add(faceInfo); + } + if (result.length >= maxFaces) { + break; + } + offset += batchSize; + } + w.stopWithLog('done reading face embeddings ${result.length}'); + return result; + } + /// Returns a map of faceID to record of clusterId and faceEmbeddingBlob /// /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] Future> getFaceEmbeddingMap({ - double minScore = kMinHighQualityFaceScore, - int minClarity = kLaplacianThreshold, + double minScore = kMinimumQualityFaceScore, + int minClarity = kLaplacianHardThreshold, int maxFaces = 20000, int offset = 0, int batchSize = 10000, @@ -481,7 +576,7 @@ class FaceMLDataDB { facesTable, columns: [faceIDColumn, faceEmbeddingBlob], where: - '$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', + '$faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', limit: batchSize, offset: offset, orderBy: '$faceIDColumn DESC', @@ -503,12 +598,50 @@ class FaceMLDataDB { return result; } + Future> getFaceEmbeddingMapForFaces( + Iterable faceIDs, + ) async { + _logger.info('reading face embeddings for ${faceIDs.length} faces'); + final db = await instance.sqliteAsyncDB; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final String query = ''' + SELECT $faceIDColumn, $faceEmbeddingBlob + FROM $facesTable + WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) + ORDER BY $faceIDColumn DESC + LIMIT $batchSize OFFSET $offset + '''; + final List> maps = await db.getAll(query); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[faceEmbeddingBlob] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + _logger.info('done reading face embeddings for ${faceIDs.length} faces'); + return result; + } + Future getTotalFaceCount({ - double minFaceScore = kMinHighQualityFaceScore, + double minFaceScore = kMinimumQualityFaceScore, }) async { final db = await instance.sqliteAsyncDB; final List> maps = await db.getAll( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold', ); return maps.first['count'] as int; } @@ -517,7 +650,7 @@ class FaceMLDataDB { final db = await instance.sqliteAsyncDB; final List> totalFacesMaps = await db.getAll( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', ); final int totalFaces = totalFacesMaps.first['count'] as int; @@ -530,11 +663,11 @@ class FaceMLDataDB { } Future getBlurryFaceCount([ - int blurThreshold = kLaplacianThreshold, + int blurThreshold = kLaplacianHardThreshold, ]) async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore', + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinimumQualityFaceScore', ); return maps.first['count'] as int; } @@ -555,7 +688,7 @@ class FaceMLDataDB { clusterPersonTable, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, ); } @@ -572,7 +705,7 @@ class FaceMLDataDB { clusterPersonTable, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, conflictAlgorithm: ConflictAlgorithm.replace, ); @@ -589,11 +722,31 @@ class FaceMLDataDB { notPersonFeedback, { personIdColumn: personID, - cluserIDColumn: clusterID, + clusterIDColumn: clusterID, }, ); } + Future bulkCaptureNotPersonFeedback( + Map clusterToPersonID, + ) async { + final db = await instance.database; + final batch = db.batch(); + for (final entry in clusterToPersonID.entries) { + final clusterID = entry.key; + final personID = entry.value; + batch.insert( + notPersonFeedback, + { + personIdColumn: personID, + clusterIDColumn: clusterID, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + } + await batch.commit(noResult: true); + } + Future removeClusterToPerson({ required String personID, required int clusterID, @@ -601,7 +754,7 @@ class FaceMLDataDB { final db = await instance.database; return db.delete( clusterPersonTable, - where: '$personIdColumn = ? AND $cluserIDColumn = ?', + where: '$personIdColumn = ? AND $clusterIDColumn = ?', whereArgs: [personID, clusterID], ); } @@ -613,13 +766,13 @@ class FaceMLDataDB { final List> maps = await db.rawQuery( 'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable ' 'INNER JOIN $clusterPersonTable ' - 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); final Map> result = {}; for (final map in maps) { - final clusterID = map[cluserIDColumn] as int; + final clusterID = map[clusterIDColumn] as int; final String faceID = map[fcFaceId] as String; final fileID = int.parse(faceID.split('_').first); result[fileID] = (result[fileID] ?? {})..add(clusterID); @@ -664,7 +817,7 @@ class FaceMLDataDB { batch.insert( clusterSummaryTable, { - cluserIDColumn: cluserID, + clusterIDColumn: cluserID, avgColumn: avg, countColumn: count, }, @@ -676,12 +829,16 @@ class FaceMLDataDB { } /// Returns a map of clusterID to (avg embedding, count) - Future> clusterSummaryAll() async { - final db = await instance.database; + Future> getAllClusterSummary([ + int? minClusterSize, + ]) async { + final db = await instance.sqliteAsyncDB; final Map result = {}; - final rows = await db.rawQuery('SELECT * from $clusterSummaryTable'); + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}', + ); for (final r in rows) { - final id = r[cluserIDColumn] as int; + final id = r[clusterIDColumn] as int; final avg = r[avgColumn] as Uint8List; final count = r[countColumn] as int; result[id] = (avg, count); @@ -692,11 +849,11 @@ class FaceMLDataDB { Future> getClusterIDToPersonID() async { final db = await instance.database; final List> maps = await db.rawQuery( - 'SELECT $personIdColumn, $cluserIDColumn FROM $clusterPersonTable', + 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', ); final Map result = {}; for (final map in maps) { - result[map[cluserIDColumn] as int] = map[personIdColumn] as String; + result[map[clusterIDColumn] as int] = map[personIdColumn] as String; } return result; } @@ -741,7 +898,7 @@ class FaceMLDataDB { final db = await instance.database; final faceIdsResult = await db.rawQuery( 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' - 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' 'WHERE $clusterPersonTable.$personIdColumn = ?', [personID], ); diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart index 2dc98ac1f..c7d0c703c 100644 --- a/mobile/lib/face/db_fields.dart +++ b/mobile/lib/face/db_fields.dart @@ -8,8 +8,7 @@ const faceDetectionColumn = 'detection'; const faceEmbeddingBlob = 'eBlob'; const faceScore = 'score'; const faceBlur = 'blur'; -const faceArea = 'area'; -const faceVisibilityScore = 'visibility'; +const isSideways = 'is_sideways'; const imageWidth = 'width'; const imageHeight = 'height'; const faceClusterId = 'cluster_id'; @@ -22,10 +21,9 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( $faceEmbeddingBlob BLOB NOT NULL, $faceScore REAL NOT NULL, $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, + $isSideways INTEGER NOT NULL DEFAULT 0, $imageHeight INTEGER NOT NULL DEFAULT 0, $imageWidth INTEGER NOT NULL DEFAULT 0, - $faceArea INTEGER NOT NULL DEFAULT 0, - $faceVisibilityScore INTEGER NOT NULL DEFAULT -1, $mlVersionColumn INTEGER NOT NULL DEFAULT -1, PRIMARY KEY($fileIDColumn, $faceIDColumn) ); @@ -62,13 +60,13 @@ const deletePersonTable = 'DROP TABLE IF EXISTS $personTable'; // Clusters Table Fields & Schema Queries const clusterPersonTable = 'cluster_person'; const personIdColumn = 'person_id'; -const cluserIDColumn = 'cluster_id'; +const clusterIDColumn = 'cluster_id'; const createClusterPersonTable = ''' CREATE TABLE IF NOT EXISTS $clusterPersonTable ( $personIdColumn TEXT NOT NULL, - $cluserIDColumn INTEGER NOT NULL, - PRIMARY KEY($personIdColumn, $cluserIDColumn) + $clusterIDColumn INTEGER NOT NULL, + PRIMARY KEY($personIdColumn, $clusterIDColumn) ); '''; const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable'; @@ -80,10 +78,10 @@ const avgColumn = 'avg'; const countColumn = 'count'; const createClusterSummaryTable = ''' CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( - $cluserIDColumn INTEGER NOT NULL, + $clusterIDColumn INTEGER NOT NULL, $avgColumn BLOB NOT NULL, $countColumn INTEGER NOT NULL, - PRIMARY KEY($cluserIDColumn) + PRIMARY KEY($clusterIDColumn) ); '''; @@ -97,7 +95,7 @@ const notPersonFeedback = 'not_person_feedback'; const createNotPersonFeedbackTable = ''' CREATE TABLE IF NOT EXISTS $notPersonFeedback ( $personIdColumn TEXT NOT NULL, - $cluserIDColumn INTEGER NOT NULL + $clusterIDColumn INTEGER NOT NULL ); '''; const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback'; diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/face/db_model_mappers.dart index 4e33a0bfd..70dc77915 100644 --- a/mobile/lib/face/db_model_mappers.dart +++ b/mobile/lib/face/db_model_mappers.dart @@ -34,9 +34,8 @@ Map mapRemoteToFaceDB(Face face) { ).writeToBuffer(), faceScore: face.score, faceBlur: face.blur, + isSideways: face.detection.faceIsSideways() ? 1 : 0, mlVersionColumn: faceMlVersion, - faceArea: face.area(), - faceVisibilityScore: face.visibility, imageWidth: face.fileInfo?.imageWidth ?? 0, imageHeight: face.fileInfo?.imageHeight ?? 0, }; diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart index cd7ff6c64..49e8c3652 100644 --- a/mobile/lib/face/model/detection.dart +++ b/mobile/lib/face/model/detection.dart @@ -1,6 +1,9 @@ +import "dart:math" show min, max; + import "package:logging/logging.dart"; import "package:photos/face/model/box.dart"; import "package:photos/face/model/landmark.dart"; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; /// Stores the face detection data, notably the bounding box and landmarks. /// @@ -19,7 +22,7 @@ class Detection { bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty; - // emoty box + // empty box Detection.empty() : box = FaceBox( xMin: 0, @@ -89,4 +92,72 @@ class Detection { return -1; } } + + FaceDirection getFaceDirection() { + if (isEmpty) { + return FaceDirection.straight; + } + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } + + bool faceIsSideways() { + if (isEmpty) { + return false; + } + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = + (nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight); + } } diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/face/model/face.dart index 631eeb141..fc4bb57d5 100644 --- a/mobile/lib/face/model/face.dart +++ b/mobile/lib/face/model/face.dart @@ -20,9 +20,9 @@ class Face { final double blur; FileInfo? fileInfo; - bool get isBlurry => blur < kLaplacianThreshold; + bool get isBlurry => blur < kLaplacianHardThreshold; - bool get hasHighScore => score > kMinHighQualityFaceScore; + bool get hasHighScore => score > kMinimumQualityFaceScore; bool get isHighQuality => (!isBlurry) && hasHighScore; diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index b5f17b54c..80fff99c6 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -2,19 +2,26 @@ import "dart:async"; import "dart:developer"; import "dart:isolate"; import "dart:math" show max; -import "dart:typed_data"; +import "dart:typed_data" show Uint8List; +import "package:computer/computer.dart"; +import "package:flutter/foundation.dart" show kDebugMode; import "package:logging/logging.dart"; import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/vector.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:simple_cluster/simple_cluster.dart"; import "package:synchronized/synchronized.dart"; class FaceInfo { final String faceID; + final double? faceScore; + final double? blurValue; + final bool? badFace; final List? embedding; final Vector? vEmbedding; int? clusterId; @@ -23,6 +30,9 @@ class FaceInfo { int? fileCreationTime; FaceInfo({ required this.faceID, + this.faceScore, + this.blurValue, + this.badFace, this.embedding, this.vEmbedding, this.clusterId, @@ -32,8 +42,18 @@ class FaceInfo { enum ClusterOperation { linearIncrementalClustering, dbscanClustering } +class ClusteringResult { + final Map newFaceIdToCluster; + final Map? newClusterSummaries; + ClusteringResult({ + required this.newFaceIdToCluster, + required this.newClusterSummaries, + }); +} + class FaceClusteringService { final _logger = Logger("FaceLinearClustering"); + final _computer = Computer.shared(); Timer? _inactivityTimer; final Duration _inactivityDuration = const Duration(minutes: 3); @@ -49,6 +69,7 @@ class FaceClusteringService { bool isRunning = false; static const kRecommendedDistanceThreshold = 0.24; + static const kConservativeDistanceThreshold = 0.06; // singleton pattern FaceClusteringService._privateConstructor(); @@ -100,31 +121,11 @@ class FaceClusteringService { try { switch (function) { case ClusterOperation.linearIncrementalClustering: - final input = args['input'] as Map; - final fileIDToCreationTime = - args['fileIDToCreationTime'] as Map?; - final distanceThreshold = args['distanceThreshold'] as double; - final offset = args['offset'] as int?; - final result = FaceClusteringService._runLinearClustering( - input, - fileIDToCreationTime: fileIDToCreationTime, - distanceThreshold: distanceThreshold, - offset: offset, - ); + final result = FaceClusteringService.runLinearClustering(args); sendPort.send(result); break; case ClusterOperation.dbscanClustering: - final input = args['input'] as Map; - final fileIDToCreationTime = - args['fileIDToCreationTime'] as Map?; - final eps = args['eps'] as double; - final minPts = args['minPts'] as int; - final result = FaceClusteringService._runDbscanClustering( - input, - fileIDToCreationTime: fileIDToCreationTime, - eps: eps, - minPts: minPts, - ); + final result = FaceClusteringService._runDbscanClustering(args); sendPort.send(result); break; } @@ -194,16 +195,19 @@ class FaceClusteringService { _inactivityTimer?.cancel(); } - /// Runs the clustering algorithm on the given [input], in an isolate. + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate. /// /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. /// /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. - Future?> predictLinear( - Map input, { + Future predictLinear( + Set input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, + double conservativeDistanceThreshold = kConservativeDistanceThreshold, + bool useDynamicThreshold = true, int? offset, + required Map oldClusterSummaries, }) async { if (input.isEmpty) { _logger.warning( @@ -225,20 +229,23 @@ class FaceClusteringService { final stopwatchClustering = Stopwatch()..start(); // final Map faceIdToCluster = // await _runLinearClusteringInComputer(input); - final Map faceIdToCluster = await _runInIsolate( + final ClusteringResult? faceIdToCluster = await _runInIsolate( ( ClusterOperation.linearIncrementalClustering, { 'input': input, 'fileIDToCreationTime': fileIDToCreationTime, 'distanceThreshold': distanceThreshold, + 'conservativeDistanceThreshold': conservativeDistanceThreshold, + 'useDynamicThreshold': useDynamicThreshold, 'offset': offset, + 'oldClusterSummaries': oldClusterSummaries, } ), ); // return _runLinearClusteringInComputer(input); _logger.info( - 'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', + 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', ); isRunning = false; @@ -250,6 +257,142 @@ class FaceClusteringService { } } + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding + Future predictLinearComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + if (input.isEmpty) { + _logger.warning( + "Linear Clustering dataset of embeddings is empty, returning empty list.", + ); + return null; + } + + // Clustering inside the isolate + _logger.info( + "Start Linear clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final clusteringInput = input + .map((key, value) { + return MapEntry( + key, + FaceInfoForClustering( + faceID: key, + embeddingBytes: value, + faceScore: kMinimumQualityFaceScore + 0.01, + blurValue: kLapacianDefault, + ), + ); + }) + .values + .toSet(); + final startTime = DateTime.now(); + final faceIdToCluster = await _computer.compute( + runLinearClustering, + param: { + "input": clusteringInput, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + "conservativeDistanceThreshold": distanceThreshold, + "useDynamicThreshold": false, + }, + taskName: "createImageEmbedding", + ) as ClusteringResult; + final endTime = DateTime.now(); + _logger.info( + "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return faceIdToCluster; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + /// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer. + /// + /// WARNING: Only use on small datasets, as it is not optimized for large datasets. + Future> predictCompleteComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + double mergeThreshold = 0.30, + }) async { + if (input.isEmpty) { + _logger.warning( + "Complete Clustering dataset of embeddings is empty, returning empty list.", + ); + return {}; + } + + // Clustering inside the isolate + _logger.info( + "Start Complete clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final startTime = DateTime.now(); + final faceIdToCluster = await _computer.compute( + runCompleteClustering, + param: { + "input": input, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + "mergeThreshold": mergeThreshold, + }, + taskName: "createImageEmbedding", + ) as Map; + final endTime = DateTime.now(); + _logger.info( + "Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return faceIdToCluster; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + Future?> predictWithinClusterComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + _logger.info( + '`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold', + ); + try { + if (input.length < 100) { + final mergeThreshold = distanceThreshold + 0.06; + _logger.info( + 'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold', + ); + return predictCompleteComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + mergeThreshold: mergeThreshold, + ); + } else { + _logger.info( + 'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold', + ); + final clusterResult = await predictLinearComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: distanceThreshold, + ); + return clusterResult?.newFaceIdToCluster; + } + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + Future>> predictDbscan( Map input, { Map? fileIDToCreationTime, @@ -299,29 +442,42 @@ class FaceClusteringService { return clusterFaceIDs; } - static Map _runLinearClustering( - Map x, { - Map? fileIDToCreationTime, - double distanceThreshold = kRecommendedDistanceThreshold, - int? offset, - }) { + static ClusteringResult? runLinearClustering(Map args) { + // final input = args['input'] as Map; + final input = args['input'] as Set; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + final conservativeDistanceThreshold = + args['conservativeDistanceThreshold'] as double; + final useDynamicThreshold = args['useDynamicThreshold'] as bool; + final offset = args['offset'] as int?; + final oldClusterSummaries = + args['oldClusterSummaries'] as Map?; + log( - "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", ); // Organize everything into a list of FaceInfo objects final List faceInfos = []; - for (final entry in x.entries) { + for (final face in input) { faceInfos.add( FaceInfo( - faceID: entry.key, + faceID: face.faceID, + faceScore: face.faceScore, + blurValue: face.blurValue, + badFace: face.faceScore < kMinimumQualityFaceScore || + face.blurValue < kLaplacianSoftThreshold || + (face.blurValue < kLaplacianVerySoftThreshold && + face.faceScore < kMediumQualityFaceScore) || + face.isSideways, vEmbedding: Vector.fromList( - EVector.fromBuffer(entry.value.$2).values, + EVector.fromBuffer(face.embeddingBytes).values, dtype: DType.float32, ), - clusterId: entry.value.$1, + clusterId: face.clusterId, fileCreationTime: - fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)], ), ); } @@ -351,19 +507,21 @@ class FaceClusteringService { facesWithClusterID.add(faceInfo); } } + final alreadyClusteredCount = facesWithClusterID.length; final sortedFaceInfos = []; sortedFaceInfos.addAll(facesWithClusterID); sortedFaceInfos.addAll(facesWithoutClusterID); log( - "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId", + "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId", ); // Make sure the first face has a clusterId final int totalFaces = sortedFaceInfos.length; + int dynamicThresholdCount = 0; if (sortedFaceInfos.isEmpty) { - return {}; + return null; } // Start actual clustering @@ -377,7 +535,6 @@ class FaceClusteringService { sortedFaceInfos[0].clusterId = clusterID; clusterID++; } - final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 1; i < totalFaces; i++) { // Incremental clustering, so we can skip faces that already have a clusterId @@ -388,6 +545,15 @@ class FaceClusteringService { int closestIdx = -1; double closestDistance = double.infinity; + late double thresholdValue; + if (useDynamicThreshold) { + thresholdValue = sortedFaceInfos[i].badFace! + ? conservativeDistanceThreshold + : distanceThreshold; + if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++; + } else { + thresholdValue = distanceThreshold; + } if (i % 250 == 0) { log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces"); } @@ -405,18 +571,16 @@ class FaceClusteringService { ); } if (distance < closestDistance) { + if (sortedFaceInfos[j].badFace! && + distance > conservativeDistanceThreshold) { + continue; + } closestDistance = distance; closestIdx = j; - // if (distance < distanceThreshold) { - // if (sortedFaceInfos[j].faceID.startsWith("14914702") || - // sortedFaceInfos[j].faceID.startsWith("15488756")) { - // log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance'); - // } - // } } } - if (closestDistance < distanceThreshold) { + if (closestDistance < thresholdValue) { if (sortedFaceInfos[closestIdx].clusterId == null) { // Ideally this should never happen, but just in case log it log( @@ -424,42 +588,99 @@ class FaceClusteringService { ); clusterID++; sortedFaceInfos[closestIdx].clusterId = clusterID; - newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID; } - // if (sortedFaceInfos[i].faceID.startsWith("14914702") || - // sortedFaceInfos[i].faceID.startsWith("15488756")) { - // log( - // "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance", - // ); - // } sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; - newFaceIdToCluster[sortedFaceInfos[i].faceID] = - sortedFaceInfos[closestIdx].clusterId!; } else { - // if (sortedFaceInfos[i].faceID.startsWith("14914702") || - // sortedFaceInfos[i].faceID.startsWith("15488756")) { - // log( - // "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}", - // ); - // } clusterID++; sortedFaceInfos[i].clusterId = clusterID; - newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID; } } + // Finally, assign the new clusterId to the faces + final Map newFaceIdToCluster = {}; + final newClusteredFaceInfos = + sortedFaceInfos.sublist(alreadyClusteredCount); + for (final faceInfo in newClusteredFaceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + stopwatchClustering.stop(); log( ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', ); + if (useDynamicThreshold) { + log( + "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity", + ); + } + + // Now calculate the mean of the embeddings for each cluster and update the cluster summaries + Map? newClusterSummaries; + if (oldClusterSummaries != null) { + newClusterSummaries = FaceClusteringService.updateClusterSummaries( + oldSummary: oldClusterSummaries, + newFaceInfos: newClusteredFaceInfos, + ); + } // analyze the results FaceClusteringService._analyzeClusterResults(sortedFaceInfos); - return newFaceIdToCluster; + return ClusteringResult( + newFaceIdToCluster: newFaceIdToCluster, + newClusterSummaries: newClusterSummaries, + ); + } + + static Map updateClusterSummaries({ + required Map oldSummary, + required List newFaceInfos, + }) { + final calcSummariesStart = DateTime.now(); + final Map> newClusterIdToFaceInfos = {}; + for (final faceInfo in newFaceInfos) { + if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) { + newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo); + } else { + newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + + final Map newClusterSummaries = {}; + for (final clusterId in newClusterIdToFaceInfos.keys) { + final List newEmbeddings = newClusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final newCount = newEmbeddings.length; + if (oldSummary.containsKey(clusterId)) { + final oldMean = Vector.fromList( + EVector.fromBuffer(oldSummary[clusterId]!.$1).values, + dtype: DType.float32, + ); + final oldCount = oldSummary[clusterId]!.$2; + final oldEmbeddings = oldMean * oldCount; + newEmbeddings.add(oldEmbeddings); + final newMeanVector = + newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); + newClusterSummaries[clusterId] = ( + EVector(values: newMeanVector.toList()).writeToBuffer(), + oldCount + newCount + ); + } else { + final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount; + newClusterSummaries[clusterId] = + (EVector(values: newMeanVector.toList()).writeToBuffer(), newCount); + } + } + log( + "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms", + ); + + return newClusterSummaries; } static void _analyzeClusterResults(List sortedFaceInfos) { + if (!kDebugMode) return; final stopwatch = Stopwatch()..start(); final Map faceIdToCluster = {}; @@ -517,14 +738,185 @@ class FaceClusteringService { ); } - static List> _runDbscanClustering( - Map x, { - Map? fileIDToCreationTime, - double eps = 0.3, - int minPts = 5, - }) { + static Map runCompleteClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + final mergeThreshold = args['mergeThreshold'] as double; + log( - "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", + "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", + ); + + // Organize everything into a list of FaceInfo objects + final List faceInfos = []; + for (final entry in input.entries) { + faceInfos.add( + FaceInfo( + faceID: entry.key, + vEmbedding: Vector.fromList( + EVector.fromBuffer(entry.value).values, + dtype: DType.float32, + ), + fileCreationTime: + fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + ), + ); + } + + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first + if (fileIDToCreationTime != null) { + faceInfos.sort((a, b) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); + } + + if (faceInfos.isEmpty) { + return {}; + } + final int totalFaces = faceInfos.length; + + // Start actual clustering + log( + "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering", + ); + + // set current epoch time as clusterID + int clusterID = DateTime.now().microsecondsSinceEpoch; + + // Start actual clustering + final Map newFaceIdToCluster = {}; + final stopwatchClustering = Stopwatch()..start(); + for (int i = 0; i < totalFaces; i++) { + if ((i + 1) % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces"); + } + if (faceInfos[i].clusterId != null) continue; + int closestIdx = -1; + double closestDistance = double.infinity; + for (int j = 0; j < totalFaces; j++) { + if (i == j) continue; + final double distance = + 1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!); + if (distance < closestDistance) { + closestDistance = distance; + closestIdx = j; + } + } + + if (closestDistance < distanceThreshold) { + if (faceInfos[closestIdx].clusterId == null) { + clusterID++; + faceInfos[closestIdx].clusterId = clusterID; + } + faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; + } else { + clusterID++; + faceInfos[i].clusterId = clusterID; + } + } + + // Now calculate the mean of the embeddings for each cluster + final Map> clusterIdToFaceInfos = {}; + for (final faceInfo in faceInfos) { + if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) { + clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo); + } else { + clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + final Map clusterIdToMeanEmbeddingAndWeight = {}; + for (final clusterId in clusterIdToFaceInfos.keys) { + final List embeddings = clusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final count = clusterIdToFaceInfos[clusterId]!.length; + final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; + clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count); + } + + // Now merge the clusters that are close to each other, based on mean embedding + final List<(int, int)> mergedClustersList = []; + final List clusterIds = + clusterIdToMeanEmbeddingAndWeight.keys.toList(); + log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); + while (true) { + if (clusterIds.length < 2) break; + double distance = double.infinity; + (int, int) clusterIDsToMerge = (-1, -1); + for (int i = 0; i < clusterIds.length; i++) { + for (int j = 0; j < clusterIds.length; j++) { + if (i == j) continue; + final double newDistance = 1.0 - + clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot( + clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1, + ); + if (newDistance < distance) { + distance = newDistance; + clusterIDsToMerge = (clusterIds[i], clusterIds[j]); + } + } + } + if (distance < mergeThreshold) { + mergedClustersList.add(clusterIDsToMerge); + final clusterID1 = clusterIDsToMerge.$1; + final clusterID2 = clusterIDsToMerge.$2; + final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1; + final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1; + final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2; + final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2; + final weight1 = count1 / (count1 + count2); + final weight2 = count2 / (count1 + count2); + clusterIdToMeanEmbeddingAndWeight[clusterID1] = ( + mean1 * weight1 + mean2 * weight2, + count1 + count2, + ); + clusterIdToMeanEmbeddingAndWeight.remove(clusterID2); + clusterIds.remove(clusterID2); + } else { + break; + } + } + log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged'); + + // Now assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + for (final mergedClusters in mergedClustersList) { + if (faceInfo.clusterId == mergedClusters.$2) { + faceInfo.clusterId = mergedClusters.$1; + } + } + } + + // Finally, assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + + stopwatchClustering.stop(); + log( + ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', + ); + + return newFaceIdToCluster; + } + + static List> _runDbscanClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final eps = args['eps'] as double; + final minPts = args['minPts'] as int; + + log( + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", ); final DBSCAN dbscan = DBSCAN( @@ -535,7 +927,7 @@ class FaceClusteringService { // Organize everything into a list of FaceInfo objects final List faceInfos = []; - for (final entry in x.entries) { + for (final entry in input.entries) { faceInfos.add( FaceInfo( faceID: entry.key, diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart new file mode 100644 index 000000000..273d85da5 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart @@ -0,0 +1,19 @@ +import "dart:typed_data" show Uint8List; + +class FaceInfoForClustering { + final String faceID; + final int? clusterId; + final Uint8List embeddingBytes; + final double faceScore; + final double blurValue; + final bool isSideways; + + FaceInfoForClustering({ + required this.faceID, + this.clusterId, + required this.embeddingBytes, + required this.faceScore, + required this.blurValue, + this.isSideways = false, + }); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart index 7aa088141..de8535c87 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -1,7 +1,24 @@ -import 'dart:math' show sqrt, pow; +import 'dart:math' show max, min, pow, sqrt; import "package:photos/face/model/dimension.dart"; +enum FaceDirection { left, right, straight } + +extension FaceDirectionExtension on FaceDirection { + String toDirectionString() { + switch (this) { + case FaceDirection.left: + return 'Left'; + case FaceDirection.right: + return 'Right'; + case FaceDirection.straight: + return 'Straight'; + default: + throw Exception('Unknown FaceDirection'); + } + } +} + abstract class Detection { final double score; @@ -16,6 +33,7 @@ abstract class Detection { String toString(); } +@Deprecated('Old method only used in other deprecated methods') extension BBoxExtension on List { void roundBoxToDouble() { final widthRounded = (this[2] - this[0]).roundToDouble(); @@ -425,6 +443,37 @@ class FaceDetectionAbsolute extends Detection { /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight]. double get height => yMaxBox - yMinBox; + + FaceDirection getFaceDirection() { + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } } List relativeToAbsoluteDetections({ diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart index 43f6b252d..9c8d2d8c8 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart @@ -1,4 +1,5 @@ import 'package:logging/logging.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; class BlurDetectionService { @@ -11,9 +12,11 @@ class BlurDetectionService { Future<(bool, double)> predictIsBlurGrayLaplacian( List> grayImage, { - int threshold = kLaplacianThreshold, + int threshold = kLaplacianHardThreshold, + FaceDirection faceDirection = FaceDirection.straight, }) async { - final List> laplacian = _applyLaplacian(grayImage); + final List> laplacian = + _applyLaplacian(grayImage, faceDirection: faceDirection); final double variance = _calculateVariance(laplacian); _logger.info('Variance: $variance'); return (variance < threshold, variance); @@ -46,43 +49,80 @@ class BlurDetectionService { return variance; } - List> _padImage(List> image) { + List> _padImage( + List> image, { + int removeSideColumns = 56, + FaceDirection faceDirection = FaceDirection.straight, + }) { + // Exception is removeSideColumns is not even + if (removeSideColumns % 2 != 0) { + throw Exception('removeSideColumns must be even'); + } + final int numRows = image.length; final int numCols = image[0].length; + final int paddedNumCols = numCols + 2 - removeSideColumns; + final int paddedNumRows = numRows + 2; // Create a new matrix with extra padding final List> paddedImage = List.generate( - numRows + 2, - (i) => List.generate(numCols + 2, (j) => 0, growable: false), + paddedNumRows, + (i) => List.generate( + paddedNumCols, + (j) => 0, + growable: false, + ), growable: false, ); - // Copy original image into the center of the padded image - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - paddedImage[i + 1][j + 1] = image[i][j]; + // Copy original image into the center of the padded image, taking into account the face direction + if (faceDirection == FaceDirection.straight) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = + image[i][j + (removeSideColumns / 2).round()]; + } + } + // If the face is facing left, we only take the right side of the face image + } else if (faceDirection == FaceDirection.left) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns]; + } + } + // If the face is facing right, we only take the left side of the face image + } else if (faceDirection == FaceDirection.right) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j]; + } } } // Reflect padding // Top and bottom rows - for (int j = 1; j <= numCols; j++) { + for (int j = 1; j <= (paddedNumCols - 2); j++) { paddedImage[0][j] = paddedImage[2][j]; // Top row paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row } // Left and right columns for (int i = 0; i < numRows + 2; i++) { paddedImage[i][0] = paddedImage[i][2]; // Left column - paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column + paddedImage[i][paddedNumCols - 1] = + paddedImage[i][paddedNumCols - 3]; // Right column } return paddedImage; } - List> _applyLaplacian(List> image) { - final List> paddedImage = _padImage(image); - final int numRows = image.length; - final int numCols = image[0].length; + List> _applyLaplacian( + List> image, { + FaceDirection faceDirection = FaceDirection.straight, + }) { + final List> paddedImage = + _padImage(image, faceDirection: faceDirection); + final int numRows = paddedImage.length - 2; + final int numCols = paddedImage[0].length - 2; final List> outputImage = List.generate( numRows, (i) => List.generate(numCols, (j) => 0, growable: false), diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart index 8fb1838c0..b1f2f6018 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -1,13 +1,17 @@ import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; /// Blur detection threshold -const kLaplacianThreshold = 15; +const kLaplacianHardThreshold = 15; +const kLaplacianSoftThreshold = 100; +const kLaplacianVerySoftThreshold = 200; /// Default blur value const kLapacianDefault = 10000.0; /// The minimum score for a face to be considered a high quality face for clustering and person detection -const kMinHighQualityFaceScore = 0.80; +const kMinimumQualityFaceScore = 0.80; +const kMediumQualityFaceScore = 0.85; +const kHighQualityFaceScore = 0.90; -/// The minimum score for a face to be detected, regardless of quality. Use [kMinHighQualityFaceScore] for high quality faces. +/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart index 5ad0c4eee..19f954013 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -1,284 +1,18 @@ import "dart:convert" show jsonEncode, jsonDecode; -import "package:flutter/material.dart" show debugPrint, immutable; +import "package:flutter/material.dart" show immutable; import "package:logging/logging.dart"; import "package:photos/face/model/dimension.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/models/ml/ml_typedefs.dart'; import "package:photos/models/ml/ml_versions.dart"; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; -import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart'; final _logger = Logger('ClusterResult_FaceMlResult'); -// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class? -@Deprecated('We are now just storing the cluster results directly in DB') -class ClusterResult { - final int personId; - String? userDefinedName; - bool get hasUserDefinedName => userDefinedName != null; - - String _thumbnailFaceId; - bool thumbnailFaceIdIsUserDefined; - - final List _fileIds; - final List _faceIds; - - final Embedding medoid; - double medoidDistanceThreshold; - - List get uniqueFileIds => _fileIds.toSet().toList(); - List get fileIDsIncludingPotentialDuplicates => _fileIds; - - List get faceIDs => _faceIds; - - String get thumbnailFaceId => _thumbnailFaceId; - - int get thumbnailFileId => getFileIdFromFaceId(_thumbnailFaceId); - - /// Sets the thumbnail faceId to the given faceId. - /// Throws an exception if the faceId is not in the list of faceIds. - set setThumbnailFaceId(String faceId) { - if (!_faceIds.contains(faceId)) { - throw Exception( - "The faceId $faceId is not in the list of faceIds: $faceId", - ); - } - _thumbnailFaceId = faceId; - thumbnailFaceIdIsUserDefined = true; - } - - /// Sets the [userDefinedName] to the given [customName] - set setUserDefinedName(String customName) { - userDefinedName = customName; - } - - int get clusterSize => _fileIds.toSet().length; - - ClusterResult({ - required this.personId, - required String thumbnailFaceId, - required List fileIds, - required List faceIds, - required this.medoid, - required this.medoidDistanceThreshold, - this.userDefinedName, - this.thumbnailFaceIdIsUserDefined = false, - }) : _thumbnailFaceId = thumbnailFaceId, - _faceIds = faceIds, - _fileIds = fileIds; - - void addFileIDsAndFaceIDs(List fileIDs, List faceIDs) { - assert(fileIDs.length == faceIDs.length); - _fileIds.addAll(fileIDs); - _faceIds.addAll(faceIDs); - } - - // TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster - int removeFileId(int fileId) { - assert(_fileIds.length == _faceIds.length); - if (!_fileIds.contains(fileId)) { - throw Exception( - "The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.", - ); - } - - int removedCount = 0; - for (var i = 0; i < _fileIds.length; i++) { - if (_fileIds[i] == fileId) { - assert(getFileIdFromFaceId(_faceIds[i]) == fileId); - _fileIds.removeAt(i); - _faceIds.removeAt(i); - debugPrint( - "Removed fileId $fileId from cluster $personId at index ${i + removedCount}}", - ); - i--; // Adjust index due to removal - removedCount++; - } - } - - _ensureClusterSizeIsAboveMinimum(); - - return removedCount; - } - - int addFileID(int fileID) { - assert(_fileIds.length == _faceIds.length); - if (_fileIds.contains(fileID)) { - return 0; - } - - _fileIds.add(fileID); - _faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID)); - - return 1; - } - - void ensureThumbnailFaceIdIsInCluster() { - if (!_faceIds.contains(_thumbnailFaceId)) { - _thumbnailFaceId = _faceIds[0]; - } - } - - void _ensureClusterSizeIsAboveMinimum() { - if (clusterSize < minimumClusterSize) { - throw Exception( - "Cluster size is below minimum cluster size of $minimumClusterSize", - ); - } - } - - Map _toJson() => { - 'personId': personId, - 'thumbnailFaceId': _thumbnailFaceId, - 'fileIds': _fileIds, - 'faceIds': _faceIds, - 'medoid': medoid, - 'medoidDistanceThreshold': medoidDistanceThreshold, - if (userDefinedName != null) 'userDefinedName': userDefinedName, - 'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined, - }; - - String toJsonString() => jsonEncode(_toJson()); - - static ClusterResult _fromJson(Map json) { - return ClusterResult( - personId: json['personId'] ?? -1, - thumbnailFaceId: json['thumbnailFaceId'] ?? '', - fileIds: - (json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [], - faceIds: - (json['faceIds'] as List?)?.map((item) => item as String).toList() ?? - [], - medoid: - (json['medoid'] as List?)?.map((item) => item as double).toList() ?? - [], - medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0, - userDefinedName: json['userDefinedName'], - thumbnailFaceIdIsUserDefined: - json['thumbnailFaceIdIsUserDefined'] as bool, - ); - } - - static ClusterResult fromJsonString(String jsonString) { - return _fromJson(jsonDecode(jsonString)); - } -} - -class ClusterResultBuilder { - int personId = -1; - String? userDefinedName; - String thumbnailFaceId = ''; - bool thumbnailFaceIdIsUserDefined = false; - - List fileIds = []; - List faceIds = []; - - List embeddings = []; - Embedding medoid = []; - double medoidDistanceThreshold = 0; - bool medoidAndThresholdCalculated = false; - final int k = 5; - - ClusterResultBuilder.createFromIndices({ - required List clusterIndices, - required List labels, - required List allEmbeddings, - required List allFileIds, - required List allFaceIds, - }) { - final clusteredFileIds = - clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList(); - final clusteredFaceIds = - clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList(); - final clusteredEmbeddings = - clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList(); - personId = labels[clusterIndices[0]]; - fileIds = clusteredFileIds; - faceIds = clusteredFaceIds; - thumbnailFaceId = faceIds[0]; - embeddings = clusteredEmbeddings; - } - - void calculateAndSetMedoidAndThreshold() { - if (embeddings.isEmpty) { - throw Exception("Cannot calculate medoid and threshold for empty list"); - } - - // Calculate the medoid and threshold - final (tempMedoid, distanceThreshold) = - _calculateMedoidAndDistanceTreshold(embeddings); - - // Update the medoid - medoid = List.from(tempMedoid); - - // Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor - medoidDistanceThreshold = distanceThreshold; - - medoidAndThresholdCalculated = true; - } - - (List, double) _calculateMedoidAndDistanceTreshold( - List> embeddings, - ) { - double minDistance = double.infinity; - List? medoid; - - // Calculate the distance between all pairs - for (int i = 0; i < embeddings.length; ++i) { - double totalDistance = 0; - for (int j = 0; j < embeddings.length; ++j) { - if (i != j) { - totalDistance += cosineDistance(embeddings[i], embeddings[j]); - - // Break early if we already exceed minDistance - if (totalDistance > minDistance) { - break; - } - } - } - - // Find the minimum total distance - if (totalDistance < minDistance) { - minDistance = totalDistance; - medoid = embeddings[i]; - } - } - - // Now, calculate k-th nearest neighbor for the medoid - final List distancesToMedoid = []; - for (List embedding in embeddings) { - if (embedding != medoid) { - distancesToMedoid.add(cosineDistance(medoid!, embedding)); - } - } - distancesToMedoid.sort(); - // TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones. - final double kthDistance = distancesToMedoid[ - distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1]; - - return (medoid!, kthDistance); - } - - void changeThumbnailFaceId(String faceId) { - if (!faceIds.contains(faceId)) { - throw Exception( - "The faceId $faceId is not in the list of faceIds: $faceIds", - ); - } - thumbnailFaceId = faceId; - } - - void addFileIDsAndFaceIDs(List addedFileIDs, List addedFaceIDs) { - assert(addedFileIDs.length == addedFaceIDs.length); - fileIds.addAll(addedFileIDs); - faceIds.addAll(addedFaceIDs); - } -} - @immutable class FaceMlResult { final int fileId; @@ -504,7 +238,7 @@ class FaceResult { final int fileId; final String faceId; - bool get isBlurry => blurValue < kLaplacianThreshold; + bool get isBlurry => blurValue < kLaplacianHardThreshold; const FaceResult({ required this.detection, @@ -545,7 +279,7 @@ class FaceResultBuilder { int fileId = -1; String faceId = ''; - bool get isBlurry => blurValue < kLaplacianThreshold; + bool get isBlurry => blurValue < kLaplacianHardThreshold; FaceResultBuilder({ required this.fileId, diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index 4efcc444d..3df9b3056 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -204,83 +204,13 @@ class FaceMlService { try { switch (function) { case FaceMlOperation.analyzeImage: - final int enteFileID = args["enteFileID"] as int; - final String imagePath = args["filePath"] as String; - final int faceDetectionAddress = - args["faceDetectionAddress"] as int; - final int faceEmbeddingAddress = - args["faceEmbeddingAddress"] as int; - - final resultBuilder = - FaceMlResultBuilder.fromEnteFileID(enteFileID); - + final time = DateTime.now(); + final FaceMlResult result = + await FaceMlService.analyzeImageSync(args); dev.log( - "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", + "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms", ); - final stopwatchTotal = Stopwatch()..start(); - final stopwatch = Stopwatch()..start(); - - // Decode the image once to use for both face detection and alignment - final imageData = await File(imagePath).readAsBytes(); - final image = await decodeImageFromData(imageData); - final ByteData imgByteData = await getByteDataFromImage(image); - dev.log('Reading and decoding image took ' - '${stopwatch.elapsedMilliseconds} ms'); - stopwatch.reset(); - - // Get the faces - final List faceDetectionResult = - await FaceMlService.detectFacesSync( - image, - imgByteData, - faceDetectionAddress, - resultBuilder: resultBuilder, - ); - - dev.log( - "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - // If no faces were detected, return a result with no faces. Otherwise, continue. - if (faceDetectionResult.isEmpty) { - dev.log( - "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " - "${stopwatch.elapsedMilliseconds} ms"); - sendPort.send(resultBuilder.buildNoFaceDetected().toJsonString()); - break; - } - - stopwatch.reset(); - // Align the faces - final Float32List faceAlignmentResult = - await FaceMlService.alignFacesSync( - image, - imgByteData, - faceDetectionResult, - resultBuilder: resultBuilder, - ); - - dev.log("Completed `alignFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - stopwatch.reset(); - // Get the embeddings of the faces - final embeddings = await FaceMlService.embedFacesSync( - faceAlignmentResult, - faceEmbeddingAddress, - resultBuilder: resultBuilder, - ); - - dev.log("Completed `embedFacesSync` function, in " - "${stopwatch.elapsedMilliseconds} ms"); - - stopwatch.stop(); - stopwatchTotal.stop(); - dev.log("Finished Analyze image (${embeddings.length} faces) with " - "uploadedFileID $enteFileID, in " - "${stopwatchTotal.elapsedMilliseconds} ms"); - - sendPort.send(resultBuilder.build().toJsonString()); + sendPort.send(result.toJsonString()); break; } } catch (e, stackTrace) { @@ -361,7 +291,7 @@ class FaceMlService { } Future clusterAllImages({ - double minFaceScore = kMinHighQualityFaceScore, + double minFaceScore = kMinimumQualityFaceScore, bool clusterInBuckets = true, }) async { _logger.info("`clusterAllImages()` called"); @@ -370,6 +300,10 @@ class FaceMlService { // Get a sense of the total number of faces in the database final int totalFaces = await FaceMLDataDB.instance .getTotalFaceCount(minFaceScore: minFaceScore); + + // Get the current cluster statistics + final Map oldClusterSummaries = + await FaceMLDataDB.instance.getAllClusterSummary(); if (clusterInBuckets) { // read the creation times from Files DB, in a map from fileID to creation time final fileIDToCreationTime = @@ -382,14 +316,14 @@ class FaceMlService { int bucket = 1; while (true) { - final faceIdToEmbeddingBucket = - await FaceMLDataDB.instance.getFaceEmbeddingMap( + final faceInfoForClustering = + await FaceMLDataDB.instance.getFaceInfoForClustering( minScore: minFaceScore, maxFaces: bucketSize, offset: offset, batchSize: batchSize, ); - if (faceIdToEmbeddingBucket.isEmpty) { + if (faceInfoForClustering.isEmpty) { _logger.warning( 'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces', ); @@ -402,20 +336,24 @@ class FaceMlService { break; } - final faceIdToCluster = + final clusteringResult = await FaceClusteringService.instance.predictLinear( - faceIdToEmbeddingBucket, + faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, offset: offset, + oldClusterSummaries: oldClusterSummaries, ); - if (faceIdToCluster == null) { + if (clusteringResult == null) { _logger.warning("faceIdToCluster is null"); return; } - await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + await FaceMLDataDB.instance + .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); _logger.info( - 'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', + 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', ); if (offset + bucketSize >= totalFaces) { _logger.info('All faces clustered'); @@ -427,14 +365,14 @@ class FaceMlService { } else { // Read all the embeddings from the database, in a map from faceID to embedding final clusterStartTime = DateTime.now(); - final faceIdToEmbedding = - await FaceMLDataDB.instance.getFaceEmbeddingMap( + final faceInfoForClustering = + await FaceMLDataDB.instance.getFaceInfoForClustering( minScore: minFaceScore, maxFaces: totalFaces, ); final gotFaceEmbeddingsTime = DateTime.now(); _logger.info( - 'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', + 'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', ); // Read the creation times from Files DB, in a map from fileID to creation time @@ -444,25 +382,29 @@ class FaceMlService { '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms'); // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID - final faceIdToCluster = + final clusteringResult = await FaceClusteringService.instance.predictLinear( - faceIdToEmbedding, + faceInfoForClustering, fileIDToCreationTime: fileIDToCreationTime, + oldClusterSummaries: oldClusterSummaries, ); - if (faceIdToCluster == null) { + if (clusteringResult == null) { _logger.warning("faceIdToCluster is null"); return; } final clusterDoneTime = DateTime.now(); _logger.info( - 'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', + 'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', ); // Store the updated clusterIDs in the database _logger.info( - 'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB', + 'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB', ); - await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + await FaceMLDataDB.instance + .updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); _logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); } @@ -875,6 +817,7 @@ class FaceMlService { } } + /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate. Future analyzeImageInSingleIsolate(EnteFile enteFile) async { _checkEnteFileForID(enteFile); await ensureInitialized(); @@ -931,6 +874,87 @@ class FaceMlService { return result; } + static Future analyzeImageSync(Map args) async { + try { + final int enteFileID = args["enteFileID"] as int; + final String imagePath = args["filePath"] as String; + final int faceDetectionAddress = args["faceDetectionAddress"] as int; + final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int; + + final resultBuilder = FaceMlResultBuilder.fromEnteFileID(enteFileID); + + dev.log( + "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", + ); + final stopwatchTotal = Stopwatch()..start(); + final stopwatch = Stopwatch()..start(); + + // Decode the image once to use for both face detection and alignment + final imageData = await File(imagePath).readAsBytes(); + final image = await decodeImageFromData(imageData); + final ByteData imgByteData = await getByteDataFromImage(image); + dev.log('Reading and decoding image took ' + '${stopwatch.elapsedMilliseconds} ms'); + stopwatch.reset(); + + // Get the faces + final List faceDetectionResult = + await FaceMlService.detectFacesSync( + image, + imgByteData, + faceDetectionAddress, + resultBuilder: resultBuilder, + ); + + dev.log( + "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + // If no faces were detected, return a result with no faces. Otherwise, continue. + if (faceDetectionResult.isEmpty) { + dev.log( + "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " + "${stopwatch.elapsedMilliseconds} ms"); + return resultBuilder.buildNoFaceDetected(); + } + + stopwatch.reset(); + // Align the faces + final Float32List faceAlignmentResult = + await FaceMlService.alignFacesSync( + image, + imgByteData, + faceDetectionResult, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `alignFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.reset(); + // Get the embeddings of the faces + final embeddings = await FaceMlService.embedFacesSync( + faceAlignmentResult, + faceEmbeddingAddress, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `embedFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.stop(); + stopwatchTotal.stop(); + dev.log("Finished Analyze image (${embeddings.length} faces) with " + "uploadedFileID $enteFileID, in " + "${stopwatchTotal.elapsedMilliseconds} ms"); + + return resultBuilder.build(); + } catch (e, s) { + dev.log("Could not analyze image: \n e: $e \n s: $s"); + rethrow; + } + } + Future _getImagePathForML( EnteFile enteFile, { FileDataForML typeOfData = FileDataForML.fileData, diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index d89b007cc..95497a90d 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -5,6 +5,8 @@ import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; +// import "package:photos/events/files_updated_event.dart"; +// import "package:photos/events/local_photos_updated_event.dart"; import "package:photos/events/people_changed_event.dart"; import "package:photos/extensions/stop_watch.dart"; import "package:photos/face/db.dart"; @@ -115,17 +117,103 @@ class ClusterFeedbackService { List files, PersonEntity p, ) async { - await FaceMLDataDB.instance.removeFilesFromPerson(files, p.remoteID); - Bus.instance.fire(PeopleChangedEvent()); + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForPerson(p.remoteID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final newFaceIdToClusterID = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) { + return; + } + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + + // Make sure the deleted faces don't get suggested in the future + final notClusterIdToPersonId = {}; + for (final clusterId in newFaceIdToClusterID.values.toSet()) { + notClusterIdToPersonId[clusterId] = p.remoteID; + } + await FaceMLDataDB.instance + .bulkCaptureNotPersonFeedback(notClusterIdToPersonId); + + Bus.instance.fire(PeopleChangedEvent()); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromPerson", e, s); + rethrow; + } } Future removeFilesFromCluster( List files, int clusterID, ) async { - await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID); - Bus.instance.fire(PeopleChangedEvent()); - return; + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForCluster(clusterID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final newFaceIdToClusterID = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) { + return; + } + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + + Bus.instance.fire( + PeopleChangedEvent( + relevantFiles: files, + type: PeopleEventType.removedFilesFromCluster, + source: "$clusterID", + ), + ); + // Bus.instance.fire( + // LocalPhotosUpdatedEvent( + // files, + // type: EventType.peopleClusterChanged, + // source: "$clusterID", + // ), + // ); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromCluster", e, s); + rethrow; + } } Future addFilesToCluster(List faceIDs, int clusterID) async { @@ -194,7 +282,7 @@ class ClusterFeedbackService { // TODO: iterate over this method to find sweet spot Future>> breakUpCluster( int clusterID, { - useDbscan = false, + bool useDbscan = false, }) async { _logger.info( 'breakUpCluster called for cluster $clusterID with dbscan $useDbscan', @@ -203,10 +291,8 @@ class ClusterFeedbackService { final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); final originalFaceIDsSet = faceIDs.toSet(); - final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); - final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); - embeddings.removeWhere((key, value) => !faceIDs.contains(key)); + final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs); final fileIDToCreationTime = await FilesDB.instance.getFileIDToCreationTime(); @@ -232,18 +318,14 @@ class ClusterFeedbackService { maxClusterID++; } } else { - final clusteringInput = embeddings.map((key, value) { - return MapEntry(key, (null, value)); - }); - final faceIdToCluster = - await FaceClusteringService.instance.predictLinear( - clusteringInput, + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, fileIDToCreationTime: fileIDToCreationTime, - distanceThreshold: 0.23, + distanceThreshold: 0.22, ); - if (faceIdToCluster == null) { + if (faceIdToCluster == null || faceIdToCluster.isEmpty) { _logger.info('No clusters found'); return {}; } else { @@ -295,6 +377,62 @@ class ClusterFeedbackService { return clusterIdToFaceIds; } + /// WARNING: this method is purely for debugging purposes, never use in production + Future createFakeClustersByBlurValue() async { + try { + // Delete old clusters + await FaceMLDataDB.instance.resetClusterIDs(); + await FaceMLDataDB.instance.dropClustersAndPersonTable(); + final List persons = + await PersonService.instance.getPersons(); + for (final PersonEntity p in persons) { + await PersonService.instance.deletePerson(p.remoteID); + } + + // Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200 + final int startClusterID = DateTime.now().microsecondsSinceEpoch; + final faceIDsToBlurValues = + await FaceMLDataDB.instance.getFaceIDsToBlurValues(200); + final faceIdToCluster = {}; + for (final entry in faceIDsToBlurValues.entries) { + final faceID = entry.key; + final blurValue = entry.value; + final newClusterID = startClusterID + blurValue ~/ 10; + faceIdToCluster[faceID] = newClusterID; + } + await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); + + Bus.instance.fire(PeopleChangedEvent()); + } catch (e, s) { + _logger.severe("Error in createFakeClustersByBlurValue", e, s); + rethrow; + } + } + + Future debugLogClusterBlurValues( + int clusterID, { + int? clusterSize, + }) async { + final List blurValues = await FaceMLDataDB.instance + .getBlurValuesForCluster(clusterID) + .then((value) => value.toList()); + + // Round the blur values to integers + final blurValuesIntegers = + blurValues.map((value) => value.round()).toList(); + + // Sort the blur values in ascending order + blurValuesIntegers.sort(); + + // Log the sorted blur values + + _logger.info( + "Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers", + ); + + return; + } + /// Returns a map of person's clusterID to map of closest clusterID to with disstance Future>> getSuggestionsUsingMean( PersonEntity p, { @@ -523,7 +661,7 @@ class ClusterFeedbackService { ); final Map clusterToSummary = - await faceMlDb.clusterSummaryAll(); + await faceMlDb.getAllClusterSummary(); final Map updatesForClusterSummary = {}; final Map> clusterAvg = {}; @@ -714,7 +852,7 @@ class ClusterFeedbackService { // Get the cluster averages for the person's clusters and the suggestions' clusters final Map clusterToSummary = - await faceMlDb.clusterSummaryAll(); + await faceMlDb.getAllClusterSummary(); // Calculate the avg embedding of the person final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 7e1eb5eaa..3f54187c1 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -824,7 +824,7 @@ class SearchService { "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", ); } - if (files.length < 3 && sortedClusterIds.length > 3) { + if (files.length < 20 && sortedClusterIds.length > 3) { continue; } facesResult.add( diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index 85aa992a3..f1c835fa2 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -8,6 +8,7 @@ import "package:photos/events/people_changed_event.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/theme/ente_theme.dart'; import 'package:photos/ui/components/captioned_text_widget.dart'; @@ -284,6 +285,34 @@ class _FaceDebugSectionWidgetState extends State { }, ), sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Rank blurs", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "This will delete all clusters and put blurry faces in separate clusters per ten points.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await ClusterFeedbackService.instance + .createFakeClustersByBlurValue(); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('Failed to rank faces on blur values ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + }, + ), + sectionOptionSpacing, MenuItemWidget( captionedTextWidget: const CaptionedTextWidget( title: "Drop embeddings & feedback", diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 4bd4a7bb5..79637af82 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -9,6 +9,7 @@ import "package:photos/face/db.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/search_service.dart"; import "package:photos/theme/ente_theme.dart"; @@ -47,7 +48,7 @@ class _FaceWidgetState extends State { @override Widget build(BuildContext context) { - if (Platform.isIOS || Platform.isAndroid) { + if (Platform.isIOS) { return FutureBuilder( future: getFaceCrop(), builder: (context, snapshot) { @@ -164,19 +165,19 @@ class _FaceWidgetState extends State { ), if (kDebugMode) Text( - 'B: ${widget.face.blur.toStringAsFixed(3)}', + 'B: ${widget.face.blur.toStringAsFixed(0)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), if (kDebugMode) Text( - 'V: ${widget.face.visibility}', + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), if (kDebugMode) Text( - 'A: ${widget.face.area()}', + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), @@ -303,6 +304,24 @@ class _FaceWidgetState extends State { style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), + if (kDebugMode) + Text( + 'B: ${widget.face.blur.toStringAsFixed(0)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), ], ), ); diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart index e2cdaca52..12b830932 100644 --- a/mobile/lib/ui/viewer/people/cluster_page.dart +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -1,10 +1,12 @@ import "dart:async"; +import "package:flutter/foundation.dart"; import 'package:flutter/material.dart'; import "package:flutter_animate/flutter_animate.dart"; import 'package:photos/core/event_bus.dart'; import 'package:photos/events/files_updated_event.dart'; import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/events/people_changed_event.dart"; import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/file/file.dart'; @@ -51,6 +53,7 @@ class _ClusterPageState extends State { final _selectedFiles = SelectedFiles(); late final List files; late final StreamSubscription _filesUpdatedEvent; + late final StreamSubscription _peopleChangedEvent; @override void initState() { @@ -69,11 +72,27 @@ class _ClusterPageState extends State { setState(() {}); } }); + _peopleChangedEvent = Bus.instance.on().listen((event) { + if (event.type == PeopleEventType.removedFilesFromCluster && + (event.source == widget.clusterID.toString())) { + for (var updatedFile in event.relevantFiles!) { + files.remove(updatedFile); + } + setState(() {}); + } + }); + kDebugMode + ? ClusterFeedbackService.instance.debugLogClusterBlurValues( + widget.clusterID, + clusterSize: files.length, + ) + : null; } @override void dispose() { _filesUpdatedEvent.cancel(); + _peopleChangedEvent.cancel(); super.dispose(); } @@ -96,10 +115,12 @@ class _ClusterPageState extends State { ); }, reloadEvent: Bus.instance.on(), + forceReloadEvents: [Bus.instance.on()], removalEventTypes: const { EventType.deletedFromRemote, EventType.deletedFromEverywhere, EventType.hide, + EventType.peopleClusterChanged, }, tagPrefix: widget.tagPrefix + widget.tagPrefix, selectedFiles: _selectedFiles, @@ -111,9 +132,10 @@ class _ClusterPageState extends State { preferredSize: const Size.fromHeight(50.0), child: ClusterAppBar( SearchResultPage.appBarType, - "${widget.searchResult.length} memories${widget.appendTitle}", + "${files.length} memories${widget.appendTitle}", _selectedFiles, widget.clusterID, + key: ValueKey(files.length), ), ), body: Column( diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart index ab29eb919..7ce10e306 100644 --- a/mobile/lib/utils/image_ml_util.dart +++ b/mobile/lib/utils/image_ml_util.dart @@ -1099,19 +1099,16 @@ Future<(Float32List, List, List, List, Size)> imageHeight: image.height, ); - final List>> faceLandmarks = - absoluteFaces.map((face) => face.allKeypoints).toList(); - final alignedImagesFloat32List = - Float32List(3 * width * height * faceLandmarks.length); + Float32List(3 * width * height * absoluteFaces.length); final alignmentResults = []; final isBlurs = []; final blurValues = []; int alignedImageIndex = 0; - for (final faceLandmark in faceLandmarks) { + for (final face in absoluteFaces) { final (alignmentResult, correctlyEstimated) = - SimilarityTransform.instance.estimate(faceLandmark); + SimilarityTransform.instance.estimate(face.allKeypoints); if (!correctlyEstimated) { alignedImageIndex += 3 * width * height; alignmentResults.add(AlignmentResult.empty()); @@ -1137,7 +1134,7 @@ Future<(Float32List, List, List, List, Size)> final grayscalems = blurDetectionStopwatch.elapsedMilliseconds; log('creating grayscale matrix took $grayscalems ms'); final (isBlur, blurValue) = await BlurDetectionService.instance - .predictIsBlurGrayLaplacian(faceGrayMatrix); + .predictIsBlurGrayLaplacian(faceGrayMatrix, faceDirection: face.getFaceDirection()); final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems; log('blur detection took $blurms ms'); log(