|
@@ -3,6 +3,7 @@ import "dart:math" show Random, min;
|
|
|
|
|
|
import "package:flutter/foundation.dart";
|
|
import "package:flutter/foundation.dart";
|
|
import "package:logging/logging.dart";
|
|
import "package:logging/logging.dart";
|
|
|
|
+import "package:ml_linalg/linalg.dart";
|
|
import "package:photos/core/event_bus.dart";
|
|
import "package:photos/core/event_bus.dart";
|
|
import "package:photos/db/files_db.dart";
|
|
import "package:photos/db/files_db.dart";
|
|
// import "package:photos/events/files_updated_event.dart";
|
|
// import "package:photos/events/files_updated_event.dart";
|
|
@@ -245,13 +246,13 @@ class ClusterFeedbackService {
|
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
|
dev.log(
|
|
dev.log(
|
|
- 'existing clusters for ${p.data.name} are $personClusters',
|
|
|
|
|
|
+ '${p.data.name} has ${personClusters.length} existing clusters',
|
|
name: "ClusterFeedbackService",
|
|
name: "ClusterFeedbackService",
|
|
);
|
|
);
|
|
|
|
|
|
// Get and update the cluster summary to get the avg (centroid) and count
|
|
// Get and update the cluster summary to get the avg (centroid) and count
|
|
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
|
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
|
- final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
|
|
|
|
|
+ final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
|
allClusterIdsToCountMap,
|
|
allClusterIdsToCountMap,
|
|
ignoredClusters,
|
|
ignoredClusters,
|
|
);
|
|
);
|
|
@@ -466,19 +467,19 @@ class ClusterFeedbackService {
|
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
|
dev.log(
|
|
dev.log(
|
|
- 'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
|
|
|
|
+ '${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
name: "getSuggestionsUsingMedian",
|
|
name: "getSuggestionsUsingMedian",
|
|
);
|
|
);
|
|
|
|
|
|
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
|
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
|
- final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
|
|
|
|
|
+ final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
|
final smallestPersonClusterSize = personClusters
|
|
final smallestPersonClusterSize = personClusters
|
|
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
|
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
|
.reduce((value, element) => min(value, element));
|
|
.reduce((value, element) => min(value, element));
|
|
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
|
|
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
|
|
for (final minimumSize in checkSizes.toSet()) {
|
|
for (final minimumSize in checkSizes.toSet()) {
|
|
if (smallestPersonClusterSize >= minimumSize) {
|
|
if (smallestPersonClusterSize >= minimumSize) {
|
|
- final Map<int, List<double>> clusterAvgBigClusters =
|
|
|
|
|
|
+ final Map<int, Vector> clusterAvgBigClusters =
|
|
await _getUpdateClusterAvg(
|
|
await _getUpdateClusterAvg(
|
|
allClusterIdsToCountMap,
|
|
allClusterIdsToCountMap,
|
|
ignoredClusters,
|
|
ignoredClusters,
|
|
@@ -487,6 +488,7 @@ class ClusterFeedbackService {
|
|
dev.log(
|
|
dev.log(
|
|
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
);
|
|
);
|
|
|
|
+ w?.log('Calculate avg for min size $minimumSize');
|
|
final List<(int, double)> suggestionsMeanBigClusters =
|
|
final List<(int, double)> suggestionsMeanBigClusters =
|
|
_calcSuggestionsMean(
|
|
_calcSuggestionsMean(
|
|
clusterAvgBigClusters,
|
|
clusterAvgBigClusters,
|
|
@@ -494,6 +496,7 @@ class ClusterFeedbackService {
|
|
ignoredClusters,
|
|
ignoredClusters,
|
|
goodMeanDistance,
|
|
goodMeanDistance,
|
|
);
|
|
);
|
|
|
|
+ w?.log('Calculate suggestions using mean for min size $minimumSize');
|
|
if (suggestionsMeanBigClusters.isNotEmpty) {
|
|
if (suggestionsMeanBigClusters.isNotEmpty) {
|
|
return suggestionsMeanBigClusters
|
|
return suggestionsMeanBigClusters
|
|
.map((e) => (e.$1, e.$2, true))
|
|
.map((e) => (e.$1, e.$2, true))
|
|
@@ -501,9 +504,10 @@ class ClusterFeedbackService {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ w?.reset();
|
|
|
|
|
|
// Get and update the cluster summary to get the avg (centroid) and count
|
|
// Get and update the cluster summary to get the avg (centroid) and count
|
|
- final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
|
|
|
|
|
+ final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
|
allClusterIdsToCountMap,
|
|
allClusterIdsToCountMap,
|
|
ignoredClusters,
|
|
ignoredClusters,
|
|
);
|
|
);
|
|
@@ -547,7 +551,7 @@ class ClusterFeedbackService {
|
|
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
|
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
|
);
|
|
);
|
|
|
|
|
|
- watch.logAndReset("Starting median test");
|
|
|
|
|
|
+ w?.logAndReset("Starting median test");
|
|
// Take the embeddings from the person's clusters in one big list and sample from it
|
|
// Take the embeddings from the person's clusters in one big list and sample from it
|
|
final List<Uint8List> personEmbeddingsProto = [];
|
|
final List<Uint8List> personEmbeddingsProto = [];
|
|
for (final clusterID in personClusters) {
|
|
for (final clusterID in personClusters) {
|
|
@@ -600,7 +604,7 @@ class ClusterFeedbackService {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- watch.log("Finished median test");
|
|
|
|
|
|
+ w?.log("Finished median test");
|
|
if (suggestionsMedian.isEmpty) {
|
|
if (suggestionsMedian.isEmpty) {
|
|
_logger.info("No suggestions found using median");
|
|
_logger.info("No suggestions found using median");
|
|
return [];
|
|
return [];
|
|
@@ -632,7 +636,7 @@ class ClusterFeedbackService {
|
|
return finalSuggestionsMedian;
|
|
return finalSuggestionsMedian;
|
|
}
|
|
}
|
|
|
|
|
|
- Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
|
|
|
|
|
+ Future<Map<int, Vector>> _getUpdateClusterAvg(
|
|
Map<int, int> allClusterIdsToCountMap,
|
|
Map<int, int> allClusterIdsToCountMap,
|
|
Set<int> ignoredClusters, {
|
|
Set<int> ignoredClusters, {
|
|
int minClusterSize = 1,
|
|
int minClusterSize = 1,
|
|
@@ -649,7 +653,7 @@ class ClusterFeedbackService {
|
|
await faceMlDb.getAllClusterSummary(minClusterSize);
|
|
await faceMlDb.getAllClusterSummary(minClusterSize);
|
|
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
|
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
|
|
|
|
|
- final Map<int, List<double>> clusterAvg = {};
|
|
|
|
|
|
+ final Map<int, Vector> clusterAvg = {};
|
|
|
|
|
|
dev.log(
|
|
dev.log(
|
|
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
|
@@ -666,7 +670,9 @@ class ClusterFeedbackService {
|
|
}
|
|
}
|
|
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
|
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
|
allClusterIds.remove(id);
|
|
allClusterIds.remove(id);
|
|
- clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
|
|
|
|
|
+ clusterAvg[id] = Vector.fromList(
|
|
|
|
+ EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
|
|
|
+ dtype: DType.float32,);
|
|
alreadyUpdatedClustersCnt++;
|
|
alreadyUpdatedClustersCnt++;
|
|
}
|
|
}
|
|
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
|
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
|
@@ -731,19 +737,15 @@ class ClusterFeedbackService {
|
|
);
|
|
);
|
|
|
|
|
|
for (final clusterID in clusterEmbeddings.keys) {
|
|
for (final clusterID in clusterEmbeddings.keys) {
|
|
- late List<double> avg;
|
|
|
|
- final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
|
|
|
- final List<double> sum = List.filled(192, 0);
|
|
|
|
- for (final embedding in embedings) {
|
|
|
|
- final data = EVector.fromBuffer(embedding).values;
|
|
|
|
- for (int i = 0; i < sum.length; i++) {
|
|
|
|
- sum[i] += data[i];
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- avg = sum.map((e) => e / embedings.length).toList();
|
|
|
|
- final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
|
|
|
|
|
+ final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
|
|
|
+ final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
|
|
|
|
+ EVector.fromBuffer(e).values,
|
|
|
|
+ dtype: DType.float32,
|
|
|
|
+ ),);
|
|
|
|
+ final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
|
|
|
+ final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
|
updatesForClusterSummary[clusterID] =
|
|
updatesForClusterSummary[clusterID] =
|
|
- (avgEmbeedingBuffer, embedings.length);
|
|
|
|
|
|
+ (avgEmbeddingBuffer, embeddings.length);
|
|
// store the intermediate updates
|
|
// store the intermediate updates
|
|
indexedInCurrentRun++;
|
|
indexedInCurrentRun++;
|
|
if (updatesForClusterSummary.length > 100) {
|
|
if (updatesForClusterSummary.length > 100) {
|
|
@@ -770,7 +772,7 @@ class ClusterFeedbackService {
|
|
|
|
|
|
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
|
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
|
List<(int, double)> _calcSuggestionsMean(
|
|
List<(int, double)> _calcSuggestionsMean(
|
|
- Map<int, List<double>> clusterAvg,
|
|
|
|
|
|
+ Map<int, Vector> clusterAvg,
|
|
Set<int> personClusters,
|
|
Set<int> personClusters,
|
|
Set<int> ignoredClusters,
|
|
Set<int> ignoredClusters,
|
|
double maxClusterDistance, {
|
|
double maxClusterDistance, {
|
|
@@ -779,23 +781,14 @@ class ClusterFeedbackService {
|
|
final Map<int, List<(int, double)>> suggestions = {};
|
|
final Map<int, List<(int, double)>> suggestions = {};
|
|
int suggestionCount = 0;
|
|
int suggestionCount = 0;
|
|
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
|
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
|
- final clusterAvgVectors = clusterAvg.map(
|
|
|
|
- (key, value) => MapEntry(
|
|
|
|
- key,
|
|
|
|
- Vector.fromList(
|
|
|
|
- value,
|
|
|
|
- dtype: DType.float32,
|
|
|
|
- ),
|
|
|
|
- ),
|
|
|
|
- );
|
|
|
|
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
|
|
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
|
|
- for (final otherClusterID in clusterAvgVectors.keys) {
|
|
|
|
|
|
+ for (final otherClusterID in clusterAvg.keys) {
|
|
// ignore the cluster that belong to the person or is ignored
|
|
// ignore the cluster that belong to the person or is ignored
|
|
if (personClusters.contains(otherClusterID) ||
|
|
if (personClusters.contains(otherClusterID) ||
|
|
ignoredClusters.contains(otherClusterID)) {
|
|
ignoredClusters.contains(otherClusterID)) {
|
|
continue;
|
|
continue;
|
|
}
|
|
}
|
|
- final otherAvg = clusterAvgVectors[otherClusterID]!;
|
|
|
|
|
|
+ final Vector otherAvg = clusterAvg[otherClusterID]!;
|
|
int? nearestPersonCluster;
|
|
int? nearestPersonCluster;
|
|
double? minDistance;
|
|
double? minDistance;
|
|
for (final personCluster in personClusters) {
|
|
for (final personCluster in personClusters) {
|
|
@@ -803,7 +796,7 @@ class ClusterFeedbackService {
|
|
_logger.info('no avg for cluster $personCluster');
|
|
_logger.info('no avg for cluster $personCluster');
|
|
continue;
|
|
continue;
|
|
}
|
|
}
|
|
- final avg = clusterAvgVectors[personCluster]!;
|
|
|
|
|
|
+ final Vector avg = clusterAvg[personCluster]!;
|
|
final distance = 1 - avg.dot(otherAvg);
|
|
final distance = 1 - avg.dot(otherAvg);
|
|
if (distance < maxClusterDistance) {
|
|
if (distance < maxClusterDistance) {
|
|
if (minDistance == null || distance < minDistance) {
|
|
if (minDistance == null || distance < minDistance) {
|