Forráskód Böngészése

Merge branch 'mobile_face' of https://github.com/ente-io/auth into mobile_face

Neeraj Gupta 1 éve
szülő
commit
1b9c81c50c

+ 24 - 2
mobile/lib/face/db.dart

@@ -481,6 +481,16 @@ class FaceMLDataDB {
     return maps.first['count'] as int;
     return maps.first['count'] as int;
   }
   }
 
 
+  Future<int> getBlurryFaceCount([
+    int blurThreshold = kLaplacianThreshold,
+  ]) async {
+    final db = await instance.database;
+    final List<Map<String, dynamic>> maps = await db.rawQuery(
+      'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore',
+    );
+    return maps.first['count'] as int;
+  }
+
   Future<void> resetClusterIDs() async {
   Future<void> resetClusterIDs() async {
     final db = await instance.database;
     final db = await instance.database;
     await db.execute(dropFaceClustersTable);
     await db.execute(dropFaceClustersTable);
@@ -726,7 +736,7 @@ class FaceMLDataDB {
     for (final enteFile in files) {
     for (final enteFile in files) {
       fileIds.add(enteFile.uploadedFileID.toString());
       fileIds.add(enteFile.uploadedFileID.toString());
     }
     }
-    int maxClusterID = DateTime.now().millisecondsSinceEpoch;
+    int maxClusterID = DateTime.now().microsecondsSinceEpoch;
     final Map<String, int> faceIDToClusterID = {};
     final Map<String, int> faceIDToClusterID = {};
     for (final row in faceIdsResult) {
     for (final row in faceIdsResult) {
       final faceID = row[fcFaceId] as String;
       final faceID = row[fcFaceId] as String;
@@ -752,7 +762,7 @@ class FaceMLDataDB {
     for (final enteFile in files) {
     for (final enteFile in files) {
       fileIds.add(enteFile.uploadedFileID.toString());
       fileIds.add(enteFile.uploadedFileID.toString());
     }
     }
-    int maxClusterID = DateTime.now().millisecondsSinceEpoch;
+    int maxClusterID = DateTime.now().microsecondsSinceEpoch;
     final Map<String, int> faceIDToClusterID = {};
     final Map<String, int> faceIDToClusterID = {};
     for (final row in faceIdsResult) {
     for (final row in faceIdsResult) {
       final faceID = row[fcFaceId] as String;
       final faceID = row[fcFaceId] as String;
@@ -763,4 +773,16 @@ class FaceMLDataDB {
     }
     }
     await forceUpdateClusterIds(faceIDToClusterID);
     await forceUpdateClusterIds(faceIDToClusterID);
   }
   }
+
+  Future<void> addFacesToCluster(
+    List<String> faceIDs,
+    int clusterID,
+  ) async {
+    final faceIDToClusterID = <String, int>{};
+    for (final faceID in faceIDs) {
+      faceIDToClusterID[faceID] = clusterID;
+    }
+
+    await forceUpdateClusterIds(faceIDToClusterID);
+  }
 }
 }

+ 42 - 0
mobile/lib/face/model/detection.dart

@@ -3,6 +3,9 @@ import "package:photos/face/model/landmark.dart";
 
 
 /// Stores the face detection data, notably the bounding box and landmarks.
 /// Stores the face detection data, notably the bounding box and landmarks.
 ///
 ///
+/// - Bounding box: [FaceBox] with xMin, yMin (so top left corner), width, height
+/// - Landmarks: list of [Landmark]s, namely leftEye, rightEye, nose, leftMouth, rightMouth
+///
 /// WARNING: All coordinates are relative to the image size, so in the range [0, 1]!
 /// WARNING: All coordinates are relative to the image size, so in the range [0, 1]!
 class Detection {
 class Detection {
   FaceBox box;
   FaceBox box;
@@ -39,4 +42,43 @@ class Detection {
       ),
       ),
     );
     );
   }
   }
+
+  // TODO: iterate on better area calculation, potentially using actual indexing image dimensions instead of file metadata
+  int getFaceArea(int imageWidth, int imageHeight) {
+    return (box.width * imageWidth * box.height * imageHeight).toInt();
+  }
+
+  // TODO: iterate on better scoring logic, current is a placeholder
+  int getVisibilityScore() {
+    final double aspectRatio = box.width / box.height;
+    final double eyeDistance = (landmarks[1].x - landmarks[0].x).abs();
+    final double mouthDistance = (landmarks[4].x - landmarks[3].x).abs();
+    final double noseEyeDistance =
+        (landmarks[2].y - ((landmarks[0].y + landmarks[1].y) / 2)).abs();
+
+    final double normalizedEyeDistance = eyeDistance / box.width;
+    final double normalizedMouthDistance = mouthDistance / box.width;
+    final double normalizedNoseEyeDistance = noseEyeDistance / box.height;
+
+    const double aspectRatioThreshold = 0.8;
+    const double eyeDistanceThreshold = 0.2;
+    const double mouthDistanceThreshold = 0.3;
+    const double noseEyeDistanceThreshold = 0.1;
+
+    double score = 0;
+    if (aspectRatio >= aspectRatioThreshold) {
+      score += 50;
+    }
+    if (normalizedEyeDistance >= eyeDistanceThreshold) {
+      score += 20;
+    }
+    if (normalizedMouthDistance >= mouthDistanceThreshold) {
+      score += 20;
+    }
+    if (normalizedNoseEyeDistance >= noseEyeDistanceThreshold) {
+      score += 10;
+    }
+
+    return score.clamp(0, 100).toInt();
+  }
 }
 }

+ 20 - 8
mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart

