Import ONNX-YOLO face changes from the web_face branch

Laurens has made the relevant changes to get ONNX-YOLO face detection working in
a manner where the generated embeddings are the same as what get generated by
the corresponding ML stack in the mobile client.

This commit cherry picks his ML related changes from the branch he was working
in, but leaves out the surrounding scaffolding (We cannot merge that branch
directly because it relies on wasm that we don't need and don't want to commit
to main).

At this point this functionality is correct but not usable - the next step will
be to tie this to the ONNX runtime that we already have on the Node.js layer of
our desktop app.
This commit is contained in:
laurenspriem 2024-04-10 14:15:02 +05:30 committed by Manav Rathi
parent ea18608727
commit 3eb95bd822
No known key found for this signature in database
20 changed files with 1286 additions and 186 deletions

View file

@ -12,7 +12,7 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
batchSize: 200,
imageSource: "Original",
faceDetection: {
method: "BlazeFace",
method: "YoloFace",
minFaceSize: 32,
},
faceCrop: {
@ -28,6 +28,10 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
faceAlignment: {
method: "ArcFace",
},
blurDetection: {
method: "Laplacian",
threshold: 15,
},
faceEmbedding: {
method: "MobileFaceNet",
faceSize: 112,
@ -70,7 +74,7 @@ export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
export const MAX_ML_SYNC_ERROR_COUNT = 4;
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export const TEXT_DETECTION_TIMEOUT_MS = [10000, 30000, 60000, 120000, 240000];
@ -81,6 +85,7 @@ export const BLAZEFACE_SCORE_THRESHOLD = 0.75;
export const BLAZEFACE_PASS1_SCORE_THRESHOLD = 0.4;
export const BLAZEFACE_FACE_SIZE = 112;
export const MOBILEFACENET_FACE_SIZE = 112;
export const MOBILEFACENET_EMBEDDING_SIZE = 192;
// scene detection model takes fixed-shaped (224x224) inputs
// https://tfhub.dev/sayannath/lite-model/image-scene/1

View file

@ -51,6 +51,11 @@ class BlazeFaceDetectionService implements FaceDetectionService {
this.desiredFaceSize = desiredFaceSize;
}
public getRelativeDetection(): FaceDetection {
// TODO(MR): onnx-yolo
throw new Error();
}
private async init() {
this.blazeFaceModel = blazeFaceLoad({
maxFaces: BLAZEFACE_MAX_FACES,

View file

@ -8,7 +8,7 @@ import {
import { imageBitmapToBlob } from "utils/image";
import {
areFaceIdsSame,
extractFaceImages,
extractFaceImagesToFloat32,
getFaceId,
getLocalFile,
getOriginalImageBitmap,
@ -49,8 +49,12 @@ class FaceService {
syncContext,
fileContext,
);
const timerId = `faceDetection-${fileContext.enteFile.id}`;
console.time(timerId);
const faceDetections =
await syncContext.faceDetectionService.detectFaces(imageBitmap);
console.timeEnd(timerId);
console.log("faceDetections: ", faceDetections?.length);
// log.info('3 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
@ -104,7 +108,7 @@ class FaceService {
async syncFileFaceAlignments(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
) {
): Promise<Float32Array> {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newDetection &&
@ -123,18 +127,37 @@ class FaceService {
newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
fileContext.newAlignment = true;
const imageBitmap =
fileContext.imageBitmap ||
(await ReaderService.getImageBitmap(syncContext, fileContext));
// Execute the face alignment calculations
for (const face of newMlFile.faces) {
face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
face.detection,
);
}
// Extract face images and convert to Float32Array
const faceAlignments = newMlFile.faces.map((f) => f.alignment);
const faceImages = await extractFaceImagesToFloat32(
faceAlignments,
syncContext.faceEmbeddingService.faceSize,
imageBitmap,
);
const blurValues =
syncContext.blurDetectionService.detectBlur(faceImages);
newMlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
imageBitmap.close();
log.info("[MLService] alignedFaces: ", newMlFile.faces?.length);
// log.info('4 TF Memory stats: ',JSON.stringify(tf.memory()));
return faceImages;
}
async syncFileFaceEmbeddings(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
alignedFacesInput: Float32Array,
) {
const { oldMlFile, newMlFile } = fileContext;
if (
@ -156,22 +179,43 @@ class FaceService {
// TODO: when not storing face crops, image will be needed to extract faces
// fileContext.imageBitmap ||
// (await this.getImageBitmap(syncContext, fileContext));
const faceImages = await extractFaceImages(
newMlFile.faces,
syncContext.faceEmbeddingService.faceSize,
);
const embeddings =
await syncContext.faceEmbeddingService.getFaceEmbeddings(
faceImages,
alignedFacesInput,
);
faceImages.forEach((faceImage) => faceImage.close());
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);
// log.info('5 TF Memory stats: ',JSON.stringify(tf.memory()));
}
async syncFileFaceMakeRelativeDetections(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newAlignment &&
!isDifferentOrOld(
oldMlFile?.faceEmbeddingMethod,
syncContext.faceEmbeddingService.method,
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
return;
}
for (let i = 0; i < newMlFile.faces.length; i++) {
const face = newMlFile.faces[i];
if (face.detection.box.x + face.detection.box.width < 2) continue; // Skip if somehow already relative
face.detection =
syncContext.faceDetectionService.getRelativeDetection(
face.detection,
newMlFile.imageDimensions,
);
}
}
async saveFaceCrop(
imageBitmap: ImageBitmap,
face: Face,

View file

@ -0,0 +1,131 @@
import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig";
import {
BlurDetectionMethod,
BlurDetectionService,
Versioned,
} from "types/machineLearning";
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
class LaplacianBlurDetectionService implements BlurDetectionService {
public method: Versioned<BlurDetectionMethod>;
public constructor() {
this.method = {
value: "Laplacian",
version: 1,
};
}
public detectBlur(alignedFaces: Float32Array): number[] {
const numFaces = Math.round(
alignedFaces.length /
(MOBILEFACENET_FACE_SIZE * MOBILEFACENET_FACE_SIZE * 3),
);
const blurValues: number[] = [];
for (let i = 0; i < numFaces; i++) {
const faceImage = createGrayscaleIntMatrixFromNormalized2List(
alignedFaces,
i,
);
const laplacian = this.applyLaplacian(faceImage);
const variance = this.calculateVariance(laplacian);
blurValues.push(variance);
}
return blurValues;
}
private calculateVariance(matrix: number[][]): number {
const numRows = matrix.length;
const numCols = matrix[0].length;
const totalElements = numRows * numCols;
// Calculate the mean
let mean: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
mean += value;
});
});
mean /= totalElements;
// Calculate the variance
let variance: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
const diff: number = value - mean;
variance += diff * diff;
});
});
variance /= totalElements;
return variance;
}
private padImage(image: number[][]): number[][] {
const numRows = image.length;
const numCols = image[0].length;
// Create a new matrix with extra padding
const paddedImage: number[][] = Array.from(
{ length: numRows + 2 },
() => new Array(numCols + 2).fill(0),
);
// Copy original image into the center of the padded image
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
paddedImage[i + 1][j + 1] = image[i][j];
}
}
// Reflect padding
// Top and bottom rows
for (let j = 1; j <= numCols; j++) {
paddedImage[0][j] = paddedImage[2][j]; // Top row
paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row
}
// Left and right columns
for (let i = 0; i < numRows + 2; i++) {
paddedImage[i][0] = paddedImage[i][2]; // Left column
paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column
}
return paddedImage;
}
private applyLaplacian(image: number[][]): number[][] {
const paddedImage: number[][] = this.padImage(image);
const numRows = image.length;
const numCols = image[0].length;
// Create an output image initialized to 0
const outputImage: number[][] = Array.from({ length: numRows }, () =>
new Array(numCols).fill(0),
);
// Define the Laplacian kernel
const kernel: number[][] = [
[0, 1, 0],
[1, -4, 1],
[0, 1, 0],
];
// Apply the kernel to each pixel
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
let sum = 0;
for (let ki = 0; ki < 3; ki++) {
for (let kj = 0; kj < 3; kj++) {
sum += paddedImage[i + ki][j + kj] * kernel[ki][kj];
}
}
// Adjust the output value if necessary (e.g., clipping)
outputImage[i][j] = sum;
}
}
return outputImage;
}
}
export default new LaplacianBlurDetectionService();

