[mob] Add more validation for clustering
This commit is contained in:
parent
723253a12c
commit
0c72fd2a69
2 changed files with 86 additions and 0 deletions
|
@ -395,6 +395,12 @@ class FaceClustering {
|
|||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
// if (distance < distanceThreshold) {
|
||||
// if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
|
||||
// sortedFaceInfos[j].faceID.startsWith("15488756")) {
|
||||
// log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -408,10 +414,22 @@ class FaceClustering {
|
|||
sortedFaceInfos[closestIdx].clusterId = clusterID;
|
||||
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
|
||||
}
|
||||
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
|
||||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
|
||||
// log(
|
||||
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
|
||||
// );
|
||||
// }
|
||||
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
|
||||
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
|
||||
sortedFaceInfos[closestIdx].clusterId!;
|
||||
} else {
|
||||
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
|
||||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
|
||||
// log(
|
||||
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
|
||||
// );
|
||||
// }
|
||||
clusterID++;
|
||||
sortedFaceInfos[i].clusterId = clusterID;
|
||||
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
|
||||
|
|
|
@ -3,14 +3,18 @@ import 'dart:async';
|
|||
import "package:flutter/foundation.dart";
|
||||
import 'package:flutter/material.dart';
|
||||
import 'package:logging/logging.dart';
|
||||
import "package:ml_linalg/linalg.dart";
|
||||
import 'package:photos/core/configuration.dart';
|
||||
import 'package:photos/core/event_bus.dart';
|
||||
import "package:photos/db/files_db.dart";
|
||||
import 'package:photos/events/subscription_purchased_event.dart';
|
||||
import "package:photos/face/db.dart";
|
||||
import "package:photos/face/model/person.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import 'package:photos/models/gallery_type.dart';
|
||||
import 'package:photos/models/selected_files.dart';
|
||||
import 'package:photos/services/collections_service.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
|
||||
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
|
||||
|
@ -39,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
|
|||
enum ClusterPopupAction {
|
||||
setCover,
|
||||
breakupCluster,
|
||||
validateCluster,
|
||||
hide,
|
||||
}
|
||||
|
||||
|
@ -127,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|||
],
|
||||
),
|
||||
),
|
||||
const PopupMenuItem(
|
||||
value: ClusterPopupAction.validateCluster,
|
||||
child: Row(
|
||||
children: [
|
||||
Icon(Icons.search_off_outlined),
|
||||
Padding(
|
||||
padding: EdgeInsets.all(8),
|
||||
),
|
||||
Text('Validate cluster'),
|
||||
],
|
||||
),
|
||||
),
|
||||
// PopupMenuItem(
|
||||
// value: ClusterPopupAction.hide,
|
||||
// child: Row(
|
||||
|
@ -152,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|||
if (value == ClusterPopupAction.breakupCluster) {
|
||||
// ignore: unawaited_futures
|
||||
await _breakUpCluster(context);
|
||||
} else if (value == ClusterPopupAction.validateCluster) {
|
||||
await _validateCluster(context);
|
||||
}
|
||||
// else if (value == ClusterPopupAction.setCover) {
|
||||
// await setCoverPhoto(context);
|
||||
|
@ -166,6 +185,55 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|||
return actions;
|
||||
}
|
||||
|
||||
Future<void> _validateCluster(BuildContext context) async {
|
||||
_logger.info('_validateCluster called');
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
|
||||
final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
|
||||
final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
|
||||
|
||||
final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
|
||||
embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
|
||||
final embeddings = embeddingsBlobs
|
||||
.map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
|
||||
|
||||
for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
|
||||
double closestDistance = double.infinity;
|
||||
double closestDistance32 = double.infinity;
|
||||
double closestDistance64 = double.infinity;
|
||||
String? closestFaceID;
|
||||
for (final MapEntry<String, List<double>> otherEmbedding
|
||||
in embeddings.entries) {
|
||||
if (embedding.key == otherEmbedding.key) {
|
||||
continue;
|
||||
}
|
||||
final distance64 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float64).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
||||
);
|
||||
final distance32 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float32).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
||||
);
|
||||
final distance = cosineDistForNormVectors(
|
||||
embedding.value,
|
||||
otherEmbedding.value,
|
||||
);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestDistance32 = distance32;
|
||||
closestDistance64 = distance64;
|
||||
closestFaceID = otherEmbedding.key;
|
||||
}
|
||||
}
|
||||
if (closestDistance > 0.3) {
|
||||
_logger.severe(
|
||||
"Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _breakUpCluster(BuildContext context) async {
|
||||
final newClusterIDToFaceIDs =
|
||||
await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);
|
||||
|
|
Loading…
Add table
Reference in a new issue