@@ -353,14 +353,8 @@ class FaceClustering {
     // Make sure the first face has a clusterId
     // Make sure the first face has a clusterId
     final int totalFaces = sortedFaceInfos.length;
     final int totalFaces = sortedFaceInfos.length;
     // set current epoch time as clusterID
     // set current epoch time as clusterID
-    int clusterID = DateTime.now().millisecondsSinceEpoch;
-    if (sortedFaceInfos.isNotEmpty) {
-      if (sortedFaceInfos.first.clusterId == null) {
-        sortedFaceInfos.first.clusterId = clusterID;
-      } else {
-        clusterID = sortedFaceInfos.first.clusterId!;
-      }
-    } else {
+    int clusterID = DateTime.now().microsecondsSinceEpoch;
+    if (sortedFaceInfos.isEmpty) {
       return {};
       return {};
     }
     }
 
 
@@ -401,6 +395,12 @@ class FaceClustering {
         if (distance < closestDistance) {
         if (distance < closestDistance) {
           closestDistance = distance;
           closestDistance = distance;
           closestIdx = j;
           closestIdx = j;
+          // if (distance < distanceThreshold) {
+          //   if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
+          //       sortedFaceInfos[j].faceID.startsWith("15488756")) {
+          //     log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
+          //   }
+          // }
         }
         }
       }
       }
 
 
@@ -414,10 +414,22 @@ class FaceClustering {
           sortedFaceInfos[closestIdx].clusterId = clusterID;
           sortedFaceInfos[closestIdx].clusterId = clusterID;
           newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
           newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
         }
         }
+        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
+        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
+        //   log(
+        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
+        //   );
+        // }
         sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
         sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
         newFaceIdToCluster[sortedFaceInfos[i].faceID] =
         newFaceIdToCluster[sortedFaceInfos[i].faceID] =
             sortedFaceInfos[closestIdx].clusterId!;
             sortedFaceInfos[closestIdx].clusterId!;
       } else {
       } else {
+        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
+        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
+        //   log(
+        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
+        //   );
+        // }
         clusterID++;
         clusterID++;
         sortedFaceInfos[i].clusterId = clusterID;
         sortedFaceInfos[i].clusterId = clusterID;
         newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
         newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;

+ 1 - 1
mobile/lib/services/machine_learning/face_ml/face_ml_service.dart

@@ -654,7 +654,7 @@ class FaceMlService {
                 .map(
                 .map(
                   (keypoint) => Landmark(
                   (keypoint) => Landmark(
                     x: keypoint[0],
                     x: keypoint[0],
-                    y: keypoint[0],
+                    y: keypoint[1],
                   ),
                   ),
                 )
                 )
                 .toList(),
                 .toList(),

+ 18 - 5
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -325,12 +325,25 @@ class ClusterFeedbackService {
     }
     }
   }
   }
 
 
-  Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) {
-    return FaceMLDataDB.instance.removeFilesFromPerson(files, p);
+  Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {
+    await FaceMLDataDB.instance.removeFilesFromPerson(files, p);
+    Bus.instance.fire(PeopleChangedEvent());
+    return;
+  }
+
+  Future<void> removeFilesFromCluster(
+    List<EnteFile> files,
+    int clusterID,
+  ) async {
+    await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
+    Bus.instance.fire(PeopleChangedEvent());
+    return;
   }
   }
 
 
-  Future<void> removeFilesFromCluster(List<EnteFile> files, int clusterID) {
-    return FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
+  Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
+    await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID);
+    Bus.instance.fire(PeopleChangedEvent());
+    return;
   }
   }
 
 
   Future<bool> checkAndDoAutomaticMerges(Person p) async {
   Future<bool> checkAndDoAutomaticMerges(Person p) async {
@@ -413,7 +426,7 @@ class ClusterFeedbackService {
         embeddings,
         embeddings,
         fileIDToCreationTime: fileIDToCreationTime,
         fileIDToCreationTime: fileIDToCreationTime,
         eps: 0.30,
         eps: 0.30,
-        minPts: 5,
+        minPts: 8,
       );
       );
 
 
       if (dbscanClusters.isEmpty) {
       if (dbscanClusters.isEmpty) {

+ 3 - 1
mobile/lib/ui/settings/debug/face_debug_section_widget.dart

@@ -114,7 +114,9 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
                 .getTotalFaceCount(minFaceScore: 0.75);
                 .getTotalFaceCount(minFaceScore: 0.75);
             final faces78 = await FaceMLDataDB.instance
             final faces78 = await FaceMLDataDB.instance
                 .getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore);
                 .getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore);
-            showShortToast(context, "Faces75: $faces75, Faces78: $faces78");
+            final blurryFaceCount =
+                await FaceMLDataDB.instance.getBlurryFaceCount(15);
+            showShortToast(context, "$blurryFaceCount blurry faces");
           },
           },
         ),
         ),
         // MenuItemWidget(
         // MenuItemWidget(

+ 131 - 57
mobile/lib/ui/viewer/file_details/face_widget.dart

@@ -2,12 +2,14 @@ import "dart:developer" show log;
 import "dart:io" show Platform;
 import "dart:io" show Platform;
 import "dart:typed_data";
 import "dart:typed_data";
 
 
+import "package:flutter/cupertino.dart";
 import "package:flutter/foundation.dart" show kDebugMode;
 import "package:flutter/foundation.dart" show kDebugMode;
 import "package:flutter/material.dart";
 import "package:flutter/material.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/model/face.dart";
 import "package:photos/face/model/face.dart";
 import "package:photos/face/model/person.dart";
 import "package:photos/face/model/person.dart";
 import 'package:photos/models/file/file.dart';
 import 'package:photos/models/file/file.dart';
