Browse Source

Inline clustering

Manav Rathi 1 year ago
parent
commit
2a35b0ec9c

+ 34 - 0
web/apps/photos/src/services/face/cluster.ts

@@ -0,0 +1,34 @@
+import { Hdbscan, type DebugInfo } from "hdbscan";
+import { type Cluster } from "services/face/types";
+
+export interface ClusterFacesResult {
+    clusters: Array<Cluster>;
+    noise: Cluster;
+    debugInfo?: DebugInfo;
+}
+
+/**
+ * Cluster the given {@link faceEmbeddings}.
+ *
+ * @param faceEmbeddings An array of embeddings produced by our face indexing
+ * pipeline. Each embedding is for a face detected in an image (a single image
+ * may have multiple faces detected within it).
+ */
+export const clusterFaces = async (
+    faceEmbeddings: Array<Array<number>>,
+): Promise<ClusterFacesResult> => {
+    const hdbscan = new Hdbscan({
+        input: faceEmbeddings,
+        minClusterSize: 3,
+        minSamples: 5,
+        clusterSelectionEpsilon: 0.6,
+        clusterSelectionMethod: "leaf",
+        debug: true,
+    });
+
+    return {
+        clusters: hdbscan.getClusters(),
+        noise: hdbscan.getNoise(),
+        debugInfo: hdbscan.getDebugInfo(),
+    };
+};

+ 2 - 20
web/apps/photos/src/services/face/types.ts

@@ -1,5 +1,5 @@
-import { DebugInfo } from "hdbscan";
 import PQueue from "p-queue";
+import type { ClusterFacesResult } from "services/face/cluster";
 import { Dimensions } from "services/face/geom";
 import { EnteFile } from "types/file";
 import { Box, Point } from "./geom";
@@ -17,15 +17,6 @@ export declare type FaceDescriptor = Float32Array;
 
 export declare type Cluster = Array<number>;
 
