Do 2nd pass of blazeface on close ups

for better accuracy
Transform utils
This commit is contained in:
Shailesh Pandit 2021-12-28 20:30:52 +05:30
parent 72b2a6ad8b
commit 385acec1ab
7 changed files with 229 additions and 50 deletions

View file

@ -14,14 +14,14 @@ const FaceChipContainer = styled.div`
flex-wrap: wrap;
justify-content: center;
align-items: center;
margin-top: 10px;
margin-bottom: 10px;
margin-top: 5px;
margin-bottom: 5px;
`;
const FaceChip = styled.div`
width: 112px;
height: 112px;
margin-right: 10px;
margin: 5px;
border-radius: 50%;
overflow: hidden;
position: relative;

View file

@ -10,6 +10,7 @@ import {
BLAZEFACE_INPUT_SIZE,
BLAZEFACE_IOU_THRESHOLD,
BLAZEFACE_MAX_FACES,
BLAZEFACE_PASS1_SCORE_THRESHOLD,
BLAZEFACE_SCORE_THRESHOLD,
DetectedFace,
FaceDetectionMethod,
@ -17,7 +18,13 @@ import {
Versioned,
} from 'types/machineLearning';
import { Box, Point } from '../../../thirdparty/face-api/classes';
import { resizeToSquare } from 'utils/image';
import { addPadding, crop, resizeToSquare } from 'utils/image';
import {
computeTransformToBox,
transformBox,
transformPoints,
} from 'utils/machineLearning/transform';
import { enlargeBox, newBox, normFaceBox } from 'utils/machineLearning';
class TFJSFaceDetectionService implements FaceDetectionService {
private blazeFaceModel: Promise<BlazeFaceModel>;
@ -38,7 +45,7 @@ class TFJSFaceDetectionService implements FaceDetectionService {
private async init() {
this.blazeFaceModel = blazeFaceLoad({
maxFaces: BLAZEFACE_MAX_FACES,
scoreThreshold: BLAZEFACE_SCORE_THRESHOLD,
scoreThreshold: BLAZEFACE_PASS1_SCORE_THRESHOLD,
iouThreshold: BLAZEFACE_IOU_THRESHOLD,
modelUrl: '/models/blazeface/back/model.json',
inputHeight: BLAZEFACE_INPUT_SIZE,
@ -149,42 +156,79 @@ class TFJSFaceDetectionService implements FaceDetectionService {
return this.blazeFaceModel;
}
public async detectFaces(
private async estimateFaces(
imageBitmap: ImageBitmap
): Promise<Array<DetectedFace>> {
const resized = resizeToSquare(imageBitmap, BLAZEFACE_INPUT_SIZE);
const widthRatio = imageBitmap.width / resized.width;
const heightRatio = imageBitmap.height / resized.height;
const tfImage = tf.browser.fromPixels(resized.image);
const blazeFaceModel = await this.getBlazefaceModel();
const faces = await blazeFaceModel.estimateFaces(tfImage);
tf.dispose(tfImage);
const detectedFaces: Array<DetectedFace> = faces?.map(
(normalizedFace) => {
const landmarks = normalizedFace.landmarks as number[][];
return {
box: new Box({
left: normalizedFace.topLeft[0] * widthRatio,
top: normalizedFace.topLeft[1] * heightRatio,
right: normalizedFace.bottomRight[0] * widthRatio,
bottom: normalizedFace.bottomRight[1] * heightRatio,
}),
landmarks:
landmarks &&
landmarks.map(
(l) =>
new Point(l[0] * widthRatio, l[1] * heightRatio)
),
probability: normalizedFace.probability as number,
// detectionMethod: this.method,
} as DetectedFace;
}
);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
// console.log("1st pass: ", { transform });
const detectedFaces: Array<DetectedFace> = faces?.map((f) => {
const box = transformBox(normFaceBox(f), transform);
const normLandmarks = (f.landmarks as number[][])?.map(
(l) => new Point(l[0], l[1])
);
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
// detectionMethod: this.method,
} as DetectedFace;
});
return detectedFaces;
}
public async detectFaces(
imageBitmap: ImageBitmap
): Promise<Array<DetectedFace>> {
const pass1Detections = await this.estimateFaces(imageBitmap);
// run 2nd pass for accuracy
const detections = [];
for (const face of pass1Detections) {
const imageBox = enlargeBox(face.box, 2);
const faceImage = crop(
imageBitmap,
imageBox,
BLAZEFACE_INPUT_SIZE / 2
);
const paddedImage = addPadding(faceImage, 0.5);
const paddedBox = enlargeBox(imageBox, 2);
const pass2Detections = await this.estimateFaces(paddedImage);
// TODO: select based on nearest under certain threshold
// this will matter based on our IOU threshold and faces near each other
const selected = pass2Detections[0];
// console.log("pass2: ", face.probability, selected.probability);
// TODO: we might miss 1st pass face actually having score within threshold
if (selected && selected.probability >= BLAZEFACE_SCORE_THRESHOLD) {
const inBox = newBox(0, 0, faceImage.width, faceImage.height);
imageBox.x = paddedBox.x;
imageBox.y = paddedBox.y;
const transform = computeTransformToBox(inBox, imageBox);
selected.box = transformBox(selected.box, transform);
selected.landmarks = transformPoints(
selected.landmarks,
transform
);
// console.log("pass2: ", { imageBox, paddedBox, transform, selected });
detections.push(selected);
}
}
return detections;
}
public async dispose() {
const blazeFaceModel = await this.getBlazefaceModel();
blazeFaceModel?.dispose();

View file

@ -273,7 +273,8 @@ export declare type MLIndex = 'files' | 'people';
export const BLAZEFACE_MAX_FACES = 20;
export const BLAZEFACE_INPUT_SIZE = 256;
export const BLAZEFACE_IOU_THRESHOLD = 0.3;
export const BLAZEFACE_SCORE_THRESHOLD = 0.65;
export const BLAZEFACE_SCORE_THRESHOLD = 0.7;
export const BLAZEFACE_PASS1_SCORE_THRESHOLD = 0.4;
export const BLAZEFACE_FACE_SIZE = 112;
export interface FaceDetectionService {

View file

@ -39,6 +39,15 @@ export function transform(
return offscreen.transferToImageBitmap();
}
export function crop(imageBitmap: ImageBitmap, cropBox: Box, size: number) {
const dimensions: Dimensions = {
width: size,
height: size,
};
return cropWithRotation(imageBitmap, cropBox, 0, dimensions, dimensions);
}
export function cropWithRotation(
imageBitmap: ImageBitmap,
cropBox: Box,
@ -105,6 +114,24 @@ export function cropWithRotation(
return offscreen.transferToImageBitmap();
}
export function addPadding(image: ImageBitmap, padding: number) {
const scale = 1 + padding * 2;
const width = scale * image.width;
const height = scale * image.height;
const offscreen = new OffscreenCanvas(width, height);
const ctx = offscreen.getContext('2d');
ctx.imageSmoothingEnabled = false;
ctx.drawImage(
image,
width / 2 - image.width / 2,
height / 2 - image.height / 2,
image.width,
image.height
);
return offscreen.transferToImageBitmap();
}
export async function imageBitmapToBlob(
imageBitmap: ImageBitmap,
options?: BlobOptions

View file

@ -40,25 +40,25 @@ export async function getStoredFaceCrop(
};
}
export async function ibExtractFaceImageFromCrop(
alignedFace: AlignedFace,
export function extractFaceImageFromCrop(
faceCrop: FaceCrop,
box: Box,
rotation: number,
faceSize: number
): Promise<ImageBitmap> {
const image = alignedFace.faceCrop?.image;
let imageBox = alignedFace.faceCrop?.imageBox;
if (!image || !imageBox) {
): ImageBitmap {
const faceCropImage = faceCrop?.image;
let imageBox = faceCrop?.imageBox;
if (!faceCropImage || !imageBox) {
throw Error('Face crop not present');
}
const box = getAlignedFaceBox(alignedFace);
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
// TODO: Have better serialization to avoid creating new object manually when calling class methods
imageBox = new Box(imageBox);
const scale = faceCropImage.width / imageBox.width;
const scaledImageBox = imageBox.rescale(scale).round();
const scaledBox = box.rescale(scale).round();
const shiftedBox = scaledBox.shift(-scaledImageBox.x, -scaledImageBox.y);
const transformedBox = box
.shift(-imageBox.x, -imageBox.y)
.rescale(scale)
.round();
// console.log({ box, imageBox, faceCropImage, scale, scaledBox, scaledImageBox, shiftedBox });
const faceSizeDimentions: Dimensions = {
@ -67,8 +67,8 @@ export async function ibExtractFaceImageFromCrop(
};
const faceImage = cropWithRotation(
faceCropImage,
shiftedBox,
alignedFace.rotation,
transformedBox,
rotation,
faceSizeDimentions,
faceSizeDimentions
);
@ -76,6 +76,21 @@ export async function ibExtractFaceImageFromCrop(
return faceImage;
}
export async function ibExtractFaceImageFromCrop(
alignedFace: AlignedFace,
faceSize: number
): Promise<ImageBitmap> {
const box = getAlignedFaceBox(alignedFace);
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
return extractFaceImageFromCrop(
{ image: faceCropImage, imageBox: alignedFace.faceCrop?.imageBox },
box,
alignedFace.rotation,
faceSize
);
}
export async function ibExtractFaceImagesFromCrops(
faces: AlignedFace[],
faceSize: number

View file

@ -16,6 +16,7 @@ import { mlFilesStore, mlPeopleStore } from 'utils/storage/mlStorage';
import { convertForPreview, needsConversionForPreview } from 'utils/file';
import { cached } from 'utils/storage/cache';
import { imageBitmapToBlob } from 'utils/image';
import { NormalizedFace } from '@tensorflow-models/blazeface';
export function f32Average(descriptors: Float32Array[]) {
if (descriptors.length < 1) {
@ -123,6 +124,28 @@ export function extractFaces(
});
}
export function newBox(x: number, y: number, width: number, height: number) {
return new Box({ x, y, width, height });
}
export function newBoxFromPoints(
left: number,
top: number,
right: number,
bottom: number
) {
return new Box({ left, top, right, bottom });
}
export function normFaceBox(face: NormalizedFace) {
return newBoxFromPoints(
face.topLeft[0],
face.topLeft[1],
face.bottomRight[0],
face.bottomRight[1]
);
}
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
}
@ -204,6 +227,15 @@ async function getOriginalImageFile(file: File, token: string) {
return new Response(fileStream).blob();
}
async function getOriginalFile(file: File, token: string) {
let fileBlob = await getOriginalImageFile(file, token);
if (needsConversionForPreview(file)) {
fileBlob = await convertForPreview(file, fileBlob);
}
return fileBlob;
}
export async function getOriginalImageBitmap(
file: File,
token: string,
@ -213,17 +245,13 @@ export async function getOriginalImageBitmap(
if (useCache) {
fileBlob = await cached('files', '/' + file.id.toString(), () => {
return getOriginalImageFile(file, token);
return getOriginalFile(file, token);
});
} else {
fileBlob = await getOriginalImageFile(file, token);
fileBlob = await getOriginalFile(file, token);
}
console.log('[MLService] Got file: ', file.id.toString());
if (needsConversionForPreview(file)) {
fileBlob = await convertForPreview(file, fileBlob);
}
return getImageBitmap(fileBlob);
}

View file

@ -0,0 +1,64 @@
import { Box, Point } from '../../../thirdparty/face-api/classes';
import { Matrix } from 'ml-matrix';
import { newBoxFromPoints } from '.';
export function translation(x: number, y: number) {
return new Matrix([
[1, 0, x],
[0, 1, y],
[0, 0, 1],
]);
}
export function scale(sx: number, sy: number) {
return new Matrix([
[sx, 0, 0],
[0, sy, 0],
[0, 0, 1],
]);
}
export function rotation(angle: number) {
const cosa = Math.cos(angle);
const sina = Math.sin(angle);
return new Matrix([
[cosa, -sina, 0],
[sina, cosa, 0],
[0, 0, 1],
]);
}
export function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
return translation(toBox.x, toBox.y).mmul(
scale(toBox.width / inBox.width, toBox.height / inBox.height)
);
}
export function pointToArray(point: Point) {
return [point.x, point.y];
}
export function transformPointVec(point: number[], transform: Matrix) {
point[2] = 1;
const mat = new Matrix([point]).transpose();
const mulmat = new Matrix(transform).mmul(mat).to1DArray();
// console.log({point, mat, mulmat});
return mulmat;
}
export function transformPoint(point: Point, transform: Matrix) {
const pointVec = transformPointVec(pointToArray(point), transform);
return new Point(pointVec[0], pointVec[1]);
}
export function transformPoints(points: Point[], transform: Matrix) {
return points.map((p) => transformPoint(p, transform));
}
export function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
}