|
@@ -868,7 +868,8 @@ class ClusterFeedbackService {
|
|
),
|
|
),
|
|
);
|
|
);
|
|
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
|
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
|
- final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
|
|
|
|
|
+ final avgNormalized = avg / avg.norm();
|
|
|
|
+ final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer();
|
|
updatesForClusterSummary[clusterID] =
|
|
updatesForClusterSummary[clusterID] =
|
|
(avgEmbeddingBuffer, embeddings.length);
|
|
(avgEmbeddingBuffer, embeddings.length);
|
|
// store the intermediate updates
|
|
// store the intermediate updates
|
|
@@ -882,7 +883,7 @@ class ClusterFeedbackService {
|
|
);
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- clusterAvg[clusterID] = avg;
|
|
|
|
|
|
+ clusterAvg[clusterID] = avgNormalized;
|
|
}
|
|
}
|
|
if (updatesForClusterSummary.isNotEmpty) {
|
|
if (updatesForClusterSummary.isNotEmpty) {
|
|
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
|
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|