[mobile] Patch faces mvp from photos-app repo

This commit is contained in:
Neeraj Gupta 2024-03-08 09:36:03 +05:30
parent 457b1c1abd
commit d2bf4846a5
122 changed files with 13969 additions and 57 deletions

View file

@ -47,7 +47,7 @@ android {
defaultConfig {
applicationId "io.ente.photos"
minSdkVersion 21
minSdkVersion 26
targetSdkVersion 33
versionCode flutterVersionCode.toInteger()
versionName flutterVersionName
@ -74,6 +74,10 @@ android {
dimension "default"
applicationIdSuffix ".dev"
}
face {
dimension "default"
applicationIdSuffix ".face"
}
playstore {
dimension "default"
}

View file

@ -0,0 +1,10 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="io.ente.photos">
<!-- Flutter needs it to communicate with the running application
to allow setting breakpoints, to provide hot reload, etc.
-->
<uses-permission android:name="android.permission.INTERNET"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
</manifest>

View file

@ -0,0 +1,4 @@
<resources>
<string name="app_name">ente face</string>
<string name="backup">backup face</string>
</resources>

View file

@ -59,6 +59,8 @@ PODS:
- flutter_inappwebview/Core (0.0.1):
- Flutter
- OrderedSet (~> 5.0)
- flutter_isolate (0.0.1):
- Flutter
- flutter_local_notifications (0.0.1):
- Flutter
- flutter_native_splash (0.0.1):
@ -197,6 +199,28 @@ PODS:
- sqlite3/fts5
- sqlite3/perf-threadsafe
- sqlite3/rtree
- TensorFlowLiteC (2.12.0):
- TensorFlowLiteC/Core (= 2.12.0)
- TensorFlowLiteC/Core (2.12.0)
- TensorFlowLiteC/CoreML (2.12.0):
- TensorFlowLiteC/Core
- TensorFlowLiteC/Metal (2.12.0):
- TensorFlowLiteC/Core
- TensorFlowLiteSwift (2.12.0):
- TensorFlowLiteSwift/Core (= 2.12.0)
- TensorFlowLiteSwift/Core (2.12.0):
- TensorFlowLiteC (= 2.12.0)
- TensorFlowLiteSwift/CoreML (2.12.0):
- TensorFlowLiteC/CoreML (= 2.12.0)
- TensorFlowLiteSwift/Core (= 2.12.0)
- TensorFlowLiteSwift/Metal (2.12.0):
- TensorFlowLiteC/Metal (= 2.12.0)
- TensorFlowLiteSwift/Core (= 2.12.0)
- tflite_flutter (0.0.1):
- Flutter
- TensorFlowLiteSwift (= 2.12.0)
- TensorFlowLiteSwift/CoreML (= 2.12.0)
- TensorFlowLiteSwift/Metal (= 2.12.0)
- Toast (4.1.0)
- uni_links (0.0.1):
- Flutter
@ -228,6 +252,7 @@ DEPENDENCIES:
- flutter_email_sender (from `.symlinks/plugins/flutter_email_sender/ios`)
- flutter_image_compress (from `.symlinks/plugins/flutter_image_compress/ios`)
- flutter_inappwebview (from `.symlinks/plugins/flutter_inappwebview/ios`)
- flutter_isolate (from `.symlinks/plugins/flutter_isolate/ios`)
- flutter_local_notifications (from `.symlinks/plugins/flutter_local_notifications/ios`)
- flutter_native_splash (from `.symlinks/plugins/flutter_native_splash/ios`)
- flutter_secure_storage (from `.symlinks/plugins/flutter_secure_storage/ios`)
@ -257,6 +282,7 @@ DEPENDENCIES:
- shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/darwin`)
- sqflite (from `.symlinks/plugins/sqflite/darwin`)
- sqlite3_flutter_libs (from `.symlinks/plugins/sqlite3_flutter_libs/ios`)
- tflite_flutter (from `.symlinks/plugins/tflite_flutter/ios`)
- uni_links (from `.symlinks/plugins/uni_links/ios`)
- url_launcher_ios (from `.symlinks/plugins/url_launcher_ios/ios`)
- video_player_avfoundation (from `.symlinks/plugins/video_player_avfoundation/darwin`)
@ -287,6 +313,8 @@ SPEC REPOS:
- Sentry
- SentryPrivate
- sqlite3
- TensorFlowLiteC
- TensorFlowLiteSwift
- Toast
EXTERNAL SOURCES:
@ -314,6 +342,8 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/flutter_image_compress/ios"
flutter_inappwebview:
:path: ".symlinks/plugins/flutter_inappwebview/ios"
flutter_isolate:
:path: ".symlinks/plugins/flutter_isolate/ios"
flutter_local_notifications:
:path: ".symlinks/plugins/flutter_local_notifications/ios"
flutter_native_splash:
@ -372,6 +402,8 @@ EXTERNAL SOURCES:
:path: ".symlinks/plugins/sqflite/darwin"
sqlite3_flutter_libs:
:path: ".symlinks/plugins/sqlite3_flutter_libs/ios"
tflite_flutter:
:path: ".symlinks/plugins/tflite_flutter/ios"
uni_links:
:path: ".symlinks/plugins/uni_links/ios"
url_launcher_ios:
@ -405,6 +437,7 @@ SPEC CHECKSUMS:
flutter_email_sender: 02d7443217d8c41483223627972bfdc09f74276b
flutter_image_compress: 5a5e9aee05b6553048b8df1c3bc456d0afaac433
flutter_inappwebview: 3d32228f1304635e7c028b0d4252937730bbc6cf
flutter_isolate: 0edf5081826d071adf21759d1eb10ff5c24503b5
flutter_local_notifications: 0c0b1ae97e741e1521e4c1629a459d04b9aec743
flutter_native_splash: 52501b97d1c0a5f898d687f1646226c1f93c56ef
flutter_secure_storage: 23fc622d89d073675f2eaa109381aefbcf5a49be
@ -449,6 +482,9 @@ SPEC CHECKSUMS:
sqflite: 673a0e54cc04b7d6dba8d24fb8095b31c3a99eec
sqlite3: 73b7fc691fdc43277614250e04d183740cb15078
sqlite3_flutter_libs: aeb4d37509853dfa79d9b59386a2dac5dd079428
TensorFlowLiteC: 20785a69299185a379ba9852b6625f00afd7984a
TensorFlowLiteSwift: 3a4928286e9e35bdd3e17970f48e53c80d25e793
tflite_flutter: 9433d086a3060431bbc9f3c7c20d017db0e72d08
Toast: ec33c32b8688982cecc6348adeae667c1b9938da
uni_links: d97da20c7701486ba192624d99bffaaffcfc298a
url_launcher_ios: bbd758c6e7f9fd7b5b1d4cde34d2b95fcce5e812

View file

@ -299,6 +299,7 @@
"${BUILT_PRODUCTS_DIR}/flutter_email_sender/flutter_email_sender.framework",
"${BUILT_PRODUCTS_DIR}/flutter_image_compress/flutter_image_compress.framework",
"${BUILT_PRODUCTS_DIR}/flutter_inappwebview/flutter_inappwebview.framework",
"${BUILT_PRODUCTS_DIR}/flutter_isolate/flutter_isolate.framework",
"${BUILT_PRODUCTS_DIR}/flutter_local_notifications/flutter_local_notifications.framework",
"${BUILT_PRODUCTS_DIR}/flutter_native_splash/flutter_native_splash.framework",
"${BUILT_PRODUCTS_DIR}/flutter_secure_storage/flutter_secure_storage.framework",
@ -382,6 +383,7 @@
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_email_sender.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_image_compress.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_inappwebview.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_isolate.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_local_notifications.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_native_splash.framework",
"${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_secure_storage.framework",

View file

@ -18,6 +18,7 @@ import 'package:photos/db/trash_db.dart';
import 'package:photos/db/upload_locks_db.dart';
import 'package:photos/events/signed_in_event.dart';
import 'package:photos/events/user_logged_out_event.dart';
import "package:photos/face/db.dart";
import 'package:photos/models/key_attributes.dart';
import 'package:photos/models/key_gen_result.dart';
import 'package:photos/models/private_key_attributes.dart';
@ -164,6 +165,7 @@ class Configuration {
: null;
await CollectionsDB.instance.clearTable();
await MemoriesDB.instance.clearTable();
await FaceMLDataDB.instance.clearTable();
await UploadLocksDB.instance.clearTable();
await IgnoredFilesService.instance.reset();

View file

@ -0,0 +1,714 @@
import 'dart:async';
import 'package:logging/logging.dart';
import 'package:path/path.dart' show join;
import 'package:path_provider/path_provider.dart';
import 'package:photos/models/ml/ml_typedefs.dart';
import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart";
import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart";
import "package:photos/services/face_ml/face_ml_result.dart";
import 'package:sqflite/sqflite.dart';
/// Stores all data for the ML-related features. The database can be accessed by `MlDataDB.instance.database`.
///
/// This includes:
/// [facesTable] - Stores all the detected faces and its embeddings in the images.
/// [peopleTable] - Stores all the clusters of faces which are considered to be the same person.
class MlDataDB {
static final Logger _logger = Logger("MlDataDB");
// TODO: [BOB] put the db in files
static const _databaseName = "ente.ml_data.db";
static const _databaseVersion = 1;
static const facesTable = 'faces';
static const fileIDColumn = 'file_id';
static const faceMlResultColumn = 'face_ml_result';
static const mlVersionColumn = 'ml_version';
static const peopleTable = 'people';
static const personIDColumn = 'person_id';
static const clusterResultColumn = 'cluster_result';
static const centroidColumn = 'cluster_centroid';
static const centroidDistanceThresholdColumn = 'centroid_distance_threshold';
static const feedbackTable = 'feedback';
static const feedbackIDColumn = 'feedback_id';
static const feedbackTypeColumn = 'feedback_type';
static const feedbackDataColumn = 'feedback_data';
static const feedbackTimestampColumn = 'feedback_timestamp';
static const feedbackFaceMlVersionColumn = 'feedback_face_ml_version';
static const feedbackClusterMlVersionColumn = 'feedback_cluster_ml_version';
static const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable (
$fileIDColumn INTEGER NOT NULL UNIQUE,
$faceMlResultColumn TEXT NOT NULL,
$mlVersionColumn INTEGER NOT NULL,
PRIMARY KEY($fileIDColumn)
);
''';
static const createPeopleTable = '''CREATE TABLE IF NOT EXISTS $peopleTable (
$personIDColumn INTEGER NOT NULL UNIQUE,
$clusterResultColumn TEXT NOT NULL,
$centroidColumn TEXT NOT NULL,
$centroidDistanceThresholdColumn REAL NOT NULL,
PRIMARY KEY($personIDColumn)
);
''';
static const createFeedbackTable =
'''CREATE TABLE IF NOT EXISTS $feedbackTable (
$feedbackIDColumn TEXT NOT NULL UNIQUE,
$feedbackTypeColumn TEXT NOT NULL,
$feedbackDataColumn TEXT NOT NULL,
$feedbackTimestampColumn TEXT NOT NULL,
$feedbackFaceMlVersionColumn INTEGER NOT NULL,
$feedbackClusterMlVersionColumn INTEGER NOT NULL,
PRIMARY KEY($feedbackIDColumn)
);
''';
static const _deleteFacesTable = 'DROP TABLE IF EXISTS $facesTable';
static const _deletePeopleTable = 'DROP TABLE IF EXISTS $peopleTable';
static const _deleteFeedbackTable = 'DROP TABLE IF EXISTS $feedbackTable';
MlDataDB._privateConstructor();
static final MlDataDB instance = MlDataDB._privateConstructor();
static Future<Database>? _dbFuture;
Future<Database> get database async {
_dbFuture ??= _initDatabase();
return _dbFuture!;
}
Future<Database> _initDatabase() async {
final documentsDirectory = await getApplicationDocumentsDirectory();
final String databaseDirectory =
join(documentsDirectory.path, _databaseName);
return await openDatabase(
databaseDirectory,
version: _databaseVersion,
onCreate: _onCreate,
);
}
Future _onCreate(Database db, int version) async {
await db.execute(createFacesTable);
await db.execute(createPeopleTable);
await db.execute(createFeedbackTable);
}
/// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes!
Future<void> cleanTables({
bool cleanFaces = false,
bool cleanPeople = false,
bool cleanFeedback = false,
}) async {
_logger.fine('`cleanTables()` called');
final db = await instance.database;
if (cleanFaces) {
_logger.fine('`cleanTables()`: Cleaning faces table');
await db.execute(_deleteFacesTable);
}
if (cleanPeople) {
_logger.fine('`cleanTables()`: Cleaning people table');
await db.execute(_deletePeopleTable);
}
if (cleanFeedback) {
_logger.fine('`cleanTables()`: Cleaning feedback table');
await db.execute(_deleteFeedbackTable);
}
if (!cleanFaces && !cleanPeople && !cleanFeedback) {
_logger.fine(
'`cleanTables()`: No tables cleaned, since no table was specified. Please be careful with this function!',
);
}
await db.execute(createFacesTable);
await db.execute(createPeopleTable);
await db.execute(createFeedbackTable);
}
Future<void> createFaceMlResult(FaceMlResult faceMlResult) async {
_logger.fine('createFaceMlResult called');
final existingResult = await getFaceMlResult(faceMlResult.fileId);
if (existingResult != null) {
if (faceMlResult.mlVersion <= existingResult.mlVersion) {
_logger.fine(
'FaceMlResult with file ID ${faceMlResult.fileId} already exists with equal or higher version. Skipping insert.',
);
return;
}
}
final db = await instance.database;
await db.insert(
facesTable,
{
fileIDColumn: faceMlResult.fileId,
faceMlResultColumn: faceMlResult.toJsonString(),
mlVersionColumn: faceMlResult.mlVersion,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
}
Future<bool> doesFaceMlResultExist(int fileId, {int? mlVersion}) async {
_logger.fine('doesFaceMlResultExist called');
final db = await instance.database;
String whereString = '$fileIDColumn = ?';
final List<dynamic> whereArgs = [fileId];
if (mlVersion != null) {
whereString += ' AND $mlVersionColumn = ?';
whereArgs.add(mlVersion);
}
final result = await db.query(
facesTable,
where: whereString,
whereArgs: whereArgs,
limit: 1,
);
return result.isNotEmpty;
}
Future<FaceMlResult?> getFaceMlResult(int fileId, {int? mlVersion}) async {
_logger.fine('getFaceMlResult called');
final db = await instance.database;
String whereString = '$fileIDColumn = ?';
final List<dynamic> whereArgs = [fileId];
if (mlVersion != null) {
whereString += ' AND $mlVersionColumn = ?';
whereArgs.add(mlVersion);
}
final result = await db.query(
facesTable,
where: whereString,
whereArgs: whereArgs,
limit: 1,
);
if (result.isNotEmpty) {
return FaceMlResult.fromJsonString(
result.first[faceMlResultColumn] as String,
);
}
_logger.fine(
'No faceMlResult found for fileID $fileId and mlVersion $mlVersion (null if not specified)',
);
return null;
}
/// Returns the faceMlResults for the given [fileIds].
Future<List<FaceMlResult>> getSelectedFaceMlResults(
List<int> fileIds,
) async {
_logger.fine('getSelectedFaceMlResults called');
final db = await instance.database;
if (fileIds.isEmpty) {
_logger.warning('getSelectedFaceMlResults called with empty fileIds');
return <FaceMlResult>[];
}
final List<Map<String, Object?>> results = await db.query(
facesTable,
columns: [faceMlResultColumn],
where: '$fileIDColumn IN (${fileIds.join(',')})',
orderBy: fileIDColumn,
);
return results
.map(
(result) =>
FaceMlResult.fromJsonString(result[faceMlResultColumn] as String),
)
.toList();
}
Future<List<FaceMlResult>> getAllFaceMlResults({int? mlVersion}) async {
_logger.fine('getAllFaceMlResults called');
final db = await instance.database;
String? whereString;
List<dynamic>? whereArgs;
if (mlVersion != null) {
whereString = '$mlVersionColumn = ?';
whereArgs = [mlVersion];
}
final results = await db.query(
facesTable,
where: whereString,
whereArgs: whereArgs,
orderBy: fileIDColumn,
);
return results
.map(
(result) =>
FaceMlResult.fromJsonString(result[faceMlResultColumn] as String),
)
.toList();
}
/// getAllFileIDs returns a set of all fileIDs from the facesTable, meaning all the fileIDs for which a FaceMlResult exists, optionally filtered by mlVersion.
Future<Set<int>> getAllFaceMlResultFileIDs({int? mlVersion}) async {
_logger.fine('getAllFaceMlResultFileIDs called');
final db = await instance.database;
String? whereString;
List<dynamic>? whereArgs;
if (mlVersion != null) {
whereString = '$mlVersionColumn = ?';
whereArgs = [mlVersion];
}
final List<Map<String, Object?>> results = await db.query(
facesTable,
where: whereString,
whereArgs: whereArgs,
orderBy: fileIDColumn,
);
return results.map((result) => result[fileIDColumn] as int).toSet();
}
Future<Set<int>> getAllFaceMlResultFileIDsProcessedWithThumbnailOnly({
int? mlVersion,
}) async {
_logger.fine('getAllFaceMlResultFileIDsProcessedWithThumbnailOnly called');
final db = await instance.database;
String? whereString;
List<dynamic>? whereArgs;
if (mlVersion != null) {
whereString = '$mlVersionColumn = ?';
whereArgs = [mlVersion];
}
final List<Map<String, Object?>> results = await db.query(
facesTable,
where: whereString,
whereArgs: whereArgs,
orderBy: fileIDColumn,
);
return results
.map(
(result) =>
FaceMlResult.fromJsonString(result[faceMlResultColumn] as String),
)
.where((element) => element.onlyThumbnailUsed)
.map((result) => result.fileId)
.toSet();
}
/// Updates the faceMlResult for the given [faceMlResult.fileId]. Update is done regardless of the [faceMlResult.mlVersion].
/// However, if [updateHigherVersionOnly] is set to true, the update is only done if the [faceMlResult.mlVersion] is higher than the existing one.
Future<int> updateFaceMlResult(
FaceMlResult faceMlResult, {
bool updateHigherVersionOnly = false,
}) async {
_logger.fine('updateFaceMlResult called');
if (updateHigherVersionOnly) {
final existingResult = await getFaceMlResult(faceMlResult.fileId);
if (existingResult != null) {
if (faceMlResult.mlVersion <= existingResult.mlVersion) {
_logger.fine(
'FaceMlResult with file ID ${faceMlResult.fileId} already exists with equal or higher version. Skipping update.',
);
return 0;
}
}
}
final db = await instance.database;
return await db.update(
facesTable,
{
fileIDColumn: faceMlResult.fileId,
faceMlResultColumn: faceMlResult.toJsonString(),
mlVersionColumn: faceMlResult.mlVersion,
},
where: '$fileIDColumn = ?',
whereArgs: [faceMlResult.fileId],
);
}
Future<int> deleteFaceMlResult(int fileId) async {
_logger.fine('deleteFaceMlResult called');
final db = await instance.database;
final deleteCount = await db.delete(
facesTable,
where: '$fileIDColumn = ?',
whereArgs: [fileId],
);
_logger.fine('Deleted $deleteCount faceMlResults');
return deleteCount;
}
Future<void> createAllClusterResults(
List<ClusterResult> clusterResults, {
bool cleanExistingClusters = true,
}) async {
_logger.fine('createClusterResults called');
final db = await instance.database;
if (clusterResults.isEmpty) {
_logger.fine('No clusterResults given, skipping insert.');
return;
}
// Completely clean the table and start fresh
if (cleanExistingClusters) {
await deleteAllClusterResults();
}
// Insert all the cluster results
for (final clusterResult in clusterResults) {
await db.insert(
peopleTable,
{
personIDColumn: clusterResult.personId,
clusterResultColumn: clusterResult.toJsonString(),
centroidColumn: clusterResult.medoid.toString(),
centroidDistanceThresholdColumn:
clusterResult.medoidDistanceThreshold,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
}
}
Future<ClusterResult?> getClusterResult(int personId) async {
_logger.fine('getClusterResult called');
final db = await instance.database;
final result = await db.query(
peopleTable,
where: '$personIDColumn = ?',
whereArgs: [personId],
limit: 1,
);
if (result.isNotEmpty) {
return ClusterResult.fromJsonString(
result.first[clusterResultColumn] as String,
);
}
_logger.fine('No clusterResult found for personID $personId');
return null;
}
/// Returns the ClusterResult objects for the given [personIDs].
Future<List<ClusterResult>> getSelectedClusterResults(
List<int> personIDs,
) async {
_logger.fine('getSelectedClusterResults called');
final db = await instance.database;
if (personIDs.isEmpty) {
_logger.warning('getSelectedClusterResults called with empty personIDs');
return <ClusterResult>[];
}
final results = await db.query(
peopleTable,
where: '$personIDColumn IN (${personIDs.join(',')})',
orderBy: personIDColumn,
);
return results
.map(
(result) => ClusterResult.fromJsonString(
result[clusterResultColumn] as String,
),
)
.toList();
}
Future<List<ClusterResult>> getAllClusterResults() async {
_logger.fine('getAllClusterResults called');
final db = await instance.database;
final results = await db.query(
peopleTable,
);
return results
.map(
(result) => ClusterResult.fromJsonString(
result[clusterResultColumn] as String,
),
)
.toList();
}
/// Returns the personIDs of all clustered people in the database.
Future<List<int>> getAllClusterIds() async {
_logger.fine('getAllClusterIds called');
final db = await instance.database;
final results = await db.query(
peopleTable,
columns: [personIDColumn],
);
return results.map((result) => result[personIDColumn] as int).toList();
}
/// Returns the fileIDs of all files associated with a given [personId].
Future<List<int>> getClusterFileIds(int personId) async {
_logger.fine('getClusterFileIds called');
final ClusterResult? clusterResult = await getClusterResult(personId);
if (clusterResult == null) {
return <int>[];
}
return clusterResult.uniqueFileIds;
}
Future<List<String>> getClusterFaceIds(int personId) async {
_logger.fine('getClusterFaceIds called');
final ClusterResult? clusterResult = await getClusterResult(personId);
if (clusterResult == null) {
return <String>[];
}
return clusterResult.faceIDs;
}
Future<List<Embedding>> getClusterEmbeddings(
int personId,
) async {
_logger.fine('getClusterEmbeddings called');
final ClusterResult? clusterResult = await getClusterResult(personId);
if (clusterResult == null) return <Embedding>[];
final fileIds = clusterResult.uniqueFileIds;
final faceIds = clusterResult.faceIDs;
if (fileIds.length != faceIds.length) {
_logger.severe(
'fileIds and faceIds have different lengths: ${fileIds.length} vs ${faceIds.length}. This should not happen!',
);
return <Embedding>[];
}
final faceMlResults = await getSelectedFaceMlResults(fileIds);
if (faceMlResults.isEmpty) return <Embedding>[];
final embeddings = <Embedding>[];
for (var i = 0; i < faceMlResults.length; i++) {
final faceMlResult = faceMlResults[i];
final int faceIndex = faceMlResult.allFaceIds.indexOf(faceIds[i]);
if (faceIndex == -1) {
_logger.severe(
'Could not find faceIndex for faceId ${faceIds[i]} in faceMlResult ${faceMlResult.fileId}',
);
return <Embedding>[];
}
embeddings.add(faceMlResult.faces[faceIndex].embedding);
}
return embeddings;
}
Future<void> updateClusterResult(ClusterResult clusterResult) async {
_logger.fine('updateClusterResult called');
final db = await instance.database;
await db.update(
peopleTable,
{
personIDColumn: clusterResult.personId,
clusterResultColumn: clusterResult.toJsonString(),
centroidColumn: clusterResult.medoid.toString(),
centroidDistanceThresholdColumn: clusterResult.medoidDistanceThreshold,
},
where: '$personIDColumn = ?',
whereArgs: [clusterResult.personId],
);
}
Future<int> deleteClusterResult(int personId) async {
_logger.fine('deleteClusterResult called');
final db = await instance.database;
final deleteCount = await db.delete(
peopleTable,
where: '$personIDColumn = ?',
whereArgs: [personId],
);
_logger.fine('Deleted $deleteCount clusterResults');
return deleteCount;
}
Future<void> deleteAllClusterResults() async {
_logger.fine('deleteAllClusterResults called');
final db = await instance.database;
await db.execute(_deletePeopleTable);
await db.execute(createPeopleTable);
}
// TODO: current function implementation will skip inserting for a similar feedback, which means I can't remove two photos from the same person in a row
Future<void> createClusterFeedback<T extends ClusterFeedback>(
T feedback, {
bool skipIfSimilarFeedbackExists = false,
}) async {
_logger.fine('createClusterFeedback called');
// TODO: this skipping might cause issues for adding photos to the same person in a row!!
if (skipIfSimilarFeedbackExists &&
await doesSimilarClusterFeedbackExist(feedback)) {
_logger.fine(
'ClusterFeedback with ID ${feedback.feedbackID} already has a similar feedback installed. Skipping insert.',
);
return;
}
final db = await instance.database;
await db.insert(
feedbackTable,
{
feedbackIDColumn: feedback.feedbackID,
feedbackTypeColumn: feedback.typeString,
feedbackDataColumn: feedback.toJsonString(),
feedbackTimestampColumn: feedback.timestampString,
feedbackFaceMlVersionColumn: feedback.madeOnFaceMlVersion,
feedbackClusterMlVersionColumn: feedback.madeOnClusterMlVersion,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
return;
}
Future<bool> doesSimilarClusterFeedbackExist<T extends ClusterFeedback>(
T feedback,
) async {
_logger.fine('doesClusterFeedbackExist called');
final List<T> existingFeedback =
await getAllClusterFeedback<T>(type: feedback.type);
if (existingFeedback.isNotEmpty) {
for (final existingFeedbackItem in existingFeedback) {
assert(
existingFeedbackItem.type == feedback.type,
'Feedback types should be the same!',
);
if (feedback.looselyMatchesMedoid(existingFeedbackItem)) {
_logger.fine(
'ClusterFeedback of type ${feedback.typeString} with ID ${feedback.feedbackID} already has a similar feedback installed!',
);
return true;
}
}
}
return false;
}
/// Returns all the clusterFeedbacks of type [T] which match the given [feedback], sorted by timestamp (latest first).
Future<List<T>> getAllMatchingClusterFeedback<T extends ClusterFeedback>(
T feedback, {
bool sortNewestFirst = true,
}) async {
_logger.fine('getAllMatchingClusterFeedback called');
final List<T> existingFeedback =
await getAllClusterFeedback<T>(type: feedback.type);
final List<T> matchingFeedback = <T>[];
if (existingFeedback.isNotEmpty) {
for (final existingFeedbackItem in existingFeedback) {
assert(
existingFeedbackItem.type == feedback.type,
'Feedback types should be the same!',
);
if (feedback.looselyMatchesMedoid(existingFeedbackItem)) {
_logger.fine(
'ClusterFeedback of type ${feedback.typeString} with ID ${feedback.feedbackID} already has a similar feedback installed!',
);
matchingFeedback.add(existingFeedbackItem);
}
}
}
if (sortNewestFirst) {
matchingFeedback.sort((a, b) => b.timestamp.compareTo(a.timestamp));
}
return matchingFeedback;
}
Future<List<T>> getAllClusterFeedback<T extends ClusterFeedback>({
required FeedbackType type,
int? mlVersion,
int? clusterMlVersion,
}) async {
_logger.fine('getAllClusterFeedback called');
final db = await instance.database;
// TODO: implement the versions for FeedbackType.imageFeedback and FeedbackType.faceFeedback and rename this function to getAllFeedback?
String whereString = '$feedbackTypeColumn = ?';
final List<dynamic> whereArgs = [type.toValueString()];
if (mlVersion != null) {
whereString += ' AND $feedbackFaceMlVersionColumn = ?';
whereArgs.add(mlVersion);
}
if (clusterMlVersion != null) {
whereString += ' AND $feedbackClusterMlVersionColumn = ?';
whereArgs.add(clusterMlVersion);
}
final results = await db.query(
feedbackTable,
where: whereString,
whereArgs: whereArgs,
);
if (results.isNotEmpty) {
if (ClusterFeedback.fromJsonStringRegistry.containsKey(type)) {
final Function(String) fromJsonString =
ClusterFeedback.fromJsonStringRegistry[type]!;
return results
.map((e) => fromJsonString(e[feedbackDataColumn] as String) as T)
.toList();
} else {
_logger.severe(
'No fromJsonString function found for type ${type.name}. This should not happen!',
);
}
}
_logger.fine(
'No clusterFeedback results found of type $type' +
(mlVersion != null ? ' and mlVersion $mlVersion' : '') +
(clusterMlVersion != null
? ' and clusterMlVersion $clusterMlVersion'
: ''),
);
return <T>[];
}
Future<int> deleteClusterFeedback<T extends ClusterFeedback>(
T feedback,
) async {
_logger.fine('deleteClusterFeedback called');
final db = await instance.database;
final deleteCount = await db.delete(
feedbackTable,
where: '$feedbackIDColumn = ?',
whereArgs: [feedback.feedbackID],
);
_logger.fine('Deleted $deleteCount clusterFeedbacks');
return deleteCount;
}
}

View file

@ -26,4 +26,5 @@ enum EventType {
hide,
unhide,
coverChanged,
peopleChanged,
}

View file

@ -0,0 +1,3 @@
import "package:photos/events/event.dart";
class PeopleChangedEvent extends Event {}

View file

@ -0,0 +1,193 @@
import 'dart:math' as math show sin, cos, atan2, sqrt, pow;
import 'package:ml_linalg/linalg.dart';
extension SetVectorValues on Vector {
Vector setValues(int start, int end, Iterable<double> values) {
if (values.length > length) {
throw Exception('Values cannot be larger than vector');
} else if (end - start != values.length) {
throw Exception('Values must be same length as range');
} else if (start < 0 || end > length) {
throw Exception('Range must be within vector');
}
final tempList = toList();
tempList.replaceRange(start, end, values);
final newVector = Vector.fromList(tempList);
return newVector;
}
}
extension SetMatrixValues on Matrix {
Matrix setSubMatrix(
int startRow,
int endRow,
int startColumn,
int endColumn,
Iterable<Iterable<double>> values,
) {
if (values.length > rowCount) {
throw Exception('New values cannot have more rows than original matrix');
} else if (values.elementAt(0).length > columnCount) {
throw Exception(
'New values cannot have more columns than original matrix',
);
} else if (endRow - startRow != values.length) {
throw Exception('Values (number of rows) must be same length as range');
} else if (endColumn - startColumn != values.elementAt(0).length) {
throw Exception(
'Values (number of columns) must be same length as range',
);
} else if (startRow < 0 ||
endRow > rowCount ||
startColumn < 0 ||
endColumn > columnCount) {
throw Exception('Range must be within matrix');
}
final tempList = asFlattenedList
.toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error
for (var i = startRow; i < endRow; i++) {
tempList.replaceRange(
i * columnCount + startColumn,
i * columnCount + endColumn,
values.elementAt(i).toList(),
);
}
final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount);
return newMatrix;
}
Matrix setValues(
int startRow,
int endRow,
int startColumn,
int endColumn,
Iterable<double> values,
) {
if ((startRow - endRow) * (startColumn - endColumn) != values.length) {
throw Exception('Values must be same length as range');
} else if (startRow < 0 ||
endRow > rowCount ||
startColumn < 0 ||
endColumn > columnCount) {
throw Exception('Range must be within matrix');
}
final tempList = asFlattenedList
.toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error
var index = 0;
for (var i = startRow; i < endRow; i++) {
for (var j = startColumn; j < endColumn; j++) {
tempList[i * columnCount + j] = values.elementAt(index);
index++;
}
}
final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount);
return newMatrix;
}
Matrix setValue(int row, int column, double value) {
if (row < 0 || row > rowCount || column < 0 || column > columnCount) {
throw Exception('Index must be within range of matrix');
}
final tempList = asFlattenedList;
tempList[row * columnCount + column] = value;
final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount);
return newMatrix;
}
Matrix appendRow(List<double> row) {
final oldNumberOfRows = rowCount;
final oldNumberOfColumns = columnCount;
if (row.length != oldNumberOfColumns) {
throw Exception('Row must have same number of columns as matrix');
}
final flatListMatrix = asFlattenedList;
flatListMatrix.addAll(row);
return Matrix.fromFlattenedList(
flatListMatrix,
oldNumberOfRows + 1,
oldNumberOfColumns,
);
}
}
extension MatrixCalculations on Matrix {
double determinant() {
final int length = rowCount;
if (length != columnCount) {
throw Exception('Matrix must be square');
}
if (length == 1) {
return this[0][0];
} else if (length == 2) {
return this[0][0] * this[1][1] - this[0][1] * this[1][0];
} else {
throw Exception('Determinant for Matrix larger than 2x2 not implemented');
}
}
/// Computes the singular value decomposition of a matrix, using https://lucidar.me/en/mathematics/singular-value-decomposition-of-a-2x2-matrix/ as reference, but with slightly different signs for the second columns of U and V
Map<String, dynamic> svd() {
if (rowCount != 2 || columnCount != 2) {
throw Exception('Matrix must be 2x2');
}
final a = this[0][0];
final b = this[0][1];
final c = this[1][0];
final d = this[1][1];
// Computation of U matrix
final tempCalc = a * a + b * b - c * c - d * d;
final theta = 0.5 * math.atan2(2 * a * c + 2 * b * d, tempCalc);
final U = Matrix.fromList([
[math.cos(theta), math.sin(theta)],
[math.sin(theta), -math.cos(theta)],
]);
// Computation of S matrix
// ignore: non_constant_identifier_names
final S1 = a * a + b * b + c * c + d * d;
// ignore: non_constant_identifier_names
final S2 =
math.sqrt(math.pow(tempCalc, 2) + 4 * math.pow(a * c + b * d, 2));
final sigma1 = math.sqrt((S1 + S2) / 2);
final sigma2 = math.sqrt((S1 - S2) / 2);
final S = Vector.fromList([sigma1, sigma2]);
// Computation of V matrix
final tempCalc2 = a * a - b * b + c * c - d * d;
final phi = 0.5 * math.atan2(2 * a * b + 2 * c * d, tempCalc2);
final s11 = (a * math.cos(theta) + c * math.sin(theta)) * math.cos(phi) +
(b * math.cos(theta) + d * math.sin(theta)) * math.sin(phi);
final s22 = (a * math.sin(theta) - c * math.cos(theta)) * math.sin(phi) +
(-b * math.sin(theta) + d * math.cos(theta)) * math.cos(phi);
final V = Matrix.fromList([
[s11.sign * math.cos(phi), s22.sign * math.sin(phi)],
[s11.sign * math.sin(phi), -s22.sign * math.cos(phi)],
]);
return {
'U': U,
'S': S,
'V': V,
};
}
int matrixRank() {
final svdResult = svd();
final Vector S = svdResult['S']!;
final rank = S.toList().where((element) => element > 1e-10).length;
return rank;
}
}
extension TransformMatrix on Matrix {
List<List<double>> to2DList() {
final List<List<double>> outerList = [];
for (var i = 0; i < rowCount; i++) {
final innerList = this[i].toList();
outerList.add(innerList);
}
return outerList;
}
}

679
mobile/lib/face/db.dart Normal file
View file

@ -0,0 +1,679 @@
import 'dart:async';
import "dart:math";
import "dart:typed_data";
import "package:collection/collection.dart";
import "package:flutter/foundation.dart";
import 'package:logging/logging.dart';
import 'package:path/path.dart' show join;
import 'package:path_provider/path_provider.dart';
import 'package:photos/face/db_fields.dart';
import "package:photos/face/db_model_mappers.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/face_ml/blur_detection/blur_constants.dart";
import 'package:sqflite/sqflite.dart';
/// Stores all data for the ML-related features. The database can be accessed by `MlDataDB.instance.database`.
///
/// This includes:
/// [facesTable] - Stores all the detected faces and its embeddings in the images.
/// [peopleTable] - Stores all the clusters of faces which are considered to be the same person.
class FaceMLDataDB {
static final Logger _logger = Logger("FaceMLDataDB");
static const _databaseName = "ente.face_ml_db.db";
static const _databaseVersion = 1;
FaceMLDataDB._privateConstructor();
static final FaceMLDataDB instance = FaceMLDataDB._privateConstructor();
static Future<Database>? _dbFuture;
Future<Database> get database async {
_dbFuture ??= _initDatabase();
return _dbFuture!;
}
Future<Database> _initDatabase() async {
final documentsDirectory = await getApplicationDocumentsDirectory();
final String databaseDirectory =
join(documentsDirectory.path, _databaseName);
return await openDatabase(
databaseDirectory,
version: _databaseVersion,
onCreate: _onCreate,
);
}
Future _onCreate(Database db, int version) async {
await db.execute(createFacesTable);
await db.execute(createPeopleTable);
await db.execute(createClusterTable);
await db.execute(createClusterSummaryTable);
await db.execute(createNotPersonFeedbackTable);
}
// bulkInsertFaces inserts the faces in the database in batches of 1000.
// This is done to avoid the error "too many SQL variables" when inserting
// a large number of faces.
Future<void> bulkInsertFaces(List<Face> faces) async {
final db = await instance.database;
const batchSize = 500;
final numBatches = (faces.length / batchSize).ceil();
for (int i = 0; i < numBatches; i++) {
final start = i * batchSize;
final end = min((i + 1) * batchSize, faces.length);
final batch = faces.sublist(start, end);
final batchInsert = db.batch();
for (final face in batch) {
batchInsert.insert(
facesTable,
mapRemoteToFaceDB(face),
conflictAlgorithm: ConflictAlgorithm.ignore,
);
}
await batchInsert.commit(noResult: true);
}
}
Future<Set<int>> getIndexedFileIds() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT DISTINCT $fileIDColumn FROM $facesTable',
);
return maps.map((e) => e[fileIDColumn] as int).toSet();
}
Future<Map<int, int>> clusterIdToFaceCount() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn, COUNT(*) as count FROM $facesTable where $cluserIDColumn IS NOT NULL GROUP BY $cluserIDColumn ',
);
final Map<int, int> result = {};
for (final map in maps) {
result[map[cluserIDColumn] as int] = map['count'] as int;
}
return result;
}
Future<Set<int>> getPersonIgnoredClusters(String personID) async {
final db = await instance.database;
// find out clusterIds that are assigned to other persons using the clusters table
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $clustersTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
[personID],
);
final Set<int> ignoredClusterIDs =
maps.map((e) => e[cluserIDColumn] as int).toSet();
final List<Map<String, dynamic>> rejectMaps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
[personID],
);
final Set<int> rejectClusterIDs =
rejectMaps.map((e) => e[cluserIDColumn] as int).toSet();
return ignoredClusterIDs.union(rejectClusterIDs);
}
Future<Set<int>> getPersonClusterIDs(String personID) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn FROM $clustersTable WHERE $personIdColumn = ?',
[personID],
);
return maps.map((e) => e[cluserIDColumn] as int).toSet();
}
Future<void> clearTable() async {
final db = await instance.database;
await db.delete(facesTable);
await db.delete(createClusterTable);
await db.delete(clusterSummaryTable);
await db.delete(peopleTable);
await db.delete(notPersonFeedback);
}
Future<Iterable<Uint8List>> getFaceEmbeddingsForCluster(
int clusterID, {
int? limit,
}) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $cluserIDColumn = ? ${limit != null ? 'LIMIT $limit' : ''}',
[clusterID],
);
return maps.map((e) => e[faceEmbeddingBlob] as Uint8List);
}
Future<Map<int, int>> getFileIdToCount() async {
final Map<int, int> result = {};
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $fileIDColumn, COUNT(*) as count FROM $facesTable where $faceScore > 0.8 GROUP BY $fileIDColumn',
);
for (final map in maps) {
result[map[fileIDColumn] as int] = map['count'] as int;
}
return result;
}
Future<Face?> getCoverFaceForPerson({
required int recentFileID,
String? personID,
int? clusterID,
}) async {
// read person from db
final db = await instance.database;
if (personID != null) {
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT * FROM $peopleTable where $idColumn = ?',
[personID],
);
if (maps.isEmpty) {
throw Exception("Person with id $personID not found");
}
final person = mapRowToPerson(maps.first);
final List<int> fileId = [recentFileID];
int? avatarFileId;
if (person.attr.avatarFaceId != null) {
avatarFileId = int.tryParse(person.attr.avatarFaceId!.split('-')[0]);
if (avatarFileId != null) {
fileId.add(avatarFileId);
}
}
final cluterRows = await db.query(
clustersTable,
columns: [cluserIDColumn],
where: '$personIdColumn = ?',
whereArgs: [personID],
);
final clusterIDs =
cluterRows.map((e) => e[cluserIDColumn] as int).toList();
final List<Map<String, dynamic>> faceMaps = await db.rawQuery(
'SELECT * FROM $facesTable where $faceClusterId IN (${clusterIDs.join(",")}) AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > 0.8 ORDER BY $faceScore DESC',
);
if (faceMaps.isNotEmpty) {
if (avatarFileId != null) {
final row = faceMaps.firstWhereOrNull(
(element) => (element[fileIDColumn] as int) == avatarFileId,
);
if (row != null) {
return mapRowToFace(row);
}
}
return mapRowToFace(faceMaps.first);
}
}
if (clusterID != null) {
final clusterIDs = [clusterID];
final List<Map<String, dynamic>> faceMaps = await db.rawQuery(
'SELECT * FROM $facesTable where $faceClusterId IN (${clusterIDs.join(",")}) AND $fileIDColumn = $recentFileID ',
);
if (faceMaps.isNotEmpty) {
return mapRowToFace(faceMaps.first);
}
}
if (personID == null && clusterID == null) {
throw Exception("personID and clusterID cannot be null");
}
return null;
}
Future<List<Face>> getFacesForGivenFileID(int fileUploadID) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.query(
facesTable,
columns: [
fileIDColumn,
faceIDColumn,
faceDetectionColumn,
faceEmbeddingBlob,
faceScore,
faceBlur,
faceClusterId,
faceClosestDistColumn,
faceClosestFaceID,
faceConfirmedColumn,
mlVersionColumn,
],
where: '$fileIDColumn = ?',
whereArgs: [fileUploadID],
);
return maps.map((e) => mapRowToFace(e)).toList();
}
Future<Face?> getFaceForFaceID(String faceID) async {
final db = await instance.database;
final result = await db.rawQuery(
'SELECT * FROM $facesTable where $faceIDColumn = ?',
[faceID],
);
if (result.isEmpty) {
return null;
}
return mapRowToFace(result.first);
}
Future<Map<String, int?>> getFaceIdsToClusterIds(
Iterable<String> faceIds,
) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $faceIDColumn, $faceClusterId FROM $facesTable where $faceIDColumn IN (${faceIds.map((id) => "'$id'").join(",")})',
);
final Map<String, int?> result = {};
for (final map in maps) {
result[map[faceIDColumn] as String] = map[faceClusterId] as int?;
}
return result;
}
Future<Map<int, Set<int>>> getFileIdToClusterIds() async {
final Map<int, Set<int>> result = {};
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $faceClusterId, $fileIDColumn FROM $facesTable where $faceClusterId IS NOT NULL',
);
for (final map in maps) {
final personID = map[faceClusterId] as int;
final fileID = map[fileIDColumn] as int;
result[fileID] = (result[fileID] ?? {})..add(personID);
}
return result;
}
Future<void> updatePersonIDForFaceIDIFNotSet(
Map<String, int> faceIDToPersonID,
) async {
final db = await instance.database;
// Start a batch
final batch = db.batch();
for (final map in faceIDToPersonID.entries) {
final faceID = map.key;
final personID = map.value;
batch.update(
facesTable,
{faceClusterId: personID},
where: '$faceIDColumn = ? AND $faceClusterId IS NULL',
whereArgs: [faceID],
);
}
// Commit the batch
await batch.commit(noResult: true);
}
Future<void> forceUpdateClusterIds(
Map<String, int> faceIDToPersonID,
) async {
final db = await instance.database;
// Start a batch
final batch = db.batch();
for (final map in faceIDToPersonID.entries) {
final faceID = map.key;
final personID = map.value;
batch.update(
facesTable,
{faceClusterId: personID},
where: '$faceIDColumn = ?',
whereArgs: [faceID],
);
}
// Commit the batch
await batch.commit(noResult: true);
}
/// Returns a map of faceID to record of faceClusterID and faceEmbeddingBlob
///
/// Only selects faces with score greater than [minScore] and blur score greater than [minClarity]
Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
double minScore = 0.8,
int minClarity = kLaplacianThreshold,
int maxRows = 20000,
}) async {
_logger.info('reading as float');
final db = await instance.database;
// Define the batch size
const batchSize = 10000;
int offset = 0;
final Map<String, (int?, Uint8List)> result = {};
while (true) {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.query(
facesTable,
columns: [faceIDColumn, faceClusterId, faceEmbeddingBlob],
where: '$faceScore > $minScore and $faceBlur > $minClarity',
limit: batchSize,
offset: offset,
// orderBy: '$faceClusterId DESC',
orderBy: '$faceIDColumn DESC',
);
// Break the loop if no more rows
if (maps.isEmpty) {
break;
}
for (final map in maps) {
final faceID = map[faceIDColumn] as String;
result[faceID] =
(map[faceClusterId] as int?, map[faceEmbeddingBlob] as Uint8List);
}
if (result.length >= 20000) {
break;
}
offset += batchSize;
}
return result;
}
Future<Map<String, Uint8List>> getFaceEmbeddingMapForFile(
List<int> fileIDs,
) async {
_logger.info('reading as float');
final db = await instance.database;
// Define the batch size
const batchSize = 10000;
int offset = 0;
final Map<String, Uint8List> result = {};
while (true) {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.query(
facesTable,
columns: [faceIDColumn, faceEmbeddingBlob],
where:
'$faceScore > 0.8 AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
limit: batchSize,
offset: offset,
orderBy: '$faceIDColumn DESC',
);
// Break the loop if no more rows
if (maps.isEmpty) {
break;
}
for (final map in maps) {
final faceID = map[faceIDColumn] as String;
result[faceID] = map[faceEmbeddingBlob] as Uint8List;
}
if (result.length > 10000) {
break;
}
offset += batchSize;
}
return result;
}
Future<void> resetClusterIDs() async {
final db = await instance.database;
await db.update(
facesTable,
{faceClusterId: null},
);
}
Future<void> insert(Person p, int cluserID) async {
debugPrint("inserting person");
final db = await instance.database;
await db.insert(
peopleTable,
mapPersonToRow(p),
conflictAlgorithm: ConflictAlgorithm.replace,
);
await db.insert(
clustersTable,
{
personIdColumn: p.remoteID,
cluserIDColumn: cluserID,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
}
Future<void> updatePerson(Person p) async {
final db = await instance.database;
await db.update(
peopleTable,
mapPersonToRow(p),
where: '$idColumn = ?',
whereArgs: [p.remoteID],
);
}
Future<void> assignClusterToPerson({
required String personID,
required int clusterID,
}) async {
final db = await instance.database;
await db.insert(
clustersTable,
{
personIdColumn: personID,
cluserIDColumn: clusterID,
},
);
}
Future<void> captureNotPersonFeedback({
required String personID,
required int clusterID,
}) async {
final db = await instance.database;
await db.insert(
notPersonFeedback,
{
personIdColumn: personID,
cluserIDColumn: clusterID,
},
);
}
Future<int> removeClusterToPerson({
required String personID,
required int clusterID,
}) async {
final db = await instance.database;
return db.delete(
clustersTable,
where: '$personIdColumn = ? AND $cluserIDColumn = ?',
whereArgs: [personID, clusterID],
);
}
// for a given personID, return a map of clusterID to fileIDs using join query
Future<Map<int, Set<int>>> getFileIdToClusterIDSet(String personID) {
final db = instance.database;
return db.then((db) async {
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $clustersTable.$cluserIDColumn, $fileIDColumn FROM $facesTable '
'INNER JOIN $clustersTable '
'ON $facesTable.$faceClusterId = $clustersTable.$cluserIDColumn '
'WHERE $clustersTable.$personIdColumn = ?',
[personID],
);
final Map<int, Set<int>> result = {};
for (final map in maps) {
final clusterID = map[cluserIDColumn] as int;
final fileID = map[fileIDColumn] as int;
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
});
}
Future<Map<int, Set<int>>> getFileIdToClusterIDSetForCluster(
Set<int> clusterIDs,
) {
final db = instance.database;
return db.then((db) async {
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $cluserIDColumn, $fileIDColumn FROM $facesTable '
'WHERE $cluserIDColumn IN (${clusterIDs.join(",")})',
);
final Map<int, Set<int>> result = {};
for (final map in maps) {
final clusterID = map[cluserIDColumn] as int;
final fileID = map[fileIDColumn] as int;
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
});
}
Future<void> clusterSummaryUpdate(Map<int, (Uint8List, int)> summary) async {
final db = await instance.database;
var batch = db.batch();
int batchCounter = 0;
for (final entry in summary.entries) {
if (batchCounter == 400) {
await batch.commit(noResult: true);
batch = db.batch();
batchCounter = 0;
}
final int cluserID = entry.key;
final int count = entry.value.$2;
final Uint8List avg = entry.value.$1;
batch.insert(
clusterSummaryTable,
{
cluserIDColumn: cluserID,
avgColumn: avg,
countColumn: count,
},
conflictAlgorithm: ConflictAlgorithm.replace,
);
batchCounter++;
}
await batch.commit(noResult: true);
}
/// Returns a map of clusterID to (avg embedding, count)
Future<Map<int, (Uint8List, int)>> clusterSummaryAll() async {
final db = await instance.database;
final Map<int, (Uint8List, int)> result = {};
final rows = await db.rawQuery('SELECT * from $clusterSummaryTable');
for (final r in rows) {
final id = r[cluserIDColumn] as int;
final avg = r[avgColumn] as Uint8List;
final count = r[countColumn] as int;
result[id] = (avg, count);
}
return result;
}
Future<Map<int, String>> getCluserIDToPersonMap() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $personIdColumn, $cluserIDColumn FROM $clustersTable',
);
final Map<int, String> result = {};
for (final map in maps) {
result[map[cluserIDColumn] as int] = map[personIdColumn] as String;
}
return result;
}
Future<(Map<int, Person>, Map<String, Person>)> getClusterIdToPerson() async {
final db = await instance.database;
final Map<String, Person> peopleMap = await getPeopleMap();
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT $personIdColumn, $cluserIDColumn FROM $clustersTable',
);
final Map<int, Person> result = {};
for (final map in maps) {
final Person? p = peopleMap[map[personIdColumn] as String];
if (p != null) {
result[map[cluserIDColumn] as int] = p;
} else {
_logger.warning(
'Person with id ${map[personIdColumn]} not found',
);
}
}
return (result, peopleMap);
}
Future<Map<String, Person>> getPeopleMap() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.query(
peopleTable,
columns: [
idColumn,
nameColumn,
personHiddenColumn,
clusterToFaceIdJson,
coverFaceIDColumn,
],
);
final Map<String, Person> result = {};
for (final map in maps) {
result[map[idColumn] as String] = mapRowToPerson(map);
}
return result;
}
Future<List<Person>> getPeople() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.query(
peopleTable,
columns: [
idColumn,
nameColumn,
personHiddenColumn,
clusterToFaceIdJson,
coverFaceIDColumn,
],
);
return maps.map((map) => mapRowToPerson(map)).toList();
}
/// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes!
Future<void> dropClustersAndPeople({bool faces = false}) async {
final db = await instance.database;
if (faces) {
await db.execute(deleteFacesTable);
await db.execute(createFacesTable);
}
await db.execute(deletePeopleTable);
await db.execute(dropClustersTable);
await db.execute(dropClusterSummaryTable);
await db.execute(dropNotPersonFeedbackTable);
// await db.execute(createFacesTable);
await db.execute(createPeopleTable);
await db.execute(createClusterTable);
await db.execute(createNotPersonFeedbackTable);
await db.execute(createClusterSummaryTable);
}
Future<void> removePersonFromFiles(List<EnteFile> files, Person p) async {
final db = await instance.database;
final result = await db.rawQuery(
'SELECT $faceIDColumn FROM $facesTable LEFT JOIN $clustersTable '
'ON $facesTable.$faceClusterId = $clustersTable.$cluserIDColumn '
'WHERE $clustersTable.$personIdColumn = ? AND $facesTable.$fileIDColumn IN (${files.map((e) => e.uploadedFileID).join(",")})',
[p.remoteID],
);
// get max clusterID
final maxRows =
await db.rawQuery('SELECT max($faceClusterId) from $facesTable');
int maxClusterID = maxRows.first.values.first as int;
final Map<String, int> faceIDToClusterID = {};
for (final faceRow in result) {
final faceID = faceRow[faceIDColumn] as String;
faceIDToClusterID[faceID] = maxClusterID + 1;
maxClusterID = maxClusterID + 1;
}
await forceUpdateClusterIds(faceIDToClusterID);
}
}

View file

@ -0,0 +1,99 @@
// Faces Table Fields & Schema Queries
import "package:photos/services/face_ml/blur_detection/blur_constants.dart";
const facesTable = 'faces';
const fileIDColumn = 'file_id';
const faceIDColumn = 'face_id';
const faceDetectionColumn = 'detection';
const faceEmbeddingBlob = 'eBlob';
const faceScore = 'score';
const faceBlur = 'blur';
const faceClusterId = 'cluster_id';
const faceConfirmedColumn = 'confirmed';
const faceClosestDistColumn = 'close_dist';
const faceClosestFaceID = 'close_face_id';
const mlVersionColumn = 'ml_version';
const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable (
$fileIDColumn INTEGER NOT NULL,
$faceIDColumn TEXT NOT NULL,
$faceDetectionColumn TEXT NOT NULL,
$faceEmbeddingBlob BLOB NOT NULL,
$faceScore REAL NOT NULL,
$faceBlur REAL NOT NULL DEFAULT $kLapacianDefault,
$faceClusterId INTEGER,
$faceClosestDistColumn REAL,
$faceClosestFaceID TEXT,
$faceConfirmedColumn INTEGER NOT NULL DEFAULT 0,
$mlVersionColumn INTEGER NOT NULL DEFAULT -1,
PRIMARY KEY($fileIDColumn, $faceIDColumn)
);
''';
const deleteFacesTable = 'DROP TABLE IF EXISTS $facesTable';
// End of Faces Table Fields & Schema Queries
// People Table Fields & Schema Queries
const peopleTable = 'people';
const idColumn = 'id';
const nameColumn = 'name';
const personHiddenColumn = 'hidden';
const clusterToFaceIdJson = 'clusterToFaceIds';
const coverFaceIDColumn = 'cover_face_id';
const createPeopleTable = '''CREATE TABLE IF NOT EXISTS $peopleTable (
$idColumn TEXT NOT NULL UNIQUE,
$nameColumn TEXT NOT NULL DEFAULT '',
$personHiddenColumn INTEGER NOT NULL DEFAULT 0,
$clusterToFaceIdJson TEXT NOT NULL DEFAULT '{}',
$coverFaceIDColumn TEXT,
PRIMARY KEY($idColumn)
);
''';
const deletePeopleTable = 'DROP TABLE IF EXISTS $peopleTable';
//End People Table Fields & Schema Queries
// Clusters Table Fields & Schema Queries
const clustersTable = 'clusters';
const personIdColumn = 'person_id';
const cluserIDColumn = 'cluster_id';
const createClusterTable = '''
CREATE TABLE IF NOT EXISTS $clustersTable (
$personIdColumn TEXT NOT NULL,
$cluserIDColumn INTEGER NOT NULL,
PRIMARY KEY($personIdColumn, $cluserIDColumn)
);
''';
const dropClustersTable = 'DROP TABLE IF EXISTS $clustersTable';
// End Clusters Table Fields & Schema Queries
/// Cluster Summary Table Fields & Schema Queries
const clusterSummaryTable = 'cluster_summary';
const avgColumn = 'avg';
const countColumn = 'count';
const createClusterSummaryTable = '''
CREATE TABLE IF NOT EXISTS $clusterSummaryTable (
$cluserIDColumn INTEGER NOT NULL,
$avgColumn BLOB NOT NULL,
$countColumn INTEGER NOT NULL,
PRIMARY KEY($cluserIDColumn)
);
''';
const dropClusterSummaryTable = 'DROP TABLE IF EXISTS $clusterSummaryTable';
/// End Cluster Summary Table Fields & Schema Queries
/// notPersonFeedback Table Fields & Schema Queries
const notPersonFeedback = 'not_person_feedback';
const createNotPersonFeedbackTable = '''
CREATE TABLE IF NOT EXISTS $notPersonFeedback (
$personIdColumn TEXT NOT NULL,
$cluserIDColumn INTEGER NOT NULL
);
''';
const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback';
// End Clusters Table Fields & Schema Queries

View file

@ -0,0 +1,86 @@
import "dart:convert";
import 'package:photos/face/db_fields.dart';
import "package:photos/face/model/detection.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import 'package:photos/face/model/person_face.dart';
import "package:photos/generated/protos/ente/common/vector.pb.dart";
int boolToSQLInt(bool? value, {bool defaultValue = false}) {
final bool v = value ?? defaultValue;
if (v == false) {
return 0;
} else {
return 1;
}
}
bool sqlIntToBool(int? value, {bool defaultValue = false}) {
final int v = value ?? (defaultValue ? 1 : 0);
if (v == 0) {
return false;
} else {
return true;
}
}
Map<String, dynamic> mapToFaceDB(PersonFace personFace) {
return {
faceIDColumn: personFace.face.faceID,
faceDetectionColumn: json.encode(personFace.face.detection.toJson()),
faceConfirmedColumn: boolToSQLInt(personFace.confirmed),
faceClusterId: personFace.personID,
faceClosestDistColumn: personFace.closeDist,
faceClosestFaceID: personFace.closeFaceID,
};
}
Map<String, dynamic> mapPersonToRow(Person p) {
return {
idColumn: p.remoteID,
nameColumn: p.attr.name,
personHiddenColumn: boolToSQLInt(p.attr.isHidden),
coverFaceIDColumn: p.attr.avatarFaceId,
clusterToFaceIdJson: jsonEncode(p.attr.faces.toList()),
};
}
Person mapRowToPerson(Map<String, dynamic> row) {
return Person(
row[idColumn] as String,
PersonAttr(
name: row[nameColumn] as String,
isHidden: sqlIntToBool(row[personHiddenColumn] as int),
avatarFaceId: row[coverFaceIDColumn] as String?,
faces: (jsonDecode(row[clusterToFaceIdJson]) as List)
.map((e) => e.toString())
.toList(),
),
);
}
Map<String, dynamic> mapRemoteToFaceDB(Face face) {
return {
faceIDColumn: face.faceID,
fileIDColumn: face.fileID,
faceDetectionColumn: json.encode(face.detection.toJson()),
faceEmbeddingBlob: EVector(
values: face.embedding,
).writeToBuffer(),
faceScore: face.score,
faceBlur: face.blur,
mlVersionColumn: 1,
};
}
Face mapRowToFace(Map<String, dynamic> row) {
return Face(
row[faceIDColumn] as String,
row[fileIDColumn] as int,
EVector.fromBuffer(row[faceEmbeddingBlob] as List<int>).values,
row[faceScore] as double,
Detection.fromJson(json.decode(row[faceDetectionColumn] as String)),
row[faceBlur] as double,
);
}

View file

View file

@ -0,0 +1,42 @@
/// Bounding box of a face.
///
/// [`x`] and [y] are the coordinates of the top left corner of the box, so the minimim values
/// [width] and [height] are the width and height of the box.
/// All values are in absolute pixels relative to the original image size.
class FaceBox {
final double x;
final double y;
final double width;
final double height;
FaceBox({
required this.x,
required this.y,
required this.width,
required this.height,
});
factory FaceBox.fromJson(Map<String, dynamic> json) {
return FaceBox(
x: (json['x'] is int
? (json['x'] as int).toDouble()
: json['x'] as double),
y: (json['y'] is int
? (json['y'] as int).toDouble()
: json['y'] as double),
width: (json['width'] is int
? (json['width'] as int).toDouble()
: json['width'] as double),
height: (json['height'] is int
? (json['height'] as int).toDouble()
: json['height'] as double),
);
}
Map<String, dynamic> toJson() => {
'x': x,
'y': y,
'width': width,
'height': height,
};
}

View file

@ -0,0 +1,37 @@
import "package:photos/face/model/box.dart";
import "package:photos/face/model/landmark.dart";
class Detection {
FaceBox box;
List<Landmark> landmarks;
Detection({
required this.box,
required this.landmarks,
});
// emoty box
Detection.empty()
: box = FaceBox(
x: 0,
y: 0,
width: 0,
height: 0,
),
landmarks = [];
Map<String, dynamic> toJson() => {
'box': box.toJson(),
'landmarks': landmarks.map((x) => x.toJson()).toList(),
};
factory Detection.fromJson(Map<String, dynamic> json) {
return Detection(
box: FaceBox.fromJson(json['box'] as Map<String, dynamic>),
landmarks: List<Landmark>.from(
json['landmarks']
.map((x) => Landmark.fromJson(x as Map<String, dynamic>)),
),
);
}
}

View file

@ -0,0 +1,43 @@
import "package:photos/face/model/detection.dart";
import "package:photos/services/face_ml/blur_detection/blur_constants.dart";
class Face {
final int fileID;
final String faceID;
final List<double> embedding;
Detection detection;
final double score;
final double blur;
bool get isBlurry => blur < kLaplacianThreshold;
Face(
this.faceID,
this.fileID,
this.embedding,
this.score,
this.detection,
this.blur,
);
factory Face.fromJson(Map<String, dynamic> json) {
return Face(
json['faceID'] as String,
json['fileID'] as int,
List<double>.from(json['embeddings'] as List),
json['score'] as double,
Detection.fromJson(json['detection'] as Map<String, dynamic>),
// high value means t
(json['blur'] ?? kLapacianDefault) as double,
);
}
Map<String, dynamic> toJson() => {
'faceID': faceID,
'fileID': fileID,
'embeddings': embedding,
'detection': detection.toJson(),
'score': score,
'blur': blur,
};
}

View file

@ -0,0 +1,26 @@
// Class for the 'landmark' sub-object
class Landmark {
double x;
double y;
Landmark({
required this.x,
required this.y,
});
Map<String, dynamic> toJson() => {
'x': x,
'y': y,
};
factory Landmark.fromJson(Map<String, dynamic> json) {
return Landmark(
x: (json['x'] is int
? (json['x'] as int).toDouble()
: json['x'] as double),
y: (json['y'] is int
? (json['y'] as int).toDouble()
: json['y'] as double),
);
}
}

View file

@ -0,0 +1,70 @@
class Person {
final String remoteID;
final PersonAttr attr;
Person(
this.remoteID,
this.attr,
);
// copyWith
Person copyWith({
String? remoteID,
PersonAttr? attr,
}) {
return Person(
remoteID ?? this.remoteID,
attr ?? this.attr,
);
}
}
class PersonAttr {
final String name;
final bool isHidden;
String? avatarFaceId;
final List<String> faces;
final String? birthDatae;
PersonAttr({
required this.name,
required this.faces,
this.avatarFaceId,
this.isHidden = false,
this.birthDatae,
});
// copyWith
PersonAttr copyWith({
String? name,
List<String>? faces,
String? avatarFaceId,
bool? isHidden,
String? birthDatae,
}) {
return PersonAttr(
name: name ?? this.name,
faces: faces ?? this.faces,
avatarFaceId: avatarFaceId ?? this.avatarFaceId,
isHidden: isHidden ?? this.isHidden,
birthDatae: birthDatae ?? this.birthDatae,
);
}
// toJson
Map<String, dynamic> toJson() => {
'name': name,
'faces': faces.toList(),
'avatarFaceId': avatarFaceId,
'isHidden': isHidden,
'birthDatae': birthDatae,
};
// fromJson
factory PersonAttr.fromJson(Map<String, dynamic> json) {
return PersonAttr(
name: json['name'] as String,
faces: List<String>.from(json['faces'] as List<dynamic>),
avatarFaceId: json['avatarFaceId'] as String?,
isHidden: json['isHidden'] as bool? ?? false,
birthDatae: json['birthDatae'] as String?,
);
}
}

View file

@ -0,0 +1,37 @@
import 'package:photos/face/model/face.dart';
class PersonFace {
final Face face;
int? personID;
bool? confirmed;
double? closeDist;
String? closeFaceID;
PersonFace(
this.face,
this.personID,
this.closeDist,
this.closeFaceID, {
this.confirmed,
});
// toJson
Map<String, dynamic> toJson() => {
'face': face.toJson(),
'personID': personID,
'confirmed': confirmed ?? false,
'close_dist': closeDist,
'close_face_id': closeFaceID,
};
// fromJson
factory PersonFace.fromJson(Map<String, dynamic> json) {
return PersonFace(
Face.fromJson(json['face'] as Map<String, dynamic>),
json['personID'] as int?,
json['close_dist'] as double?,
json['close_face_id'] as String?,
confirmed: json['confirmed'] as bool?,
);
}
}

View file

@ -0,0 +1,44 @@
// import "dart:io";
import "package:dio/dio.dart";
import "package:logging/logging.dart";
import "package:photos/core/configuration.dart";
import "package:photos/core/network/network.dart";
import "package:photos/face/model/face.dart";
final _logger = Logger("import_from_zip");
Future<List<Face>> downloadZip() async {
final List<Face> result = [];
for (int i = 0; i < 2; i++) {
_logger.info("downloading $i");
final remoteZipUrl = "http://192.168.1.13:8700/ml/cx_ml_json_${i}.json";
final response = await NetworkClient.instance.getDio().get(
remoteZipUrl,
options: Options(
headers: {"X-Auth-Token": Configuration.instance.getToken()},
),
);
if (response.statusCode != 200) {
_logger.warning('download failed ${response.toString()}');
throw Exception("download failed");
}
final res = response.data as List<dynamic>;
for (final item in res) {
try {
result.add(Face.fromJson(item));
} catch (e) {
_logger.warning("failed to parse $item");
rethrow;
}
}
}
Set<String> faceID = {};
for (final face in result) {
if (faceID.contains(face.faceID)) {
_logger.warning("duplicate faceID ${face.faceID}");
}
faceID.add(face.faceID);
}
return result;
}

View file

@ -973,6 +973,7 @@ class MessageLookup extends MessageLookupByLibrary {
"paymentFailedWithReason": m36,
"pendingItems": MessageLookupByLibrary.simpleMessage("Pending items"),
"pendingSync": MessageLookupByLibrary.simpleMessage("Pending sync"),
"people": MessageLookupByLibrary.simpleMessage("People"),
"peopleUsingYourCode":
MessageLookupByLibrary.simpleMessage("People using your code"),
"permDeleteWarning": MessageLookupByLibrary.simpleMessage(

View file

@ -8158,6 +8158,16 @@ class S {
);
}
/// `People`
String get people {
return Intl.message(
'People',
name: 'people',
desc: '',
args: [],
);
}
/// `Contents`
String get contents {
return Intl.message(

View file

@ -0,0 +1,111 @@
//
// Generated code. Do not modify.
// source: ente/common/box.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:core' as $core;
import 'package:protobuf/protobuf.dart' as $pb;
/// CenterBox is a box where x,y is the center of the box
class CenterBox extends $pb.GeneratedMessage {
factory CenterBox({
$core.double? x,
$core.double? y,
$core.double? height,
$core.double? width,
}) {
final $result = create();
if (x != null) {
$result.x = x;
}
if (y != null) {
$result.y = y;
}
if (height != null) {
$result.height = height;
}
if (width != null) {
$result.width = width;
}
return $result;
}
CenterBox._() : super();
factory CenterBox.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory CenterBox.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'CenterBox', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create)
..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF)
..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF)
..a<$core.double>(3, _omitFieldNames ? '' : 'height', $pb.PbFieldType.OF)
..a<$core.double>(4, _omitFieldNames ? '' : 'width', $pb.PbFieldType.OF)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
CenterBox clone() => CenterBox()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
CenterBox copyWith(void Function(CenterBox) updates) => super.copyWith((message) => updates(message as CenterBox)) as CenterBox;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static CenterBox create() => CenterBox._();
CenterBox createEmptyInstance() => create();
static $pb.PbList<CenterBox> createRepeated() => $pb.PbList<CenterBox>();
@$core.pragma('dart2js:noInline')
static CenterBox getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<CenterBox>(create);
static CenterBox? _defaultInstance;
@$pb.TagNumber(1)
$core.double get x => $_getN(0);
@$pb.TagNumber(1)
set x($core.double v) { $_setFloat(0, v); }
@$pb.TagNumber(1)
$core.bool hasX() => $_has(0);
@$pb.TagNumber(1)
void clearX() => clearField(1);
@$pb.TagNumber(2)
$core.double get y => $_getN(1);
@$pb.TagNumber(2)
set y($core.double v) { $_setFloat(1, v); }
@$pb.TagNumber(2)
$core.bool hasY() => $_has(1);
@$pb.TagNumber(2)
void clearY() => clearField(2);
@$pb.TagNumber(3)
$core.double get height => $_getN(2);
@$pb.TagNumber(3)
set height($core.double v) { $_setFloat(2, v); }
@$pb.TagNumber(3)
$core.bool hasHeight() => $_has(2);
@$pb.TagNumber(3)
void clearHeight() => clearField(3);
@$pb.TagNumber(4)
$core.double get width => $_getN(3);
@$pb.TagNumber(4)
set width($core.double v) { $_setFloat(3, v); }
@$pb.TagNumber(4)
$core.bool hasWidth() => $_has(3);
@$pb.TagNumber(4)
void clearWidth() => clearField(4);
}
const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names');
const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names');

View file

@ -0,0 +1,11 @@
//
// Generated code. Do not modify.
// source: ente/common/box.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import

View file

@ -0,0 +1,38 @@
//
// Generated code. Do not modify.
// source: ente/common/box.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:convert' as $convert;
import 'dart:core' as $core;
import 'dart:typed_data' as $typed_data;
@$core.Deprecated('Use centerBoxDescriptor instead')
const CenterBox$json = {
'1': 'CenterBox',
'2': [
{'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true},
{'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true},
{'1': 'height', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'height', '17': true},
{'1': 'width', '3': 4, '4': 1, '5': 2, '9': 3, '10': 'width', '17': true},
],
'8': [
{'1': '_x'},
{'1': '_y'},
{'1': '_height'},
{'1': '_width'},
],
};
/// Descriptor for `CenterBox`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List centerBoxDescriptor = $convert.base64Decode(
'CglDZW50ZXJCb3gSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBARIbCgZoZW'
'lnaHQYAyABKAJIAlIGaGVpZ2h0iAEBEhkKBXdpZHRoGAQgASgCSANSBXdpZHRoiAEBQgQKAl94'
'QgQKAl95QgkKB19oZWlnaHRCCAoGX3dpZHRo');

View file

@ -0,0 +1,14 @@
//
// Generated code. Do not modify.
// source: ente/common/box.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names
// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
export 'box.pb.dart';

View file

@ -0,0 +1,83 @@
//
// Generated code. Do not modify.
// source: ente/common/point.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:core' as $core;
import 'package:protobuf/protobuf.dart' as $pb;
/// EPoint is a point in 2D space
class EPoint extends $pb.GeneratedMessage {
factory EPoint({
$core.double? x,
$core.double? y,
}) {
final $result = create();
if (x != null) {
$result.x = x;
}
if (y != null) {
$result.y = y;
}
return $result;
}
EPoint._() : super();
factory EPoint.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory EPoint.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EPoint', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create)
..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF)
..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
EPoint clone() => EPoint()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
EPoint copyWith(void Function(EPoint) updates) => super.copyWith((message) => updates(message as EPoint)) as EPoint;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static EPoint create() => EPoint._();
EPoint createEmptyInstance() => create();
static $pb.PbList<EPoint> createRepeated() => $pb.PbList<EPoint>();
@$core.pragma('dart2js:noInline')
static EPoint getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<EPoint>(create);
static EPoint? _defaultInstance;
@$pb.TagNumber(1)
$core.double get x => $_getN(0);
@$pb.TagNumber(1)
set x($core.double v) { $_setFloat(0, v); }
@$pb.TagNumber(1)
$core.bool hasX() => $_has(0);
@$pb.TagNumber(1)
void clearX() => clearField(1);
@$pb.TagNumber(2)
$core.double get y => $_getN(1);
@$pb.TagNumber(2)
set y($core.double v) { $_setFloat(1, v); }
@$pb.TagNumber(2)
$core.bool hasY() => $_has(1);
@$pb.TagNumber(2)
void clearY() => clearField(2);
}
const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names');
const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names');

View file

@ -0,0 +1,11 @@
//
// Generated code. Do not modify.
// source: ente/common/point.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import

View file

@ -0,0 +1,33 @@
//
// Generated code. Do not modify.
// source: ente/common/point.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:convert' as $convert;
import 'dart:core' as $core;
import 'dart:typed_data' as $typed_data;
@$core.Deprecated('Use ePointDescriptor instead')
const EPoint$json = {
'1': 'EPoint',
'2': [
{'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true},
{'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true},
],
'8': [
{'1': '_x'},
{'1': '_y'},
],
};
/// Descriptor for `EPoint`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List ePointDescriptor = $convert.base64Decode(
'CgZFUG9pbnQSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBAUIECgJfeEIECg'
'JfeQ==');

View file

@ -0,0 +1,14 @@
//
// Generated code. Do not modify.
// source: ente/common/point.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names
// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
export 'point.pb.dart';

View file

@ -0,0 +1,64 @@
//
// Generated code. Do not modify.
// source: ente/common/vector.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:core' as $core;
import 'package:protobuf/protobuf.dart' as $pb;
/// Vector is generic message for dealing with lists of doubles
/// It should ideally be used independently and not as a submessage
class EVector extends $pb.GeneratedMessage {
factory EVector({
$core.Iterable<$core.double>? values,
}) {
final $result = create();
if (values != null) {
$result.values.addAll(values);
}
return $result;
}
EVector._() : super();
factory EVector.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory EVector.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EVector', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create)
..p<$core.double>(1, _omitFieldNames ? '' : 'values', $pb.PbFieldType.KD)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
EVector clone() => EVector()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
EVector copyWith(void Function(EVector) updates) => super.copyWith((message) => updates(message as EVector)) as EVector;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static EVector create() => EVector._();
EVector createEmptyInstance() => create();
static $pb.PbList<EVector> createRepeated() => $pb.PbList<EVector>();
@$core.pragma('dart2js:noInline')
static EVector getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<EVector>(create);
static EVector? _defaultInstance;
@$pb.TagNumber(1)
$core.List<$core.double> get values => $_getList(0);
}
const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names');
const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names');

View file

@ -0,0 +1,11 @@
//
// Generated code. Do not modify.
// source: ente/common/vector.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import

View file

@ -0,0 +1,27 @@
//
// Generated code. Do not modify.
// source: ente/common/vector.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:convert' as $convert;
import 'dart:core' as $core;
import 'dart:typed_data' as $typed_data;
@$core.Deprecated('Use eVectorDescriptor instead')
const EVector$json = {
'1': 'EVector',
'2': [
{'1': 'values', '3': 1, '4': 3, '5': 1, '10': 'values'},
],
};
/// Descriptor for `EVector`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List eVectorDescriptor = $convert.base64Decode(
'CgdFVmVjdG9yEhYKBnZhbHVlcxgBIAMoAVIGdmFsdWVz');

View file

@ -0,0 +1,14 @@
//
// Generated code. Do not modify.
// source: ente/common/vector.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names
// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
export 'vector.pb.dart';

View file

@ -0,0 +1,169 @@
//
// Generated code. Do not modify.
// source: ente/ml/face.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:core' as $core;
import 'package:protobuf/protobuf.dart' as $pb;
import '../common/box.pb.dart' as $0;
import '../common/point.pb.dart' as $1;
class Detection extends $pb.GeneratedMessage {
factory Detection({
$0.CenterBox? box,
$1.EPoint? landmarks,
}) {
final $result = create();
if (box != null) {
$result.box = box;
}
if (landmarks != null) {
$result.landmarks = landmarks;
}
return $result;
}
Detection._() : super();
factory Detection.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory Detection.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Detection', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create)
..aOM<$0.CenterBox>(1, _omitFieldNames ? '' : 'box', subBuilder: $0.CenterBox.create)
..aOM<$1.EPoint>(2, _omitFieldNames ? '' : 'landmarks', subBuilder: $1.EPoint.create)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
Detection clone() => Detection()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
Detection copyWith(void Function(Detection) updates) => super.copyWith((message) => updates(message as Detection)) as Detection;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static Detection create() => Detection._();
Detection createEmptyInstance() => create();
static $pb.PbList<Detection> createRepeated() => $pb.PbList<Detection>();
@$core.pragma('dart2js:noInline')
static Detection getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<Detection>(create);
static Detection? _defaultInstance;
@$pb.TagNumber(1)
$0.CenterBox get box => $_getN(0);
@$pb.TagNumber(1)
set box($0.CenterBox v) { setField(1, v); }
@$pb.TagNumber(1)
$core.bool hasBox() => $_has(0);
@$pb.TagNumber(1)
void clearBox() => clearField(1);
@$pb.TagNumber(1)
$0.CenterBox ensureBox() => $_ensure(0);
@$pb.TagNumber(2)
$1.EPoint get landmarks => $_getN(1);
@$pb.TagNumber(2)
set landmarks($1.EPoint v) { setField(2, v); }
@$pb.TagNumber(2)
$core.bool hasLandmarks() => $_has(1);
@$pb.TagNumber(2)
void clearLandmarks() => clearField(2);
@$pb.TagNumber(2)
$1.EPoint ensureLandmarks() => $_ensure(1);
}
class Face extends $pb.GeneratedMessage {
factory Face({
$core.String? id,
Detection? detection,
$core.double? confidence,
}) {
final $result = create();
if (id != null) {
$result.id = id;
}
if (detection != null) {
$result.detection = detection;
}
if (confidence != null) {
$result.confidence = confidence;
}
return $result;
}
Face._() : super();
factory Face.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory Face.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Face', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create)
..aOS(1, _omitFieldNames ? '' : 'id')
..aOM<Detection>(2, _omitFieldNames ? '' : 'detection', subBuilder: Detection.create)
..a<$core.double>(3, _omitFieldNames ? '' : 'confidence', $pb.PbFieldType.OF)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
Face clone() => Face()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
Face copyWith(void Function(Face) updates) => super.copyWith((message) => updates(message as Face)) as Face;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static Face create() => Face._();
Face createEmptyInstance() => create();
static $pb.PbList<Face> createRepeated() => $pb.PbList<Face>();
@$core.pragma('dart2js:noInline')
static Face getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<Face>(create);
static Face? _defaultInstance;
@$pb.TagNumber(1)
$core.String get id => $_getSZ(0);
@$pb.TagNumber(1)
set id($core.String v) { $_setString(0, v); }
@$pb.TagNumber(1)
$core.bool hasId() => $_has(0);
@$pb.TagNumber(1)
void clearId() => clearField(1);
@$pb.TagNumber(2)
Detection get detection => $_getN(1);
@$pb.TagNumber(2)
set detection(Detection v) { setField(2, v); }
@$pb.TagNumber(2)
$core.bool hasDetection() => $_has(1);
@$pb.TagNumber(2)
void clearDetection() => clearField(2);
@$pb.TagNumber(2)
Detection ensureDetection() => $_ensure(1);
@$pb.TagNumber(3)
$core.double get confidence => $_getN(2);
@$pb.TagNumber(3)
set confidence($core.double v) { $_setFloat(2, v); }
@$pb.TagNumber(3)
$core.bool hasConfidence() => $_has(2);
@$pb.TagNumber(3)
void clearConfidence() => clearField(3);
}
const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names');
const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names');

View file

@ -0,0 +1,11 @@
//
// Generated code. Do not modify.
// source: ente/ml/face.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import

View file

@ -0,0 +1,55 @@
//
// Generated code. Do not modify.
// source: ente/ml/face.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:convert' as $convert;
import 'dart:core' as $core;
import 'dart:typed_data' as $typed_data;
@$core.Deprecated('Use detectionDescriptor instead')
const Detection$json = {
'1': 'Detection',
'2': [
{'1': 'box', '3': 1, '4': 1, '5': 11, '6': '.ente.common.CenterBox', '9': 0, '10': 'box', '17': true},
{'1': 'landmarks', '3': 2, '4': 1, '5': 11, '6': '.ente.common.EPoint', '9': 1, '10': 'landmarks', '17': true},
],
'8': [
{'1': '_box'},
{'1': '_landmarks'},
],
};
/// Descriptor for `Detection`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List detectionDescriptor = $convert.base64Decode(
'CglEZXRlY3Rpb24SLQoDYm94GAEgASgLMhYuZW50ZS5jb21tb24uQ2VudGVyQm94SABSA2JveI'
'gBARI2CglsYW5kbWFya3MYAiABKAsyEy5lbnRlLmNvbW1vbi5FUG9pbnRIAVIJbGFuZG1hcmtz'
'iAEBQgYKBF9ib3hCDAoKX2xhbmRtYXJrcw==');
@$core.Deprecated('Use faceDescriptor instead')
const Face$json = {
'1': 'Face',
'2': [
{'1': 'id', '3': 1, '4': 1, '5': 9, '9': 0, '10': 'id', '17': true},
{'1': 'detection', '3': 2, '4': 1, '5': 11, '6': '.ente.ml.Detection', '9': 1, '10': 'detection', '17': true},
{'1': 'confidence', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'confidence', '17': true},
],
'8': [
{'1': '_id'},
{'1': '_detection'},
{'1': '_confidence'},
],
};
/// Descriptor for `Face`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List faceDescriptor = $convert.base64Decode(
'CgRGYWNlEhMKAmlkGAEgASgJSABSAmlkiAEBEjUKCWRldGVjdGlvbhgCIAEoCzISLmVudGUubW'
'wuRGV0ZWN0aW9uSAFSCWRldGVjdGlvbogBARIjCgpjb25maWRlbmNlGAMgASgCSAJSCmNvbmZp'
'ZGVuY2WIAQFCBQoDX2lkQgwKCl9kZXRlY3Rpb25CDQoLX2NvbmZpZGVuY2U=');

View file

@ -0,0 +1,14 @@
//
// Generated code. Do not modify.
// source: ente/ml/face.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names
// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
export 'face.pb.dart';

View file

@ -0,0 +1,179 @@
//
// Generated code. Do not modify.
// source: ente/ml/fileml.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:core' as $core;
import 'package:fixnum/fixnum.dart' as $fixnum;
import 'package:protobuf/protobuf.dart' as $pb;
import 'face.pb.dart' as $2;
class FileML extends $pb.GeneratedMessage {
factory FileML({
$fixnum.Int64? id,
$core.Iterable<$core.double>? clip,
}) {
final $result = create();
if (id != null) {
$result.id = id;
}
if (clip != null) {
$result.clip.addAll(clip);
}
return $result;
}
FileML._() : super();
factory FileML.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory FileML.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileML', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create)
..aInt64(1, _omitFieldNames ? '' : 'id')
..p<$core.double>(2, _omitFieldNames ? '' : 'clip', $pb.PbFieldType.KD)
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
FileML clone() => FileML()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
FileML copyWith(void Function(FileML) updates) => super.copyWith((message) => updates(message as FileML)) as FileML;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static FileML create() => FileML._();
FileML createEmptyInstance() => create();
static $pb.PbList<FileML> createRepeated() => $pb.PbList<FileML>();
@$core.pragma('dart2js:noInline')
static FileML getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<FileML>(create);
static FileML? _defaultInstance;
@$pb.TagNumber(1)
$fixnum.Int64 get id => $_getI64(0);
@$pb.TagNumber(1)
set id($fixnum.Int64 v) { $_setInt64(0, v); }
@$pb.TagNumber(1)
$core.bool hasId() => $_has(0);
@$pb.TagNumber(1)
void clearId() => clearField(1);
@$pb.TagNumber(2)
$core.List<$core.double> get clip => $_getList(1);
}
class FileFaces extends $pb.GeneratedMessage {
factory FileFaces({
$core.Iterable<$2.Face>? faces,
$core.int? height,
$core.int? width,
$core.int? version,
$core.String? error,
}) {
final $result = create();
if (faces != null) {
$result.faces.addAll(faces);
}
if (height != null) {
$result.height = height;
}
if (width != null) {
$result.width = width;
}
if (version != null) {
$result.version = version;
}
if (error != null) {
$result.error = error;
}
return $result;
}
FileFaces._() : super();
factory FileFaces.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r);
factory FileFaces.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r);
static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileFaces', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create)
..pc<$2.Face>(1, _omitFieldNames ? '' : 'faces', $pb.PbFieldType.PM, subBuilder: $2.Face.create)
..a<$core.int>(2, _omitFieldNames ? '' : 'height', $pb.PbFieldType.O3)
..a<$core.int>(3, _omitFieldNames ? '' : 'width', $pb.PbFieldType.O3)
..a<$core.int>(4, _omitFieldNames ? '' : 'version', $pb.PbFieldType.O3)
..aOS(5, _omitFieldNames ? '' : 'error')
..hasRequiredFields = false
;
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.deepCopy] instead. '
'Will be removed in next major version')
FileFaces clone() => FileFaces()..mergeFromMessage(this);
@$core.Deprecated(
'Using this can add significant overhead to your binary. '
'Use [GeneratedMessageGenericExtensions.rebuild] instead. '
'Will be removed in next major version')
FileFaces copyWith(void Function(FileFaces) updates) => super.copyWith((message) => updates(message as FileFaces)) as FileFaces;
$pb.BuilderInfo get info_ => _i;
@$core.pragma('dart2js:noInline')
static FileFaces create() => FileFaces._();
FileFaces createEmptyInstance() => create();
static $pb.PbList<FileFaces> createRepeated() => $pb.PbList<FileFaces>();
@$core.pragma('dart2js:noInline')
static FileFaces getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor<FileFaces>(create);
static FileFaces? _defaultInstance;
@$pb.TagNumber(1)
$core.List<$2.Face> get faces => $_getList(0);
@$pb.TagNumber(2)
$core.int get height => $_getIZ(1);
@$pb.TagNumber(2)
set height($core.int v) { $_setSignedInt32(1, v); }
@$pb.TagNumber(2)
$core.bool hasHeight() => $_has(1);
@$pb.TagNumber(2)
void clearHeight() => clearField(2);
@$pb.TagNumber(3)
$core.int get width => $_getIZ(2);
@$pb.TagNumber(3)
set width($core.int v) { $_setSignedInt32(2, v); }
@$pb.TagNumber(3)
$core.bool hasWidth() => $_has(2);
@$pb.TagNumber(3)
void clearWidth() => clearField(3);
@$pb.TagNumber(4)
$core.int get version => $_getIZ(3);
@$pb.TagNumber(4)
set version($core.int v) { $_setSignedInt32(3, v); }
@$pb.TagNumber(4)
$core.bool hasVersion() => $_has(3);
@$pb.TagNumber(4)
void clearVersion() => clearField(4);
@$pb.TagNumber(5)
$core.String get error => $_getSZ(4);
@$pb.TagNumber(5)
set error($core.String v) { $_setString(4, v); }
@$pb.TagNumber(5)
$core.bool hasError() => $_has(4);
@$pb.TagNumber(5)
void clearError() => clearField(5);
}
const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names');
const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names');

View file

@ -0,0 +1,11 @@
//
// Generated code. Do not modify.
// source: ente/ml/fileml.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import

View file

@ -0,0 +1,57 @@
//
// Generated code. Do not modify.
// source: ente/ml/fileml.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
import 'dart:convert' as $convert;
import 'dart:core' as $core;
import 'dart:typed_data' as $typed_data;
@$core.Deprecated('Use fileMLDescriptor instead')
const FileML$json = {
'1': 'FileML',
'2': [
{'1': 'id', '3': 1, '4': 1, '5': 3, '9': 0, '10': 'id', '17': true},
{'1': 'clip', '3': 2, '4': 3, '5': 1, '10': 'clip'},
],
'8': [
{'1': '_id'},
],
};
/// Descriptor for `FileML`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List fileMLDescriptor = $convert.base64Decode(
'CgZGaWxlTUwSEwoCaWQYASABKANIAFICaWSIAQESEgoEY2xpcBgCIAMoAVIEY2xpcEIFCgNfaW'
'Q=');
@$core.Deprecated('Use fileFacesDescriptor instead')
const FileFaces$json = {
'1': 'FileFaces',
'2': [
{'1': 'faces', '3': 1, '4': 3, '5': 11, '6': '.ente.ml.Face', '10': 'faces'},
{'1': 'height', '3': 2, '4': 1, '5': 5, '9': 0, '10': 'height', '17': true},
{'1': 'width', '3': 3, '4': 1, '5': 5, '9': 1, '10': 'width', '17': true},
{'1': 'version', '3': 4, '4': 1, '5': 5, '9': 2, '10': 'version', '17': true},
{'1': 'error', '3': 5, '4': 1, '5': 9, '9': 3, '10': 'error', '17': true},
],
'8': [
{'1': '_height'},
{'1': '_width'},
{'1': '_version'},
{'1': '_error'},
],
};
/// Descriptor for `FileFaces`. Decode as a `google.protobuf.DescriptorProto`.
final $typed_data.Uint8List fileFacesDescriptor = $convert.base64Decode(
'CglGaWxlRmFjZXMSIwoFZmFjZXMYASADKAsyDS5lbnRlLm1sLkZhY2VSBWZhY2VzEhsKBmhlaW'
'dodBgCIAEoBUgAUgZoZWlnaHSIAQESGQoFd2lkdGgYAyABKAVIAVIFd2lkdGiIAQESHQoHdmVy'
'c2lvbhgEIAEoBUgCUgd2ZXJzaW9uiAEBEhkKBWVycm9yGAUgASgJSANSBWVycm9yiAEBQgkKB1'
'9oZWlnaHRCCAoGX3dpZHRoQgoKCF92ZXJzaW9uQggKBl9lcnJvcg==');

View file

@ -0,0 +1,14 @@
//
// Generated code. Do not modify.
// source: ente/ml/fileml.proto
//
// @dart = 2.12
// ignore_for_file: annotate_overrides, camel_case_types, comment_references
// ignore_for_file: constant_identifier_names
// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes
// ignore_for_file: non_constant_identifier_names, prefer_final_fields
// ignore_for_file: unnecessary_import, unnecessary_this, unused_import
export 'fileml.pb.dart';

View file

@ -1170,6 +1170,7 @@
}
},
"faces": "Faces",
"people": "People",
"contents": "Contents",
"addNew": "Add new",
"@addNew": {

View file

@ -25,6 +25,7 @@ import 'package:photos/services/app_lifecycle_service.dart';
import 'package:photos/services/billing_service.dart';
import 'package:photos/services/collections_service.dart';
import "package:photos/services/entity_service.dart";
import "package:photos/services/face_ml/face_ml_service.dart";
import 'package:photos/services/favorites_service.dart';
import 'package:photos/services/feature_flag_service.dart';
import 'package:photos/services/local_file_update_service.dart';
@ -242,9 +243,11 @@ Future<void> _init(bool isBackground, {String via = ''}) async {
// Can not including existing tf/ml binaries as they are not being built
// from source.
// See https://gitlab.com/fdroid/fdroiddata/-/merge_requests/12671#note_1294346819
// if (!UpdateService.instance.isFdroidFlavor()) {
// unawaited(ObjectDetectionService.instance.init());
// }
if (!UpdateService.instance.isFdroidFlavor()) {
// unawaited(ObjectDetectionService.instance.init());
unawaited(FaceMlService.instance.init());
FaceMlService.instance.listenIndexOnDiffSync();
}
_logger.info("Initialization done");
}

View file

@ -18,6 +18,8 @@ enum GalleryType {
searchResults,
locationTag,
quickLink,
peopleTag,
cluster,
}
extension GalleyTypeExtension on GalleryType {
@ -32,12 +34,14 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.locationTag:
case GalleryType.quickLink:
case GalleryType.uncategorized:
case GalleryType.peopleTag:
return true;
case GalleryType.hiddenSection:
case GalleryType.hiddenOwnedCollection:
case GalleryType.trash:
case GalleryType.sharedCollection:
case GalleryType.cluster:
return false;
}
}
@ -50,6 +54,7 @@ extension GalleyTypeExtension on GalleryType {
return true;
case GalleryType.hiddenSection:
case GalleryType.peopleTag:
case GalleryType.hiddenOwnedCollection:
case GalleryType.favorite:
case GalleryType.searchResults:
@ -59,6 +64,7 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.trash:
case GalleryType.sharedCollection:
case GalleryType.locationTag:
case GalleryType.cluster:
return false;
}
}
@ -75,12 +81,14 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.uncategorized:
case GalleryType.locationTag:
case GalleryType.quickLink:
case GalleryType.peopleTag:
return true;
case GalleryType.trash:
case GalleryType.archive:
case GalleryType.hiddenSection:
case GalleryType.hiddenOwnedCollection:
case GalleryType.sharedCollection:
case GalleryType.cluster:
return false;
}
}
@ -98,8 +106,10 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.localFolder:
case GalleryType.locationTag:
case GalleryType.quickLink:
case GalleryType.peopleTag:
return true;
case GalleryType.trash:
case GalleryType.cluster:
case GalleryType.sharedCollection:
return false;
}
@ -114,8 +124,10 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.archive:
case GalleryType.uncategorized:
case GalleryType.locationTag:
case GalleryType.peopleTag:
return true;
case GalleryType.hiddenSection:
case GalleryType.cluster:
case GalleryType.hiddenOwnedCollection:
case GalleryType.localFolder:
case GalleryType.trash:
@ -132,6 +144,7 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.quickLink:
return true;
case GalleryType.hiddenSection:
case GalleryType.peopleTag:
case GalleryType.hiddenOwnedCollection:
case GalleryType.uncategorized:
case GalleryType.favorite:
@ -139,6 +152,7 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.homepage:
case GalleryType.archive:
case GalleryType.localFolder:
case GalleryType.cluster:
case GalleryType.trash:
case GalleryType.locationTag:
return false;
@ -154,6 +168,7 @@ extension GalleyTypeExtension on GalleryType {
return true;
case GalleryType.hiddenSection:
case GalleryType.peopleTag:
case GalleryType.hiddenOwnedCollection:
case GalleryType.favorite:
case GalleryType.searchResults:
@ -162,6 +177,7 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.trash:
case GalleryType.sharedCollection:
case GalleryType.locationTag:
case GalleryType.cluster:
return false;
}
}
@ -182,10 +198,12 @@ extension GalleyTypeExtension on GalleryType {
return true;
case GalleryType.hiddenSection:
case GalleryType.peopleTag:
case GalleryType.hiddenOwnedCollection:
case GalleryType.localFolder:
case GalleryType.trash:
case GalleryType.favorite:
case GalleryType.cluster:
case GalleryType.sharedCollection:
return false;
}
@ -203,12 +221,14 @@ extension GalleyTypeExtension on GalleryType {
case GalleryType.searchResults:
case GalleryType.uncategorized:
case GalleryType.locationTag:
case GalleryType.peopleTag:
return true;
case GalleryType.hiddenSection:
case GalleryType.hiddenOwnedCollection:
case GalleryType.quickLink:
case GalleryType.favorite:
case GalleryType.cluster:
case GalleryType.archive:
case GalleryType.localFolder:
case GalleryType.trash:
@ -244,7 +264,7 @@ extension GalleyTypeExtension on GalleryType {
}
bool showEditLocation() {
return this != GalleryType.sharedCollection;
return this != GalleryType.sharedCollection && this != GalleryType.cluster;
}
}
@ -334,7 +354,9 @@ extension GalleryAppBarExtn on GalleryType {
case GalleryType.locationTag:
case GalleryType.searchResults:
return false;
case GalleryType.cluster:
case GalleryType.uncategorized:
case GalleryType.peopleTag:
case GalleryType.ownedCollection:
case GalleryType.sharedCollection:
case GalleryType.quickLink:

View file

@ -0,0 +1,7 @@
typedef Embedding = List<double>;
typedef Num3DInputMatrix = List<List<List<num>>>;
typedef Int3DInputMatrix = List<List<List<int>>>;
typedef Double3DInputMatrix = List<List<List<double>>>;

View file

@ -0,0 +1,3 @@
const faceMlVersion = 1;
const clusterMlVersion = 1;
const minimumClusterSize = 2;

View file

@ -8,8 +8,15 @@ class GenericSearchResult extends SearchResult {
final List<EnteFile> _files;
final ResultType _type;
final Function(BuildContext context)? onResultTap;
final Map<String, dynamic> params;
GenericSearchResult(this._type, this._name, this._files, {this.onResultTap});
GenericSearchResult(
this._type,
this._name,
this._files, {
this.onResultTap,
this.params = const {},
});
@override
String name() {

View file

@ -0,0 +1,3 @@
const kPersonParamID = 'person_id';
const kClusterParamId = 'cluster_id';
const kFileID = 'file_id';

View file

@ -33,6 +33,7 @@ enum ResultType {
fileCaption,
event,
shared,
faces,
magic,
}
@ -55,7 +56,7 @@ extension SectionTypeExtensions on SectionType {
String sectionTitle(BuildContext context) {
switch (this) {
case SectionType.face:
return S.of(context).faces;
return S.of(context).people;
case SectionType.content:
return S.of(context).contents;
case SectionType.moment:
@ -99,7 +100,7 @@ extension SectionTypeExtensions on SectionType {
bool get isCTAVisible {
switch (this) {
case SectionType.face:
return false;
return true;
case SectionType.content:
return false;
case SectionType.moment:
@ -117,6 +118,8 @@ extension SectionTypeExtensions on SectionType {
}
}
bool get sortByName => this != SectionType.face;
bool get isEmptyCTAVisible {
switch (this) {
case SectionType.face:
@ -245,8 +248,7 @@ extension SectionTypeExtensions on SectionType {
}) {
switch (this) {
case SectionType.face:
return Future.value(List<GenericSearchResult>.empty());
return SearchService.instance.getAllFace(limit);
case SectionType.content:
return Future.value(List<GenericSearchResult>.empty());

View file

@ -0,0 +1,2 @@
const kLaplacianThreshold = 10;
const kLapacianDefault = 10000.0;

View file

@ -0,0 +1,115 @@
import 'package:logging/logging.dart';
import "package:photos/services/face_ml/blur_detection/blur_constants.dart";
class BlurDetectionService {
final _logger = Logger('BlurDetectionService');
// singleton pattern
BlurDetectionService._privateConstructor();
static final instance = BlurDetectionService._privateConstructor();
factory BlurDetectionService() => instance;
Future<(bool, double)> predictIsBlurGrayLaplacian(
List<List<int>> grayImage, {
int threshold = kLaplacianThreshold,
}) async {
final List<List<int>> laplacian = _applyLaplacian(grayImage);
final double variance = _calculateVariance(laplacian);
_logger.info('Variance: $variance');
return (variance < threshold, variance);
}
double _calculateVariance(List<List<int>> matrix) {
final int numRows = matrix.length;
final int numCols = matrix[0].length;
final int totalElements = numRows * numCols;
// Calculate the mean
double mean = 0;
for (var row in matrix) {
for (var value in row) {
mean += value;
}
}
mean /= totalElements;
// Calculate the variance
double variance = 0;
for (var row in matrix) {
for (var value in row) {
final double diff = value - mean;
variance += diff * diff;
}
}
variance /= totalElements;
return variance;
}
List<List<int>> _padImage(List<List<int>> image) {
final int numRows = image.length;
final int numCols = image[0].length;
// Create a new matrix with extra padding
final List<List<int>> paddedImage = List.generate(
numRows + 2,
(i) => List.generate(numCols + 2, (j) => 0, growable: false),
growable: false,
);
// Copy original image into the center of the padded image
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numCols; j++) {
paddedImage[i + 1][j + 1] = image[i][j];
}
}
// Reflect padding
// Top and bottom rows
for (int j = 1; j <= numCols; j++) {
paddedImage[0][j] = paddedImage[2][j]; // Top row
paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row
}
// Left and right columns
for (int i = 0; i < numRows + 2; i++) {
paddedImage[i][0] = paddedImage[i][2]; // Left column
paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column
}
return paddedImage;
}
List<List<int>> _applyLaplacian(List<List<int>> image) {
final List<List<int>> paddedImage = _padImage(image);
final int numRows = image.length;
final int numCols = image[0].length;
final List<List<int>> outputImage = List.generate(
numRows,
(i) => List.generate(numCols, (j) => 0, growable: false),
growable: false,
);
// Define the Laplacian kernel
final List<List<int>> kernel = [
[0, 1, 0],
[1, -4, 1],
[0, 1, 0],
];
// Apply the kernel to each pixel
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numCols; j++) {
int sum = 0;
for (int ki = 0; ki < 3; ki++) {
for (int kj = 0; kj < 3; kj++) {
sum += paddedImage[i + ki][j + kj] * kernel[ki][kj];
}
}
// Adjust the output value if necessary (e.g., clipping)
outputImage[i][j] = sum; //.clamp(0, 255);
}
}
return outputImage;
}
}

View file

@ -0,0 +1,36 @@
class AlignmentResult {
final List<List<double>> affineMatrix; // 3x3
final List<double> center; // [x, y]
final double size; // 1 / scale
final double rotation; // atan2(simRotation[1][0], simRotation[0][0]);
AlignmentResult({required this.affineMatrix, required this.center, required this.size, required this.rotation});
AlignmentResult.empty()
: affineMatrix = <List<double>>[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
center = <double>[0, 0],
size = 1,
rotation = 0;
factory AlignmentResult.fromJson(Map<String, dynamic> json) {
return AlignmentResult(
affineMatrix: (json['affineMatrix'] as List)
.map((item) => List<double>.from(item))
.toList(),
center: List<double>.from(json['center'] as List),
size: json['size'] as double,
rotation: json['rotation'] as double,
);
}
Map<String, dynamic> toJson() => {
'affineMatrix': affineMatrix,
'center': center,
'size': size,
'rotation': rotation,
};
}

View file

@ -0,0 +1,171 @@
import 'dart:math' show atan2;
import 'package:ml_linalg/linalg.dart';
import 'package:photos/extensions/ml_linalg_extensions.dart';
import "package:photos/services/face_ml/face_alignment/alignment_result.dart";
/// Class to compute the similarity transform between two sets of points.
///
/// The class estimates the parameters of the similarity transformation via the `estimate` function.
/// After estimation, the transformation can be applied to an image using the `warpAffine` function.
class SimilarityTransform {
Matrix _params = Matrix.fromList([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0, 0, 1],
]);
List<double> _center = <double>[0, 0]; // [x, y]
double _size = 1; // 1 / scale
double _rotation = 0; // atan2(simRotation[1][0], simRotation[0][0]);
final arcface4Landmarks = [
<double>[38.2946, 51.6963],
<double>[73.5318, 51.5014],
<double>[56.0252, 71.7366],
<double>[56.1396, 92.2848],
];
final arcface5Landmarks = [
<double>[38.2946, 51.6963],
<double>[73.5318, 51.5014],
<double>[56.0252, 71.7366],
<double>[41.5493, 92.3655],
<double>[70.7299, 92.2041],
];
get arcfaceNormalized4 => arcface4Landmarks
.map((list) => list.map((value) => value / 112.0).toList())
.toList();
get arcfaceNormalized5 => arcface5Landmarks
.map((list) => list.map((value) => value / 112.0).toList())
.toList();
List<List<double>> get paramsList => _params.to2DList();
// singleton pattern
SimilarityTransform._privateConstructor();
static final instance = SimilarityTransform._privateConstructor();
factory SimilarityTransform() => instance;
void _cleanParams() {
_params = Matrix.fromList([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0, 0, 1],
]);
_center = <double>[0, 0];
_size = 1;
_rotation = 0;
}
/// Function to estimate the parameters of the affine transformation. These parameters are stored in the class variable params.
///
/// Returns a tuple of (AlignmentResult, bool). The bool indicates whether the parameters are valid or not.
///
/// Runs efficiently in about 1-3 ms after initial warm-up.
///
/// It takes the source and destination points as input and returns the
/// parameters of the affine transformation as output. The function
/// returns false if the parameters cannot be estimated. The function
/// estimates the parameters by solving a least-squares problem using
/// the Umeyama algorithm, via [_umeyama].
(AlignmentResult, bool) estimate(List<List<double>> src) {
_cleanParams();
final (params, center, size, rotation) =
_umeyama(src, arcfaceNormalized5, true);
_params = params;
_center = center;
_size = size;
_rotation = rotation;
final alignmentResult = AlignmentResult(
affineMatrix: paramsList,
center: _center,
size: _size,
rotation: _rotation,
);
// We check for NaN in the transformation matrix params.
final isNoNanInParam =
!_params.asFlattenedList.any((element) => element.isNaN);
return (alignmentResult, isNoNanInParam);
}
static (Matrix, List<double>, double, double) _umeyama(
List<List<double>> src,
List<List<double>> dst, [
bool estimateScale = true,
]) {
final srcMat = Matrix.fromList(
src,
// .map((list) => list.map((value) => value.toDouble()).toList())
// .toList(),
);
final dstMat = Matrix.fromList(dst);
final num = srcMat.rowCount;
final dim = srcMat.columnCount;
// Compute mean of src and dst.
final srcMean = srcMat.mean(Axis.columns);
final dstMean = dstMat.mean(Axis.columns);
// Subtract mean from src and dst.
final srcDemean = srcMat.mapRows((vector) => vector - srcMean);
final dstDemean = dstMat.mapRows((vector) => vector - dstMean);
// Eq. (38).
final A = (dstDemean.transpose() * srcDemean) / num;
// Eq. (39).
var d = Vector.filled(dim, 1.0);
if (A.determinant() < 0) {
d = d.set(dim - 1, -1);
}
var T = Matrix.identity(dim + 1);
final svdResult = A.svd();
final Matrix U = svdResult['U']!;
final Vector S = svdResult['S']!;
final Matrix V = svdResult['V']!;
// Eq. (40) and (43).
final rank = A.matrixRank();
if (rank == 0) {
return (T * double.nan, <double>[0, 0], 1, 0);
} else if (rank == dim - 1) {
if (U.determinant() * V.determinant() > 0) {
T = T.setSubMatrix(0, dim, 0, dim, U * V);
} else {
final s = d[dim - 1];
d = d.set(dim - 1, -1);
final replacement = U * Matrix.diagonal(d.toList()) * V;
T = T.setSubMatrix(0, dim, 0, dim, replacement);
d = d.set(dim - 1, s);
}
} else {
final replacement = U * Matrix.diagonal(d.toList()) * V;
T = T.setSubMatrix(0, dim, 0, dim, replacement);
}
final Matrix simRotation = U * Matrix.diagonal(d.toList()) * V;
var scale = 1.0;
if (estimateScale) {
// Eq. (41) and (42).
scale = 1.0 / srcDemean.variance(Axis.columns).sum() * (S * d).sum();
}
final subTIndices = Iterable<int>.generate(dim, (index) => index);
final subT = T.sample(rowIndices: subTIndices, columnIndices: subTIndices);
final newSubT = dstMean - (subT * srcMean) * scale;
T = T.setValues(0, dim, dim, dim + 1, newSubT);
final newNewSubT =
T.sample(rowIndices: subTIndices, columnIndices: subTIndices) * scale;
T = T.setSubMatrix(0, dim, 0, dim, newNewSubT);
// final List<double> translation = [T[0][2], T[1][2]];
// final simRotation = replacement?;
final size = 1 / scale;
final rotation = atan2(simRotation[1][0], simRotation[0][0]);
final meanTranslation = (dstMean - 0.5) * size;
final centerMat = srcMean - meanTranslation;
final List<double> center = [centerMat[0], centerMat[1]];
return (T, center, size, rotation);
}
}

View file

@ -0,0 +1,55 @@
import 'dart:math' show sqrt;
/// Calculates the cosine distance between two embeddings/vectors.
///
/// Throws an ArgumentError if the vectors are of different lengths or
/// if either of the vectors has a magnitude of zero.
double cosineDistance(List<double> vector1, List<double> vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}
double dotProduct = 0.0;
double magnitude1 = 0.0;
double magnitude2 = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
magnitude1 += vector1[i] * vector1[i];
magnitude2 += vector2[i] * vector2[i];
}
magnitude1 = sqrt(magnitude1);
magnitude2 = sqrt(magnitude2);
// Avoid division by zero. This should never happen. If it does, then one of the vectors contains only zeros.
if (magnitude1 == 0 || magnitude2 == 0) {
throw ArgumentError('Vectors must not have a magnitude of zero');
}
final double similarity = dotProduct / (magnitude1 * magnitude2);
// Cosine distance is the complement of cosine similarity
return 1.0 - similarity;
}
// cosineDistForNormVectors calculates the cosine distance between two normalized embeddings/vectors.
@pragma('vm:entry-point')
double cosineDistForNormVectors(List<double> vector1, List<double> vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}
double dotProduct = 0.0;
for (int i = 0; i < vector1.length; i++) {
dotProduct += vector1[i] * vector2[i];
}
return 1.0 - dotProduct;
}
double calculateSqrDistance(List<double> v1, List<double> v2) {
double sum = 0;
for (int i = 0; i < v1.length; i++) {
sum += (v1[i] - v2[i]) * (v1[i] - v2[i]);
}
return sqrt(sum);
}

View file

@ -0,0 +1,405 @@
import "dart:async";
import "dart:developer";
import "dart:isolate";
import "dart:math" show max;
import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/services/face_ml/face_clustering/cosine_distance.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
final String faceID;
final List<double> embedding;
int? clusterId;
String? closestFaceId;
int? closestDist;
FaceInfo({
required this.faceID,
required this.embedding,
this.clusterId,
});
}
enum ClusterOperation { linearIncrementalClustering }
class FaceLinearClustering {
final _logger = Logger("FaceLinearClustering");
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(seconds: 30);
int _activeTasks = 0;
final _initLock = Lock();
late Isolate _isolate;
late ReceivePort _receivePort = ReceivePort();
late SendPort _mainSendPort;
bool isSpawned = false;
bool isRunning = false;
static const recommendedDistanceThreshold = 0.3;
// singleton pattern
FaceLinearClustering._privateConstructor();
/// Use this instance to access the FaceClustering service.
/// e.g. `FaceLinearClustering.instance.predict(dataset)`
static final instance = FaceLinearClustering._privateConstructor();
factory FaceLinearClustering() => instance;
Future<void> init() async {
return _initLock.synchronized(() async {
if (isSpawned) return;
_receivePort = ReceivePort();
try {
_isolate = await Isolate.spawn(
_isolateMain,
_receivePort.sendPort,
);
_mainSendPort = await _receivePort.first as SendPort;
isSpawned = true;
_resetInactivityTimer();
} catch (e) {
_logger.severe('Could not spawn isolate', e);
isSpawned = false;
}
});
}
Future<void> ensureSpawned() async {
if (!isSpawned) {
await init();
}
}
/// The main execution function of the isolate.
static void _isolateMain(SendPort mainSendPort) async {
final receivePort = ReceivePort();
mainSendPort.send(receivePort.sendPort);
receivePort.listen((message) async {
final functionIndex = message[0] as int;
final function = ClusterOperation.values[functionIndex];
final args = message[1] as Map<String, dynamic>;
final sendPort = message[2] as SendPort;
try {
switch (function) {
case ClusterOperation.linearIncrementalClustering:
final input = args['input'] as Map<String, (int?, Uint8List)>;
final result = FaceLinearClustering._runLinearClustering(input);
sendPort.send(result);
break;
}
} catch (e, stackTrace) {
sendPort
.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
}
});
}
/// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result.
Future<dynamic> _runInIsolate(
(ClusterOperation, Map<String, dynamic>) message,
) async {
await ensureSpawned();
_resetInactivityTimer();
final completer = Completer<dynamic>();
final answerPort = ReceivePort();
_activeTasks++;
_mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]);
answerPort.listen((receivedMessage) {
if (receivedMessage is Map && receivedMessage.containsKey('error')) {
// Handle the error
final errorMessage = receivedMessage['error'];
final errorStackTrace = receivedMessage['stackTrace'];
final exception = Exception(errorMessage);
final stackTrace = StackTrace.fromString(errorStackTrace);
completer.completeError(exception, stackTrace);
} else {
completer.complete(receivedMessage);
}
});
_activeTasks--;
return completer.future;
}
/// Resets a timer that kills the isolate after a certain amount of inactivity.
///
/// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`)
void _resetInactivityTimer() {
_inactivityTimer?.cancel();
_inactivityTimer = Timer(_inactivityDuration, () {
if (_activeTasks > 0) {
_logger.info('Tasks are still running. Delaying isolate disposal.');
// Optionally, reschedule the timer to check again later.
_resetInactivityTimer();
} else {
_logger.info(
'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.',
);
dispose();
}
});
}
/// Disposes the isolate worker.
void dispose() {
if (!isSpawned) return;
isSpawned = false;
_isolate.kill();
_receivePort.close();
_inactivityTimer?.cancel();
}
/// Runs the clustering algorithm on the given [input], in an isolate.
///
/// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset.
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predict(
Map<String, (int?, Uint8List)> input,
) async {
if (input.isEmpty) {
_logger.warning(
"Clustering dataset of embeddings is empty, returning empty list.",
);
return null;
}
if (isRunning) {
_logger.warning("Clustering is already running, returning empty list.");
return null;
}
isRunning = true;
// Clustering inside the isolate
_logger.info(
"Start clustering on ${input.length} embeddings inside computer isolate",
);
final stopwatchClustering = Stopwatch()..start();
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final Map<String, int> faceIdToCluster = await _runInIsolate(
(ClusterOperation.linearIncrementalClustering, {'input': input}),
);
// return _runLinearClusteringInComputer(input);
_logger.info(
'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds',
);
isRunning = false;
return faceIdToCluster;
}
static Map<String, int> _runLinearClustering(
Map<String, (int?, Uint8List)> x,
) {
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
);
final List<FaceInfo> faceInfos = [];
for (final entry in x.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value.$2).values,
clusterId: entry.value.$1,
),
);
}
// Sort the faceInfos such that the ones with null clusterId are at the end
faceInfos.sort((a, b) {
if (a.clusterId == null && b.clusterId == null) {
return 0;
} else if (a.clusterId == null) {
return 1;
} else if (b.clusterId == null) {
return -1;
} else {
return 0;
}
});
// Count the amount of null values at the end
int nullCount = 0;
for (final faceInfo in faceInfos.reversed) {
if (faceInfo.clusterId == null) {
nullCount++;
} else {
break;
}
}
log(
"[ClusterIsolate] ${DateTime.now()} Clustering $nullCount new faces without clusterId, and ${faceInfos.length - nullCount} faces with clusterId",
);
for (final clusteredFaceInfo
in faceInfos.sublist(0, faceInfos.length - nullCount)) {
assert(clusteredFaceInfo.clusterId != null);
}
final int totalFaces = faceInfos.length;
int clusterID = 1;
if (faceInfos.isNotEmpty) {
faceInfos.first.clusterId = clusterID;
}
log(
"[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces",
);
final stopwatchClustering = Stopwatch()..start();
for (int i = 1; i < totalFaces; i++) {
// Incremental clustering, so we can skip faces that already have a clusterId
if (faceInfos[i].clusterId != null) {
clusterID = max(clusterID, faceInfos[i].clusterId!);
continue;
}
final currentEmbedding = faceInfos[i].embedding;
int closestIdx = -1;
double closestDistance = double.infinity;
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
}
for (int j = 0; j < i; j++) {
final double distance = cosineDistForNormVectors(
currentEmbedding,
faceInfos[j].embedding,
);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
}
}
if (closestDistance < recommendedDistanceThreshold) {
if (faceInfos[closestIdx].clusterId == null) {
// Ideally this should never happen, but just in case log it
log(
" [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID",
);
clusterID++;
faceInfos[closestIdx].clusterId = clusterID;
}
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId;
} else {
clusterID++;
faceInfos[i].clusterId = clusterID;
}
}
final Map<String, int> result = {};
for (final faceInfo in faceInfos) {
result[faceInfo.faceID] = faceInfo.clusterId!;
}
stopwatchClustering.stop();
log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings (${faceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
);
// return result;
// NOTe: The main clustering logic is done, the following is just filtering and logging
final input = x;
final faceIdToCluster = result;
stopwatchClustering.reset();
stopwatchClustering.start();
final Set<String> newFaceIds = <String>{};
input.forEach((key, value) {
if (value.$1 == null) {
newFaceIds.add(key);
}
});
// Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs
final Map<int, int> clusterIdToSize = {};
faceIdToCluster.forEach((key, value) {
if (clusterIdToSize.containsKey(value)) {
clusterIdToSize[value] = clusterIdToSize[value]! + 1;
} else {
clusterIdToSize[value] = 1;
}
});
final Map<String, int> faceIdToClusterFiltered = {};
for (final entry in faceIdToCluster.entries) {
if (clusterIdToSize[entry.value]! > 0 && newFaceIds.contains(entry.key)) {
faceIdToClusterFiltered[entry.key] = entry.value;
}
}
// print top 10 cluster ids and their sizes based on the internal cluster id
final clusterIds = faceIdToCluster.values.toSet();
final clusterSizes = clusterIds.map((clusterId) {
return faceIdToCluster.values.where((id) => id == clusterId).length;
}).toList();
clusterSizes.sort();
// find clusters whose size is graeter than 1
int oneClusterCount = 0;
int moreThan5Count = 0;
int moreThan10Count = 0;
int moreThan20Count = 0;
int moreThan50Count = 0;
int moreThan100Count = 0;
// for (int i = 0; i < clusterSizes.length; i++) {
// if (clusterSizes[i] > 100) {
// moreThan100Count++;
// } else if (clusterSizes[i] > 50) {
// moreThan50Count++;
// } else if (clusterSizes[i] > 20) {
// moreThan20Count++;
// } else if (clusterSizes[i] > 10) {
// moreThan10Count++;
// } else if (clusterSizes[i] > 5) {
// moreThan5Count++;
// } else if (clusterSizes[i] == 1) {
// oneClusterCount++;
// }
// }
for (int i = 0; i < clusterSizes.length; i++) {
if (clusterSizes[i] > 100) {
moreThan100Count++;
}
if (clusterSizes[i] > 50) {
moreThan50Count++;
}
if (clusterSizes[i] > 20) {
moreThan20Count++;
}
if (clusterSizes[i] > 10) {
moreThan10Count++;
}
if (clusterSizes[i] > 5) {
moreThan5Count++;
}
if (clusterSizes[i] == 1) {
oneClusterCount++;
}
}
// print the metrics
log(
'[ClusterIsolate] Total clusters ${clusterIds.length}, '
'oneClusterCount $oneClusterCount, '
'moreThan5Count $moreThan5Count, '
'moreThan10Count $moreThan10Count, '
'moreThan20Count $moreThan20Count, '
'moreThan50Count $moreThan50Count, '
'moreThan100Count $moreThan100Count',
);
stopwatchClustering.stop();
log(
"[ClusterIsolate] Clustering additional steps took ${stopwatchClustering.elapsedMilliseconds} ms",
);
// log('Top clusters count ${clusterSizes.reversed.take(10).toList()}');
return faceIdToClusterFiltered;
}
}

