Face small improvements (#1839)

## Description

- Fix embeddings fetch issue
- Decrypt embeddings in computer
- Change clustering sorting and remove restrictions
- Cleaned up faces status page


## Tests

Tested in debug mode on pixel phone.
This commit is contained in:
Neeraj Gupta 2024-05-24 12:52:41 +05:30 committed by GitHub
commit 776dba4fb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 105 additions and 127 deletions

View file

@ -13,6 +13,8 @@ import "package:photos/face/model/face.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/utils/ml_util.dart";
import 'package:sqlite_async/sqlite_async.dart';
/// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`.
@ -249,7 +251,7 @@ class FaceMLDataDB {
final List<int> fileId = [recentFileID];
int? avatarFileId;
if (avatarFaceId != null) {
avatarFileId = int.tryParse(avatarFaceId.split('_')[0]);
avatarFileId = tryGetFileIdFromFaceId(avatarFaceId);
if (avatarFileId != null) {
fileId.add(avatarFileId);
}
@ -401,8 +403,10 @@ class FaceMLDataDB {
final personID = map[personIdColumn] as String;
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {})
.add(faceID);
result
.putIfAbsent(personID, () => {})
.putIfAbsent(clusterID, () => {})
.add(faceID);
}
return result;
}
@ -476,8 +480,7 @@ class FaceMLDataDB {
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceID = map[fcFaceId] as String;
final x = faceID.split('_').first;
final fileID = int.parse(x);
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@ -665,19 +668,38 @@ class FaceMLDataDB {
return maps.first['count'] as int;
}
Future<int> getClusteredFaceCount() async {
Future<int> getClusteredFileCount() async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable',
'SELECT $fcFaceId FROM $faceClustersTable',
);
return maps.first['count'] as int;
final Set<int> fileIDs = {};
for (final map in maps) {
final int fileID = getFileIdFromFaceId(map[fcFaceId] as String);
fileIDs.add(fileID);
}
return fileIDs.length;
}
Future<double> getClusteredToTotalFacesRatio() async {
final int totalFaces = await getTotalFaceCount();
final int clusteredFaces = await getClusteredFaceCount();
Future<double> getClusteredToIndexableFilesRatio() async {
final int indexableFiles = (await getIndexableFileIDs()).length;
final int clusteredFiles = await getClusteredFileCount();
return clusteredFaces / totalFaces;
return clusteredFiles / indexableFiles;
}
Future<int> getUnclusteredFaceCount() async {
final db = await instance.asyncDB;
const String query = '''
SELECT f.$faceIDColumn
FROM $facesTable f
LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$fcFaceId
WHERE f.$faceScore > $kMinimumQualityFaceScore
AND f.$faceBlur > $kLaplacianHardThreshold
AND fc.$fcFaceId IS NULL
''';
final List<Map<String, dynamic>> maps = await db.getAll(query);
return maps.length;
}
Future<int> getBlurryFaceCount([
@ -795,7 +817,7 @@ class FaceMLDataDB {
for (final map in maps) {
final clusterID = map[clusterIDColumn] as int;
final String faceID = map[fcFaceId] as String;
final fileID = int.parse(faceID.split('_').first);
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@ -814,8 +836,8 @@ class FaceMLDataDB {
final Map<int, Set<int>> result = {};
for (final map in maps) {
final clusterID = map[fcClusterID] as int;
final faceId = map[fcFaceId] as String;
final fileID = int.parse(faceId.split("_").first);
final faceID = map[fcFaceId] as String;
final fileID = getFileIdFromFaceId(faceID);
result[fileID] = (result[fileID] ?? {})..add(clusterID);
}
return result;
@ -964,7 +986,7 @@ class FaceMLDataDB {
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
if (fileIds.contains(faceID.split('_').first)) {
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
maxClusterID += 1;
faceIDToClusterID[faceID] = maxClusterID;
}
@ -990,7 +1012,7 @@ class FaceMLDataDB {
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
if (fileIds.contains(faceID.split('_').first)) {
if (fileIds.contains(getFileIdFromFaceId(faceID))) {
maxClusterID += 1;
faceIDToClusterID[faceID] = maxClusterID;
}

View file

@ -498,19 +498,8 @@ class FaceClusteringService {
}
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
// Sort the faceInfos such that the ones with null clusterId are at the end
@ -796,19 +785,8 @@ class FaceClusteringService {
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
if (faceInfos.isEmpty) {
@ -996,19 +974,8 @@ class FaceClusteringService {
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
_sortFaceInfosOnCreationTime(faceInfos);
}
// Get the embeddings
@ -1027,3 +994,20 @@ class FaceClusteringService {
return clusteredFaceIDs;
}
}
/// Sort the faceInfos based on fileCreationTime, in descending order, so newest faces are first
void _sortFaceInfosOnCreationTime(
List<FaceInfo> faceInfos,
) {
faceInfos.sort((b, a) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}

View file

@ -310,5 +310,9 @@ class FaceResultBuilder {
}
int getFileIdFromFaceId(String faceId) {
return int.parse(faceId.split("_")[0]);
return int.parse(faceId.split("_").first);
}
int? tryGetFileIdFromFaceId(String faceId) {
return int.tryParse(faceId.split("_").first);
}

View file

@ -12,7 +12,6 @@ import "package:flutter/foundation.dart" show debugPrint, kDebugMode;
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:package_info_plus/package_info_plus.dart";
import "package:photos/core/configuration.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/events/diff_sync_complete_event.dart";
@ -99,6 +98,7 @@ class FaceMlService {
final int _fileDownloadLimit = 5;
final int _embeddingFetchLimit = 200;
final int _kForceClusteringFaceCount = 4000;
Future<void> init({bool initializeImageMlIsolate = false}) async {
if (LocalSettings.instance.isFaceIndexingEnabled == false) {
@ -358,16 +358,17 @@ class FaceMlService {
if (_cannotRunMLFunction()) return;
await sync(forceSync: _shouldSyncPeople);
await indexAllImages();
final indexingCompleteRatio = await _getIndexedDoneRatio();
if (indexingCompleteRatio < 0.95) {
final int unclusteredFacesCount =
await FaceMLDataDB.instance.getUnclusteredFaceCount();
if (unclusteredFacesCount > _kForceClusteringFaceCount) {
_logger.info(
"Indexing is not far enough to start clustering, skipping clustering. Indexing is at $indexingCompleteRatio",
"There are $unclusteredFacesCount unclustered faces, doing clustering first",
);
return;
} else {
await clusterAllImages();
}
await indexAllImages();
await clusterAllImages();
}
void pauseIndexingAndClustering() {
@ -445,7 +446,7 @@ class FaceMlService {
if (LocalSettings.instance.remoteFetchEnabled) {
try {
final List<int> fileIds = [];
final Set<int> fileIds = {}; // if there are duplicates here server returns 400
// Try to find embeddings on the remote server
for (final f in chunk) {
fileIds.add(f.uploadedFileID!);
@ -590,8 +591,8 @@ class FaceMlService {
allFaceInfoForClustering.add(faceInfo);
}
}
// sort the embeddings based on file creation time, oldest first
allFaceInfoForClustering.sort((a, b) {
// sort the embeddings based on file creation time, newest first
allFaceInfoForClustering.sort((b, a) {
return fileIDToCreationTime[a.fileID]!
.compareTo(fileIDToCreationTime[b.fileID]!);
});
@ -1171,24 +1172,6 @@ class FaceMlService {
}
}
Future<double> _getIndexedDoneRatio() async {
final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start();
final int alreadyIndexedCount = await FaceMLDataDB.instance
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final int totalIndexableCount = (await getIndexableFileIDs()).length;
final ratio = alreadyIndexedCount / totalIndexableCount;
w?.log('getIndexedDoneRatio');
return ratio;
}
static Future<List<int>> getIndexableFileIDs() async {
return FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
}
bool _skipAnalysisEnteFile(EnteFile enteFile, Map<int, int> indexedFileIds) {
if (_isIndexingOrClusteringRunning == false ||
_mlControllerStatus == false) {

View file

@ -1,6 +1,7 @@
import "dart:async";
import "dart:convert";
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import "package:photos/core/network/network.dart";
import "package:photos/db/files_db.dart";
@ -16,6 +17,8 @@ import "package:shared_preferences/shared_preferences.dart";
class RemoteFileMLService {
RemoteFileMLService._privateConstructor();
static final Computer _computer = Computer.shared();
static final RemoteFileMLService instance =
RemoteFileMLService._privateConstructor();
@ -52,13 +55,13 @@ class RemoteFileMLService {
}
Future<FilesMLDataResponse> getFilessEmbedding(
List<int> fileIds,
Set<int> fileIds,
) async {
try {
final res = await _dio.post(
"/embeddings/files",
data: {
"fileIDs": fileIds,
"fileIDs": fileIds.toList(),
"model": 'file-ml-clip-face',
},
);
@ -107,15 +110,17 @@ class RemoteFileMLService {
final input = EmbeddingsDecoderInput(embedding, fileKey);
inputs.add(input);
}
// todo: use compute or isolate
return decryptFileMLComputer(
{
return _computer.compute<Map<String, dynamic>, Map<int, FileMl>>(
_decryptFileMLComputer,
param: {
"inputs": inputs,
},
);
}
Future<Map<int, FileMl>> decryptFileMLComputer(
}
Future<Map<int, FileMl>> _decryptFileMLComputer(
Map<String, dynamic> args,
) async {
final result = <int, FileMl>{};
@ -134,5 +139,4 @@ class RemoteFileMLService {
result[input.embedding.fileID] = decodedEmbedding;
}
return result;
}
}
}

View file

@ -23,6 +23,7 @@ import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx
import "package:photos/utils/debouncer.dart";
import "package:photos/utils/device_info.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
import "package:photos/utils/thumbnail_util.dart";
class SemanticSearchService {
@ -160,8 +161,7 @@ class SemanticSearchService {
}
Future<IndexStatus> getIndexStatus() async {
final indexableFileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
final indexableFileIDs = await getIndexableFileIDs();
return IndexStatus(
min(_cachedEmbeddings.length, indexableFileIDs.length),
(await _getFileIDsToBeIndexed()).length,
@ -222,8 +222,7 @@ class SemanticSearchService {
}
Future<List<int>> _getFileIDsToBeIndexed() async {
final uploadedFileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
final uploadedFileIDs = await getIndexableFileIDs();
final embeddedFileIDs =
await EmbeddingsDB.instance.getFileIDs(_currentModel);

View file

@ -754,15 +754,6 @@ class SearchService {
Future<List<GenericSearchResult>> getAllFace(int? limit) async {
try {
// Don't return anything if clustering is not nearly complete yet
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
final clusteredFaces =
await FaceMLDataDB.instance.getClusteredFaceCount();
final clusteringDoneRatio = clusteredFaces / foundFaces;
if (clusteringDoneRatio < 0.9) {
return [];
}
debugPrint("getting faces");
final Map<int, Set<int>> fileIdToClusterID =
await FaceMLDataDB.instance.getFileIdToClusterIds();

View file

@ -177,7 +177,7 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
sectionOptionSpacing,
MenuItemWidget(
captionedTextWidget: FutureBuilder<double>(
future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(),
future: FaceMLDataDB.instance.getClusteredToIndexableFilesRatio(),
builder: (context, snapshot) {
if (snapshot.hasData) {
return CaptionedTextWidget(

View file

@ -26,6 +26,7 @@ import "package:photos/ui/components/title_bar_widget.dart";
import "package:photos/ui/components/toggle_switch_widget.dart";
import "package:photos/utils/data_util.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/ml_util.dart";
final _logger = Logger("MachineLearningSettingsPage");
@ -438,19 +439,16 @@ class FaceRecognitionStatusWidgetState
});
}
Future<(int, int, int, double)> getIndexStatus() async {
Future<(int, int, double)> getIndexStatus() async {
try {
final indexedFiles = await FaceMLDataDB.instance
.getIndexedFileCount(minimumMlVersion: faceMlVersion);
final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length;
final indexableFiles = (await getIndexableFileIDs()).length;
final showIndexedFiles = min(indexedFiles, indexableFiles);
final pendingFiles = max(indexableFiles - indexedFiles, 0);
final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount();
final clusteredFaces =
await FaceMLDataDB.instance.getClusteredFaceCount();
final clusteringDoneRatio = clusteredFaces / foundFaces;
final clusteringDoneRatio = await FaceMLDataDB.instance.getClusteredToIndexableFilesRatio();
return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio);
return (showIndexedFiles, pendingFiles, clusteringDoneRatio);
} catch (e, s) {
_logger.severe('Error getting face recognition status', e, s);
rethrow;
@ -479,8 +477,7 @@ class FaceRecognitionStatusWidgetState
if (snapshot.hasData) {
final int indexedFiles = snapshot.data!.$1;
final int pendingFiles = snapshot.data!.$2;
final int foundFaces = snapshot.data!.$3;
final double clusteringDoneRatio = snapshot.data!.$4;
final double clusteringDoneRatio = snapshot.data!.$3;
final double clusteringPercentage =
(clusteringDoneRatio * 100).clamp(0, 100);
@ -512,19 +509,6 @@ class FaceRecognitionStatusWidgetState
isGestureDetectorDisabled: true,
key: ValueKey("pending_items_" + pendingFiles.toString()),
),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: S.of(context).foundFaces,
),
trailingWidget: Text(
NumberFormat().format(foundFaces),
style: Theme.of(context).textTheme.bodySmall,
),
singleBorderRadius: 8,
alignCaptionedTextToLeft: true,
isGestureDetectorDisabled: true,
key: ValueKey("found_faces_" + foundFaces.toString()),
),
MenuItemWidget(
captionedTextWidget: CaptionedTextWidget(
title: S.of(context).clusteringProgress,

View file

@ -0,0 +1,7 @@
import "package:photos/core/configuration.dart";
import "package:photos/db/files_db.dart";
Future<List<int>> getIndexableFileIDs() async {
return FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
}

View file

@ -12,7 +12,7 @@ description: ente photos application
# Read more about iOS versioning at
# https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html
version: 0.8.109+633
version: 0.8.110+634
publish_to: none
environment: