|
@@ -3,14 +3,18 @@ import 'dart:async';
|
|
|
import "package:flutter/foundation.dart";
|
|
|
import 'package:flutter/material.dart';
|
|
|
import 'package:logging/logging.dart';
|
|
|
+import "package:ml_linalg/linalg.dart";
|
|
|
import 'package:photos/core/configuration.dart';
|
|
|
import 'package:photos/core/event_bus.dart';
|
|
|
import "package:photos/db/files_db.dart";
|
|
|
import 'package:photos/events/subscription_purchased_event.dart';
|
|
|
+import "package:photos/face/db.dart";
|
|
|
import "package:photos/face/model/person.dart";
|
|
|
+import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|
|
import 'package:photos/models/gallery_type.dart';
|
|
|
import 'package:photos/models/selected_files.dart';
|
|
|
import 'package:photos/services/collections_service.dart';
|
|
|
+import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
|
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
|
|
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
|
|
|
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
|
|
@@ -39,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
|
|
|
enum ClusterPopupAction {
|
|
|
setCover,
|
|
|
breakupCluster,
|
|
|
+ validateCluster,
|
|
|
hide,
|
|
|
}
|
|
|
|
|
@@ -127,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|
|
],
|
|
|
),
|
|
|
),
|
|
|
+ const PopupMenuItem(
|
|
|
+ value: ClusterPopupAction.validateCluster,
|
|
|
+ child: Row(
|
|
|
+ children: [
|
|
|
+ Icon(Icons.search_off_outlined),
|
|
|
+ Padding(
|
|
|
+ padding: EdgeInsets.all(8),
|
|
|
+ ),
|
|
|
+ Text('Validate cluster'),
|
|
|
+ ],
|
|
|
+ ),
|
|
|
+ ),
|
|
|
// PopupMenuItem(
|
|
|
// value: ClusterPopupAction.hide,
|
|
|
// child: Row(
|
|
@@ -152,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|
|
if (value == ClusterPopupAction.breakupCluster) {
|
|
|
// ignore: unawaited_futures
|
|
|
await _breakUpCluster(context);
|
|
|
+ } else if (value == ClusterPopupAction.validateCluster) {
|
|
|
+ await _validateCluster(context);
|
|
|
}
|
|
|
// else if (value == ClusterPopupAction.setCover) {
|
|
|
// await setCoverPhoto(context);
|
|
@@ -166,6 +185,55 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|
|
return actions;
|
|
|
}
|
|
|
|
|
|
+ Future<void> _validateCluster(BuildContext context) async {
|
|
|
+ _logger.info('_validateCluster called');
|
|
|
+ final faceMlDb = FaceMLDataDB.instance;
|
|
|
+
|
|
|
+ final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
|
|
|
+ final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
|
|
|
+
|
|
|
+ final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
|
|
|
+ embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
|
|
|
+ final embeddings = embeddingsBlobs
|
|
|
+ .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
|
|
|
+
|
|
|
+ for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
|
|
|
+ double closestDistance = double.infinity;
|
|
|
+ double closestDistance32 = double.infinity;
|
|
|
+ double closestDistance64 = double.infinity;
|
|
|
+ String? closestFaceID;
|
|
|
+ for (final MapEntry<String, List<double>> otherEmbedding
|
|
|
+ in embeddings.entries) {
|
|
|
+ if (embedding.key == otherEmbedding.key) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ final distance64 = 1.0 -
|
|
|
+ Vector.fromList(embedding.value, dtype: DType.float64).dot(
|
|
|
+ Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
|
|
+ );
|
|
|
+ final distance32 = 1.0 -
|
|
|
+ Vector.fromList(embedding.value, dtype: DType.float32).dot(
|
|
|
+ Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
|
|
+ );
|
|
|
+ final distance = cosineDistForNormVectors(
|
|
|
+ embedding.value,
|
|
|
+ otherEmbedding.value,
|
|
|
+ );
|
|
|
+ if (distance < closestDistance) {
|
|
|
+ closestDistance = distance;
|
|
|
+ closestDistance32 = distance32;
|
|
|
+ closestDistance64 = distance64;
|
|
|
+ closestFaceID = otherEmbedding.key;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (closestDistance > 0.3) {
|
|
|
+ _logger.severe(
|
|
|
+ "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
Future<void> _breakUpCluster(BuildContext context) async {
|
|
|
final newClusterIDToFaceIDs =
|
|
|
await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);
|