View file

@ -0,0 +1,469 @@
import 'dart:convert' show utf8;
import 'dart:math' show sqrt, pow;
import 'dart:ui' show Size;
import 'package:crypto/crypto.dart' show sha256;
abstract class Detection {
final double score;
Detection({required this.score});
const Detection.empty() : score = 0;
get width;
get height;
@override
String toString();
}
extension BBoxExtension on List<double> {
void roundBoxToDouble() {
final widthRounded = (this[2] - this[0]).roundToDouble();
final heightRounded = (this[3] - this[1]).roundToDouble();
this[0] = this[0].roundToDouble();
this[1] = this[1].roundToDouble();
this[2] = this[0] + widthRounded;
this[3] = this[1] + heightRounded;
}
// double get xMinBox =>
// isNotEmpty ? this[0] : throw IndexError.withLength(0, length);
// double get yMinBox =>
// length >= 2 ? this[1] : throw IndexError.withLength(1, length);
// double get xMaxBox =>
// length >= 3 ? this[2] : throw IndexError.withLength(2, length);
// double get yMaxBox =>
// length >= 4 ? this[3] : throw IndexError.withLength(3, length);
}
/// This class represents a face detection with relative coordinates in the range [0, 1].
/// The coordinates are relative to the image size. The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate.
///
/// The [score] attribute is a double representing the confidence of the face detection.
///
/// The [box] attribute is a list of 4 doubles, representing the coordinates of the bounding box of the face detection.
/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox].
///
/// The [allKeypoints] attribute is a list of 6 lists of 2 doubles, representing the coordinates of the keypoints of the face detection.
/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order.
class FaceDetectionRelative extends Detection {
final List<double> box;
final List<List<double>> allKeypoints;
double get xMinBox => box[0];
double get yMinBox => box[1];
double get xMaxBox => box[2];
double get yMaxBox => box[3];
List<double> get leftEye => allKeypoints[0];
List<double> get rightEye => allKeypoints[1];
List<double> get nose => allKeypoints[2];
List<double> get leftMouth => allKeypoints[3];
List<double> get rightMouth => allKeypoints[4];
FaceDetectionRelative({
required double score,
required List<double> box,
required List<List<double>> allKeypoints,
}) : assert(
box.every((e) => e >= -0.1 && e <= 1.1),
"Bounding box values must be in the range [0, 1], with only a small margin of error allowed.",
),
assert(
allKeypoints
.every((sublist) => sublist.every((e) => e >= -0.1 && e <= 1.1)),
"All keypoints must be in the range [0, 1], with only a small margin of error allowed.",
),
box = List<double>.from(box.map((e) => e.clamp(0.0, 1.0))),
allKeypoints = allKeypoints
.map(
(sublist) =>
List<double>.from(sublist.map((e) => e.clamp(0.0, 1.0))),
)
.toList(),
super(score: score);
factory FaceDetectionRelative.zero() {
return FaceDetectionRelative(
score: 0,
box: <double>[0, 0, 0, 0],
allKeypoints: <List<double>>[
[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
],
);
}
/// This is used to initialize the FaceDetectionRelative object with default values.
/// This constructor is useful because it can be used to initialize a FaceDetectionRelative object as a constant.
/// Contrary to the `FaceDetectionRelative.zero()` constructor, this one gives immutable attributes [box] and [allKeypoints].
FaceDetectionRelative.defaultInitialization()
: box = const <double>[0, 0, 0, 0],
allKeypoints = const <List<double>>[
[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
],
super.empty();
FaceDetectionRelative getNearestDetection(
List<FaceDetectionRelative> detections,
) {
if (detections.isEmpty) {
throw ArgumentError("The detection list cannot be empty.");
}
var nearestDetection = detections[0];
var minDistance = double.infinity;
// Calculate the center of the current instance
final centerX1 = (xMinBox + xMaxBox) / 2;
final centerY1 = (yMinBox + yMaxBox) / 2;
for (var detection in detections) {
final centerX2 = (detection.xMinBox + detection.xMaxBox) / 2;
final centerY2 = (detection.yMinBox + detection.yMaxBox) / 2;
final distance =
sqrt(pow(centerX2 - centerX1, 2) + pow(centerY2 - centerY1, 2));
if (distance < minDistance) {
minDistance = distance;
nearestDetection = detection;
}
}
return nearestDetection;
}
void transformRelativeToOriginalImage(
List<double> fromBox, // [xMin, yMin, xMax, yMax]
List<double> toBox, // [xMin, yMin, xMax, yMax]
) {
// Return if all elements of fromBox and toBox are equal
for (int i = 0; i < fromBox.length; i++) {
if (fromBox[i] != toBox[i]) {
break;
}
if (i == fromBox.length - 1) {
return;
}
}
// Account for padding
final double paddingXRatio =
(fromBox[0] - toBox[0]) / (toBox[2] - toBox[0]);
final double paddingYRatio =
(fromBox[1] - toBox[1]) / (toBox[3] - toBox[1]);
// Calculate the scaling and translation
final double scaleX = (fromBox[2] - fromBox[0]) / (1 - 2 * paddingXRatio);
final double scaleY = (fromBox[3] - fromBox[1]) / (1 - 2 * paddingYRatio);
final double translateX = fromBox[0] - paddingXRatio * scaleX;
final double translateY = fromBox[1] - paddingYRatio * scaleY;
// Transform Box
_transformBox(box, scaleX, scaleY, translateX, translateY);
// Transform All Keypoints
for (int i = 0; i < allKeypoints.length; i++) {
allKeypoints[i] = _transformPoint(
allKeypoints[i],
scaleX,
scaleY,
translateX,
translateY,
);
}
}
void correctForMaintainedAspectRatio(
Size originalSize,
Size newSize,
) {
// Return if both are the same size, meaning no scaling was done on both width and height
if (originalSize == newSize) {
return;
}
// Calculate the scaling
final double scaleX = originalSize.width / newSize.width;
final double scaleY = originalSize.height / newSize.height;
const double translateX = 0;
const double translateY = 0;
// Transform Box
_transformBox(box, scaleX, scaleY, translateX, translateY);
// Transform All Keypoints
for (int i = 0; i < allKeypoints.length; i++) {
allKeypoints[i] = _transformPoint(
allKeypoints[i],
scaleX,
scaleY,
translateX,
translateY,
);
}
}
void _transformBox(
List<double> box,
double scaleX,
double scaleY,
double translateX,
double translateY,
) {
box[0] = (box[0] * scaleX + translateX).clamp(0.0, 1.0);
box[1] = (box[1] * scaleY + translateY).clamp(0.0, 1.0);
box[2] = (box[2] * scaleX + translateX).clamp(0.0, 1.0);
box[3] = (box[3] * scaleY + translateY).clamp(0.0, 1.0);
}
List<double> _transformPoint(
List<double> point,
double scaleX,
double scaleY,
double translateX,
double translateY,
) {
return [
(point[0] * scaleX + translateX).clamp(0.0, 1.0),
(point[1] * scaleY + translateY).clamp(0.0, 1.0),
];
}
FaceDetectionAbsolute toAbsolute({
required int imageWidth,
required int imageHeight,
}) {
final scoreCopy = score;
final boxCopy = List<double>.from(box, growable: false);
final allKeypointsCopy = allKeypoints
.map((sublist) => List<double>.from(sublist, growable: false))
.toList();
boxCopy[0] *= imageWidth;
boxCopy[1] *= imageHeight;
boxCopy[2] *= imageWidth;
boxCopy[3] *= imageHeight;
// final intbox = boxCopy.map((e) => e.toInt()).toList();
for (List<double> keypoint in allKeypointsCopy) {
keypoint[0] *= imageWidth;
keypoint[1] *= imageHeight;
}
// final intKeypoints =
// allKeypointsCopy.map((e) => e.map((e) => e.toInt()).toList()).toList();
return FaceDetectionAbsolute(
score: scoreCopy,
box: boxCopy,
allKeypoints: allKeypointsCopy,
);
}
String toFaceID({required int fileID}) {
// Assert that the values are within the expected range
assert(
(xMinBox >= 0 && xMinBox <= 1) &&
(yMinBox >= 0 && yMinBox <= 1) &&
(xMaxBox >= 0 && xMaxBox <= 1) &&
(yMaxBox >= 0 && yMaxBox <= 1),
"Bounding box values must be in the range [0, 1]",
);
// Extract bounding box values
final String xMin =
xMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2);
final String yMin =
yMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2);
final String xMax =
xMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2);
final String yMax =
yMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2);
// Convert the bounding box values to string and concatenate
final String rawID = "${xMin}_${yMin}_${xMax}_$yMax";
// Hash the concatenated string using SHA256
final digest = sha256.convert(utf8.encode(rawID));
// Return the hexadecimal representation of the hash
return fileID.toString() + '_' + digest.toString();
}
/// This method is used to generate a faceID for a face detection that was manually added by the user.
static String toFaceIDEmpty({required int fileID}) {
return fileID.toString() + '_0';
}
/// This method is used to check if a faceID corresponds to a manually added face detection and not an actual face detection.
static bool isFaceIDEmpty(String faceID) {
return faceID.split('_')[1] == '0';
}
@override
String toString() {
return 'FaceDetectionRelative( with relative coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )';
}
Map<String, dynamic> toJson() {
return {
'score': score,
'box': box,
'allKeypoints': allKeypoints,
};
}
factory FaceDetectionRelative.fromJson(Map<String, dynamic> json) {
return FaceDetectionRelative(
score: (json['score'] as num).toDouble(),
box: List<double>.from(json['box']),
allKeypoints: (json['allKeypoints'] as List)
.map((item) => List<double>.from(item))
.toList(),
);
}
@override
/// The width of the bounding box of the face detection, in relative range [0, 1].
double get width => xMaxBox - xMinBox;
@override
/// The height of the bounding box of the face detection, in relative range [0, 1].
double get height => yMaxBox - yMinBox;
}
/// This class represents a face detection with absolute coordinates in pixels, in the range [0, imageWidth] for the horizontal coordinates and [0, imageHeight] for the vertical coordinates.
/// The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate.
///
/// The [score] attribute is a double representing the confidence of the face detection.
///
/// The [box] attribute is a list of 4 integers, representing the coordinates of the bounding box of the face detection.
/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox].
///
/// The [allKeypoints] attribute is a list of 6 lists of 2 integers, representing the coordinates of the keypoints of the face detection.
/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order.
class FaceDetectionAbsolute extends Detection {
final List<double> box;
final List<List<double>> allKeypoints;
double get xMinBox => box[0];
double get yMinBox => box[1];
double get xMaxBox => box[2];
double get yMaxBox => box[3];
List<double> get leftEye => allKeypoints[0];
List<double> get rightEye => allKeypoints[1];
List<double> get nose => allKeypoints[2];
List<double> get leftMouth => allKeypoints[3];
List<double> get rightMouth => allKeypoints[4];
FaceDetectionAbsolute({
required double score,
required this.box,
required this.allKeypoints,
}) : super(score: score);
factory FaceDetectionAbsolute._zero() {
return FaceDetectionAbsolute(
score: 0,
box: <double>[0, 0, 0, 0],
allKeypoints: <List<double>>[
[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
],
);
}
FaceDetectionAbsolute.defaultInitialization()
: box = const <double>[0, 0, 0, 0],
allKeypoints = const <List<double>>[
[0, 0],
[0, 0],
[0, 0],
[0, 0],
[0, 0],
],
super.empty();
@override
String toString() {
return 'FaceDetectionAbsolute( with absolute coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )';
}
Map<String, dynamic> toJson() {
return {
'score': score,
'box': box,
'allKeypoints': allKeypoints,
};
}
factory FaceDetectionAbsolute.fromJson(Map<String, dynamic> json) {
return FaceDetectionAbsolute(
score: (json['score'] as num).toDouble(),
box: List<double>.from(json['box']),
allKeypoints: (json['allKeypoints'] as List)
.map((item) => List<double>.from(item))
.toList(),
);
}
static FaceDetectionAbsolute empty = FaceDetectionAbsolute._zero();
@override
/// The width of the bounding box of the face detection, in number of pixels, range [0, imageWidth].
double get width => xMaxBox - xMinBox;
@override
/// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight].
double get height => yMaxBox - yMinBox;
}
List<FaceDetectionAbsolute> relativeToAbsoluteDetections({
required List<FaceDetectionRelative> relativeDetections,
required int imageWidth,
required int imageHeight,
}) {
final numberOfDetections = relativeDetections.length;
final absoluteDetections = List<FaceDetectionAbsolute>.filled(
numberOfDetections,
FaceDetectionAbsolute._zero(),
);
for (var i = 0; i < relativeDetections.length; i++) {
final relativeDetection = relativeDetections[i];
final absoluteDetection = relativeDetection.toAbsolute(
imageWidth: imageWidth,
imageHeight: imageHeight,
);
absoluteDetections[i] = absoluteDetection;
}
return absoluteDetections;
}
/// Returns an enlarged version of the [box] by a factor of [factor].
List<double> getEnlargedRelativeBox(List<double> box, [double factor = 2]) {
final boxCopy = List<double>.from(box, growable: false);
// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox].
final width = boxCopy[2] - boxCopy[0];
final height = boxCopy[3] - boxCopy[1];
boxCopy[0] -= width * (factor - 1) / 2;
boxCopy[1] -= height * (factor - 1) / 2;
boxCopy[2] += width * (factor - 1) / 2;
boxCopy[3] += height * (factor - 1) / 2;
return boxCopy;
}

View file

@ -0,0 +1,49 @@
import 'dart:math' as math show max, min;
import "package:photos/services/face_ml/face_detection/detection.dart";
List<FaceDetectionRelative> naiveNonMaxSuppression({
required List<FaceDetectionRelative> detections,
required double iouThreshold,
}) {
// Sort the detections by score, the highest first
detections.sort((a, b) => b.score.compareTo(a.score));
// Loop through the detections and calculate the IOU
for (var i = 0; i < detections.length - 1; i++) {
for (var j = i + 1; j < detections.length; j++) {
final iou = _calculateIOU(detections[i], detections[j]);
if (iou >= iouThreshold) {
detections.removeAt(j);
j--;
}
}
}
return detections;
}
double _calculateIOU(
FaceDetectionRelative detectionA,
FaceDetectionRelative detectionB,
) {
final areaA = detectionA.width * detectionA.height;
final areaB = detectionB.width * detectionB.height;
final intersectionMinX = math.max(detectionA.xMinBox, detectionB.xMinBox);
final intersectionMinY = math.max(detectionA.yMinBox, detectionB.yMinBox);
final intersectionMaxX = math.min(detectionA.xMaxBox, detectionB.xMaxBox);
final intersectionMaxY = math.min(detectionA.yMaxBox, detectionB.yMaxBox);
final intersectionWidth = intersectionMaxX - intersectionMinX;
final intersectionHeight = intersectionMaxY - intersectionMinY;
if (intersectionWidth < 0 || intersectionHeight < 0) {
return 0.0; // If boxes do not overlap, IoU is 0
}
final intersectionArea = intersectionWidth * intersectionHeight;
final unionArea = areaA + areaB - intersectionArea;
return intersectionArea / unionArea;
}

View file

@ -0,0 +1,786 @@
import "dart:async";
import "dart:developer" as dev show log;
import "dart:io" show File;
import "dart:isolate";
import 'dart:typed_data' show Float32List, Uint8List;
import "package:computer/computer.dart";
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import 'package:onnxruntime/onnxruntime.dart';
import "package:photos/services/face_ml/face_detection/detection.dart";
import "package:photos/services/face_ml/face_detection/naive_non_max_suppression.dart";
import "package:photos/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart";
import "package:photos/services/face_ml/face_detection/yolov5face/yolo_filter_extract_detections.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_isolate.dart";
import "package:photos/utils/image_ml_util.dart";
import "package:synchronized/synchronized.dart";
enum FaceDetectionOperation { yoloInferenceAndPostProcessing }
class YoloOnnxFaceDetection {
static final _logger = Logger('YOLOFaceDetectionService');
final _computer = Computer.shared();
int sessionAddress = 0;
static const kModelBucketEndpoint = "https://models.ente.io/";
static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx";
// static const kRemoteBucketModelPath = "yolov5n_face_640_640.onnx";
static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath;
static const kInputWidth = 640;
static const kInputHeight = 640;
static const kIouThreshold = 0.4;
static const kMinScoreSigmoidThreshold = 0.8;
bool isInitialized = false;
// Isolate things
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(seconds: 30);
final _initLock = Lock();
final _computerLock = Lock();
late Isolate _isolate;
late ReceivePort _receivePort = ReceivePort();
late SendPort _mainSendPort;
bool isSpawned = false;
bool isRunning = false;
// singleton pattern
YoloOnnxFaceDetection._privateConstructor();
/// Use this instance to access the FaceDetection service. Make sure to call `init()` before using it.
/// e.g. `await FaceDetection.instance.init();`
///
/// Then you can use `predict()` to get the bounding boxes of the faces, so `FaceDetection.instance.predict(imageData)`
///
/// config options: yoloV5FaceN //
static final instance = YoloOnnxFaceDetection._privateConstructor();
factory YoloOnnxFaceDetection() => instance;
/// Check if the interpreter is initialized, if not initialize it with `loadModel()`
Future<void> init() async {
if (!isInitialized) {
_logger.info('init is called');
final model =
await RemoteAssetsService.instance.getAsset(modelRemotePath);
final startTime = DateTime.now();
// Doing this from main isolate since `rootBundle` cannot be accessed outside it
sessionAddress = await _computer.compute(
_loadModel,
param: {
"modelPath": model.path,
},
);
final endTime = DateTime.now();
_logger.info(
"Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
);
if (sessionAddress != -1) {
isInitialized = true;
}
}
}
Future<void> release() async {
if (isInitialized) {
await _computer
.compute(_releaseModel, param: {'address': sessionAddress});
isInitialized = false;
sessionAddress = 0;
}
}
Future<void> initIsolate() async {
return _initLock.synchronized(() async {
if (isSpawned) return;
_receivePort = ReceivePort();
try {
_isolate = await Isolate.spawn(
_isolateMain,
_receivePort.sendPort,
);
_mainSendPort = await _receivePort.first as SendPort;
isSpawned = true;
_resetInactivityTimer();
} catch (e) {
_logger.severe('Could not spawn isolate', e);
isSpawned = false;
}
});
}
Future<void> ensureSpawnedIsolate() async {
if (!isSpawned) {
await initIsolate();
}
}
/// The main execution function of the isolate.
static void _isolateMain(SendPort mainSendPort) async {
final receivePort = ReceivePort();
mainSendPort.send(receivePort.sendPort);
receivePort.listen((message) async {
final functionIndex = message[0] as int;
final function = FaceDetectionOperation.values[functionIndex];
final args = message[1] as Map<String, dynamic>;
final sendPort = message[2] as SendPort;
try {
switch (function) {
case FaceDetectionOperation.yoloInferenceAndPostProcessing:
final inputImageList = args['inputImageList'] as Float32List;
final inputShape = args['inputShape'] as List<int>;
final newSize = args['newSize'] as Size;
final sessionAddress = args['sessionAddress'] as int;
final timeSentToIsolate = args['timeNow'] as DateTime;
final delaySentToIsolate =
DateTime.now().difference(timeSentToIsolate).inMilliseconds;
final Stopwatch stopwatchPrepare = Stopwatch()..start();
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchPrepare.stop();
dev.log(
'[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms',
);
stopwatchPrepare.reset();
stopwatchPrepare.start();
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
stopwatchPrepare.stop();
dev.log(
'[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms',
);
final stopwatchInterpreter = Stopwatch()..start();
late final List<OrtValue?> outputs;
try {
outputs = session.run(runOptions, inputs);
} catch (e, s) {
dev.log(
'[YOLOFaceDetectionService] Error while running inference: $e \n $s',
);
throw YOLOInterpreterRunException();
}
stopwatchInterpreter.stop();
dev.log(
'[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
);
final relativeDetections =
_yoloPostProcessOutputs(outputs, newSize);
sendPort
.send((relativeDetections, delaySentToIsolate, DateTime.now()));
break;
}
} catch (e, stackTrace) {
sendPort
.send({'error': e.toString(), 'stackTrace': stackTrace.toString()});
}
});
}
/// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result.
Future<dynamic> _runInIsolate(
(FaceDetectionOperation, Map<String, dynamic>) message,
) async {
await ensureSpawnedIsolate();
_resetInactivityTimer();
final completer = Completer<dynamic>();
final answerPort = ReceivePort();
_mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]);
answerPort.listen((receivedMessage) {
if (receivedMessage is Map && receivedMessage.containsKey('error')) {
// Handle the error
final errorMessage = receivedMessage['error'];
final errorStackTrace = receivedMessage['stackTrace'];
final exception = Exception(errorMessage);
final stackTrace = StackTrace.fromString(errorStackTrace);
completer.completeError(exception, stackTrace);
} else {
completer.complete(receivedMessage);
}
});
return completer.future;
}
/// Resets a timer that kills the isolate after a certain amount of inactivity.
///
/// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`)
void _resetInactivityTimer() {
_inactivityTimer?.cancel();
_inactivityTimer = Timer(_inactivityDuration, () {
_logger.info(
'Face detection (YOLO ONNX) Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds. Killing isolate.',
);
disposeIsolate();
});
}
/// Disposes the isolate worker.
void disposeIsolate() {
if (!isSpawned) return;
isSpawned = false;
_isolate.kill();
_receivePort.close();
_inactivityTimer?.cancel();
}
/// Detects faces in the given image data.
Future<(List<FaceDetectionRelative>, Size)> predict(
Uint8List imageData,
) async {
assert(isInitialized);
final stopwatch = Stopwatch()..start();
final stopwatchDecoding = Stopwatch()..start();
final (inputImageList, originalSize, newSize) =
await ImageMlIsolate.instance.preprocessImageYoloOnnx(
imageData,
normalize: true,
requiredWidth: kInputWidth,
requiredHeight: kInputHeight,
maintainAspectRatio: true,
quality: FilterQuality.medium,
);
// final input = [inputImageList];
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchDecoding.stop();
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
_logger.info('original size: $originalSize \n new size: $newSize');
// Run inference
final stopwatchInterpreter = Stopwatch()..start();
List<OrtValue?>? outputs;
try {
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
outputs = session.run(runOptions, inputs);
// inputOrt.release();
// runOptions.release();
} catch (e, s) {
_logger.severe('Error while running inference: $e \n $s');
throw YOLOInterpreterRunException();
}
stopwatchInterpreter.stop();
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
);
final relativeDetections = _yoloPostProcessOutputs(outputs, newSize);
stopwatch.stop();
_logger.info(
'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms',
);
return (relativeDetections, originalSize);
}
/// Detects faces in the given image data.
static Future<(List<FaceDetectionRelative>, Size)> predictSync(
String imagePath,
int sessionAddress,
) async {
assert(sessionAddress != 0 && sessionAddress != -1);
final stopwatch = Stopwatch()..start();
final stopwatchDecoding = Stopwatch()..start();
final imageData = await File(imagePath).readAsBytes();
final (inputImageList, originalSize, newSize) =
await preprocessImageToFloat32ChannelsFirst(
imageData,
normalization: 1,
requiredWidth: kInputWidth,
requiredHeight: kInputHeight,
maintainAspectRatio: true,
quality: FilterQuality.medium,
);
// final input = [inputImageList];
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchDecoding.stop();
dev.log(
'Face detection image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
_logger.info('original size: $originalSize \n new size: $newSize');
// Run inference
final stopwatchInterpreter = Stopwatch()..start();
List<OrtValue?>? outputs;
try {
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
outputs = session.run(runOptions, inputs);
// inputOrt.release();
// runOptions.release();
} catch (e, s) {
_logger.severe('Error while running inference: $e \n $s');
throw YOLOInterpreterRunException();
}
stopwatchInterpreter.stop();
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
);
final relativeDetections = _yoloPostProcessOutputs(outputs, newSize);
stopwatch.stop();
_logger.info(
'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms',
);
return (relativeDetections, originalSize);
}
/// Detects faces in the given image data.
Future<(List<FaceDetectionRelative>, Size)> predictInIsolate(
Uint8List imageData,
) async {
await ensureSpawnedIsolate();
assert(isInitialized);
_logger.info('predictInIsolate() is called');
final stopwatch = Stopwatch()..start();
final stopwatchDecoding = Stopwatch()..start();
final (inputImageList, originalSize, newSize) =
await ImageMlIsolate.instance.preprocessImageYoloOnnx(
imageData,
normalize: true,
requiredWidth: kInputWidth,
requiredHeight: kInputHeight,
maintainAspectRatio: true,
quality: FilterQuality.medium,
);
// final input = [inputImageList];
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
stopwatchDecoding.stop();
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
_logger.info('original size: $originalSize \n new size: $newSize');
final (
List<FaceDetectionRelative> relativeDetections,
delaySentToIsolate,
timeSentToMain
) = await _runInIsolate(
(
FaceDetectionOperation.yoloInferenceAndPostProcessing,
{
'inputImageList': inputImageList,
'inputShape': inputShape,
'newSize': newSize,
'sessionAddress': sessionAddress,
'timeNow': DateTime.now(),
}
),
) as (List<FaceDetectionRelative>, int, DateTime);
final delaySentToMain =
DateTime.now().difference(timeSentToMain).inMilliseconds;
stopwatch.stop();
_logger.info(
'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate',
);
return (relativeDetections, originalSize);
}
Future<(List<FaceDetectionRelative>, Size)> predictInComputer(
String imagePath,
) async {
assert(isInitialized);
_logger.info('predictInComputer() is called');
final stopwatch = Stopwatch()..start();
final stopwatchDecoding = Stopwatch()..start();
final imageData = await File(imagePath).readAsBytes();
final (inputImageList, originalSize, newSize) =
await ImageMlIsolate.instance.preprocessImageYoloOnnx(
imageData,
normalize: true,
requiredWidth: kInputWidth,
requiredHeight: kInputHeight,
maintainAspectRatio: true,
quality: FilterQuality.medium,
);
// final input = [inputImageList];
return await _computerLock.synchronized(() async {
final inputShape = [
1,
3,
kInputHeight,
kInputWidth,
];
stopwatchDecoding.stop();
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
_logger.info('original size: $originalSize \n new size: $newSize');
final (
List<FaceDetectionRelative> relativeDetections,
delaySentToIsolate,
timeSentToMain
) = await _computer.compute(
inferenceAndPostProcess,
param: {
'inputImageList': inputImageList,
'inputShape': inputShape,
'newSize': newSize,
'sessionAddress': sessionAddress,
'timeNow': DateTime.now(),
},
) as (List<FaceDetectionRelative>, int, DateTime);
final delaySentToMain =
DateTime.now().difference(timeSentToMain).inMilliseconds;
stopwatch.stop();
_logger.info(
'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate',
);
return (relativeDetections, originalSize);
});
}
/// Detects faces in the given image data.
/// This method is optimized for batch processing.
///
/// `imageDataList`: The image data to analyze.
///
/// WARNING: Currently this method only returns the detections for the first image in the batch.
/// Change the function to output all detection before actually using it in production.
Future<List<FaceDetectionRelative>> predictBatch(
List<Uint8List> imageDataList,
) async {
assert(isInitialized);
final stopwatch = Stopwatch()..start();
final stopwatchDecoding = Stopwatch()..start();
final List<Float32List> inputImageDataLists = [];
final List<(Size, Size)> originalAndNewSizeList = [];
int concatenatedImageInputsLength = 0;
for (final imageData in imageDataList) {
final (inputImageList, originalSize, newSize) =
await ImageMlIsolate.instance.preprocessImageYoloOnnx(
imageData,
normalize: true,
requiredWidth: kInputWidth,
requiredHeight: kInputHeight,
maintainAspectRatio: true,
quality: FilterQuality.medium,
);
inputImageDataLists.add(inputImageList);
originalAndNewSizeList.add((originalSize, newSize));
concatenatedImageInputsLength += inputImageList.length;
}
final inputImageList = Float32List(concatenatedImageInputsLength);
int offset = 0;
for (int i = 0; i < inputImageDataLists.length; i++) {
final inputImageData = inputImageDataLists[i];
inputImageList.setRange(
offset,
offset + inputImageData.length,
inputImageData,
);
offset += inputImageData.length;
}
// final input = [inputImageList];
final inputShape = [
inputImageDataLists.length,
3,
kInputHeight,
kInputWidth,
];
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchDecoding.stop();
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
// _logger.info('original size: $originalSize \n new size: $newSize');
_logger.info('interpreter.run is called');
// Run inference
final stopwatchInterpreter = Stopwatch()..start();
List<OrtValue?>? outputs;
try {
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
outputs = session.run(runOptions, inputs);
inputOrt.release();
runOptions.release();
} catch (e, s) {
_logger.severe('Error while running inference: $e \n $s');
throw YOLOInterpreterRunException();
}
stopwatchInterpreter.stop();
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms, or ${stopwatchInterpreter.elapsedMilliseconds / inputImageDataLists.length} ms per image',
);
_logger.info('outputs: $outputs');
const int imageOutputToUse = 0;
// // Get output tensors
final nestedResults =
outputs[0]?.value as List<List<List<double>>>; // [b, 25200, 16]
final selectedResults = nestedResults[imageOutputToUse]; // [25200, 16]
// final rawScores = <double>[];
// for (final result in firstResults) {
// rawScores.add(result[4]);
// }
// final rawScoresCopy = List<double>.from(rawScores);
// rawScoresCopy.sort();
// _logger.info('rawScores minimum: ${rawScoresCopy.first}');
// _logger.info('rawScores maximum: ${rawScoresCopy.last}');
var relativeDetections = yoloOnnxFilterExtractDetections(
kMinScoreSigmoidThreshold,
kInputWidth,
kInputHeight,
results: selectedResults,
);
// Release outputs
for (var element in outputs) {
element?.release();
}
// Account for the fact that the aspect ratio was maintained
for (final faceDetection in relativeDetections) {
faceDetection.correctForMaintainedAspectRatio(
Size(
kInputWidth.toDouble(),
kInputHeight.toDouble(),
),
originalAndNewSizeList[imageOutputToUse].$2,
);
}
// Non-maximum suppression to remove duplicate detections
relativeDetections = naiveNonMaxSuppression(
detections: relativeDetections,
iouThreshold: kIouThreshold,
);
if (relativeDetections.isEmpty) {
_logger.info('No face detected');
return <FaceDetectionRelative>[];
}
stopwatch.stop();
_logger.info(
'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms',
);
return relativeDetections;
}
static List<FaceDetectionRelative> _yoloPostProcessOutputs(
List<OrtValue?>? outputs,
Size newSize,
) {
// // Get output tensors
final nestedResults =
outputs?[0]?.value as List<List<List<double>>>; // [1, 25200, 16]
final firstResults = nestedResults[0]; // [25200, 16]
// final rawScores = <double>[];
// for (final result in firstResults) {
// rawScores.add(result[4]);
// }
// final rawScoresCopy = List<double>.from(rawScores);
// rawScoresCopy.sort();
// _logger.info('rawScores minimum: ${rawScoresCopy.first}');
// _logger.info('rawScores maximum: ${rawScoresCopy.last}');
var relativeDetections = yoloOnnxFilterExtractDetections(
kMinScoreSigmoidThreshold,
kInputWidth,
kInputHeight,
results: firstResults,
);
// Release outputs
// outputs?.forEach((element) {
// element?.release();
// });
// Account for the fact that the aspect ratio was maintained
for (final faceDetection in relativeDetections) {
faceDetection.correctForMaintainedAspectRatio(
Size(
kInputWidth.toDouble(),
kInputHeight.toDouble(),
),
newSize,
);
}
// Non-maximum suppression to remove duplicate detections
relativeDetections = naiveNonMaxSuppression(
detections: relativeDetections,
iouThreshold: kIouThreshold,
);
dev.log(
'[YOLOFaceDetectionService] ${relativeDetections.length} faces detected',
);
return relativeDetections;
}
/// Initialize the interpreter by loading the model file.
static Future<int> _loadModel(Map args) async {
final sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try {
// _logger.info('Loading face embedding model');
final session =
OrtSession.fromFile(File(args["modelPath"]), sessionOptions);
// _logger.info('Face embedding model loaded');
return session.address;
} catch (e, _) {
// _logger.severe('Face embedding model not loaded', e, s);
}
return -1;
}
static Future<void> _releaseModel(Map args) async {
final address = args['address'] as int;
if (address == 0) {
return;
}
final session = OrtSession.fromAddress(address);
session.release();
return;
}
static Future<(List<FaceDetectionRelative>, int, DateTime)>
inferenceAndPostProcess(
Map args,
) async {
final inputImageList = args['inputImageList'] as Float32List;
final inputShape = args['inputShape'] as List<int>;
final newSize = args['newSize'] as Size;
final sessionAddress = args['sessionAddress'] as int;
final timeSentToIsolate = args['timeNow'] as DateTime;
final delaySentToIsolate =
DateTime.now().difference(timeSentToIsolate).inMilliseconds;
final Stopwatch stopwatchPrepare = Stopwatch()..start();
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
inputShape,
);
final inputs = {'input': inputOrt};
stopwatchPrepare.stop();
dev.log(
'[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms',
);
stopwatchPrepare.reset();
stopwatchPrepare.start();
final runOptions = OrtRunOptions();
final session = OrtSession.fromAddress(sessionAddress);
stopwatchPrepare.stop();
dev.log(
'[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms',
);
final stopwatchInterpreter = Stopwatch()..start();
late final List<OrtValue?> outputs;
try {
outputs = session.run(runOptions, inputs);
} catch (e, s) {
dev.log(
'[YOLOFaceDetectionService] Error while running inference: $e \n $s',
);
throw YOLOInterpreterRunException();
}
stopwatchInterpreter.stop();
dev.log(
'[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms',
);
final relativeDetections = _yoloPostProcessOutputs(outputs, newSize);
return (relativeDetections, delaySentToIsolate, DateTime.now());
}
}

