Просмотр исходного кода

[mob] Add more validation for clustering

laurenspriem 1 год назад
Родитель
Сommit
0c72fd2a69

+ 18 - 0
mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart

@@ -395,6 +395,12 @@ class FaceClustering {
         if (distance < closestDistance) {
           closestDistance = distance;
           closestIdx = j;
+          // if (distance < distanceThreshold) {
+          //   if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
+          //       sortedFaceInfos[j].faceID.startsWith("15488756")) {
+          //     log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
+          //   }
+          // }
         }
       }
 
@@ -408,10 +414,22 @@ class FaceClustering {
           sortedFaceInfos[closestIdx].clusterId = clusterID;
           newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
         }
+        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
+        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
+        //   log(
+        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
+        //   );
+        // }
         sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
         newFaceIdToCluster[sortedFaceInfos[i].faceID] =
             sortedFaceInfos[closestIdx].clusterId!;
       } else {
+        // if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
+        //     sortedFaceInfos[i].faceID.startsWith("15488756")) {
+        //   log(
+        //     "[XXX]  [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
+        //   );
+        // }
         clusterID++;
         sortedFaceInfos[i].clusterId = clusterID;
         newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;

+ 68 - 0
mobile/lib/ui/viewer/people/cluster_app_bar.dart

@@ -3,14 +3,18 @@ import 'dart:async';
 import "package:flutter/foundation.dart";
 import 'package:flutter/material.dart';
 import 'package:logging/logging.dart';
+import "package:ml_linalg/linalg.dart";
 import 'package:photos/core/configuration.dart';
 import 'package:photos/core/event_bus.dart';
 import "package:photos/db/files_db.dart";
 import 'package:photos/events/subscription_purchased_event.dart';
+import "package:photos/face/db.dart";
 import "package:photos/face/model/person.dart";
+import "package:photos/generated/protos/ente/common/vector.pb.dart";
 import 'package:photos/models/gallery_type.dart';
 import 'package:photos/models/selected_files.dart';
 import 'package:photos/services/collections_service.dart';
+import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
 import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
 import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
 import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
@@ -39,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
 enum ClusterPopupAction {
   setCover,
   breakupCluster,
+  validateCluster,
   hide,
 }
 
@@ -127,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
             ],
           ),
         ),
+        const PopupMenuItem(
+          value: ClusterPopupAction.validateCluster,
+          child: Row(
+            children: [
+              Icon(Icons.search_off_outlined),
+              Padding(
+                padding: EdgeInsets.all(8),
+              ),
+              Text('Validate cluster'),
+            ],
+          ),
+        ),
         // PopupMenuItem(
         //   value: ClusterPopupAction.hide,
         //   child: Row(
@@ -152,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
             if (value == ClusterPopupAction.breakupCluster) {
               // ignore: unawaited_futures
               await _breakUpCluster(context);
+            } else if (value == ClusterPopupAction.validateCluster) {
+              await _validateCluster(context);
             }
             // else if (value == ClusterPopupAction.setCover) {
             //   await setCoverPhoto(context);
@@ -166,6 +185,55 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
     return actions;
   }
 
+  Future<void> _validateCluster(BuildContext context) async {
+    _logger.info('_validateCluster called');
+    final faceMlDb = FaceMLDataDB.instance;
+
+    final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
+    final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
+
+    final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
+    embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
+    final embeddings = embeddingsBlobs
+        .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
+
+    for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
+      double closestDistance = double.infinity;
+      double closestDistance32 = double.infinity;
+      double closestDistance64 = double.infinity;
+      String? closestFaceID;
+      for (final MapEntry<String, List<double>> otherEmbedding
+          in embeddings.entries) {
+        if (embedding.key == otherEmbedding.key) {
+          continue;
+        }
+        final distance64 = 1.0 -
+            Vector.fromList(embedding.value, dtype: DType.float64).dot(
+              Vector.fromList(otherEmbedding.value, dtype: DType.float64),
+            );
+        final distance32 = 1.0 -
+            Vector.fromList(embedding.value, dtype: DType.float32).dot(
+              Vector.fromList(otherEmbedding.value, dtype: DType.float32),
+            );
+        final distance = cosineDistForNormVectors(
+          embedding.value,
+          otherEmbedding.value,
+        );
+        if (distance < closestDistance) {
+          closestDistance = distance;
+          closestDistance32 = distance32;
+          closestDistance64 = distance64;
+          closestFaceID = otherEmbedding.key;
+        }
+      }
+      if (closestDistance > 0.3) {
+        _logger.severe(
+          "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
+        );
+      }
+    }
+  }
+
   Future<void> _breakUpCluster(BuildContext context) async {
     final newClusterIDToFaceIDs =
         await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);