[mob][photos] Use vectors everywhere in cluster suggestion
This commit is contained in:
parent
4b6641d7d8
commit
e829f7b62f
1 changed files with 29 additions and 36 deletions
|
@ -3,6 +3,7 @@ import "dart:math" show Random, min;
|
|||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
// import "package:photos/events/files_updated_event.dart";
|
||||
|
@ -245,13 +246,13 @@ class ClusterFeedbackService {
|
|||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters',
|
||||
'${p.data.name} has ${personClusters.length} existing clusters',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
|
@ -466,19 +467,19 @@ class ClusterFeedbackService {
|
|||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
name: "getSuggestionsUsingMedian",
|
||||
);
|
||||
|
||||
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||
final smallestPersonClusterSize = personClusters
|
||||
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
||||
.reduce((value, element) => min(value, element));
|
||||
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
|
||||
for (final minimumSize in checkSizes.toSet()) {
|
||||
if (smallestPersonClusterSize >= minimumSize) {
|
||||
final Map<int, List<double>> clusterAvgBigClusters =
|
||||
final Map<int, Vector> clusterAvgBigClusters =
|
||||
await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
|
@ -487,6 +488,7 @@ class ClusterFeedbackService {
|
|||
dev.log(
|
||||
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
w?.log('Calculate avg for min size $minimumSize');
|
||||
final List<(int, double)> suggestionsMeanBigClusters =
|
||||
_calcSuggestionsMean(
|
||||
clusterAvgBigClusters,
|
||||
|
@ -494,6 +496,7 @@ class ClusterFeedbackService {
|
|||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
w?.log('Calculate suggestions using mean for min size $minimumSize');
|
||||
if (suggestionsMeanBigClusters.isNotEmpty) {
|
||||
return suggestionsMeanBigClusters
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
|
@ -501,9 +504,10 @@ class ClusterFeedbackService {
|
|||
}
|
||||
}
|
||||
}
|
||||
w?.reset();
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
|
@ -547,7 +551,7 @@ class ClusterFeedbackService {
|
|||
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
||||
);
|
||||
|
||||
watch.logAndReset("Starting median test");
|
||||
w?.logAndReset("Starting median test");
|
||||
// Take the embeddings from the person's clusters in one big list and sample from it
|
||||
final List<Uint8List> personEmbeddingsProto = [];
|
||||
for (final clusterID in personClusters) {
|
||||
|
@ -600,7 +604,7 @@ class ClusterFeedbackService {
|
|||
}
|
||||
}
|
||||
}
|
||||
watch.log("Finished median test");
|
||||
w?.log("Finished median test");
|
||||
if (suggestionsMedian.isEmpty) {
|
||||
_logger.info("No suggestions found using median");
|
||||
return [];
|
||||
|
@ -632,7 +636,7 @@ class ClusterFeedbackService {
|
|||
return finalSuggestionsMedian;
|
||||
}
|
||||
|
||||
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
||||
Future<Map<int, Vector>> _getUpdateClusterAvg(
|
||||
Map<int, int> allClusterIdsToCountMap,
|
||||
Set<int> ignoredClusters, {
|
||||
int minClusterSize = 1,
|
||||
|
@ -649,7 +653,7 @@ class ClusterFeedbackService {
|
|||
await faceMlDb.getAllClusterSummary(minClusterSize);
|
||||
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
||||
|
||||
final Map<int, List<double>> clusterAvg = {};
|
||||
final Map<int, Vector> clusterAvg = {};
|
||||
|
||||
dev.log(
|
||||
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
|
@ -666,7 +670,9 @@ class ClusterFeedbackService {
|
|||
}
|
||||
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||
allClusterIds.remove(id);
|
||||
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
||||
clusterAvg[id] = Vector.fromList(
|
||||
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
||||
dtype: DType.float32,);
|
||||
alreadyUpdatedClustersCnt++;
|
||||
}
|
||||
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||
|
@ -731,19 +737,15 @@ class ClusterFeedbackService {
|
|||
);
|
||||
|
||||
for (final clusterID in clusterEmbeddings.keys) {
|
||||
late List<double> avg;
|
||||
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
||||
final List<double> sum = List.filled(192, 0);
|
||||
for (final embedding in embedings) {
|
||||
final data = EVector.fromBuffer(embedding).values;
|
||||
for (int i = 0; i < sum.length; i++) {
|
||||
sum[i] += data[i];
|
||||
}
|
||||
}
|
||||
avg = sum.map((e) => e / embedings.length).toList();
|
||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
||||
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
||||
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
|
||||
EVector.fromBuffer(e).values,
|
||||
dtype: DType.float32,
|
||||
),);
|
||||
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
||||
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeedingBuffer, embedings.length);
|
||||
(avgEmbeddingBuffer, embeddings.length);
|
||||
// store the intermediate updates
|
||||
indexedInCurrentRun++;
|
||||
if (updatesForClusterSummary.length > 100) {
|
||||
|
@ -770,7 +772,7 @@ class ClusterFeedbackService {
|
|||
|
||||
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
||||
List<(int, double)> _calcSuggestionsMean(
|
||||
Map<int, List<double>> clusterAvg,
|
||||
Map<int, Vector> clusterAvg,
|
||||
Set<int> personClusters,
|
||||
Set<int> ignoredClusters,
|
||||
double maxClusterDistance, {
|
||||
|
@ -779,23 +781,14 @@ class ClusterFeedbackService {
|
|||
final Map<int, List<(int, double)>> suggestions = {};
|
||||
int suggestionCount = 0;
|
||||
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||
final clusterAvgVectors = clusterAvg.map(
|
||||
(key, value) => MapEntry(
|
||||
key,
|
||||
Vector.fromList(
|
||||
value,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
),
|
||||
);
|
||||
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
|
||||
for (final otherClusterID in clusterAvgVectors.keys) {
|
||||
for (final otherClusterID in clusterAvg.keys) {
|
||||
// ignore the cluster that belong to the person or is ignored
|
||||
if (personClusters.contains(otherClusterID) ||
|
||||
ignoredClusters.contains(otherClusterID)) {
|
||||
continue;
|
||||
}
|
||||
final otherAvg = clusterAvgVectors[otherClusterID]!;
|
||||
final Vector otherAvg = clusterAvg[otherClusterID]!;
|
||||
int? nearestPersonCluster;
|
||||
double? minDistance;
|
||||
for (final personCluster in personClusters) {
|
||||
|
@ -803,7 +796,7 @@ class ClusterFeedbackService {
|
|||
_logger.info('no avg for cluster $personCluster');
|
||||
continue;
|
||||
}
|
||||
final avg = clusterAvgVectors[personCluster]!;
|
||||
final Vector avg = clusterAvg[personCluster]!;
|
||||
final distance = 1 - avg.dot(otherAvg);
|
||||
if (distance < maxClusterDistance) {
|
||||
if (minDistance == null || distance < minDistance) {
|
||||
|
|
Loading…
Add table
Reference in a new issue