View file

@ -0,0 +1,3 @@
class YOLOInterpreterInitializationException implements Exception {}
class YOLOInterpreterRunException implements Exception {}

View file

@ -0,0 +1,31 @@
import 'dart:math' as math show log;
class FaceDetectionOptionsYOLO {
final double minScoreSigmoidThreshold;
final double iouThreshold;
final int inputWidth;
final int inputHeight;
final int numCoords;
final int numKeypoints;
final int numValuesPerKeypoint;
final int maxNumFaces;
final double scoreClippingThresh;
final double inverseSigmoidMinScoreThreshold;
final bool useSigmoidScore;
final bool flipVertically;
FaceDetectionOptionsYOLO({
required this.minScoreSigmoidThreshold,
required this.iouThreshold,
required this.inputWidth,
required this.inputHeight,
this.numCoords = 14,
this.numKeypoints = 5,
this.numValuesPerKeypoint = 2,
this.maxNumFaces = 100,
this.scoreClippingThresh = 100.0,
this.useSigmoidScore = true,
this.flipVertically = false,
}) : inverseSigmoidMinScoreThreshold =
math.log(minScoreSigmoidThreshold / (1 - minScoreSigmoidThreshold));
}

View file

@ -0,0 +1,81 @@
import "package:photos/services/face_ml/face_detection/detection.dart";
List<FaceDetectionRelative> yoloOnnxFilterExtractDetections(
double minScoreSigmoidThreshold,
int inputWidth,
int inputHeight, {
required List<List<double>> results, // // [25200, 16]
}) {
final outputDetections = <FaceDetectionRelative>[];
final output = <List<double>>[];
// Go through the raw output and check the scores
for (final result in results) {
// Filter out raw detections with low scores
if (result[4] < minScoreSigmoidThreshold) {
continue;
}
// Get the raw detection
final rawDetection = List<double>.from(result);
// Append the processed raw detection to the output
output.add(rawDetection);
}
for (final List<double> rawDetection in output) {
// Get absolute bounding box coordinates in format [xMin, yMin, xMax, yMax] https://github.com/deepcam-cn/yolov5-face/blob/eb23d18defe4a76cc06449a61cd51004c59d2697/utils/general.py#L216
final xMinAbs = rawDetection[0] - rawDetection[2] / 2;
final yMinAbs = rawDetection[1] - rawDetection[3] / 2;
final xMaxAbs = rawDetection[0] + rawDetection[2] / 2;
final yMaxAbs = rawDetection[1] + rawDetection[3] / 2;
// Get the relative bounding box coordinates in format [xMin, yMin, xMax, yMax]
final box = [
xMinAbs / inputWidth,
yMinAbs / inputHeight,
xMaxAbs / inputWidth,
yMaxAbs / inputHeight,
];
// Get the keypoints coordinates in format [x, y]
final allKeypoints = <List<double>>[
[
rawDetection[5] / inputWidth,
rawDetection[6] / inputHeight,
],
[
rawDetection[7] / inputWidth,
rawDetection[8] / inputHeight,
],
[
rawDetection[9] / inputWidth,
rawDetection[10] / inputHeight,
],
[
rawDetection[11] / inputWidth,
rawDetection[12] / inputHeight,
],
[
rawDetection[13] / inputWidth,
rawDetection[14] / inputHeight,
],
];
// Get the score
final score =
rawDetection[4]; // Or should it be rawDetection[4]*rawDetection[15]?
// Create the relative detection
final detection = FaceDetectionRelative(
score: score,
box: box,
allKeypoints: allKeypoints,
);
// Append the relative detection to the output
outputDetections.add(detection);
}
return outputDetections;
}