-export interface ClusteringResults {
-    clusters: Array<Cluster>;
-    noise: Cluster;
-}
-
-export interface HdbscanResults extends ClusteringResults {
-    debugInfo?: DebugInfo;
-}
-
 export interface FacesCluster {
     faces: Cluster;
     summary?: FaceDescriptor;
@@ -212,7 +203,6 @@ export interface MLSyncContext {
     faceCropService: FaceCropService;
     faceEmbeddingService: FaceEmbeddingService;
     blurDetectionService: BlurDetectionService;
-    faceClusteringService: ClusteringService;
 
     localFilesMap: Map<number, EnteFile>;
     outOfSyncFiles: EnteFile[];
@@ -246,7 +236,7 @@ export interface MLSyncFileContext {
 
 export interface MLLibraryData {
     faceClusteringMethod?: Versioned<ClusteringMethod>;
-    faceClusteringResults?: ClusteringResults;
+    faceClusteringResults?: ClusterFacesResult;
     faceClustersWithNoise?: FacesClustersWithNoise;
 }
 
@@ -283,14 +273,6 @@ export interface BlurDetectionService {
     detectBlur(alignedFaces: Float32Array, faces: Face[]): number[];
 }
 
-export interface ClusteringService {
-    method: Versioned<ClusteringMethod>;
-
-    cluster(input: ClusteringInput): Promise<ClusteringResults>;
-}
-
-export declare type ClusteringInput = Array<Array<number>>;
-
 export interface MachineLearningWorker {
     closeLocalSyncContext(): Promise<void>;
 

+ 8 - 6
web/apps/photos/src/services/machineLearning/faceService.ts

@@ -11,6 +11,7 @@ import {
     type Versioned,
 } from "services/face/types";
 import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
+import { clusterFaces } from "../face/cluster";
 import {
     fetchImageBitmap,
     fetchImageBitmapForContext,
@@ -257,12 +258,13 @@ class FaceService {
         }
 
         log.info("Running clustering allFaces: ", allFaces.length);
-        syncContext.mlLibraryData.faceClusteringResults =
-            await syncContext.faceClusteringService.cluster(
-                allFaces.map((f) => Array.from(f.embedding)),
-            );
-        syncContext.mlLibraryData.faceClusteringMethod =
-            syncContext.faceClusteringService.method;
+        syncContext.mlLibraryData.faceClusteringResults = await clusterFaces(
+            allFaces.map((f) => Array.from(f.embedding)),
+        );
+        syncContext.mlLibraryData.faceClusteringMethod = {
+            value: "Hdbscan",
+            version: 1,
+        };
         log.info(
             "[MLService] Got face clustering results: ",
             JSON.stringify(syncContext.mlLibraryData.faceClusteringResults),

+ 0 - 40
web/apps/photos/src/services/machineLearning/hdbscanClusteringService.ts

@@ -1,40 +0,0 @@
-import { Hdbscan } from "hdbscan";
-import {
-    ClusteringInput,
-    ClusteringMethod,
-    ClusteringService,
-    HdbscanResults,
-    Versioned,
-} from "services/face/types";
-
-class HdbscanClusteringService implements ClusteringService {
-    public method: Versioned<ClusteringMethod>;
-
-    constructor() {
-        this.method = {
-            value: "Hdbscan",
-            version: 1,
-        };
-    }
-
-    public async cluster(input: ClusteringInput): Promise<HdbscanResults> {
-        // log.info('Clustering input: ', input);
-        const hdbscan = new Hdbscan({
-            input,
-
-            minClusterSize: 3,
-            minSamples: 5,
-            clusterSelectionEpsilon: 0.6,
-            clusterSelectionMethod: "leaf",
-            debug: true,
-        });
-
-        return {
-            clusters: hdbscan.getClusters(),
-            noise: hdbscan.getNoise(),
-            debugInfo: hdbscan.getDebugInfo(),
-        };
-    }
-}
-
-export default new HdbscanClusteringService();

+ 0 - 15
web/apps/photos/src/services/machineLearning/machineLearningService.ts

@@ -14,8 +14,6 @@ import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db";
 import {
     BlurDetectionMethod,
     BlurDetectionService,
-    ClusteringMethod,
-    ClusteringService,
     Face,
     FaceCropMethod,
     FaceCropService,
@@ -38,7 +36,6 @@ import { EnteFile } from "types/file";
 import { isInternalUserForML } from "utils/user";
 import arcfaceCropService from "./arcfaceCropService";
 import FaceService from "./faceService";
-import hdbscanClusteringService from "./hdbscanClusteringService";
 import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
 import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
 import PeopleService from "./peopleService";
@@ -155,16 +152,6 @@ export class MLFactory {
 
         throw Error("Unknon face embedding method: " + method);
     }
-
-    public static getClusteringService(
-        method: ClusteringMethod,
-    ): ClusteringService {
-        if (method === "Hdbscan") {
-            return hdbscanClusteringService;
-        }
-
-        throw Error("Unknon clustering method: " + method);
-    }
 }
 
 export class LocalMLSyncContext implements MLSyncContext {
@@ -176,7 +163,6 @@ export class LocalMLSyncContext implements MLSyncContext {
     public faceCropService: FaceCropService;
     public blurDetectionService: BlurDetectionService;
     public faceEmbeddingService: FaceEmbeddingService;
-    public faceClusteringService: ClusteringService;
 
     public localFilesMap: Map<number, EnteFile>;
     public outOfSyncFiles: EnteFile[];
@@ -215,7 +201,6 @@ export class LocalMLSyncContext implements MLSyncContext {
             MLFactory.getBlurDetectionService("Laplacian");
         this.faceEmbeddingService =
             MLFactory.getFaceEmbeddingService("MobileFaceNet");
-        this.faceClusteringService = MLFactory.getClusteringService("Hdbscan");
 
         this.outOfSyncFiles = [];
         this.nSyncedFiles = 0;

+ 4 - 4
web/apps/photos/src/services/machineLearning/peopleService.ts

@@ -9,10 +9,10 @@ class PeopleService {
         const filesVersion = await mlIDbStorage.getIndexVersion("files");
         if (
             filesVersion <= (await mlIDbStorage.getIndexVersion("people")) &&
-            !isDifferentOrOld(
-                syncContext.mlLibraryData?.faceClusteringMethod,
-                syncContext.faceClusteringService.method,
-            )
+            !isDifferentOrOld(syncContext.mlLibraryData?.faceClusteringMethod, {
+                value: "Hdbscan",
+                version: 1,
+            })
         ) {
             log.info(
                 "[MLService] Skipping people index as already synced to latest version",