Merge branch 'mobile_face' of https://github.com/ente-io/auth into mobile_face

This commit is contained in:
Neeraj Gupta 2024-04-20 16:01:08 +05:30
commit cc682a0a09
21 changed files with 1259 additions and 547 deletions

View file

@ -1316,8 +1316,8 @@ class FilesDB {
} }
Future<Map<int, int>> getFileIDToCreationTime() async { Future<Map<int, int>> getFileIDToCreationTime() async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final rows = await db.rawQuery( final rows = await db.getAll(
''' '''
SELECT $columnUploadedFileID, $columnCreationTime SELECT $columnUploadedFileID, $columnCreationTime
FROM $filesTable FROM $filesTable

View file

@ -27,4 +27,5 @@ enum EventType {
unhide, unhide,
coverChanged, coverChanged,
peopleChanged, peopleChanged,
peopleClusterChanged,
} }

View file

@ -1,3 +1,22 @@
import "package:photos/events/event.dart"; import "package:photos/events/event.dart";
import "package:photos/models/file/file.dart";
class PeopleChangedEvent extends Event {} class PeopleChangedEvent extends Event {
final List<EnteFile>? relevantFiles;
final PeopleEventType type;
final String source;
PeopleChangedEvent({
this.relevantFiles,
this.type = PeopleEventType.defaultType,
this.source = "",
});
@override
String get reason => '$runtimeType{type: ${type.name}, "via": $source}';
}
enum PeopleEventType {
defaultType,
removedFilesFromCluster,
}

View file

@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart';
import "package:photos/face/db_model_mappers.dart"; import "package:photos/face/db_model_mappers.dart";
import "package:photos/face/model/face.dart"; import "package:photos/face/model/face.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import 'package:sqflite/sqflite.dart'; import 'package:sqflite/sqflite.dart';
import 'package:sqlite_async/sqlite_async.dart' as sqlite_async; import 'package:sqlite_async/sqlite_async.dart' as sqlite_async;
@ -160,27 +161,27 @@ class FaceMLDataDB {
final db = await instance.database; final db = await instance.database;
// find out clusterIds that are assigned to other persons using the clusters table // find out clusterIds that are assigned to other persons using the clusters table
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
[personID], [personID],
); );
final Set<int> ignoredClusterIDs = final Set<int> ignoredClusterIDs =
maps.map((e) => e[cluserIDColumn] as int).toSet(); maps.map((e) => e[clusterIDColumn] as int).toSet();
final List<Map<String, dynamic>> rejectMaps = await db.rawQuery( final List<Map<String, dynamic>> rejectMaps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
[personID], [personID],
); );
final Set<int> rejectClusterIDs = final Set<int> rejectClusterIDs =
rejectMaps.map((e) => e[cluserIDColumn] as int).toSet(); rejectMaps.map((e) => e[clusterIDColumn] as int).toSet();
return ignoredClusterIDs.union(rejectClusterIDs); return ignoredClusterIDs.union(rejectClusterIDs);
} }
Future<Set<int>> getPersonClusterIDs(String personID) async { Future<Set<int>> getPersonClusterIDs(String personID) async {
final db = await instance.database; final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?',
[personID], [personID],
); );
return maps.map((e) => e[cluserIDColumn] as int).toSet(); return maps.map((e) => e[clusterIDColumn] as int).toSet();
} }
Future<void> clearTable() async { Future<void> clearTable() async {
@ -249,16 +250,16 @@ class FaceMLDataDB {
} }
final cluterRows = await db.query( final cluterRows = await db.query(
clusterPersonTable, clusterPersonTable,
columns: [cluserIDColumn], columns: [clusterIDColumn],
where: '$personIdColumn = ?', where: '$personIdColumn = ?',
whereArgs: [personID], whereArgs: [personID],
); );
final clusterIDs = final clusterIDs =
cluterRows.map((e) => e[cluserIDColumn] as int).toList(); cluterRows.map((e) => e[clusterIDColumn] as int).toList();
final List<Map<String, dynamic>> faceMaps = await db.rawQuery( final List<Map<String, dynamic>> faceMaps = await db.rawQuery(
'SELECT * FROM $facesTable where ' 'SELECT * FROM $facesTable where '
'$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))'
'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinHighQualityFaceScore ORDER BY $faceScore DESC', 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC',
); );
if (faceMaps.isNotEmpty) { if (faceMaps.isNotEmpty) {
if (avatarFileId != null) { if (avatarFileId != null) {
@ -308,8 +309,6 @@ class FaceMLDataDB {
faceBlur, faceBlur,
imageHeight, imageHeight,
imageWidth, imageWidth,
faceArea,
faceVisibilityScore,
mlVersionColumn, mlVersionColumn,
], ],
where: '$fileIDColumn = ?', where: '$fileIDColumn = ?',
@ -334,16 +333,60 @@ class FaceMLDataDB {
} }
Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async { Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async {
final db = await instance.database; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.query( final List<Map<String, dynamic>> maps = await db.getAll(
faceClustersTable, 'SELECT $fcFaceId FROM $faceClustersTable '
columns: [fcFaceId], 'WHERE $faceClustersTable.$fcClusterID = ?',
where: '$fcClusterID = ?', [clusterID],
whereArgs: [clusterID],
); );
return maps.map((e) => e[fcFaceId] as String).toSet(); return maps.map((e) => e[fcFaceId] as String).toSet();
} }
Future<Iterable<String>> getFaceIDsForPerson(String personID) async {
final db = await instance.sqliteAsyncDB;
final faceIdsResult = await db.getAll(
'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable '
'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
'WHERE $clusterPersonTable.$personIdColumn = ?',
[personID],
);
return faceIdsResult.map((e) => e[fcFaceId] as String).toSet();
}
Future<Iterable<double>> getBlurValuesForCluster(int clusterID) async {
final db = await instance.sqliteAsyncDB;
const String query = '''
SELECT $facesTable.$faceBlur
FROM $facesTable
JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$fcFaceId
WHERE $faceClustersTable.$fcClusterID = ?
''';
// const String query2 = '''
// SELECT $faceBlur
// FROM $facesTable
// WHERE $faceIDColumn IN (SELECT $fcFaceId FROM $faceClustersTable WHERE $fcClusterID = ?)
// ''';
final List<Map<String, dynamic>> maps = await db.getAll(
query,
[clusterID],
);
return maps.map((e) => e[faceBlur] as double).toSet();
}
Future<Map<String, double>> getFaceIDsToBlurValues(
int maxBlurValue,
) async {
final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceIDColumn, $faceBlur FROM $facesTable WHERE $faceBlur < $maxBlurValue AND $faceBlur > 1 ORDER BY $faceBlur ASC',
);
final Map<String, double> result = {};
for (final map in maps) {
result[map[faceIDColumn] as String] = map[faceBlur] as double;
}
return result;
}
Future<Map<String, int?>> getFaceIdsToClusterIds( Future<Map<String, int?>> getFaceIdsToClusterIds(
Iterable<String> faceIds, Iterable<String> faceIds,
) async { ) async {
@ -376,14 +419,14 @@ class FaceMLDataDB {
} }
Future<void> forceUpdateClusterIds( Future<void> forceUpdateClusterIds(
Map<String, int> faceIDToPersonID, Map<String, int> faceIDToClusterID,
) async { ) async {
final db = await instance.database; final db = await instance.database;
// Start a batch // Start a batch
final batch = db.batch(); final batch = db.batch();
for (final map in faceIDToPersonID.entries) { for (final map in faceIDToClusterID.entries) {
final faceID = map.key; final faceID = map.key;
final clusterID = map.value; final clusterID = map.value;
batch.insert( batch.insert(
@ -410,12 +453,64 @@ class FaceMLDataDB {
); );
} }
Future<Set<FaceInfoForClustering>> getFaceInfoForClustering({
double minScore = kMinimumQualityFaceScore,
int minClarity = kLaplacianHardThreshold,
int maxFaces = 20000,
int offset = 0,
int batchSize = 10000,
}) async {
final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start();
w.logAndReset(
'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize',
);
final db = await instance.sqliteAsyncDB;
final Set<FaceInfoForClustering> result = {};
while (true) {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable'
' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
' ORDER BY $faceIDColumn'
' DESC LIMIT $batchSize OFFSET $offset',
);
// Break the loop if no more rows
if (maps.isEmpty) {
break;
}
final List<String> faceIds = [];
for (final map in maps) {
faceIds.add(map[faceIDColumn] as String);
}
final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds);
for (final map in maps) {
final faceID = map[faceIDColumn] as String;
final faceInfo = FaceInfoForClustering(
faceID: faceID,
clusterId: faceIdToClusterId[faceID],
embeddingBytes: map[faceEmbeddingBlob] as Uint8List,
faceScore: map[faceScore] as double,
blurValue: map[faceBlur] as double,
isSideways: (map[isSideways] as int) == 1,
);
result.add(faceInfo);
}
if (result.length >= maxFaces) {
break;
}
offset += batchSize;
}
w.stopWithLog('done reading face embeddings ${result.length}');
return result;
}
/// Returns a map of faceID to record of clusterId and faceEmbeddingBlob /// Returns a map of faceID to record of clusterId and faceEmbeddingBlob
/// ///
/// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity]
Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({ Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
double minScore = kMinHighQualityFaceScore, double minScore = kMinimumQualityFaceScore,
int minClarity = kLaplacianThreshold, int minClarity = kLaplacianHardThreshold,
int maxFaces = 20000, int maxFaces = 20000,
int offset = 0, int offset = 0,
int batchSize = 10000, int batchSize = 10000,
@ -481,7 +576,7 @@ class FaceMLDataDB {
facesTable, facesTable,
columns: [faceIDColumn, faceEmbeddingBlob], columns: [faceIDColumn, faceEmbeddingBlob],
where: where:
'$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', '$faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
limit: batchSize, limit: batchSize,
offset: offset, offset: offset,
orderBy: '$faceIDColumn DESC', orderBy: '$faceIDColumn DESC',
@ -503,12 +598,50 @@ class FaceMLDataDB {
return result; return result;
} }
Future<Map<String, Uint8List>> getFaceEmbeddingMapForFaces(
Iterable<String> faceIDs,
) async {
_logger.info('reading face embeddings for ${faceIDs.length} faces');
final db = await instance.sqliteAsyncDB;
// Define the batch size
const batchSize = 10000;
int offset = 0;
final Map<String, Uint8List> result = {};
while (true) {
// Query a batch of rows
final String query = '''
SELECT $faceIDColumn, $faceEmbeddingBlob
FROM $facesTable
WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")})
ORDER BY $faceIDColumn DESC
LIMIT $batchSize OFFSET $offset
''';
final List<Map<String, dynamic>> maps = await db.getAll(query);
// Break the loop if no more rows
if (maps.isEmpty) {
break;
}
for (final map in maps) {
final faceID = map[faceIDColumn] as String;
result[faceID] = map[faceEmbeddingBlob] as Uint8List;
}
if (result.length > 10000) {
break;
}
offset += batchSize;
}
_logger.info('done reading face embeddings for ${faceIDs.length} faces');
return result;
}
Future<int> getTotalFaceCount({ Future<int> getTotalFaceCount({
double minFaceScore = kMinHighQualityFaceScore, double minFaceScore = kMinimumQualityFaceScore,
}) async { }) async {
final db = await instance.sqliteAsyncDB; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll( final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold', 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold',
); );
return maps.first['count'] as int; return maps.first['count'] as int;
} }
@ -517,7 +650,7 @@ class FaceMLDataDB {
final db = await instance.sqliteAsyncDB; final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> totalFacesMaps = await db.getAll( final List<Map<String, dynamic>> totalFacesMaps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold', 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold',
); );
final int totalFaces = totalFacesMaps.first['count'] as int; final int totalFaces = totalFacesMaps.first['count'] as int;
@ -530,11 +663,11 @@ class FaceMLDataDB {
} }
Future<int> getBlurryFaceCount([ Future<int> getBlurryFaceCount([
int blurThreshold = kLaplacianThreshold, int blurThreshold = kLaplacianHardThreshold,
]) async { ]) async {
final db = await instance.database; final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore', 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinimumQualityFaceScore',
); );
return maps.first['count'] as int; return maps.first['count'] as int;
} }
@ -555,7 +688,7 @@ class FaceMLDataDB {
clusterPersonTable, clusterPersonTable,
{ {
personIdColumn: personID, personIdColumn: personID,
cluserIDColumn: clusterID, clusterIDColumn: clusterID,
}, },
); );
} }
@ -572,7 +705,7 @@ class FaceMLDataDB {
clusterPersonTable, clusterPersonTable,
{ {
personIdColumn: personID, personIdColumn: personID,
cluserIDColumn: clusterID, clusterIDColumn: clusterID,
}, },
conflictAlgorithm: ConflictAlgorithm.replace, conflictAlgorithm: ConflictAlgorithm.replace,
); );
@ -589,11 +722,31 @@ class FaceMLDataDB {
notPersonFeedback, notPersonFeedback,
{ {
personIdColumn: personID, personIdColumn: personID,
cluserIDColumn: clusterID, clusterIDColumn: clusterID,
}, },
); );
} }
Future<void> bulkCaptureNotPersonFeedback(
Map<int, String> clusterToPersonID,
) async {
final db = await instance.database;
final batch = db.batch();
for (final entry in clusterToPersonID.entries) {
final clusterID = entry.key;
final personID = entry.value;
batch.insert(
notPersonFeedback,
{
personIdColumn: personID,
clusterIDColumn: clusterID,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
}
await batch.commit(noResult: true);
}
Future<int> removeClusterToPerson({ Future<int> removeClusterToPerson({
required String personID, required String personID,
required int clusterID, required int clusterID,
@ -601,7 +754,7 @@ class FaceMLDataDB {
final db = await instance.database; final db = await instance.database;
return db.delete( return db.delete(
clusterPersonTable, clusterPersonTable,
where: '$personIdColumn = ? AND $cluserIDColumn = ?', where: '$personIdColumn = ? AND $clusterIDColumn = ?',
whereArgs: [personID, clusterID], whereArgs: [personID, clusterID],
); );
} }
@ -613,13 +766,13 @@ class FaceMLDataDB {
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable ' 'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable '
'INNER JOIN $clusterPersonTable ' 'INNER JOIN $clusterPersonTable '
'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
'WHERE $clusterPersonTable.$personIdColumn = ?', 'WHERE $clusterPersonTable.$personIdColumn = ?',
[personID], [personID],
); );
final Map<int, Set<int>> result = {}; final Map<int, Set<int>> result = {};
for (final map in maps) { for (final map in maps) {
final clusterID = map[cluserIDColumn] as int; final clusterID = map[clusterIDColumn] as int;
final String faceID = map[fcFaceId] as String; final String faceID = map[fcFaceId] as String;
final fileID = int.parse(faceID.split('_').first); final fileID = int.parse(faceID.split('_').first);
result[fileID] = (result[fileID] ?? {})..add(clusterID); result[fileID] = (result[fileID] ?? {})..add(clusterID);
@ -664,7 +817,7 @@ class FaceMLDataDB {
batch.insert( batch.insert(
clusterSummaryTable, clusterSummaryTable,
{ {
cluserIDColumn: cluserID, clusterIDColumn: cluserID,
avgColumn: avg, avgColumn: avg,
countColumn: count, countColumn: count,
}, },
@ -676,12 +829,16 @@ class FaceMLDataDB {
} }
/// Returns a map of clusterID to (avg embedding, count) /// Returns a map of clusterID to (avg embedding, count)
Future<Map<int, (Uint8List, int)>> clusterSummaryAll() async { Future<Map<int, (Uint8List, int)>> getAllClusterSummary([
final db = await instance.database; int? minClusterSize,
]) async {
final db = await instance.sqliteAsyncDB;
final Map<int, (Uint8List, int)> result = {}; final Map<int, (Uint8List, int)> result = {};
final rows = await db.rawQuery('SELECT * from $clusterSummaryTable'); final rows = await db.getAll(
'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}',
);
for (final r in rows) { for (final r in rows) {
final id = r[cluserIDColumn] as int; final id = r[clusterIDColumn] as int;
final avg = r[avgColumn] as Uint8List; final avg = r[avgColumn] as Uint8List;
final count = r[countColumn] as int; final count = r[countColumn] as int;
result[id] = (avg, count); result[id] = (avg, count);
@ -692,11 +849,11 @@ class FaceMLDataDB {
Future<Map<int, String>> getClusterIDToPersonID() async { Future<Map<int, String>> getClusterIDToPersonID() async {
final db = await instance.database; final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery( final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $personIdColumn, $cluserIDColumn FROM $clusterPersonTable', 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable',
); );
final Map<int, String> result = {}; final Map<int, String> result = {};
for (final map in maps) { for (final map in maps) {
result[map[cluserIDColumn] as int] = map[personIdColumn] as String; result[map[clusterIDColumn] as int] = map[personIdColumn] as String;
} }
return result; return result;
} }
@ -741,7 +898,7 @@ class FaceMLDataDB {
final db = await instance.database; final db = await instance.database;
final faceIdsResult = await db.rawQuery( final faceIdsResult = await db.rawQuery(
'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable '
'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$cluserIDColumn ' 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
'WHERE $clusterPersonTable.$personIdColumn = ?', 'WHERE $clusterPersonTable.$personIdColumn = ?',
[personID], [personID],
); );

View file

@ -8,8 +8,7 @@ const faceDetectionColumn = 'detection';
const faceEmbeddingBlob = 'eBlob'; const faceEmbeddingBlob = 'eBlob';
const faceScore = 'score'; const faceScore = 'score';
const faceBlur = 'blur'; const faceBlur = 'blur';
const faceArea = 'area'; const isSideways = 'is_sideways';
const faceVisibilityScore = 'visibility';
const imageWidth = 'width'; const imageWidth = 'width';
const imageHeight = 'height'; const imageHeight = 'height';
const faceClusterId = 'cluster_id'; const faceClusterId = 'cluster_id';
@ -22,10 +21,9 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable (
$faceEmbeddingBlob BLOB NOT NULL, $faceEmbeddingBlob BLOB NOT NULL,
$faceScore REAL NOT NULL, $faceScore REAL NOT NULL,
$faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault,
$isSideways INTEGER NOT NULL DEFAULT 0,
$imageHeight INTEGER NOT NULL DEFAULT 0, $imageHeight INTEGER NOT NULL DEFAULT 0,
$imageWidth INTEGER NOT NULL DEFAULT 0, $imageWidth INTEGER NOT NULL DEFAULT 0,
$faceArea INTEGER NOT NULL DEFAULT 0,
$faceVisibilityScore INTEGER NOT NULL DEFAULT -1,
$mlVersionColumn INTEGER NOT NULL DEFAULT -1, $mlVersionColumn INTEGER NOT NULL DEFAULT -1,
PRIMARY KEY($fileIDColumn, $faceIDColumn) PRIMARY KEY($fileIDColumn, $faceIDColumn)
); );
@ -62,13 +60,13 @@ const deletePersonTable = 'DROP TABLE IF EXISTS $personTable';
// Clusters Table Fields & Schema Queries // Clusters Table Fields & Schema Queries
const clusterPersonTable = 'cluster_person'; const clusterPersonTable = 'cluster_person';
const personIdColumn = 'person_id'; const personIdColumn = 'person_id';
const cluserIDColumn = 'cluster_id'; const clusterIDColumn = 'cluster_id';
const createClusterPersonTable = ''' const createClusterPersonTable = '''
CREATE TABLE IF NOT EXISTS $clusterPersonTable ( CREATE TABLE IF NOT EXISTS $clusterPersonTable (
$personIdColumn TEXT NOT NULL, $personIdColumn TEXT NOT NULL,
$cluserIDColumn INTEGER NOT NULL, $clusterIDColumn INTEGER NOT NULL,
PRIMARY KEY($personIdColumn, $cluserIDColumn) PRIMARY KEY($personIdColumn, $clusterIDColumn)
); );
'''; ''';
const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable'; const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable';
@ -80,10 +78,10 @@ const avgColumn = 'avg';
const countColumn = 'count'; const countColumn = 'count';
const createClusterSummaryTable = ''' const createClusterSummaryTable = '''
CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( CREATE TABLE IF NOT EXISTS $clusterSummaryTable (
$cluserIDColumn INTEGER NOT NULL, $clusterIDColumn INTEGER NOT NULL,
$avgColumn BLOB NOT NULL, $avgColumn BLOB NOT NULL,
$countColumn INTEGER NOT NULL, $countColumn INTEGER NOT NULL,
PRIMARY KEY($cluserIDColumn) PRIMARY KEY($clusterIDColumn)
); );
'''; ''';
@ -97,7 +95,7 @@ const notPersonFeedback = 'not_person_feedback';
const createNotPersonFeedbackTable = ''' const createNotPersonFeedbackTable = '''
CREATE TABLE IF NOT EXISTS $notPersonFeedback ( CREATE TABLE IF NOT EXISTS $notPersonFeedback (
$personIdColumn TEXT NOT NULL, $personIdColumn TEXT NOT NULL,
$cluserIDColumn INTEGER NOT NULL $clusterIDColumn INTEGER NOT NULL
); );
'''; ''';
const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback'; const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback';

View file

@ -34,9 +34,8 @@ Map<String, dynamic> mapRemoteToFaceDB(Face face) {
).writeToBuffer(), ).writeToBuffer(),
faceScore: face.score, faceScore: face.score,
faceBlur: face.blur, faceBlur: face.blur,
isSideways: face.detection.faceIsSideways() ? 1 : 0,
mlVersionColumn: faceMlVersion, mlVersionColumn: faceMlVersion,
faceArea: face.area(),
faceVisibilityScore: face.visibility,
imageWidth: face.fileInfo?.imageWidth ?? 0, imageWidth: face.fileInfo?.imageWidth ?? 0,
imageHeight: face.fileInfo?.imageHeight ?? 0, imageHeight: face.fileInfo?.imageHeight ?? 0,
}; };

View file

@ -1,6 +1,9 @@
import "dart:math" show min, max;
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/face/model/box.dart"; import "package:photos/face/model/box.dart";
import "package:photos/face/model/landmark.dart"; import "package:photos/face/model/landmark.dart";
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
/// Stores the face detection data, notably the bounding box and landmarks. /// Stores the face detection data, notably the bounding box and landmarks.
/// ///
@ -19,7 +22,7 @@ class Detection {
bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty; bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty;
// emoty box // empty box
Detection.empty() Detection.empty()
: box = FaceBox( : box = FaceBox(
xMin: 0, xMin: 0,
@ -89,4 +92,72 @@ class Detection {
return -1; return -1;
} }
} }
FaceDirection getFaceDirection() {
if (isEmpty) {
return FaceDirection.straight;
}
final leftEye = [landmarks[0].x, landmarks[0].y];
final rightEye = [landmarks[1].x, landmarks[1].y];
final nose = [landmarks[2].x, landmarks[2].y];
final leftMouth = [landmarks[3].x, landmarks[3].y];
final rightMouth = [landmarks[4].x, landmarks[4].y];
final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
final bool faceIsUpright =
(max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
(nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) &&
(nose[0] < min(leftMouth[0], rightMouth[0]));
final bool noseStickingOutRight =
(nose[0] > max(leftEye[0], rightEye[0])) &&
(nose[0] > max(leftMouth[0], rightMouth[0]));
final bool noseCloseToLeftEye =
(nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX;
final bool noseCloseToRightEye =
(nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX;
// if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) {
if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
return FaceDirection.left;
// } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) {
} else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
return FaceDirection.right;
}
return FaceDirection.straight;
}
bool faceIsSideways() {
if (isEmpty) {
return false;
}
final leftEye = [landmarks[0].x, landmarks[0].y];
final rightEye = [landmarks[1].x, landmarks[1].y];
final nose = [landmarks[2].x, landmarks[2].y];
final leftMouth = [landmarks[3].x, landmarks[3].y];
final rightMouth = [landmarks[4].x, landmarks[4].y];
final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
final bool faceIsUpright =
(max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
(nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
final bool noseStickingOutLeft =
(nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
(nose[0] < min(leftMouth[0], rightMouth[0]));
final bool noseStickingOutRight =
(nose[0] > max(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
(nose[0] > max(leftMouth[0], rightMouth[0]));
return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight);
}
} }

View file

@ -20,9 +20,9 @@ class Face {
final double blur; final double blur;
FileInfo? fileInfo; FileInfo? fileInfo;
bool get isBlurry => blur < kLaplacianThreshold; bool get isBlurry => blur < kLaplacianHardThreshold;
bool get hasHighScore => score > kMinHighQualityFaceScore; bool get hasHighScore => score > kMinimumQualityFaceScore;
bool get isHighQuality => (!isBlurry) && hasHighScore; bool get isHighQuality => (!isBlurry) && hasHighScore;

View file

@ -2,19 +2,26 @@ import "dart:async";
import "dart:developer"; import "dart:developer";
import "dart:isolate"; import "dart:isolate";
import "dart:math" show max; import "dart:math" show max;
import "dart:typed_data"; import "dart:typed_data" show Uint8List;
import "package:computer/computer.dart";
import "package:flutter/foundation.dart" show kDebugMode;
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart"; import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart"; import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:simple_cluster/simple_cluster.dart"; import "package:simple_cluster/simple_cluster.dart";
import "package:synchronized/synchronized.dart"; import "package:synchronized/synchronized.dart";
class FaceInfo { class FaceInfo {
final String faceID; final String faceID;
final double? faceScore;
final double? blurValue;
final bool? badFace;
final List<double>? embedding; final List<double>? embedding;
final Vector? vEmbedding; final Vector? vEmbedding;
int? clusterId; int? clusterId;
@ -23,6 +30,9 @@ class FaceInfo {
int? fileCreationTime; int? fileCreationTime;
FaceInfo({ FaceInfo({
required this.faceID, required this.faceID,
this.faceScore,
this.blurValue,
this.badFace,
this.embedding, this.embedding,
this.vEmbedding, this.vEmbedding,
this.clusterId, this.clusterId,
@ -32,8 +42,18 @@ class FaceInfo {
enum ClusterOperation { linearIncrementalClustering, dbscanClustering } enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
class ClusteringResult {
final Map<String, int> newFaceIdToCluster;
final Map<int, (Uint8List, int)>? newClusterSummaries;
ClusteringResult({
required this.newFaceIdToCluster,
required this.newClusterSummaries,
});
}
class FaceClusteringService { class FaceClusteringService {
final _logger = Logger("FaceLinearClustering"); final _logger = Logger("FaceLinearClustering");
final _computer = Computer.shared();
Timer? _inactivityTimer; Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(minutes: 3); final Duration _inactivityDuration = const Duration(minutes: 3);
@ -49,6 +69,7 @@ class FaceClusteringService {
bool isRunning = false; bool isRunning = false;
static const kRecommendedDistanceThreshold = 0.24; static const kRecommendedDistanceThreshold = 0.24;
static const kConservativeDistanceThreshold = 0.06;
// singleton pattern // singleton pattern
FaceClusteringService._privateConstructor(); FaceClusteringService._privateConstructor();
@ -100,31 +121,11 @@ class FaceClusteringService {
try { try {
switch (function) { switch (function) {
case ClusterOperation.linearIncrementalClustering: case ClusterOperation.linearIncrementalClustering:
final input = args['input'] as Map<String, (int?, Uint8List)>; final result = FaceClusteringService.runLinearClustering(args);
final fileIDToCreationTime =
args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final offset = args['offset'] as int?;
final result = FaceClusteringService._runLinearClustering(
input,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: distanceThreshold,
offset: offset,
);
sendPort.send(result); sendPort.send(result);
break; break;
case ClusterOperation.dbscanClustering: case ClusterOperation.dbscanClustering:
final input = args['input'] as Map<String, Uint8List>; final result = FaceClusteringService._runDbscanClustering(args);
final fileIDToCreationTime =
args['fileIDToCreationTime'] as Map<int, int>?;
final eps = args['eps'] as double;
final minPts = args['minPts'] as int;
final result = FaceClusteringService._runDbscanClustering(
input,
fileIDToCreationTime: fileIDToCreationTime,
eps: eps,
minPts: minPts,
);
sendPort.send(result); sendPort.send(result);
break; break;
} }
@ -194,16 +195,19 @@ class FaceClusteringService {
_inactivityTimer?.cancel(); _inactivityTimer?.cancel();
} }
/// Runs the clustering algorithm on the given [input], in an isolate. /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate.
/// ///
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
/// ///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predictLinear( Future<ClusteringResult?> predictLinear(
Map<String, (int?, Uint8List)> input, { Set<FaceInfoForClustering> input, {
Map<int, int>? fileIDToCreationTime, Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold, double distanceThreshold = kRecommendedDistanceThreshold,
double conservativeDistanceThreshold = kConservativeDistanceThreshold,
bool useDynamicThreshold = true,
int? offset, int? offset,
required Map<int, (Uint8List, int)> oldClusterSummaries,
}) async { }) async {
if (input.isEmpty) { if (input.isEmpty) {
_logger.warning( _logger.warning(
@ -225,20 +229,23 @@ class FaceClusteringService {
final stopwatchClustering = Stopwatch()..start(); final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster = // final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input); // await _runLinearClusteringInComputer(input);
final Map<String, int> faceIdToCluster = await _runInIsolate( final ClusteringResult? faceIdToCluster = await _runInIsolate(
( (
ClusterOperation.linearIncrementalClustering, ClusterOperation.linearIncrementalClustering,
{ {
'input': input, 'input': input,
'fileIDToCreationTime': fileIDToCreationTime, 'fileIDToCreationTime': fileIDToCreationTime,
'distanceThreshold': distanceThreshold, 'distanceThreshold': distanceThreshold,
'conservativeDistanceThreshold': conservativeDistanceThreshold,
'useDynamicThreshold': useDynamicThreshold,
'offset': offset, 'offset': offset,
'oldClusterSummaries': oldClusterSummaries,
} }
), ),
); );
// return _runLinearClusteringInComputer(input); // return _runLinearClusteringInComputer(input);
_logger.info( _logger.info(
'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
); );
isRunning = false; isRunning = false;
@ -250,6 +257,142 @@ class FaceClusteringService {
} }
} }
/// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding
Future<ClusteringResult?> predictLinearComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
}) async {
if (input.isEmpty) {
_logger.warning(
"Linear Clustering dataset of embeddings is empty, returning empty list.",
);
return null;
}
// Clustering inside the isolate
_logger.info(
"Start Linear clustering on ${input.length} embeddings inside computer isolate",
);
try {
final clusteringInput = input
.map((key, value) {
return MapEntry(
key,
FaceInfoForClustering(
faceID: key,
embeddingBytes: value,
faceScore: kMinimumQualityFaceScore + 0.01,
blurValue: kLapacianDefault,
),
);
})
.values
.toSet();
final startTime = DateTime.now();
final faceIdToCluster = await _computer.compute(
runLinearClustering,
param: {
"input": clusteringInput,
"fileIDToCreationTime": fileIDToCreationTime,
"distanceThreshold": distanceThreshold,
"conservativeDistanceThreshold": distanceThreshold,
"useDynamicThreshold": false,
},
taskName: "createImageEmbedding",
) as ClusteringResult;
final endTime = DateTime.now();
_logger.info(
"Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
);
return faceIdToCluster;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
/// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer.
///
/// WARNING: Only use on small datasets, as it is not optimized for large datasets.
Future<Map<String, int>> predictCompleteComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
double mergeThreshold = 0.30,
}) async {
if (input.isEmpty) {
_logger.warning(
"Complete Clustering dataset of embeddings is empty, returning empty list.",
);
return {};
}
// Clustering inside the isolate
_logger.info(
"Start Complete clustering on ${input.length} embeddings inside computer isolate",
);
try {
final startTime = DateTime.now();
final faceIdToCluster = await _computer.compute(
runCompleteClustering,
param: {
"input": input,
"fileIDToCreationTime": fileIDToCreationTime,
"distanceThreshold": distanceThreshold,
"mergeThreshold": mergeThreshold,
},
taskName: "createImageEmbedding",
) as Map<String, int>;
final endTime = DateTime.now();
_logger.info(
"Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
);
return faceIdToCluster;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
Future<Map<String, int>?> predictWithinClusterComputer(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
}) async {
_logger.info(
'`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold',
);
try {
if (input.length < 100) {
final mergeThreshold = distanceThreshold + 0.06;
_logger.info(
'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold',
);
return predictCompleteComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
mergeThreshold: mergeThreshold,
);
} else {
_logger.info(
'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold',
);
final clusterResult = await predictLinearComputer(
input,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: distanceThreshold,
);
return clusterResult?.newFaceIdToCluster;
}
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
Future<List<List<String>>> predictDbscan( Future<List<List<String>>> predictDbscan(
Map<String, Uint8List> input, { Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime, Map<int, int>? fileIDToCreationTime,
@ -299,29 +442,42 @@ class FaceClusteringService {
return clusterFaceIDs; return clusterFaceIDs;
} }
static Map<String, int> _runLinearClustering( static ClusteringResult? runLinearClustering(Map args) {
Map<String, (int?, Uint8List)> x, { // final input = args['input'] as Map<String, (int?, Uint8List)>;
Map<int, int>? fileIDToCreationTime, final input = args['input'] as Set<FaceInfoForClustering>;
double distanceThreshold = kRecommendedDistanceThreshold, final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
int? offset, final distanceThreshold = args['distanceThreshold'] as double;
}) { final conservativeDistanceThreshold =
args['conservativeDistanceThreshold'] as double;
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
final offset = args['offset'] as int?;
final oldClusterSummaries =
args['oldClusterSummaries'] as Map<int, (Uint8List, int)>?;
log( log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
); );
// Organize everything into a list of FaceInfo objects // Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = []; final List<FaceInfo> faceInfos = [];
for (final entry in x.entries) { for (final face in input) {
faceInfos.add( faceInfos.add(
FaceInfo( FaceInfo(
faceID: entry.key, faceID: face.faceID,
faceScore: face.faceScore,
blurValue: face.blurValue,
badFace: face.faceScore < kMinimumQualityFaceScore ||
face.blurValue < kLaplacianSoftThreshold ||
(face.blurValue < kLaplacianVerySoftThreshold &&
face.faceScore < kMediumQualityFaceScore) ||
face.isSideways,
vEmbedding: Vector.fromList( vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value.$2).values, EVector.fromBuffer(face.embeddingBytes).values,
dtype: DType.float32, dtype: DType.float32,
), ),
clusterId: entry.value.$1, clusterId: face.clusterId,
fileCreationTime: fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
), ),
); );
} }
@ -351,19 +507,21 @@ class FaceClusteringService {
facesWithClusterID.add(faceInfo); facesWithClusterID.add(faceInfo);
} }
} }
final alreadyClusteredCount = facesWithClusterID.length;
final sortedFaceInfos = <FaceInfo>[]; final sortedFaceInfos = <FaceInfo>[];
sortedFaceInfos.addAll(facesWithClusterID); sortedFaceInfos.addAll(facesWithClusterID);
sortedFaceInfos.addAll(facesWithoutClusterID); sortedFaceInfos.addAll(facesWithoutClusterID);
log( log(
"[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId", "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId",
); );
// Make sure the first face has a clusterId // Make sure the first face has a clusterId
final int totalFaces = sortedFaceInfos.length; final int totalFaces = sortedFaceInfos.length;
int dynamicThresholdCount = 0;
if (sortedFaceInfos.isEmpty) { if (sortedFaceInfos.isEmpty) {
return {}; return null;
} }
// Start actual clustering // Start actual clustering
@ -377,7 +535,6 @@ class FaceClusteringService {
sortedFaceInfos[0].clusterId = clusterID; sortedFaceInfos[0].clusterId = clusterID;
clusterID++; clusterID++;
} }
final Map<String, int> newFaceIdToCluster = {};
final stopwatchClustering = Stopwatch()..start(); final stopwatchClustering = Stopwatch()..start();
for (int i = 1; i < totalFaces; i++) { for (int i = 1; i < totalFaces; i++) {
// Incremental clustering, so we can skip faces that already have a clusterId // Incremental clustering, so we can skip faces that already have a clusterId
@ -388,6 +545,15 @@ class FaceClusteringService {
int closestIdx = -1; int closestIdx = -1;
double closestDistance = double.infinity; double closestDistance = double.infinity;
late double thresholdValue;
if (useDynamicThreshold) {
thresholdValue = sortedFaceInfos[i].badFace!
? conservativeDistanceThreshold
: distanceThreshold;
if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++;
} else {
thresholdValue = distanceThreshold;
}
if (i % 250 == 0) { if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces"); log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
} }
@ -405,18 +571,16 @@ class FaceClusteringService {
); );
} }
if (distance < closestDistance) { if (distance < closestDistance) {
if (sortedFaceInfos[j].badFace! &&
distance > conservativeDistanceThreshold) {
continue;
}
closestDistance = distance; closestDistance = distance;
closestIdx = j; closestIdx = j;
// if (distance < distanceThreshold) {
// if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
// sortedFaceInfos[j].faceID.startsWith("15488756")) {
// log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
// }
// }
} }
} }
if (closestDistance < distanceThreshold) { if (closestDistance < thresholdValue) {
if (sortedFaceInfos[closestIdx].clusterId == null) { if (sortedFaceInfos[closestIdx].clusterId == null) {
// Ideally this should never happen, but just in case log it // Ideally this should never happen, but just in case log it
log( log(
@ -424,42 +588,99 @@ class FaceClusteringService {
); );
clusterID++; clusterID++;
sortedFaceInfos[closestIdx].clusterId = clusterID; sortedFaceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
} }
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
// );
// }
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
sortedFaceInfos[closestIdx].clusterId!;
} else { } else {
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
// );
// }
clusterID++; clusterID++;
sortedFaceInfos[i].clusterId = clusterID; sortedFaceInfos[i].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
} }
} }
// Finally, assign the new clusterId to the faces
final Map<String, int> newFaceIdToCluster = {};
final newClusteredFaceInfos =
sortedFaceInfos.sublist(alreadyClusteredCount);
for (final faceInfo in newClusteredFaceInfos) {
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
stopwatchClustering.stop(); stopwatchClustering.stop();
log( log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
); );
if (useDynamicThreshold) {
log(
"[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity",
);
}
// Now calculate the mean of the embeddings for each cluster and update the cluster summaries
Map<int, (Uint8List, int)>? newClusterSummaries;
if (oldClusterSummaries != null) {
newClusterSummaries = FaceClusteringService.updateClusterSummaries(
oldSummary: oldClusterSummaries,
newFaceInfos: newClusteredFaceInfos,
);
}
// analyze the results // analyze the results
FaceClusteringService._analyzeClusterResults(sortedFaceInfos); FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
return newFaceIdToCluster; return ClusteringResult(
newFaceIdToCluster: newFaceIdToCluster,
newClusterSummaries: newClusterSummaries,
);
}
static Map<int, (Uint8List, int)> updateClusterSummaries({
required Map<int, (Uint8List, int)> oldSummary,
required List<FaceInfo> newFaceInfos,
}) {
final calcSummariesStart = DateTime.now();
final Map<int, List<FaceInfo>> newClusterIdToFaceInfos = {};
for (final faceInfo in newFaceInfos) {
if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) {
newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo);
} else {
newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
}
}
final Map<int, (Uint8List, int)> newClusterSummaries = {};
for (final clusterId in newClusterIdToFaceInfos.keys) {
final List<Vector> newEmbeddings = newClusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final newCount = newEmbeddings.length;
if (oldSummary.containsKey(clusterId)) {
final oldMean = Vector.fromList(
EVector.fromBuffer(oldSummary[clusterId]!.$1).values,
dtype: DType.float32,
);
final oldCount = oldSummary[clusterId]!.$2;
final oldEmbeddings = oldMean * oldCount;
newEmbeddings.add(oldEmbeddings);
final newMeanVector =
newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount);
newClusterSummaries[clusterId] = (
EVector(values: newMeanVector.toList()).writeToBuffer(),
oldCount + newCount
);
} else {
final newMeanVector = newEmbeddings.reduce((a, b) => a + b) / newCount;
newClusterSummaries[clusterId] =
(EVector(values: newMeanVector.toList()).writeToBuffer(), newCount);
}
}
log(
"[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms",
);
return newClusterSummaries;
} }
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) { static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
if (!kDebugMode) return;
final stopwatch = Stopwatch()..start(); final stopwatch = Stopwatch()..start();
final Map<String, int> faceIdToCluster = {}; final Map<String, int> faceIdToCluster = {};
@ -517,14 +738,185 @@ class FaceClusteringService {
); );
} }
static List<List<String>> _runDbscanClustering( static Map<String, int> runCompleteClustering(Map args) {
Map<String, Uint8List> x, { final input = args['input'] as Map<String, Uint8List>;
Map<int, int>? fileIDToCreationTime, final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
double eps = 0.3, final distanceThreshold = args['distanceThreshold'] as double;
int minPts = 5, final mergeThreshold = args['mergeThreshold'] as double;
}) {
log( log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
);
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in input.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value).values,
dtype: DType.float32,
),
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}
if (faceInfos.isEmpty) {
return {};
}
final int totalFaces = faceInfos.length;
// Start actual clustering
log(
"[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering",
);
// set current epoch time as clusterID
int clusterID = DateTime.now().microsecondsSinceEpoch;
// Start actual clustering
final Map<String, int> newFaceIdToCluster = {};
final stopwatchClustering = Stopwatch()..start();
for (int i = 0; i < totalFaces; i++) {
if ((i + 1) % 250 == 0) {
log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces");
}
if (faceInfos[i].clusterId != null) continue;
int closestIdx = -1;
double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance =
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
}
}
if (closestDistance < distanceThreshold) {
if (faceInfos[closestIdx].clusterId == null) {
clusterID++;
faceInfos[closestIdx].clusterId = clusterID;
}
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
} else {
clusterID++;
faceInfos[i].clusterId = clusterID;
}
}
// Now calculate the mean of the embeddings for each cluster
final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
for (final faceInfo in faceInfos) {
if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
} else {
clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
}
}
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
for (final clusterId in clusterIdToFaceInfos.keys) {
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final count = clusterIdToFaceInfos[clusterId]!.length;
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
}
// Now merge the clusters that are close to each other, based on mean embedding
final List<(int, int)> mergedClustersList = [];
final List<int> clusterIds =
clusterIdToMeanEmbeddingAndWeight.keys.toList();
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
while (true) {
if (clusterIds.length < 2) break;
double distance = double.infinity;
(int, int) clusterIDsToMerge = (-1, -1);
for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue;
final double newDistance = 1.0 -
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
if (newDistance < distance) {
distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
}
}
}
if (distance < mergeThreshold) {
mergedClustersList.add(clusterIDsToMerge);
final clusterID1 = clusterIDsToMerge.$1;
final clusterID2 = clusterIDsToMerge.$2;
final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
final weight1 = count1 / (count1 + count2);
final weight2 = count2 / (count1 + count2);
clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
mean1 * weight1 + mean2 * weight2,
count1 + count2,
);
clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
clusterIds.remove(clusterID2);
} else {
break;
}
}
log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
// Now assign the new clusterId to the faces
for (final faceInfo in faceInfos) {
for (final mergedClusters in mergedClustersList) {
if (faceInfo.clusterId == mergedClusters.$2) {
faceInfo.clusterId = mergedClusters.$1;
}
}
}
// Finally, assign the new clusterId to the faces
for (final faceInfo in faceInfos) {
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
stopwatchClustering.stop();
log(
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
);
return newFaceIdToCluster;
}
static List<List<String>> _runDbscanClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final eps = args['eps'] as double;
final minPts = args['minPts'] as int;
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces",
); );
final DBSCAN dbscan = DBSCAN( final DBSCAN dbscan = DBSCAN(
@ -535,7 +927,7 @@ class FaceClusteringService {
// Organize everything into a list of FaceInfo objects // Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = []; final List<FaceInfo> faceInfos = [];
for (final entry in x.entries) { for (final entry in input.entries) {
faceInfos.add( faceInfos.add(
FaceInfo( FaceInfo(
faceID: entry.key, faceID: entry.key,

View file

@ -0,0 +1,19 @@
import "dart:typed_data" show Uint8List;
class FaceInfoForClustering {
final String faceID;
final int? clusterId;
final Uint8List embeddingBytes;
final double faceScore;
final double blurValue;
final bool isSideways;
FaceInfoForClustering({
required this.faceID,
this.clusterId,
required this.embeddingBytes,
required this.faceScore,
required this.blurValue,
this.isSideways = false,
});
}

View file

@ -1,7 +1,24 @@
import 'dart:math' show sqrt, pow; import 'dart:math' show max, min, pow, sqrt;
import "package:photos/face/model/dimension.dart"; import "package:photos/face/model/dimension.dart";
enum FaceDirection { left, right, straight }
extension FaceDirectionExtension on FaceDirection {
String toDirectionString() {
switch (this) {
case FaceDirection.left:
return 'Left';
case FaceDirection.right:
return 'Right';
case FaceDirection.straight:
return 'Straight';
default:
throw Exception('Unknown FaceDirection');
}
}
}
abstract class Detection { abstract class Detection {
final double score; final double score;
@ -16,6 +33,7 @@ abstract class Detection {
String toString(); String toString();
} }
@Deprecated('Old method only used in other deprecated methods')
extension BBoxExtension on List<double> { extension BBoxExtension on List<double> {
void roundBoxToDouble() { void roundBoxToDouble() {
final widthRounded = (this[2] - this[0]).roundToDouble(); final widthRounded = (this[2] - this[0]).roundToDouble();
@ -425,6 +443,37 @@ class FaceDetectionAbsolute extends Detection {
/// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight]. /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight].
double get height => yMaxBox - yMinBox; double get height => yMaxBox - yMinBox;
FaceDirection getFaceDirection() {
final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs();
final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs();
final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs();
final bool faceIsUpright =
(max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) &&
(nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1]));
final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) &&
(nose[0] < min(leftMouth[0], rightMouth[0]));
final bool noseStickingOutRight =
(nose[0] > max(leftEye[0], rightEye[0])) &&
(nose[0] > max(leftMouth[0], rightMouth[0]));
final bool noseCloseToLeftEye =
(nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX;
final bool noseCloseToRightEye =
(nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX;
// if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) {
if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
return FaceDirection.left;
// } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) {
} else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
return FaceDirection.right;
}
return FaceDirection.straight;
}
} }
List<FaceDetectionAbsolute> relativeToAbsoluteDetections({ List<FaceDetectionAbsolute> relativeToAbsoluteDetections({

View file

@ -1,4 +1,5 @@
import 'package:logging/logging.dart'; import 'package:logging/logging.dart';
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
class BlurDetectionService { class BlurDetectionService {
@ -11,9 +12,11 @@ class BlurDetectionService {
Future<(bool, double)> predictIsBlurGrayLaplacian( Future<(bool, double)> predictIsBlurGrayLaplacian(
List<List<int>> grayImage, { List<List<int>> grayImage, {
int threshold = kLaplacianThreshold, int threshold = kLaplacianHardThreshold,
FaceDirection faceDirection = FaceDirection.straight,
}) async { }) async {
final List<List<int>> laplacian = _applyLaplacian(grayImage); final List<List<int>> laplacian =
_applyLaplacian(grayImage, faceDirection: faceDirection);
final double variance = _calculateVariance(laplacian); final double variance = _calculateVariance(laplacian);
_logger.info('Variance: $variance'); _logger.info('Variance: $variance');
return (variance < threshold, variance); return (variance < threshold, variance);
@ -46,43 +49,80 @@ class BlurDetectionService {
return variance; return variance;
} }
List<List<int>> _padImage(List<List<int>> image) { List<List<int>> _padImage(
List<List<int>> image, {
int removeSideColumns = 56,
FaceDirection faceDirection = FaceDirection.straight,
}) {
// Exception is removeSideColumns is not even
if (removeSideColumns % 2 != 0) {
throw Exception('removeSideColumns must be even');
}
final int numRows = image.length; final int numRows = image.length;
final int numCols = image[0].length; final int numCols = image[0].length;
final int paddedNumCols = numCols + 2 - removeSideColumns;
final int paddedNumRows = numRows + 2;
// Create a new matrix with extra padding // Create a new matrix with extra padding
final List<List<int>> paddedImage = List.generate( final List<List<int>> paddedImage = List.generate(
numRows + 2, paddedNumRows,
(i) => List.generate(numCols + 2, (j) => 0, growable: false), (i) => List.generate(
paddedNumCols,
(j) => 0,
growable: false,
),
growable: false, growable: false,
); );
// Copy original image into the center of the padded image // Copy original image into the center of the padded image, taking into account the face direction
if (faceDirection == FaceDirection.straight) {
for (int i = 0; i < numRows; i++) { for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numCols; j++) { for (int j = 0; j < (paddedNumCols - 2); j++) {
paddedImage[i + 1][j + 1] =
image[i][j + (removeSideColumns / 2).round()];
}
}
// If the face is facing left, we only take the right side of the face image
} else if (faceDirection == FaceDirection.left) {
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < (paddedNumCols - 2); j++) {
paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns];
}
}
// If the face is facing right, we only take the left side of the face image
} else if (faceDirection == FaceDirection.right) {
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < (paddedNumCols - 2); j++) {
paddedImage[i + 1][j + 1] = image[i][j]; paddedImage[i + 1][j + 1] = image[i][j];
} }
} }
}
// Reflect padding // Reflect padding
// Top and bottom rows // Top and bottom rows
for (int j = 1; j <= numCols; j++) { for (int j = 1; j <= (paddedNumCols - 2); j++) {
paddedImage[0][j] = paddedImage[2][j]; // Top row paddedImage[0][j] = paddedImage[2][j]; // Top row
paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row
} }
// Left and right columns // Left and right columns
for (int i = 0; i < numRows + 2; i++) { for (int i = 0; i < numRows + 2; i++) {
paddedImage[i][0] = paddedImage[i][2]; // Left column paddedImage[i][0] = paddedImage[i][2]; // Left column
paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column paddedImage[i][paddedNumCols - 1] =
paddedImage[i][paddedNumCols - 3]; // Right column
} }
return paddedImage; return paddedImage;
} }
List<List<int>> _applyLaplacian(List<List<int>> image) { List<List<int>> _applyLaplacian(
final List<List<int>> paddedImage = _padImage(image); List<List<int>> image, {
final int numRows = image.length; FaceDirection faceDirection = FaceDirection.straight,
final int numCols = image[0].length; }) {
final List<List<int>> paddedImage =
_padImage(image, faceDirection: faceDirection);
final int numRows = paddedImage.length - 2;
final int numCols = paddedImage[0].length - 2;
final List<List<int>> outputImage = List.generate( final List<List<int>> outputImage = List.generate(
numRows, numRows,
(i) => List.generate(numCols, (j) => 0, growable: false), (i) => List.generate(numCols, (j) => 0, growable: false),

View file

@ -1,13 +1,17 @@
import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
/// Blur detection threshold /// Blur detection threshold
const kLaplacianThreshold = 15; const kLaplacianHardThreshold = 15;
const kLaplacianSoftThreshold = 100;
const kLaplacianVerySoftThreshold = 200;
/// Default blur value /// Default blur value
const kLapacianDefault = 10000.0; const kLapacianDefault = 10000.0;
/// The minimum score for a face to be considered a high quality face for clustering and person detection /// The minimum score for a face to be considered a high quality face for clustering and person detection
const kMinHighQualityFaceScore = 0.80; const kMinimumQualityFaceScore = 0.80;
const kMediumQualityFaceScore = 0.85;
const kHighQualityFaceScore = 0.90;
/// The minimum score for a face to be detected, regardless of quality. Use [kMinHighQualityFaceScore] for high quality faces. /// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;

View file

@ -1,284 +1,18 @@
import "dart:convert" show jsonEncode, jsonDecode; import "dart:convert" show jsonEncode, jsonDecode;
import "package:flutter/material.dart" show debugPrint, immutable; import "package:flutter/material.dart" show immutable;
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/face/model/dimension.dart"; import "package:photos/face/model/dimension.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import 'package:photos/models/ml/ml_typedefs.dart'; import 'package:photos/models/ml/ml_typedefs.dart';
import "package:photos/models/ml/ml_versions.dart"; import "package:photos/models/ml/ml_versions.dart";
import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart';
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart'; import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart';
final _logger = Logger('ClusterResult_FaceMlResult'); final _logger = Logger('ClusterResult_FaceMlResult');
// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class?
@Deprecated('We are now just storing the cluster results directly in DB')
class ClusterResult {
final int personId;
String? userDefinedName;
bool get hasUserDefinedName => userDefinedName != null;
String _thumbnailFaceId;
bool thumbnailFaceIdIsUserDefined;
final List<int> _fileIds;
final List<String> _faceIds;
final Embedding medoid;
double medoidDistanceThreshold;
List<int> get uniqueFileIds => _fileIds.toSet().toList();
List<int> get fileIDsIncludingPotentialDuplicates => _fileIds;
List<String> get faceIDs => _faceIds;
String get thumbnailFaceId => _thumbnailFaceId;
int get thumbnailFileId => getFileIdFromFaceId(_thumbnailFaceId);
/// Sets the thumbnail faceId to the given faceId.
/// Throws an exception if the faceId is not in the list of faceIds.
set setThumbnailFaceId(String faceId) {
if (!_faceIds.contains(faceId)) {
throw Exception(
"The faceId $faceId is not in the list of faceIds: $faceId",
);
}
_thumbnailFaceId = faceId;
thumbnailFaceIdIsUserDefined = true;
}
/// Sets the [userDefinedName] to the given [customName]
set setUserDefinedName(String customName) {
userDefinedName = customName;
}
int get clusterSize => _fileIds.toSet().length;
ClusterResult({
required this.personId,
required String thumbnailFaceId,
required List<int> fileIds,
required List<String> faceIds,
required this.medoid,
required this.medoidDistanceThreshold,
this.userDefinedName,
this.thumbnailFaceIdIsUserDefined = false,
}) : _thumbnailFaceId = thumbnailFaceId,
_faceIds = faceIds,
_fileIds = fileIds;
void addFileIDsAndFaceIDs(List<int> fileIDs, List<String> faceIDs) {
assert(fileIDs.length == faceIDs.length);
_fileIds.addAll(fileIDs);
_faceIds.addAll(faceIDs);
}
// TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster
int removeFileId(int fileId) {
assert(_fileIds.length == _faceIds.length);
if (!_fileIds.contains(fileId)) {
throw Exception(
"The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.",
);
}
int removedCount = 0;
for (var i = 0; i < _fileIds.length; i++) {
if (_fileIds[i] == fileId) {
assert(getFileIdFromFaceId(_faceIds[i]) == fileId);
_fileIds.removeAt(i);
_faceIds.removeAt(i);
debugPrint(
"Removed fileId $fileId from cluster $personId at index ${i + removedCount}}",
);
i--; // Adjust index due to removal
removedCount++;
}
}
_ensureClusterSizeIsAboveMinimum();
return removedCount;
}
int addFileID(int fileID) {
assert(_fileIds.length == _faceIds.length);
if (_fileIds.contains(fileID)) {
return 0;
}
_fileIds.add(fileID);
_faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID));
return 1;
}
void ensureThumbnailFaceIdIsInCluster() {
if (!_faceIds.contains(_thumbnailFaceId)) {
_thumbnailFaceId = _faceIds[0];
}
}
void _ensureClusterSizeIsAboveMinimum() {
if (clusterSize < minimumClusterSize) {
throw Exception(
"Cluster size is below minimum cluster size of $minimumClusterSize",
);
}
}
Map<String, dynamic> _toJson() => {
'personId': personId,
'thumbnailFaceId': _thumbnailFaceId,
'fileIds': _fileIds,
'faceIds': _faceIds,
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
if (userDefinedName != null) 'userDefinedName': userDefinedName,
'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined,
};
String toJsonString() => jsonEncode(_toJson());
static ClusterResult _fromJson(Map<String, dynamic> json) {
return ClusterResult(
personId: json['personId'] ?? -1,
thumbnailFaceId: json['thumbnailFaceId'] ?? '',
fileIds:
(json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [],
faceIds:
(json['faceIds'] as List?)?.map((item) => item as String).toList() ??
[],
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0,
userDefinedName: json['userDefinedName'],
thumbnailFaceIdIsUserDefined:
json['thumbnailFaceIdIsUserDefined'] as bool,
);
}
static ClusterResult fromJsonString(String jsonString) {
return _fromJson(jsonDecode(jsonString));
}
}
class ClusterResultBuilder {
int personId = -1;
String? userDefinedName;
String thumbnailFaceId = '';
bool thumbnailFaceIdIsUserDefined = false;
List<int> fileIds = <int>[];
List<String> faceIds = <String>[];
List<Embedding> embeddings = <Embedding>[];
Embedding medoid = <double>[];
double medoidDistanceThreshold = 0;
bool medoidAndThresholdCalculated = false;
final int k = 5;
ClusterResultBuilder.createFromIndices({
required List<int> clusterIndices,
required List<int> labels,
required List<Embedding> allEmbeddings,
required List<int> allFileIds,
required List<String> allFaceIds,
}) {
final clusteredFileIds =
clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList();
final clusteredFaceIds =
clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList();
final clusteredEmbeddings =
clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList();
personId = labels[clusterIndices[0]];
fileIds = clusteredFileIds;
faceIds = clusteredFaceIds;
thumbnailFaceId = faceIds[0];
embeddings = clusteredEmbeddings;
}
void calculateAndSetMedoidAndThreshold() {
if (embeddings.isEmpty) {
throw Exception("Cannot calculate medoid and threshold for empty list");
}
// Calculate the medoid and threshold
final (tempMedoid, distanceThreshold) =
_calculateMedoidAndDistanceTreshold(embeddings);
// Update the medoid
medoid = List.from(tempMedoid);
// Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor
medoidDistanceThreshold = distanceThreshold;
medoidAndThresholdCalculated = true;
}
(List<double>, double) _calculateMedoidAndDistanceTreshold(
List<List<double>> embeddings,
) {
double minDistance = double.infinity;
List<double>? medoid;
// Calculate the distance between all pairs
for (int i = 0; i < embeddings.length; ++i) {
double totalDistance = 0;
for (int j = 0; j < embeddings.length; ++j) {
if (i != j) {
totalDistance += cosineDistance(embeddings[i], embeddings[j]);
// Break early if we already exceed minDistance
if (totalDistance > minDistance) {
break;
}
}
}
// Find the minimum total distance
if (totalDistance < minDistance) {
minDistance = totalDistance;
medoid = embeddings[i];
}
}
// Now, calculate k-th nearest neighbor for the medoid
final List<double> distancesToMedoid = [];
for (List<double> embedding in embeddings) {
if (embedding != medoid) {
distancesToMedoid.add(cosineDistance(medoid!, embedding));
}
}
distancesToMedoid.sort();
// TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones.
final double kthDistance = distancesToMedoid[
distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1];
return (medoid!, kthDistance);
}
void changeThumbnailFaceId(String faceId) {
if (!faceIds.contains(faceId)) {
throw Exception(
"The faceId $faceId is not in the list of faceIds: $faceIds",
);
}
thumbnailFaceId = faceId;
}
void addFileIDsAndFaceIDs(List<int> addedFileIDs, List<String> addedFaceIDs) {
assert(addedFileIDs.length == addedFaceIDs.length);
fileIds.addAll(addedFileIDs);
faceIds.addAll(addedFaceIDs);
}
}
@immutable @immutable
class FaceMlResult { class FaceMlResult {
final int fileId; final int fileId;
@ -504,7 +238,7 @@ class FaceResult {
final int fileId; final int fileId;
final String faceId; final String faceId;
bool get isBlurry => blurValue < kLaplacianThreshold; bool get isBlurry => blurValue < kLaplacianHardThreshold;
const FaceResult({ const FaceResult({
required this.detection, required this.detection,
@ -545,7 +279,7 @@ class FaceResultBuilder {
int fileId = -1; int fileId = -1;
String faceId = ''; String faceId = '';
bool get isBlurry => blurValue < kLaplacianThreshold; bool get isBlurry => blurValue < kLaplacianHardThreshold;
FaceResultBuilder({ FaceResultBuilder({
required this.fileId, required this.fileId,

View file

@ -204,83 +204,13 @@ class FaceMlService {
try { try {
switch (function) { switch (function) {
case FaceMlOperation.analyzeImage: case FaceMlOperation.analyzeImage:
final int enteFileID = args["enteFileID"] as int; final time = DateTime.now();
final String imagePath = args["filePath"] as String; final FaceMlResult result =
final int faceDetectionAddress = await FaceMlService.analyzeImageSync(args);
args["faceDetectionAddress"] as int;
final int faceEmbeddingAddress =
args["faceEmbeddingAddress"] as int;
final resultBuilder =
FaceMlResultBuilder.fromEnteFileID(enteFileID);
dev.log( dev.log(
"Start analyzing image with uploadedFileID: $enteFileID inside the isolate", "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms",
); );
final stopwatchTotal = Stopwatch()..start(); sendPort.send(result.toJsonString());
final stopwatch = Stopwatch()..start();
// Decode the image once to use for both face detection and alignment
final imageData = await File(imagePath).readAsBytes();
final image = await decodeImageFromData(imageData);
final ByteData imgByteData = await getByteDataFromImage(image);
dev.log('Reading and decoding image took '
'${stopwatch.elapsedMilliseconds} ms');
stopwatch.reset();
// Get the faces
final List<FaceDetectionRelative> faceDetectionResult =
await FaceMlService.detectFacesSync(
image,
imgByteData,
faceDetectionAddress,
resultBuilder: resultBuilder,
);
dev.log(
"${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
// If no faces were detected, return a result with no faces. Otherwise, continue.
if (faceDetectionResult.isEmpty) {
dev.log(
"No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in "
"${stopwatch.elapsedMilliseconds} ms");
sendPort.send(resultBuilder.buildNoFaceDetected().toJsonString());
break;
}
stopwatch.reset();
// Align the faces
final Float32List faceAlignmentResult =
await FaceMlService.alignFacesSync(
image,
imgByteData,
faceDetectionResult,
resultBuilder: resultBuilder,
);
dev.log("Completed `alignFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
stopwatch.reset();
// Get the embeddings of the faces
final embeddings = await FaceMlService.embedFacesSync(
faceAlignmentResult,
faceEmbeddingAddress,
resultBuilder: resultBuilder,
);
dev.log("Completed `embedFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
stopwatch.stop();
stopwatchTotal.stop();
dev.log("Finished Analyze image (${embeddings.length} faces) with "
"uploadedFileID $enteFileID, in "
"${stopwatchTotal.elapsedMilliseconds} ms");
sendPort.send(resultBuilder.build().toJsonString());
break; break;
} }
} catch (e, stackTrace) { } catch (e, stackTrace) {
@ -361,7 +291,7 @@ class FaceMlService {
} }
Future<void> clusterAllImages({ Future<void> clusterAllImages({
double minFaceScore = kMinHighQualityFaceScore, double minFaceScore = kMinimumQualityFaceScore,
bool clusterInBuckets = true, bool clusterInBuckets = true,
}) async { }) async {
_logger.info("`clusterAllImages()` called"); _logger.info("`clusterAllImages()` called");
@ -370,6 +300,10 @@ class FaceMlService {
// Get a sense of the total number of faces in the database // Get a sense of the total number of faces in the database
final int totalFaces = await FaceMLDataDB.instance final int totalFaces = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: minFaceScore); .getTotalFaceCount(minFaceScore: minFaceScore);
// Get the current cluster statistics
final Map<int, (Uint8List, int)> oldClusterSummaries =
await FaceMLDataDB.instance.getAllClusterSummary();
if (clusterInBuckets) { if (clusterInBuckets) {
// read the creation times from Files DB, in a map from fileID to creation time // read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime = final fileIDToCreationTime =
@ -382,14 +316,14 @@ class FaceMlService {
int bucket = 1; int bucket = 1;
while (true) { while (true) {
final faceIdToEmbeddingBucket = final faceInfoForClustering =
await FaceMLDataDB.instance.getFaceEmbeddingMap( await FaceMLDataDB.instance.getFaceInfoForClustering(
minScore: minFaceScore, minScore: minFaceScore,
maxFaces: bucketSize, maxFaces: bucketSize,
offset: offset, offset: offset,
batchSize: batchSize, batchSize: batchSize,
); );
if (faceIdToEmbeddingBucket.isEmpty) { if (faceInfoForClustering.isEmpty) {
_logger.warning( _logger.warning(
'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces', 'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
); );
@ -402,20 +336,24 @@ class FaceMlService {
break; break;
} }
final faceIdToCluster = final clusteringResult =
await FaceClusteringService.instance.predictLinear( await FaceClusteringService.instance.predictLinear(
faceIdToEmbeddingBucket, faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime, fileIDToCreationTime: fileIDToCreationTime,
offset: offset, offset: offset,
oldClusterSummaries: oldClusterSummaries,
); );
if (faceIdToCluster == null) { if (clusteringResult == null) {
_logger.warning("faceIdToCluster is null"); _logger.warning("faceIdToCluster is null");
return; return;
} }
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info( _logger.info(
'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset', 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
); );
if (offset + bucketSize >= totalFaces) { if (offset + bucketSize >= totalFaces) {
_logger.info('All faces clustered'); _logger.info('All faces clustered');
@ -427,14 +365,14 @@ class FaceMlService {
} else { } else {
// Read all the embeddings from the database, in a map from faceID to embedding // Read all the embeddings from the database, in a map from faceID to embedding
final clusterStartTime = DateTime.now(); final clusterStartTime = DateTime.now();
final faceIdToEmbedding = final faceInfoForClustering =
await FaceMLDataDB.instance.getFaceEmbeddingMap( await FaceMLDataDB.instance.getFaceInfoForClustering(
minScore: minFaceScore, minScore: minFaceScore,
maxFaces: totalFaces, maxFaces: totalFaces,
); );
final gotFaceEmbeddingsTime = DateTime.now(); final gotFaceEmbeddingsTime = DateTime.now();
_logger.info( _logger.info(
'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms', 'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
); );
// Read the creation times from Files DB, in a map from fileID to creation time // Read the creation times from Files DB, in a map from fileID to creation time
@ -444,25 +382,29 @@ class FaceMlService {
'${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms'); '${DateTime.now().difference(gotFaceEmbeddingsTime).inMilliseconds} ms');
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster = final clusteringResult =
await FaceClusteringService.instance.predictLinear( await FaceClusteringService.instance.predictLinear(
faceIdToEmbedding, faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime, fileIDToCreationTime: fileIDToCreationTime,
oldClusterSummaries: oldClusterSummaries,
); );
if (faceIdToCluster == null) { if (clusteringResult == null) {
_logger.warning("faceIdToCluster is null"); _logger.warning("faceIdToCluster is null");
return; return;
} }
final clusterDoneTime = DateTime.now(); final clusterDoneTime = DateTime.now();
_logger.info( _logger.info(
'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', 'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
); );
// Store the updated clusterIDs in the database // Store the updated clusterIDs in the database
_logger.info( _logger.info(
'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB', 'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
); );
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); await FaceMLDataDB.instance
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
await FaceMLDataDB.instance
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' _logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
'${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds');
} }
@ -875,6 +817,7 @@ class FaceMlService {
} }
} }
/// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate.
Future<FaceMlResult?> analyzeImageInSingleIsolate(EnteFile enteFile) async { Future<FaceMlResult?> analyzeImageInSingleIsolate(EnteFile enteFile) async {
_checkEnteFileForID(enteFile); _checkEnteFileForID(enteFile);
await ensureInitialized(); await ensureInitialized();
@ -931,6 +874,87 @@ class FaceMlService {
return result; return result;
} }
static Future<FaceMlResult> analyzeImageSync(Map args) async {
try {
final int enteFileID = args["enteFileID"] as int;
final String imagePath = args["filePath"] as String;
final int faceDetectionAddress = args["faceDetectionAddress"] as int;
final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int;
final resultBuilder = FaceMlResultBuilder.fromEnteFileID(enteFileID);
dev.log(
"Start analyzing image with uploadedFileID: $enteFileID inside the isolate",
);
final stopwatchTotal = Stopwatch()..start();
final stopwatch = Stopwatch()..start();
// Decode the image once to use for both face detection and alignment
final imageData = await File(imagePath).readAsBytes();
final image = await decodeImageFromData(imageData);
final ByteData imgByteData = await getByteDataFromImage(image);
dev.log('Reading and decoding image took '
'${stopwatch.elapsedMilliseconds} ms');
stopwatch.reset();
// Get the faces
final List<FaceDetectionRelative> faceDetectionResult =
await FaceMlService.detectFacesSync(
image,
imgByteData,
faceDetectionAddress,
resultBuilder: resultBuilder,
);
dev.log(
"${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
// If no faces were detected, return a result with no faces. Otherwise, continue.
if (faceDetectionResult.isEmpty) {
dev.log(
"No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in "
"${stopwatch.elapsedMilliseconds} ms");
return resultBuilder.buildNoFaceDetected();
}
stopwatch.reset();
// Align the faces
final Float32List faceAlignmentResult =
await FaceMlService.alignFacesSync(
image,
imgByteData,
faceDetectionResult,
resultBuilder: resultBuilder,
);
dev.log("Completed `alignFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
stopwatch.reset();
// Get the embeddings of the faces
final embeddings = await FaceMlService.embedFacesSync(
faceAlignmentResult,
faceEmbeddingAddress,
resultBuilder: resultBuilder,
);
dev.log("Completed `embedFacesSync` function, in "
"${stopwatch.elapsedMilliseconds} ms");
stopwatch.stop();
stopwatchTotal.stop();
dev.log("Finished Analyze image (${embeddings.length} faces) with "
"uploadedFileID $enteFileID, in "
"${stopwatchTotal.elapsedMilliseconds} ms");
return resultBuilder.build();
} catch (e, s) {
dev.log("Could not analyze image: \n e: $e \n s: $s");
rethrow;
}
}
Future<String?> _getImagePathForML( Future<String?> _getImagePathForML(
EnteFile enteFile, { EnteFile enteFile, {
FileDataForML typeOfData = FileDataForML.fileData, FileDataForML typeOfData = FileDataForML.fileData,

View file

@ -5,6 +5,8 @@ import "package:flutter/foundation.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart"; import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart"; import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart";
// import "package:photos/events/local_photos_updated_event.dart";
import "package:photos/events/people_changed_event.dart"; import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart"; import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart"; import "package:photos/face/db.dart";
@ -115,19 +117,105 @@ class ClusterFeedbackService {
List<EnteFile> files, List<EnteFile> files,
PersonEntity p, PersonEntity p,
) async { ) async {
await FaceMLDataDB.instance.removeFilesFromPerson(files, p.remoteID); try {
// Get the relevant faces to be removed
final faceIDs = await FaceMLDataDB.instance
.getFaceIDsForPerson(p.remoteID)
.then((iterable) => iterable.toList());
faceIDs.retainWhere((faceID) {
final fileID = getFileIdFromFaceId(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Re-cluster within the deleted faces
final newFaceIdToClusterID =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.20,
);
if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
return;
}
// Update the deleted faces
await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
// Make sure the deleted faces don't get suggested in the future
final notClusterIdToPersonId = <int, String>{};
for (final clusterId in newFaceIdToClusterID.values.toSet()) {
notClusterIdToPersonId[clusterId] = p.remoteID;
}
await FaceMLDataDB.instance
.bulkCaptureNotPersonFeedback(notClusterIdToPersonId);
Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire(PeopleChangedEvent());
return;
} catch (e, s) {
_logger.severe("Error in removeFilesFromPerson", e, s);
rethrow;
}
} }
Future<void> removeFilesFromCluster( Future<void> removeFilesFromCluster(
List<EnteFile> files, List<EnteFile> files,
int clusterID, int clusterID,
) async { ) async {
await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID); try {
Bus.instance.fire(PeopleChangedEvent()); // Get the relevant faces to be removed
final faceIDs = await FaceMLDataDB.instance
.getFaceIDsForCluster(clusterID)
.then((iterable) => iterable.toList());
faceIDs.retainWhere((faceID) {
final fileID = getFileIdFromFaceId(faceID);
return files.any((file) => file.uploadedFileID == fileID);
});
final embeddings =
await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs);
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Re-cluster within the deleted faces
final newFaceIdToClusterID =
await FaceClusteringService.instance.predictWithinClusterComputer(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.20,
);
if (newFaceIdToClusterID == null || newFaceIdToClusterID.isEmpty) {
return; return;
} }
// Update the deleted faces
await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID);
Bus.instance.fire(
PeopleChangedEvent(
relevantFiles: files,
type: PeopleEventType.removedFilesFromCluster,
source: "$clusterID",
),
);
// Bus.instance.fire(
// LocalPhotosUpdatedEvent(
// files,
// type: EventType.peopleClusterChanged,
// source: "$clusterID",
// ),
// );
return;
} catch (e, s) {
_logger.severe("Error in removeFilesFromCluster", e, s);
rethrow;
}
}
Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async { Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID); await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID);
Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire(PeopleChangedEvent());
@ -194,7 +282,7 @@ class ClusterFeedbackService {
// TODO: iterate over this method to find sweet spot // TODO: iterate over this method to find sweet spot
Future<Map<int, List<String>>> breakUpCluster( Future<Map<int, List<String>>> breakUpCluster(
int clusterID, { int clusterID, {
useDbscan = false, bool useDbscan = false,
}) async { }) async {
_logger.info( _logger.info(
'breakUpCluster called for cluster $clusterID with dbscan $useDbscan', 'breakUpCluster called for cluster $clusterID with dbscan $useDbscan',
@ -203,10 +291,8 @@ class ClusterFeedbackService {
final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID);
final originalFaceIDsSet = faceIDs.toSet(); final originalFaceIDsSet = faceIDs.toSet();
final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs);
embeddings.removeWhere((key, value) => !faceIDs.contains(key));
final fileIDToCreationTime = final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime(); await FilesDB.instance.getFileIDToCreationTime();
@ -232,18 +318,14 @@ class ClusterFeedbackService {
maxClusterID++; maxClusterID++;
} }
} else { } else {
final clusteringInput = embeddings.map((key, value) {
return MapEntry(key, (null, value));
});
final faceIdToCluster = final faceIdToCluster =
await FaceClusteringService.instance.predictLinear( await FaceClusteringService.instance.predictWithinClusterComputer(
clusteringInput, embeddings,
fileIDToCreationTime: fileIDToCreationTime, fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.23, distanceThreshold: 0.22,
); );
if (faceIdToCluster == null) { if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
_logger.info('No clusters found'); _logger.info('No clusters found');
return {}; return {};
} else { } else {
@ -295,6 +377,62 @@ class ClusterFeedbackService {
return clusterIdToFaceIds; return clusterIdToFaceIds;
} }
/// WARNING: this method is purely for debugging purposes, never use in production
Future<void> createFakeClustersByBlurValue() async {
try {
// Delete old clusters
await FaceMLDataDB.instance.resetClusterIDs();
await FaceMLDataDB.instance.dropClustersAndPersonTable();
final List<PersonEntity> persons =
await PersonService.instance.getPersons();
for (final PersonEntity p in persons) {
await PersonService.instance.deletePerson(p.remoteID);
}
// Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200
final int startClusterID = DateTime.now().microsecondsSinceEpoch;
final faceIDsToBlurValues =
await FaceMLDataDB.instance.getFaceIDsToBlurValues(200);
final faceIdToCluster = <String, int>{};
for (final entry in faceIDsToBlurValues.entries) {
final faceID = entry.key;
final blurValue = entry.value;
final newClusterID = startClusterID + blurValue ~/ 10;
faceIdToCluster[faceID] = newClusterID;
}
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
Bus.instance.fire(PeopleChangedEvent());
} catch (e, s) {
_logger.severe("Error in createFakeClustersByBlurValue", e, s);
rethrow;
}
}
Future<void> debugLogClusterBlurValues(
int clusterID, {
int? clusterSize,
}) async {
final List<double> blurValues = await FaceMLDataDB.instance
.getBlurValuesForCluster(clusterID)
.then((value) => value.toList());
// Round the blur values to integers
final blurValuesIntegers =
blurValues.map((value) => value.round()).toList();
// Sort the blur values in ascending order
blurValuesIntegers.sort();
// Log the sorted blur values
_logger.info(
"Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers",
);
return;
}
/// Returns a map of person's clusterID to map of closest clusterID to with disstance /// Returns a map of person's clusterID to map of closest clusterID to with disstance
Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean( Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
PersonEntity p, { PersonEntity p, {
@ -523,7 +661,7 @@ class ClusterFeedbackService {
); );
final Map<int, (Uint8List, int)> clusterToSummary = final Map<int, (Uint8List, int)> clusterToSummary =
await faceMlDb.clusterSummaryAll(); await faceMlDb.getAllClusterSummary();
final Map<int, (Uint8List, int)> updatesForClusterSummary = {}; final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {}; final Map<int, List<double>> clusterAvg = {};
@ -714,7 +852,7 @@ class ClusterFeedbackService {
// Get the cluster averages for the person's clusters and the suggestions' clusters // Get the cluster averages for the person's clusters and the suggestions' clusters
final Map<int, (Uint8List, int)> clusterToSummary = final Map<int, (Uint8List, int)> clusterToSummary =
await faceMlDb.clusterSummaryAll(); await faceMlDb.getAllClusterSummary();
// Calculate the avg embedding of the person // Calculate the avg embedding of the person
final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);

View file

@ -824,7 +824,7 @@ class SearchService {
"Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
); );
} }
if (files.length < 3 && sortedClusterIds.length > 3) { if (files.length < 20 && sortedClusterIds.length > 3) {
continue; continue;
} }
facesResult.add( facesResult.add(

View file

@ -8,6 +8,7 @@ import "package:photos/events/people_changed_event.dart";
import "package:photos/face/db.dart"; import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart"; import "package:photos/face/model/person.dart";
import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart';
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/theme/ente_theme.dart'; import 'package:photos/theme/ente_theme.dart';
import 'package:photos/ui/components/captioned_text_widget.dart'; import 'package:photos/ui/components/captioned_text_widget.dart';
@ -284,6 +285,34 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
}, },
), ),
sectionOptionSpacing, sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Rank blurs",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
await showChoiceDialog(
context,
title: "Are you sure?",
body:
"This will delete all clusters and put blurry faces in separate clusters per ten points.",
firstButtonLabel: "Yes, confirm",
firstButtonOnTap: () async {
try {
await ClusterFeedbackService.instance
.createFakeClustersByBlurValue();
showShortToast(context, "Done");
} catch (e, s) {
_logger.warning('Failed to rank faces on blur values ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
);
},
),
sectionOptionSpacing,
MenuItemWidget( MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget( captionedTextWidget: const CaptionedTextWidget(
title: "Drop embeddings & feedback", title: "Drop embeddings & feedback",

View file

@ -9,6 +9,7 @@ import "package:photos/face/db.dart";
import "package:photos/face/model/face.dart"; import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart"; import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart'; import 'package:photos/models/file/file.dart';
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/search_service.dart"; import "package:photos/services/search_service.dart";
import "package:photos/theme/ente_theme.dart"; import "package:photos/theme/ente_theme.dart";
@ -47,7 +48,7 @@ class _FaceWidgetState extends State<FaceWidget> {
@override @override
Widget build(BuildContext context) { Widget build(BuildContext context) {
if (Platform.isIOS || Platform.isAndroid) { if (Platform.isIOS) {
return FutureBuilder<Uint8List?>( return FutureBuilder<Uint8List?>(
future: getFaceCrop(), future: getFaceCrop(),
builder: (context, snapshot) { builder: (context, snapshot) {
@ -164,19 +165,19 @@ class _FaceWidgetState extends State<FaceWidget> {
), ),
if (kDebugMode) if (kDebugMode)
Text( Text(
'B: ${widget.face.blur.toStringAsFixed(3)}', 'B: ${widget.face.blur.toStringAsFixed(0)}',
style: Theme.of(context).textTheme.bodySmall, style: Theme.of(context).textTheme.bodySmall,
maxLines: 1, maxLines: 1,
), ),
if (kDebugMode) if (kDebugMode)
Text( Text(
'V: ${widget.face.visibility}', 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}',
style: Theme.of(context).textTheme.bodySmall, style: Theme.of(context).textTheme.bodySmall,
maxLines: 1, maxLines: 1,
), ),
if (kDebugMode) if (kDebugMode)
Text( Text(
'A: ${widget.face.area()}', 'Sideways: ${widget.face.detection.faceIsSideways().toString()}',
style: Theme.of(context).textTheme.bodySmall, style: Theme.of(context).textTheme.bodySmall,
maxLines: 1, maxLines: 1,
), ),
@ -303,6 +304,24 @@ class _FaceWidgetState extends State<FaceWidget> {
style: Theme.of(context).textTheme.bodySmall, style: Theme.of(context).textTheme.bodySmall,
maxLines: 1, maxLines: 1,
), ),
if (kDebugMode)
Text(
'B: ${widget.face.blur.toStringAsFixed(0)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
if (kDebugMode)
Text(
'D: ${widget.face.detection.getFaceDirection().toDirectionString()}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
if (kDebugMode)
Text(
'Sideways: ${widget.face.detection.faceIsSideways().toString()}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
], ],
), ),
); );

View file

@ -1,10 +1,12 @@
import "dart:async"; import "dart:async";
import "package:flutter/foundation.dart";
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import "package:flutter_animate/flutter_animate.dart"; import "package:flutter_animate/flutter_animate.dart";
import 'package:photos/core/event_bus.dart'; import 'package:photos/core/event_bus.dart';
import 'package:photos/events/files_updated_event.dart'; import 'package:photos/events/files_updated_event.dart';
import 'package:photos/events/local_photos_updated_event.dart'; import 'package:photos/events/local_photos_updated_event.dart';
import "package:photos/events/people_changed_event.dart";
import "package:photos/face/model/person.dart"; import "package:photos/face/model/person.dart";
import "package:photos/generated/l10n.dart"; import "package:photos/generated/l10n.dart";
import 'package:photos/models/file/file.dart'; import 'package:photos/models/file/file.dart';
@ -51,6 +53,7 @@ class _ClusterPageState extends State<ClusterPage> {
final _selectedFiles = SelectedFiles(); final _selectedFiles = SelectedFiles();
late final List<EnteFile> files; late final List<EnteFile> files;
late final StreamSubscription<LocalPhotosUpdatedEvent> _filesUpdatedEvent; late final StreamSubscription<LocalPhotosUpdatedEvent> _filesUpdatedEvent;
late final StreamSubscription<PeopleChangedEvent> _peopleChangedEvent;
@override @override
void initState() { void initState() {
@ -69,11 +72,27 @@ class _ClusterPageState extends State<ClusterPage> {
setState(() {}); setState(() {});
} }
}); });
_peopleChangedEvent = Bus.instance.on<PeopleChangedEvent>().listen((event) {
if (event.type == PeopleEventType.removedFilesFromCluster &&
(event.source == widget.clusterID.toString())) {
for (var updatedFile in event.relevantFiles!) {
files.remove(updatedFile);
}
setState(() {});
}
});
kDebugMode
? ClusterFeedbackService.instance.debugLogClusterBlurValues(
widget.clusterID,
clusterSize: files.length,
)
: null;
} }
@override @override
void dispose() { void dispose() {
_filesUpdatedEvent.cancel(); _filesUpdatedEvent.cancel();
_peopleChangedEvent.cancel();
super.dispose(); super.dispose();
} }
@ -96,10 +115,12 @@ class _ClusterPageState extends State<ClusterPage> {
); );
}, },
reloadEvent: Bus.instance.on<LocalPhotosUpdatedEvent>(), reloadEvent: Bus.instance.on<LocalPhotosUpdatedEvent>(),
forceReloadEvents: [Bus.instance.on<PeopleChangedEvent>()],
removalEventTypes: const { removalEventTypes: const {
EventType.deletedFromRemote, EventType.deletedFromRemote,
EventType.deletedFromEverywhere, EventType.deletedFromEverywhere,
EventType.hide, EventType.hide,
EventType.peopleClusterChanged,
}, },
tagPrefix: widget.tagPrefix + widget.tagPrefix, tagPrefix: widget.tagPrefix + widget.tagPrefix,
selectedFiles: _selectedFiles, selectedFiles: _selectedFiles,
@ -111,9 +132,10 @@ class _ClusterPageState extends State<ClusterPage> {
preferredSize: const Size.fromHeight(50.0), preferredSize: const Size.fromHeight(50.0),
child: ClusterAppBar( child: ClusterAppBar(
SearchResultPage.appBarType, SearchResultPage.appBarType,
"${widget.searchResult.length} memories${widget.appendTitle}", "${files.length} memories${widget.appendTitle}",
_selectedFiles, _selectedFiles,
widget.clusterID, widget.clusterID,
key: ValueKey(files.length),
), ),
), ),
body: Column( body: Column(

View file

@ -1099,19 +1099,16 @@ Future<(Float32List, List<AlignmentResult>, List<bool>, List<double>, Size)>
imageHeight: image.height, imageHeight: image.height,
); );
final List<List<List<double>>> faceLandmarks =
absoluteFaces.map((face) => face.allKeypoints).toList();
final alignedImagesFloat32List = final alignedImagesFloat32List =
Float32List(3 * width * height * faceLandmarks.length); Float32List(3 * width * height * absoluteFaces.length);
final alignmentResults = <AlignmentResult>[]; final alignmentResults = <AlignmentResult>[];
final isBlurs = <bool>[]; final isBlurs = <bool>[];
final blurValues = <double>[]; final blurValues = <double>[];
int alignedImageIndex = 0; int alignedImageIndex = 0;
for (final faceLandmark in faceLandmarks) { for (final face in absoluteFaces) {
final (alignmentResult, correctlyEstimated) = final (alignmentResult, correctlyEstimated) =
SimilarityTransform.instance.estimate(faceLandmark); SimilarityTransform.instance.estimate(face.allKeypoints);
if (!correctlyEstimated) { if (!correctlyEstimated) {
alignedImageIndex += 3 * width * height; alignedImageIndex += 3 * width * height;
alignmentResults.add(AlignmentResult.empty()); alignmentResults.add(AlignmentResult.empty());
@ -1137,7 +1134,7 @@ Future<(Float32List, List<AlignmentResult>, List<bool>, List<double>, Size)>
final grayscalems = blurDetectionStopwatch.elapsedMilliseconds; final grayscalems = blurDetectionStopwatch.elapsedMilliseconds;
log('creating grayscale matrix took $grayscalems ms'); log('creating grayscale matrix took $grayscalems ms');
final (isBlur, blurValue) = await BlurDetectionService.instance final (isBlur, blurValue) = await BlurDetectionService.instance
.predictIsBlurGrayLaplacian(faceGrayMatrix); .predictIsBlurGrayLaplacian(faceGrayMatrix, faceDirection: face.getFaceDirection());
final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems; final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems;
log('blur detection took $blurms ms'); log('blur detection took $blurms ms');
log( log(