diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart index ee7322456..8203958e9 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart @@ -395,6 +395,12 @@ class FaceClustering { if (distance < closestDistance) { closestDistance = distance; closestIdx = j; + // if (distance < distanceThreshold) { + // if (sortedFaceInfos[j].faceID.startsWith("14914702") || + // sortedFaceInfos[j].faceID.startsWith("15488756")) { + // log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance'); + // } + // } } } @@ -408,10 +414,22 @@ class FaceClustering { sortedFaceInfos[closestIdx].clusterId = clusterID; newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID; } + // if (sortedFaceInfos[i].faceID.startsWith("14914702") || + // sortedFaceInfos[i].faceID.startsWith("15488756")) { + // log( + // "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance", + // ); + // } sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; newFaceIdToCluster[sortedFaceInfos[i].faceID] = sortedFaceInfos[closestIdx].clusterId!; } else { + // if (sortedFaceInfos[i].faceID.startsWith("14914702") || + // sortedFaceInfos[i].faceID.startsWith("15488756")) { + // log( + // "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}", + // ); + // } clusterID++; sortedFaceInfos[i].clusterId = clusterID; newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID; diff --git a/mobile/lib/ui/viewer/people/cluster_app_bar.dart b/mobile/lib/ui/viewer/people/cluster_app_bar.dart index 82744a3d5..02dde594b 100644 --- a/mobile/lib/ui/viewer/people/cluster_app_bar.dart +++ b/mobile/lib/ui/viewer/people/cluster_app_bar.dart @@ -3,14 +3,18 @@ import 'dart:async'; import "package:flutter/foundation.dart"; import 'package:flutter/material.dart'; import 'package:logging/logging.dart'; +import "package:ml_linalg/linalg.dart"; import 'package:photos/core/configuration.dart'; import 'package:photos/core/event_bus.dart'; import "package:photos/db/files_db.dart"; import 'package:photos/events/subscription_purchased_event.dart'; +import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/models/gallery_type.dart'; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; @@ -39,6 +43,7 @@ class ClusterAppBar extends StatefulWidget { enum ClusterPopupAction { setCover, breakupCluster, + validateCluster, hide, } @@ -127,6 +132,18 @@ class _AppBarWidgetState extends State { ], ), ), + const PopupMenuItem( + value: ClusterPopupAction.validateCluster, + child: Row( + children: [ + Icon(Icons.search_off_outlined), + Padding( + padding: EdgeInsets.all(8), + ), + Text('Validate cluster'), + ], + ), + ), // PopupMenuItem( // value: ClusterPopupAction.hide, // child: Row( @@ -152,6 +169,8 @@ class _AppBarWidgetState extends State { if (value == ClusterPopupAction.breakupCluster) { // ignore: unawaited_futures await _breakUpCluster(context); + } else if (value == ClusterPopupAction.validateCluster) { + await _validateCluster(context); } // else if (value == ClusterPopupAction.setCover) { // await setCoverPhoto(context); @@ -166,6 +185,55 @@ class _AppBarWidgetState extends State { return actions; } + Future _validateCluster(BuildContext context) async { + _logger.info('_validateCluster called'); + final faceMlDb = FaceMLDataDB.instance; + + final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID); + final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); + + final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); + embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key)); + final embeddings = embeddingsBlobs + .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values)); + + for (final MapEntry> embedding in embeddings.entries) { + double closestDistance = double.infinity; + double closestDistance32 = double.infinity; + double closestDistance64 = double.infinity; + String? closestFaceID; + for (final MapEntry> otherEmbedding + in embeddings.entries) { + if (embedding.key == otherEmbedding.key) { + continue; + } + final distance64 = 1.0 - + Vector.fromList(embedding.value, dtype: DType.float64).dot( + Vector.fromList(otherEmbedding.value, dtype: DType.float64), + ); + final distance32 = 1.0 - + Vector.fromList(embedding.value, dtype: DType.float32).dot( + Vector.fromList(otherEmbedding.value, dtype: DType.float32), + ); + final distance = cosineDistForNormVectors( + embedding.value, + otherEmbedding.value, + ); + if (distance < closestDistance) { + closestDistance = distance; + closestDistance32 = distance32; + closestDistance64 = distance64; + closestFaceID = otherEmbedding.key; + } + } + if (closestDistance > 0.3) { + _logger.severe( + "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64", + ); + } + } + } + Future _breakUpCluster(BuildContext context) async { final newClusterIDToFaceIDs = await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);