Do 2nd pass of blazeface on close ups
for better accuracy Transform utils
This commit is contained in:
parent
72b2a6ad8b
commit
385acec1ab
7 changed files with 229 additions and 50 deletions
|
@ -14,14 +14,14 @@ const FaceChipContainer = styled.div`
|
|||
flex-wrap: wrap;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 5px;
|
||||
margin-bottom: 5px;
|
||||
`;
|
||||
|
||||
const FaceChip = styled.div`
|
||||
width: 112px;
|
||||
height: 112px;
|
||||
margin-right: 10px;
|
||||
margin: 5px;
|
||||
border-radius: 50%;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
|
|
|
@ -10,6 +10,7 @@ import {
|
|||
BLAZEFACE_INPUT_SIZE,
|
||||
BLAZEFACE_IOU_THRESHOLD,
|
||||
BLAZEFACE_MAX_FACES,
|
||||
BLAZEFACE_PASS1_SCORE_THRESHOLD,
|
||||
BLAZEFACE_SCORE_THRESHOLD,
|
||||
DetectedFace,
|
||||
FaceDetectionMethod,
|
||||
|
@ -17,7 +18,13 @@ import {
|
|||
Versioned,
|
||||
} from 'types/machineLearning';
|
||||
import { Box, Point } from '../../../thirdparty/face-api/classes';
|
||||
import { resizeToSquare } from 'utils/image';
|
||||
import { addPadding, crop, resizeToSquare } from 'utils/image';
|
||||
import {
|
||||
computeTransformToBox,
|
||||
transformBox,
|
||||
transformPoints,
|
||||
} from 'utils/machineLearning/transform';
|
||||
import { enlargeBox, newBox, normFaceBox } from 'utils/machineLearning';
|
||||
|
||||
class TFJSFaceDetectionService implements FaceDetectionService {
|
||||
private blazeFaceModel: Promise<BlazeFaceModel>;
|
||||
|
@ -38,7 +45,7 @@ class TFJSFaceDetectionService implements FaceDetectionService {
|
|||
private async init() {
|
||||
this.blazeFaceModel = blazeFaceLoad({
|
||||
maxFaces: BLAZEFACE_MAX_FACES,
|
||||
scoreThreshold: BLAZEFACE_SCORE_THRESHOLD,
|
||||
scoreThreshold: BLAZEFACE_PASS1_SCORE_THRESHOLD,
|
||||
iouThreshold: BLAZEFACE_IOU_THRESHOLD,
|
||||
modelUrl: '/models/blazeface/back/model.json',
|
||||
inputHeight: BLAZEFACE_INPUT_SIZE,
|
||||
|
@ -149,42 +156,79 @@ class TFJSFaceDetectionService implements FaceDetectionService {
|
|||
return this.blazeFaceModel;
|
||||
}
|
||||
|
||||
public async detectFaces(
|
||||
private async estimateFaces(
|
||||
imageBitmap: ImageBitmap
|
||||
): Promise<Array<DetectedFace>> {
|
||||
const resized = resizeToSquare(imageBitmap, BLAZEFACE_INPUT_SIZE);
|
||||
const widthRatio = imageBitmap.width / resized.width;
|
||||
const heightRatio = imageBitmap.height / resized.height;
|
||||
const tfImage = tf.browser.fromPixels(resized.image);
|
||||
const blazeFaceModel = await this.getBlazefaceModel();
|
||||
const faces = await blazeFaceModel.estimateFaces(tfImage);
|
||||
tf.dispose(tfImage);
|
||||
|
||||
const detectedFaces: Array<DetectedFace> = faces?.map(
|
||||
(normalizedFace) => {
|
||||
const landmarks = normalizedFace.landmarks as number[][];
|
||||
const inBox = newBox(0, 0, resized.width, resized.height);
|
||||
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
|
||||
const transform = computeTransformToBox(inBox, toBox);
|
||||
// console.log("1st pass: ", { transform });
|
||||
|
||||
const detectedFaces: Array<DetectedFace> = faces?.map((f) => {
|
||||
const box = transformBox(normFaceBox(f), transform);
|
||||
const normLandmarks = (f.landmarks as number[][])?.map(
|
||||
(l) => new Point(l[0], l[1])
|
||||
);
|
||||
const landmarks = transformPoints(normLandmarks, transform);
|
||||
return {
|
||||
box: new Box({
|
||||
left: normalizedFace.topLeft[0] * widthRatio,
|
||||
top: normalizedFace.topLeft[1] * heightRatio,
|
||||
right: normalizedFace.bottomRight[0] * widthRatio,
|
||||
bottom: normalizedFace.bottomRight[1] * heightRatio,
|
||||
}),
|
||||
landmarks:
|
||||
landmarks &&
|
||||
landmarks.map(
|
||||
(l) =>
|
||||
new Point(l[0] * widthRatio, l[1] * heightRatio)
|
||||
),
|
||||
probability: normalizedFace.probability as number,
|
||||
box,
|
||||
landmarks,
|
||||
probability: f.probability as number,
|
||||
// detectionMethod: this.method,
|
||||
} as DetectedFace;
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
return detectedFaces;
|
||||
}
|
||||
|
||||
public async detectFaces(
|
||||
imageBitmap: ImageBitmap
|
||||
): Promise<Array<DetectedFace>> {
|
||||
const pass1Detections = await this.estimateFaces(imageBitmap);
|
||||
|
||||
// run 2nd pass for accuracy
|
||||
const detections = [];
|
||||
for (const face of pass1Detections) {
|
||||
const imageBox = enlargeBox(face.box, 2);
|
||||
const faceImage = crop(
|
||||
imageBitmap,
|
||||
imageBox,
|
||||
BLAZEFACE_INPUT_SIZE / 2
|
||||
);
|
||||
const paddedImage = addPadding(faceImage, 0.5);
|
||||
const paddedBox = enlargeBox(imageBox, 2);
|
||||
const pass2Detections = await this.estimateFaces(paddedImage);
|
||||
// TODO: select based on nearest under certain threshold
|
||||
// this will matter based on our IOU threshold and faces near each other
|
||||
const selected = pass2Detections[0];
|
||||
// console.log("pass2: ", face.probability, selected.probability);
|
||||
|
||||
// TODO: we might miss 1st pass face actually having score within threshold
|
||||
if (selected && selected.probability >= BLAZEFACE_SCORE_THRESHOLD) {
|
||||
const inBox = newBox(0, 0, faceImage.width, faceImage.height);
|
||||
imageBox.x = paddedBox.x;
|
||||
imageBox.y = paddedBox.y;
|
||||
const transform = computeTransformToBox(inBox, imageBox);
|
||||
|
||||
selected.box = transformBox(selected.box, transform);
|
||||
selected.landmarks = transformPoints(
|
||||
selected.landmarks,
|
||||
transform
|
||||
);
|
||||
// console.log("pass2: ", { imageBox, paddedBox, transform, selected });
|
||||
detections.push(selected);
|
||||
}
|
||||
}
|
||||
|
||||
return detections;
|
||||
}
|
||||
|
||||
public async dispose() {
|
||||
const blazeFaceModel = await this.getBlazefaceModel();
|
||||
blazeFaceModel?.dispose();
|
||||
|
|
|
@ -273,7 +273,8 @@ export declare type MLIndex = 'files' | 'people';
|
|||
export const BLAZEFACE_MAX_FACES = 20;
|
||||
export const BLAZEFACE_INPUT_SIZE = 256;
|
||||
export const BLAZEFACE_IOU_THRESHOLD = 0.3;
|
||||
export const BLAZEFACE_SCORE_THRESHOLD = 0.65;
|
||||
export const BLAZEFACE_SCORE_THRESHOLD = 0.7;
|
||||
export const BLAZEFACE_PASS1_SCORE_THRESHOLD = 0.4;
|
||||
export const BLAZEFACE_FACE_SIZE = 112;
|
||||
|
||||
export interface FaceDetectionService {
|
||||
|
|
|
@ -39,6 +39,15 @@ export function transform(
|
|||
return offscreen.transferToImageBitmap();
|
||||
}
|
||||
|
||||
export function crop(imageBitmap: ImageBitmap, cropBox: Box, size: number) {
|
||||
const dimensions: Dimensions = {
|
||||
width: size,
|
||||
height: size,
|
||||
};
|
||||
|
||||
return cropWithRotation(imageBitmap, cropBox, 0, dimensions, dimensions);
|
||||
}
|
||||
|
||||
export function cropWithRotation(
|
||||
imageBitmap: ImageBitmap,
|
||||
cropBox: Box,
|
||||
|
@ -105,6 +114,24 @@ export function cropWithRotation(
|
|||
return offscreen.transferToImageBitmap();
|
||||
}
|
||||
|
||||
export function addPadding(image: ImageBitmap, padding: number) {
|
||||
const scale = 1 + padding * 2;
|
||||
const width = scale * image.width;
|
||||
const height = scale * image.height;
|
||||
const offscreen = new OffscreenCanvas(width, height);
|
||||
const ctx = offscreen.getContext('2d');
|
||||
ctx.imageSmoothingEnabled = false;
|
||||
ctx.drawImage(
|
||||
image,
|
||||
width / 2 - image.width / 2,
|
||||
height / 2 - image.height / 2,
|
||||
image.width,
|
||||
image.height
|
||||
);
|
||||
|
||||
return offscreen.transferToImageBitmap();
|
||||
}
|
||||
|
||||
export async function imageBitmapToBlob(
|
||||
imageBitmap: ImageBitmap,
|
||||
options?: BlobOptions
|
||||
|
|
|
@ -40,25 +40,25 @@ export async function getStoredFaceCrop(
|
|||
};
|
||||
}
|
||||
|
||||
export async function ibExtractFaceImageFromCrop(
|
||||
alignedFace: AlignedFace,
|
||||
export function extractFaceImageFromCrop(
|
||||
faceCrop: FaceCrop,
|
||||
box: Box,
|
||||
rotation: number,
|
||||
faceSize: number
|
||||
): Promise<ImageBitmap> {
|
||||
const image = alignedFace.faceCrop?.image;
|
||||
let imageBox = alignedFace.faceCrop?.imageBox;
|
||||
if (!image || !imageBox) {
|
||||
): ImageBitmap {
|
||||
const faceCropImage = faceCrop?.image;
|
||||
let imageBox = faceCrop?.imageBox;
|
||||
if (!faceCropImage || !imageBox) {
|
||||
throw Error('Face crop not present');
|
||||
}
|
||||
|
||||
const box = getAlignedFaceBox(alignedFace);
|
||||
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
|
||||
|
||||
// TODO: Have better serialization to avoid creating new object manually when calling class methods
|
||||
imageBox = new Box(imageBox);
|
||||
const scale = faceCropImage.width / imageBox.width;
|
||||
const scaledImageBox = imageBox.rescale(scale).round();
|
||||
const scaledBox = box.rescale(scale).round();
|
||||
const shiftedBox = scaledBox.shift(-scaledImageBox.x, -scaledImageBox.y);
|
||||
const transformedBox = box
|
||||
.shift(-imageBox.x, -imageBox.y)
|
||||
.rescale(scale)
|
||||
.round();
|
||||
// console.log({ box, imageBox, faceCropImage, scale, scaledBox, scaledImageBox, shiftedBox });
|
||||
|
||||
const faceSizeDimentions: Dimensions = {
|
||||
|
@ -67,8 +67,8 @@ export async function ibExtractFaceImageFromCrop(
|
|||
};
|
||||
const faceImage = cropWithRotation(
|
||||
faceCropImage,
|
||||
shiftedBox,
|
||||
alignedFace.rotation,
|
||||
transformedBox,
|
||||
rotation,
|
||||
faceSizeDimentions,
|
||||
faceSizeDimentions
|
||||
);
|
||||
|
@ -76,6 +76,21 @@ export async function ibExtractFaceImageFromCrop(
|
|||
return faceImage;
|
||||
}
|
||||
|
||||
export async function ibExtractFaceImageFromCrop(
|
||||
alignedFace: AlignedFace,
|
||||
faceSize: number
|
||||
): Promise<ImageBitmap> {
|
||||
const box = getAlignedFaceBox(alignedFace);
|
||||
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
|
||||
|
||||
return extractFaceImageFromCrop(
|
||||
{ image: faceCropImage, imageBox: alignedFace.faceCrop?.imageBox },
|
||||
box,
|
||||
alignedFace.rotation,
|
||||
faceSize
|
||||
);
|
||||
}
|
||||
|
||||
export async function ibExtractFaceImagesFromCrops(
|
||||
faces: AlignedFace[],
|
||||
faceSize: number
|
||||
|
|
|
@ -16,6 +16,7 @@ import { mlFilesStore, mlPeopleStore } from 'utils/storage/mlStorage';
|
|||
import { convertForPreview, needsConversionForPreview } from 'utils/file';
|
||||
import { cached } from 'utils/storage/cache';
|
||||
import { imageBitmapToBlob } from 'utils/image';
|
||||
import { NormalizedFace } from '@tensorflow-models/blazeface';
|
||||
|
||||
export function f32Average(descriptors: Float32Array[]) {
|
||||
if (descriptors.length < 1) {
|
||||
|
@ -123,6 +124,28 @@ export function extractFaces(
|
|||
});
|
||||
}
|
||||
|
||||
export function newBox(x: number, y: number, width: number, height: number) {
|
||||
return new Box({ x, y, width, height });
|
||||
}
|
||||
|
||||
export function newBoxFromPoints(
|
||||
left: number,
|
||||
top: number,
|
||||
right: number,
|
||||
bottom: number
|
||||
) {
|
||||
return new Box({ left, top, right, bottom });
|
||||
}
|
||||
|
||||
export function normFaceBox(face: NormalizedFace) {
|
||||
return newBoxFromPoints(
|
||||
face.topLeft[0],
|
||||
face.topLeft[1],
|
||||
face.bottomRight[0],
|
||||
face.bottomRight[1]
|
||||
);
|
||||
}
|
||||
|
||||
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
|
||||
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
|
||||
}
|
||||
|
@ -204,6 +227,15 @@ async function getOriginalImageFile(file: File, token: string) {
|
|||
return new Response(fileStream).blob();
|
||||
}
|
||||
|
||||
async function getOriginalFile(file: File, token: string) {
|
||||
let fileBlob = await getOriginalImageFile(file, token);
|
||||
if (needsConversionForPreview(file)) {
|
||||
fileBlob = await convertForPreview(file, fileBlob);
|
||||
}
|
||||
|
||||
return fileBlob;
|
||||
}
|
||||
|
||||
export async function getOriginalImageBitmap(
|
||||
file: File,
|
||||
token: string,
|
||||
|
@ -213,17 +245,13 @@ export async function getOriginalImageBitmap(
|
|||
|
||||
if (useCache) {
|
||||
fileBlob = await cached('files', '/' + file.id.toString(), () => {
|
||||
return getOriginalImageFile(file, token);
|
||||
return getOriginalFile(file, token);
|
||||
});
|
||||
} else {
|
||||
fileBlob = await getOriginalImageFile(file, token);
|
||||
fileBlob = await getOriginalFile(file, token);
|
||||
}
|
||||
console.log('[MLService] Got file: ', file.id.toString());
|
||||
|
||||
if (needsConversionForPreview(file)) {
|
||||
fileBlob = await convertForPreview(file, fileBlob);
|
||||
}
|
||||
|
||||
return getImageBitmap(fileBlob);
|
||||
}
|
||||
|
||||
|
|
64
src/utils/machineLearning/transform.ts
Normal file
64
src/utils/machineLearning/transform.ts
Normal file
|
@ -0,0 +1,64 @@
|
|||
import { Box, Point } from '../../../thirdparty/face-api/classes';
|
||||
import { Matrix } from 'ml-matrix';
|
||||
import { newBoxFromPoints } from '.';
|
||||
|
||||
export function translation(x: number, y: number) {
|
||||
return new Matrix([
|
||||
[1, 0, x],
|
||||
[0, 1, y],
|
||||
[0, 0, 1],
|
||||
]);
|
||||
}
|
||||
|
||||
export function scale(sx: number, sy: number) {
|
||||
return new Matrix([
|
||||
[sx, 0, 0],
|
||||
[0, sy, 0],
|
||||
[0, 0, 1],
|
||||
]);
|
||||
}
|
||||
|
||||
export function rotation(angle: number) {
|
||||
const cosa = Math.cos(angle);
|
||||
const sina = Math.sin(angle);
|
||||
return new Matrix([
|
||||
[cosa, -sina, 0],
|
||||
[sina, cosa, 0],
|
||||
[0, 0, 1],
|
||||
]);
|
||||
}
|
||||
|
||||
export function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
|
||||
return translation(toBox.x, toBox.y).mmul(
|
||||
scale(toBox.width / inBox.width, toBox.height / inBox.height)
|
||||
);
|
||||
}
|
||||
|
||||
export function pointToArray(point: Point) {
|
||||
return [point.x, point.y];
|
||||
}
|
||||
|
||||
export function transformPointVec(point: number[], transform: Matrix) {
|
||||
point[2] = 1;
|
||||
const mat = new Matrix([point]).transpose();
|
||||
const mulmat = new Matrix(transform).mmul(mat).to1DArray();
|
||||
// console.log({point, mat, mulmat});
|
||||
|
||||
return mulmat;
|
||||
}
|
||||
|
||||
export function transformPoint(point: Point, transform: Matrix) {
|
||||
const pointVec = transformPointVec(pointToArray(point), transform);
|
||||
return new Point(pointVec[0], pointVec[1]);
|
||||
}
|
||||
|
||||
export function transformPoints(points: Point[], transform: Matrix) {
|
||||
return points.map((p) => transformPoint(p, transform));
|
||||
}
|
||||
|
||||
export function transformBox(box: Box, transform: Matrix) {
|
||||
const topLeft = transformPoint(box.topLeft, transform);
|
||||
const bottomRight = transformPoint(box.bottomRight, transform);
|
||||
|
||||
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
|
||||
}
|
Loading…
Reference in a new issue