[mob] Sort clustering on fileCreationTime asc

This commit is contained in:
laurenspriem 2024-03-21 15:41:34 +05:30
parent a9fdee96a8
commit a2bca84b91
4 changed files with 63 additions and 11 deletions

View file

@ -1304,6 +1304,23 @@ class FilesDB {
return result;
}
Future<Map<int,int>> getFileIDToCreationTime() async {
final db = await instance.database;
final rows = await db.rawQuery(
'''
SELECT $columnUploadedFileID, $columnCreationTime
FROM $filesTable
WHERE
($columnUploadedFileID IS NOT NULL AND $columnUploadedFileID IS NOT -1);
''',
);
final result = <int, int>{};
for (final row in rows) {
result[row[columnUploadedFileID] as int] = row[columnCreationTime] as int;
}
return result;
}
// getCollectionFileFirstOrLast returns the first or last uploaded file in
// the collection based on the given collectionID and the order.
Future<EnteFile?> getCollectionFileFirstOrLast(

View file

@ -94,7 +94,12 @@ class FaceLinearClustering {
switch (function) {
case ClusterOperation.linearIncrementalClustering:
final input = args['input'] as Map<String, (int?, Uint8List)>;
final result = FaceLinearClustering._runLinearClustering(input);
final fileIDToCreationTime =
args['fileIDToCreationTime'] as Map<int, int>?;
final result = FaceLinearClustering._runLinearClustering(
input,
fileIDToCreationTime: fileIDToCreationTime,
);
sendPort.send(result);
break;
}
@ -169,8 +174,9 @@ class FaceLinearClustering {
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predict(
Map<String, (int?, Uint8List)> input,
) async {
Map<String, (int?, Uint8List)> input, {
Map<int, int>? fileIDToCreationTime,
}) async {
if (input.isEmpty) {
_logger.warning(
"Clustering dataset of embeddings is empty, returning empty list.",
@ -192,7 +198,10 @@ class FaceLinearClustering {
// final Map<String, int> faceIdToCluster =
// await _runLinearClusteringInComputer(input);
final Map<String, int> faceIdToCluster = await _runInIsolate(
(ClusterOperation.linearIncrementalClustering, {'input': input}),
(
ClusterOperation.linearIncrementalClustering,
{'input': input, 'fileIDToCreationTime': fileIDToCreationTime}
),
);
// return _runLinearClusteringInComputer(input);
_logger.info(
@ -205,8 +214,9 @@ class FaceLinearClustering {
}
static Map<String, int> _runLinearClustering(
Map<String, (int?, Uint8List)> x,
) {
Map<String, (int?, Uint8List)> x, {
Map<int, int>? fileIDToCreationTime,
}) {
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
);
@ -217,9 +227,27 @@ class FaceLinearClustering {
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value.$2).values,
clusterId: entry.value.$1,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}
// Sort the faceInfos such that the ones with null clusterId are at the end
faceInfos.sort((a, b) {
if (a.clusterId == null && b.clusterId == null) {

View file

@ -37,7 +37,7 @@ class ClusterResult {
String get thumbnailFaceId => _thumbnailFaceId;
int get thumbnailFileId => _getFileIdFromFaceId(_thumbnailFaceId);
int get thumbnailFileId => getFileIdFromFaceId(_thumbnailFaceId);
/// Sets the thumbnail faceId to the given faceId.
/// Throws an exception if the faceId is not in the list of faceIds.
@ -89,7 +89,7 @@ class ClusterResult {
int removedCount = 0;
for (var i = 0; i < _fileIds.length; i++) {
if (_fileIds[i] == fileId) {
assert(_getFileIdFromFaceId(_faceIds[i]) == fileId);
assert(getFileIdFromFaceId(_faceIds[i]) == fileId);
_fileIds.removeAt(i);
_faceIds.removeAt(i);
debugPrint(
@ -748,6 +748,6 @@ class FaceResultBuilder {
}
}
int _getFileIdFromFaceId(String faceId) {
int getFileIdFromFaceId(String faceId) {
return int.parse(faceId.split("_")[0]);
}

View file

@ -13,6 +13,7 @@ import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/core/configuration.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/ml_data_db.dart";
import "package:photos/events/diff_sync_complete_event.dart";
import "package:photos/extensions/list.dart";
@ -375,9 +376,15 @@ class FaceMlService {
);
_logger.info('read embeddings ${faceIdToEmbedding.length} ');
// Read the creation times from Files DB, in a map from fileID to creation time
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster =
await FaceLinearClustering.instance.predict(faceIdToEmbedding);
final faceIdToCluster = await FaceLinearClustering.instance.predict(
faceIdToEmbedding,
fileIDToCreationTime: fileIDToCreationTime,
);
if (faceIdToCluster == null) {
_logger.warning("faceIdToCluster is null");
return;