View file

@ -0,0 +1,22 @@
import "package:photos/services/face_ml/face_detection/yolov5face/yolo_face_detection_options.dart";
import "package:photos/services/face_ml/model_file.dart";
class YOLOModelConfig {
final String modelPath;
final FaceDetectionOptionsYOLO faceOptions;
YOLOModelConfig({
required this.modelPath,
required this.faceOptions,
});
}
final YOLOModelConfig yoloV5FaceS640x640DynamicBatchonnx = YOLOModelConfig(
modelPath: ModelFile.yoloV5FaceS640x640DynamicBatchonnx,
faceOptions: FaceDetectionOptionsYOLO(
minScoreSigmoidThreshold: 0.8,
iouThreshold: 0.4,
inputWidth: 640,
inputHeight: 640,
),
);

View file

@ -0,0 +1,11 @@
class MobileFaceNetInterpreterInitializationException implements Exception {}
class MobileFaceNetImagePreprocessingException implements Exception {}
class MobileFaceNetEmptyInput implements Exception {}
class MobileFaceNetWrongInputSize implements Exception {}
class MobileFaceNetWrongInputRange implements Exception {}
class MobileFaceNetInterpreterRunException implements Exception {}

View file

@ -0,0 +1,15 @@
class FaceEmbeddingOptions {
final int inputWidth;
final int inputHeight;
final int embeddingLength;
final int numChannels;
final bool preWhiten;
FaceEmbeddingOptions({
required this.inputWidth,
required this.inputHeight,
this.embeddingLength = 192,
this.numChannels = 3,
this.preWhiten = false,
});
}

