[mob][photos] Improve suggestions by improving speed and preferring big clusters

This commit is contained in:
laurenspriem 2024-04-22 16:40:31 +05:30
parent 37ab467da5
commit fa466d715f
3 changed files with 104 additions and 100 deletions

View file

@ -15,3 +15,6 @@ const kHighQualityFaceScore = 0.90;
/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;
/// The minimum cluster size for displaying a cluster in the UI
const kMinimumClusterSizeSearchResult = 20;

View file

@ -15,6 +15,7 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/search_service.dart";
@ -65,7 +66,7 @@ class ClusterFeedbackService {
try {
// Get the suggestions for the person using centroids and median
final List<(int, double, bool)> suggestClusterIds =
await _getSuggestionsUsingMedian(person);
await _getSuggestions(person);
// Get the files for the suggestions
final Map<int, Set<int>> fileIdToClusterID =
@ -241,7 +242,7 @@ class ClusterFeedbackService {
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
final List<(int, double)> suggestions = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
@ -257,21 +258,17 @@ class ClusterFeedbackService {
}
// log suggestions
for (final entry in suggestions.entries) {
dev.log(
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
name: "ClusterFeedbackService",
);
}
dev.log(
'suggestions for ${p.data.name} for cluster ID ${p.remoteID} are suggestions $suggestions}',
name: "ClusterFeedbackService",
);
for (final suggestionsPerCluster in suggestions.values) {
for (final suggestion in suggestionsPerCluster) {
final clusterID = suggestion.$1;
await PersonService.instance.assignClusterToPerson(
personID: p.remoteID,
clusterID: clusterID,
);
}
for (final suggestion in suggestions) {
final clusterID = suggestion.$1;
await PersonService.instance.assignClusterToPerson(
personID: p.remoteID,
clusterID: clusterID,
);
}
Bus.instance.fire(PeopleChangedEvent());
@ -433,111 +430,77 @@ class ClusterFeedbackService {
return;
}
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
PersonEntity p, {
double maxClusterDistance = 0.4,
}) async {
// Get all the cluster data
final faceMlDb = FaceMLDataDB.instance;
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.data.name} are $personClusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
// Find the actual closest clusters for the person
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
maxClusterDistance,
);
// log suggestions
for (final entry in suggestions.entries) {
dev.log(
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
name: "ClusterFeedbackService",
);
}
return suggestions;
}
/// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements:
/// 1. clusterID: the ID of the cluster
/// 2. distance: the distance between the person's cluster and the suggestion
/// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false)
Future<List<(int, double, bool)>> _getSuggestionsUsingMedian(
Future<List<(int, double, bool)>> _getSuggestions(
PersonEntity p, {
int sampleSize = 50,
double maxMedianDistance = 0.65,
double goodMedianDistance = 0.55,
double maxMeanDistance = 0.65,
double goodMeanDistance = 0.4,
double goodMeanDistance = 0.5,
}) async {
// Get all the cluster data
final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance;
// final Map<int, List<(int, double)>> suggestions = {};
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.data.name} are $personClusters',
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
name: "getSuggestionsUsingMedian",
);
// Get and update the cluster summary to get the avg (centroid) and count
// First only do a simple check on the big clusters
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvgBigClusters =
await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
minClusterSize: kMinimumClusterSizeSearchResult,
);
dev.log(
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean(
clusterAvgBigClusters,
personClusters,
ignoredClusters,
goodMeanDistance,
);
if (suggestionsMeanBigClusters.isNotEmpty) {
return suggestionsMeanBigClusters
.map((e) => (e.$1, e.$2, true))
.toList(growable: false);
}
// Get and update the cluster summary to get the avg (centroid) and count
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
watch.log('computed avg for ${clusterAvg.length} clusters');
dev.log(
'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
// Find the other cluster candidates based on the mean
final Map<int, List<(int, double)>> suggestionsMean = _calcSuggestionsMean(
final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
goodMeanDistance,
);
if (suggestionsMean.isNotEmpty) {
final List<(int, double)> suggestClusterIds = [];
for (final List<(int, double)> suggestion in suggestionsMean.values) {
suggestClusterIds.addAll(suggestion);
}
suggestClusterIds.sort(
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!),
);
final suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false);
final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info(
"Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances",
);
return suggestClusterIds
return suggestionsMean
.map((e) => (e.$1, e.$2, true))
.toList(growable: false);
}
// Find the other cluster candidates based on the median
final Map<int, List<(int, double)>> moreSuggestionsMean =
_calcSuggestionsMean(
final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
clusterAvg,
personClusters,
ignoredClusters,
@ -549,12 +512,8 @@ class ClusterFeedbackService {
return [];
}
final List<(int, double)> temp = [];
for (final List<(int, double)> suggestion in moreSuggestionsMean.values) {
temp.addAll(suggestion);
}
temp.sort((a, b) => a.$2.compareTo(b.$2));
final otherClusterIdsCandidates = temp
moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2));
final otherClusterIdsCandidates = moreSuggestionsMean
.map(
(e) => e.$1,
)
@ -655,20 +614,26 @@ class ClusterFeedbackService {
int maxClusterInCurrentRun = 500,
int maxEmbeddingToRead = 10000,
}) async {
final startTime = DateTime.now();
final faceMlDb = FaceMLDataDB.instance;
_logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
);
final Map<int, (Uint8List, int)> clusterToSummary =
await faceMlDb.getAllClusterSummary();
await faceMlDb.getAllClusterSummary(minClusterSize);
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {};
dev.log(
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
int smallerClustersCnt = 0;
final serializationTime = DateTime.now();
for (final id in allClusterIdsToCountMap.keys) {
if (ignoredClusters.contains(id)) {
allClusterIds.remove(id);
@ -684,9 +649,20 @@ class ClusterFeedbackService {
smallerClustersCnt++;
}
}
dev.log(
'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms',
);
_logger.info(
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
);
if (allClusterIds.isEmpty) {
_logger.info(
'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
return clusterAvg;
}
// get clusterIDs sorted by count in descending order
final sortedClusterIDs = allClusterIds.toList();
sortedClusterIDs.sort(
@ -760,18 +736,21 @@ class ClusterFeedbackService {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
}
w?.logAndReset('done computing avg ');
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
_logger.info(
'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
return clusterAvg;
}
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
Map<int, List<(int, double)>> _calcSuggestionsMean(
List<(int, double)> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg,
Set<int> personClusters,
Set<int> ignoredClusters,
double maxClusterDistance,
) {
double maxClusterDistance, {
Map<int, int>? allClusterIdsToCountMap,
}) {
final Map<int, List<(int, double)>> suggestions = {};
for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored
@ -802,11 +781,32 @@ class ClusterFeedbackService {
.add((otherClusterID, minDistance));
}
}
for (final entry in suggestions.entries) {
entry.value.sort((a, b) => a.$1.compareTo(b.$1));
}
return suggestions;
if (suggestions.isNotEmpty) {
final List<(int, double)> suggestClusterIds = [];
for (final List<(int, double)> suggestion in suggestions.values) {
suggestClusterIds.addAll(suggestion);
}
List<int>? suggestClusterIdsSizes;
if (allClusterIdsToCountMap != null) {
suggestClusterIds.sort(
(a, b) => allClusterIdsToCountMap[b.$1]!
.compareTo(allClusterIdsToCountMap[a.$1]!),
);
suggestClusterIdsSizes = suggestClusterIds
.map((e) => allClusterIdsToCountMap[e.$1]!)
.toList(growable: false);
}
final suggestClusterIdsDistances =
suggestClusterIds.map((e) => e.$2).toList(growable: false);
_logger.info(
"Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances",
);
return suggestClusterIds;
} else {
_logger.info("No suggestions found using mean");
return <(int, double)>[];
}
}
List<T> _randomSampleWithoutReplacement<T>(

View file

@ -28,6 +28,7 @@ import "package:photos/models/search/search_constants.dart";
import "package:photos/models/search/search_types.dart";
import 'package:photos/services/collections_service.dart';
import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/states/location_screen_state.dart";
@ -824,7 +825,7 @@ class SearchService {
"Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
);
}
if (files.length < 20 && sortedClusterIds.length > 3) {
if (files.length < kMinimumClusterSizeSearchResult && sortedClusterIds.length > 3) {
continue;
}
facesResult.add(