diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index b81a0c1ab..d11afa180 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -3,6 +3,7 @@ import "dart:math" show Random, min; import "package:flutter/foundation.dart"; import "package:logging/logging.dart"; +import "package:ml_linalg/linalg.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; // import "package:photos/events/files_updated_event.dart"; @@ -245,13 +246,13 @@ class ClusterFeedbackService { final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); dev.log( - 'existing clusters for ${p.data.name} are $personClusters', + '${p.data.name} has ${personClusters.length} existing clusters', name: "ClusterFeedbackService", ); // Get and update the cluster summary to get the avg (centroid) and count final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); - final Map> clusterAvg = await _getUpdateClusterAvg( + final Map clusterAvg = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, ); @@ -466,19 +467,19 @@ class ClusterFeedbackService { final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); dev.log( - 'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms', + '${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms', name: "getSuggestionsUsingMedian", ); // First only do a simple check on the big clusters, if the person does not have small clusters yet - final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); final smallestPersonClusterSize = personClusters .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0) .reduce((value, element) => min(value, element)); final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1]; for (final minimumSize in checkSizes.toSet()) { if (smallestPersonClusterSize >= minimumSize) { - final Map> clusterAvgBigClusters = + final Map clusterAvgBigClusters = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, @@ -487,6 +488,7 @@ class ClusterFeedbackService { dev.log( 'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', ); + w?.log('Calculate avg for min size $minimumSize'); final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean( clusterAvgBigClusters, @@ -494,6 +496,7 @@ class ClusterFeedbackService { ignoredClusters, goodMeanDistance, ); + w?.log('Calculate suggestions using mean for min size $minimumSize'); if (suggestionsMeanBigClusters.isNotEmpty) { return suggestionsMeanBigClusters .map((e) => (e.$1, e.$2, true)) @@ -501,9 +504,10 @@ class ClusterFeedbackService { } } } + w?.reset(); // Get and update the cluster summary to get the avg (centroid) and count - final Map> clusterAvg = await _getUpdateClusterAvg( + final Map clusterAvg = await _getUpdateClusterAvg( allClusterIdsToCountMap, ignoredClusters, ); @@ -547,7 +551,7 @@ class ClusterFeedbackService { "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates", ); - watch.logAndReset("Starting median test"); + w?.logAndReset("Starting median test"); // Take the embeddings from the person's clusters in one big list and sample from it final List personEmbeddingsProto = []; for (final clusterID in personClusters) { @@ -600,7 +604,7 @@ class ClusterFeedbackService { } } } - watch.log("Finished median test"); + w?.log("Finished median test"); if (suggestionsMedian.isEmpty) { _logger.info("No suggestions found using median"); return []; @@ -632,7 +636,7 @@ class ClusterFeedbackService { return finalSuggestionsMedian; } - Future>> _getUpdateClusterAvg( + Future> _getUpdateClusterAvg( Map allClusterIdsToCountMap, Set ignoredClusters, { int minClusterSize = 1, @@ -649,7 +653,7 @@ class ClusterFeedbackService { await faceMlDb.getAllClusterSummary(minClusterSize); final Map updatesForClusterSummary = {}; - final Map> clusterAvg = {}; + final Map clusterAvg = {}; dev.log( 'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms', @@ -666,7 +670,9 @@ class ClusterFeedbackService { } if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { allClusterIds.remove(id); - clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values; + clusterAvg[id] = Vector.fromList( + EVector.fromBuffer(clusterToSummary[id]!.$1).values, + dtype: DType.float32,); alreadyUpdatedClustersCnt++; } if (allClusterIdsToCountMap[id]! < minClusterSize) { @@ -731,19 +737,15 @@ class ClusterFeedbackService { ); for (final clusterID in clusterEmbeddings.keys) { - late List avg; - final Iterable embedings = clusterEmbeddings[clusterID]!; - final List sum = List.filled(192, 0); - for (final embedding in embedings) { - final data = EVector.fromBuffer(embedding).values; - for (int i = 0; i < sum.length; i++) { - sum[i] += data[i]; - } - } - avg = sum.map((e) => e / embedings.length).toList(); - final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer(); + final Iterable embeddings = clusterEmbeddings[clusterID]!; + final Iterable vectors = embeddings.map((e) => Vector.fromList( + EVector.fromBuffer(e).values, + dtype: DType.float32, + ),); + final avg = vectors.reduce((a, b) => a + b) / vectors.length; + final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer(); updatesForClusterSummary[clusterID] = - (avgEmbeedingBuffer, embedings.length); + (avgEmbeddingBuffer, embeddings.length); // store the intermediate updates indexedInCurrentRun++; if (updatesForClusterSummary.length > 100) { @@ -770,7 +772,7 @@ class ClusterFeedbackService { /// Returns a map of person's clusterID to map of closest clusterID to with disstance List<(int, double)> _calcSuggestionsMean( - Map> clusterAvg, + Map clusterAvg, Set personClusters, Set ignoredClusters, double maxClusterDistance, { @@ -779,23 +781,14 @@ class ClusterFeedbackService { final Map> suggestions = {}; int suggestionCount = 0; final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); - final clusterAvgVectors = clusterAvg.map( - (key, value) => MapEntry( - key, - Vector.fromList( - value, - dtype: DType.float32, - ), - ), - ); w?.log('converted avg to vectors for ${clusterAvg.length} averages'); - for (final otherClusterID in clusterAvgVectors.keys) { + for (final otherClusterID in clusterAvg.keys) { // ignore the cluster that belong to the person or is ignored if (personClusters.contains(otherClusterID) || ignoredClusters.contains(otherClusterID)) { continue; } - final otherAvg = clusterAvgVectors[otherClusterID]!; + final Vector otherAvg = clusterAvg[otherClusterID]!; int? nearestPersonCluster; double? minDistance; for (final personCluster in personClusters) { @@ -803,7 +796,7 @@ class ClusterFeedbackService { _logger.info('no avg for cluster $personCluster'); continue; } - final avg = clusterAvgVectors[personCluster]!; + final Vector avg = clusterAvg[personCluster]!; final distance = 1 - avg.dot(otherAvg); if (distance < maxClusterDistance) { if (minDistance == null || distance < minDistance) {