diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 4b728b5a1..bef9089fe 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -386,7 +386,26 @@ class FaceMLDataDB { return maps.map((e) => e[fcFaceId] as String).toSet(); } - Future> getFaceIDsForPerson(String personID) async { + // Get Map of personID to Map of clusterID to faceIDs + Future>>> + getPersonToClusterIdToFaceIds() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $personIdColumn, $faceClustersTable.$fcClusterID, $fcFaceId FROM $clusterPersonTable ' + 'LEFT JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$fcClusterID', + ); + final Map>> result = {}; + for (final map in maps) { + final personID = map[personIdColumn] as String; + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {}) + ..add(faceID); + } + return result; + } + + Future> getFaceIDsForPerson(String personID) async { final db = await instance.asyncDB; final faceIdsResult = await db.getAll( 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index d860ca4b3..913e52268 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -1,10 +1,12 @@ import "dart:async" show unawaited; import "dart:convert"; +import "dart:developer"; import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/stop_watch.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; import "package:photos/models/api/entity/type.dart"; @@ -69,6 +71,89 @@ class PersonService { return entities.map((e) => e.id).toSet(); } + Future reconcileClusters() async { + final EnteWatch? w = kDebugMode ? EnteWatch("reconcileClusters") : null; + w?.start(); + await storeRemoteFeedback(); + w?.log("Stored remote feedback"); + final dbPersonClusterInfo = + await faceMLDataDB.getPersonToClusterIdToFaceIds(); + w?.log("Got DB person cluster info"); + final persons = await getPersonsMap(); + w?.log("Got persons"); + for (var personID in dbPersonClusterInfo.keys) { + final person = persons[personID]; + if (person == null) { + logger.warning("Person $personID not found"); + continue; + } + final personData = person.data; + final Map> dbPersonCluster = + dbPersonClusterInfo[personID]!; + if (_shouldUpdateRemotePerson(personData, dbPersonCluster)) { + final personData = person.data; + personData.assigned = dbPersonCluster.entries + .map( + (e) => ClusterInfo( + id: e.key, + faces: e.value, + ), + ) + .toList(); + entityService + .addOrUpdate( + EntityType.person, + json.encode(personData.toJson()), + id: personID, + ) + .ignore(); + personData.logStats(); + } + } + w?.log("Reconciled clusters for ${persons.length} persons"); + } + + bool _shouldUpdateRemotePerson( + PersonData personData, Map> dbPersonCluster) { + bool result = false; + if ((personData.assigned?.length ?? 0) != dbPersonCluster.length) { + log( + "Person ${personData.name} has ${personData.assigned?.length} clusters, but ${dbPersonCluster.length} clusters found in DB", + name: "PersonService", + ); + result = true; + } else { + for (ClusterInfo info in personData.assigned!) { + final dbCluster = dbPersonCluster[info.id]; + if (dbCluster == null) { + log( + "Cluster ${info.id} not found in DB for person ${personData.name}", + name: "PersonService", + ); + result = true; + continue; + } + if (info.faces.length != dbCluster.length) { + log( + "Cluster ${info.id} has ${info.faces.length} faces, but ${dbCluster.length} faces found in DB", + name: "PersonService", + ); + result = true; + } + for (var faceId in info.faces) { + if (!dbCluster.contains(faceId)) { + log( + "Face $faceId not found in cluster ${info.id} for person ${personData.name}", + name: "PersonService", + ); + result = true; + } + } + } + } + return result; + } + Future addPerson(String name, int clusterID) async { final faceIds = await faceMLDataDB.getFaceIDsForCluster(clusterID); final data = PersonData( diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index 0112490a4..08fb3d9d2 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -288,6 +288,24 @@ class _FaceDebugSectionWidgetState extends State { ); }, ), + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Sync person mappings ", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + await PersonService.instance.reconcileClusters(); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('sync person mappings failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), // sectionOptionSpacing, // MenuItemWidget( // captionedTextWidget: const CaptionedTextWidget(