[mob][photos] Improve suggestions by improving speed and preferring big clusters
This commit is contained in:
parent
37ab467da5
commit
fa466d715f
3 changed files with 104 additions and 100 deletions
|
@ -15,3 +15,6 @@ const kHighQualityFaceScore = 0.90;
|
|||
|
||||
/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces.
|
||||
const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold;
|
||||
|
||||
/// The minimum cluster size for displaying a cluster in the UI
|
||||
const kMinimumClusterSizeSearchResult = 20;
|
||||
|
|
|
@ -15,6 +15,7 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|||
import "package:photos/models/file/file.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
|
@ -65,7 +66,7 @@ class ClusterFeedbackService {
|
|||
try {
|
||||
// Get the suggestions for the person using centroids and median
|
||||
final List<(int, double, bool)> suggestClusterIds =
|
||||
await _getSuggestionsUsingMedian(person);
|
||||
await _getSuggestions(person);
|
||||
|
||||
// Get the files for the suggestions
|
||||
final Map<int, Set<int>> fileIdToClusterID =
|
||||
|
@ -241,7 +242,7 @@ class ClusterFeedbackService {
|
|||
watch.log('computed avg for ${clusterAvg.length} clusters');
|
||||
|
||||
// Find the actual closest clusters for the person
|
||||
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
|
||||
final List<(int, double)> suggestions = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
|
@ -257,21 +258,17 @@ class ClusterFeedbackService {
|
|||
}
|
||||
|
||||
// log suggestions
|
||||
for (final entry in suggestions.entries) {
|
||||
dev.log(
|
||||
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
}
|
||||
dev.log(
|
||||
'suggestions for ${p.data.name} for cluster ID ${p.remoteID} are suggestions $suggestions}',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
|
||||
for (final suggestionsPerCluster in suggestions.values) {
|
||||
for (final suggestion in suggestionsPerCluster) {
|
||||
final clusterID = suggestion.$1;
|
||||
await PersonService.instance.assignClusterToPerson(
|
||||
personID: p.remoteID,
|
||||
clusterID: clusterID,
|
||||
);
|
||||
}
|
||||
for (final suggestion in suggestions) {
|
||||
final clusterID = suggestion.$1;
|
||||
await PersonService.instance.assignClusterToPerson(
|
||||
personID: p.remoteID,
|
||||
clusterID: clusterID,
|
||||
);
|
||||
}
|
||||
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
|
@ -433,111 +430,77 @@ class ClusterFeedbackService {
|
|||
return;
|
||||
}
|
||||
|
||||
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
||||
Future<Map<int, List<(int, double)>>> getSuggestionsUsingMean(
|
||||
PersonEntity p, {
|
||||
double maxClusterDistance = 0.4,
|
||||
}) async {
|
||||
// Get all the cluster data
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
|
||||
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
|
||||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
watch.log('computed avg for ${clusterAvg.length} clusters');
|
||||
|
||||
// Find the actual closest clusters for the person
|
||||
final Map<int, List<(int, double)>> suggestions = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
maxClusterDistance,
|
||||
);
|
||||
|
||||
// log suggestions
|
||||
for (final entry in suggestions.entries) {
|
||||
dev.log(
|
||||
' ${entry.value.length} suggestion for ${p.data.name} for cluster ID ${entry.key} are suggestions ${entry.value}}',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
}
|
||||
return suggestions;
|
||||
}
|
||||
|
||||
/// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements:
|
||||
/// 1. clusterID: the ID of the cluster
|
||||
/// 2. distance: the distance between the person's cluster and the suggestion
|
||||
/// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false)
|
||||
Future<List<(int, double, bool)>> _getSuggestionsUsingMedian(
|
||||
Future<List<(int, double, bool)>> _getSuggestions(
|
||||
PersonEntity p, {
|
||||
int sampleSize = 50,
|
||||
double maxMedianDistance = 0.65,
|
||||
double goodMedianDistance = 0.55,
|
||||
double maxMeanDistance = 0.65,
|
||||
double goodMeanDistance = 0.4,
|
||||
double goodMeanDistance = 0.5,
|
||||
}) async {
|
||||
// Get all the cluster data
|
||||
final startTime = DateTime.now();
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
// final Map<int, List<(int, double)>> suggestions = {};
|
||||
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());
|
||||
final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
|
||||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters',
|
||||
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
name: "getSuggestionsUsingMedian",
|
||||
);
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
// First only do a simple check on the big clusters
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final Map<int, List<double>> clusterAvgBigClusters =
|
||||
await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
minClusterSize: kMinimumClusterSizeSearchResult,
|
||||
);
|
||||
dev.log(
|
||||
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean(
|
||||
clusterAvgBigClusters,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
if (suggestionsMeanBigClusters.isNotEmpty) {
|
||||
return suggestionsMeanBigClusters
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
.toList(growable: false);
|
||||
}
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
watch.log('computed avg for ${clusterAvg.length} clusters');
|
||||
dev.log(
|
||||
'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
// Find the other cluster candidates based on the mean
|
||||
final Map<int, List<(int, double)>> suggestionsMean = _calcSuggestionsMean(
|
||||
final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
if (suggestionsMean.isNotEmpty) {
|
||||
final List<(int, double)> suggestClusterIds = [];
|
||||
for (final List<(int, double)> suggestion in suggestionsMean.values) {
|
||||
suggestClusterIds.addAll(suggestion);
|
||||
}
|
||||
suggestClusterIds.sort(
|
||||
(a, b) => allClusterIdsToCountMap[b.$1]!
|
||||
.compareTo(allClusterIdsToCountMap[a.$1]!),
|
||||
);
|
||||
final suggestClusterIdsSizes = suggestClusterIds
|
||||
.map((e) => allClusterIdsToCountMap[e.$1]!)
|
||||
.toList(growable: false);
|
||||
final suggestClusterIdsDistances =
|
||||
suggestClusterIds.map((e) => e.$2).toList(growable: false);
|
||||
_logger.info(
|
||||
"Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances",
|
||||
);
|
||||
return suggestClusterIds
|
||||
return suggestionsMean
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
.toList(growable: false);
|
||||
}
|
||||
|
||||
// Find the other cluster candidates based on the median
|
||||
final Map<int, List<(int, double)>> moreSuggestionsMean =
|
||||
_calcSuggestionsMean(
|
||||
final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
|
@ -549,12 +512,8 @@ class ClusterFeedbackService {
|
|||
return [];
|
||||
}
|
||||
|
||||
final List<(int, double)> temp = [];
|
||||
for (final List<(int, double)> suggestion in moreSuggestionsMean.values) {
|
||||
temp.addAll(suggestion);
|
||||
}
|
||||
temp.sort((a, b) => a.$2.compareTo(b.$2));
|
||||
final otherClusterIdsCandidates = temp
|
||||
moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2));
|
||||
final otherClusterIdsCandidates = moreSuggestionsMean
|
||||
.map(
|
||||
(e) => e.$1,
|
||||
)
|
||||
|
@ -655,20 +614,26 @@ class ClusterFeedbackService {
|
|||
int maxClusterInCurrentRun = 500,
|
||||
int maxEmbeddingToRead = 10000,
|
||||
}) async {
|
||||
final startTime = DateTime.now();
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
_logger.info(
|
||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
|
||||
);
|
||||
|
||||
final Map<int, (Uint8List, int)> clusterToSummary =
|
||||
await faceMlDb.getAllClusterSummary();
|
||||
await faceMlDb.getAllClusterSummary(minClusterSize);
|
||||
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
||||
|
||||
final Map<int, List<double>> clusterAvg = {};
|
||||
|
||||
dev.log(
|
||||
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
|
||||
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
|
||||
int smallerClustersCnt = 0;
|
||||
final serializationTime = DateTime.now();
|
||||
for (final id in allClusterIdsToCountMap.keys) {
|
||||
if (ignoredClusters.contains(id)) {
|
||||
allClusterIds.remove(id);
|
||||
|
@ -684,9 +649,20 @@ class ClusterFeedbackService {
|
|||
smallerClustersCnt++;
|
||||
}
|
||||
}
|
||||
dev.log(
|
||||
'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms',
|
||||
);
|
||||
_logger.info(
|
||||
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
|
||||
);
|
||||
|
||||
if (allClusterIds.isEmpty) {
|
||||
_logger.info(
|
||||
'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
return clusterAvg;
|
||||
}
|
||||
|
||||
// get clusterIDs sorted by count in descending order
|
||||
final sortedClusterIDs = allClusterIds.toList();
|
||||
sortedClusterIDs.sort(
|
||||
|
@ -760,18 +736,21 @@ class ClusterFeedbackService {
|
|||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
}
|
||||
w?.logAndReset('done computing avg ');
|
||||
_logger.info('end getUpdateClusterAvg for ${clusterAvg.length} clusters');
|
||||
_logger.info(
|
||||
'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
return clusterAvg;
|
||||
}
|
||||
|
||||
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
||||
Map<int, List<(int, double)>> _calcSuggestionsMean(
|
||||
List<(int, double)> _calcSuggestionsMean(
|
||||
Map<int, List<double>> clusterAvg,
|
||||
Set<int> personClusters,
|
||||
Set<int> ignoredClusters,
|
||||
double maxClusterDistance,
|
||||
) {
|
||||
double maxClusterDistance, {
|
||||
Map<int, int>? allClusterIdsToCountMap,
|
||||
}) {
|
||||
final Map<int, List<(int, double)>> suggestions = {};
|
||||
for (final otherClusterID in clusterAvg.keys) {
|
||||
// ignore the cluster that belong to the person or is ignored
|
||||
|
@ -802,11 +781,32 @@ class ClusterFeedbackService {
|
|||
.add((otherClusterID, minDistance));
|
||||
}
|
||||
}
|
||||
for (final entry in suggestions.entries) {
|
||||
entry.value.sort((a, b) => a.$1.compareTo(b.$1));
|
||||
}
|
||||
|
||||
return suggestions;
|
||||
if (suggestions.isNotEmpty) {
|
||||
final List<(int, double)> suggestClusterIds = [];
|
||||
for (final List<(int, double)> suggestion in suggestions.values) {
|
||||
suggestClusterIds.addAll(suggestion);
|
||||
}
|
||||
List<int>? suggestClusterIdsSizes;
|
||||
if (allClusterIdsToCountMap != null) {
|
||||
suggestClusterIds.sort(
|
||||
(a, b) => allClusterIdsToCountMap[b.$1]!
|
||||
.compareTo(allClusterIdsToCountMap[a.$1]!),
|
||||
);
|
||||
suggestClusterIdsSizes = suggestClusterIds
|
||||
.map((e) => allClusterIdsToCountMap[e.$1]!)
|
||||
.toList(growable: false);
|
||||
}
|
||||
final suggestClusterIdsDistances =
|
||||
suggestClusterIds.map((e) => e.$2).toList(growable: false);
|
||||
_logger.info(
|
||||
"Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances",
|
||||
);
|
||||
return suggestClusterIds;
|
||||
} else {
|
||||
_logger.info("No suggestions found using mean");
|
||||
return <(int, double)>[];
|
||||
}
|
||||
}
|
||||
|
||||
List<T> _randomSampleWithoutReplacement<T>(
|
||||
|
|
|
@ -28,6 +28,7 @@ import "package:photos/models/search/search_constants.dart";
|
|||
import "package:photos/models/search/search_types.dart";
|
||||
import 'package:photos/services/collections_service.dart';
|
||||
import "package:photos/services/location_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
|
||||
import "package:photos/states/location_screen_state.dart";
|
||||
|
@ -824,7 +825,7 @@ class SearchService {
|
|||
"Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}",
|
||||
);
|
||||
}
|
||||
if (files.length < 20 && sortedClusterIds.length > 3) {
|
||||
if (files.length < kMinimumClusterSizeSearchResult && sortedClusterIds.length > 3) {
|
||||
continue;
|
||||
}
|
||||
facesResult.add(
|
||||
|
|
Loading…
Reference in a new issue