This commit is contained in:
Manav Rathi 2024-05-16 13:50:16 +05:30
parent 43a3df5bbf
commit 84c737ddd3
No known key found for this signature in database
5 changed files with 10 additions and 36 deletions

View file

@ -1,6 +1,6 @@
import { Face } from "services/face/types";
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
import { mobileFaceNetFaceSize } from "../machineLearning/embed";
import { mobileFaceNetFaceSize } from "./embed";
/**
* Laplacian blur detection.

View file

@ -8,7 +8,7 @@ export const mobileFaceNetFaceSize = 112;
*
* The model used is MobileFaceNet, running in an ONNX runtime.
*/
export const getFaceEmbeddings = async (
export const faceEmbeddings = async (
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> => {
const outputData = await workerBridge.faceEmbeddings(faceData);

View file

@ -199,7 +199,6 @@ export interface MLSyncContext {
userID: number;
faceCropService: FaceCropService;
faceEmbeddingService: FaceEmbeddingService;
localFilesMap: Map<number, EnteFile>;
outOfSyncFiles: EnteFile[];
@ -248,13 +247,6 @@ export interface FaceCropService {
): Promise<FaceCrop>;
}
export interface FaceEmbeddingService {
method: Versioned<FaceEmbeddingMethod>;
faceSize: number;
getFaceEmbeddings(faceImages: Float32Array): Promise<Array<FaceEmbedding>>;
}
export interface MachineLearningWorker {
closeLocalSyncContext(): Promise<void>;

View file

@ -3,6 +3,7 @@ import log from "@/next/log";
import { faceAlignment } from "services/face/align";
import mlIDbStorage from "services/face/db";
import { detectFaces, getRelativeDetection } from "services/face/detect";
import { faceEmbeddings, mobileFaceNetFaceSize } from "services/face/embed";
import {
DetectedFace,
Face,
@ -11,9 +12,9 @@ import {
type FaceAlignment,
} from "services/face/types";
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import { detectBlur } from "../face/blur";
import { clusterFaces } from "../face/cluster";
import { getFaceCrop } from "../face/crop";
import { detectBlur } from "../face/blur";
import {
fetchImageBitmap,
fetchImageBitmapForContext,
@ -86,7 +87,7 @@ class FaceService {
const faceAlignments = newMlFile.faces.map((f) => f.alignment);
const faceImages = await extractFaceImagesToFloat32(
faceAlignments,
syncContext.faceEmbeddingService.faceSize,
mobileFaceNetFaceSize,
imageBitmap,
);
const blurValues = detectBlur(faceImages, newMlFile.faces);
@ -104,15 +105,15 @@ class FaceService {
alignedFacesInput: Float32Array,
) {
const { newMlFile } = fileContext;
newMlFile.faceEmbeddingMethod = syncContext.faceEmbeddingService.method;
newMlFile.faceEmbeddingMethod = {
value: "MobileFaceNet",
version: 2,
};
// TODO: when not storing face crops, image will be needed to extract faces
// fileContext.imageBitmap ||
// (await this.getImageBitmap(fileContext));
const embeddings =
await syncContext.faceEmbeddingService.getFaceEmbeddings(
alignedFacesInput,
);
const embeddings = await faceEmbeddings(alignedFacesInput);
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);

View file

@ -15,8 +15,6 @@ import {
Face,
FaceCropService,
FaceDetection,
FaceEmbeddingMethod,
FaceEmbeddingService,
Landmark,
MLLibraryData,
MLSearchConfig,
@ -31,7 +29,6 @@ import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import { fetchImageBitmapForContext } from "../face/image";
import { syncPeopleIndex } from "../face/people";
import mobileFaceNetEmbeddingService from "./embed";
import FaceService from "./faceService";
/**
@ -105,24 +102,11 @@ export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
}
export class MLFactory {
public static getFaceEmbeddingService(
method: FaceEmbeddingMethod,
): FaceEmbeddingService {
if (method === "MobileFaceNet") {
return mobileFaceNetEmbeddingService;
}
throw Error("Unknon face embedding method: " + method);
}
}
export class LocalMLSyncContext implements MLSyncContext {
public token: string;
public userID: number;
public faceCropService: FaceCropService;
public faceEmbeddingService: FaceEmbeddingService;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
@ -148,9 +132,6 @@ export class LocalMLSyncContext implements MLSyncContext {
this.token = token;
this.userID = userID;
this.faceEmbeddingService =
MLFactory.getFaceEmbeddingService("MobileFaceNet");
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;