This commit is contained in:
Manav Rathi 2024-05-16 13:25:47 +05:30
parent 36af1cfacd
commit 839b4c04a9
No known key found for this signature in database
4 changed files with 212 additions and 257 deletions

View file

@ -7,12 +7,7 @@ import {
boxFromBoundingBox,
newBox,
} from "services/face/geom";
import {
FaceDetection,
FaceDetectionMethod,
FaceDetectionService,
Versioned,
} from "services/face/types";
import { FaceDetection } from "services/face/types";
import {
Matrix,
applyToPoint,
@ -26,222 +21,208 @@ import {
normalizePixelBetween0And1,
} from "utils/image";
class YoloFaceDetectionService implements FaceDetectionService {
public method: Versioned<FaceDetectionMethod>;
public constructor() {
this.method = {
value: "YoloFace",
version: 1,
};
}
public async detectFaces(
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> {
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
const preprocessResult =
this.preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const outputData = await workerBridge.detectFaces(data);
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,
);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
}
private preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap: ImageBitmap,
requiredWidth: number,
requiredHeight: number,
maintainAspectRatio: boolean = true,
normFunction: (
pixelValue: number,
) => number = normalizePixelBetween0And1,
) {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * 3 * requiredWidth * requiredHeight,
);
// Populate the Float32Array with normalized pixel values
let pixelIndex = 0;
const channelOffsetGreen = requiredHeight * requiredWidth;
const channelOffsetBlue = 2 * requiredHeight * requiredWidth;
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + channelOffsetGreen] = normFunction(
pixel.g,
);
processedImage[pixelIndex + channelOffsetBlue] = normFunction(
pixel.b,
);
pixelIndex++;
}
}
return {
data: processedImage,
originalSize: {
width: imageBitmap.width,
height: imageBitmap.height,
},
newSize: { width: scaledWidth, height: scaledHeight },
};
}
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
private getFacesFromYoloOutput(
rowOutput: Float32Array,
minScore: number,
): Array<FaceDetection> {
const faces: Array<FaceDetection> = [];
// iterate over each row
for (let i = 0; i < rowOutput.length; i += 16) {
const score = rowOutput[i + 4];
if (score < minScore) {
continue;
}
// The first 4 values represent the bounding box's coordinates (x1, y1, x2, y2)
const xCenter = rowOutput[i];
const yCenter = rowOutput[i + 1];
const width = rowOutput[i + 2];
const height = rowOutput[i + 3];
const xMin = xCenter - width / 2.0; // topLeft
const yMin = yCenter - height / 2.0; // topLeft
const leftEyeX = rowOutput[i + 5];
const leftEyeY = rowOutput[i + 6];
const rightEyeX = rowOutput[i + 7];
const rightEyeY = rowOutput[i + 8];
const noseX = rowOutput[i + 9];
const noseY = rowOutput[i + 10];
const leftMouthX = rowOutput[i + 11];
const leftMouthY = rowOutput[i + 12];
const rightMouthX = rowOutput[i + 13];
const rightMouthY = rowOutput[i + 14];
const box = new Box({
x: xMin,
y: yMin,
width: width,
height: height,
});
const probability = score as number;
const landmarks = [
new Point(leftEyeX, leftEyeY),
new Point(rightEyeX, rightEyeY),
new Point(noseX, noseY),
new Point(leftMouthX, leftMouthY),
new Point(rightMouthX, rightMouthY),
];
const face: FaceDetection = {
box,
landmarks,
probability,
// detectionMethod: this.method,
};
faces.push(face);
}
return faces;
}
public getRelativeDetection(
faceDetection: FaceDetection,
dimensions: Dimensions,
): FaceDetection {
const oldBox: Box = faceDetection.box;
const box = new Box({
x: oldBox.x / dimensions.width,
y: oldBox.y / dimensions.height,
width: oldBox.width / dimensions.width,
height: oldBox.height / dimensions.height,
});
const oldLandmarks: Point[] = faceDetection.landmarks;
const landmarks = oldLandmarks.map((l) => {
return new Point(l.x / dimensions.width, l.y / dimensions.height);
});
/**
* Detect faces in the given {@link imageBitmap}.
*
* The ML model used is YOLO, running in an ONNX runtime.
*/
export const detectFaces = async (
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> => {
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
const preprocessResult = preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const outputData = await workerBridge.detectFaces(data);
const faces = getFacesFromYoloOutput(outputData as Float32Array, 0.7);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: faceDetection.probability,
};
}
}
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
};
export default new YoloFaceDetectionService();
const preprocessImageBitmapToFloat32ChannelsFirst = (
imageBitmap: ImageBitmap,
requiredWidth: number,
requiredHeight: number,
maintainAspectRatio: boolean = true,
normFunction: (pixelValue: number) => number = normalizePixelBetween0And1,
) => {
// Create an OffscreenCanvas and set its size
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * 3 * requiredWidth * requiredHeight,
);
// Populate the Float32Array with normalized pixel values
let pixelIndex = 0;
const channelOffsetGreen = requiredHeight * requiredWidth;
const channelOffsetBlue = 2 * requiredHeight * requiredWidth;
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + channelOffsetGreen] = normFunction(
pixel.g,
);
processedImage[pixelIndex + channelOffsetBlue] = normFunction(
pixel.b,
);
pixelIndex++;
}
}
return {
data: processedImage,
originalSize: {
width: imageBitmap.width,
height: imageBitmap.height,
},
newSize: { width: scaledWidth, height: scaledHeight },
};
};
// The rowOutput is a Float32Array of shape [25200, 16], where each row represents a bounding box.
const getFacesFromYoloOutput = (
rowOutput: Float32Array,
minScore: number,
): Array<FaceDetection> => {
const faces: Array<FaceDetection> = [];
// iterate over each row
for (let i = 0; i < rowOutput.length; i += 16) {
const score = rowOutput[i + 4];
if (score < minScore) {
continue;
}
// The first 4 values represent the bounding box's coordinates (x1, y1, x2, y2)
const xCenter = rowOutput[i];
const yCenter = rowOutput[i + 1];
const width = rowOutput[i + 2];
const height = rowOutput[i + 3];
const xMin = xCenter - width / 2.0; // topLeft
const yMin = yCenter - height / 2.0; // topLeft
const leftEyeX = rowOutput[i + 5];
const leftEyeY = rowOutput[i + 6];
const rightEyeX = rowOutput[i + 7];
const rightEyeY = rowOutput[i + 8];
const noseX = rowOutput[i + 9];
const noseY = rowOutput[i + 10];
const leftMouthX = rowOutput[i + 11];
const leftMouthY = rowOutput[i + 12];
const rightMouthX = rowOutput[i + 13];
const rightMouthY = rowOutput[i + 14];
const box = new Box({
x: xMin,
y: yMin,
width: width,
height: height,
});
const probability = score as number;
const landmarks = [
new Point(leftEyeX, leftEyeY),
new Point(rightEyeX, rightEyeY),
new Point(noseX, noseY),
new Point(leftMouthX, leftMouthY),
new Point(rightMouthX, rightMouthY),
];
const face: FaceDetection = {
box,
landmarks,
probability,
// detectionMethod: this.method,
};
faces.push(face);
}
return faces;
};
export const getRelativeDetection = (
faceDetection: FaceDetection,
dimensions: Dimensions,
): FaceDetection => {
const oldBox: Box = faceDetection.box;
const box = new Box({
x: oldBox.x / dimensions.width,
y: oldBox.y / dimensions.height,
width: oldBox.width / dimensions.width,
height: oldBox.height / dimensions.height,
});
const oldLandmarks: Point[] = faceDetection.landmarks;
const landmarks = oldLandmarks.map((l) => {
return new Point(l.x / dimensions.width, l.y / dimensions.height);
});
return {
box,
landmarks,
probability: faceDetection.probability,
};
};
/**
* Removes duplicate face detections from an array of detections.

View file

@ -198,7 +198,6 @@ export interface MLSyncContext {
token: string;
userID: number;
faceDetectionService: FaceDetectionService;
faceCropService: FaceCropService;
faceEmbeddingService: FaceEmbeddingService;
@ -240,16 +239,6 @@ export interface MLLibraryData {
export declare type MLIndex = "files" | "people";
export interface FaceDetectionService {
method: Versioned<FaceDetectionMethod>;
detectFaces(image: ImageBitmap): Promise<Array<FaceDetection>>;
getRelativeDetection(
faceDetection: FaceDetection,
imageDimensions: Dimensions,
): FaceDetection;
}
export interface FaceCropService {
method: Versioned<FaceCropMethod>;

View file

@ -2,6 +2,7 @@ import { openCache } from "@/next/blob-cache";
import log from "@/next/log";
import { faceAlignment } from "services/face/align";
import mlIDbStorage from "services/face/db";
import { detectFaces, getRelativeDetection } from "services/face/detect-face";
import {
DetectedFace,
Face,
@ -10,9 +11,9 @@ import {
type FaceAlignment,
} from "services/face/types";
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import { detectBlur } from "../face/detect-blur";
import { clusterFaces } from "../face/cluster";
import { getFaceCrop } from "../face/crop";
import { detectBlur } from "../face/detect-blur";
import {
fetchImageBitmap,
fetchImageBitmapForContext,
@ -26,11 +27,13 @@ class FaceService {
fileContext: MLSyncFileContext,
) {
const { newMlFile } = fileContext;
newMlFile.faceDetectionMethod = syncContext.faceDetectionService.method;
newMlFile.faceDetectionMethod = {
value: "YoloFace",
version: 1,
};
fileContext.newDetection = true;
const imageBitmap = await fetchImageBitmapForContext(fileContext);
const faceDetections =
await syncContext.faceDetectionService.detectFaces(imageBitmap);
const faceDetections = await detectFaces(imageBitmap);
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
@ -123,11 +126,10 @@ class FaceService {
for (let i = 0; i < newMlFile.faces.length; i++) {
const face = newMlFile.faces[i];
if (face.detection.box.x + face.detection.box.width < 2) continue; // Skip if somehow already relative
face.detection =
syncContext.faceDetectionService.getRelativeDetection(
face.detection,
newMlFile.imageDimensions,
);
face.detection = getRelativeDetection(
face.detection,
newMlFile.imageDimensions,
);
}
}

View file

@ -15,8 +15,6 @@ import {
Face,
FaceCropService,
FaceDetection,
FaceDetectionMethod,
FaceDetectionService,
FaceEmbeddingMethod,
FaceEmbeddingService,
Landmark,
@ -31,12 +29,10 @@ import {
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import FaceService from "./faceService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import { fetchImageBitmapForContext } from "../face/image";
import { syncPeopleIndex } from "../face/people";
import yoloFaceDetectionService from "../face/detect-face";
import FaceService from "./faceService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
/**
* TODO-ML(MR): What and why.
@ -110,16 +106,6 @@ export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
}
export class MLFactory {
public static getFaceDetectionService(
method: FaceDetectionMethod,
): FaceDetectionService {
if (method === "YoloFace") {
return yoloFaceDetectionService;
}
throw Error("Unknon face detection method: " + method);
}
public static getFaceEmbeddingService(
method: FaceEmbeddingMethod,
): FaceEmbeddingService {
@ -135,7 +121,6 @@ export class LocalMLSyncContext implements MLSyncContext {
public token: string;
public userID: number;
public faceDetectionService: FaceDetectionService;
public faceCropService: FaceCropService;
public faceEmbeddingService: FaceEmbeddingService;
@ -163,8 +148,6 @@ export class LocalMLSyncContext implements MLSyncContext {
this.token = token;
this.userID = userID;
this.faceDetectionService =
MLFactory.getFaceDetectionService("YoloFace");
this.faceEmbeddingService =
MLFactory.getFaceEmbeddingService("MobileFaceNet");