[mob][photos] Improve suggestions by improving speed and preferring big clusters

This commit is contained in:
laurenspriem 2024-04-22 16:40:31 +05:30
parent 37ab467da5
commit fa466d715f
3 changed files with 104 additions and 100 deletions

View file

@ -15,3 +15,6 @@ const kHighQualityFaceScore = 0.90;
/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. /// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;
/// The minimum cluster size for displaying a cluster in the UI
const kMinimumClusterSizeSearchResult = 20;

View file

@ -15,6 +15,7 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart"; import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/search_service.dart"; import "package:photos/services/search_service.dart";
@ -65,7 +66,7 @@ class ClusterFeedbackService {
try { try {
// Get the suggestions for the person using centroids and median // Get the suggestions for the person using centroids and median
final List<(int, double, bool)> suggestClusterIds = final List<(int, double, bool)> suggestClusterIds =
await _getSuggestionsUsingMedian(person); await _getSuggestions(person);
// Get the files for the suggestions // Get the files for the suggestions
final Map<int, Set<int>> fileIdToClusterID = final Map<int, Set<int>> fileIdToClusterID =
@ -241,7 +242,7 @@ class ClusterFeedbackService {
watch.log('computed avg for ${clusterAvg.length} clusters'); watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person // Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean( final List<(int, double)> suggestions = _calcSuggestionsMean(
clusterAvg, clusterAvg,
personClusters, personClusters,
ignoredClusters, ignoredClusters,
@ -257,21 +258,17 @@ class ClusterFeedbackService {
} }
// log suggestions // log suggestions
for (final entry in suggestions.entries) { dev.log(
dev.log( 'suggestions for ${p.data.name} for cluster ID ${p.remoteID} are suggestions $suggestions}',
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}', name: "ClusterFeedbackService",
name: "ClusterFeedbackService", );
);
}
for (final suggestionsPerCluster in suggestions.values) { for (final suggestion in suggestions) {
for (final suggestion in suggestionsPerCluster) { final clusterID = suggestion.$1;
final clusterID = suggestion.$1; await PersonService.instance.assignClusterToPerson(
await PersonService.instance.assignClusterToPerson( personID: p.remoteID,
personID: p.remoteID, clusterID: clusterID,
clusterID: clusterID, );
);
}
} }
Bus.instance.fire(PeopleChangedEvent()); Bus.instance.fire(PeopleChangedEvent());
@ -433,111 +430,77 @@ class ClusterFeedbackService {
return; return;
} }
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
PersonEntity p, {
double maxClusterDistance = 0.4,
}) async {
// Get all the cluster data
final faceMlDb = FaceMLDataDB.instance;
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.data.name} are $personClusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
maxClusterDistance,
);
// log suggestions
for (final entry in suggestions.entries) {
dev.log(
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
name: "ClusterFeedbackService",
);
}
return suggestions;
}
/// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements: /// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements:
/// 1. clusterID: the ID of the cluster /// 1. clusterID: the ID of the cluster
/// 2. distance: the distance between the person's cluster and the suggestion /// 2. distance: the distance between the person's cluster and the suggestion
/// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false) /// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false)
Future<List<(int, double, bool)>> _getSuggestionsUsingMedian( Future<List<(int, double, bool)>> _getSuggestions(
PersonEntity p, { PersonEntity p, {
int sampleSize = 50, int sampleSize = 50,
double maxMedianDistance = 0.65, double maxMedianDistance = 0.65,
double goodMedianDistance = 0.55, double goodMedianDistance = 0.55,
double maxMeanDistance = 0.65, double maxMeanDistance = 0.65,
double goodMeanDistance = 0.4, double goodMeanDistance = 0.5,
}) async { }) async {
// Get all the cluster data // Get all the cluster data
final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance; final faceMlDb = FaceMLDataDB.instance;
// final Map<int, List<(int, double)>> suggestions = {}; // final Map<int, List<(int, double)>> suggestions = {};
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log( dev.log(
'existing clusters for ${p.data.name} are $personClusters', 'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
name: "getSuggestionsUsingMedian", name: "getSuggestionsUsingMedian",
); );
// Get and update the cluster summary to get the avg (centroid) and count // First only do a simple check on the big clusters
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvgBigClusters =
await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
minClusterSize: kMinimumClusterSizeSearchResult,
);
dev.log(
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean(
clusterAvgBigClusters,
personClusters,
ignoredClusters,
goodMeanDistance,
);
if (suggestionsMeanBigClusters.isNotEmpty) {
return suggestionsMeanBigClusters
.map((e) => (e.$1, e.$2, true))
.toList(growable: false);
}
// Get and update the cluster summary to get the avg (centroid) and count
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg( final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap, allClusterIdsToCountMap,
ignoredClusters, ignoredClusters,
); );
watch.log('computed avg for ${clusterAvg.length} clusters'); dev.log(
'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
// Find the other cluster candidates based on the mean // Find the other cluster candidates based on the mean
final Map<int, List<(int, double)>> suggestionsMean = _calcSuggestionsMean( final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
clusterAvg, clusterAvg,
personClusters, personClusters,
ignoredClusters, ignoredClusters,
goodMeanDistance, goodMeanDistance,
); );
if (suggestionsMean.isNotEmpty) { if (suggestionsMean.isNotEmpty) {
final List<(int, double)> suggestClusterIds = []; return suggestionsMean
for (final List<(int, double)> suggestion in suggestionsMean.values) {
suggestClusterIds.addAll(suggestion);
}
suggestClusterIds.sort(
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!),
);
final suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false);
final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info(
"Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances",
);
return suggestClusterIds
.map((e) => (e.$1, e.$2, true)) .map((e) => (e.$1, e.$2, true))
.toList(growable: false); .toList(growable: false);
} }
// Find the other cluster candidates based on the median // Find the other cluster candidates based on the median
final Map<int, List<(int, double)>> moreSuggestionsMean = final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
_calcSuggestionsMean(
clusterAvg, clusterAvg,
personClusters, personClusters,
ignoredClusters, ignoredClusters,
@ -549,12 +512,8 @@ class ClusterFeedbackService {
return []; return [];
} }
final List<(int, double)> temp = []; moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2));
for (final List<(int, double)> suggestion in moreSuggestionsMean.values) { final otherClusterIdsCandidates = moreSuggestionsMean
temp.addAll(suggestion);
}
temp.sort((a, b) => a.$2.compareTo(b.$2));
final otherClusterIdsCandidates = temp
.map( .map(
(e) => e.$1, (e) => e.$1,
) )
@ -655,20 +614,26 @@ class ClusterFeedbackService {
int maxClusterInCurrentRun = 500, int maxClusterInCurrentRun = 500,
int maxEmbeddingToRead = 10000, int maxEmbeddingToRead = 10000,
}) async { }) async {
final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance; final faceMlDb = FaceMLDataDB.instance;
_logger.info( _logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun', 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
); );
final Map<int, (Uint8List, int)> clusterToSummary = final Map<int, (Uint8List, int)> clusterToSummary =
await faceMlDb.getAllClusterSummary(); await faceMlDb.getAllClusterSummary(minClusterSize);
final Map<int, (Uint8List, int)> updatesForClusterSummary = {}; final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {}; final Map<int, List<double>> clusterAvg = {};
dev.log(
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
final allClusterIds = allClusterIdsToCountMap.keys.toSet(); final allClusterIds = allClusterIdsToCountMap.keys.toSet();
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
int smallerClustersCnt = 0; int smallerClustersCnt = 0;
final serializationTime = DateTime.now();
for (final id in allClusterIdsToCountMap.keys) { for (final id in allClusterIdsToCountMap.keys) {
if (ignoredClusters.contains(id)) { if (ignoredClusters.contains(id)) {
allClusterIds.remove(id); allClusterIds.remove(id);
@ -684,9 +649,20 @@ class ClusterFeedbackService {
smallerClustersCnt++; smallerClustersCnt++;
} }
} }
dev.log(
'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms',
);
_logger.info( _logger.info(
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize', 'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
); );
if (allClusterIds.isEmpty) {
_logger.info(
'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
return clusterAvg;
}
// get clusterIDs sorted by count in descending order // get clusterIDs sorted by count in descending order
final sortedClusterIDs = allClusterIds.toList(); final sortedClusterIDs = allClusterIds.toList();
sortedClusterIDs.sort( sortedClusterIDs.sort(
@ -760,18 +736,21 @@ class ClusterFeedbackService {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
} }
w?.logAndReset('done computing avg '); w?.logAndReset('done computing avg ');
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters'); _logger.info(
'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
return clusterAvg; return clusterAvg;
} }
/// Returns a map of person's clusterID to map of closest clusterID to with disstance /// Returns a map of person's clusterID to map of closest clusterID to with disstance
Map<int, List<(int, double)>> _calcSuggestionsMean( List<(int, double)> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg, Map<int, List<double>> clusterAvg,
Set<int> personClusters, Set<int> personClusters,
Set<int> ignoredClusters, Set<int> ignoredClusters,
double maxClusterDistance, double maxClusterDistance, {
) { Map<int, int>? allClusterIdsToCountMap,
}) {
final Map<int, List<(int, double)>> suggestions = {}; final Map<int, List<(int, double)>> suggestions = {};
for (final otherClusterID in clusterAvg.keys) { for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored // ignore the cluster that belong to the person or is ignored
@ -802,11 +781,32 @@ class ClusterFeedbackService {
.add((otherClusterID, minDistance)); .add((otherClusterID, minDistance));
} }
} }
for (final entry in suggestions.entries) {
entry.value.sort((a, b) => a.$1.compareTo(b.$1));
}
return suggestions; if (suggestions.isNotEmpty) {
final List<(int, double)> suggestClusterIds = [];
for (final List<(int, double)> suggestion in suggestions.values) {
suggestClusterIds.addAll(suggestion);
}
List<int>? suggestClusterIdsSizes;
if (allClusterIdsToCountMap != null) {
suggestClusterIds.sort(
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!),
);
suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false);
}
final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info(
"Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances",
);
return suggestClusterIds;
} else {
_logger.info("No suggestions found using mean");
return <(int, double)>[];
}
} }
List<T> _randomSampleWithoutReplacement<T>( List<T> _randomSampleWithoutReplacement<T>(

View file

@ -28,6 +28,7 @@ import "package:photos/models/search/search_constants.dart";
import "package:photos/models/search/search_types.dart"; import "package:photos/models/search/search_types.dart";
import 'package:photos/services/collections_service.dart'; import 'package:photos/services/collections_service.dart';
import "package:photos/services/location_service.dart"; import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/states/location_screen_state.dart"; import "package:photos/states/location_screen_state.dart";
@ -824,7 +825,7 @@ class SearchService {
"Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
); );
} }
if (files.length < 20 && sortedClusterIds.length > 3) { if (files.length < kMinimumClusterSizeSearchResult && sortedClusterIds.length > 3) {
continue; continue;
} }
facesResult.add( facesResult.add(