[mob] compute suggestion in small batches
This commit is contained in:
parent
e2ed836b16
commit
faa07a0704
1 changed files with 53 additions and 31 deletions
|
@ -367,11 +367,13 @@ class ClusterFeedbackService {
|
|||
|
||||
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
||||
Map<int, int> allClusterIdsToCountMap,
|
||||
Set<int> ignoredClusters,
|
||||
) async {
|
||||
Set<int> ignoredClusters, {
|
||||
int minClusterSize = 1,
|
||||
int maxClusterInCurrentRun = 500,
|
||||
}) async {
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
_logger.info(
|
||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun',
|
||||
);
|
||||
|
||||
final Map<int, (Uint8List, int)> clusterToSummary =
|
||||
|
@ -380,42 +382,61 @@ class ClusterFeedbackService {
|
|||
|
||||
final Map<int, List<double>> clusterAvg = {};
|
||||
|
||||
final allClusterIds = allClusterIdsToCountMap.keys;
|
||||
for (final clusterID in allClusterIds) {
|
||||
if (ignoredClusters.contains(clusterID)) {
|
||||
continue;
|
||||
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
|
||||
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
|
||||
int smallerClustersCnt = 0;
|
||||
for (final id in allClusterIdsToCountMap.keys) {
|
||||
if (ignoredClusters.contains(id)) {
|
||||
allClusterIds.remove(id);
|
||||
ignoredClustersCnt++;
|
||||
}
|
||||
if (allClusterIdsToCountMap[clusterID]! < 2) {
|
||||
continue;
|
||||
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||
allClusterIds.remove(id);
|
||||
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
||||
alreadyUpdatedClustersCnt++;
|
||||
}
|
||||
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||
allClusterIds.remove(id);
|
||||
smallerClustersCnt++;
|
||||
}
|
||||
}
|
||||
_logger.info(
|
||||
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
|
||||
);
|
||||
// get clusterIDs sorted by count in descending order
|
||||
final sortedClusterIDs = allClusterIds.toList();
|
||||
sortedClusterIDs.sort(
|
||||
(a, b) =>
|
||||
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
|
||||
);
|
||||
int indexedInCurrentRun = 0;
|
||||
|
||||
late List<double> avg;
|
||||
if (clusterToSummary[clusterID]?.$2 ==
|
||||
allClusterIdsToCountMap[clusterID]) {
|
||||
avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values;
|
||||
} else {
|
||||
final Iterable<Uint8List> embedings =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
||||
final List<double> sum = List.filled(192, 0);
|
||||
for (final embedding in embedings) {
|
||||
final data = EVector.fromBuffer(embedding).values;
|
||||
for (int i = 0; i < sum.length; i++) {
|
||||
sum[i] += data[i];
|
||||
}
|
||||
}
|
||||
avg = sum.map((e) => e / embedings.length).toList();
|
||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeedingBuffer, embedings.length);
|
||||
for (final clusterID in sortedClusterIDs) {
|
||||
if (maxClusterInCurrentRun-- <= 0) {
|
||||
break;
|
||||
}
|
||||
indexedInCurrentRun++;
|
||||
late List<double> avg;
|
||||
final Iterable<Uint8List> embedings =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
||||
final List<double> sum = List.filled(192, 0);
|
||||
for (final embedding in embedings) {
|
||||
final data = EVector.fromBuffer(embedding).values;
|
||||
for (int i = 0; i < sum.length; i++) {
|
||||
sum[i] += data[i];
|
||||
}
|
||||
}
|
||||
avg = sum.map((e) => e / embedings.length).toList();
|
||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeedingBuffer, embedings.length);
|
||||
// store the intermediate updates
|
||||
if (updatesForClusterSummary.length > 100) {
|
||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
updatesForClusterSummary.clear();
|
||||
if (kDebugMode) {
|
||||
_logger.info(
|
||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
||||
);
|
||||
'getUpdateClusterAvg $indexedInCurrentRun clusters in current one');
|
||||
}
|
||||
}
|
||||
clusterAvg[clusterID] = avg;
|
||||
|
@ -549,8 +570,9 @@ class ClusterFeedbackService {
|
|||
);
|
||||
}
|
||||
suggestion.$4.sort((b, a) {
|
||||
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!];
|
||||
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!];
|
||||
//todo: review with @laurens, added this to avoid null safety issue
|
||||
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1;
|
||||
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1;
|
||||
return distanceA.compareTo(distanceB);
|
||||
});
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue