From fa466d715f7862ed39bf3352217cb82b4e3ff97a Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Mon, 22 Apr 2024 16:40:31 +0530 Subject: [PATCH] [mob][photos] Improve suggestions by improving speed and preferring big clusters --- .../face_filtering_constants.dart | 3 + .../face_ml/feedback/cluster_feedback.dart | 198 +++++++++--------- mobile/lib/services/search_service.dart | 3 +- 3 files changed, 104 insertions(+), 100 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart index b1f2f6018..0feb275a7 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -15,3 +15,6 @@ const kHighQualityFaceScore = 0.90; /// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; + +/// The minimum cluster size for displaying a cluster in the UI +const kMinimumClusterSizeSearchResult = 20; diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 95497a90d..e30dc375f 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -15,6 +15,7 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -65,7 +66,7 @@ class ClusterFeedbackService { try { // Get the suggestions for the person using centroids and median final List<(int, double, bool)> suggestClusterIds = - await _getSuggestionsUsingMedian(person); + await _getSuggestions(person); // Get the files for the suggestions final Map> fileIdToClusterID = @@ -241,7 +242,7 @@ class ClusterFeedbackService { watch.log('computed avg for ${clusterAvg.length} clusters'); // Find the actual closest clusters for the person - final Map> suggestions = _calcSuggestionsMean( + final List<(int, double)> suggestions = _calcSuggestionsMean( clusterAvg, personClusters, ignoredClusters, @@ -257,21 +258,17 @@ class ClusterFeedbackService { } // log suggestions - for (final entry in suggestions.entries) { - dev.log( - ' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}', - name: "ClusterFeedbackService", - ); - } + dev.log( + 'suggestions for ${p.data.name} for cluster ID ${p.remoteID} are suggestions $suggestions}', + name: "ClusterFeedbackService", + ); - for (final suggestionsPerCluster in suggestions.values) { - for (final suggestion in suggestionsPerCluster) { - final clusterID = suggestion.$1; - await PersonService.instance.assignClusterToPerson( - personID: p.remoteID, - clusterID: clusterID, - ); - } + for (final suggestion in suggestions) { + final clusterID = suggestion.$1; + await PersonService.instance.assignClusterToPerson( + personID: p.remoteID, + clusterID: clusterID, + ); } Bus.instance.fire(PeopleChangedEvent()); @@ -433,111 +430,77 @@ class ClusterFeedbackService { return; } - /// Returns a map of person's clusterID to map of closest clusterID to with disstance - Future>> getSuggestionsUsingMean( - PersonEntity p, { - double maxClusterDistance = 0.4, - }) async { - // Get all the cluster data - final faceMlDb = FaceMLDataDB.instance; - - final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); - final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); - final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); - dev.log( - 'existing clusters for ${p.data.name} are $personClusters', - name: "ClusterFeedbackService", - ); - - // Get and update the cluster summary to get the avg (centroid) and count - final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); - final Map> clusterAvg = await _getUpdateClusterAvg( - allClusterIdsToCountMap, - ignoredClusters, - ); - watch.log('computed avg for ${clusterAvg.length} clusters'); - - // Find the actual closest clusters for the person - final Map> suggestions = _calcSuggestionsMean( - clusterAvg, - personClusters, - ignoredClusters, - maxClusterDistance, - ); - - // log suggestions - for (final entry in suggestions.entries) { - dev.log( - ' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}', - name: "ClusterFeedbackService", - ); - } - return suggestions; - } - /// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements: /// 1. clusterID: the ID of the cluster /// 2. distance: the distance between the person's cluster and the suggestion /// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false) - Future> _getSuggestionsUsingMedian( + Future> _getSuggestions( PersonEntity p, { int sampleSize = 50, double maxMedianDistance = 0.65, double goodMedianDistance = 0.55, double maxMeanDistance = 0.65, - double goodMeanDistance = 0.4, + double goodMeanDistance = 0.5, }) async { // Get all the cluster data + final startTime = DateTime.now(); final faceMlDb = FaceMLDataDB.instance; // final Map> suggestions = {}; - final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); + final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount(); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); dev.log( - 'existing clusters for ${p.data.name} are $personClusters', + 'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms', name: "getSuggestionsUsingMedian", ); - // Get and update the cluster summary to get the avg (centroid) and count + // First only do a simple check on the big clusters final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final Map> clusterAvgBigClusters = + await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + minClusterSize: kMinimumClusterSizeSearchResult, + ); + dev.log( + 'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); + final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean( + clusterAvgBigClusters, + personClusters, + ignoredClusters, + goodMeanDistance, + ); + if (suggestionsMeanBigClusters.isNotEmpty) { + return suggestionsMeanBigClusters + .map((e) => (e.$1, e.$2, true)) + .toList(growable: false); + } + + // Get and update the cluster summary to get the avg (centroid) and count final Map> clusterAvg = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, ); - watch.log('computed avg for ${clusterAvg.length} clusters'); + dev.log( + 'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); // Find the other cluster candidates based on the mean - final Map> suggestionsMean = _calcSuggestionsMean( + final List<(int, double)> suggestionsMean = _calcSuggestionsMean( clusterAvg, personClusters, ignoredClusters, goodMeanDistance, ); if (suggestionsMean.isNotEmpty) { - final List<(int, double)> suggestClusterIds = []; - for (final List<(int, double)> suggestion in suggestionsMean.values) { - suggestClusterIds.addAll(suggestion); - } - suggestClusterIds.sort( - (a, b) => allClusterIdsToCountMap[b.$1]! - .compareTo(allClusterIdsToCountMap[a.$1]!), - ); - final suggestClusterIdsSizes = suggestClusterIds - .map((e) => allClusterIdsToCountMap[e.$1]!) - .toList(growable: false); - final suggestClusterIdsDistances = - suggestClusterIds.map((e) => e.$2).toList(growable: false); - _logger.info( - "Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances", - ); - return suggestClusterIds + return suggestionsMean .map((e) => (e.$1, e.$2, true)) .toList(growable: false); } // Find the other cluster candidates based on the median - final Map> moreSuggestionsMean = - _calcSuggestionsMean( + final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean( clusterAvg, personClusters, ignoredClusters, @@ -549,12 +512,8 @@ class ClusterFeedbackService { return []; } - final List<(int, double)> temp = []; - for (final List<(int, double)> suggestion in moreSuggestionsMean.values) { - temp.addAll(suggestion); - } - temp.sort((a, b) => a.$2.compareTo(b.$2)); - final otherClusterIdsCandidates = temp + moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2)); + final otherClusterIdsCandidates = moreSuggestionsMean .map( (e) => e.$1, ) @@ -655,20 +614,26 @@ class ClusterFeedbackService { int maxClusterInCurrentRun = 500, int maxEmbeddingToRead = 10000, }) async { + final startTime = DateTime.now(); final faceMlDb = FaceMLDataDB.instance; _logger.info( 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun', ); final Map clusterToSummary = - await faceMlDb.getAllClusterSummary(); + await faceMlDb.getAllClusterSummary(minClusterSize); final Map updatesForClusterSummary = {}; final Map> clusterAvg = {}; + dev.log( + 'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); + final allClusterIds = allClusterIdsToCountMap.keys.toSet(); int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; int smallerClustersCnt = 0; + final serializationTime = DateTime.now(); for (final id in allClusterIdsToCountMap.keys) { if (ignoredClusters.contains(id)) { allClusterIds.remove(id); @@ -684,9 +649,20 @@ class ClusterFeedbackService { smallerClustersCnt++; } } + dev.log( + 'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms', + ); _logger.info( 'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize', ); + + if (allClusterIds.isEmpty) { + _logger.info( + 'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); + return clusterAvg; + } + // get clusterIDs sorted by count in descending order final sortedClusterIDs = allClusterIds.toList(); sortedClusterIDs.sort( @@ -760,18 +736,21 @@ class ClusterFeedbackService { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); } w?.logAndReset('done computing avg '); - _logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters'); + _logger.info( + 'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); return clusterAvg; } /// Returns a map of person's clusterID to map of closest clusterID to with disstance - Map> _calcSuggestionsMean( + List<(int, double)> _calcSuggestionsMean( Map> clusterAvg, Set personClusters, Set ignoredClusters, - double maxClusterDistance, - ) { + double maxClusterDistance, { + Map? allClusterIdsToCountMap, + }) { final Map> suggestions = {}; for (final otherClusterID in clusterAvg.keys) { // ignore the cluster that belong to the person or is ignored @@ -802,11 +781,32 @@ class ClusterFeedbackService { .add((otherClusterID, minDistance)); } } - for (final entry in suggestions.entries) { - entry.value.sort((a, b) => a.$1.compareTo(b.$1)); - } - return suggestions; + if (suggestions.isNotEmpty) { + final List<(int, double)> suggestClusterIds = []; + for (final List<(int, double)> suggestion in suggestions.values) { + suggestClusterIds.addAll(suggestion); + } + List? suggestClusterIdsSizes; + if (allClusterIdsToCountMap != null) { + suggestClusterIds.sort( + (a, b) => allClusterIdsToCountMap[b.$1]! + .compareTo(allClusterIdsToCountMap[a.$1]!), + ); + suggestClusterIdsSizes = suggestClusterIds + .map((e) => allClusterIdsToCountMap[e.$1]!) + .toList(growable: false); + } + final suggestClusterIdsDistances = + suggestClusterIds.map((e) => e.$2).toList(growable: false); + _logger.info( + "Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances", + ); + return suggestClusterIds; + } else { + _logger.info("No suggestions found using mean"); + return <(int, double)>[]; + } } List _randomSampleWithoutReplacement( diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index 3f54187c1..5f04e5ba4 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -28,6 +28,7 @@ import "package:photos/models/search/search_constants.dart"; import "package:photos/models/search/search_types.dart"; import 'package:photos/services/collections_service.dart'; import "package:photos/services/location_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/states/location_screen_state.dart"; @@ -824,7 +825,7 @@ class SearchService { "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", ); } - if (files.length < 20 && sortedClusterIds.length > 3) { + if (files.length < kMinimumClusterSizeSearchResult && sortedClusterIds.length > 3) { continue; } facesResult.add(