Browse Source

Merge branch 'mobile_face' into fix_face_thumbnail

ashilkn 1 year ago
parent
commit
8225697e43

+ 1 - 1
mobile/lib/face/model/detection.dart

@@ -155,7 +155,7 @@ class Detection {
         (nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
             (nose[0] < min(leftMouth[0], rightMouth[0]));
     final bool noseStickingOutRight =
-        (nose[0] > max(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) &&
+        (nose[0] > max(leftEye[0], rightEye[0]) + 0.5 * eyeDistanceX) &&
             (nose[0] > max(leftMouth[0], rightMouth[0]));
 
     return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight);

+ 3 - 0
mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart

@@ -15,3 +15,6 @@ const kHighQualityFaceScore = 0.90;
 
 /// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
 const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;
+
+/// The minimum cluster size for displaying a cluster in the UI
+const kMinimumClusterSizeSearchResult = 20;

+ 99 - 99
mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart

@@ -15,6 +15,7 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import "package:photos/models/file/file.dart";
 import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
 import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
+import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import "package:photos/services/search_service.dart";
@@ -65,7 +66,7 @@ class ClusterFeedbackService {
     try {
       // Get the suggestions for the person using centroids and median
       final List<(int, double, bool)> suggestClusterIds =
-          await _getSuggestionsUsingMedian(person);
+          await _getSuggestions(person);
 
       // Get the files for the suggestions
       final Map<int, Set<int>> fileIdToClusterID =
@@ -241,7 +242,7 @@ class ClusterFeedbackService {
     watch.log('computed avg for ${clusterAvg.length} clusters');
 
     // Find the actual closest clusters for the person
-    final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
+    final List<(int, double)> suggestions = _calcSuggestionsMean(
       clusterAvg,
       personClusters,
       ignoredClusters,
@@ -257,21 +258,17 @@ class ClusterFeedbackService {
     }
 
     // log suggestions
-    for (final entry in suggestions.entries) {
-      dev.log(
-        ' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are  suggestions ${entry.value}}',
-        name: "ClusterFeedbackService",
-      );
-    }
+    dev.log(
+      'suggestions for ${p.data.name} for cluster ID ${p.remoteID} are  suggestions $suggestions}',
+      name: "ClusterFeedbackService",
+    );
 
-    for (final suggestionsPerCluster in suggestions.values) {
-      for (final suggestion in suggestionsPerCluster) {
-        final clusterID = suggestion.$1;
-        await PersonService.instance.assignClusterToPerson(
-          personID: p.remoteID,
-          clusterID: clusterID,
-        );
-      }
+    for (final suggestion in suggestions) {
+      final clusterID = suggestion.$1;
+      await PersonService.instance.assignClusterToPerson(
+        personID: p.remoteID,
+        clusterID: clusterID,
+      );
     }
 
     Bus.instance.fire(PeopleChangedEvent());
@@ -433,111 +430,77 @@ class ClusterFeedbackService {
     return;
   }
 
-  /// Returns a map of person's clusterID to map of closest clusterID to with disstance
-  Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
-    PersonEntity p, {
-    double maxClusterDistance = 0.4,
-  }) async {
-    // Get all the cluster data
-    final faceMlDb = FaceMLDataDB.instance;
-
-    final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
-    final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
-    final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
-    dev.log(
-      'existing clusters for ${p.data.name} are $personClusters',
-      name: "ClusterFeedbackService",
-    );
-
-    // Get and update the cluster summary to get the avg (centroid) and count
-    final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
-    final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
-      allClusterIdsToCountMap,
-      ignoredClusters,
-    );
-    watch.log('computed avg for ${clusterAvg.length} clusters');
-
-    // Find the actual closest clusters for the person
-    final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
-      clusterAvg,
-      personClusters,
-      ignoredClusters,
-      maxClusterDistance,
-    );
-
-    // log suggestions
-    for (final entry in suggestions.entries) {
-      dev.log(
-        ' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are  suggestions ${entry.value}}',
-        name: "ClusterFeedbackService",
-      );
-    }
-    return suggestions;
-  }
-
   /// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements:
   /// 1. clusterID: the ID of the cluster
   /// 2. distance: the distance between the person's cluster and the suggestion
   /// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false)
-  Future<List<(int, double, bool)>> _getSuggestionsUsingMedian(
+  Future<List<(int, double, bool)>> _getSuggestions(
     PersonEntity p, {
     int sampleSize = 50,
     double maxMedianDistance = 0.65,
     double goodMedianDistance = 0.55,
     double maxMeanDistance = 0.65,
-    double goodMeanDistance = 0.4,
+    double goodMeanDistance = 0.5,
   }) async {
     // Get all the cluster data
+    final startTime = DateTime.now();
     final faceMlDb = FaceMLDataDB.instance;
     // final Map<int, List<(int, double)>> suggestions = {};
-    final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
+    final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
     final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
     final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
     dev.log(
-      'existing clusters for ${p.data.name} are $personClusters',
+      'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
       name: "getSuggestionsUsingMedian",
     );
 
-    // Get and update the cluster summary to get the avg (centroid) and count
+    // First only do a simple check on the big clusters
     final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
+    final Map<int, List<double>> clusterAvgBigClusters =
+        await _getUpdateClusterAvg(
+      allClusterIdsToCountMap,
+      ignoredClusters,
+      minClusterSize: kMinimumClusterSizeSearchResult,
+    );
+    dev.log(
+      'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
+    );
+    final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean(
+      clusterAvgBigClusters,
+      personClusters,
+      ignoredClusters,
+      goodMeanDistance,
+    );
+    if (suggestionsMeanBigClusters.isNotEmpty) {
+      return suggestionsMeanBigClusters
+          .map((e) => (e.$1, e.$2, true))
+          .toList(growable: false);
+    }
+
+    // Get and update the cluster summary to get the avg (centroid) and count
     final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
       allClusterIdsToCountMap,
       ignoredClusters,
     );
-    watch.log('computed avg for ${clusterAvg.length} clusters');
+    dev.log(
+      'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
+    );
 
     // Find the other cluster candidates based on the mean
-    final Map<int, List<(int, double)>> suggestionsMean = _calcSuggestionsMean(
+    final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
       clusterAvg,
       personClusters,
       ignoredClusters,
       goodMeanDistance,
     );
     if (suggestionsMean.isNotEmpty) {
-      final List<(int, double)> suggestClusterIds = [];
-      for (final List<(int, double)> suggestion in suggestionsMean.values) {
-        suggestClusterIds.addAll(suggestion);
-      }
-      suggestClusterIds.sort(
-        (a, b) => allClusterIdsToCountMap[b.$1]!
-            .compareTo(allClusterIdsToCountMap[a.$1]!),
-      );
-      final suggestClusterIdsSizes = suggestClusterIds
-          .map((e) => allClusterIdsToCountMap[e.$1]!)
-          .toList(growable: false);
-      final suggestClusterIdsDistances =
-          suggestClusterIds.map((e) => e.$2).toList(growable: false);
-      _logger.info(
-        "Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances",
-      );
-      return suggestClusterIds
+      return suggestionsMean
           .map((e) => (e.$1, e.$2, true))
           .toList(growable: false);
     }
 
     // Find the other cluster candidates based on the median
-    final Map<int, List<(int, double)>> moreSuggestionsMean =
-        _calcSuggestionsMean(
+    final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
       clusterAvg,
       personClusters,
       ignoredClusters,
@@ -549,12 +512,8 @@ class ClusterFeedbackService {
       return [];
     }
 
-    final List<(int, double)> temp = [];
-    for (final List<(int, double)> suggestion in moreSuggestionsMean.values) {
-      temp.addAll(suggestion);
-    }
-    temp.sort((a, b) => a.$2.compareTo(b.$2));
-    final otherClusterIdsCandidates = temp
+    moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2));
+    final otherClusterIdsCandidates = moreSuggestionsMean
         .map(
           (e) => e.$1,
         )
@@ -655,20 +614,26 @@ class ClusterFeedbackService {
     int maxClusterInCurrentRun = 500,
     int maxEmbeddingToRead = 10000,
   }) async {
+    final startTime = DateTime.now();
     final faceMlDb = FaceMLDataDB.instance;
     _logger.info(
       'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
     );
 
     final Map<int, (Uint8List, int)> clusterToSummary =
-        await faceMlDb.getAllClusterSummary();
+        await faceMlDb.getAllClusterSummary(minClusterSize);
     final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
 
     final Map<int, List<double>> clusterAvg = {};
 
+    dev.log(
+      'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
+    );
+
     final allClusterIds = allClusterIdsToCountMap.keys.toSet();
     int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
     int smallerClustersCnt = 0;
+    final serializationTime = DateTime.now();
     for (final id in allClusterIdsToCountMap.keys) {
       if (ignoredClusters.contains(id)) {
         allClusterIds.remove(id);
@@ -684,9 +649,20 @@ class ClusterFeedbackService {
         smallerClustersCnt++;
       }
     }
+    dev.log(
+      'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms',
+    );
     _logger.info(
       'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
     );
+
+    if (allClusterIds.isEmpty) {
+      _logger.info(
+        'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
+      );
+      return clusterAvg;
+    }
+
     // get clusterIDs sorted by count in descending order
     final sortedClusterIDs = allClusterIds.toList();
     sortedClusterIDs.sort(
@@ -760,18 +736,21 @@ class ClusterFeedbackService {
       await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
     }
     w?.logAndReset('done computing avg ');
-    _logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
+    _logger.info(
+      'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
+    );
 
     return clusterAvg;
   }
 
   /// Returns a map of person's clusterID to map of closest clusterID to with disstance
-  Map<int, List<(int, double)>> _calcSuggestionsMean(
+  List<(int, double)> _calcSuggestionsMean(
     Map<int, List<double>> clusterAvg,
     Set<int> personClusters,
     Set<int> ignoredClusters,
-    double maxClusterDistance,
-  ) {
+    double maxClusterDistance, {
+    Map<int, int>? allClusterIdsToCountMap,
+  }) {
     final Map<int, List<(int, double)>> suggestions = {};
     for (final otherClusterID in clusterAvg.keys) {
       // ignore the cluster that belong to the person or is ignored
@@ -802,11 +781,32 @@ class ClusterFeedbackService {
             .add((otherClusterID, minDistance));
       }
     }
-    for (final entry in suggestions.entries) {
-      entry.value.sort((a, b) => a.$1.compareTo(b.$1));
-    }
 
-    return suggestions;
+    if (suggestions.isNotEmpty) {
+      final List<(int, double)> suggestClusterIds = [];
+      for (final List<(int, double)> suggestion in suggestions.values) {
+        suggestClusterIds.addAll(suggestion);
+      }
+      List<int>? suggestClusterIdsSizes;
+      if (allClusterIdsToCountMap != null) {
+        suggestClusterIds.sort(
+          (a, b) => allClusterIdsToCountMap[b.$1]!
+              .compareTo(allClusterIdsToCountMap[a.$1]!),
+        );
+        suggestClusterIdsSizes = suggestClusterIds
+            .map((e) => allClusterIdsToCountMap[e.$1]!)
+            .toList(growable: false);
+      }
+      final suggestClusterIdsDistances =
+          suggestClusterIds.map((e) => e.$2).toList(growable: false);
+      _logger.info(
+        "Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances",
+      );
+      return suggestClusterIds;
+    } else {
+      _logger.info("No suggestions found using mean");
+      return <(int, double)>[];
+    }
   }
 
   List<T> _randomSampleWithoutReplacement<T>(

+ 2 - 1
mobile/lib/services/search_service.dart

@@ -28,6 +28,7 @@ import "package:photos/models/search/search_constants.dart";
 import "package:photos/models/search/search_types.dart";
 import 'package:photos/services/collections_service.dart';
 import "package:photos/services/location_service.dart";
+import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
 import "package:photos/states/location_screen_state.dart";
@@ -824,7 +825,7 @@ class SearchService {
             "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
           );
         }
-        if (files.length < 20 && sortedClusterIds.length > 3) {
+        if (files.length < kMinimumClusterSizeSearchResult && sortedClusterIds.length > 3) {
           continue;
         }
         facesResult.add(

+ 28 - 29
mobile/lib/ui/settings/debug/face_debug_section_widget.dart

@@ -8,7 +8,6 @@ import "package:photos/events/people_changed_event.dart";
 import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
 import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart';
-import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import 'package:photos/theme/ente_theme.dart';
 import 'package:photos/ui/components/captioned_text_widget.dart';
@@ -284,34 +283,34 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
             );
           },
         ),
-        sectionOptionSpacing,
-        MenuItemWidget(
-          captionedTextWidget: const CaptionedTextWidget(
-            title: "Rank blurs",
-          ),
-          pressedColor: getEnteColorScheme(context).fillFaint,
-          trailingIcon: Icons.chevron_right_outlined,
-          trailingIconIsMuted: true,
-          onTap: () async {
-            await showChoiceDialog(
-              context,
-              title: "Are you sure?",
-              body:
-                  "This will delete all clusters and put blurry faces in separate clusters per ten points.",
-              firstButtonLabel: "Yes, confirm",
-              firstButtonOnTap: () async {
-                try {
-                  await ClusterFeedbackService.instance
-                      .createFakeClustersByBlurValue();
-                  showShortToast(context, "Done");
-                } catch (e, s) {
-                  _logger.warning('Failed to rank faces on blur values ', e, s);
-                  await showGenericErrorDialog(context: context, error: e);
-                }
-              },
-            );
-          },
-        ),
+        // sectionOptionSpacing,
+        // MenuItemWidget(
+        //   captionedTextWidget: const CaptionedTextWidget(
+        //     title: "Rank blurs",
+        //   ),
+        //   pressedColor: getEnteColorScheme(context).fillFaint,
+        //   trailingIcon: Icons.chevron_right_outlined,
+        //   trailingIconIsMuted: true,
+        //   onTap: () async {
+        //     await showChoiceDialog(
+        //       context,
+        //       title: "Are you sure?",
+        //       body:
+        //           "This will delete all clusters and put blurry faces in separate clusters per ten points.",
+        //       firstButtonLabel: "Yes, confirm",
+        //       firstButtonOnTap: () async {
+        //         try {
+        //           await ClusterFeedbackService.instance
+        //               .createFakeClustersByBlurValue();
+        //           showShortToast(context, "Done");
+        //         } catch (e, s) {
+        //           _logger.warning('Failed to rank faces on blur values ', e, s);
+        //           await showGenericErrorDialog(context: context, error: e);
+        //         }
+        //       },
+        //     );
+        //   },
+        // ),
         sectionOptionSpacing,
         MenuItemWidget(
           captionedTextWidget: const CaptionedTextWidget(

+ 3 - 2
mobile/lib/ui/viewer/file_details/face_widget.dart

@@ -1,5 +1,4 @@
 import "dart:developer" show log;
-import "dart:io" show Platform;
 import "dart:typed_data";
 
 import "package:flutter/cupertino.dart";
@@ -21,6 +20,8 @@ import "package:photos/utils/face/face_box_crop.dart";
 import "package:photos/utils/thumbnail_util.dart";
 // import "package:photos/utils/toast_util.dart";
 
+const useGeneratedFaceCrops = false;
+
 class FaceWidget extends StatefulWidget {
   final EnteFile file;
   final Face face;
@@ -48,7 +49,7 @@ class _FaceWidgetState extends State<FaceWidget> {
 
   @override
   Widget build(BuildContext context) {
-    if (Platform.isIOS) {
+    if (useGeneratedFaceCrops) {
       return FutureBuilder<Uint8List?>(
         future: getFaceCrop(),
         builder: (context, snapshot) {

+ 3 - 2
mobile/lib/ui/viewer/search/result/person_face_widget.dart

@@ -1,5 +1,5 @@
 import "dart:developer";
-import "dart:io";
+// import "dart:io";
 import "dart:typed_data";
 
 import 'package:flutter/widgets.dart';
@@ -10,6 +10,7 @@ import "package:photos/face/model/person.dart";
 import 'package:photos/models/file/file.dart';
 import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
 import 'package:photos/ui/viewer/file/thumbnail_widget.dart';
+import "package:photos/ui/viewer/file_details/face_widget.dart";
 import "package:photos/ui/viewer/people/cropped_face_image_view.dart";
 import "package:photos/utils/face/face_box_crop.dart";
 import "package:photos/utils/thumbnail_util.dart";
@@ -34,7 +35,7 @@ class PersonFaceWidget extends StatelessWidget {
 
   @override
   Widget build(BuildContext context) {
-    if (Platform.isIOS || Platform.isAndroid) {
+    if (useGeneratedFaceCrops) {
       return FutureBuilder<Uint8List?>(
         future: getFaceCrop(),
         builder: (context, snapshot) {