View file

@ -0,0 +1,279 @@
import 'dart:io';
import "dart:math" show min, max, sqrt;
// import 'dart:math' as math show min, max;
import 'dart:typed_data' show Uint8List;
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import 'package:photos/models/ml/ml_typedefs.dart';
import "package:photos/services/face_ml/face_detection/detection.dart";
import "package:photos/services/face_ml/face_embedding/face_embedding_exceptions.dart";
import "package:photos/services/face_ml/face_embedding/face_embedding_options.dart";
import "package:photos/services/face_ml/face_embedding/mobilefacenet_model_config.dart";
import 'package:photos/utils/image_ml_isolate.dart';
import 'package:photos/utils/image_ml_util.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
/// This class is responsible for running the MobileFaceNet model, and can be accessed through the singleton `FaceEmbedding.instance`.
class FaceEmbedding {
Interpreter? _interpreter;
IsolateInterpreter? _isolateInterpreter;
int get getAddress => _interpreter!.address;
final outputShapes = <List<int>>[];
final outputTypes = <TensorType>[];
final _logger = Logger("FaceEmbeddingService");
final MobileFaceNetModelConfig config;
final FaceEmbeddingOptions embeddingOptions;
// singleton pattern
FaceEmbedding._privateConstructor({required this.config})
: embeddingOptions = config.faceEmbeddingOptions;
/// Use this instance to access the FaceEmbedding service. Make sure to call `init()` before using it.
/// e.g. `await FaceEmbedding.instance.init();`
///
/// Then you can use `predict()` to get the embedding of a face, so `FaceEmbedding.instance.predict(imageData)`
///
/// config options: faceEmbeddingEnte
static final instance =
FaceEmbedding._privateConstructor(config: faceEmbeddingEnte);
factory FaceEmbedding() => instance;
/// Check if the interpreter is initialized, if not initialize it with `loadModel()`
Future<void> init() async {
if (_interpreter == null || _isolateInterpreter == null) {
await _loadModel();
}
}
Future<void> dispose() async {
_logger.info('dispose() is called');
try {
_interpreter?.close();
_interpreter = null;
await _isolateInterpreter?.close();
_isolateInterpreter = null;
} catch (e) {
_logger.severe('Error while closing interpreter: $e');
rethrow;
}
}
/// WARNING: This function only works for one face at a time. it's better to use [predict], which can handle both single and multiple faces.
Future<List<double>> predictSingle(
Uint8List imageData,
FaceDetectionRelative face,
) async {
assert(_interpreter != null && _isolateInterpreter != null);
final stopwatch = Stopwatch()..start();
// Image decoding and preprocessing
List<List<List<List<num>>>> input;
List output;
try {
final stopwatchDecoding = Stopwatch()..start();
final (inputImageMatrix, _, _, _, _) =
await ImageMlIsolate.instance.preprocessMobileFaceNet(
imageData,
[face],
);
input = inputImageMatrix;
stopwatchDecoding.stop();
_logger.info(
'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
output = createEmptyOutputMatrix(outputShapes[0]);
} catch (e) {
_logger.severe('Error while decoding and preprocessing image: $e');
throw MobileFaceNetImagePreprocessingException();
}
_logger.info('interpreter.run is called');
// Run inference
try {
await _isolateInterpreter!.run(input, output);
// _interpreter!.run(input, output);
// ignore: avoid_catches_without_on_clauses
} catch (e) {
_logger.severe('Error while running inference: $e');
throw MobileFaceNetInterpreterRunException();
}
_logger.info('interpreter.run is finished');
// Get output tensors
final embedding = output[0] as List<double>;
// Normalize the embedding
final norm = sqrt(embedding.map((e) => e * e).reduce((a, b) => a + b));
for (int i = 0; i < embedding.length; i++) {
embedding[i] /= norm;
}
stopwatch.stop();
_logger.info(
'predict() executed in ${stopwatch.elapsedMilliseconds}ms',
);
// _logger.info(
// 'results (only first few numbers): embedding ${embedding.sublist(0, 5)}',
// );
// _logger.info(
// 'Mean of embedding: ${embedding.reduce((a, b) => a + b) / embedding.length}',
// );
// _logger.info(
// 'Max of embedding: ${embedding.reduce(math.max)}',
// );
// _logger.info(
// 'Min of embedding: ${embedding.reduce(math.min)}',
// );
return embedding;
}
Future<List<List<double>>> predict(
List<Num3DInputMatrix> inputImageMatrix,
) async {
assert(_interpreter != null && _isolateInterpreter != null);
final stopwatch = Stopwatch()..start();
_checkPreprocessedInput(inputImageMatrix); // [inputHeight, inputWidth, 3]
final input = [inputImageMatrix];
// await encodeAndSaveData(inputImageMatrix, 'input_mobilefacenet');
final output = <int, Object>{};
final outputShape = outputShapes[0];
outputShape[0] = inputImageMatrix.length;
output[0] = createEmptyOutputMatrix(outputShape);
// for (int i = 0; i < faces.length; i++) {
// output[i] = createEmptyOutputMatrix(outputShapes[0]);
// }
_logger.info('interpreter.run is called');
// Run inference
final stopwatchInterpreter = Stopwatch()..start();
try {
await _isolateInterpreter!.runForMultipleInputs(input, output);
// _interpreter!.runForMultipleInputs(input, output);
// ignore: avoid_catches_without_on_clauses
} catch (e) {
_logger.severe('Error while running inference: $e');
throw MobileFaceNetInterpreterRunException();
}
stopwatchInterpreter.stop();
_logger.info(
'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds}ms',
);
// _logger.info('output: $output');
// Get output tensors
final embeddings = <List<double>>[];
final outerEmbedding = output[0]! as Iterable<dynamic>;
for (int i = 0; i < inputImageMatrix.length; i++) {
final embedding = List<double>.from(outerEmbedding.toList()[i]);
// _logger.info("The $i-th embedding: $embedding");
embeddings.add(embedding);
}
// await encodeAndSaveData(embeddings, 'output_mobilefacenet');
// Normalize the embedding
for (int i = 0; i < embeddings.length; i++) {
final embedding = embeddings[i];
final norm = sqrt(embedding.map((e) => e * e).reduce((a, b) => a + b));
for (int j = 0; j < embedding.length; j++) {
embedding[j] /= norm;
}
}
stopwatch.stop();
_logger.info(
'predictBatch() executed in ${stopwatch.elapsedMilliseconds}ms',
);
return embeddings;
}
Future<void> _loadModel() async {
_logger.info('loadModel is called');
try {
final interpreterOptions = InterpreterOptions();
// Android Delegates
// TODO: Make sure this works on both platforms: Android and iOS
if (Platform.isAndroid) {
// Use GPU Delegate (GPU). WARNING: It doesn't work on emulator
// if (!kDebugMode) {
// interpreterOptions.addDelegate(GpuDelegateV2());
// }
// Use XNNPACK Delegate (CPU)
interpreterOptions.addDelegate(XNNPackDelegate());
}
// iOS Delegates
if (Platform.isIOS) {
// Use Metal Delegate (GPU)
interpreterOptions.addDelegate(GpuDelegate());
}
// Load model from assets
_interpreter ??= await Interpreter.fromAsset(
config.modelPath,
options: interpreterOptions,
);
_isolateInterpreter ??=
await IsolateInterpreter.create(address: _interpreter!.address);
_logger.info('Interpreter created from asset: ${config.modelPath}');
// Get tensor input shape [1, 112, 112, 3]
final inputTensors = _interpreter!.getInputTensors().first;
_logger.info('Input Tensors: $inputTensors');
// Get tensour output shape [1, 192]
final outputTensors = _interpreter!.getOutputTensors();
final outputTensor = outputTensors.first;
_logger.info('Output Tensors: $outputTensor');
for (var tensor in outputTensors) {
outputShapes.add(tensor.shape);
outputTypes.add(tensor.type);
}
_logger.info('outputShapes: $outputShapes');
_logger.info('loadModel is finished');
// ignore: avoid_catches_without_on_clauses
} catch (e) {
_logger.severe('Error while creating interpreter: $e');
throw MobileFaceNetInterpreterInitializationException();
}
}
void _checkPreprocessedInput(
List<Num3DInputMatrix> inputMatrix,
) {
final embeddingOptions = config.faceEmbeddingOptions;
if (inputMatrix.isEmpty) {
// Check if the input is empty
throw MobileFaceNetEmptyInput();
}
// Check if the input is the correct size
if (inputMatrix[0].length != embeddingOptions.inputHeight ||
inputMatrix[0][0].length != embeddingOptions.inputWidth) {
throw MobileFaceNetWrongInputSize();
}
final flattened = inputMatrix[0].expand((i) => i).expand((i) => i);
final minValue = flattened.reduce(min);
final maxValue = flattened.reduce(max);
if (minValue < -1 || maxValue > 1) {
throw MobileFaceNetWrongInputRange();
}
}
}

View file

@ -0,0 +1,20 @@
import "package:photos/services/face_ml/face_embedding/face_embedding_options.dart";
import "package:photos/services/face_ml/model_file.dart";
class MobileFaceNetModelConfig {
final String modelPath;
final FaceEmbeddingOptions faceEmbeddingOptions;
MobileFaceNetModelConfig({
required this.modelPath,
required this.faceEmbeddingOptions,
});
}
final MobileFaceNetModelConfig faceEmbeddingEnte = MobileFaceNetModelConfig(
modelPath: ModelFile.faceEmbeddingEnte,
faceEmbeddingOptions: FaceEmbeddingOptions(
inputWidth: 112,
inputHeight: 112,
),
);

View file

@ -0,0 +1,245 @@
import "dart:io" show File;
import 'dart:math' as math show max, min, sqrt;
import 'dart:typed_data' show Float32List;
import 'package:computer/computer.dart';
import 'package:logging/logging.dart';
import 'package:onnxruntime/onnxruntime.dart';
import "package:photos/services/face_ml/face_detection/detection.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/utils/image_ml_isolate.dart";
import "package:synchronized/synchronized.dart";
class FaceEmbeddingOnnx {
static const kModelBucketEndpoint = "https://models.ente.io/";
static const kRemoteBucketModelPath = "mobilefacenet_opset15.onnx";
static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath;
static const int kInputSize = 112;
static const int kEmbeddingSize = 192;
static final _logger = Logger('FaceEmbeddingOnnx');
bool isInitialized = false;
int sessionAddress = 0;
final _computer = Computer.shared();
final _computerLock = Lock();
// singleton pattern
FaceEmbeddingOnnx._privateConstructor();
/// Use this instance to access the FaceEmbedding service. Make sure to call `init()` before using it.
/// e.g. `await FaceEmbedding.instance.init();`
///
/// Then you can use `predict()` to get the embedding of a face, so `FaceEmbedding.instance.predict(imageData)`
///
/// config options: faceEmbeddingEnte
static final instance = FaceEmbeddingOnnx._privateConstructor();
factory FaceEmbeddingOnnx() => instance;
/// Check if the interpreter is initialized, if not initialize it with `loadModel()`
Future<void> init() async {
if (!isInitialized) {
_logger.info('init is called');
final model =
await RemoteAssetsService.instance.getAsset(modelRemotePath);
final startTime = DateTime.now();
// Doing this from main isolate since `rootBundle` cannot be accessed outside it
sessionAddress = await _computer.compute(
_loadModel,
param: {
"modelPath": model.path,
},
);
final endTime = DateTime.now();
_logger.info(
"Face embedding model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
);
if (sessionAddress != -1) {
isInitialized = true;
}
}
}
Future<void> release() async {
if (isInitialized) {
await _computer.compute(_releaseModel, param: {'address': sessionAddress});
isInitialized = false;
sessionAddress = 0;
}
}
static Future<int> _loadModel(Map args) async {
final sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try {
// _logger.info('Loading face embedding model');
final session =
OrtSession.fromFile(File(args["modelPath"]), sessionOptions);
// _logger.info('Face embedding model loaded');
return session.address;
} catch (e, _) {
// _logger.severe('Face embedding model not loaded', e, s);
}
return -1;
}
static Future<void> _releaseModel(Map args) async {
final address = args['address'] as int;
if (address == 0) {
return;
}
final session = OrtSession.fromAddress(address);
session.release();
return;
}
Future<(List<double>, bool, double)> predictFromImageDataInComputer(
String imagePath,
FaceDetectionRelative face,
) async {
assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized);
try {
final stopwatchDecoding = Stopwatch()..start();
final (inputImageList, alignmentResults, isBlur, blurValue, _) =
await ImageMlIsolate.instance.preprocessMobileFaceNetOnnx(
imagePath,
[face],
);
stopwatchDecoding.stop();
_logger.info(
'MobileFaceNet image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms',
);
final stopwatch = Stopwatch()..start();
_logger.info('MobileFaceNet interpreter.run is called');
final embedding = await _computer.compute(
inferFromMap,
param: {
'input': inputImageList,
'address': sessionAddress,
'inputSize': kInputSize,
},
taskName: 'createFaceEmbedding',
) as List<double>;
stopwatch.stop();
_logger.info(
'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms',
);
_logger.info(
'MobileFaceNet results (only first few numbers): embedding ${embedding.sublist(0, 5)}',
);
_logger.info(
'Mean of embedding: ${embedding.reduce((a, b) => a + b) / embedding.length}',
);
_logger.info(
'Max of embedding: ${embedding.reduce(math.max)}',
);
_logger.info(
'Min of embedding: ${embedding.reduce(math.min)}',
);
return (embedding, isBlur[0], blurValue[0]);
} catch (e) {
_logger.info('MobileFaceNet Error while running inference: $e');
rethrow;
}
}
Future<List<List<double>>> predictInComputer(Float32List input) async {
assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized);
return await _computerLock.synchronized(() async {
try {
final stopwatch = Stopwatch()..start();
_logger.info('MobileFaceNet interpreter.run is called');
final embeddings = await _computer.compute(
inferFromMap,
param: {
'input': input,
'address': sessionAddress,
'inputSize': kInputSize,
},
taskName: 'createFaceEmbedding',
) as List<List<double>>;
stopwatch.stop();
_logger.info(
'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms',
);
return embeddings;
} catch (e) {
_logger.info('MobileFaceNet Error while running inference: $e');
rethrow;
}
});
}
static Future<List<List<double>>> predictSync(
Float32List input,
int sessionAddress,
) async {
assert(sessionAddress != 0 && sessionAddress != -1);
try {
final stopwatch = Stopwatch()..start();
_logger.info('MobileFaceNet interpreter.run is called');
final embeddings = await infer(
input,
sessionAddress,
kInputSize,
);
stopwatch.stop();
_logger.info(
'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms',
);
return embeddings;
} catch (e) {
_logger.info('MobileFaceNet Error while running inference: $e');
rethrow;
}
}
static Future<List<List<double>>> inferFromMap(Map args) async {
final inputImageList = args['input'] as Float32List;
final address = args['address'] as int;
final inputSize = args['inputSize'] as int;
return await infer(inputImageList, address, inputSize);
}
static Future<List<List<double>>> infer(
Float32List inputImageList,
int address,
int inputSize,
) async {
final runOptions = OrtRunOptions();
final int numberOfFaces =
inputImageList.length ~/ (inputSize * inputSize * 3);
final inputOrt = OrtValueTensor.createTensorWithDataList(
inputImageList,
[numberOfFaces, inputSize, inputSize, 3],
);
final inputs = {'img_inputs': inputOrt};
final session = OrtSession.fromAddress(address);
final List<OrtValue?> outputs = session.run(runOptions, inputs);
final embeddings = outputs[0]?.value as List<List<double>>;
for (final embedding in embeddings) {
double normalization = 0;
for (int i = 0; i < kEmbeddingSize; i++) {
normalization += embedding[i] * embedding[i];
}
final double sqrtNormalization = math.sqrt(normalization);
for (int i = 0; i < kEmbeddingSize; i++) {
embedding[i] = embedding[i] / sqrtNormalization;
}
}
return embeddings;
}
}

View file