View file

@ -6,6 +6,8 @@ import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worke
import PQueue from "p-queue";
import { EnteFile } from "types/file";
import {
BlurDetectionMethod,
BlurDetectionService,
ClusteringMethod,
ClusteringService,
Face,
@ -28,19 +30,20 @@ import {
import { logQueueStats } from "utils/machineLearning";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import blazeFaceDetectionService from "./blazeFaceDetectionService";
import dbscanClusteringService from "./dbscanClusteringService";
import hdbscanClusteringService from "./hdbscanClusteringService";
import imageSceneService from "./imageSceneService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import ssdMobileNetV2Service from "./ssdMobileNetV2Service";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
export class MLFactory {
public static getFaceDetectionService(
method: FaceDetectionMethod,
): FaceDetectionService {
if (method === "BlazeFace") {
return blazeFaceDetectionService;
if (method === "YoloFace") {
return yoloFaceDetectionService;
}
throw Error("Unknon face detection method: " + method);
@ -84,6 +87,16 @@ export class MLFactory {
throw Error("Unknon face alignment method: " + method);
}
public static getBlurDetectionService(
method: BlurDetectionMethod,
): BlurDetectionService {
if (method === "Laplacian") {
return laplacianBlurDetectionService;
}
throw Error("Unknon blur detection method: " + method);
}
public static getFaceEmbeddingService(
method: FaceEmbeddingMethod,
): FaceEmbeddingService {
@ -131,6 +144,7 @@ export class LocalMLSyncContext implements MLSyncContext {
public faceDetectionService: FaceDetectionService;
public faceCropService: FaceCropService;
public faceAlignmentService: FaceAlignmentService;
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public objectDetectionService: ObjectDetectionService;
@ -178,6 +192,9 @@ export class LocalMLSyncContext implements MLSyncContext {
this.faceAlignmentService = MLFactory.getFaceAlignmentService(
this.config.faceAlignment.method,
);
this.blurDetectionService = MLFactory.getBlurDetectionService(
this.config.blurDetection.method,
);
this.faceEmbeddingService = MLFactory.getFaceEmbeddingService(
this.config.faceEmbedding.method,
);
@ -196,7 +213,7 @@ export class LocalMLSyncContext implements MLSyncContext {
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;
this.concurrency = concurrency || getConcurrency();
this.concurrency = concurrency ?? getConcurrency();
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
@ -212,6 +229,7 @@ export class LocalMLSyncContext implements MLSyncContext {
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;

View file

@ -34,11 +34,6 @@ class MachineLearningService {
}
await downloadManager.init(APPS.PHOTOS, { token });
// await this.init();
// Used to debug tf memory leak, all tf memory
// needs to be cleaned using tf.dispose or tf.tidy
// tf.engine().startScope();
const syncContext = await this.getSyncContext(token, userID);
@ -185,6 +180,50 @@ class MachineLearningService {
log.info("getOutOfSyncFiles", Date.now() - startTime, "ms");
}
// TODO: optimize, use indexdb indexes, move facecrops to cache to reduce io
// remove, already done
private async getUniqueOutOfSyncFilesNoIdx(
syncContext: MLSyncContext,
files: EnteFile[],
) {
const limit = syncContext.config.batchSize;
const mlVersion = syncContext.config.mlVersion;
const uniqueFiles: Map<number, EnteFile> = new Map<number, EnteFile>();
for (let i = 0; uniqueFiles.size < limit && i < files.length; i++) {
const mlFileData = await this.getMLFileData(files[i].id);
const mlFileVersion = mlFileData?.mlVersion || 0;
if (
!uniqueFiles.has(files[i].id) &&
(!mlFileData?.errorCount || mlFileData.errorCount < 2) &&
(mlFileVersion < mlVersion ||
syncContext.config.imageSource !== mlFileData.imageSource)
) {
uniqueFiles.set(files[i].id, files[i]);
}
}
return [...uniqueFiles.values()];
}
private async getOutOfSyncFilesNoIdx(syncContext: MLSyncContext) {
const existingFilesMap = await this.getLocalFilesMap(syncContext);
// existingFiles.sort(
// (a, b) => b.metadata.creationTime - a.metadata.creationTime
// );
console.time("getUniqueOutOfSyncFiles");
syncContext.outOfSyncFiles = await this.getUniqueOutOfSyncFilesNoIdx(
syncContext,
[...existingFilesMap.values()],
);
// addLogLine("getUniqueOutOfSyncFiles");
// addLogLine(
// "Got unique outOfSyncFiles: ",
// syncContext.outOfSyncFiles.length,
// "for batchSize: ",
// syncContext.config.batchSize,
// );
}
private async syncFiles(syncContext: MLSyncContext) {
try {
const functions = syncContext.outOfSyncFiles.map(
@ -283,6 +322,11 @@ class MachineLearningService {
textDetectionTimeoutIndex?: number,
): Promise<MlFileData> {
try {
console.log(
"Start index for ",
enteFile.title ?? "no title",
enteFile.id,
);
const mlFileData = await this.syncFile(
syncContext,
enteFile,
@ -319,6 +363,12 @@ class MachineLearningService {
await this.persistMLFileSyncError(syncContext, enteFile, error);
syncContext.nSyncedFiles += 1;
} finally {
console.log(
"done index for ",
enteFile.title ?? "no title",
enteFile.id,
);
// addLogLine('TF Memory stats: ', JSON.stringify(tf.memory()));
log.info("TF Memory stats: ", JSON.stringify(tf.memory()));
}
}
@ -330,6 +380,7 @@ class MachineLearningService {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
textDetectionTimeoutIndex?: number,
) {
console.log("Syncing for file" + enteFile.title);
const fileContext: MLSyncFileContext = { enteFile, localFile };
const oldMlFile =
(fileContext.oldMlFile = await this.getMLFileData(enteFile.id)) ??
@ -351,11 +402,16 @@ class MachineLearningService {
try {
await ReaderService.getImageBitmap(syncContext, fileContext);
await Promise.all([
this.syncFaceDetections(syncContext, fileContext),
ObjectService.syncFileObjectDetections(
syncContext,
fileContext,
),
this.syncFileAnalyzeFaces(syncContext, fileContext),
// ObjectService.syncFileObjectDetections(
// syncContext,
// fileContext
// ),
// TextService.syncFileTextDetections(
// syncContext,
// fileContext,
// textDetectionTimeoutIndex
// ),
]);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
@ -448,7 +504,7 @@ class MachineLearningService {
await this.persistMLLibraryData(syncContext);
}
private async syncFaceDetections(
private async syncFileAnalyzeFaces(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext,
) {
@ -459,9 +515,21 @@ class MachineLearningService {
if (newMlFile.faces && newMlFile.faces.length > 0) {
await FaceService.syncFileFaceCrops(syncContext, fileContext);
await FaceService.syncFileFaceAlignments(syncContext, fileContext);
const alignedFacesData = await FaceService.syncFileFaceAlignments(
syncContext,
fileContext,
);
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext);
await FaceService.syncFileFaceEmbeddings(
syncContext,
fileContext,
alignedFacesData,
);
await FaceService.syncFileFaceMakeRelativeDetections(
syncContext,
fileContext,
);
}
log.info(
`face detection time taken ${fileContext.enteFile.id}`,

View file

@ -30,7 +30,7 @@ class MLWorkManager {
constructor() {
this.liveSyncQueue = new PQueue({
concurrency: 1,
concurrency: 4,
// TODO: temp, remove
timeout: LIVE_SYNC_QUEUE_TIMEOUT_SEC * 1000,
throwOnTimeout: true,

View file

@ -1,23 +1,38 @@
import log from "@/next/log";
import * as tf from "@tensorflow/tfjs-core";
import { TFLiteModel } from "@tensorflow/tfjs-tflite";
import { MOBILEFACENET_FACE_SIZE } from "constants/mlConfig";
import PQueue from "p-queue";
import {
MOBILEFACENET_EMBEDDING_SIZE,
MOBILEFACENET_FACE_SIZE,
} from "constants/mlConfig";
// import { TFLiteModel } from "@tensorflow/tfjs-tflite";
// import PQueue from "p-queue";
import {
FaceEmbedding,
FaceEmbeddingMethod,
FaceEmbeddingService,
Versioned,
} from "types/machineLearning";
import { imageBitmapsToTensor4D } from "utils/machineLearning";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
import {
clamp,
getPixelBilinear,
normalizePixelBetweenMinus1And1,
} from "utils/image";
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceEmbeddingMethod>;
public faceSize: number;
private mobileFaceNetModel: Promise<TFLiteModel>;
private serialQueue: PQueue;
public constructor(faceSize: number = MOBILEFACENET_FACE_SIZE) {
this.method = {
value: "MobileFaceNet",
@ -25,81 +40,156 @@ class MobileFaceNetEmbeddingService implements FaceEmbeddingService {
};
this.faceSize = faceSize;
// TODO: set timeout
this.serialQueue = new PQueue({ concurrency: 1 });
}
private async init() {
// TODO: can also create new instance per new syncContext
const tflite = await import("@tensorflow/tfjs-tflite");
tflite.setWasmPath("/js/tflite/");
private async initOnnx() {
console.log("start ort mobilefacenet");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/mobilefacenet/mobilefacenet_opset15.onnx",
);
const faceBatchSize = 1;
const data = new Float32Array(
faceBatchSize * 3 * this.faceSize * this.faceSize,
);
const inputTensor = new ort.Tensor("float32", data, [
faceBatchSize,
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end mobilefacenet");
}
this.mobileFaceNetModel = tflite.loadTFLiteModel(
"/models/mobilefacenet/mobilefacenet.tflite",
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
}
private preprocessImageBitmapToFloat32(
imageBitmap: ImageBitmap,
requiredWidth: number = this.faceSize,
requiredHeight: number = this.faceSize,
maintainAspectRatio: boolean = true,
normFunction: (
pixelValue: number,
) => number = normalizePixelBetweenMinus1And1,
) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * requiredWidth * requiredHeight * 3,
);
log.info("loaded mobileFaceNetModel: ", tf.getBackend());
}
private async getMobileFaceNetModel() {
if (!this.mobileFaceNetModel) {
await this.init();
// Populate the Float32Array with normalized pixel values
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
const pixelIndex = 3 * (h * requiredWidth + w);
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + 1] = normFunction(pixel.g);
processedImage[pixelIndex + 2] = normFunction(pixel.b);
}
}
return this.mobileFaceNetModel;
}
public getFaceEmbeddingTF(
faceTensor: tf.Tensor4D,
mobileFaceNetModel: TFLiteModel,
): tf.Tensor2D {
return tf.tidy(() => {
const normalizedFace = tf.sub(tf.div(faceTensor, 127.5), 1.0);
return mobileFaceNetModel.predict(normalizedFace) as tf.Tensor2D;
});
}
// Do not use this, use getFaceEmbedding which calls this through serialqueue
private async getFaceEmbeddingNoQueue(
faceImage: ImageBitmap,
): Promise<FaceEmbedding> {
const mobileFaceNetModel = await this.getMobileFaceNetModel();
const embeddingTensor = tf.tidy(() => {
const faceTensor = imageBitmapsToTensor4D([faceImage]);
const embeddingsTensor = this.getFaceEmbeddingTF(
faceTensor,
mobileFaceNetModel,
);
return tf.squeeze(embeddingsTensor, [0]);
});
const embedding = new Float32Array(await embeddingTensor.data());
embeddingTensor.dispose();
return embedding;
}
// TODO: TFLiteModel seems to not work concurrenly,
// remove serialqueue if that is not the case
private async getFaceEmbedding(
faceImage: ImageBitmap,
): Promise<FaceEmbedding> {
// @ts-expect-error "TODO: Fix ML related type errors"
return this.serialQueue.add(() =>
this.getFaceEmbeddingNoQueue(faceImage),
);
return processedImage;
}
public async getFaceEmbeddings(
faceImages: Array<ImageBitmap>,
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> {
return Promise.all(
faceImages.map((faceImage) => this.getFaceEmbedding(faceImage)),
const inputTensor = new ort.Tensor("float32", faceData, [
Math.round(faceData.length / (this.faceSize * this.faceSize * 3)),
this.faceSize,
this.faceSize,
3,
]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["img_inputs"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
// TODO(MR): onnx-yolo
// const runout: ort.InferenceSession.OnnxValueMapType =
const runout: any = await inferenceSession.run(feeds);
// const test = runout.embeddings;
// const test2 = test.cpuData;
const outputData = runout.embeddings["cpuData"] as Float32Array;
const embeddings = new Array<FaceEmbedding>(
outputData.length / MOBILEFACENET_EMBEDDING_SIZE,
);
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = new Float32Array(
outputData.slice(
i * MOBILEFACENET_EMBEDDING_SIZE,
(i + 1) * MOBILEFACENET_EMBEDDING_SIZE,
),
);
}
return embeddings;
}
public async dispose() {
this.mobileFaceNetModel = undefined;
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}

View file

@ -0,0 +1,331 @@
import {
BLAZEFACE_FACE_SIZE,
MAX_FACE_DISTANCE_PERCENT,
} from "constants/mlConfig";
import { Dimensions } from "types/image";
import {
FaceDetection,
FaceDetectionMethod,
FaceDetectionService,
Versioned,
} from "types/machineLearning";
import {
clamp,
getPixelBilinear,
normalizePixelBetween0And1,
} from "utils/image";
import { newBox } from "utils/machineLearning";
import { removeDuplicateDetections } from "utils/machineLearning/faceDetection";
import {
computeTransformToBox,
transformBox,
transformPoints,
} from "utils/machineLearning/transform";
import { Box, Point } from "../../../thirdparty/face-api/classes";
// TODO(MR): onnx-yolo
// import * as ort from "onnxruntime-web";
// import { env } from "onnxruntime-web";
const ort: any = {};
// TODO(MR): onnx-yolo
// env.wasm.wasmPaths = "/js/onnx/";
class YoloFaceDetectionService implements FaceDetectionService {
// TODO(MR): onnx-yolo
// private onnxInferenceSession?: ort.InferenceSession;
private onnxInferenceSession?: any;
public method: Versioned<FaceDetectionMethod>;
private desiredFaceSize;
public constructor(desiredFaceSize: number = BLAZEFACE_FACE_SIZE) {
this.method = {
value: "YoloFace",
version: 1,
};
this.desiredFaceSize = desiredFaceSize;
}
private async initOnnx() {
console.log("start ort");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
);
const data = new Float32Array(1 * 3 * 640 * 640);
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end");
}
private async getOnnxInferenceSession() {
if (!this.onnxInferenceSession) {
await this.initOnnx();
}
return this.onnxInferenceSession;
}
private preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap: ImageBitmap,
requiredWidth: number,
requiredHeight: number,
maintainAspectRatio: boolean = true,
normFunction: (
pixelValue: number,
) => number = normalizePixelBetween0And1,
) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * 3 * requiredWidth * requiredHeight,
);
// Populate the Float32Array with normalized pixel values
let pixelIndex = 0;
const channelOffsetGreen = requiredHeight * requiredWidth;
const channelOffsetBlue = 2 * requiredHeight * requiredWidth;
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + channelOffsetGreen] = normFunction(
pixel.g,
);
processedImage[pixelIndex + channelOffsetBlue] = normFunction(
pixel.b,
);
pixelIndex++;
}
}
return {
data: processedImage,
originalSize: {
width: imageBitmap.width,
height: imageBitmap.height,
},
newSize: { width: scaledWidth, height: scaledHeight },
};
}
/**
* @deprecated The method should not be used
*/
private imageBitmapToTensorData(imageBitmap) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
const data = new Float32Array(
1 * 3 * imageBitmap.width * imageBitmap.height,
);
// Populate the Float32Array with normalized pixel values
for (let i = 0; i < pixelData.length; i += 4) {
// Normalize pixel values to the range [0, 1]
data[i / 4] = pixelData[i] / 255.0; // Red channel
data[i / 4 + imageBitmap.width * imageBitmap.height] =
pixelData[i + 1] / 255.0; // Green channel
data[i / 4 + 2 * imageBitmap.width * imageBitmap.height] =
pixelData[i + 2] / 255.0; // Blue channel
}
return {
data: data,
shape: [1, 3, imageBitmap.width, imageBitmap.height],
};
}
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
private getFacesFromYoloOutput(
rowOutput: Float32Array,
minScore: number,
): Array<FaceDetection> {
const faces: Array<FaceDetection> = [];
// iterate over each row
for (let i = 0; i < rowOutput.length; i += 16) {
const score = rowOutput[i + 4];
if (score < minScore) {
continue;
}
// The first 4 values represent the bounding box's coordinates (x1, y1, x2, y2)
const xCenter = rowOutput[i];
const yCenter = rowOutput[i + 1];
const width = rowOutput[i + 2];
const height = rowOutput[i + 3];
const xMin = xCenter - width / 2.0; // topLeft
const yMin = yCenter - height / 2.0; // topLeft
const leftEyeX = rowOutput[i + 5];
const leftEyeY = rowOutput[i + 6];
const rightEyeX = rowOutput[i + 7];
const rightEyeY = rowOutput[i + 8];
const noseX = rowOutput[i + 9];
const noseY = rowOutput[i + 10];
const leftMouthX = rowOutput[i + 11];
const leftMouthY = rowOutput[i + 12];
const rightMouthX = rowOutput[i + 13];
const rightMouthY = rowOutput[i + 14];
const box = new Box({
x: xMin,
y: yMin,
width: width,
height: height,
});
const probability = score as number;
const landmarks = [
new Point(leftEyeX, leftEyeY),
new Point(rightEyeX, rightEyeY),
new Point(noseX, noseY),
new Point(leftMouthX, leftMouthY),
new Point(rightMouthX, rightMouthY),
];
const face: FaceDetection = {
box,
landmarks,
probability,
// detectionMethod: this.method,
};
faces.push(face);
}
return faces;
}
public getRelativeDetection(
faceDetection: FaceDetection,
dimensions: Dimensions,
): FaceDetection {
const oldBox: Box = faceDetection.box;
const box = new Box({
x: oldBox.x / dimensions.width,
y: oldBox.y / dimensions.height,
width: oldBox.width / dimensions.width,
height: oldBox.height / dimensions.height,
});
const oldLandmarks: Point[] = faceDetection.landmarks;
const landmarks = oldLandmarks.map((l) => {
return new Point(l.x / dimensions.width, l.y / dimensions.height);
});
return {
box,
landmarks,
probability: faceDetection.probability,
};
}
private async estimateOnnx(imageBitmap: ImageBitmap) {
const maxFaceDistance = imageBitmap.width * MAX_FACE_DISTANCE_PERCENT;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["input"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
const runout = await inferenceSession.run(feeds);
const outputData = runout.output.data;
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
// measure time taken
const facesFromOnnx = await this.estimateOnnx(imageBitmap);
return facesFromOnnx;
}
public async dispose() {
const inferenceSession = await this.getOnnxInferenceSession();
inferenceSession?.release();
this.onnxInferenceSession = undefined;
}
}
export default new YoloFaceDetectionService();

View file

@ -332,8 +332,10 @@ function searchCollection(
}
function searchFilesByName(searchPhrase: string, files: EnteFile[]) {
return files.filter((file) =>
file.metadata.title.toLowerCase().includes(searchPhrase),
return files.filter(
(file) =>
file.id.toString().includes(searchPhrase) ||
file.metadata.title.toLowerCase().includes(searchPhrase),
);
}

View file

@ -6,3 +6,11 @@ export const ARCFACE_LANDMARKS = [
] as Array<[number, number]>;
export const ARCFACE_LANDMARKS_FACE_SIZE = 112;
export const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;

View file

@ -0,0 +1,4 @@
export interface ClipEmbedding {
embedding: Float32Array;
model: "ggml-clip" | "onnx-clip";
}

View file

@ -0,0 +1,27 @@
/// [`x`] and [y] are the coordinates of the top left corner of the box, so the minimim values
/// [width] and [height] are the width and height of the box.
/// All values are in absolute pixels relative to the original image size.
export interface CenterBox {
x: number;
y: number;
height: number;
width: number;
}
export interface Point {
x: number;
y: number;
}
export interface Detection {
box: CenterBox;
landmarks: Point[];
}
export interface Face {
id: string;
confidence: number;
blur: number;
embedding: Float32Array;
detection: Detection;
}

View file

@ -0,0 +1,12 @@
import { ClipEmbedding } from "./clip";
import { Face } from "./face";
export interface FileML {
fileID: number;
clip?: ClipEmbedding;
faces: Face[];
height: number;
width: number;
version: number;
error?: string;
}

View file

@ -1,4 +1,5 @@
import * as tf from "@tensorflow/tfjs-core";
import { DebugInfo } from "hdbscan";
import PQueue from "p-queue";
import { EnteFile } from "types/file";
@ -15,6 +16,14 @@ export interface MLSyncResult {
error?: Error;
}
export interface DebugFace {
fileId: string;
// face: FaceApiResult;
face: AlignedFace;
embedding: FaceEmbedding;
faceImage: FaceImage;
}
export declare type FaceImage = Array<Array<Array<number>>>;
export declare type FaceImageBlob = Blob;
@ -50,7 +59,10 @@ export declare type Landmark = Point;
export declare type ImageType = "Original" | "Preview";
export declare type FaceDetectionMethod = "BlazeFace" | "FaceApiSSD";
export declare type FaceDetectionMethod =
| "BlazeFace"
| "FaceApiSSD"
| "YoloFace";
export declare type ObjectDetectionMethod = "SSDMobileNetV2";
@ -65,6 +77,8 @@ export declare type FaceAlignmentMethod =
export declare type FaceEmbeddingMethod = "MobileFaceNet" | "FaceApiDlib";
export declare type BlurDetectionMethod = "Laplacian";
export declare type ClusteringMethod = "Hdbscan" | "Dbscan";
export class AlignedBox {
@ -120,6 +134,7 @@ export interface FaceAlignment {
export interface AlignedFace extends CroppedFace {
alignment?: FaceAlignment;
blurValue?: number;
}
export declare type FaceEmbedding = Float32Array;
@ -215,6 +230,11 @@ export interface FaceAlignmentConfig {
method: FaceAlignmentMethod;
}
export interface BlurDetectionConfig {
method: BlurDetectionMethod;
threshold: number;
}
export interface FaceEmbeddingConfig {
method: FaceEmbeddingMethod;
faceSize: number;
@ -241,6 +261,7 @@ export interface MLSyncConfig {
faceDetection: FaceDetectionConfig;
faceCrop: FaceCropConfig;
faceAlignment: FaceAlignmentConfig;
blurDetection: BlurDetectionConfig;
faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig;
objectDetection: ObjectDetectionConfig;
@ -263,6 +284,7 @@ export interface MLSyncContext {
faceCropService: FaceCropService;
faceAlignmentService: FaceAlignmentService;
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
objectDetectionService: ObjectDetectionService;
sceneDetectionService: SceneDetectionService;
@ -312,6 +334,10 @@ export interface FaceDetectionService {
method: Versioned<FaceDetectionMethod>;
// init(): Promise<void>;
detectFaces(image: ImageBitmap): Promise<Array<FaceDetection>>;
getRelativeDetection(
faceDetection: FaceDetection,
imageDimensions: Dimensions,
): FaceDetection;
dispose(): Promise<void>;
}
@ -354,12 +380,15 @@ export interface FaceEmbeddingService {
method: Versioned<FaceEmbeddingMethod>;
faceSize: number;
// init(): Promise<void>;
getFaceEmbeddings(
faceImages: Array<ImageBitmap>,
): Promise<Array<FaceEmbedding>>;
getFaceEmbeddings(faceImages: Float32Array): Promise<Array<FaceEmbedding>>;
dispose(): Promise<void>;
}
export interface BlurDetectionService {
method: Versioned<BlurDetectionMethod>;
detectBlur(alignedFaces: Float32Array): number[];
}
export interface ClusteringService {
method: Versioned<ClusteringMethod>;
@ -396,18 +425,3 @@ export interface MachineLearningWorker {
close(): void;
}
// export class TFImageBitmap {
// imageBitmap: ImageBitmap;
// tfImage: tf.Tensor3D;
// constructor(imageBitmap: ImageBitmap, tfImage: tf.Tensor3D) {
// this.imageBitmap = imageBitmap;
// this.tfImage = tfImage;
// }
// async dispose() {
// this.tfImage && (await tf.dispose(this.tfImage));
// this.imageBitmap && this.imageBitmap.close();
// }
// }

View file

@ -1,9 +1,324 @@
// these utils only work in env where OffscreenCanvas is available
import { Matrix, inverse } from "ml-matrix";
import { BlobOptions, Dimensions } from "types/image";
import { FaceAlignment } from "types/machineLearning";
import { enlargeBox } from "utils/machineLearning";
import { Box } from "../../../thirdparty/face-api/classes";
export function normalizePixelBetween0And1(pixelValue: number) {
return pixelValue / 255.0;
}
export function normalizePixelBetweenMinus1And1(pixelValue: number) {
return pixelValue / 127.5 - 1.0;
}
export function unnormalizePixelFromBetweenMinus1And1(pixelValue: number) {
return clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255);
}
export function readPixelColor(
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
}
export function clamp(value: number, min: number, max: number) {
return Math.min(max, Math.max(min, value));
}
export function getPixelBicubic(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1;
const nx = x + 1;
const ax = x + 2;
const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1);
const py = y - 1;
const ny = y + 1;
const ay = y + 2;
const dx = fx - x;
const dy = fy - y;
function cubic(
dx: number,
ipp: number,
icp: number,
inp: number,
iap: number,
) {
return (
icp +
0.5 *
(dx * (-ipp + inp) +
dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) +
dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap))
);
}
const icc = readPixelColor(imageData, imageWidth, imageHeight, x, y);
const ipp =
px < 0 || py < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, py);
const icp =
px < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, py);
const inp =
py < 0 || nx >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, py);
const iap =
ax >= imageWidth || py < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, py);
const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r);
const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g);
const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b);
// const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a);
const ipc =
px < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, y);
const inc =
nx >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, y);
const iac =
ax >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, y);
const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r);
const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g);
const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b);
// const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a);
const ipn =
px < 0 || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, ny);
const icn =
ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, ny);
const inn =
nx >= imageWidth || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, ny);
const ian =
ax >= imageWidth || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, ny);
const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r);
const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g);
const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b);
// const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a);
const ipa =
px < 0 || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, ay);
const ica =
ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, ay);
const ina =
nx >= imageWidth || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, ay);
const iaa =
ax >= imageWidth || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, ay);
const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r);
const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g);
const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b);
// const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a);
const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255));
const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255));
const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255));
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 };
}
/// Returns the pixel value (RGB) at the given coordinates using bilinear interpolation.
export function getPixelBilinear(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels
const pixel1 = readPixelColor(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = readPixelColor(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = readPixelColor(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = readPixelColor(imageData, imageWidth, imageHeight, x1, y1);
function bilinear(val1: number, val2: number, val3: number, val4: number) {
return Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
}
// Interpolate the pixel values
const red = bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r);
const green = bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g);
const blue = bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b);
return { r: red, g: green, b: blue };
}
export function warpAffineFloat32List(
imageBitmap: ImageBitmap,
faceAlignment: FaceAlignment,
faceSize: number,
inputData: Float32Array,
inputStartIndex: number,
): void {
// Get the pixel data
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
const transformationMatrix = faceAlignment.affineMatrix.map((row) =>
row.map((val) => (val != 1.0 ? val * faceSize : 1.0)),
); // 3x3
const A: Matrix = new Matrix([
[transformationMatrix[0][0], transformationMatrix[0][1]],
[transformationMatrix[1][0], transformationMatrix[1][1]],
]);
const Ainverse = inverse(A);
const b00 = transformationMatrix[0][2];
const b10 = transformationMatrix[1][2];
const a00Prime = Ainverse.get(0, 0);
const a01Prime = Ainverse.get(0, 1);
const a10Prime = Ainverse.get(1, 0);
const a11Prime = Ainverse.get(1, 1);
for (let yTrans = 0; yTrans < faceSize; ++yTrans) {
for (let xTrans = 0; xTrans < faceSize; ++xTrans) {
// Perform inverse affine transformation
const xOrigin =
a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10);
const yOrigin =
a10Prime * (xTrans - b00) + a11Prime * (yTrans - b10);
// Get the pixel from interpolation
const pixel = getPixelBicubic(
xOrigin,
yOrigin,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
// Set the pixel in the input data
const index = (yTrans * faceSize + xTrans) * 3;
inputData[inputStartIndex + index] =
normalizePixelBetweenMinus1And1(pixel.r);
inputData[inputStartIndex + index + 1] =
normalizePixelBetweenMinus1And1(pixel.g);
inputData[inputStartIndex + index + 2] =
normalizePixelBetweenMinus1And1(pixel.b);
}
}
}
export function createGrayscaleIntMatrixFromNormalized2List(
imageList: Float32Array,
faceNumber: number,
width: number = 112,
height: number = 112,
): number[][] {
const startIndex = faceNumber * width * height * 3;
return Array.from({ length: height }, (_, y) =>
Array.from({ length: width }, (_, x) => {
// 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue
const pixelIndex = startIndex + 3 * (y * width + x);
return clamp(
Math.round(
0.299 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex],
) +
0.587 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 1],
) +
0.114 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 2],
),
),
0,
255,
);
}),
);
}
export function resizeToSquare(img: ImageBitmap, size: number) {
const scale = size / Math.max(img.height, img.width);
const width = scale * img.width;

View file

@ -6,6 +6,7 @@ import { FaceAlignment, FaceDetection } from "types/machineLearning";
import {
ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
ARC_FACE_5_LANDMARKS,
} from "types/machineLearning/archface";
import { cropWithRotation, transform } from "utils/image";
import {
@ -21,7 +22,7 @@ import { Box, Point } from "../../../thirdparty/face-api/classes";
export function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
) {
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
@ -74,9 +75,13 @@ export function getFaceAlignmentUsingSimilarityTransform(
export function getArcfaceAlignment(
faceDetection: FaceDetection,
): FaceAlignment {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(ARCFACE_LANDMARKS, ARCFACE_LANDMARKS_FACE_SIZE),
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
}
@ -161,6 +166,7 @@ export function ibExtractFaceImage(
);
}
// Used in MLDebugViewOnly
export function ibExtractFaceImageUsingTransform(
image: ImageBitmap,
alignment: FaceAlignment,
@ -183,42 +189,6 @@ export function ibExtractFaceImages(
);
}
export function extractArcfaceAlignedFaceImage(
image: tf.Tensor4D,
faceDetection: FaceDetection,
faceSize: number,
): tf.Tensor4D {
const alignment = getFaceAlignmentUsingSimilarityTransform(
faceDetection,
ARCFACE_LANDMARKS,
);
return extractFaceImage(image, alignment, faceSize);
}
export function extractArcfaceAlignedFaceImages(
image: tf.Tensor3D | tf.Tensor4D,
faceDetections: Array<FaceDetection>,
faceSize: number,
): tf.Tensor4D {
return tf.tidy(() => {
const tf4dFloat32Image = toTensor4D(image, "float32");
const faceImages = new Array<tf.Tensor3D>(faceDetections.length);
for (let i = 0; i < faceDetections.length; i++) {
faceImages[i] = tf.squeeze(
extractArcfaceAlignedFaceImage(
tf4dFloat32Image,
faceDetections[i],
faceSize,
),
[0],
);
}
return tf.stack(faceImages) as tf.Tensor4D;
});
}
const BLAZEFACE_LEFT_EYE_INDEX = 0;
const BLAZEFACE_RIGHT_EYE_INDEX = 1;
// const BLAZEFACE_NOSE_INDEX = 2;

View file

@ -35,6 +35,18 @@ export function getDetectionCenter(detection: FaceDetection) {
return center.div({ x: 4, y: 4 });
}
/**
* Finds the nearest face detection from a list of detections to a specified detection.
*
* This function calculates the center of each detection and then finds the detection whose center is nearest to the center of the specified detection.
* If a maximum distance is specified, only detections within that distance are considered.
*
* @param toDetection - The face detection to find the nearest detection to.
* @param fromDetections - An array of face detections to search in.
* @param maxDistance - The maximum distance between the centers of the two detections for a detection to be considered. If not specified, all detections are considered.
*
* @returns The nearest face detection from the list, or `undefined` if no detection is within the maximum distance.
*/
export function getNearestDetection(
toDetection: FaceDetection,
fromDetections: Array<FaceDetection>,
@ -47,7 +59,18 @@ export function getNearestDetection(
return nearestIndex >= 0 && fromDetections[nearestIndex];
}
// TODO: can also be done through tf.image.nonMaxSuppression
/**
* Removes duplicate face detections from an array of detections.
*
* This function sorts the detections by their probability in descending order, then iterates over them.
* For each detection, it calculates the Euclidean distance to all other detections.
* If the distance is less than or equal to the specified threshold (`withinDistance`), the other detection is considered a duplicate and is removed.
*
* @param detections - An array of face detections to remove duplicates from.
* @param withinDistance - The maximum Euclidean distance between two detections for them to be considered duplicates.
*
* @returns An array of face detections with duplicates removed.
*/
export function removeDuplicateDetections(
detections: Array<FaceDetection>,
withinDistance: number,

View file

@ -17,6 +17,7 @@ import {
DetectedFace,
DetectedObject,
Face,
FaceAlignment,
FaceImageBlob,
MlFileData,
Person,
@ -24,18 +25,11 @@ import {
Versioned,
} from "types/machineLearning";
import { getRenderableImage } from "utils/file";
import { imageBitmapToBlob } from "utils/image";
import { clamp, imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import mlIDbStorage from "utils/storage/mlIDbStorage";
import { Box, Point } from "../../../thirdparty/face-api/classes";
import {
getArcfaceAlignment,
ibExtractFaceImage,
ibExtractFaceImages,
} from "./faceAlign";
import {
getFaceCropBlobFromStorage,
ibExtractFaceImagesFromCrops,
} from "./faceCrop";
import { ibExtractFaceImage, ibExtractFaceImages } from "./faceAlign";
import { getFaceCropBlobFromStorage } from "./faceCrop";
export function f32Average(descriptors: Float32Array[]) {
if (descriptors.length < 1) {
@ -241,9 +235,10 @@ export async function extractFaceImages(
faceSize: number,
image?: ImageBitmap,
) {
if (faces.length === faces.filter((f) => f.crop).length) {
return ibExtractFaceImagesFromCrops(faces, faceSize);
} else if (image) {
// if (faces.length === faces.filter((f) => f.crop).length) {
// return ibExtractFaceImagesFromCrops(faces, faceSize);
// } else
if (image) {
const faceAlignments = faces.map((f) => f.alignment);
return ibExtractFaceImages(image, faceAlignments, faceSize);
} else {
@ -253,31 +248,68 @@ export async function extractFaceImages(
}
}
export async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
faceSize: number,
image: ImageBitmap,
): Promise<Float32Array> {
const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3,
);
for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List(
image,
alignedFace,
faceSize,
faceData,
faceDataOffset,
);
}
return faceData;
}
export function leftFillNum(num: number, length: number, padding: number) {
return num.toString().padStart(length, padding.toString());
}
// TODO: same face can not be only based on this id,
// this gives same id to faces whose arcface center lies in same box of 1% image grid
// maximum distance for same id will be around √2%
// will give same id in most of the cases, except for face centers lying near grid edges
// faces with same id should be treated as same face, and diffrent id should be tested further
// further test can rely on nearest face within certain threshold in same image
// can also explore spatial index similar to Geohash for indexing, but overkill
// for mostly single digit faces in one image
// also check if this needs to be globally unique or unique for a user
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
const arcFaceAlignedFace = getArcfaceAlignment(detectedFace.detection);
const imgDimPoint = new Point(imageDims.width, imageDims.height);
const gridPt = arcFaceAlignedFace.center
.mul(new Point(100, 100))
.div(imgDimPoint)
.floor()
.bound(0, 99);
const gridPaddedX = leftFillNum(gridPt.x, 2, 0);
const gridPaddedY = leftFillNum(gridPt.y, 2, 0);
const xMin = clamp(
detectedFace.detection.box.x / imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMin = clamp(
detectedFace.detection.box.y / imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const xMax = clamp(
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMax = clamp(
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
return `${detectedFace.fileId}-${gridPaddedX}-${gridPaddedY}`;
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
return faceID;
}
export function getObjectId(

View file

@ -24,7 +24,8 @@ module.exports = {
"max-len": "off",
"new-cap": "off",
"no-invalid-this": "off",
eqeqeq: "error",
// TODO(MR): We want this off anyway, for now forcing it here
eqeqeq: "off",
"object-curly-spacing": ["error", "always"],
"space-before-function-paren": "off",
"operator-linebreak": [