diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 95b3f2b08..26623fecc 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -481,46 +481,46 @@ class ClusterFeedbackService { late Map clusterAvgBigClusters; final List<(int, double)> suggestionsMean = []; for (final minimumSize in checkSizes.toSet()) { - // if (smallestPersonClusterSize >= minimumSize) { - clusterAvgBigClusters = await _getUpdateClusterAvg( - allClusterIdsToCountMap, - ignoredClusters, - minClusterSize: minimumSize, - ); - w?.log( - 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', - ); - final List<(int, double)> suggestionsMeanBigClusters = - _calcSuggestionsMean( - clusterAvgBigClusters, - personClusters, - ignoredClusters, - goodMeanDistance, - ); - w?.log( - 'Calculate suggestions using mean for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', - ); - for (final suggestion in suggestionsMeanBigClusters) { - // Skip suggestions that have a high overlap with the person's files - final suggestionSet = allClusterIdToFaceIDs[suggestion.$1]! - .map((faceID) => getFileIdFromFaceId(faceID)) - .toSet(); - final overlap = personFileIDs.intersection(suggestionSet); - if (overlap.isNotEmpty && - ((overlap.length / suggestionSet.length) > 0.5)) { - await FaceMLDataDB.instance.captureNotPersonFeedback( - personID: p.remoteID, - clusterID: suggestion.$1, - ); - continue; + if (smallestPersonClusterSize >= minimumSize) { + clusterAvgBigClusters = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + minClusterSize: minimumSize, + ); + w?.log( + 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', + ); + final List<(int, double)> suggestionsMeanBigClusters = + _calcSuggestionsMean( + clusterAvgBigClusters, + personClusters, + ignoredClusters, + goodMeanDistance, + ); + w?.log( + 'Calculate suggestions using mean for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', + ); + for (final suggestion in suggestionsMeanBigClusters) { + // Skip suggestions that have a high overlap with the person's files + final suggestionSet = allClusterIdToFaceIDs[suggestion.$1]! + .map((faceID) => getFileIdFromFaceId(faceID)) + .toSet(); + final overlap = personFileIDs.intersection(suggestionSet); + if (overlap.isNotEmpty && + ((overlap.length / suggestionSet.length) > 0.5)) { + await FaceMLDataDB.instance.captureNotPersonFeedback( + personID: p.remoteID, + clusterID: suggestion.$1, + ); + continue; + } + suggestionsMean.add(suggestion); + } + if (suggestionsMean.isNotEmpty) { + return suggestionsMean + .map((e) => (e.$1, e.$2, true)) + .toList(growable: false); } - suggestionsMean.add(suggestion); - } - if (suggestionsMean.isNotEmpty) { - return suggestionsMean - .map((e) => (e.$1, e.$2, true)) - .toList(growable: false); - // } } } w?.reset(); @@ -784,24 +784,31 @@ class ClusterFeedbackService { Map? allClusterIdsToCountMap, }) { final Map> suggestions = {}; + const suggestionMax = 2000; int suggestionCount = 0; + int comparisons = 0; final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); - for (final otherClusterID in clusterAvg.keys) { - // ignore the cluster that belong to the person or is ignored - if (personClusters.contains(otherClusterID) || - ignoredClusters.contains(otherClusterID)) { + + // ignore the clusters that belong to the person or is ignored + Set otherClusters = clusterAvg.keys.toSet().difference(personClusters); + otherClusters = otherClusters.difference(ignoredClusters); + + for (final otherClusterID in otherClusters) { + final Vector? otherAvg = clusterAvg[otherClusterID]; + if (otherAvg == null) { + _logger.warning('no avg for othercluster $otherClusterID'); continue; } - final Vector otherAvg = clusterAvg[otherClusterID]!; int? nearestPersonCluster; double? minDistance; for (final personCluster in personClusters) { if (clusterAvg[personCluster] == null) { - _logger.info('no avg for cluster $personCluster'); + _logger.warning('no avg for personcluster $personCluster'); continue; } final Vector avg = clusterAvg[personCluster]!; final distance = cosineDistanceSIMD(avg, otherAvg); + comparisons++; if (distance < maxClusterDistance) { if (minDistance == null || distance < minDistance) { minDistance = distance; @@ -815,11 +822,13 @@ class ClusterFeedbackService { .add((otherClusterID, minDistance)); suggestionCount++; } - if (suggestionCount >= 2000) { + if (suggestionCount >= suggestionMax) { break; } } - w?.log('calculation inside calcSuggestionsMean'); + w?.log( + 'calculation inside calcSuggestionsMean for ${personClusters.length} person clusters and ${otherClusters.length} other clusters (so ${personClusters.length * otherClusters.length} combinations, $comparisons comparisons made resulted in $suggestionCount suggestions)', + ); if (suggestions.isNotEmpty) { final List<(int, double)> suggestClusterIds = []; diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 913e52268..8617e9cda 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -114,7 +114,9 @@ class PersonService { } bool _shouldUpdateRemotePerson( - PersonData personData, Map> dbPersonCluster) { + PersonData personData, + Map> dbPersonCluster, + ) { bool result = false; if ((personData.assigned?.length ?? 0) != dbPersonCluster.length) { log(