@ -0,0 +1,379 @@
import "dart:convert";
import "package:photos/services/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/face_ml/face_feedback.dart/feedback.dart";
import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart";
abstract class ClusterFeedback extends Feedback {
static final Map<FeedbackType, Function(String)> fromJsonStringRegistry = {
FeedbackType.deleteClusterFeedback: DeleteClusterFeedback.fromJsonString,
FeedbackType.mergeClusterFeedback: MergeClusterFeedback.fromJsonString,
FeedbackType.renameOrCustomThumbnailClusterFeedback:
RenameOrCustomThumbnailClusterFeedback.fromJsonString,
FeedbackType.removePhotosClusterFeedback:
RemovePhotosClusterFeedback.fromJsonString,
FeedbackType.addPhotosClusterFeedback:
AddPhotosClusterFeedback.fromJsonString,
};
final List<double> medoid;
final double medoidDistanceThreshold;
// TODO: work out the optimal distance threshold so there's never an overlap between clusters
ClusterFeedback(
FeedbackType type,
this.medoid,
this.medoidDistanceThreshold, {
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : super(
type,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
/// Compares this feedback with another [ClusterFeedback] to see if they are similar enough that only one should be kept.
///
/// It checks this by comparing the distance between the two medoids with the medoidDistanceThreshold of each feedback.
///
/// Returns true if they are similar enough, false otherwise.
/// // TODO: Should it maybe return a merged feedback instead, when you are similar enough?
bool looselyMatchesMedoid(ClusterFeedback other) {
// Using the cosineDistance function you mentioned
final double distance = cosineDistance(medoid, other.medoid);
// Check if the distance is less than either of the threshold values
return distance < medoidDistanceThreshold ||
distance < other.medoidDistanceThreshold;
}
bool exactlyMatchesMedoid(ClusterFeedback other) {
if (medoid.length != other.medoid.length) {
return false;
}
for (int i = 0; i < medoid.length; i++) {
if (medoid[i] != other.medoid[i]) {
return false;
}
}
return true;
}
}
class DeleteClusterFeedback extends ClusterFeedback {
DeleteClusterFeedback({
required List<double> medoid,
required double medoidDistanceThreshold,
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : super(
FeedbackType.deleteClusterFeedback,
medoid,
medoidDistanceThreshold,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
@override
Map<String, dynamic> toJson() {
return {
'type': type.toValueString(),
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
'feedbackID': feedbackID,
'timestamp': timestamp.toIso8601String(),
'madeOnFaceMlVersion': madeOnFaceMlVersion,
'madeOnClusterMlVersion': madeOnClusterMlVersion,
};
}
@override
String toJsonString() => jsonEncode(toJson());
static DeleteClusterFeedback fromJson(Map<String, dynamic> json) {
assert(json['type'] == FeedbackType.deleteClusterFeedback.toValueString());
return DeleteClusterFeedback(
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'],
feedbackID: json['feedbackID'],
timestamp: DateTime.parse(json['timestamp']),
madeOnFaceMlVersion: json['madeOnFaceMlVersion'],
madeOnClusterMlVersion: json['madeOnClusterMlVersion'],
);
}
static fromJsonString(String jsonString) {
return fromJson(jsonDecode(jsonString));
}
}
class MergeClusterFeedback extends ClusterFeedback {
final List<double> medoidToMoveTo;
MergeClusterFeedback({
required List<double> medoid,
required double medoidDistanceThreshold,
required this.medoidToMoveTo,
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : super(
FeedbackType.mergeClusterFeedback,
medoid,
medoidDistanceThreshold,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
@override
Map<String, dynamic> toJson() {
return {
'type': type.toValueString(),
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
'medoidToMoveTo': medoidToMoveTo,
'feedbackID': feedbackID,
'timestamp': timestamp.toIso8601String(),
'madeOnFaceMlVersion': madeOnFaceMlVersion,
'madeOnClusterMlVersion': madeOnClusterMlVersion,
};
}
@override
String toJsonString() => jsonEncode(toJson());
static MergeClusterFeedback fromJson(Map<String, dynamic> json) {
assert(json['type'] == FeedbackType.mergeClusterFeedback.toValueString());
return MergeClusterFeedback(
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'],
medoidToMoveTo: (json['medoidToMoveTo'] as List?)
?.map((item) => item as double)
.toList() ??
[],
feedbackID: json['feedbackID'],
timestamp: DateTime.parse(json['timestamp']),
madeOnFaceMlVersion: json['madeOnFaceMlVersion'],
madeOnClusterMlVersion: json['madeOnClusterMlVersion'],
);
}
static MergeClusterFeedback fromJsonString(String jsonString) {
return fromJson(jsonDecode(jsonString));
}
}
class RenameOrCustomThumbnailClusterFeedback extends ClusterFeedback {
String? customName;
String? customThumbnailFaceId;
RenameOrCustomThumbnailClusterFeedback({
required List<double> medoid,
required double medoidDistanceThreshold,
this.customName,
this.customThumbnailFaceId,
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : assert(
customName != null || customThumbnailFaceId != null,
"Either customName or customThumbnailFaceId must be non-null!",
),
super(
FeedbackType.renameOrCustomThumbnailClusterFeedback,
medoid,
medoidDistanceThreshold,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
@override
Map<String, dynamic> toJson() {
return {
'type': type.toValueString(),
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
if (customName != null) 'customName': customName,
if (customThumbnailFaceId != null)
'customThumbnailFaceId': customThumbnailFaceId,
'feedbackID': feedbackID,
'timestamp': timestamp.toIso8601String(),
'madeOnFaceMlVersion': madeOnFaceMlVersion,
'madeOnClusterMlVersion': madeOnClusterMlVersion,
};
}
@override
String toJsonString() => jsonEncode(toJson());
static RenameOrCustomThumbnailClusterFeedback fromJson(
Map<String, dynamic> json,
) {
assert(
json['type'] ==
FeedbackType.renameOrCustomThumbnailClusterFeedback.toValueString(),
);
return RenameOrCustomThumbnailClusterFeedback(
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'],
customName: json['customName'],
customThumbnailFaceId: json['customThumbnailFaceId'],
feedbackID: json['feedbackID'],
timestamp: DateTime.parse(json['timestamp']),
madeOnFaceMlVersion: json['madeOnFaceMlVersion'],
madeOnClusterMlVersion: json['madeOnClusterMlVersion'],
);
}
static RenameOrCustomThumbnailClusterFeedback fromJsonString(
String jsonString,
) {
return fromJson(jsonDecode(jsonString));
}
}
class RemovePhotosClusterFeedback extends ClusterFeedback {
final List<int> removedPhotosFileID;
RemovePhotosClusterFeedback({
required List<double> medoid,
required double medoidDistanceThreshold,
required this.removedPhotosFileID,
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : super(
FeedbackType.removePhotosClusterFeedback,
medoid,
medoidDistanceThreshold,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
@override
Map<String, dynamic> toJson() {
return {
'type': type.toValueString(),
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
'removedPhotosFileID': removedPhotosFileID,
'feedbackID': feedbackID,
'timestamp': timestamp.toIso8601String(),
'madeOnFaceMlVersion': madeOnFaceMlVersion,
'madeOnClusterMlVersion': madeOnClusterMlVersion,
};
}
@override
String toJsonString() => jsonEncode(toJson());
static RemovePhotosClusterFeedback fromJson(Map<String, dynamic> json) {
assert(
json['type'] == FeedbackType.removePhotosClusterFeedback.toValueString(),
);
return RemovePhotosClusterFeedback(
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'],
removedPhotosFileID: (json['removedPhotosFileID'] as List?)
?.map((item) => item as int)
.toList() ??
[],
feedbackID: json['feedbackID'],
timestamp: DateTime.parse(json['timestamp']),
madeOnFaceMlVersion: json['madeOnFaceMlVersion'],
madeOnClusterMlVersion: json['madeOnClusterMlVersion'],
);
}
static RemovePhotosClusterFeedback fromJsonString(String jsonString) {
return fromJson(jsonDecode(jsonString));
}
}
class AddPhotosClusterFeedback extends ClusterFeedback {
final List<int> addedPhotoFileIDs;
AddPhotosClusterFeedback({
required List<double> medoid,
required double medoidDistanceThreshold,
required this.addedPhotoFileIDs,
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : super(
FeedbackType.addPhotosClusterFeedback,
medoid,
medoidDistanceThreshold,
feedbackID: feedbackID,
timestamp: timestamp,
madeOnFaceMlVersion: madeOnFaceMlVersion,
madeOnClusterMlVersion: madeOnClusterMlVersion,
);
@override
Map<String, dynamic> toJson() {
return {
'type': type.toValueString(),
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
'addedPhotoFileIDs': addedPhotoFileIDs,
'feedbackID': feedbackID,
'timestamp': timestamp.toIso8601String(),
'madeOnFaceMlVersion': madeOnFaceMlVersion,
'madeOnClusterMlVersion': madeOnClusterMlVersion,
};
}
@override
String toJsonString() => jsonEncode(toJson());
static AddPhotosClusterFeedback fromJson(Map<String, dynamic> json) {
assert(
json['type'] == FeedbackType.addPhotosClusterFeedback.toValueString(),
);
return AddPhotosClusterFeedback(
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'],
addedPhotoFileIDs: (json['addedPhotoFileIDs'] as List?)
?.map((item) => item as int)
.toList() ??
[],
feedbackID: json['feedbackID'],
timestamp: DateTime.parse(json['timestamp']),
madeOnFaceMlVersion: json['madeOnFaceMlVersion'],
madeOnClusterMlVersion: json['madeOnClusterMlVersion'],
);
}
static AddPhotosClusterFeedback fromJsonString(String jsonString) {
return fromJson(jsonDecode(jsonString));
}
}

View file

@ -0,0 +1,416 @@
import "package:logging/logging.dart";
import "package:photos/db/ml_data_db.dart";
import "package:photos/services/face_ml/face_detection/detection.dart";
import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart";
import "package:photos/services/face_ml/face_ml_result.dart";
class FaceFeedbackService {
final _logger = Logger("FaceFeedbackService");
final _mlDatabase = MlDataDB.instance;
int executedFeedbackCount = 0;
final int _reclusterFeedbackThreshold = 10;
// singleton pattern
FaceFeedbackService._privateConstructor();
static final instance = FaceFeedbackService._privateConstructor();
factory FaceFeedbackService() => instance;
/// Returns the updated cluster after removing the given file from the given person's cluster.
///
/// If the file is not in the cluster, returns null.
///
/// The updated cluster is also updated in [MlDataDB].
Future<ClusterResult> removePhotosFromCluster(
List<int> fileIDs,
int personID,
) async {
// TODO: check if photo was originally added to cluster by user. If so, we should remove that addition instead of changing the embedding, because there is no embedding...
_logger.info(
'removePhotoFromCluster called with fileIDs $fileIDs and personID $personID',
);
if (fileIDs.isEmpty) {
_logger.severe(
"No fileIDs given, unable to add photos to cluster!",
);
throw ArgumentError(
"No fileIDs given, unable to add photos to cluster!",
);
}
// Get the relevant cluster
final ClusterResult? cluster = await _mlDatabase.getClusterResult(personID);
if (cluster == null) {
_logger.severe(
"No cluster found for personID $personID, unable to remove photo from non-existent cluster!",
);
throw ArgumentError(
"No cluster found for personID $personID, unable to remove photo from non-existent cluster!",
);
}
// Get the relevant faceMlResults
final List<FaceMlResult> faceMlResults =
await _mlDatabase.getSelectedFaceMlResults(fileIDs);
if (faceMlResults.length != fileIDs.length) {
final List<int> foundFileIDs =
faceMlResults.map((faceMlResult) => faceMlResult.fileId).toList();
_logger.severe(
"Couldn't find all facemlresults for fileIDs $fileIDs, only found for $foundFileIDs. Unable to remove unindexed photos from cluster!",
);
throw ArgumentError(
"Couldn't find all facemlresults for fileIDs $fileIDs, only found for $foundFileIDs. Unable to remove unindexed photos from cluster!",
);
}
// Check if at least one of the files is in the cluster. If all files are already not in the cluster, return the cluster.
final List<int> fileIDsInCluster = fileIDs
.where((fileID) => cluster.uniqueFileIds.contains(fileID))
.toList();
if (fileIDsInCluster.isEmpty) {
_logger.warning(
"All fileIDs are already not in the cluster, unable to remove photos from cluster!",
);
return cluster;
}
final List<FaceMlResult> faceMlResultsInCluster = faceMlResults
.where((faceMlResult) => fileIDsInCluster.contains(faceMlResult.fileId))
.toList();
assert(faceMlResultsInCluster.length == fileIDsInCluster.length);
for (var i = 0; i < fileIDsInCluster.length; i++) {
// Find the faces/embeddings associated with both the fileID and personID
final List<String> faceIDs = faceMlResultsInCluster[i].allFaceIds;
final List<String> faceIDsInCluster = cluster.faceIDs;
final List<String> relevantFaceIDs =
faceIDsInCluster.where((faceID) => faceIDs.contains(faceID)).toList();
if (relevantFaceIDs.isEmpty) {
_logger.severe(
"No faces found in both cluster and file, unable to remove photo from cluster!",
);
throw ArgumentError(
"No faces found in both cluster and file, unable to remove photo from cluster!",
);
}
// Set the embeddings to [10, 10,..., 10] and save the updated faceMlResult
faceMlResultsInCluster[i].setEmbeddingsToTen(relevantFaceIDs);
await _mlDatabase.updateFaceMlResult(faceMlResultsInCluster[i]);
// Make sure there is a manual override for [10, 10,..., 10] embeddings (not actually here, but in building the clusters, see _checkIfClusterIsDeleted function)
// Manually remove the fileID from the cluster
cluster.removeFileId(fileIDsInCluster[i]);
}
// TODO: see below
// Re-cluster and check if this leads to more deletions. If so, save them and ask the user if they want to delete them too.
executedFeedbackCount++;
if (executedFeedbackCount % _reclusterFeedbackThreshold == 0) {
// await recluster();
}
// Update the cluster in the database
await _mlDatabase.updateClusterResult(cluster);
// TODO: see below
// Safe the given feedback to the database
final removePhotoFeedback = RemovePhotosClusterFeedback(
medoid: cluster.medoid,
medoidDistanceThreshold: cluster.medoidDistanceThreshold,
removedPhotosFileID: fileIDsInCluster,
);
await _mlDatabase.createClusterFeedback(
removePhotoFeedback,
skipIfSimilarFeedbackExists: false,
);
// Return the updated cluster
return cluster;
}
Future<ClusterResult> addPhotosToCluster(List<int> fileIDs, personID) async {
_logger.info(
'addPhotosToCluster called with fileIDs $fileIDs and personID $personID',
);
if (fileIDs.isEmpty) {
_logger.severe(
"No fileIDs given, unable to add photos to cluster!",
);
throw ArgumentError(
"No fileIDs given, unable to add photos to cluster!",
);
}
// Get the relevant cluster
final ClusterResult? cluster = await _mlDatabase.getClusterResult(personID);
if (cluster == null) {
_logger.severe(
"No cluster found for personID $personID, unable to add photos to non-existent cluster!",
);
throw ArgumentError(
"No cluster found for personID $personID, unable to add photos to non-existent cluster!",
);
}
// Check if at least one of the files is not in the cluster. If all files are already in the cluster, return the cluster.
final List<int> fileIDsNotInCluster = fileIDs
.where((fileID) => !cluster.uniqueFileIds.contains(fileID))
.toList();
if (fileIDsNotInCluster.isEmpty) {
_logger.warning(
"All fileIDs are already in the cluster, unable to add new photos to cluster!",
);
return cluster;
}
final List<String> faceIDsNotInCluster = fileIDsNotInCluster
.map((fileID) => FaceDetectionRelative.toFaceIDEmpty(fileID: fileID))
.toList();
// Add the new files to the cluster
cluster.addFileIDsAndFaceIDs(fileIDsNotInCluster, faceIDsNotInCluster);
// Update the cluster in the database
await _mlDatabase.updateClusterResult(cluster);
// Build the addPhotoFeedback
final AddPhotosClusterFeedback addPhotosFeedback = AddPhotosClusterFeedback(
medoid: cluster.medoid,
medoidDistanceThreshold: cluster.medoidDistanceThreshold,
addedPhotoFileIDs: fileIDsNotInCluster,
);
// TODO: check for exact match and update feedback if necessary
// Save the addPhotoFeedback to the database
await _mlDatabase.createClusterFeedback(
addPhotosFeedback,
skipIfSimilarFeedbackExists: false,
);
// Return the updated cluster
return cluster;
}
/// Deletes the given cluster completely.
Future<void> deleteCluster(int personID) async {
_logger.info(
'deleteCluster called with personID $personID',
);
// Get the relevant cluster
final cluster = await _mlDatabase.getClusterResult(personID);
if (cluster == null) {
_logger.severe(
"No cluster found for personID $personID, unable to delete non-existent cluster!",
);
throw ArgumentError(
"No cluster found for personID $personID, unable to delete non-existent cluster!",
);
}
// Delete the cluster from the database
await _mlDatabase.deleteClusterResult(cluster.personId);
// TODO: look into the right threshold distance.
// Build the deleteClusterFeedback
final DeleteClusterFeedback deleteClusterFeedback = DeleteClusterFeedback(
medoid: cluster.medoid,
medoidDistanceThreshold: cluster.medoidDistanceThreshold,
);
// TODO: maybe I should merge the two feedbacks if they are similar enough? Or alternatively, I keep them both?
// Check if feedback doesn't already exist
if (await _mlDatabase
.doesSimilarClusterFeedbackExist(deleteClusterFeedback)) {
_logger.warning(
"Feedback already exists for deleting cluster $personID, unable to delete cluster!",
);
return;
}
// Save the deleteClusterFeedback to the database
await _mlDatabase.createClusterFeedback(deleteClusterFeedback);
}
/// Renames the given cluster and/or sets the thumbnail of the given cluster.
///
/// Requires either a [customName] or a [customFaceID]. If both are given, both are used. If neither are given, an error is thrown.
Future<ClusterResult> renameOrSetThumbnailCluster(
int personID, {
String? customName,
String? customFaceID,
}) async {
_logger.info(
'renameOrSetThumbnailCluster called with personID $personID, customName $customName, and customFaceID $customFaceID',
);
if (customFaceID != null &&
FaceDetectionRelative.isFaceIDEmpty(customFaceID)) {
_logger.severe(
"customFaceID $customFaceID is belongs to empty detection, unable to set as thumbnail of cluster!",
);
customFaceID = null;
}
if (customName == null && customFaceID == null) {
_logger.severe(
"No name or faceID given, unable to rename or set thumbnail of cluster!",
);
throw ArgumentError(
"No name or faceID given, unable to rename or set thumbnail of cluster!",
);
}
// Get the relevant cluster
final cluster = await _mlDatabase.getClusterResult(personID);
if (cluster == null) {
_logger.severe(
"No cluster found for personID $personID, unable to delete non-existent cluster!",
);
throw ArgumentError(
"No cluster found for personID $personID, unable to delete non-existent cluster!",
);
}
// Update the cluster
if (customName != null) cluster.setUserDefinedName = customName;
if (customFaceID != null) cluster.setThumbnailFaceId = customFaceID;
// Update the cluster in the database
await _mlDatabase.updateClusterResult(cluster);
// Build the RenameOrCustomThumbnailClusterFeedback
final RenameOrCustomThumbnailClusterFeedback renameClusterFeedback =
RenameOrCustomThumbnailClusterFeedback(
medoid: cluster.medoid,
medoidDistanceThreshold: cluster.medoidDistanceThreshold,
customName: customName,
customThumbnailFaceId: customFaceID,
);
// TODO: maybe I should merge the two feedbacks if they are similar enough?
// Check if feedback doesn't already exist
final matchingFeedbacks =
await _mlDatabase.getAllMatchingClusterFeedback(renameClusterFeedback);
for (final matchingFeedback in matchingFeedbacks) {
// Update the current feedback wherever possible
renameClusterFeedback.customName ??= matchingFeedback.customName;
renameClusterFeedback.customThumbnailFaceId ??=
matchingFeedback.customThumbnailFaceId;
// Delete the old feedback (since we want the user to be able to overwrite their earlier feedback)
await _mlDatabase.deleteClusterFeedback(matchingFeedback);
}
// Save the RenameOrCustomThumbnailClusterFeedback to the database
await _mlDatabase.createClusterFeedback(renameClusterFeedback);
// Return the updated cluster
return cluster;
}
/// Merges the given clusters. The largest cluster is kept and the other clusters are deleted.
///
/// Requires either a [clusters] or [personIDs]. If both are given, the [clusters] are used.
Future<ClusterResult> mergeClusters(List<int> personIDs) async {
_logger.info(
'mergeClusters called with personIDs $personIDs',
);
// Get the relevant clusters
final List<ClusterResult> clusters =
await _mlDatabase.getSelectedClusterResults(personIDs);
if (clusters.length <= 1) {
_logger.severe(
"${clusters.length} clusters found for personIDs $personIDs, unable to merge non-existent clusters!",
);
throw ArgumentError(
"${clusters.length} clusters found for personIDs $personIDs, unable to merge non-existent clusters!",
);
}
// Find the largest cluster
clusters.sort((a, b) => b.clusterSize.compareTo(a.clusterSize));
final ClusterResult largestCluster = clusters.first;
// Now iterate through the clusters to be merged and deleted
for (var i = 1; i < clusters.length; i++) {
final ClusterResult clusterToBeMerged = clusters[i];
// Add the files and faces of the cluster to be merged to the largest cluster
largestCluster.addFileIDsAndFaceIDs(
clusterToBeMerged.fileIDsIncludingPotentialDuplicates,
clusterToBeMerged.faceIDs,
);
// TODO: maybe I should wrap the logic below in a separate function, since it's also used in renameOrSetThumbnailCluster
// Merge any names and thumbnails if the largest cluster doesn't have them
bool shouldCreateNamingFeedback = false;
String? nameToBeMerged;
String? thumbnailToBeMerged;
if (!largestCluster.hasUserDefinedName &&
clusterToBeMerged.hasUserDefinedName) {
largestCluster.setUserDefinedName = clusterToBeMerged.userDefinedName!;
nameToBeMerged = clusterToBeMerged.userDefinedName!;
shouldCreateNamingFeedback = true;
}
if (!largestCluster.thumbnailFaceIdIsUserDefined &&
clusterToBeMerged.thumbnailFaceIdIsUserDefined) {
largestCluster.setThumbnailFaceId = clusterToBeMerged.thumbnailFaceId;
thumbnailToBeMerged = clusterToBeMerged.thumbnailFaceId;
shouldCreateNamingFeedback = true;
}
if (shouldCreateNamingFeedback) {
final RenameOrCustomThumbnailClusterFeedback renameClusterFeedback =
RenameOrCustomThumbnailClusterFeedback(
medoid: largestCluster.medoid,
medoidDistanceThreshold: largestCluster.medoidDistanceThreshold,
customName: nameToBeMerged,
customThumbnailFaceId: thumbnailToBeMerged,
);
// Check if feedback doesn't already exist
final matchingFeedbacks = await _mlDatabase
.getAllMatchingClusterFeedback(renameClusterFeedback);
for (final matchingFeedback in matchingFeedbacks) {
// Update the current feedback wherever possible
renameClusterFeedback.customName ??= matchingFeedback.customName;
renameClusterFeedback.customThumbnailFaceId ??=
matchingFeedback.customThumbnailFaceId;
// Delete the old feedback (since we want the user to be able to overwrite their earlier feedback)
await _mlDatabase.deleteClusterFeedback(matchingFeedback);
}
// Save the RenameOrCustomThumbnailClusterFeedback to the database
await _mlDatabase.createClusterFeedback(renameClusterFeedback);
}
// Build the mergeClusterFeedback
final MergeClusterFeedback mergeClusterFeedback = MergeClusterFeedback(
medoid: clusterToBeMerged.medoid,
medoidDistanceThreshold: clusterToBeMerged.medoidDistanceThreshold,
medoidToMoveTo: largestCluster.medoid,
);
// Save the mergeClusterFeedback to the database and delete any old matching feedbacks
final matchingFeedbacks =
await _mlDatabase.getAllMatchingClusterFeedback(mergeClusterFeedback);
for (final matchingFeedback in matchingFeedbacks) {
await _mlDatabase.deleteClusterFeedback(matchingFeedback);
}
await _mlDatabase.createClusterFeedback(mergeClusterFeedback);
// Delete the cluster from the database
await _mlDatabase.deleteClusterResult(clusterToBeMerged.personId);
}
// TODO: should I update the medoid of this new cluster? My intuition says no, but I'm not sure.
// Update the largest cluster in the database
await _mlDatabase.updateClusterResult(largestCluster);
// Return the merged cluster
return largestCluster;
}
}

View file

@ -0,0 +1,34 @@
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart";
import "package:uuid/uuid.dart";
abstract class Feedback {
final FeedbackType type;
final String feedbackID;
final DateTime timestamp;
final int madeOnFaceMlVersion;
final int madeOnClusterMlVersion;
get typeString => type.toValueString();
get timestampString => timestamp.toIso8601String();
Feedback(
this.type, {
String? feedbackID,
DateTime? timestamp,
int? madeOnFaceMlVersion,
int? madeOnClusterMlVersion,
}) : feedbackID = feedbackID ?? const Uuid().v4(),
timestamp = timestamp ?? DateTime.now(),
madeOnFaceMlVersion = madeOnFaceMlVersion ?? faceMlVersion,
madeOnClusterMlVersion = madeOnClusterMlVersion ?? clusterMlVersion;
Map<String, dynamic> toJson();
String toJsonString();
// Feedback fromJson(Map<String, dynamic> json);
// Feedback fromJsonString(String jsonString);
}

View file

@ -0,0 +1,26 @@
enum FeedbackType {
removePhotosClusterFeedback,
addPhotosClusterFeedback,
deleteClusterFeedback,
mergeClusterFeedback,
renameOrCustomThumbnailClusterFeedback; // I have merged renameClusterFeedback and customThumbnailClusterFeedback, since I suspect they will be used together often
factory FeedbackType.fromValueString(String value) {
switch (value) {
case 'deleteClusterFeedback':
return FeedbackType.deleteClusterFeedback;
case 'mergeClusterFeedback':
return FeedbackType.mergeClusterFeedback;
case 'renameOrCustomThumbnailClusterFeedback':
return FeedbackType.renameOrCustomThumbnailClusterFeedback;
case 'removePhotoClusterFeedback':
return FeedbackType.removePhotosClusterFeedback;
case 'addPhotoClusterFeedback':
return FeedbackType.addPhotosClusterFeedback;
default:
throw Exception('Invalid FeedbackType: $value');
}
}
String toValueString() => name;
}

View file

@ -0,0 +1,30 @@
class GeneralFaceMlException implements Exception {
final String message;
GeneralFaceMlException(this.message);
@override
String toString() => 'GeneralFaceMlException: $message';
}
class CouldNotRetrieveAnyFileData implements Exception {}
class CouldNotInitializeFaceDetector implements Exception {}
class CouldNotRunFaceDetector implements Exception {}
class CouldNotWarpAffine implements Exception {}
class CouldNotInitializeFaceEmbeddor implements Exception {}
class InputProblemFaceEmbeddor implements Exception {
final String message;
InputProblemFaceEmbeddor(this.message);
@override
String toString() => 'InputProblemFaceEmbeddor: $message';
}
class CouldNotRunFaceEmbeddor implements Exception {}

View file

@ -0,0 +1,90 @@
import "package:photos/services/face_ml/face_ml_version.dart";
/// Represents a face detection method with a specific version.
class FaceDetectionMethod extends VersionedMethod {
/// Creates a [FaceDetectionMethod] instance with a specific `method` and `version` (default `1`)
FaceDetectionMethod(String method, {int version = 1})
: super(method, version);
/// Creates a [FaceDetectionMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`)
const FaceDetectionMethod.empty() : super.empty();
/// Creates a [FaceDetectionMethod] instance with 'BlazeFace' as the method, and a specific `version` (default `1`)
FaceDetectionMethod.blazeFace({int version = 1})
: super('BlazeFace', version);
static FaceDetectionMethod fromMlVersion(int version) {
switch (version) {
case 1:
return FaceDetectionMethod.blazeFace(version: version);
default:
return const FaceDetectionMethod.empty();
}
}
static FaceDetectionMethod fromJson(Map<String, dynamic> json) {
return FaceDetectionMethod(
json['method'],
version: json['version'],
);
}
}
/// Represents a face alignment method with a specific version.
class FaceAlignmentMethod extends VersionedMethod {
/// Creates a [FaceAlignmentMethod] instance with a specific `method` and `version` (default `1`)
FaceAlignmentMethod(String method, {int version = 1})
: super(method, version);
/// Creates a [FaceAlignmentMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`)
const FaceAlignmentMethod.empty() : super.empty();
/// Creates a [FaceAlignmentMethod] instance with 'ArcFace' as the method, and a specific `version` (default `1`)
FaceAlignmentMethod.arcFace({int version = 1}) : super('ArcFace', version);
static FaceAlignmentMethod fromMlVersion(int version) {
switch (version) {
case 1:
return FaceAlignmentMethod.arcFace(version: version);
default:
return const FaceAlignmentMethod.empty();
}
}
static FaceAlignmentMethod fromJson(Map<String, dynamic> json) {
return FaceAlignmentMethod(
json['method'],
version: json['version'],
);
}
}
/// Represents a face embedding method with a specific version.
class FaceEmbeddingMethod extends VersionedMethod {
/// Creates a [FaceEmbeddingMethod] instance with a specific `method` and `version` (default `1`)
FaceEmbeddingMethod(String method, {int version = 1})
: super(method, version);
/// Creates a [FaceEmbeddingMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`)
const FaceEmbeddingMethod.empty() : super.empty();
/// Creates a [FaceEmbeddingMethod] instance with 'MobileFaceNet' as the method, and a specific `version` (default `1`)
FaceEmbeddingMethod.mobileFaceNet({int version = 1})
: super('MobileFaceNet', version);
static FaceEmbeddingMethod fromMlVersion(int version) {
switch (version) {
case 1:
return FaceEmbeddingMethod.mobileFaceNet(version: version);
default:
return const FaceEmbeddingMethod.empty();
}
}
static FaceEmbeddingMethod fromJson(Map<String, dynamic> json) {
return FaceEmbeddingMethod(
json['method'],
version: json['version'],
);
}
}

View file

@ -0,0 +1,753 @@
import "dart:convert" show jsonEncode, jsonDecode;
import "package:flutter/material.dart" show Size, debugPrint, immutable;
import "package:logging/logging.dart";
import "package:photos/db/ml_data_db.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/models/ml/ml_typedefs.dart';
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/services/face_ml/blur_detection/blur_constants.dart";
import "package:photos/services/face_ml/face_alignment/alignment_result.dart";
import "package:photos/services/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/face_ml/face_detection/detection.dart";
import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart";
import "package:photos/services/face_ml/face_ml_methods.dart";
final _logger = Logger('ClusterResult_FaceMlResult');
// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class?
class ClusterResult {
final int personId;
String? userDefinedName;
bool get hasUserDefinedName => userDefinedName != null;
String _thumbnailFaceId;
bool thumbnailFaceIdIsUserDefined;
final List<int> _fileIds;
final List<String> _faceIds;
final Embedding medoid;
double medoidDistanceThreshold;
List<int> get uniqueFileIds => _fileIds.toSet().toList();
List<int> get fileIDsIncludingPotentialDuplicates => _fileIds;
List<String> get faceIDs => _faceIds;
String get thumbnailFaceId => _thumbnailFaceId;
int get thumbnailFileId => _getFileIdFromFaceId(_thumbnailFaceId);
/// Sets the thumbnail faceId to the given faceId.
/// Throws an exception if the faceId is not in the list of faceIds.
set setThumbnailFaceId(String faceId) {
if (!_faceIds.contains(faceId)) {
throw Exception(
"The faceId $faceId is not in the list of faceIds: $faceId",
);
}
_thumbnailFaceId = faceId;
thumbnailFaceIdIsUserDefined = true;
}
/// Sets the [userDefinedName] to the given [customName]
set setUserDefinedName(String customName) {
userDefinedName = customName;
}
int get clusterSize => _fileIds.toSet().length;
ClusterResult({
required this.personId,
required String thumbnailFaceId,
required List<int> fileIds,
required List<String> faceIds,
required this.medoid,
required this.medoidDistanceThreshold,
this.userDefinedName,
this.thumbnailFaceIdIsUserDefined = false,
}) : _thumbnailFaceId = thumbnailFaceId,
_faceIds = faceIds,
_fileIds = fileIds;
void addFileIDsAndFaceIDs(List<int> fileIDs, List<String> faceIDs) {
assert(fileIDs.length == faceIDs.length);
_fileIds.addAll(fileIDs);
_faceIds.addAll(faceIDs);
}
// TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster
int removeFileId(int fileId) {
assert(_fileIds.length == _faceIds.length);
if (!_fileIds.contains(fileId)) {
throw Exception(
"The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.",
);
}
int removedCount = 0;
for (var i = 0; i < _fileIds.length; i++) {
if (_fileIds[i] == fileId) {
assert(_getFileIdFromFaceId(_faceIds[i]) == fileId);
_fileIds.removeAt(i);
_faceIds.removeAt(i);
debugPrint(
"Removed fileId $fileId from cluster $personId at index ${i + removedCount}}",
);
i--; // Adjust index due to removal
removedCount++;
}
}
_ensureClusterSizeIsAboveMinimum();
return removedCount;
}
int addFileID(int fileID) {
assert(_fileIds.length == _faceIds.length);
if (_fileIds.contains(fileID)) {
return 0;
}
_fileIds.add(fileID);
_faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID));
return 1;
}
void ensureThumbnailFaceIdIsInCluster() {
if (!_faceIds.contains(_thumbnailFaceId)) {
_thumbnailFaceId = _faceIds[0];
}
}
void _ensureClusterSizeIsAboveMinimum() {
if (clusterSize < minimumClusterSize) {
throw Exception(
"Cluster size is below minimum cluster size of $minimumClusterSize",
);
}
}
Map<String, dynamic> _toJson() => {
'personId': personId,
'thumbnailFaceId': _thumbnailFaceId,
'fileIds': _fileIds,
'faceIds': _faceIds,
'medoid': medoid,
'medoidDistanceThreshold': medoidDistanceThreshold,
if (userDefinedName != null) 'userDefinedName': userDefinedName,
'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined,
};
String toJsonString() => jsonEncode(_toJson());
static ClusterResult _fromJson(Map<String, dynamic> json) {
return ClusterResult(
personId: json['personId'] ?? -1,
thumbnailFaceId: json['thumbnailFaceId'] ?? '',
fileIds:
(json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [],
faceIds:
(json['faceIds'] as List?)?.map((item) => item as String).toList() ??
[],
medoid:
(json['medoid'] as List?)?.map((item) => item as double).toList() ??
[],
medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0,
userDefinedName: json['userDefinedName'],
thumbnailFaceIdIsUserDefined:
json['thumbnailFaceIdIsUserDefined'] as bool,
);
}
static ClusterResult fromJsonString(String jsonString) {
return _fromJson(jsonDecode(jsonString));
}
}
class ClusterResultBuilder {
int personId = -1;
String? userDefinedName;
String thumbnailFaceId = '';
bool thumbnailFaceIdIsUserDefined = false;
List<int> fileIds = <int>[];
List<String> faceIds = <String>[];
List<Embedding> embeddings = <Embedding>[];
Embedding medoid = <double>[];
double medoidDistanceThreshold = 0;
bool medoidAndThresholdCalculated = false;
final int k = 5;
ClusterResultBuilder.createFromIndices({
required List<int> clusterIndices,
required List<int> labels,
required List<Embedding> allEmbeddings,
required List<int> allFileIds,
required List<String> allFaceIds,
}) {
final clusteredFileIds =
clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList();
final clusteredFaceIds =
clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList();
final clusteredEmbeddings =
clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList();
personId = labels[clusterIndices[0]];
fileIds = clusteredFileIds;
faceIds = clusteredFaceIds;
thumbnailFaceId = faceIds[0];
embeddings = clusteredEmbeddings;
}
void calculateAndSetMedoidAndThreshold() {
if (embeddings.isEmpty) {
throw Exception("Cannot calculate medoid and threshold for empty list");
}
// Calculate the medoid and threshold
final (tempMedoid, distanceThreshold) =
_calculateMedoidAndDistanceTreshold(embeddings);
// Update the medoid
medoid = List.from(tempMedoid);
// Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor
medoidDistanceThreshold = distanceThreshold;
medoidAndThresholdCalculated = true;
}
(List<double>, double) _calculateMedoidAndDistanceTreshold(
List<List<double>> embeddings,
) {
double minDistance = double.infinity;
List<double>? medoid;
// Calculate the distance between all pairs
for (int i = 0; i < embeddings.length; ++i) {
double totalDistance = 0;
for (int j = 0; j < embeddings.length; ++j) {
if (i != j) {
totalDistance += cosineDistance(embeddings[i], embeddings[j]);
// Break early if we already exceed minDistance
if (totalDistance > minDistance) {
break;
}
}
}
// Find the minimum total distance
if (totalDistance < minDistance) {
minDistance = totalDistance;
medoid = embeddings[i];
}
}
// Now, calculate k-th nearest neighbor for the medoid
final List<double> distancesToMedoid = [];
for (List<double> embedding in embeddings) {
if (embedding != medoid) {
distancesToMedoid.add(cosineDistance(medoid!, embedding));
}
}
distancesToMedoid.sort();
// TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones.
final double kthDistance = distancesToMedoid[
distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1];
return (medoid!, kthDistance);
}
Future<bool> _checkIfClusterIsDeleted() async {
assert(medoidAndThresholdCalculated);
// Check if the medoid is the default medoid for deleted faces
if (cosineDistance(medoid, List.filled(medoid.length, 10.0)) < 0.001) {
return true;
}
final tempFeedback = DeleteClusterFeedback(
medoid: medoid,
medoidDistanceThreshold: medoidDistanceThreshold,
);
return await MlDataDB.instance
.doesSimilarClusterFeedbackExist(tempFeedback);
}
Future<void> _checkAndAddPhotos() async {
assert(medoidAndThresholdCalculated);
final tempFeedback = AddPhotosClusterFeedback(
medoid: medoid,
medoidDistanceThreshold: medoidDistanceThreshold,
addedPhotoFileIDs: [],
);
final allAddPhotosFeedbacks =
await MlDataDB.instance.getAllMatchingClusterFeedback(tempFeedback);
for (final addPhotosFeedback in allAddPhotosFeedbacks) {
final fileIDsToAdd = addPhotosFeedback.addedPhotoFileIDs;
final faceIDsToAdd = fileIDsToAdd
.map((fileID) => FaceDetectionRelative.toFaceIDEmpty(fileID: fileID))
.toList();
addFileIDsAndFaceIDs(fileIDsToAdd, faceIDsToAdd);
}
}
Future<void> _checkAndAddCustomName() async {
assert(medoidAndThresholdCalculated);
final tempFeedback = RenameOrCustomThumbnailClusterFeedback(
medoid: medoid,
medoidDistanceThreshold: medoidDistanceThreshold,
customName: 'test',
);
final allRenameFeedbacks =
await MlDataDB.instance.getAllMatchingClusterFeedback(tempFeedback);
for (final nameFeedback in allRenameFeedbacks) {
userDefinedName ??= nameFeedback.customName;
if (!thumbnailFaceIdIsUserDefined) {
thumbnailFaceId = nameFeedback.customThumbnailFaceId ?? thumbnailFaceId;
thumbnailFaceIdIsUserDefined =
nameFeedback.customThumbnailFaceId != null;
}
}
return;
}
void changeThumbnailFaceId(String faceId) {
if (!faceIds.contains(faceId)) {
throw Exception(
"The faceId $faceId is not in the list of faceIds: $faceIds",
);
}
thumbnailFaceId = faceId;
}
void addFileIDsAndFaceIDs(List<int> addedFileIDs, List<String> addedFaceIDs) {
assert(addedFileIDs.length == addedFaceIDs.length);
fileIds.addAll(addedFileIDs);
faceIds.addAll(addedFaceIDs);
}
static Future<List<ClusterResult>> buildClusters(
List<ClusterResultBuilder> clusterBuilders,
) async {
final List<int> deletedClusterIndices = [];
for (var i = 0; i < clusterBuilders.length; i++) {
final clusterBuilder = clusterBuilders[i];
clusterBuilder.calculateAndSetMedoidAndThreshold();
// Check if the cluster has been deleted
if (await clusterBuilder._checkIfClusterIsDeleted()) {
deletedClusterIndices.add(i);
}
await clusterBuilder._checkAndAddPhotos();
}
// Check if a cluster should be merged with another cluster
for (var i = 0; i < clusterBuilders.length; i++) {
// Don't check for clusters that have been deleted
if (deletedClusterIndices.contains(i)) {
continue;
}
final clusterBuilder = clusterBuilders[i];
final List<MergeClusterFeedback> allMatchingMergeFeedback =
await MlDataDB.instance.getAllMatchingClusterFeedback(
MergeClusterFeedback(
medoid: clusterBuilder.medoid,
medoidDistanceThreshold: clusterBuilder.medoidDistanceThreshold,
medoidToMoveTo: clusterBuilder.medoid,
),
);
if (allMatchingMergeFeedback.isEmpty) {
continue;
}
// Merge the cluster with the first merge feedback
final mainFeedback = allMatchingMergeFeedback.first;
if (allMatchingMergeFeedback.length > 1) {
// This is the BUG!!!!
_logger.warning(
"There are ${allMatchingMergeFeedback.length} merge feedbacks for cluster ${clusterBuilder.personId}. Using the first one.",
);
}
for (var j = 0; j < clusterBuilders.length; j++) {
if (i == j) continue;
final clusterBuilderToMergeTo = clusterBuilders[j];
final distance = cosineDistance(
// BUG: it hasn't calculated the medoid for every clusterBuilder yet!!!
mainFeedback.medoidToMoveTo,
clusterBuilderToMergeTo.medoid,
);
if (distance < mainFeedback.medoidDistanceThreshold ||
distance < clusterBuilderToMergeTo.medoidDistanceThreshold) {
clusterBuilderToMergeTo.addFileIDsAndFaceIDs(
clusterBuilder.fileIds,
clusterBuilder.faceIds,
);
deletedClusterIndices.add(i);
}
}
}
final clusterResults = <ClusterResult>[];
for (var i = 0; i < clusterBuilders.length; i++) {
// Don't build the cluster if it has been deleted or merged
if (deletedClusterIndices.contains(i)) {
continue;
}
final clusterBuilder = clusterBuilders[i];
// Check if the cluster has a custom name or thumbnail
await clusterBuilder._checkAndAddCustomName();
// Build the clusterResult
clusterResults.add(
ClusterResult(
personId: clusterBuilder.personId,
thumbnailFaceId: clusterBuilder.thumbnailFaceId,
fileIds: clusterBuilder.fileIds,
faceIds: clusterBuilder.faceIds,
medoid: clusterBuilder.medoid,
medoidDistanceThreshold: clusterBuilder.medoidDistanceThreshold,
userDefinedName: clusterBuilder.userDefinedName,
thumbnailFaceIdIsUserDefined:
clusterBuilder.thumbnailFaceIdIsUserDefined,
),
);
}
return clusterResults;
}
// TODO: This function should include the feedback from the user. Should also be nullable, since user might want to delete the cluster.
Future<ClusterResult?> _buildSingleCluster() async {
calculateAndSetMedoidAndThreshold();
if (await _checkIfClusterIsDeleted()) {
return null;
}
await _checkAndAddCustomName();
return ClusterResult(
personId: personId,
thumbnailFaceId: thumbnailFaceId,
fileIds: fileIds,
faceIds: faceIds,
medoid: medoid,
medoidDistanceThreshold: medoidDistanceThreshold,
);
}
}
@immutable
class FaceMlResult {
final int fileId;
final List<FaceResult> faces;
final Size? faceDetectionImageSize;
final Size? faceAlignmentImageSize;
final int mlVersion;
final bool errorOccured;
final bool onlyThumbnailUsed;
bool get hasFaces => faces.isNotEmpty;
int get numberOfFaces => faces.length;
List<Embedding> get allFaceEmbeddings {
return faces.map((face) => face.embedding).toList();
}
List<String> get allFaceIds {
return faces.map((face) => face.faceId).toList();
}
List<int> get fileIdForEveryFace {
return List<int>.filled(faces.length, fileId);
}
FaceDetectionMethod get faceDetectionMethod =>
FaceDetectionMethod.fromMlVersion(mlVersion);
FaceAlignmentMethod get faceAlignmentMethod =>
FaceAlignmentMethod.fromMlVersion(mlVersion);
FaceEmbeddingMethod get faceEmbeddingMethod =>
FaceEmbeddingMethod.fromMlVersion(mlVersion);
const FaceMlResult({
required this.fileId,
required this.faces,
required this.mlVersion,
required this.errorOccured,
required this.onlyThumbnailUsed,
required this.faceDetectionImageSize,
this.faceAlignmentImageSize,
});
Map<String, dynamic> _toJson() => {
'fileId': fileId,
'faces': faces.map((face) => face.toJson()).toList(),
'mlVersion': mlVersion,
'errorOccured': errorOccured,
'onlyThumbnailUsed': onlyThumbnailUsed,
if (faceDetectionImageSize != null)
'faceDetectionImageSize': {
'width': faceDetectionImageSize!.width,
'height': faceDetectionImageSize!.height,
},
if (faceAlignmentImageSize != null)
'faceAlignmentImageSize': {
'width': faceAlignmentImageSize!.width,
'height': faceAlignmentImageSize!.height,
},
};
String toJsonString() => jsonEncode(_toJson());
static FaceMlResult _fromJson(Map<String, dynamic> json) {
return FaceMlResult(
fileId: json['fileId'],
faces: (json['faces'] as List)
.map((item) => FaceResult.fromJson(item as Map<String, dynamic>))
.toList(),
mlVersion: json['mlVersion'],
errorOccured: json['errorOccured'] ?? false,
onlyThumbnailUsed: json['onlyThumbnailUsed'] ?? false,
faceDetectionImageSize: json['faceDetectionImageSize'] == null
? null
: Size(
json['faceDetectionImageSize']['width'],
json['faceDetectionImageSize']['height'],
),
faceAlignmentImageSize: json['faceAlignmentImageSize'] == null
? null
: Size(
json['faceAlignmentImageSize']['width'],
json['faceAlignmentImageSize']['height'],
),
);
}
static FaceMlResult fromJsonString(String jsonString) {
return _fromJson(jsonDecode(jsonString));
}
/// Sets the embeddings of the faces with the given faceIds to [10, 10,..., 10].
///
/// Throws an exception if a faceId is not found in the FaceMlResult.
void setEmbeddingsToTen(List<String> faceIds) {
for (final faceId in faceIds) {
final faceIndex = faces.indexWhere((face) => face.faceId == faceId);
if (faceIndex == -1) {
throw Exception("No face found with faceId $faceId");
}
for (var i = 0; i < faces[faceIndex].embedding.length; i++) {
faces[faceIndex].embedding[i] = 10;
}
}
}
FaceDetectionRelative getDetectionForFaceId(String faceId) {
final faceIndex = faces.indexWhere((face) => face.faceId == faceId);
if (faceIndex == -1) {
throw Exception("No face found with faceId $faceId");
}
return faces[faceIndex].detection;
}
}
class FaceMlResultBuilder {
int fileId;
List<FaceResultBuilder> faces = <FaceResultBuilder>[];
Size? faceDetectionImageSize;
Size? faceAlignmentImageSize;
int mlVersion;
bool errorOccured;
bool onlyThumbnailUsed;
FaceMlResultBuilder({
this.fileId = -1,
this.mlVersion = faceMlVersion,
this.errorOccured = false,
this.onlyThumbnailUsed = false,
});
FaceMlResultBuilder.fromEnteFile(
EnteFile file, {
this.mlVersion = faceMlVersion,
this.errorOccured = false,
this.onlyThumbnailUsed = false,
}) : fileId = file.uploadedFileID ?? -1;
FaceMlResultBuilder.fromEnteFileID(
int fileID, {
this.mlVersion = faceMlVersion,
this.errorOccured = false,
this.onlyThumbnailUsed = false,
}) : fileId = fileID;
void addNewlyDetectedFaces(
List<FaceDetectionRelative> faceDetections,
Size originalSize,
) {
faceDetectionImageSize = originalSize;
for (var i = 0; i < faceDetections.length; i++) {
faces.add(
FaceResultBuilder.fromFaceDetection(
faceDetections[i],
resultBuilder: this,
),
);
}
}
void addAlignmentResults(
List<AlignmentResult> alignmentResults,
List<double> blurValues,
Size imageSizeUsedForAlignment,
) {
if (alignmentResults.length != faces.length) {
throw Exception(
"The amount of alignment results (${alignmentResults.length}) does not match the number of faces (${faces.length})",
);
}
for (var i = 0; i < alignmentResults.length; i++) {
faces[i].alignment = alignmentResults[i];
faces[i].blurValue = blurValues[i];
}
faceAlignmentImageSize = imageSizeUsedForAlignment;
}
void addEmbeddingsToExistingFaces(
List<Embedding> embeddings,
) {
if (embeddings.length != faces.length) {
throw Exception(
"The amount of embeddings (${embeddings.length}) does not match the number of faces (${faces.length})",
);
}
for (var faceIndex = 0; faceIndex < faces.length; faceIndex++) {
faces[faceIndex].embedding = embeddings[faceIndex];
}
}
FaceMlResult build() {
final faceResults = <FaceResult>[];
for (var i = 0; i < faces.length; i++) {
faceResults.add(faces[i].build());
}
return FaceMlResult(
fileId: fileId,
faces: faceResults,
mlVersion: mlVersion,
errorOccured: errorOccured,
onlyThumbnailUsed: onlyThumbnailUsed,
faceDetectionImageSize: faceDetectionImageSize,
faceAlignmentImageSize: faceAlignmentImageSize,
);
}
FaceMlResult buildNoFaceDetected() {
faces = <FaceResultBuilder>[];
return build();
}
FaceMlResult buildErrorOccurred() {
faces = <FaceResultBuilder>[];
errorOccured = true;
return build();
}
}
@immutable
class FaceResult {
final FaceDetectionRelative detection;
final double blurValue;
final AlignmentResult alignment;
final Embedding embedding;
final int fileId;
final String faceId;
bool get isBlurry => blurValue < kLaplacianThreshold;
const FaceResult({
required this.detection,
required this.blurValue,
required this.alignment,
required this.embedding,
required this.fileId,
required this.faceId,
});
Map<String, dynamic> toJson() => {
'detection': detection.toJson(),
'blurValue': blurValue,
'alignment': alignment.toJson(),
'embedding': embedding,
'fileId': fileId,
'faceId': faceId,
};
static FaceResult fromJson(Map<String, dynamic> json) {
return FaceResult(
detection: FaceDetectionRelative.fromJson(json['detection']),
blurValue: json['blurValue'],
alignment: AlignmentResult.fromJson(json['alignment']),
embedding: Embedding.from(json['embedding']),
fileId: json['fileId'],
faceId: json['faceId'],
);
}
}
class FaceResultBuilder {
FaceDetectionRelative detection =
FaceDetectionRelative.defaultInitialization();
double blurValue = 1000;
AlignmentResult alignment = AlignmentResult.empty();
Embedding embedding = <double>[];
int fileId = -1;
String faceId = '';
bool get isBlurry => blurValue < kLaplacianThreshold;
FaceResultBuilder({
required this.fileId,
required this.faceId,
});
FaceResultBuilder.fromFaceDetection(
FaceDetectionRelative faceDetection, {
required FaceMlResultBuilder resultBuilder,
}) {
fileId = resultBuilder.fileId;
faceId = faceDetection.toFaceID(fileID: resultBuilder.fileId);
detection = faceDetection;
}
FaceResult build() {
assert(detection.allKeypoints[0][0] <= 1);
assert(detection.box[0] <= 1);
return FaceResult(
detection: detection,
blurValue: blurValue,
alignment: alignment,
embedding: embedding,
fileId: fileId,
faceId: faceId,
);
}
}
int _getFileIdFromFaceId(String faceId) {
return int.parse(faceId.split("_")[0]);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,15 @@
abstract class VersionedMethod {
final String method;
final int version;
VersionedMethod(this.method, [this.version = 0]);
const VersionedMethod.empty()
: method = 'Empty method',
version = 0;
Map<String, dynamic> toJson() => {
'method': method,
'version': version,
};
}

View file

@ -0,0 +1,120 @@
import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml_data_db.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/utils/image_ml_isolate.dart';
import "package:photos/utils/thumbnail_util.dart";
class FaceSearchService {
final _logger = Logger("FaceSearchService");
final _mlDatabase = MlDataDB.instance;
final _filesDatabase = FilesDB.instance;
// singleton pattern
FaceSearchService._privateConstructor();
static final instance = FaceSearchService._privateConstructor();
factory FaceSearchService() => instance;
/// Returns the personIDs of all clustered people in the database.
Future<List<int>> getAllPeople() async {
final peopleIds = await _mlDatabase.getAllClusterIds();
return peopleIds;
}
/// Returns the thumbnail associated with a given personId.
Future<Uint8List?> getPersonThumbnail(int personID) async {
// get the cluster associated with the personID
final cluster = await _mlDatabase.getClusterResult(personID);
if (cluster == null) {
_logger.warning(
"No cluster found for personID $personID, unable to get thumbnail.",
);
return null;
}
// get the faceID and fileID you want to use to generate the thumbnail
final String thumbnailFaceID = cluster.thumbnailFaceId;
final int thumbnailFileID = cluster.thumbnailFileId;
// get the full file thumbnail
final EnteFile enteFile = await _filesDatabase
.getFilesFromIDs([thumbnailFileID]).then((value) => value.values.first);
final Uint8List? fileThumbnail = await getThumbnail(enteFile);
if (fileThumbnail == null) {
_logger.warning(
"No full file thumbnail found for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.",
);
return null;
}
// get the face detection for the thumbnail
final thumbnailMlResult =
await _mlDatabase.getFaceMlResult(thumbnailFileID);
if (thumbnailMlResult == null) {
_logger.warning(
"No face ml result found for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.",
);
return null;
}
final detection = thumbnailMlResult.getDetectionForFaceId(thumbnailFaceID);
// create the thumbnail from the full file thumbnail and the face detection
Uint8List faceThumbnail;
try {
faceThumbnail = await ImageMlIsolate.instance.generateFaceThumbnail(
fileThumbnail,
detection,
);
} catch (e, s) {
_logger.warning(
"Unable to generate face thumbnail for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.",
e,
s,
);
return null;
}
return faceThumbnail;
}
/// Returns all files associated with a given personId.
Future<List<EnteFile>> getFilesForPerson(int personID) async {
final fileIDs = await _mlDatabase.getClusterFileIds(personID);
final Map<int, EnteFile> files =
await _filesDatabase.getFilesFromIDs(fileIDs);
return files.values.toList();
}
Future<List<EnteFile>> getFilesForIntersectOfPeople(
List<int> personIDs,
) async {
if (personIDs.length <= 1) {
_logger
.warning('Cannot get intersection of files for less than 2 people');
return <EnteFile>[];
}
final Set<int> fileIDsFirstCluster = await _mlDatabase
.getClusterFileIds(personIDs.first)
.then((value) => value.toSet());
for (final personID in personIDs.sublist(1)) {
final fileIDsSingleCluster =
await _mlDatabase.getClusterFileIds(personID);
fileIDsFirstCluster.retainAll(fileIDsSingleCluster);
// Early termination if intersection is empty
if (fileIDsFirstCluster.isEmpty) {
return <EnteFile>[];
}
}
final Map<int, EnteFile> files =
await _filesDatabase.getFilesFromIDs(fileIDsFirstCluster.toList());
return files.values.toList();
}
}

View file

@ -0,0 +1,464 @@
import 'dart:developer' as dev;
import "dart:math" show Random;
import "dart:typed_data";
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/search_service.dart";
class ClusterFeedbackService {
final Logger _logger = Logger("ClusterFeedbackService");
ClusterFeedbackService._privateConstructor();
static final ClusterFeedbackService instance =
ClusterFeedbackService._privateConstructor();
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
Person p, {
double maxClusterDistance = 0.4,
}) async {
// Get all the cluster data
final faceMlDb = FaceMLDataDB.instance;
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.attr.name} are $personClusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
maxClusterDistance,
);
// log suggestions
for (final entry in suggestions.entries) {
dev.log(
' ${entry.value.length} suggestion for ${p.attr.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
name: "ClusterFeedbackService",
);
}
return suggestions;
}
Future<List<int>> getSuggestionsUsingMedian(
Person p, {
int sampleSize = 50,
double maxMedianDistance = 0.65,
double goodMedianDistance = 0.55,
double maxMeanDistance = 0.65,
double goodMeanDistance = 0.4,
}) async {
// Get all the cluster data
final faceMlDb = FaceMLDataDB.instance;
// final Map<int, List<(int, double)>> suggestions = {};
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.attr.name} are $personClusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the other cluster candidates based on the mean
final Map<int, List<(int, double)>> suggestionsMean = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
goodMeanDistance,
);
if (suggestionsMean.isNotEmpty) {
final List<(int, double)> suggestClusterIds = [];
for (final List<(int, double)> suggestion in suggestionsMean.values) {
suggestClusterIds.addAll(suggestion);
}
suggestClusterIds.sort(
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!),
);
final suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false);
final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info(
"Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances",
);
return suggestClusterIds.map((e) => e.$1).toList(growable: false);
}
// Find the other cluster candidates based on the median
final Map<int, List<(int, double)>> moreSuggestionsMean =
_calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
maxMeanDistance,
);
if (moreSuggestionsMean.isEmpty) {
_logger
.info("No suggestions found using mean, even with higher threshold");
return <int>[];
}
final List<(int, double)> temp = [];
for (final List<(int, double)> suggestion in moreSuggestionsMean.values) {
temp.addAll(suggestion);
}
temp.sort((a, b) => a.$2.compareTo(b.$2));
final otherClusterIdsCandidates = temp
.map(
(e) => e.$1,
)
.toList(growable: false);
_logger.info(
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
);
watch.logAndReset("Starting median test");
// Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) {
final Iterable<Uint8List> embedings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
personEmbeddingsProto.addAll(embedings);
}
final List<Uint8List> sampledEmbeddingsProto =
_randomSampleWithoutReplacement(
personEmbeddingsProto,
sampleSize,
);
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values)
.toList(growable: false);
// Find the actual closest clusters for the person using median
final List<(int, double)> suggestionsMedian = [];
final List<(int, double)> greatSuggestionsMedian = [];
double minMedianDistance = maxMedianDistance;
for (final otherClusterId in otherClusterIdsCandidates) {
final Iterable<Uint8List> otherEmbeddingsProto =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(
otherClusterId,
);
final sampledOtherEmbeddingsProto = _randomSampleWithoutReplacement(
otherEmbeddingsProto,
sampleSize,
);
final List<List<double>> sampledOtherEmbeddings =
sampledOtherEmbeddingsProto
.map((embedding) => EVector.fromBuffer(embedding).values)
.toList(growable: false);
// Calculate distances and find the median
final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) {
distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
}
}
distances.sort();
final double medianDistance = distances[distances.length ~/ 2];
if (medianDistance < minMedianDistance) {
suggestionsMedian.add((otherClusterId, medianDistance));
minMedianDistance = medianDistance;
if (medianDistance < goodMedianDistance) {
greatSuggestionsMedian.add((otherClusterId, medianDistance));
break;
}
}
}
watch.log("Finished median test");
if (suggestionsMedian.isEmpty) {
_logger.info("No suggestions found using median");
return <int>[];
} else {
_logger.info("Found suggestions using median: $suggestionsMedian");
}
final List<int> finalSuggestionsMedian = suggestionsMedian
.map(((e) => e.$1))
.toList(growable: false)
.reversed
.toList(growable: false);
if (greatSuggestionsMedian.isNotEmpty) {
_logger.info(
"Found great suggestion using median: $greatSuggestionsMedian",
);
// // Return the largest size cluster by using allClusterIdsToCountMap
// final List<int> greatSuggestionsMedianClusterIds =
// greatSuggestionsMedian.map((e) => e.$1).toList(growable: false);
// greatSuggestionsMedianClusterIds.sort(
// (a, b) =>
// allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
// );
// return [greatSuggestionsMedian.last.$1, ...finalSuggestionsMedian];
}
return finalSuggestionsMedian;
}
Future<List<(int, List<EnteFile>)>> getClusterFilesForPersonID(
Person person,
) async {
_logger.info(
'getClusterFilesForPersonID ${kDebugMode ? person.attr.name : person.remoteID}',
);
// Get the suggestions for the person using only centroids
// final Map<int, List<(int, double)>> suggestions =
// await getSuggestionsUsingMean(person);
// final Set<int> suggestClusterIds = {};
// for (final List<(int, double)> suggestion in suggestions.values) {
// for (final clusterNeighbors in suggestion) {
// suggestClusterIds.add(clusterNeighbors.$1);
// }
// }
try {
// Get the suggestions for the person using centroids and median
final List<int> suggestClusterIds =
await getSuggestionsUsingMedian(person);
// Get the files for the suggestions
final Map<int, Set<int>> fileIdToClusterID = await FaceMLDataDB.instance
.getFileIdToClusterIDSetForCluster(suggestClusterIds.toSet());
final Map<int, List<EnteFile>> clusterIDToFiles = {};
final allFiles = await SearchService.instance.getAllFiles();
for (final f in allFiles) {
if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) {
continue;
}
final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!;
for (final cluster in cluserIds) {
if (clusterIDToFiles.containsKey(cluster)) {
clusterIDToFiles[cluster]!.add(f);
} else {
clusterIDToFiles[cluster] = [f];
}
}
}
final List<(int, List<EnteFile>)> clusterIdAndFiles = [];
for (final clusterId in suggestClusterIds) {
if (clusterIDToFiles.containsKey(clusterId)) {
clusterIdAndFiles.add(
(clusterId, clusterIDToFiles[clusterId]!),
);
}
}
return clusterIdAndFiles;
} catch (e, s) {
_logger.severe("Error in getClusterFilesForPersonID", e, s);
rethrow;
}
}
Future<void> removePersonFromFiles(List<EnteFile> files, Person p) {
return FaceMLDataDB.instance.removePersonFromFiles(files, p);
}
Future<bool> checkAndDoAutomaticMerges(Person p) async {
final faceMlDb = FaceMLDataDB.instance;
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.attr.name} are $personClusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
0.3,
);
if (suggestions.isEmpty) {
dev.log(
'No automatic merge suggestions for ${p.attr.name}',
name: "ClusterFeedbackService",
);
return false;
}
// log suggestions
for (final entry in suggestions.entries) {
dev.log(
' ${entry.value.length} suggestion for ${p.attr.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
name: "ClusterFeedbackService",
);
}
for (final suggestionsPerCluster in suggestions.values) {
for (final suggestion in suggestionsPerCluster) {
final clusterID = suggestion.$1;
await faceMlDb.assignClusterToPerson(
personID: p.remoteID,
clusterID: clusterID,
);
}
}
Bus.instance.fire(PeopleChangedEvent());
return true;
}
Future<Map<int, List<double>>> _getUpdateClusterAvg(
Map<int, int> allClusterIdsToCountMap,
Set<int> ignoredClusters,
) async {
final faceMlDb = FaceMLDataDB.instance;
final Map<int, (Uint8List, int)> clusterToSummary =
await faceMlDb.clusterSummaryAll();
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {};
final allClusterIds = allClusterIdsToCountMap.keys;
for (final clusterID in allClusterIds) {
if (ignoredClusters.contains(clusterID)) {
continue;
}
late List<double> avg;
if (clusterToSummary[clusterID]?.$2 ==
allClusterIdsToCountMap[clusterID]) {
avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values;
} else {
final Iterable<Uint8List> embedings =
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
final List<double> sum = List.filled(192, 0);
for (final embedding in embedings) {
final data = EVector.fromBuffer(embedding).values;
for (int i = 0; i < sum.length; i++) {
sum[i] += data[i];
}
}
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length);
}
clusterAvg[clusterID] = avg;
}
if (updatesForClusterSummary.isNotEmpty) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
}
return clusterAvg;
}
Map<int, List<(int, double)>> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg,
Set<int> personClusters,
Set<int> ignoredClusters,
double maxClusterDistance,
) {
final Map<int, List<(int, double)>> suggestions = {};
for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored
if (personClusters.contains(otherClusterID) ||
ignoredClusters.contains(otherClusterID)) {
continue;
}
final otherAvg = clusterAvg[otherClusterID]!;
int? nearestPersonCluster;
double? minDistance;
for (final personCluster in personClusters) {
final avg = clusterAvg[personCluster]!;
final distance = cosineDistForNormVectors(avg, otherAvg);
if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) {
minDistance = distance;
nearestPersonCluster = personCluster;
}
}
}
if (nearestPersonCluster != null && minDistance != null) {
suggestions
.putIfAbsent(nearestPersonCluster, () => [])
.add((otherClusterID, minDistance));
}
}
for (final entry in suggestions.entries) {
entry.value.sort((a, b) => a.$1.compareTo(b.$1));
}
return suggestions;
}
List<T> _randomSampleWithoutReplacement<T>(
Iterable<T> embeddings,
int sampleSize,
) {
final random = Random();
if (sampleSize >= embeddings.length) {
return embeddings.toList();
}
// If sampleSize is more than half the list size, shuffle and take first sampleSize elements
if (sampleSize > embeddings.length / 2) {
final List<T> shuffled = List<T>.from(embeddings)..shuffle(random);
return shuffled.take(sampleSize).toList(growable: false);
}
// Otherwise, use the set-based method for efficiency
final selectedIndices = <int>{};
final sampledEmbeddings = <T>[];
while (sampledEmbeddings.length < sampleSize) {
final int index = random.nextInt(embeddings.length);
if (!selectedIndices.contains(index)) {
selectedIndices.add(index);
sampledEmbeddings.add(embeddings.elementAt(index));
}
}
return sampledEmbeddings;
}
}

