diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 0dca25e93..63ddef47e 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -481,6 +481,16 @@ class FaceMLDataDB { return maps.first['count'] as int; } + Future getBlurryFaceCount([ + int blurThreshold = kLaplacianThreshold, + ]) async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore', + ); + return maps.first['count'] as int; + } + Future resetClusterIDs() async { final db = await instance.database; await db.execute(dropFaceClustersTable); @@ -726,7 +736,7 @@ class FaceMLDataDB { for (final enteFile in files) { fileIds.add(enteFile.uploadedFileID.toString()); } - int maxClusterID = DateTime.now().millisecondsSinceEpoch; + int maxClusterID = DateTime.now().microsecondsSinceEpoch; final Map faceIDToClusterID = {}; for (final row in faceIdsResult) { final faceID = row[fcFaceId] as String; @@ -752,7 +762,7 @@ class FaceMLDataDB { for (final enteFile in files) { fileIds.add(enteFile.uploadedFileID.toString()); } - int maxClusterID = DateTime.now().millisecondsSinceEpoch; + int maxClusterID = DateTime.now().microsecondsSinceEpoch; final Map faceIDToClusterID = {}; for (final row in faceIdsResult) { final faceID = row[fcFaceId] as String; @@ -763,4 +773,16 @@ class FaceMLDataDB { } await forceUpdateClusterIds(faceIDToClusterID); } + + Future addFacesToCluster( + List faceIDs, + int clusterID, + ) async { + final faceIDToClusterID = {}; + for (final faceID in faceIDs) { + faceIDToClusterID[faceID] = clusterID; + } + + await forceUpdateClusterIds(faceIDToClusterID); + } } diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart index 44acd85bb..1ffaa1eb3 100644 --- a/mobile/lib/face/model/detection.dart +++ b/mobile/lib/face/model/detection.dart @@ -3,6 +3,9 @@ import "package:photos/face/model/landmark.dart"; /// Stores the face detection data, notably the bounding box and landmarks. /// +/// - Bounding box: [FaceBox] with xMin, yMin (so top left corner), width, height +/// - Landmarks: list of [Landmark]s, namely leftEye, rightEye, nose, leftMouth, rightMouth +/// /// WARNING: All coordinates are relative to the image size, so in the range [0, 1]! class Detection { FaceBox box; @@ -39,4 +42,43 @@ class Detection { ), ); } + + // TODO: iterate on better area calculation, potentially using actual indexing image dimensions instead of file metadata + int getFaceArea(int imageWidth, int imageHeight) { + return (box.width * imageWidth * box.height * imageHeight).toInt(); + } + + // TODO: iterate on better scoring logic, current is a placeholder + int getVisibilityScore() { + final double aspectRatio = box.width / box.height; + final double eyeDistance = (landmarks[1].x - landmarks[0].x).abs(); + final double mouthDistance = (landmarks[4].x - landmarks[3].x).abs(); + final double noseEyeDistance = + (landmarks[2].y - ((landmarks[0].y + landmarks[1].y) / 2)).abs(); + + final double normalizedEyeDistance = eyeDistance / box.width; + final double normalizedMouthDistance = mouthDistance / box.width; + final double normalizedNoseEyeDistance = noseEyeDistance / box.height; + + const double aspectRatioThreshold = 0.8; + const double eyeDistanceThreshold = 0.2; + const double mouthDistanceThreshold = 0.3; + const double noseEyeDistanceThreshold = 0.1; + + double score = 0; + if (aspectRatio >= aspectRatioThreshold) { + score += 50; + } + if (normalizedEyeDistance >= eyeDistanceThreshold) { + score += 20; + } + if (normalizedMouthDistance >= mouthDistanceThreshold) { + score += 20; + } + if (normalizedNoseEyeDistance >= noseEyeDistanceThreshold) { + score += 10; + } + + return score.clamp(0, 100).toInt(); + } } diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart index 86b747551..8203958e9 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart @@ -353,14 +353,8 @@ class FaceClustering { // Make sure the first face has a clusterId final int totalFaces = sortedFaceInfos.length; // set current epoch time as clusterID - int clusterID = DateTime.now().millisecondsSinceEpoch; - if (sortedFaceInfos.isNotEmpty) { - if (sortedFaceInfos.first.clusterId == null) { - sortedFaceInfos.first.clusterId = clusterID; - } else { - clusterID = sortedFaceInfos.first.clusterId!; - } - } else { + int clusterID = DateTime.now().microsecondsSinceEpoch; + if (sortedFaceInfos.isEmpty) { return {}; } @@ -401,6 +395,12 @@ class FaceClustering { if (distance < closestDistance) { closestDistance = distance; closestIdx = j; + // if (distance < distanceThreshold) { + // if (sortedFaceInfos[j].faceID.startsWith("14914702") || + // sortedFaceInfos[j].faceID.startsWith("15488756")) { + // log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance'); + // } + // } } } @@ -414,10 +414,22 @@ class FaceClustering { sortedFaceInfos[closestIdx].clusterId = clusterID; newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID; } + // if (sortedFaceInfos[i].faceID.startsWith("14914702") || + // sortedFaceInfos[i].faceID.startsWith("15488756")) { + // log( + // "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance", + // ); + // } sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; newFaceIdToCluster[sortedFaceInfos[i].faceID] = sortedFaceInfos[closestIdx].clusterId!; } else { + // if (sortedFaceInfos[i].faceID.startsWith("14914702") || + // sortedFaceInfos[i].faceID.startsWith("15488756")) { + // log( + // "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}", + // ); + // } clusterID++; sortedFaceInfos[i].clusterId = clusterID; newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index e163defe6..d84c5dce2 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -654,7 +654,7 @@ class FaceMlService { .map( (keypoint) => Landmark( x: keypoint[0], - y: keypoint[0], + y: keypoint[1], ), ) .toList(), diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 7a83c8d82..cef19c7a2 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -325,12 +325,25 @@ class ClusterFeedbackService { } } - Future removeFilesFromPerson(List files, Person p) { - return FaceMLDataDB.instance.removeFilesFromPerson(files, p); + Future removeFilesFromPerson(List files, Person p) async { + await FaceMLDataDB.instance.removeFilesFromPerson(files, p); + Bus.instance.fire(PeopleChangedEvent()); + return; } - Future removeFilesFromCluster(List files, int clusterID) { - return FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID); + Future removeFilesFromCluster( + List files, + int clusterID, + ) async { + await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID); + Bus.instance.fire(PeopleChangedEvent()); + return; + } + + Future addFilesToCluster(List faceIDs, int clusterID) async { + await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID); + Bus.instance.fire(PeopleChangedEvent()); + return; } Future checkAndDoAutomaticMerges(Person p) async { @@ -413,7 +426,7 @@ class ClusterFeedbackService { embeddings, fileIDToCreationTime: fileIDToCreationTime, eps: 0.30, - minPts: 5, + minPts: 8, ); if (dbscanClusters.isEmpty) { diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart index e09574356..559eed18b 100644 --- a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -114,7 +114,9 @@ class _FaceDebugSectionWidgetState extends State { .getTotalFaceCount(minFaceScore: 0.75); final faces78 = await FaceMLDataDB.instance .getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore); - showShortToast(context, "Faces75: $faces75, Faces78: $faces78"); + final blurryFaceCount = + await FaceMLDataDB.instance.getBlurryFaceCount(15); + showShortToast(context, "$blurryFaceCount blurry faces"); }, ), // MenuItemWidget( diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart index 35163750f..9ee41add2 100644 --- a/mobile/lib/ui/viewer/file_details/face_widget.dart +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -2,12 +2,14 @@ import "dart:developer" show log; import "dart:io" show Platform; import "dart:typed_data"; +import "package:flutter/cupertino.dart"; import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/material.dart"; import "package:photos/face/db.dart"; import "package:photos/face/model/face.dart"; import "package:photos/face/model/person.dart"; import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import "package:photos/services/search_service.dart"; import "package:photos/theme/ente_theme.dart"; import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; @@ -16,13 +18,15 @@ import "package:photos/ui/viewer/people/cropped_face_image_view.dart"; import "package:photos/ui/viewer/people/people_page.dart"; import "package:photos/utils/face/face_box_crop.dart"; import "package:photos/utils/thumbnail_util.dart"; +// import "package:photos/utils/toast_util.dart"; -class FaceWidget extends StatelessWidget { +class FaceWidget extends StatefulWidget { final EnteFile file; final Face face; final Person? person; final int? clusterID; final bool highlight; + final bool editMode; const FaceWidget( this.file, @@ -30,9 +34,17 @@ class FaceWidget extends StatelessWidget { this.person, this.clusterID, this.highlight = false, + this.editMode = false, Key? key, }) : super(key: key); + @override + State createState() => _FaceWidgetState(); +} + +class _FaceWidgetState extends State { + bool isJustRemoved = false; + @override Widget build(BuildContext context) { if (Platform.isIOS || Platform.isAndroid) { @@ -43,22 +55,24 @@ class FaceWidget extends StatelessWidget { final ImageProvider imageProvider = MemoryImage(snapshot.data!); return GestureDetector( onTap: () async { + if (widget.editMode) return; + log( - "FaceWidget is tapped, with person $person and clusterID $clusterID", + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", name: "FaceWidget", ); - if (person == null && clusterID == null) { + if (widget.person == null && widget.clusterID == null) { return; } - if (person != null) { + if (widget.person != null) { await Navigator.of(context).push( MaterialPageRoute( builder: (context) => PeoplePage( - person: person!, + person: widget.person!, ), ), ); - } else if (clusterID != null) { + } else if (widget.clusterID != null) { final fileIdsToClusterIds = await FaceMLDataDB.instance.getFileIdToClusterIds(); final files = await SearchService.instance.getAllFiles(); @@ -66,7 +80,7 @@ class FaceWidget extends StatelessWidget { .where( (file) => fileIdsToClusterIds[file.uploadedFileID] - ?.contains(clusterID) ?? + ?.contains(widget.clusterID) ?? false, ) .toList(); @@ -74,7 +88,7 @@ class FaceWidget extends StatelessWidget { MaterialPageRoute( builder: (context) => ClusterPage( clusterFiles, - clusterID: clusterID!, + clusterID: widget.clusterID!, ), ), ); @@ -82,46 +96,87 @@ class FaceWidget extends StatelessWidget { }, child: Column( children: [ - // TODO: the edges of the green line are still not properly rounded around ClipRRect - Container( - height: 60, - width: 60, - decoration: ShapeDecoration( - shape: RoundedRectangleBorder( - borderRadius: - const BorderRadius.all(Radius.elliptical(16, 12)), - side: highlight - ? BorderSide( - color: getEnteColorScheme(context).primary700, - width: 2.0, - ) - : BorderSide.none, - ), - ), - child: ClipRRect( - borderRadius: - const BorderRadius.all(Radius.elliptical(16, 12)), - child: SizedBox( - width: 60, + Stack( + children: [ + Container( height: 60, - child: Image( - image: imageProvider, - fit: BoxFit.cover, + width: 60, + decoration: ShapeDecoration( + shape: RoundedRectangleBorder( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + side: widget.highlight + ? BorderSide( + color: + getEnteColorScheme(context).primary700, + width: 1.0, + ) + : BorderSide.none, + ), + ), + child: ClipRRect( + borderRadius: + const BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ), ), ), - ), + // TODO: the edges of the green line are still not properly rounded around ClipRRect + if (widget.editMode) + Positioned( + right: 0, + top: 0, + child: GestureDetector( + onTap: _cornerIconPressed, + child: isJustRemoved + ? const Icon( + CupertinoIcons.add_circled_solid, + color: Colors.green, + ) + : const Icon( + Icons.cancel, + color: Colors.red, + ), + ), + ), + ], ), const SizedBox(height: 8), - if (person != null) + if (widget.person != null) Text( - person!.attr.name.trim(), + widget.person!.attr.name.trim(), style: Theme.of(context).textTheme.bodySmall, overflow: TextOverflow.ellipsis, maxLines: 1, ), if (kDebugMode) Text( - 'S: ${face.score.toStringAsFixed(3)}', + 'S: ${widget.face.score.toStringAsFixed(3)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'B: ${widget.face.blur.toStringAsFixed(3)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'V: ${widget.face.detection.getVisibilityScore()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'A: ${widget.face.detection.getFaceArea(widget.file.width, widget.file.height)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), @@ -168,21 +223,21 @@ class FaceWidget extends StatelessWidget { return GestureDetector( onTap: () async { log( - "FaceWidget is tapped, with person $person and clusterID $clusterID", + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", name: "FaceWidget", ); - if (person == null && clusterID == null) { + if (widget.person == null && widget.clusterID == null) { return; } - if (person != null) { + if (widget.person != null) { await Navigator.of(context).push( MaterialPageRoute( builder: (context) => PeoplePage( - person: person!, + person: widget.person!, ), ), ); - } else if (clusterID != null) { + } else if (widget.clusterID != null) { final fileIdsToClusterIds = await FaceMLDataDB.instance.getFileIdToClusterIds(); final files = await SearchService.instance.getAllFiles(); @@ -190,7 +245,7 @@ class FaceWidget extends StatelessWidget { .where( (file) => fileIdsToClusterIds[file.uploadedFileID] - ?.contains(clusterID) ?? + ?.contains(widget.clusterID) ?? false, ) .toList(); @@ -198,7 +253,7 @@ class FaceWidget extends StatelessWidget { MaterialPageRoute( builder: (context) => ClusterPage( clusterFiles, - clusterID: clusterID!, + clusterID: widget.clusterID!, ), ), ); @@ -213,7 +268,7 @@ class FaceWidget extends StatelessWidget { shape: RoundedRectangleBorder( borderRadius: const BorderRadius.all(Radius.elliptical(16, 12)), - side: highlight + side: widget.highlight ? BorderSide( color: getEnteColorScheme(context).primary700, width: 2.0, @@ -228,23 +283,23 @@ class FaceWidget extends StatelessWidget { width: 60, height: 60, child: CroppedFaceImageView( - enteFile: file, - face: face, + enteFile: widget.file, + face: widget.face, ), ), ), ), const SizedBox(height: 8), - if (person != null) + if (widget.person != null) Text( - person!.attr.name.trim(), + widget.person!.attr.name.trim(), style: Theme.of(context).textTheme.bodySmall, overflow: TextOverflow.ellipsis, maxLines: 1, ), if (kDebugMode) Text( - 'S: ${face.score.toStringAsFixed(3)}', + 'S: ${widget.face.score.toStringAsFixed(3)}', style: Theme.of(context).textTheme.bodySmall, maxLines: 1, ), @@ -256,36 +311,55 @@ class FaceWidget extends StatelessWidget { } } + void _cornerIconPressed() async { + log('face widget (file info) corner icon is pressed'); + try { + if (isJustRemoved) { + await ClusterFeedbackService.instance + .addFilesToCluster([widget.face.faceID], widget.clusterID!); + } else { + await ClusterFeedbackService.instance + .removeFilesFromCluster([widget.file], widget.clusterID!); + } + + setState(() { + isJustRemoved = !isJustRemoved; + }); + } catch (e, s) { + log("removing face/file from cluster from file info widget failed: $e, \n $s"); + } + } + Future getFaceCrop() async { try { - final Uint8List? cachedFace = faceCropCache.get(face.faceID); + final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID); if (cachedFace != null) { return cachedFace; } - final faceCropCacheFile = cachedFaceCropPath(face.faceID); + final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID); if ((await faceCropCacheFile.exists())) { final data = await faceCropCacheFile.readAsBytes(); - faceCropCache.put(face.faceID, data); + faceCropCache.put(widget.face.faceID, data); return data; } final result = await pool.withResource( () async => await getFaceCrops( - file, + widget.file, { - face.faceID: face.detection.box, + widget.face.faceID: widget.face.detection.box, }, ), ); - final Uint8List? computedCrop = result?[face.faceID]; + final Uint8List? computedCrop = result?[widget.face.faceID]; if (computedCrop != null) { - faceCropCache.put(face.faceID, computedCrop); + faceCropCache.put(widget.face.faceID, computedCrop); faceCropCacheFile.writeAsBytes(computedCrop).ignore(); } return computedCrop; } catch (e, s) { log( - "Error getting face for faceID: ${face.faceID}", + "Error getting face for faceID: ${widget.face.faceID}", error: e, stackTrace: s, ); diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart index 3a541a477..d06f83974 100644 --- a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -1,3 +1,4 @@ +import "package:flutter/foundation.dart" show kDebugMode; import "package:flutter/material.dart"; import "package:logging/logging.dart"; import "package:photos/face/db.dart"; @@ -9,23 +10,44 @@ import "package:photos/ui/components/buttons/chip_button_widget.dart"; import "package:photos/ui/components/info_item_widget.dart"; import "package:photos/ui/viewer/file_details/face_widget.dart"; -class FacesItemWidget extends StatelessWidget { +class FacesItemWidget extends StatefulWidget { final EnteFile file; const FacesItemWidget(this.file, {super.key}); + @override + State createState() => _FacesItemWidgetState(); +} + +class _FacesItemWidgetState extends State { + bool editMode = false; + + @override + void initState() { + super.initState(); + setState(() {}); + } + @override Widget build(BuildContext context) { return InfoItemWidget( key: const ValueKey("Faces"), leadingIcon: Icons.face_retouching_natural_outlined, - subtitleSection: _faceWidgets(context, file), + subtitleSection: _faceWidgets(context, widget.file, editMode), hasChipButtons: true, + editOnTap: _toggleEditMode, ); } + void _toggleEditMode() { + setState(() { + editMode = !editMode; + }); + } + Future> _faceWidgets( BuildContext context, EnteFile file, + bool editMode, ) async { try { if (file.uploadedFileID == null) { @@ -47,8 +69,13 @@ class FacesItemWidget extends StatelessWidget { ), ]; } - if (faces.isEmpty || - faces.every((face) => face.score < 0.75 || face.isBlurry)) { + + // Remove faces with low scores and blurry faces + if (!kDebugMode) { + faces.removeWhere((face) => (face.isBlurry || face.score < 0.75)); + } + + if (faces.isEmpty) { return [ const ChipButtonWidget( "No faces found", @@ -60,9 +87,6 @@ class FacesItemWidget extends StatelessWidget { // Sort the faces by score in descending order, so that the highest scoring face is first. faces.sort((Face a, Face b) => b.score.compareTo(a.score)); - // Remove faces with low scores and blurry faces - faces.removeWhere((face) => (face.isBlurry || face.score < 0.75)); - // TODO: add deduplication of faces of same person final faceIdsToClusterIds = await FaceMLDataDB.instance .getFaceIdsToClusterIds(faces.map((face) => face.faceID)); @@ -84,6 +108,7 @@ class FacesItemWidget extends StatelessWidget { clusterID: clusterID, person: person, highlight: highlight, + editMode: highlight ? editMode : false, ), ); } diff --git a/mobile/lib/ui/viewer/people/cluster_app_bar.dart b/mobile/lib/ui/viewer/people/cluster_app_bar.dart index bc32c9088..02dde594b 100644 --- a/mobile/lib/ui/viewer/people/cluster_app_bar.dart +++ b/mobile/lib/ui/viewer/people/cluster_app_bar.dart @@ -3,21 +3,22 @@ import 'dart:async'; import "package:flutter/foundation.dart"; import 'package:flutter/material.dart'; import 'package:logging/logging.dart'; +import "package:ml_linalg/linalg.dart"; import 'package:photos/core/configuration.dart'; import 'package:photos/core/event_bus.dart'; import "package:photos/db/files_db.dart"; -// import "package:photos/events/people_changed_event.dart"; import 'package:photos/events/subscription_purchased_event.dart'; -// import "package:photos/face/db.dart"; +import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/models/gallery_type.dart'; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; -import "package:photos/ui/viewer/people/cluster_page.dart"; -// import "package:photos/utils/dialog_util.dart"; +import "package:photos/ui/viewer/people/cluster_breakup_page.dart"; class ClusterAppBar extends StatefulWidget { final GalleryType type; @@ -42,6 +43,7 @@ class ClusterAppBar extends StatefulWidget { enum ClusterPopupAction { setCover, breakupCluster, + validateCluster, hide, } @@ -130,6 +132,18 @@ class _AppBarWidgetState extends State { ], ), ), + const PopupMenuItem( + value: ClusterPopupAction.validateCluster, + child: Row( + children: [ + Icon(Icons.search_off_outlined), + Padding( + padding: EdgeInsets.all(8), + ), + Text('Validate cluster'), + ], + ), + ), // PopupMenuItem( // value: ClusterPopupAction.hide, // child: Row( @@ -155,6 +169,8 @@ class _AppBarWidgetState extends State { if (value == ClusterPopupAction.breakupCluster) { // ignore: unawaited_futures await _breakUpCluster(context); + } else if (value == ClusterPopupAction.validateCluster) { + await _validateCluster(context); } // else if (value == ClusterPopupAction.setCover) { // await setCoverPhoto(context); @@ -169,28 +185,84 @@ class _AppBarWidgetState extends State { return actions; } + Future _validateCluster(BuildContext context) async { + _logger.info('_validateCluster called'); + final faceMlDb = FaceMLDataDB.instance; + + final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID); + final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); + + final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); + embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key)); + final embeddings = embeddingsBlobs + .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values)); + + for (final MapEntry> embedding in embeddings.entries) { + double closestDistance = double.infinity; + double closestDistance32 = double.infinity; + double closestDistance64 = double.infinity; + String? closestFaceID; + for (final MapEntry> otherEmbedding + in embeddings.entries) { + if (embedding.key == otherEmbedding.key) { + continue; + } + final distance64 = 1.0 - + Vector.fromList(embedding.value, dtype: DType.float64).dot( + Vector.fromList(otherEmbedding.value, dtype: DType.float64), + ); + final distance32 = 1.0 - + Vector.fromList(embedding.value, dtype: DType.float32).dot( + Vector.fromList(otherEmbedding.value, dtype: DType.float32), + ); + final distance = cosineDistForNormVectors( + embedding.value, + otherEmbedding.value, + ); + if (distance < closestDistance) { + closestDistance = distance; + closestDistance32 = distance32; + closestDistance64 = distance64; + closestFaceID = otherEmbedding.key; + } + } + if (closestDistance > 0.3) { + _logger.severe( + "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64", + ); + } + } + } + Future _breakUpCluster(BuildContext context) async { final newClusterIDToFaceIDs = await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID); - for (final cluster in newClusterIDToFaceIDs.entries) { - // ignore: unawaited_futures - final newClusterID = cluster.key; - final faceIDs = cluster.value; - final files = await FilesDB.instance - .getFilesFromIDs(faceIDs.map((e) => getFileIdFromFaceId(e)).toList()); - unawaited( - Navigator.of(context).push( - MaterialPageRoute( - builder: (context) => ClusterPage( - files.values.toList(), - appendTitle: - (newClusterID == -1) ? "(Analysis noise)" : "(Analysis)", - clusterID: newClusterID, - ), - ), + final allFileIDs = newClusterIDToFaceIDs.values + .expand((e) => e) + .map((e) => getFileIdFromFaceId(e)) + .toList(); + + final fileIDtoFile = await FilesDB.instance.getFilesFromIDs( + allFileIDs, + ); + + final newClusterIDToFiles = newClusterIDToFaceIDs.map( + (key, value) => MapEntry( + key, + value + .map((faceId) => fileIDtoFile[getFileIdFromFaceId(faceId)]!) + .toList(), + ), + ); + + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterBreakupPage( + newClusterIDToFiles, + "(Analysis)", ), - ); - } + ), + ); } } diff --git a/mobile/lib/ui/viewer/people/cluster_breakup_page.dart b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart new file mode 100644 index 000000000..e91909f47 --- /dev/null +++ b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart @@ -0,0 +1,124 @@ +import "package:flutter/material.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class ClusterBreakupPage extends StatefulWidget { + final Map> newClusterIDsToFiles; + final String title; + + const ClusterBreakupPage( + this.newClusterIDsToFiles, + this.title, { + super.key, + }); + + @override + State createState() => _ClusterBreakupPageState(); +} + +class _ClusterBreakupPageState extends State { + @override + Widget build(BuildContext context) { + final keys = widget.newClusterIDsToFiles.keys.toList(); + final clusterIDsToFiles = widget.newClusterIDsToFiles; + + return Scaffold( + appBar: AppBar( + title: Text(widget.title), + ), + body: ListView.builder( + itemCount: widget.newClusterIDsToFiles.keys.length, + itemBuilder: (context, index) { + final int clusterID = keys[index]; + final List files = clusterIDsToFiles[keys[index]]!; + return InkWell( + onTap: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + files, + clusterID: index, + appendTitle: "(Analysis)", + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + SizedBox( + width: 64, + height: 64, + child: files.isNotEmpty + ? ClipRRect( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12),), + child: PersonFaceWidget( + files.first, + clusterID: clusterID, + ), + ) + : const ClipRRect( + borderRadius: + BorderRadius.all(Radius.elliptical(16, 12)), + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + const SizedBox( + width: 8.0, + ), // Add some spacing between the thumbnail and the text + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8.0), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text( + "${clusterIDsToFiles[keys[index]]!.length} photos", + style: getEnteTextTheme(context).body, + ), + // GestureDetector( + // onTap: () async { + // try { + // final int result = await FaceMLDataDB + // .instance + // .removeClusterToPerson( + // personID: widget.person.remoteID, + // clusterID: clusterID, + // ); + // _logger.info( + // "Removed cluster $clusterID from person ${widget.person.remoteID}, result: $result", + // ); + // Bus.instance.fire(PeopleChangedEvent()); + // setState(() {}); + // } catch (e) { + // _logger.severe( + // "removing cluster from person,", + // e, + // ); + // } + // }, + // child: const Icon( + // CupertinoIcons.minus_circled, + // color: Colors.red, + // ), + // ), + ], + ), + ), + ), + ], + ), + ), + ); + }, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/people_app_bar.dart b/mobile/lib/ui/viewer/people/people_app_bar.dart index fe5af20bf..b6379153c 100644 --- a/mobile/lib/ui/viewer/people/people_app_bar.dart +++ b/mobile/lib/ui/viewer/people/people_app_bar.dart @@ -14,8 +14,8 @@ import 'package:photos/models/gallery_type.dart'; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; -import "package:photos/ui/viewer/people/person_cluserts.dart"; import "package:photos/ui/viewer/people/person_cluster_suggestion.dart"; +import 'package:photos/ui/viewer/people/person_clusters_page.dart'; import "package:photos/utils/dialog_util.dart"; class PeopleAppBar extends StatefulWidget { @@ -215,7 +215,7 @@ class _AppBarWidgetState extends State { unawaited( Navigator.of(context).push( MaterialPageRoute( - builder: (context) => PersonClusters(widget.person), + builder: (context) => PersonClustersPage(widget.person), ), ), ); diff --git a/mobile/lib/ui/viewer/people/person_cluserts.dart b/mobile/lib/ui/viewer/people/person_clusters_page.dart similarity index 89% rename from mobile/lib/ui/viewer/people/person_cluserts.dart rename to mobile/lib/ui/viewer/people/person_clusters_page.dart index ebac4d46c..044cd90d2 100644 --- a/mobile/lib/ui/viewer/people/person_cluserts.dart +++ b/mobile/lib/ui/viewer/people/person_clusters_page.dart @@ -13,19 +13,19 @@ import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; import "package:photos/ui/viewer/people/cluster_page.dart"; import "package:photos/ui/viewer/search/result/person_face_widget.dart"; -class PersonClusters extends StatefulWidget { +class PersonClustersPage extends StatefulWidget { final Person person; - const PersonClusters( + const PersonClustersPage( this.person, { super.key, }); @override - State createState() => _PersonClustersState(); + State createState() => _PersonClustersPageState(); } -class _PersonClustersState extends State { +class _PersonClustersPageState extends State { final Logger _logger = Logger("_PersonClustersState"); @override Widget build(BuildContext context) { @@ -64,13 +64,19 @@ class _PersonClustersState extends State { width: 64, height: 64, child: files.isNotEmpty - ? ClipOval( + ? ClipRRect( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), child: PersonFaceWidget( files.first, clusterID: clusterID, ), ) - : const ClipOval( + : const ClipRRect( + borderRadius: BorderRadius.all( + Radius.elliptical(16, 12), + ), child: NoThumbnailWidget( addBorder: false, ),