+import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import "package:photos/services/search_service.dart";
 import "package:photos/services/search_service.dart";
 import "package:photos/theme/ente_theme.dart";
 import "package:photos/theme/ente_theme.dart";
 import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
 import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
@@ -16,13 +18,15 @@ import "package:photos/ui/viewer/people/cropped_face_image_view.dart";
 import "package:photos/ui/viewer/people/people_page.dart";
 import "package:photos/ui/viewer/people/people_page.dart";
 import "package:photos/utils/face/face_box_crop.dart";
 import "package:photos/utils/face/face_box_crop.dart";
 import "package:photos/utils/thumbnail_util.dart";
 import "package:photos/utils/thumbnail_util.dart";
+// import "package:photos/utils/toast_util.dart";
 
 
-class FaceWidget extends StatelessWidget {
+class FaceWidget extends StatefulWidget {
   final EnteFile file;
   final EnteFile file;
   final Face face;
   final Face face;
   final Person? person;
   final Person? person;
   final int? clusterID;
   final int? clusterID;
   final bool highlight;
   final bool highlight;
+  final bool editMode;
 
 
   const FaceWidget(
   const FaceWidget(
     this.file,
     this.file,
@@ -30,9 +34,17 @@ class FaceWidget extends StatelessWidget {
     this.person,
     this.person,
     this.clusterID,
     this.clusterID,
     this.highlight = false,
     this.highlight = false,
+    this.editMode = false,
     Key? key,
     Key? key,
   }) : super(key: key);
   }) : super(key: key);
 
 
+  @override
+  State<FaceWidget> createState() => _FaceWidgetState();
+}
+
+class _FaceWidgetState extends State<FaceWidget> {
+  bool isJustRemoved = false;
+
   @override
   @override
   Widget build(BuildContext context) {
   Widget build(BuildContext context) {
     if (Platform.isIOS || Platform.isAndroid) {
     if (Platform.isIOS || Platform.isAndroid) {
@@ -43,22 +55,24 @@ class FaceWidget extends StatelessWidget {
             final ImageProvider imageProvider = MemoryImage(snapshot.data!);
             final ImageProvider imageProvider = MemoryImage(snapshot.data!);
             return GestureDetector(
             return GestureDetector(
               onTap: () async {
               onTap: () async {
+                if (widget.editMode) return;
+
                 log(
                 log(
-                  "FaceWidget is tapped, with person $person and clusterID $clusterID",
+                  "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
                   name: "FaceWidget",
                   name: "FaceWidget",
                 );
                 );
-                if (person == null && clusterID == null) {
+                if (widget.person == null && widget.clusterID == null) {
                   return;
                   return;
                 }
                 }
-                if (person != null) {
+                if (widget.person != null) {
                   await Navigator.of(context).push(
                   await Navigator.of(context).push(
                     MaterialPageRoute(
                     MaterialPageRoute(
                       builder: (context) => PeoplePage(
                       builder: (context) => PeoplePage(
-                        person: person!,
+                        person: widget.person!,
                       ),
                       ),
                     ),
                     ),
                   );
                   );
-                } else if (clusterID != null) {
+                } else if (widget.clusterID != null) {
                   final fileIdsToClusterIds =
                   final fileIdsToClusterIds =
                       await FaceMLDataDB.instance.getFileIdToClusterIds();
                       await FaceMLDataDB.instance.getFileIdToClusterIds();
                   final files = await SearchService.instance.getAllFiles();
                   final files = await SearchService.instance.getAllFiles();
@@ -66,7 +80,7 @@ class FaceWidget extends StatelessWidget {
                       .where(
                       .where(
                         (file) =>
                         (file) =>
                             fileIdsToClusterIds[file.uploadedFileID]
                             fileIdsToClusterIds[file.uploadedFileID]
-                                ?.contains(clusterID) ??
+                                ?.contains(widget.clusterID) ??
                             false,
                             false,
                       )
                       )
                       .toList();
                       .toList();
@@ -74,7 +88,7 @@ class FaceWidget extends StatelessWidget {
                     MaterialPageRoute(
                     MaterialPageRoute(
                       builder: (context) => ClusterPage(
                       builder: (context) => ClusterPage(
                         clusterFiles,
                         clusterFiles,
-                        clusterID: clusterID!,
+                        clusterID: widget.clusterID!,
                       ),
                       ),
                     ),
                     ),
                   );
                   );
@@ -82,46 +96,87 @@ class FaceWidget extends StatelessWidget {
               },
               },
               child: Column(
               child: Column(
                 children: [
                 children: [
-                  // TODO: the edges of the green line are still not properly rounded around ClipRRect
-                  Container(
-                    height: 60,
-                    width: 60,
-                    decoration: ShapeDecoration(
-                      shape: RoundedRectangleBorder(
-                        borderRadius:
-                            const BorderRadius.all(Radius.elliptical(16, 12)),
-                        side: highlight
-                            ? BorderSide(
-                                color: getEnteColorScheme(context).primary700,
-                                width: 2.0,
-                              )
-                            : BorderSide.none,
-                      ),
-                    ),
-                    child: ClipRRect(
-                      borderRadius:
-                          const BorderRadius.all(Radius.elliptical(16, 12)),
-                      child: SizedBox(
-                        width: 60,
+                  Stack(
+                    children: [
+                      Container(
                         height: 60,
                         height: 60,
-                        child: Image(
-                          image: imageProvider,
-                          fit: BoxFit.cover,
+                        width: 60,
+                        decoration: ShapeDecoration(
+                          shape: RoundedRectangleBorder(
+                            borderRadius: const BorderRadius.all(
+                              Radius.elliptical(16, 12),
+                            ),
+                            side: widget.highlight
+                                ? BorderSide(
+                                    color:
+                                        getEnteColorScheme(context).primary700,
+                                    width: 1.0,
+                                  )
+                                : BorderSide.none,
+                          ),
+                        ),
+                        child: ClipRRect(
+                          borderRadius:
+                              const BorderRadius.all(Radius.elliptical(16, 12)),
+                          child: SizedBox(
+                            width: 60,
+                            height: 60,
+                            child: Image(
+                              image: imageProvider,
+                              fit: BoxFit.cover,
+                            ),
+                          ),
                         ),
                         ),
                       ),
                       ),
-                    ),
+                      // TODO: the edges of the green line are still not properly rounded around ClipRRect
+                      if (widget.editMode)
+                        Positioned(
+                          right: 0,
+                          top: 0,
+                          child: GestureDetector(
+                            onTap: _cornerIconPressed,
+                            child: isJustRemoved
+                                ? const Icon(
+                                    CupertinoIcons.add_circled_solid,
+                                    color: Colors.green,
+                                  )
+                                : const Icon(
+                                    Icons.cancel,
+                                    color: Colors.red,
+                                  ),
+                          ),
+                        ),
+                    ],
                   ),
                   ),
                   const SizedBox(height: 8),
                   const SizedBox(height: 8),
-                  if (person != null)
+                  if (widget.person != null)
                     Text(
                     Text(
-                      person!.attr.name.trim(),
+                      widget.person!.attr.name.trim(),
                       style: Theme.of(context).textTheme.bodySmall,
                       style: Theme.of(context).textTheme.bodySmall,
                       overflow: TextOverflow.ellipsis,
                       overflow: TextOverflow.ellipsis,
                       maxLines: 1,
                       maxLines: 1,
                     ),
                     ),
                   if (kDebugMode)
                   if (kDebugMode)
                     Text(
                     Text(
-                      'S: ${face.score.toStringAsFixed(3)}',
+                      'S: ${widget.face.score.toStringAsFixed(3)}',
+                      style: Theme.of(context).textTheme.bodySmall,
+                      maxLines: 1,
+                    ),
+                  if (kDebugMode)
+                    Text(
+                      'B: ${widget.face.blur.toStringAsFixed(3)}',
+                      style: Theme.of(context).textTheme.bodySmall,
+                      maxLines: 1,
+                    ),
+                  if (kDebugMode)
+                    Text(
+                      'V: ${widget.face.detection.getVisibilityScore()}',
+                      style: Theme.of(context).textTheme.bodySmall,
+                      maxLines: 1,
+                    ),
+                  if (kDebugMode)
+                    Text(
+                      'A: ${widget.face.detection.getFaceArea(widget.file.width, widget.file.height)}',
                       style: Theme.of(context).textTheme.bodySmall,
                       style: Theme.of(context).textTheme.bodySmall,
                       maxLines: 1,
                       maxLines: 1,
                     ),
                     ),
@@ -168,21 +223,21 @@ class FaceWidget extends StatelessWidget {
           return GestureDetector(
           return GestureDetector(
             onTap: () async {
             onTap: () async {
               log(
               log(
-                "FaceWidget is tapped, with person $person and clusterID $clusterID",
+                "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
                 name: "FaceWidget",
                 name: "FaceWidget",
               );
               );
-              if (person == null && clusterID == null) {
+              if (widget.person == null && widget.clusterID == null) {
                 return;
                 return;
               }
               }
-              if (person != null) {
+              if (widget.person != null) {
                 await Navigator.of(context).push(
                 await Navigator.of(context).push(
                   MaterialPageRoute(
                   MaterialPageRoute(
                     builder: (context) => PeoplePage(
                     builder: (context) => PeoplePage(
-                      person: person!,
+                      person: widget.person!,
                     ),
                     ),
                   ),
                   ),
                 );
                 );
-              } else if (clusterID != null) {
+              } else if (widget.clusterID != null) {
                 final fileIdsToClusterIds =
                 final fileIdsToClusterIds =
                     await FaceMLDataDB.instance.getFileIdToClusterIds();
                     await FaceMLDataDB.instance.getFileIdToClusterIds();
                 final files = await SearchService.instance.getAllFiles();
                 final files = await SearchService.instance.getAllFiles();
@@ -190,7 +245,7 @@ class FaceWidget extends StatelessWidget {
                     .where(
                     .where(
                       (file) =>
                       (file) =>
                           fileIdsToClusterIds[file.uploadedFileID]
                           fileIdsToClusterIds[file.uploadedFileID]
-                              ?.contains(clusterID) ??
+                              ?.contains(widget.clusterID) ??
                           false,
                           false,
                     )
                     )
                     .toList();
                     .toList();
@@ -198,7 +253,7 @@ class FaceWidget extends StatelessWidget {
                   MaterialPageRoute(
                   MaterialPageRoute(
                     builder: (context) => ClusterPage(
                     builder: (context) => ClusterPage(
                       clusterFiles,
                       clusterFiles,
-                      clusterID: clusterID!,
+                      clusterID: widget.clusterID!,
                     ),
                     ),
                   ),
                   ),
                 );
                 );
@@ -213,7 +268,7 @@ class FaceWidget extends StatelessWidget {
                     shape: RoundedRectangleBorder(
                     shape: RoundedRectangleBorder(
                       borderRadius:
                       borderRadius:
                           const BorderRadius.all(Radius.elliptical(16, 12)),
                           const BorderRadius.all(Radius.elliptical(16, 12)),
-                      side: highlight
+                      side: widget.highlight
                           ? BorderSide(
                           ? BorderSide(
                               color: getEnteColorScheme(context).primary700,
                               color: getEnteColorScheme(context).primary700,
                               width: 2.0,
                               width: 2.0,
@@ -228,23 +283,23 @@ class FaceWidget extends StatelessWidget {
                       width: 60,
                       width: 60,
                       height: 60,
                       height: 60,
                       child: CroppedFaceImageView(
                       child: CroppedFaceImageView(
-                        enteFile: file,
-                        face: face,
+                        enteFile: widget.file,
+                        face: widget.face,
                       ),
                       ),
                     ),
                     ),
                   ),
                   ),
                 ),
                 ),
                 const SizedBox(height: 8),
                 const SizedBox(height: 8),
-                if (person != null)
+                if (widget.person != null)
                   Text(
                   Text(
-                    person!.attr.name.trim(),
+                    widget.person!.attr.name.trim(),
                     style: Theme.of(context).textTheme.bodySmall,
                     style: Theme.of(context).textTheme.bodySmall,
                     overflow: TextOverflow.ellipsis,
                     overflow: TextOverflow.ellipsis,
                     maxLines: 1,
                     maxLines: 1,
                   ),
                   ),
                 if (kDebugMode)
                 if (kDebugMode)
                   Text(
                   Text(
-                    'S: ${face.score.toStringAsFixed(3)}',
+                    'S: ${widget.face.score.toStringAsFixed(3)}',
                     style: Theme.of(context).textTheme.bodySmall,
                     style: Theme.of(context).textTheme.bodySmall,
                     maxLines: 1,
                     maxLines: 1,
                   ),
                   ),
@@ -256,36 +311,55 @@ class FaceWidget extends StatelessWidget {
     }
     }
   }
   }
 
 
+  void _cornerIconPressed() async {
+    log('face widget (file info) corner icon is pressed');
+    try {
+      if (isJustRemoved) {
+        await ClusterFeedbackService.instance
+            .addFilesToCluster([widget.face.faceID], widget.clusterID!);
+      } else {
+        await ClusterFeedbackService.instance
+            .removeFilesFromCluster([widget.file], widget.clusterID!);
+      }
+
+      setState(() {
+        isJustRemoved = !isJustRemoved;
+      });
+    } catch (e, s) {
+      log("removing face/file from cluster from file info widget failed: $e, \n $s");
+    }
+  }
+
   Future<Uint8List?> getFaceCrop() async {
   Future<Uint8List?> getFaceCrop() async {
     try {
     try {
-      final Uint8List? cachedFace = faceCropCache.get(face.faceID);
+      final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID);
       if (cachedFace != null) {
       if (cachedFace != null) {
         return cachedFace;
         return cachedFace;
       }
       }
-      final faceCropCacheFile = cachedFaceCropPath(face.faceID);
+      final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID);
       if ((await faceCropCacheFile.exists())) {
       if ((await faceCropCacheFile.exists())) {
         final data = await faceCropCacheFile.readAsBytes();
         final data = await faceCropCacheFile.readAsBytes();
-        faceCropCache.put(face.faceID, data);
+        faceCropCache.put(widget.face.faceID, data);
         return data;
         return data;
       }
       }
 
 
       final result = await pool.withResource(
       final result = await pool.withResource(
         () async => await getFaceCrops(
         () async => await getFaceCrops(
-          file,
+          widget.file,
           {
           {
-            face.faceID: face.detection.box,
+            widget.face.faceID: widget.face.detection.box,
           },
           },
         ),
         ),
       );
       );
-      final Uint8List? computedCrop = result?[face.faceID];
+      final Uint8List? computedCrop = result?[widget.face.faceID];
       if (computedCrop != null) {
       if (computedCrop != null) {
-        faceCropCache.put(face.faceID, computedCrop);
+        faceCropCache.put(widget.face.faceID, computedCrop);
         faceCropCacheFile.writeAsBytes(computedCrop).ignore();
         faceCropCacheFile.writeAsBytes(computedCrop).ignore();
       }
       }
       return computedCrop;
       return computedCrop;
     } catch (e, s) {
     } catch (e, s) {
       log(
       log(
-        "Error getting face for faceID: ${face.faceID}",
+        "Error getting face for faceID: ${widget.face.faceID}",
         error: e,
         error: e,
         stackTrace: s,
         stackTrace: s,
       );
       );

+ 32 - 7
mobile/lib/ui/viewer/file_details/faces_item_widget.dart

@@ -1,3 +1,4 @@
+import "package:flutter/foundation.dart" show kDebugMode;
 import "package:flutter/material.dart";
 import "package:flutter/material.dart";
 import "package:logging/logging.dart";
 import "package:logging/logging.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/db.dart";
@@ -9,23 +10,44 @@ import "package:photos/ui/components/buttons/chip_button_widget.dart";
 import "package:photos/ui/components/info_item_widget.dart";
 import "package:photos/ui/components/info_item_widget.dart";
 import "package:photos/ui/viewer/file_details/face_widget.dart";
 import "package:photos/ui/viewer/file_details/face_widget.dart";
 
 
-class FacesItemWidget extends StatelessWidget {
+class FacesItemWidget extends StatefulWidget {
   final EnteFile file;
   final EnteFile file;
   const FacesItemWidget(this.file, {super.key});
   const FacesItemWidget(this.file, {super.key});
 
 
+  @override
+  State<FacesItemWidget> createState() => _FacesItemWidgetState();
+}
+
+class _FacesItemWidgetState extends State<FacesItemWidget> {
+  bool editMode = false;
+
+  @override
+  void initState() {
+    super.initState();
+    setState(() {});
+  }
+
   @override
   @override
   Widget build(BuildContext context) {
   Widget build(BuildContext context) {
     return InfoItemWidget(
     return InfoItemWidget(
       key: const ValueKey("Faces"),
       key: const ValueKey("Faces"),
       leadingIcon: Icons.face_retouching_natural_outlined,
       leadingIcon: Icons.face_retouching_natural_outlined,
-      subtitleSection: _faceWidgets(context, file),
+      subtitleSection: _faceWidgets(context, widget.file, editMode),
       hasChipButtons: true,
       hasChipButtons: true,
+      editOnTap: _toggleEditMode,
     );
     );
   }
   }
 
 
+  void _toggleEditMode() {
+    setState(() {
+      editMode = !editMode;
+    });
+  }
+
   Future<List<Widget>> _faceWidgets(
   Future<List<Widget>> _faceWidgets(
     BuildContext context,
     BuildContext context,
     EnteFile file,
     EnteFile file,
+    bool editMode,
   ) async {
   ) async {
     try {
     try {
       if (file.uploadedFileID == null) {
       if (file.uploadedFileID == null) {
@@ -47,8 +69,13 @@ class FacesItemWidget extends StatelessWidget {
           ),
           ),
         ];
         ];
       }
       }
-      if (faces.isEmpty ||
-          faces.every((face) => face.score < 0.75 || face.isBlurry)) {
+
+      // Remove faces with low scores and blurry faces
+      if (!kDebugMode) {
+        faces.removeWhere((face) => (face.isBlurry || face.score < 0.75));
+      }
+
+      if (faces.isEmpty) {
         return [
         return [
           const ChipButtonWidget(
           const ChipButtonWidget(
             "No faces found",
             "No faces found",
@@ -60,9 +87,6 @@ class FacesItemWidget extends StatelessWidget {
       // Sort the faces by score in descending order, so that the highest scoring face is first.
       // Sort the faces by score in descending order, so that the highest scoring face is first.
       faces.sort((Face a, Face b) => b.score.compareTo(a.score));
       faces.sort((Face a, Face b) => b.score.compareTo(a.score));
 
 
-      // Remove faces with low scores and blurry faces
-      faces.removeWhere((face) => (face.isBlurry || face.score < 0.75));
-
       // TODO: add deduplication of faces of same person
       // TODO: add deduplication of faces of same person
       final faceIdsToClusterIds = await FaceMLDataDB.instance
       final faceIdsToClusterIds = await FaceMLDataDB.instance
           .getFaceIdsToClusterIds(faces.map((face) => face.faceID));
           .getFaceIdsToClusterIds(faces.map((face) => face.faceID));
@@ -84,6 +108,7 @@ class FacesItemWidget extends StatelessWidget {
             clusterID: clusterID,
             clusterID: clusterID,
             person: person,
             person: person,
             highlight: highlight,
             highlight: highlight,
+            editMode: highlight ? editMode : false,
           ),
           ),
         );
         );
       }
       }

+ 94 - 22
mobile/lib/ui/viewer/people/cluster_app_bar.dart

@@ -3,21 +3,22 @@ import 'dart:async';
 import "package:flutter/foundation.dart";
 import "package:flutter/foundation.dart";
 import 'package:flutter/material.dart';
 import 'package:flutter/material.dart';
 import 'package:logging/logging.dart';
 import 'package:logging/logging.dart';
+import "package:ml_linalg/linalg.dart";
 import 'package:photos/core/configuration.dart';
 import 'package:photos/core/configuration.dart';
 import 'package:photos/core/event_bus.dart';
 import 'package:photos/core/event_bus.dart';
 import "package:photos/db/files_db.dart";
 import "package:photos/db/files_db.dart";
-// import "package:photos/events/people_changed_event.dart";
 import 'package:photos/events/subscription_purchased_event.dart';
 import 'package:photos/events/subscription_purchased_event.dart';
-// import "package:photos/face/db.dart";
+import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
 import "package:photos/face/model/person.dart";
+import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import 'package:photos/models/gallery_type.dart';
 import 'package:photos/models/gallery_type.dart';
 import 'package:photos/models/selected_files.dart';
 import 'package:photos/models/selected_files.dart';
 import 'package:photos/services/collections_service.dart';
 import 'package:photos/services/collections_service.dart';
+import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
 import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
-import "package:photos/ui/viewer/people/cluster_page.dart";
-// import "package:photos/utils/dialog_util.dart";
+import "package:photos/ui/viewer/people/cluster_breakup_page.dart";
 
 
 class ClusterAppBar extends StatefulWidget {
 class ClusterAppBar extends StatefulWidget {
   final GalleryType type;
   final GalleryType type;
@@ -42,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
 enum ClusterPopupAction {
 enum ClusterPopupAction {
   setCover,
   setCover,
   breakupCluster,
   breakupCluster,
+  validateCluster,
   hide,
   hide,
 }
 }
 
 
@@ -130,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
             ],
             ],
           ),
           ),
         ),
         ),
+        const PopupMenuItem(
+          value: ClusterPopupAction.validateCluster,
+          child: Row(
+            children: [
+              Icon(Icons.search_off_outlined),
+              Padding(
+                padding: EdgeInsets.all(8),
+              ),
+              Text('Validate cluster'),
+            ],
+          ),
+        ),
         // PopupMenuItem(
         // PopupMenuItem(
         //   value: ClusterPopupAction.hide,
         //   value: ClusterPopupAction.hide,
         //   child: Row(
         //   child: Row(
@@ -155,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
             if (value == ClusterPopupAction.breakupCluster) {
             if (value == ClusterPopupAction.breakupCluster) {
               // ignore: unawaited_futures
               // ignore: unawaited_futures
               await _breakUpCluster(context);
               await _breakUpCluster(context);
+            } else if (value == ClusterPopupAction.validateCluster) {
+              await _validateCluster(context);
             }
             }
             // else if (value == ClusterPopupAction.setCover) {
             // else if (value == ClusterPopupAction.setCover) {
             //   await setCoverPhoto(context);
             //   await setCoverPhoto(context);
@@ -169,28 +185,84 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
     return actions;
     return actions;
   }
   }
 
 
+  Future<void> _validateCluster(BuildContext context) async {
+    _logger.info('_validateCluster called');
+    final faceMlDb = FaceMLDataDB.instance;
+
+    final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
+    final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
+
+    final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
+    embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
+    final embeddings = embeddingsBlobs
+        .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
+
+    for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
+      double closestDistance = double.infinity;
+      double closestDistance32 = double.infinity;
+      double closestDistance64 = double.infinity;
+      String? closestFaceID;
+      for (final MapEntry<String, List<double>> otherEmbedding
+          in embeddings.entries) {
+        if (embedding.key == otherEmbedding.key) {
+          continue;
+        }
+        final distance64 = 1.0 -
+            Vector.fromList(embedding.value, dtype: DType.float64).dot(
+              Vector.fromList(otherEmbedding.value, dtype: DType.float64),
+            );
+        final distance32 = 1.0 -
+            Vector.fromList(embedding.value, dtype: DType.float32).dot(
+              Vector.fromList(otherEmbedding.value, dtype: DType.float32),
+            );
+        final distance = cosineDistForNormVectors(
+          embedding.value,
+          otherEmbedding.value,
+        );
+        if (distance < closestDistance) {
+          closestDistance = distance;
+          closestDistance32 = distance32;
+          closestDistance64 = distance64;
+          closestFaceID = otherEmbedding.key;
+        }
+      }
+      if (closestDistance > 0.3) {
+        _logger.severe(
+          "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
+        );
+      }
+    }
+  }
+
   Future<void> _breakUpCluster(BuildContext context) async {
   Future<void> _breakUpCluster(BuildContext context) async {
     final newClusterIDToFaceIDs =
     final newClusterIDToFaceIDs =
         await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);
         await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);
 
 
-    for (final cluster in newClusterIDToFaceIDs.entries) {
-      // ignore: unawaited_futures
-      final newClusterID = cluster.key;
-      final faceIDs = cluster.value;
-      final files = await FilesDB.instance
-          .getFilesFromIDs(faceIDs.map((e) => getFileIdFromFaceId(e)).toList());
-      unawaited(
-        Navigator.of(context).push(
-          MaterialPageRoute(
-            builder: (context) => ClusterPage(
-              files.values.toList(),
-              appendTitle:
-                  (newClusterID == -1) ? "(Analysis noise)" : "(Analysis)",
-              clusterID: newClusterID,
-            ),
-          ),
+    final allFileIDs = newClusterIDToFaceIDs.values
+        .expand((e) => e)
+        .map((e) => getFileIdFromFaceId(e))
+        .toList();
+
+    final fileIDtoFile = await FilesDB.instance.getFilesFromIDs(
+      allFileIDs,
+    );
+
+    final newClusterIDToFiles = newClusterIDToFaceIDs.map(
+      (key, value) => MapEntry(
+        key,
+        value
+            .map((faceId) => fileIDtoFile[getFileIdFromFaceId(faceId)]!)
+            .toList(),
+      ),
+    );
+
+    await Navigator.of(context).push(
+      MaterialPageRoute(
+        builder: (context) => ClusterBreakupPage(
+          newClusterIDToFiles,
+          "(Analysis)",
         ),
         ),
-      );
-    }
+      ),
+    );
   }
   }
 }
 }

+ 124 - 0
mobile/lib/ui/viewer/people/cluster_breakup_page.dart

@@ -0,0 +1,124 @@
+import "package:flutter/material.dart";
+import "package:photos/models/file/file.dart";
+import "package:photos/theme/ente_theme.dart";
+import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
+import "package:photos/ui/viewer/people/cluster_page.dart";
+import "package:photos/ui/viewer/search/result/person_face_widget.dart";
+
+class ClusterBreakupPage extends StatefulWidget {
+  final Map<int, List<EnteFile>> newClusterIDsToFiles;
+  final String title;
+
+  const ClusterBreakupPage(
+    this.newClusterIDsToFiles,
+    this.title, {
+    super.key,
+  });
+
+  @override
+  State<ClusterBreakupPage> createState() => _ClusterBreakupPageState();
+}
+
+class _ClusterBreakupPageState extends State<ClusterBreakupPage> {
+  @override
+  Widget build(BuildContext context) {
+    final keys = widget.newClusterIDsToFiles.keys.toList();
+    final clusterIDsToFiles = widget.newClusterIDsToFiles;
+
+    return Scaffold(
+      appBar: AppBar(
+        title: Text(widget.title),
+      ),
+      body: ListView.builder(
+        itemCount: widget.newClusterIDsToFiles.keys.length,
+        itemBuilder: (context, index) {
+          final int clusterID = keys[index];
+          final List<EnteFile> files = clusterIDsToFiles[keys[index]]!;
+          return InkWell(
+            onTap: () {
+              Navigator.of(context).push(
+                MaterialPageRoute(
+                  builder: (context) => ClusterPage(
+                    files,
+                    clusterID: index,
+                    appendTitle: "(Analysis)",
+                  ),
+                ),
+              );
+            },
+            child: Container(
+              padding: const EdgeInsets.all(8.0),
+              child: Row(
+                children: <Widget>[
+                  SizedBox(
+                    width: 64,
+                    height: 64,
+                    child: files.isNotEmpty
+                        ? ClipRRect(
+                            borderRadius: const BorderRadius.all(
+                                Radius.elliptical(16, 12),),
+                            child: PersonFaceWidget(
+                              files.first,
+                              clusterID: clusterID,
+                            ),
+                          )
+                        : const ClipRRect(
+                            borderRadius:
+                                BorderRadius.all(Radius.elliptical(16, 12)),
+                            child: NoThumbnailWidget(
+                              addBorder: false,
+                            ),
+                          ),
+                  ),
+                  const SizedBox(
+                    width: 8.0,
+                  ), // Add some spacing between the thumbnail and the text
+                  Expanded(
+                    child: Padding(
+                      padding: const EdgeInsets.symmetric(horizontal: 8.0),
+                      child: Row(
+                        mainAxisAlignment: MainAxisAlignment.spaceBetween,
+                        children: <Widget>[
+                          Text(
+                            "${clusterIDsToFiles[keys[index]]!.length} photos",
+                            style: getEnteTextTheme(context).body,
+                          ),
+                          // GestureDetector(
+                          //   onTap: () async {
+                          //     try {
+                          //       final int result = await FaceMLDataDB
+                          //           .instance
+                          //           .removeClusterToPerson(
+                          //         personID: widget.person.remoteID,
+                          //         clusterID: clusterID,
+                          //       );
+                          //       _logger.info(
+                          //         "Removed cluster $clusterID from person ${widget.person.remoteID}, result: $result",
+                          //       );
+                          //       Bus.instance.fire(PeopleChangedEvent());
+                          //       setState(() {});
+                          //     } catch (e) {
+                          //       _logger.severe(
+                          //         "removing cluster from person,",
+                          //         e,
+                          //       );
+                          //     }
+                          //   },
+                          //   child: const Icon(
+                          //     CupertinoIcons.minus_circled,
+                          //     color: Colors.red,
+                          //   ),
+                          // ),
+                        ],
+                      ),
+                    ),
+                  ),
+                ],
+              ),
+            ),
+          );
+        },
+      ),
+    );
+  }
+}

+ 2 - 2
mobile/lib/ui/viewer/people/people_app_bar.dart

@@ -14,8 +14,8 @@ import 'package:photos/models/gallery_type.dart';
 import 'package:photos/models/selected_files.dart';
 import 'package:photos/models/selected_files.dart';
 import 'package:photos/services/collections_service.dart';
 import 'package:photos/services/collections_service.dart';
 import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
 import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
-import "package:photos/ui/viewer/people/person_cluserts.dart";
 import "package:photos/ui/viewer/people/person_cluster_suggestion.dart";
 import "package:photos/ui/viewer/people/person_cluster_suggestion.dart";
+import 'package:photos/ui/viewer/people/person_clusters_page.dart';
 import "package:photos/utils/dialog_util.dart";
 import "package:photos/utils/dialog_util.dart";
 
 
 class PeopleAppBar extends StatefulWidget {
 class PeopleAppBar extends StatefulWidget {
@@ -215,7 +215,7 @@ class _AppBarWidgetState extends State<PeopleAppBar> {
               unawaited(
               unawaited(
                 Navigator.of(context).push(
                 Navigator.of(context).push(
                   MaterialPageRoute(
                   MaterialPageRoute(
-                    builder: (context) => PersonClusters(widget.person),
+                    builder: (context) => PersonClustersPage(widget.person),
                   ),
                   ),
                 ),
                 ),
               );
               );

+ 12 - 6
mobile/lib/ui/viewer/people/person_cluserts.dart → mobile/lib/ui/viewer/people/person_clusters_page.dart

@@ -13,19 +13,19 @@ import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
 import "package:photos/ui/viewer/people/cluster_page.dart";
 import "package:photos/ui/viewer/people/cluster_page.dart";
 import "package:photos/ui/viewer/search/result/person_face_widget.dart";
 import "package:photos/ui/viewer/search/result/person_face_widget.dart";
 
 
-class PersonClusters extends StatefulWidget {
+class PersonClustersPage extends StatefulWidget {
   final Person person;
   final Person person;
 
 
-  const PersonClusters(
+  const PersonClustersPage(
     this.person, {
     this.person, {
     super.key,
     super.key,
   });
   });
 
 
   @override
   @override
-  State<PersonClusters> createState() => _PersonClustersState();
+  State<PersonClustersPage> createState() => _PersonClustersPageState();
 }
 }
 
 
-class _PersonClustersState extends State<PersonClusters> {
+class _PersonClustersPageState extends State<PersonClustersPage> {
   final Logger _logger = Logger("_PersonClustersState");
   final Logger _logger = Logger("_PersonClustersState");
   @override
   @override
   Widget build(BuildContext context) {
   Widget build(BuildContext context) {
@@ -64,13 +64,19 @@ class _PersonClustersState extends State<PersonClusters> {
                           width: 64,
                           width: 64,
                           height: 64,
                           height: 64,
                           child: files.isNotEmpty
                           child: files.isNotEmpty
-                              ? ClipOval(
+                              ? ClipRRect(
+                                  borderRadius: const BorderRadius.all(
+                                    Radius.elliptical(16, 12),
+                                  ),
                                   child: PersonFaceWidget(
                                   child: PersonFaceWidget(
                                     files.first,
                                     files.first,
                                     clusterID: clusterID,
                                     clusterID: clusterID,
                                   ),
                                   ),
                                 )
                                 )
-                              : const ClipOval(
+                              : const ClipRRect(
+                                  borderRadius: BorderRadius.all(
+                                    Radius.elliptical(16, 12),
+                                  ),
                                   child: NoThumbnailWidget(
                                   child: NoThumbnailWidget(
                                     addBorder: false,
                                     addBorder: false,
                                   ),
                                   ),