View file

@ -0,0 +1,11 @@
mixin ModelFile {
static const String faceDetectionBackWeb =
'assets/models/blazeface/blazeface_back_ente_web.tflite';
// TODO: which of the two mobilefacenet model should I use now??
// static const String faceEmbeddingEnte =
// 'assets/models/mobilefacenet/mobilefacenet_ente_web.tflite';
static const String faceEmbeddingEnte =
'assets/models/mobilefacenet/mobilefacenet_unq_TF211.tflite';
static const String yoloV5FaceS640x640DynamicBatchonnx =
'assets/models/yolov5face/yolov5s_face_640_640_dynamic.onnx';
}

View file

@ -11,6 +11,8 @@ import 'package:photos/data/years.dart';
import 'package:photos/db/files_db.dart';
import 'package:photos/events/local_photos_updated_event.dart';
import "package:photos/extensions/string_ext.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/models/api/collection/user.dart";
import 'package:photos/models/collection/collection.dart';
import 'package:photos/models/collection/collection_items.dart';
@ -22,6 +24,7 @@ import "package:photos/models/location/location.dart";
import "package:photos/models/location_tag/location_tag.dart";
import 'package:photos/models/search/album_search_result.dart';
import 'package:photos/models/search/generic_search_result.dart';
import "package:photos/models/search/search_constants.dart";
import "package:photos/models/search/search_types.dart";
import 'package:photos/services/collections_service.dart';
import "package:photos/services/location_service.dart";
@ -29,6 +32,8 @@ import 'package:photos/services/machine_learning/semantic_search/semantic_search
import "package:photos/states/location_screen_state.dart";
import "package:photos/ui/viewer/location/add_location_sheet.dart";
import "package:photos/ui/viewer/location/location_screen.dart";
import "package:photos/ui/viewer/people/cluster_page.dart";
import "package:photos/ui/viewer/people/people_page.dart";
import 'package:photos/utils/date_time_util.dart';
import "package:photos/utils/navigation_util.dart";
import 'package:tuple/tuple.dart';
@ -704,6 +709,146 @@ class SearchService {
return searchResults;
}
Future<Map<int, List<EnteFile>>> getClusterFilesForPersonID(
String personID,
) async {
_logger.info('getClusterFilesForPersonID $personID');
final Map<int, Set<int>> fileIdToClusterID =
await FaceMLDataDB.instance.getFileIdToClusterIDSet(personID);
_logger.info('faceDbDone getClusterFilesForPersonID $personID');
final Map<int, List<EnteFile>> clusterIDToFiles = {};
final allFiles = await getAllFiles();
for (final f in allFiles) {
if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) {
continue;
}
final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!;
for (final cluster in cluserIds) {
if (clusterIDToFiles.containsKey(cluster)) {
clusterIDToFiles[cluster]!.add(f);
} else {
clusterIDToFiles[cluster] = [f];
}
}
}
_logger.info('done getClusterFilesForPersonID $personID');
return clusterIDToFiles;
}
Future<List<GenericSearchResult>> getAllFace(int? limit) async {
debugPrint("getting faces");
final Map<int, Set<int>> fileIdToClusterID =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final (clusterIDToPerson, personIdToPerson) =
await FaceMLDataDB.instance.getClusterIdToPerson();
debugPrint("building result");
final List<GenericSearchResult> facesResult = [];
final Map<int, List<EnteFile>> clusterIdToFiles = {};
final Map<String, List<EnteFile>> personIdToFiles = {};
final allFiles = await getAllFiles();
for (final f in allFiles) {
if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) {
continue;
}
final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!;
for (final cluster in cluserIds) {
final Person? p = clusterIDToPerson[cluster];
if (p != null) {
if (personIdToFiles.containsKey(p.remoteID)) {
personIdToFiles[p.remoteID]!.add(f);
} else {
personIdToFiles[p.remoteID] = [f];
}
} else {
if (clusterIdToFiles.containsKey(cluster)) {
clusterIdToFiles[cluster]!.add(f);
} else {
clusterIdToFiles[cluster] = [f];
}
}
}
}
// get sorted personId by files count
final sortedPersonIds = personIdToFiles.keys.toList()
..sort(
(a, b) => personIdToFiles[b]!.length.compareTo(
personIdToFiles[a]!.length,
),
);
for (final personID in sortedPersonIds) {
final files = personIdToFiles[personID]!;
if (files.isEmpty) {
continue;
}
final Person p = personIdToPerson[personID]!;
facesResult.add(
GenericSearchResult(
ResultType.faces,
p.attr.name,
files,
params: {
kPersonParamID: personID,
kFileID: files.first.uploadedFileID,
},
onResultTap: (ctx) {
routeToPage(
ctx,
PeoplePage(
tagPrefix: "${ResultType.faces.toString()}_${p.attr.name}",
person: p,
),
);
},
),
);
}
final sortedClusterIds = clusterIdToFiles.keys.toList()
..sort(
(a, b) =>
clusterIdToFiles[b]!.length.compareTo(clusterIdToFiles[a]!.length),
);
for (final clusterId in sortedClusterIds) {
final files = clusterIdToFiles[clusterId]!;
// final String clusterName = "ID:$clusterId, ${files.length}";
final String clusterName = "${files.length}";
final Person? p = clusterIDToPerson[clusterId];
if (p != null) {
throw Exception("Person should be null");
}
if (files.length < 3) {
continue;
}
facesResult.add(
GenericSearchResult(
ResultType.faces,
clusterName,
files,
params: {
kClusterParamId: clusterId,
kFileID: files.first.uploadedFileID,
},
onResultTap: (ctx) {
routeToPage(
ctx,
ClusterPage(
files,
tagPrefix: "${ResultType.faces.toString()}_$clusterName",
cluserID: clusterId,
),
);
},
),
);
}
if (limit != null) {
return facesResult.sublist(0, min(limit, facesResult.length));
} else {
return facesResult;
}
}
Future<List<GenericSearchResult>> getAllLocationTags(int? limit) async {
try {
final Map<LocalEntity<LocationTag>, List<EnteFile>> tagToItemsMap = {};

View file

@ -6,6 +6,7 @@ import "package:logging/logging.dart";
import "package:photos/core/constants.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/events/files_updated_event.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/events/tab_changed_event.dart";
import "package:photos/models/search/search_result.dart";
import "package:photos/models/search/search_types.dart";
@ -31,6 +32,7 @@ class _AllSectionsExamplesProviderState
Future<List<List<SearchResult>>> allSectionsExamplesFuture = Future.value([]);
late StreamSubscription<FilesUpdatedEvent> _filesUpdatedEvent;
late StreamSubscription<PeopleChangedEvent> _onPeopleChangedEvent;
late StreamSubscription<TabChangedEvent> _tabChangeEvent;
bool hasPendingUpdate = false;
bool isOnSearchTab = false;
@ -46,16 +48,11 @@ class _AllSectionsExamplesProviderState
super.initState();
//add all common events for all search sections to reload to here.
_filesUpdatedEvent = Bus.instance.on<FilesUpdatedEvent>().listen((event) {
if (!isOnSearchTab) {
if (kDebugMode) {
_logger.finest('Skip reload till user clicks on search tab');
}
hasPendingUpdate = true;
return;
} else {
hasPendingUpdate = false;
reloadAllSections();
}
onDataUpdate();
});
_onPeopleChangedEvent =
Bus.instance.on<PeopleChangedEvent>().listen((event) {
onDataUpdate();
});
_tabChangeEvent = Bus.instance.on<TabChangedEvent>().listen((event) {
if (event.source == TabChangedEventSource.pageView &&
@ -72,6 +69,18 @@ class _AllSectionsExamplesProviderState
reloadAllSections();
}
void onDataUpdate() {
if (!isOnSearchTab) {
if (kDebugMode) {
_logger.finest('Skip reload till user clicks on search tab');
}
hasPendingUpdate = true;
} else {
hasPendingUpdate = false;
reloadAllSections();
}
}
void reloadAllSections() {
_logger.info('queue reload all sections');
_debouncer.run(() async {
@ -79,22 +88,28 @@ class _AllSectionsExamplesProviderState
_logger.info("'_debounceTimer: reloading all sections in search tab");
final allSectionsExamples = <Future<List<SearchResult>>>[];
for (SectionType sectionType in SectionType.values) {
if (sectionType == SectionType.face ||
sectionType == SectionType.content) {
if (sectionType == SectionType.content) {
continue;
}
allSectionsExamples.add(
sectionType.getData(context, limit: kSearchSectionLimit),
);
}
allSectionsExamplesFuture =
Future.wait<List<SearchResult>>(allSectionsExamples);
try {
allSectionsExamplesFuture = Future.wait<List<SearchResult>>(
allSectionsExamples,
eagerError: false,
);
} catch (e) {
_logger.severe("Error reloading all sections: $e");
}
});
});
}
@override
void dispose() {
_onPeopleChangedEvent.cancel();
_filesUpdatedEvent.cancel();
_tabChangeEvent.cancel();
_debouncer.cancelDebounce();

View file

@ -1,5 +1,6 @@
import 'package:flutter/material.dart';
import 'package:photos/core/constants.dart';
import "package:photos/face/model/person.dart";
import 'package:photos/models/collection/collection.dart';
import "package:photos/models/gallery_type.dart";
import 'package:photos/models/selected_files.dart';
@ -11,6 +12,8 @@ import "package:photos/ui/viewer/actions/file_selection_actions_widget.dart";
class BottomActionBarWidget extends StatelessWidget {
final GalleryType galleryType;
final Collection? collection;
final Person? person;
final int? clusterID;
final SelectedFiles selectedFiles;
final VoidCallback? onCancel;
final Color? backgroundColor;
@ -19,6 +22,8 @@ class BottomActionBarWidget extends StatelessWidget {
required this.galleryType,
required this.selectedFiles,
this.collection,
this.person,
this.clusterID,
this.onCancel,
this.backgroundColor,
super.key,
@ -54,6 +59,8 @@ class BottomActionBarWidget extends StatelessWidget {
galleryType,
selectedFiles,
collection: collection,
person: person,
clusterID: clusterID,
),
const DividerWidget(dividerType: DividerType.bottomBar),
ActionBarWidget(

View file

@ -67,7 +67,6 @@ class DebugSectionWidget extends StatelessWidget {
showShortToast(context, "Done");
},
),
sectionOptionSpacing,
],
);
}

View file

@ -0,0 +1,214 @@
import "dart:async";
import "package:flutter/foundation.dart";
import 'package:flutter/material.dart';
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/services/face_ml/face_ml_service.dart";
import "package:photos/services/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/theme/ente_theme.dart';
import 'package:photos/ui/components/captioned_text_widget.dart';
import 'package:photos/ui/components/expandable_menu_item_widget.dart';
import 'package:photos/ui/components/menu_item_widget/menu_item_widget.dart';
import 'package:photos/ui/settings/common_settings.dart';
import "package:photos/utils/dialog_util.dart";
import "package:photos/utils/local_settings.dart";
import 'package:photos/utils/toast_util.dart';
class FaceDebugSectionWidget extends StatefulWidget {
const FaceDebugSectionWidget({Key? key}) : super(key: key);
@override
State<FaceDebugSectionWidget> createState() => _FaceDebugSectionWidgetState();
}
class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
Timer? _timer;
@override
void initState() {
super.initState();
_timer = Timer.periodic(const Duration(seconds: 5), (timer) {
setState(() {
// Your state update logic here
});
});
}
@override
void dispose() {
_timer?.cancel();
super.dispose();
}
@override
Widget build(BuildContext context) {
return ExpandableMenuItemWidget(
title: "Face Beta",
selectionOptionsWidget: _getSectionOptions(context),
leadingIcon: Icons.bug_report_outlined,
);
}
Widget _getSectionOptions(BuildContext context) {
final Logger _logger = Logger("FaceDebugSectionWidget");
return Column(
children: [
MenuItemWidget(
captionedTextWidget: FutureBuilder<Set<int>>(
future: FaceMLDataDB.instance.getIndexedFileIds(),
builder: (context, snapshot) {
if (snapshot.hasData) {
return CaptionedTextWidget(
title: LocalSettings.instance.isFaceIndexingEnabled
? "Disable Indexing (${snapshot.data!.length})"
: "Enable indexing (${snapshot.data!.length})",
);
}
return const SizedBox.shrink();
},
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
try {
final isEnabled =
await LocalSettings.instance.toggleFaceIndexing();
if (isEnabled) {
FaceMlService.instance.indexAllImages().ignore();
} else {
FaceMlService.instance.pauseIndexing();
}
if (mounted) {
setState(() {});
}
} catch (e, s) {
_logger.warning('indexing failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Run Clustering",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
await FaceMlService.instance.clusterAllImages(minFaceScore: 0.75);
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Reset feedback & labels",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
await FaceMLDataDB.instance.resetClusterIDs();
await FaceMLDataDB.instance.dropClustersAndPeople();
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
},
),
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Drop embeddings & feedback",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
await showChoiceDialog(
context,
title: "Are you sure?",
body:
"You will need to again re-index all the faces. You can drop feedback if you want to label again",
firstButtonLabel: "Yes, confirm",
firstButtonOnTap: () async {
await FaceMLDataDB.instance.dropClustersAndPeople(faces: true);
Bus.instance.fire(PeopleChangedEvent());
showShortToast(context, "Done");
},
);
},
),
if (kDebugMode) sectionOptionSpacing,
if (kDebugMode)
MenuItemWidget(
captionedTextWidget: const CaptionedTextWidget(
title: "Pull Embeddings From Local",
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
try {
final List<Person> persons =
await FaceMLDataDB.instance.getPeople();
final EnteWatch w = EnteWatch('feedback')..start();
for (final Person p in persons) {
await ClusterFeedbackService.instance
.getSuggestionsUsingMean(p);
w.logAndReset('suggestion calculated for ${p.attr.name}');
}
w.log("done with feedback");
showShortToast(context, "done avg");
// await FaceMLDataDB.instance.bulkInsertFaces([]);
// final EnteWatch watch = EnteWatch("face_time")..start();
// final results = await downloadZip();
// watch.logAndReset('downloaded and de-serialized');
// await FaceMLDataDB.instance.bulkInsertFaces(results);
// watch.logAndReset('inserted in to db');
// showShortToast(context, "Got ${results.length} results");
} catch (e, s) {
_logger.warning('download failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
// _showKeyAttributesDialog(context);
},
),
if (kDebugMode) sectionOptionSpacing,
if (kDebugMode)
MenuItemWidget(
captionedTextWidget: FutureBuilder<Set<int>>(
future: FaceMLDataDB.instance.getIndexedFileIds(),
builder: (context, snapshot) {
if (snapshot.hasData) {
return CaptionedTextWidget(
title: "Read embeddings for ${snapshot.data!.length} files",
);
}
return const CaptionedTextWidget(
title: "Loading...",
);
},
),
pressedColor: getEnteColorScheme(context).fillFaint,
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
final EnteWatch watch = EnteWatch("read_embeddings")..start();
final result = await FaceMLDataDB.instance.getFaceEmbeddingMap();
watch.logAndReset('read embeddings ${result.length} ');
showShortToast(
context,
"Done in ${watch.elapsed.inSeconds} secs",
);
},
),
],
);
}
}

View file

@ -7,7 +7,6 @@ import 'package:photos/core/configuration.dart';
import 'package:photos/core/event_bus.dart';
import 'package:photos/events/opened_settings_event.dart';
import "package:photos/generated/l10n.dart";
import 'package:photos/services/feature_flag_service.dart';
import "package:photos/services/storage_bonus_service.dart";
import 'package:photos/theme/colors.dart';
import 'package:photos/theme/ente_theme.dart';
@ -17,7 +16,8 @@ import 'package:photos/ui/settings/about_section_widget.dart';
import 'package:photos/ui/settings/account_section_widget.dart';
import 'package:photos/ui/settings/app_version_widget.dart';
import 'package:photos/ui/settings/backup/backup_section_widget.dart';
import 'package:photos/ui/settings/debug_section_widget.dart';
import 'package:photos/ui/settings/debug/debug_section_widget.dart';
import "package:photos/ui/settings/debug/face_debug_section_widget.dart";
import 'package:photos/ui/settings/general_section_widget.dart';
import 'package:photos/ui/settings/inherited_settings_state.dart';
import 'package:photos/ui/settings/security_section_widget.dart';
@ -52,6 +52,10 @@ class SettingsPage extends StatelessWidget {
final hasLoggedIn = Configuration.instance.isLoggedIn();
final enteTextTheme = getEnteTextTheme(context);
final List<Widget> contents = [];
const sectionSpacing = SizedBox(height: 8);
if (kDebugMode) {
contents.addAll([const FaceDebugSectionWidget(), sectionSpacing]);
}
contents.add(
GestureDetector(
onDoubleTap: () {
@ -81,7 +85,7 @@ class SettingsPage extends StatelessWidget {
),
),
);
const sectionSpacing = SizedBox(height: 8);
contents.add(const SizedBox(height: 8));
if (hasLoggedIn) {
final showStorageBonusBanner =
@ -139,9 +143,9 @@ class SettingsPage extends StatelessWidget {
const AboutSectionWidget(),
]);
if (hasLoggedIn &&
FeatureFlagService.instance.isInternalUserOrDebugBuild()) {
if (hasLoggedIn) {
contents.addAll([sectionSpacing, const DebugSectionWidget()]);
contents.addAll([sectionSpacing, const FaceDebugSectionWidget()]);
}
contents.add(const AppVersionWidget());
contents.add(

View file

@ -113,6 +113,7 @@ class _AppLockState extends State<AppLock> with WidgetsBindingObserver {
theme: widget.lightTheme,
darkTheme: widget.darkTheme,
locale: widget.locale,
debugShowCheckedModeBanner: false,
supportedLocales: appSupportedLocales,
localeListResolutionCallback: localResolutionCallBack,
localizationsDelegates: const [

View file

@ -1,10 +1,15 @@
import "dart:async";
import 'package:fast_base58/fast_base58.dart';
import "package:flutter/cupertino.dart";
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import "package:modal_bottom_sheet/modal_bottom_sheet.dart";
import 'package:photos/core/configuration.dart';
import "package:photos/core/event_bus.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/l10n.dart";
import 'package:photos/models/collection/collection.dart';
import 'package:photos/models/device_collection.dart';
@ -15,6 +20,7 @@ import 'package:photos/models/gallery_type.dart';
import "package:photos/models/metadata/common_keys.dart";
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import "package:photos/services/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/services/hidden_service.dart';
import "package:photos/theme/colors.dart";
import "package:photos/theme/ente_theme.dart";
@ -39,12 +45,16 @@ class FileSelectionActionsWidget extends StatefulWidget {
final Collection? collection;
final DeviceCollection? deviceCollection;
final SelectedFiles selectedFiles;
final Person? person;
final int? clusterID;
const FileSelectionActionsWidget(
this.type,
this.selectedFiles, {
Key? key,
this.collection,
this.person,
this.clusterID,
this.deviceCollection,
}) : super(key: key);
@ -116,7 +126,24 @@ class _FileSelectionActionsWidgetState
//and set [shouldShow] to false for items that should not be shown and true
//for items that should be shown.
final List<SelectionActionButton> items = [];
if (widget.type == GalleryType.peopleTag && widget.person != null) {
items.add(
SelectionActionButton(
icon: Icons.remove_circle_outline,
labelText: 'Not ${widget.person!.attr.name}?',
onTap: anyUploadedFiles ? _onNotpersonClicked : null,
),
);
if (ownedFilesCount == 1) {
items.add(
SelectionActionButton(
icon: Icons.image_outlined,
labelText: 'Use as cover',
onTap: anyUploadedFiles ? _setPersonCover : null,
),
);
}
}
if (widget.type.showCreateLink()) {
if (_cachedCollectionForSharedLink != null && anyUploadedFiles) {
items.add(
@ -374,6 +401,16 @@ class _FileSelectionActionsWidgetState
),
);
if (widget.type == GalleryType.cluster) {
items.add(
SelectionActionButton(
labelText: 'Remove',
icon: CupertinoIcons.minus,
onTap: () => showToast(context, 'yet to implement'),
),
);
}
if (items.isNotEmpty) {
final scrollController = ScrollController();
// h4ck: https://github.com/flutter/flutter/issues/57920#issuecomment-893970066
@ -613,6 +650,59 @@ class _FileSelectionActionsWidgetState
}
}
Future<void> _setPersonCover() async {
final EnteFile file = widget.selectedFiles.files.first;
final Person newPerson = widget.person!.copyWith(
attr: widget.person!.attr
.copyWith(avatarFaceId: file.uploadedFileID.toString()),
);
await FaceMLDataDB.instance.updatePerson(newPerson);
widget.selectedFiles.clearAll();
if (mounted) {
setState(() => {});
}
Bus.instance.fire(PeopleChangedEvent());
}
Future<void> _onNotpersonClicked() async {
final actionResult = await showActionSheet(
context: context,
buttons: [
ButtonWidget(
labelText: S.of(context).yesRemove,
buttonType: ButtonType.neutral,
buttonSize: ButtonSize.large,
shouldStickToDarkTheme: true,
buttonAction: ButtonAction.first,
isInAlert: true,
),
ButtonWidget(
labelText: S.of(context).cancel,
buttonType: ButtonType.secondary,
buttonSize: ButtonSize.large,
buttonAction: ButtonAction.second,
shouldStickToDarkTheme: true,
isInAlert: true,
),
],
title: "Remove these photos for ${widget.person!.attr.name}?",
actionSheetType: ActionSheetType.defaultActionSheet,
);
if (actionResult?.action != null) {
if (actionResult!.action == ButtonAction.first) {
await ClusterFeedbackService.instance.removePersonFromFiles(
widget.selectedFiles.files.toList(),
widget.person!,
);
}
Bus.instance.fire(PeopleChangedEvent());
}
widget.selectedFiles.clearAll();
if (mounted) {
setState(() => {});
}
}
Future<void> _copyLink() async {
if (_cachedCollectionForSharedLink != null) {
final String collectionKey = Base58Encode(

View file

@ -1,4 +1,5 @@
import 'package:flutter/material.dart';
import "package:photos/face/model/person.dart";
import 'package:photos/models/collection/collection.dart';
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
@ -10,12 +11,14 @@ class FileSelectionOverlayBar extends StatefulWidget {
final SelectedFiles selectedFiles;
final Collection? collection;
final Color? backgroundColor;
final Person? person;
const FileSelectionOverlayBar(
this.galleryType,
this.selectedFiles, {
this.collection,
this.backgroundColor,
this.person,
Key? key,
}) : super(key: key);
@ -65,6 +68,7 @@ class _FileSelectionOverlayBarState extends State<FileSelectionOverlayBar> {
selectedFiles: widget.selectedFiles,
galleryType: widget.galleryType,
collection: widget.collection,
person: widget.person,
onCancel: () {
if (widget.selectedFiles.files.isNotEmpty) {
widget.selectedFiles.clearAll();

View file

@ -18,9 +18,9 @@ import "package:photos/ui/viewer/file_details/albums_item_widget.dart";
import 'package:photos/ui/viewer/file_details/backed_up_time_item_widget.dart';
import "package:photos/ui/viewer/file_details/creation_time_item_widget.dart";
import 'package:photos/ui/viewer/file_details/exif_item_widgets.dart';
import "package:photos/ui/viewer/file_details/faces_item_widget.dart";
import "package:photos/ui/viewer/file_details/file_properties_item_widget.dart";
import "package:photos/ui/viewer/file_details/location_tags_widget.dart";
import "package:photos/ui/viewer/file_details/objects_item_widget.dart";
import "package:photos/utils/exif_util.dart";
class FileDetailsWidget extends StatefulWidget {
@ -221,7 +221,8 @@ class _FileDetailsWidgetState extends State<FileDetailsWidget> {
if (!UpdateService.instance.isFdroidFlavor()) {
fileDetailsTiles.addAll([
ObjectsItemWidget(file),
// ObjectsItemWidget(file),
FacesItemWidget(file),
const FileDetailsDivider(),
]);
}

View file

@ -1,5 +1,5 @@
import 'dart:async';
import 'dart:io';
import 'dart:io' as io;
import 'package:flutter/material.dart';
import 'package:flutter/widgets.dart';
@ -198,7 +198,7 @@ class _ZoomableImageState extends State<ZoomableImage>
_loadingFinalImage = true;
getFile(
_photo,
isOrigin: Platform.isIOS &&
isOrigin: io.Platform.isIOS &&
_isGIF(), // since on iOS GIFs playback only when origin-files are loaded
).then((file) {
if (file != null && file.existsSync()) {
@ -240,7 +240,25 @@ class _ZoomableImageState extends State<ZoomableImage>
}
}
void _onFinalImageLoaded(ImageProvider imageProvider) {
void _onFinalImageLoaded(ImageProvider imageProvider) async {
// // final result = await FaceMlService.instance.analyzeImage(
// // _photo,
// // preferUsingThumbnailForEverything: false,
// // disposeImageIsolateAfterUse: false,
// // );
// // _logger.info("FaceMlService result: $result");
// // _logger.info("Number of faces detected: ${result.faces.length}");
// // _logger.info("Box: ${result.faces[0].detection.box}");
// // _logger.info("Landmarks: ${result.faces[0].detection.allKeypoints}");
// // final embedding = result.faces[0].embedding;
// // Calculate the magnitude of the embedding vector
// double sum = 0;
// for (final double value in embedding) {
// sum += value * value;
// }
// final magnitude = math.sqrt(sum);
// log("Magnitude: $magnitude");
// log("Embedding: $embedding");
if (mounted) {
precacheImage(imageProvider, context).then((value) async {
if (mounted) {

View file

@ -0,0 +1,160 @@
import "dart:developer" show log;
import "dart:typed_data";
import "package:flutter/material.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart';
import "package:photos/services/search_service.dart";
import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
import "package:photos/ui/viewer/people/cluster_page.dart";
import "package:photos/ui/viewer/people/people_page.dart";
import "package:photos/utils/face/face_box_crop.dart";
import "package:photos/utils/thumbnail_util.dart";
class FaceWidget extends StatelessWidget {
final EnteFile file;
final Face face;
final Person? person;
final int? clusterID;
const FaceWidget(
this.file,
this.face, {
this.person,
this.clusterID,
Key? key,
}) : super(key: key);
@override
Widget build(BuildContext context) {
return FutureBuilder<Uint8List?>(
future: getFaceCrop(),
builder: (context, snapshot) {
if (snapshot.hasData) {
final ImageProvider imageProvider = MemoryImage(snapshot.data!);
return GestureDetector(
onTap: () async {
log(
"FaceWidget is tapped, with person $person and clusterID $clusterID",
name: "FaceWidget",
);
if (person == null && clusterID == null) {
return;
}
if (person != null) {
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PeoplePage(
person: person!,
),
),
);
} else if (clusterID != null) {
final fileIdsToClusterIds =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final files = await SearchService.instance.getAllFiles();
final clusterFiles = files
.where(
(file) =>
fileIdsToClusterIds[file.uploadedFileID]
?.contains(clusterID) ??
false,
)
.toList();
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => ClusterPage(
clusterFiles,
cluserID: clusterID!,
),
),
);
}
},
child: Column(
children: [
ClipOval(
child: SizedBox(
width: 60,
height: 60,
child: Image(
image: imageProvider,
fit: BoxFit.cover,
),
),
),
const SizedBox(height: 8),
if (person != null)
Text(
person!.attr.name.trim(),
style: Theme.of(context).textTheme.bodySmall,
overflow: TextOverflow.ellipsis,
maxLines: 1,
),
],
),
);
} else {
if (snapshot.connectionState == ConnectionState.waiting) {
return const ClipOval(
child: SizedBox(
width: 60, // Ensure consistent sizing
height: 60,
child: CircularProgressIndicator(),
),
);
}
if (snapshot.hasError) {
log('Error getting face: ${snapshot.error}');
}
return const ClipOval(
child: SizedBox(
width: 60, // Ensure consistent sizing
height: 60,
child: NoThumbnailWidget(),
),
);
}
},
);
}
Future<Uint8List?> getFaceCrop() async {
try {
final Uint8List? cachedFace = faceCropCache.get(face.faceID);
if (cachedFace != null) {
return cachedFace;
}
final faceCropCacheFile = cachedFaceCropPath(face.faceID);
if ((await faceCropCacheFile.exists())) {
final data = await faceCropCacheFile.readAsBytes();
faceCropCache.put(face.faceID, data);
return data;
}
final result = await pool.withResource(
() async => await getFaceCrops(
file,
{
face.faceID: face.detection.box,
},
),
);
final Uint8List? computedCrop = result?[face.faceID];
if (computedCrop != null) {
faceCropCache.put(face.faceID, computedCrop);
faceCropCacheFile.writeAsBytes(computedCrop).ignore();
}
return computedCrop;
} catch (e, s) {
log(
"Error getting face for faceID: ${face.faceID}",
error: e,
stackTrace: s,
);
return null;
}
}
}

View file

@ -0,0 +1,79 @@
import "package:flutter/material.dart";
import "package:logging/logging.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import "package:photos/models/file/file.dart";
import "package:photos/ui/components/buttons/chip_button_widget.dart";
import "package:photos/ui/components/info_item_widget.dart";
import "package:photos/ui/viewer/file_details/face_widget.dart";
class FacesItemWidget extends StatelessWidget {
final EnteFile file;
const FacesItemWidget(this.file, {super.key});
@override
Widget build(BuildContext context) {
return InfoItemWidget(
key: const ValueKey("Faces"),
leadingIcon: Icons.face_retouching_natural_outlined,
subtitleSection: _faceWidgets(context, file),
hasChipButtons: true,
);
}
Future<List<Widget>> _faceWidgets(
BuildContext context,
EnteFile file,
) async {
try {
if (file.uploadedFileID == null) {
return [
const ChipButtonWidget(
"File not uploaded yet",
noChips: true,
),
];
}
final List<Face> faces = await FaceMLDataDB.instance
.getFacesForGivenFileID(file.uploadedFileID!);
if (faces.isEmpty || faces.every((face) => face.score < 0.5)) {
return [
const ChipButtonWidget(
"No faces found",
noChips: true,
),
];
}
// Sort the faces by score in descending order, so that the highest scoring face is first.
faces.sort((Face a, Face b) => b.score.compareTo(a.score));
// TODO: add deduplication of faces of same person
final faceIdsToClusterIds = await FaceMLDataDB.instance
.getFaceIdsToClusterIds(faces.map((face) => face.faceID));
final (clusterIDToPerson, personIdToPerson) =
await FaceMLDataDB.instance.getClusterIdToPerson();
final faceWidgets = <FaceWidget>[];
for (final Face face in faces) {
final int? clusterID = faceIdsToClusterIds[face.faceID];
final Person? person = clusterIDToPerson[clusterID];
faceWidgets.add(
FaceWidget(
file,
face,
clusterID: clusterID,
person: person,
),
);
}
return faceWidgets;
} catch (e, s) {
Logger("FacesItemWidget").info(e, s);
return <FaceWidget>[];
}
}
}

View file

@ -27,6 +27,7 @@ class ObjectsItemWidget extends StatelessWidget {
try {
final chipButtons = <ChipButtonWidget>[];
var objectTags = <String, double>{};
// final thumbnail = await getThumbnail(file);
// if (thumbnail != null) {
// objectTags = await ObjectDetectionService.instance.predict(thumbnail);

View file

@ -0,0 +1,301 @@
import "dart:async";
import "dart:developer";
import "dart:math" as math;
import 'package:flutter/material.dart';
import "package:logging/logging.dart";
import 'package:modal_bottom_sheet/modal_bottom_sheet.dart';
import "package:photos/core/event_bus.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/l10n.dart";
import "package:photos/services/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/theme/colors.dart';
import 'package:photos/theme/ente_theme.dart';
import 'package:photos/ui/common/loading_widget.dart';
import 'package:photos/ui/components/bottom_of_title_bar_widget.dart';
import 'package:photos/ui/components/buttons/button_widget.dart';
import 'package:photos/ui/components/models/button_type.dart';
import "package:photos/ui/components/text_input_widget.dart";
import 'package:photos/ui/components/title_bar_title_widget.dart';
import "package:photos/ui/viewer/people/new_person_item_widget.dart";
import "package:photos/ui/viewer/people/person_row_item.dart";
import "package:photos/utils/dialog_util.dart";
import "package:photos/utils/toast_util.dart";
import "package:uuid/uuid.dart";
enum PersonActionType {
assignPerson,
}
String _actionName(
BuildContext context,
PersonActionType type,
) {
String text = "";
switch (type) {
case PersonActionType.assignPerson:
text = "Add name";
break;
}
return text;
}
Future<dynamic> showAssignPersonAction(
BuildContext context, {
required int clusterID,
PersonActionType actionType = PersonActionType.assignPerson,
bool showOptionToCreateNewAlbum = true,
}) {
return showBarModalBottomSheet(
context: context,
builder: (context) {
return PersonActionSheet(
actionType: actionType,
showOptionToCreateNewAlbum: showOptionToCreateNewAlbum,
cluserID: clusterID,
);
},
shape: const RoundedRectangleBorder(
side: BorderSide(width: 0),
borderRadius: BorderRadius.vertical(
top: Radius.circular(5),
),
),
topControl: const SizedBox.shrink(),
backgroundColor: getEnteColorScheme(context).backgroundElevated,
barrierColor: backdropFaintDark,
enableDrag: false,
);
}
class PersonActionSheet extends StatefulWidget {
final PersonActionType actionType;
final int cluserID;
final bool showOptionToCreateNewAlbum;
const PersonActionSheet({
required this.actionType,
required this.cluserID,
required this.showOptionToCreateNewAlbum,
super.key,
});
@override
State<PersonActionSheet> createState() => _PersonActionSheetState();
}
class _PersonActionSheetState extends State<PersonActionSheet> {
static const int cancelButtonSize = 80;
String _searchQuery = "";
@override
void initState() {
super.initState();
}
@override
Widget build(BuildContext context) {
final bottomInset = MediaQuery.of(context).viewInsets.bottom;
final isKeyboardUp = bottomInset > 100;
return Padding(
padding: EdgeInsets.only(
bottom: isKeyboardUp ? bottomInset - cancelButtonSize : 0,
),
child: Row(
mainAxisAlignment: MainAxisAlignment.center,
children: [
ConstrainedBox(
constraints: BoxConstraints(
maxWidth: math.min(428, MediaQuery.of(context).size.width),
),
child: Padding(
padding: const EdgeInsets.fromLTRB(0, 32, 0, 8),
child: Column(
mainAxisSize: MainAxisSize.max,
children: [
Expanded(
child: Column(
children: [
BottomOfTitleBarWidget(
title: TitleBarTitleWidget(
title: _actionName(context, widget.actionType),
),
// caption: 'Select or create a ',
),
Padding(
padding: const EdgeInsets.only(
top: 16,
left: 16,
right: 16,
),
child: TextInputWidget(
hintText: 'Person name',
prefixIcon: Icons.search_rounded,
onChange: (value) {
setState(() {
_searchQuery = value;
});
},
isClearable: true,
shouldUnfocusOnClearOrSubmit: true,
borderRadius: 2,
),
),
_getPersonItems(),
],
),
),
SafeArea(
child: Container(
//inner stroke of 1pt + 15 pts of top padding = 16 pts
padding: const EdgeInsets.fromLTRB(16, 15, 16, 8),
decoration: BoxDecoration(
border: Border(
top: BorderSide(
color: getEnteColorScheme(context).strokeFaint,
),
),
),
child: ButtonWidget(
buttonType: ButtonType.secondary,
buttonAction: ButtonAction.cancel,
isInAlert: true,
labelText: S.of(context).cancel,
),
),
),
],
),
),
),
],
),
);
}
Flexible _getPersonItems() {
return Flexible(
child: Padding(
padding: const EdgeInsets.fromLTRB(16, 24, 4, 0),
child: FutureBuilder<List<Person>>(
future: _getPersons(),
builder: (context, snapshot) {
if (snapshot.hasError) {
log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
//Need to show an error on the UI here
return const SizedBox.shrink();
} else if (snapshot.hasData) {
final persons = snapshot.data as List<Person>;
final searchResults = _searchQuery.isNotEmpty
? persons
.where(
(element) => element.attr.name
.toLowerCase()
.contains(_searchQuery),
)
.toList()
: persons;
final shouldShowCreateAlbum = widget.showOptionToCreateNewAlbum &&
(_searchQuery.isEmpty || searchResults.isEmpty);
return Scrollbar(
thumbVisibility: true,
radius: const Radius.circular(2),
child: Padding(
padding: const EdgeInsets.only(right: 12),
child: ListView.separated(
itemCount:
searchResults.length + (shouldShowCreateAlbum ? 1 : 0),
itemBuilder: (context, index) {
if (index == 0 && shouldShowCreateAlbum) {
return GestureDetector(
child: const NewPersonItemWidget(),
onTap: () async => {
addNewPerson(
context,
initValue: _searchQuery.trim(),
clusterID: widget.cluserID,
),
},
);
}
final person = searchResults[
index - (shouldShowCreateAlbum ? 1 : 0)];
return PersonRowItem(
person: person,
onTap: () async {
await FaceMLDataDB.instance.assignClusterToPerson(
personID: person.remoteID,
clusterID: widget.cluserID,
);
Bus.instance.fire(PeopleChangedEvent());
Navigator.pop(context, person);
},
);
},
separatorBuilder: (context, index) {
return const SizedBox(height: 2);
},
),
),
);
} else {
return const EnteLoadingWidget();
}
},
),
),
);
}
Future<void> addNewPerson(
BuildContext context, {
String initValue = '',
required int clusterID,
}) async {
final result = await showTextInputDialog(
context,
title: "New person",
submitButtonLabel: 'Add',
hintText: 'Add name',
alwaysShowSuccessState: false,
initialValue: initValue,
textCapitalization: TextCapitalization.words,
onSubmit: (String text) async {
// indicates user cancelled the rename request
if (text.trim() == "") {
return;
}
try {
final String id = const Uuid().v4().toString();
final Person p = Person(
id,
PersonAttr(name: text, faces: <String>[]),
);
await FaceMLDataDB.instance.insert(p, clusterID);
final bool extraPhotosFound =
await ClusterFeedbackService.instance.checkAndDoAutomaticMerges(p);
if (extraPhotosFound) {
showShortToast(context, "Extra photos found for $text");
}
Bus.instance.fire(PeopleChangedEvent());
Navigator.pop(context, p);
log("inserted person");
} catch (e, s) {
Logger("_PersonActionSheetState")
.severe("Failed to rename album", e, s);
rethrow;
}
},
);
if (result is Exception) {
await showGenericErrorDialog(context: context, error: result);
}
}
Future<List<Person>> _getPersons() async {
return FaceMLDataDB.instance.getPeople();
}
}

View file

@ -0,0 +1,140 @@
import "dart:async";
import 'package:flutter/material.dart';
import 'package:photos/core/event_bus.dart';
import 'package:photos/events/files_updated_event.dart';
import 'package:photos/events/local_photos_updated_event.dart';
import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart';
import 'package:photos/models/file_load_result.dart';
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart';
import 'package:photos/ui/viewer/gallery/gallery.dart';
import 'package:photos/ui/viewer/gallery/gallery_app_bar_widget.dart';
import "package:photos/ui/viewer/people/add_person_action_sheet.dart";
import "package:photos/ui/viewer/people/people_page.dart";
import "package:photos/ui/viewer/search/result/search_result_page.dart";
import "package:photos/utils/navigation_util.dart";
import "package:photos/utils/toast_util.dart";
class ClusterPage extends StatefulWidget {
final List<EnteFile> searchResult;
final bool enableGrouping;
final String tagPrefix;
final int cluserID;
final Person? personID;
static const GalleryType appBarType = GalleryType.cluster;
static const GalleryType overlayType = GalleryType.cluster;
const ClusterPage(
this.searchResult, {
this.enableGrouping = true,
this.tagPrefix = "",
required this.cluserID,
this.personID,
Key? key,
}) : super(key: key);
@override
State<ClusterPage> createState() => _ClusterPageState();
}
class _ClusterPageState extends State<ClusterPage> {
final _selectedFiles = SelectedFiles();
late final List<EnteFile> files;
late final StreamSubscription<LocalPhotosUpdatedEvent> _filesUpdatedEvent;
@override
void initState() {
super.initState();
files = widget.searchResult;
_filesUpdatedEvent =
Bus.instance.on<LocalPhotosUpdatedEvent>().listen((event) {
if (event.type == EventType.deletedFromDevice ||
event.type == EventType.deletedFromEverywhere ||
event.type == EventType.deletedFromRemote ||
event.type == EventType.hide) {
for (var updatedFile in event.updatedFiles) {
files.remove(updatedFile);
}
setState(() {});
}
});
}
@override
void dispose() {
_filesUpdatedEvent.cancel();
super.dispose();
}
@override
Widget build(BuildContext context) {
final gallery = Gallery(
asyncLoader: (creationStartTime, creationEndTime, {limit, asc}) {
final result = files
.where(
(file) =>
file.creationTime! >= creationStartTime &&
file.creationTime! <= creationEndTime,
)
.toList();
return Future.value(
FileLoadResult(
result,
result.length < files.length,
),
);
},
reloadEvent: Bus.instance.on<LocalPhotosUpdatedEvent>(),
removalEventTypes: const {
EventType.deletedFromRemote,
EventType.deletedFromEverywhere,
EventType.hide,
},
tagPrefix: widget.tagPrefix + widget.tagPrefix,
selectedFiles: _selectedFiles,
enableFileGrouping: widget.enableGrouping,
initialFiles: [widget.searchResult.first],
);
return Scaffold(
appBar: PreferredSize(
preferredSize: const Size.fromHeight(50.0),
child: GestureDetector(
onTap: () async {
if (widget.personID == null) {
final result = await showAssignPersonAction(
context,
clusterID: widget.cluserID,
);
if (result != null && result is Person) {
Navigator.pop(context);
// ignore: unawaited_futures
routeToPage(context, PeoplePage(person: result));
}
} else {
showShortToast(context, "11No personID or clusterID");
}
},
child: GalleryAppBarWidget(
SearchResultPage.appBarType,
widget.personID != null ? widget.personID!.attr.name : "Add name",
_selectedFiles,
),
),
),
body: Stack(
alignment: Alignment.bottomCenter,
children: [
gallery,
FileSelectionOverlayBar(
ClusterPage.overlayType,
_selectedFiles,
),
],
),
);
}
}

View file

@ -0,0 +1,73 @@
import 'package:dotted_border/dotted_border.dart';
import 'package:flutter/material.dart';
import 'package:photos/theme/ente_theme.dart';
///https://www.figma.com/file/SYtMyLBs5SAOkTbfMMzhqt/ente-Visual-Design?node-id=10854%3A57947&t=H5AvR79OYDnB9ekw-4
class NewPersonItemWidget extends StatelessWidget {
const NewPersonItemWidget({
super.key,
});
@override
Widget build(BuildContext context) {
final textTheme = getEnteTextTheme(context);
final colorScheme = getEnteColorScheme(context);
const sideOfThumbnail = 60.0;
return LayoutBuilder(
builder: (context, constraints) {
return Stack(
alignment: Alignment.center,
children: [
Row(
children: [
ClipRRect(
borderRadius: const BorderRadius.horizontal(
left: Radius.circular(4),
),
child: SizedBox(
height: sideOfThumbnail,
width: sideOfThumbnail,
child: Icon(
Icons.add_outlined,
color: colorScheme.strokeMuted,
),
),
),
Padding(
padding: const EdgeInsets.only(left: 12),
child: Text(
'Add person',
style:
textTheme.body.copyWith(color: colorScheme.textMuted),
),
),
],
),
IgnorePointer(
child: DottedBorder(
dashPattern: const [4],
color: colorScheme.strokeFainter,
strokeWidth: 1,
padding: const EdgeInsets.all(0),
borderType: BorderType.RRect,
radius: const Radius.circular(4),
child: SizedBox(
//Have to decrease the height and width by 1 pt as the stroke
//dotted border gives is of strokeAlign.center, so 0.5 inside and
// outside. Here for the row, stroke should be inside so we
//decrease the size of this sizedBox by 1 (so it shrinks 0.5 from
//every side) so that the strokeAlign.center of this sizedBox
//looks like a strokeAlign.inside in the row.
height: sideOfThumbnail - 1,
//This width will work for this only if the row widget takes up the
//full size it's parent (stack).
width: constraints.maxWidth - 1,
),
),
),
],
);
},
);
}
}

View file

@ -0,0 +1,256 @@
import 'dart:async';
import "package:flutter/cupertino.dart";
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import 'package:photos/core/configuration.dart';
import 'package:photos/core/event_bus.dart';
import "package:photos/events/people_changed_event.dart";
import 'package:photos/events/subscription_purchased_event.dart';
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/l10n.dart";
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
import "package:photos/ui/viewer/people/person_cluserts.dart";
import "package:photos/ui/viewer/people/person_cluster_suggestion.dart";
import "package:photos/utils/dialog_util.dart";
class PeopleAppBar extends StatefulWidget {
final GalleryType type;
final String? title;
final SelectedFiles selectedFiles;
final Person person;
const PeopleAppBar(
this.type,
this.title,
this.selectedFiles,
this.person, {
Key? key,
}) : super(key: key);
@override
State<PeopleAppBar> createState() => _AppBarWidgetState();
}
enum PeoplPopupAction {
rename,
setCover,
viewPhotos,
confirmPhotos,
hide,
}
class _AppBarWidgetState extends State<PeopleAppBar> {
final _logger = Logger("_AppBarWidgetState");
late StreamSubscription _userAuthEventSubscription;
late Function() _selectedFilesListener;
String? _appBarTitle;
late CollectionActions collectionActions;
final GlobalKey shareButtonKey = GlobalKey();
bool isQuickLink = false;
late GalleryType galleryType;
@override
void initState() {
super.initState();
_selectedFilesListener = () {
setState(() {});
};
collectionActions = CollectionActions(CollectionsService.instance);
widget.selectedFiles.addListener(_selectedFilesListener);
_userAuthEventSubscription =
Bus.instance.on<SubscriptionPurchasedEvent>().listen((event) {
setState(() {});
});
_appBarTitle = widget.title;
galleryType = widget.type;
}
@override
void dispose() {
_userAuthEventSubscription.cancel();
widget.selectedFiles.removeListener(_selectedFilesListener);
super.dispose();
}
@override
Widget build(BuildContext context) {
return AppBar(
elevation: 0,
centerTitle: false,
title: Text(
_appBarTitle!,
style:
Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16),
maxLines: 2,
overflow: TextOverflow.ellipsis,
),
actions: _getDefaultActions(context),
);
}
Future<dynamic> _renameAlbum(BuildContext context) async {
final result = await showTextInputDialog(
context,
title: 'Rename',
submitButtonLabel: S.of(context).done,
hintText: S.of(context).enterAlbumName,
alwaysShowSuccessState: true,
initialValue: widget.person.attr.name,
textCapitalization: TextCapitalization.words,
onSubmit: (String text) async {
// indicates user cancelled the rename request
if (text == "" || text == _appBarTitle!) {
return;
}
try {
final updatePerson = widget.person
.copyWith(attr: widget.person.attr.copyWith(name: text));
await FaceMLDataDB.instance.updatePerson(updatePerson);
if (mounted) {
_appBarTitle = text;
setState(() {});
}
Bus.instance.fire(PeopleChangedEvent());
} catch (e, s) {
_logger.severe("Failed to rename album", e, s);
rethrow;
}
},
);
if (result is Exception) {
await showGenericErrorDialog(context: context, error: result);
}
}
List<Widget> _getDefaultActions(BuildContext context) {
final List<Widget> actions = <Widget>[];
// If the user has selected files, don't show any actions
if (widget.selectedFiles.files.isNotEmpty ||
!Configuration.instance.hasConfiguredAccount()) {
return actions;
}
final List<PopupMenuItem<PeoplPopupAction>> items = [];
items.addAll(
[
PopupMenuItem(
value: PeoplPopupAction.rename,
child: Row(
children: [
const Icon(Icons.edit),
const Padding(
padding: EdgeInsets.all(8),
),
Text(S.of(context).rename),
],
),
),
// PopupMenuItem(
// value: PeoplPopupAction.setCover,
// child: Row(
// children: [
// const Icon(Icons.image_outlined),
// const Padding(
// padding: EdgeInsets.all(8),
// ),
// Text(S.of(context).setCover),
// ],
// ),
// ),
// PopupMenuItem(
// value: PeoplPopupAction.rename,
// child: Row(
// children: [
// const Icon(Icons.visibility_off),
// const Padding(
// padding: EdgeInsets.all(8),
// ),
// Text(S.of(context).hide),
// ],
// ),
// ),
const PopupMenuItem(
value: PeoplPopupAction.viewPhotos,
child: Row(
children: [
Icon(Icons.view_array_outlined),
Padding(
padding: EdgeInsets.all(8),
),
Text('View confirmed photos'),
],
),
),
const PopupMenuItem(
value: PeoplPopupAction.confirmPhotos,
child: Row(
children: [
Icon(CupertinoIcons.square_stack_3d_down_right),
Padding(
padding: EdgeInsets.all(8),
),
Text('Review suggestions'),
],
),
),
],
);
if (items.isNotEmpty) {
actions.add(
PopupMenuButton(
itemBuilder: (context) {
return items;
},
onSelected: (PeoplPopupAction value) async {
if (value == PeoplPopupAction.viewPhotos) {
// ignore: unawaited_futures
unawaited(
Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PersonClusters(widget.person),
),
),
);
} else if (value == PeoplPopupAction.confirmPhotos) {
// ignore: unawaited_futures
unawaited(
Navigator.of(context).push(
MaterialPageRoute(
builder: (context) =>
PersonReviewClusterSuggestion(widget.person),
),
),
);
} else if (value == PeoplPopupAction.rename) {
await _renameAlbum(context);
} else if (value == PeoplPopupAction.setCover) {
await setCoverPhoto(context);
} else if (value == PeoplPopupAction.hide) {
// ignore: unawaited_futures
}
},
),
);
}
return actions;
}
Future<void> setCoverPhoto(BuildContext context) async {
// final int? coverPhotoID = await showPickCoverPhotoSheet(
// context,
// widget.collection!,
// );
// if (coverPhotoID != null) {
// unawaited(changeCoverPhoto(context, widget.collection!, coverPhotoID));
// }
}
}

View file

@ -0,0 +1,155 @@
import "dart:async";
import "dart:developer";
import 'package:flutter/material.dart';
import "package:logging/logging.dart";
import 'package:photos/core/event_bus.dart';
import 'package:photos/events/files_updated_event.dart';
import 'package:photos/events/local_photos_updated_event.dart';
import "package:photos/events/people_changed_event.dart";
import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart';
import 'package:photos/models/file_load_result.dart';
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import "package:photos/services/search_service.dart";
import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart';
import 'package:photos/ui/viewer/gallery/gallery.dart';
import "package:photos/ui/viewer/people/people_app_bar.dart";
class PeoplePage extends StatefulWidget {
final String tagPrefix;
final Person person;
static const GalleryType appBarType = GalleryType.peopleTag;
static const GalleryType overlayType = GalleryType.peopleTag;
const PeoplePage({
this.tagPrefix = "",
required this.person,
Key? key,
}) : super(key: key);
@override
State<PeoplePage> createState() => _PeoplePageState();
}
class _PeoplePageState extends State<PeoplePage> {
final Logger _logger = Logger("_PeoplePageState");
final _selectedFiles = SelectedFiles();
List<EnteFile>? files;
late final StreamSubscription<LocalPhotosUpdatedEvent> _filesUpdatedEvent;
late final StreamSubscription<PeopleChangedEvent> _peopleChangedEvent;
@override
void initState() {
super.initState();
_peopleChangedEvent = Bus.instance.on<PeopleChangedEvent>().listen((event) {
setState(() {});
});
_filesUpdatedEvent =
Bus.instance.on<LocalPhotosUpdatedEvent>().listen((event) {
if (event.type == EventType.deletedFromDevice ||
event.type == EventType.deletedFromEverywhere ||
event.type == EventType.deletedFromRemote ||
event.type == EventType.hide) {
for (var updatedFile in event.updatedFiles) {
files?.remove(updatedFile);
}
setState(() {});
}
});
}
Future<List<EnteFile>> loadPersonFiles() async {
log("loadPersonFiles");
final result = await SearchService.instance
.getClusterFilesForPersonID(widget.person.remoteID);
final List<EnteFile> resultFiles = [];
for (final e in result.entries) {
resultFiles.addAll(e.value);
}
files = resultFiles;
return resultFiles;
}
@override
void dispose() {
_filesUpdatedEvent.cancel();
_peopleChangedEvent.cancel();
super.dispose();
}
@override
Widget build(BuildContext context) {
_logger.info("Building for ${widget.person.attr.name}");
return Scaffold(
appBar: PreferredSize(
preferredSize: const Size.fromHeight(50.0),
child: PeopleAppBar(
GalleryType.peopleTag,
widget.person.attr.name,
_selectedFiles,
widget.person,
),
),
body: Stack(
alignment: Alignment.bottomCenter,
children: [
FutureBuilder<List<EnteFile>>(
future: loadPersonFiles(),
builder: (context, snapshot) {
if (snapshot.hasData) {
final personFiles = snapshot.data as List<EnteFile>;
return Gallery(
asyncLoader: (
creationStartTime,
creationEndTime, {
limit,
asc,
}) async {
final result = await loadPersonFiles();
return Future.value(
FileLoadResult(
result,
false,
),
);
},
reloadEvent: Bus.instance.on<LocalPhotosUpdatedEvent>(),
forceReloadEvents: [
Bus.instance.on<PeopleChangedEvent>(),
],
removalEventTypes: const {
EventType.deletedFromRemote,
EventType.deletedFromEverywhere,
EventType.hide,
},
tagPrefix: widget.tagPrefix + widget.tagPrefix,
selectedFiles: _selectedFiles,
initialFiles:
personFiles.isNotEmpty ? [personFiles.first] : [],
);
} else if (snapshot.hasError) {
log("Error: ${snapshot.error} ${snapshot.stackTrace}}");
//Need to show an error on the UI here
return const SizedBox.shrink();
} else {
return const Center(
child: CircularProgressIndicator(),
);
}
},
),
FileSelectionOverlayBar(
PeoplePage.overlayType,
_selectedFiles,
person: widget.person,
),
],
),
);
}
}

Some files were not shown because too many files have changed in this diff Show more