Inline clustering

This commit is contained in:
Manav Rathi 2024-05-16 12:04:44 +05:30
parent fbebbd3583
commit 2a35b0ec9c
No known key found for this signature in database
6 changed files with 48 additions and 85 deletions

View file

@ -0,0 +1,34 @@
import { Hdbscan, type DebugInfo } from "hdbscan";
import { type Cluster } from "services/face/types";
export interface ClusterFacesResult {
clusters: Array<Cluster>;
noise: Cluster;
debugInfo?: DebugInfo;
}
/**
* Cluster the given {@link faceEmbeddings}.
*
* @param faceEmbeddings An array of embeddings produced by our face indexing
* pipeline. Each embedding is for a face detected in an image (a single image
* may have multiple faces detected within it).
*/
export const clusterFaces = async (
faceEmbeddings: Array<Array<number>>,
): Promise<ClusterFacesResult> => {
const hdbscan = new Hdbscan({
input: faceEmbeddings,
minClusterSize: 3,
minSamples: 5,
clusterSelectionEpsilon: 0.6,
clusterSelectionMethod: "leaf",
debug: true,
});
return {
clusters: hdbscan.getClusters(),
noise: hdbscan.getNoise(),
debugInfo: hdbscan.getDebugInfo(),
};
};

View file

@ -1,5 +1,5 @@
import { DebugInfo } from "hdbscan";
import PQueue from "p-queue";
import type { ClusterFacesResult } from "services/face/cluster";
import { Dimensions } from "services/face/geom";
import { EnteFile } from "types/file";
import { Box, Point } from "./geom";
@ -17,15 +17,6 @@ export declare type FaceDescriptor = Float32Array;
export declare type Cluster = Array<number>;
export interface ClusteringResults {
clusters: Array<Cluster>;
noise: Cluster;
}
export interface HdbscanResults extends ClusteringResults {
debugInfo?: DebugInfo;
}
export interface FacesCluster {
faces: Cluster;
summary?: FaceDescriptor;
@ -212,7 +203,6 @@ export interface MLSyncContext {
faceCropService: FaceCropService;
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
localFilesMap: Map<number, EnteFile>;
outOfSyncFiles: EnteFile[];
@ -246,7 +236,7 @@ export interface MLSyncFileContext {
export interface MLLibraryData {
faceClusteringMethod?: Versioned<ClusteringMethod>;
faceClusteringResults?: ClusteringResults;
faceClusteringResults?: ClusterFacesResult;
faceClustersWithNoise?: FacesClustersWithNoise;
}
@ -283,14 +273,6 @@ export interface BlurDetectionService {
detectBlur(alignedFaces: Float32Array, faces: Face[]): number[];
}
export interface ClusteringService {
method: Versioned<ClusteringMethod>;
cluster(input: ClusteringInput): Promise<ClusteringResults>;
}
export declare type ClusteringInput = Array<Array<number>>;
export interface MachineLearningWorker {
closeLocalSyncContext(): Promise<void>;

View file

@ -11,6 +11,7 @@ import {
type Versioned,
} from "services/face/types";
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import { clusterFaces } from "../face/cluster";
import {
fetchImageBitmap,
fetchImageBitmapForContext,
@ -257,12 +258,13 @@ class FaceService {
}
log.info("Running clustering allFaces: ", allFaces.length);
syncContext.mlLibraryData.faceClusteringResults =
await syncContext.faceClusteringService.cluster(
allFaces.map((f) => Array.from(f.embedding)),
);
syncContext.mlLibraryData.faceClusteringMethod =
syncContext.faceClusteringService.method;
syncContext.mlLibraryData.faceClusteringResults = await clusterFaces(
allFaces.map((f) => Array.from(f.embedding)),
);
syncContext.mlLibraryData.faceClusteringMethod = {
value: "Hdbscan",
version: 1,
};
log.info(
"[MLService] Got face clustering results: ",
JSON.stringify(syncContext.mlLibraryData.faceClusteringResults),

View file

@ -1,40 +0,0 @@
import { Hdbscan } from "hdbscan";
import {
ClusteringInput,
ClusteringMethod,
ClusteringService,
HdbscanResults,
Versioned,
} from "services/face/types";
class HdbscanClusteringService implements ClusteringService {
public method: Versioned<ClusteringMethod>;
constructor() {
this.method = {
value: "Hdbscan",
version: 1,
};
}
public async cluster(input: ClusteringInput): Promise<HdbscanResults> {
// log.info('Clustering input: ', input);
const hdbscan = new Hdbscan({
input,
minClusterSize: 3,
minSamples: 5,
clusterSelectionEpsilon: 0.6,
clusterSelectionMethod: "leaf",
debug: true,
});
return {
clusters: hdbscan.getClusters(),
noise: hdbscan.getNoise(),
debugInfo: hdbscan.getDebugInfo(),
};
}
}
export default new HdbscanClusteringService();

View file

@ -14,8 +14,6 @@ import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db";
import {
BlurDetectionMethod,
BlurDetectionService,
ClusteringMethod,
ClusteringService,
Face,
FaceCropMethod,
FaceCropService,
@ -38,7 +36,6 @@ import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import arcfaceCropService from "./arcfaceCropService";
import FaceService from "./faceService";
import hdbscanClusteringService from "./hdbscanClusteringService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import PeopleService from "./peopleService";
@ -155,16 +152,6 @@ export class MLFactory {
throw Error("Unknon face embedding method: " + method);
}
public static getClusteringService(
method: ClusteringMethod,
): ClusteringService {
if (method === "Hdbscan") {
return hdbscanClusteringService;
}
throw Error("Unknon clustering method: " + method);
}
}
export class LocalMLSyncContext implements MLSyncContext {
@ -176,7 +163,6 @@ export class LocalMLSyncContext implements MLSyncContext {
public faceCropService: FaceCropService;
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
@ -215,7 +201,6 @@ export class LocalMLSyncContext implements MLSyncContext {
MLFactory.getBlurDetectionService("Laplacian");
this.faceEmbeddingService =
MLFactory.getFaceEmbeddingService("MobileFaceNet");
this.faceClusteringService = MLFactory.getClusteringService("Hdbscan");
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;

View file

@ -9,10 +9,10 @@ class PeopleService {
const filesVersion = await mlIDbStorage.getIndexVersion("files");
if (
filesVersion <= (await mlIDbStorage.getIndexVersion("people")) &&
!isDifferentOrOld(
syncContext.mlLibraryData?.faceClusteringMethod,
syncContext.faceClusteringService.method,
)
!isDifferentOrOld(syncContext.mlLibraryData?.faceClusteringMethod, {
value: "Hdbscan",
version: 1,
})
) {
log.info(
"[MLService] Skipping people index as already synced to latest version",