Merge branch 'mobile_face' of https://github.com/ente-io/auth into mobile_face

This commit is contained in:
Neeraj Gupta 2024-04-12 15:53:34 +05:30
commit fbec7db865
4 changed files with 145 additions and 68 deletions

View file

@ -1,4 +1,5 @@
import 'dart:async';
import "dart:io" show Directory;
import "dart:math";
import "package:collection/collection.dart";
@ -13,6 +14,7 @@ import "package:photos/face/model/face.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import 'package:sqflite/sqflite.dart';
import 'package:sqlite_async/sqlite_async.dart' as sqlite_async;
/// Stores all data for the ML-related features. The database can be accessed by `MlDataDB.instance.database`.
///
@ -29,13 +31,20 @@ class FaceMLDataDB {
static final FaceMLDataDB instance = FaceMLDataDB._privateConstructor();
// only have a single app-wide reference to the database
static Future<Database>? _dbFuture;
static Future<sqlite_async.SqliteDatabase>? _sqliteAsyncDBFuture;
Future<Database> get database async {
_dbFuture ??= _initDatabase();
return _dbFuture!;
}
Future<sqlite_async.SqliteDatabase> get sqliteAsyncDB async {
_sqliteAsyncDBFuture ??= _initSqliteAsyncDatabase();
return _sqliteAsyncDBFuture!;
}
Future<Database> _initDatabase() async {
final documentsDirectory = await getApplicationDocumentsDirectory();
final String databaseDirectory =
@ -47,6 +56,15 @@ class FaceMLDataDB {
);
}
Future<sqlite_async.SqliteDatabase> _initSqliteAsyncDatabase() async {
final Directory documentsDirectory =
await getApplicationDocumentsDirectory();
final String databaseDirectory =
join(documentsDirectory.path, _databaseName);
_logger.info("Opening sqlite_async access: DB path " + databaseDirectory);
return sqlite_async.SqliteDatabase(path: databaseDirectory, maxReaders: 1);
}
Future _onCreate(Database db, int version) async {
await db.execute(createFacesTable);
await db.execute(createFaceClustersTable);
@ -107,8 +125,8 @@ class FaceMLDataDB {
/// Returns a map of fileID to the indexed ML version
Future<Map<int, int>> getIndexedFileIds() async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $fileIDColumn, $mlVersionColumn FROM $facesTable',
);
final Map<int, int> result = {};
@ -398,18 +416,22 @@ class FaceMLDataDB {
w.logAndReset(
'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize',
);
final db = await instance.database;
final db = await instance.sqliteAsyncDB;
final Map<String, (int?, Uint8List)> result = {};
while (true) {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.query(
facesTable,
columns: [faceIDColumn, faceEmbeddingBlob],
where: '$faceScore > $minScore and $faceBlur > $minClarity',
limit: batchSize,
offset: offset,
orderBy: '$faceIDColumn DESC',
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceIDColumn, $faceEmbeddingBlob FROM $facesTable'
' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
' ORDER BY $faceIDColumn'
' DESC LIMIT $batchSize OFFSET $offset',
// facesTable,
// columns: [faceIDColumn, faceEmbeddingBlob],
// where: '$faceScore > $minScore and $faceBlur > $minClarity',
// limit: batchSize,
// offset: offset,
// orderBy: '$faceIDColumn DESC',
);
// Break the loop if no more rows
if (maps.isEmpty) {
@ -476,8 +498,8 @@ class FaceMLDataDB {
Future<int> getTotalFaceCount({
double minFaceScore = kMinHighQualityFaceScore,
}) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold',
);
return maps.first['count'] as int;

View file

@ -36,7 +36,7 @@ class FaceClustering {
final _logger = Logger("FaceLinearClustering");
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(seconds: 30);
final Duration _inactivityDuration = const Duration(seconds: 90);
int _activeTasks = 0;
final _initLock = Lock();

View file

@ -420,11 +420,15 @@ class FaceMlService {
offset += offsetIncrement;
}
} else {
final int totalFaces = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: minFaceScore);
// Read all the embeddings from the database, in a map from faceID to embedding
final clusterStartTime = DateTime.now();
final faceIdToEmbedding =
await FaceMLDataDB.instance.getFaceEmbeddingMap(
minScore: minFaceScore,
maxFaces: totalFaces,
);
final gotFaceEmbeddingsTime = DateTime.now();
_logger.info(
@ -514,7 +518,7 @@ class FaceMlService {
/// Analyzes all the images in the database with the latest ml version and stores the results in the database.
///
/// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image.
Future<void> indexAllImages() async {
Future<void> indexAllImages({withFetching = true}) async {
if (isImageIndexRunning) {
_logger.warning("indexAllImages is already running, skipping");
return;
@ -566,44 +570,48 @@ class FaceMlService {
for (final f in chunk) {
fileIds.add(f.uploadedFileID!);
}
try {
final EnteWatch? w = kDebugMode ? EnteWatch("face_em_fetch") : null;
w?.start();
w?.log('starting remote fetch for ${fileIds.length} files');
final res =
await RemoteFileMLService.instance.getFilessEmbedding(fileIds);
w?.logAndReset('fetched ${res.mlData.length} embeddings');
final List<Face> faces = [];
final remoteFileIdToVersion = <int, int>{};
for (FileMl fileMl in res.mlData.values) {
if (shouldDiscardRemoteEmbedding(fileMl)) continue;
if (fileMl.faceEmbedding.faces.isEmpty) {
faces.add(
Face.empty(
fileMl.fileID,
error: (fileMl.faceEmbedding.error ?? false),
),
);
} else {
for (final f in fileMl.faceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: fileMl.height,
imageWidth: fileMl.width,
if (withFetching) {
try {
final EnteWatch? w = kDebugMode ? EnteWatch("face_em_fetch") : null;
w?.start();
w?.log('starting remote fetch for ${fileIds.length} files');
final res =
await RemoteFileMLService.instance.getFilessEmbedding(fileIds);
w?.logAndReset('fetched ${res.mlData.length} embeddings');
final List<Face> faces = [];
final remoteFileIdToVersion = <int, int>{};
for (FileMl fileMl in res.mlData.values) {
if (shouldDiscardRemoteEmbedding(fileMl)) continue;
if (fileMl.faceEmbedding.faces.isEmpty) {
faces.add(
Face.empty(
fileMl.fileID,
error: (fileMl.faceEmbedding.error ?? false),
),
);
faces.add(f);
} else {
for (final f in fileMl.faceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: fileMl.height,
imageWidth: fileMl.width,
);
faces.add(f);
}
}
remoteFileIdToVersion[fileMl.fileID] =
fileMl.faceEmbedding.version;
}
remoteFileIdToVersion[fileMl.fileID] = fileMl.faceEmbedding.version;
await FaceMLDataDB.instance.bulkInsertFaces(faces);
w?.logAndReset('stored embeddings');
for (final entry in remoteFileIdToVersion.entries) {
alreadyIndexedFiles[entry.key] = entry.value;
}
_logger
.info('already indexed files ${remoteFileIdToVersion.length}');
} catch (e, s) {
_logger.severe("err while getting files embeddings", e, s);
rethrow;
}
await FaceMLDataDB.instance.bulkInsertFaces(faces);
w?.logAndReset('stored embeddings');
for (final entry in remoteFileIdToVersion.entries) {
alreadyIndexedFiles[entry.key] = entry.value;
}
_logger.info('already indexed files ${remoteFileIdToVersion.length}');
} catch (e, s) {
_logger.severe("err while getting files embeddings", e, s);
rethrow;
}
for (final enteFile in chunk) {

View file

@ -5,10 +5,8 @@ import 'package:flutter/material.dart';
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart';
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/theme/ente_theme.dart';
@ -65,7 +63,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
if (snapshot.hasData) {
return CaptionedTextWidget(
title: LocalSettings.instance.isFaceIndexingEnabled
? "Disable Indexing (${snapshot.data!.length})"
? "Disable indexing (${snapshot.data!.length})"
: "Enable indexing (${snapshot.data!.length})",
);
}
@ -94,12 +92,14 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
},
),
MenuItemWidget(
captionedTextWidget: FutureBuilder<int>(
future: FaceMLDataDB.instance.getTotalFaceCount(),
captionedTextWidget: FutureBuilder<Map<int, int>>(
future: FaceMLDataDB.instance.getIndexedFileIds(),
builder: (context, snapshot) {
if (snapshot.hasData) {
return CaptionedTextWidget(
title: "${snapshot.data!} high quality faces",
title: LocalSettings.instance.isFaceIndexingEnabled
? "Disable indexing (no fetch) (${snapshot.data!.length})"
: "Enable indexing (${snapshot.data!.length})",
);
}
return const SizedBox.shrink();
@ -109,16 +109,51 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
final faces75 = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: 0.75);
final faces78 = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore);
final blurryFaceCount =
await FaceMLDataDB.instance.getBlurryFaceCount(15);
showShortToast(context, "$blurryFaceCount blurry faces");
try {
final isEnabled =
await LocalSettings.instance.toggleFaceIndexing();
if (isEnabled) {
FaceMlService.instance
.indexAllImages(withFetching: false)
.ignore();
} else {
FaceMlService.instance.pauseIndexing();
}
if (mounted) {
setState(() {});
}
} catch (e, s) {
_logger.warning('indexing failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
// MenuItemWidget(
// captionedTextWidget: FutureBuilder<int>(
// future: FaceMLDataDB.instance.getTotalFaceCount(),
// builder: (context, snapshot) {
// if (snapshot.hasData) {
// return CaptionedTextWidget(
// title: "${snapshot.data!} high quality faces",
// );
// }
// return const SizedBox.shrink();
// },
// ),
// pressedColor: getEnteColorScheme(context).fillFaint,
// trailingIcon: Icons.chevron_right_outlined,
// trailingIconIsMuted: true,
// onTap: () async {
// final faces75 = await FaceMLDataDB.instance
// .getTotalFaceCount(minFaceScore: 0.75);
// final faces78 = await FaceMLDataDB.instance
// .getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore);
// final blurryFaceCount =
// await FaceMLDataDB.instance.getBlurryFaceCount(15);
// showShortToast(context, "$blurryFaceCount blurry faces");
// },
// ),
// MenuItemWidget(
// captionedTextWidget: const CaptionedTextWidget(
// title: "Analyze file ID 25728869",
// ),
@ -296,13 +331,25 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
trailingIcon: Icons.chevron_right_outlined,
trailingIconIsMuted: true,
onTap: () async {
final EnteWatch watch = EnteWatch("read_embeddings")..start();
final result = await FaceMLDataDB.instance.getFaceEmbeddingMap();
watch.logAndReset('read embeddings ${result.length} ');
showShortToast(
context,
"Read ${result.length} face embeddings in ${watch.elapsed.inSeconds} secs",
);
final int totalFaces =
await FaceMLDataDB.instance.getTotalFaceCount();
_logger.info('start reading embeddings for $totalFaces faces');
final time = DateTime.now();
try {
final result = await FaceMLDataDB.instance
.getFaceEmbeddingMap(maxFaces: totalFaces);
final endTime = DateTime.now();
_logger.info(
'Read embeddings of ${result.length} faces in ${time.difference(endTime).inSeconds} secs',
);
showShortToast(
context,
"Read embeddings of ${result.length} faces in ${time.difference(endTime).inSeconds} secs",
);
} catch (e, s) {
_logger.warning('read embeddings failed ', e, s);
await showGenericErrorDialog(context: context, error: e);
}
},
),
],