Do 2nd pass of blazeface on close ups
for better accuracy Transform utils
This commit is contained in:
parent
72b2a6ad8b
commit
385acec1ab
7 changed files with 229 additions and 50 deletions
|
@ -14,14 +14,14 @@ const FaceChipContainer = styled.div`
|
||||||
flex-wrap: wrap;
|
flex-wrap: wrap;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
margin-top: 10px;
|
margin-top: 5px;
|
||||||
margin-bottom: 10px;
|
margin-bottom: 5px;
|
||||||
`;
|
`;
|
||||||
|
|
||||||
const FaceChip = styled.div`
|
const FaceChip = styled.div`
|
||||||
width: 112px;
|
width: 112px;
|
||||||
height: 112px;
|
height: 112px;
|
||||||
margin-right: 10px;
|
margin: 5px;
|
||||||
border-radius: 50%;
|
border-radius: 50%;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
position: relative;
|
position: relative;
|
||||||
|
|
|
@ -10,6 +10,7 @@ import {
|
||||||
BLAZEFACE_INPUT_SIZE,
|
BLAZEFACE_INPUT_SIZE,
|
||||||
BLAZEFACE_IOU_THRESHOLD,
|
BLAZEFACE_IOU_THRESHOLD,
|
||||||
BLAZEFACE_MAX_FACES,
|
BLAZEFACE_MAX_FACES,
|
||||||
|
BLAZEFACE_PASS1_SCORE_THRESHOLD,
|
||||||
BLAZEFACE_SCORE_THRESHOLD,
|
BLAZEFACE_SCORE_THRESHOLD,
|
||||||
DetectedFace,
|
DetectedFace,
|
||||||
FaceDetectionMethod,
|
FaceDetectionMethod,
|
||||||
|
@ -17,7 +18,13 @@ import {
|
||||||
Versioned,
|
Versioned,
|
||||||
} from 'types/machineLearning';
|
} from 'types/machineLearning';
|
||||||
import { Box, Point } from '../../../thirdparty/face-api/classes';
|
import { Box, Point } from '../../../thirdparty/face-api/classes';
|
||||||
import { resizeToSquare } from 'utils/image';
|
import { addPadding, crop, resizeToSquare } from 'utils/image';
|
||||||
|
import {
|
||||||
|
computeTransformToBox,
|
||||||
|
transformBox,
|
||||||
|
transformPoints,
|
||||||
|
} from 'utils/machineLearning/transform';
|
||||||
|
import { enlargeBox, newBox, normFaceBox } from 'utils/machineLearning';
|
||||||
|
|
||||||
class TFJSFaceDetectionService implements FaceDetectionService {
|
class TFJSFaceDetectionService implements FaceDetectionService {
|
||||||
private blazeFaceModel: Promise<BlazeFaceModel>;
|
private blazeFaceModel: Promise<BlazeFaceModel>;
|
||||||
|
@ -38,7 +45,7 @@ class TFJSFaceDetectionService implements FaceDetectionService {
|
||||||
private async init() {
|
private async init() {
|
||||||
this.blazeFaceModel = blazeFaceLoad({
|
this.blazeFaceModel = blazeFaceLoad({
|
||||||
maxFaces: BLAZEFACE_MAX_FACES,
|
maxFaces: BLAZEFACE_MAX_FACES,
|
||||||
scoreThreshold: BLAZEFACE_SCORE_THRESHOLD,
|
scoreThreshold: BLAZEFACE_PASS1_SCORE_THRESHOLD,
|
||||||
iouThreshold: BLAZEFACE_IOU_THRESHOLD,
|
iouThreshold: BLAZEFACE_IOU_THRESHOLD,
|
||||||
modelUrl: '/models/blazeface/back/model.json',
|
modelUrl: '/models/blazeface/back/model.json',
|
||||||
inputHeight: BLAZEFACE_INPUT_SIZE,
|
inputHeight: BLAZEFACE_INPUT_SIZE,
|
||||||
|
@ -149,42 +156,79 @@ class TFJSFaceDetectionService implements FaceDetectionService {
|
||||||
return this.blazeFaceModel;
|
return this.blazeFaceModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async detectFaces(
|
private async estimateFaces(
|
||||||
imageBitmap: ImageBitmap
|
imageBitmap: ImageBitmap
|
||||||
): Promise<Array<DetectedFace>> {
|
): Promise<Array<DetectedFace>> {
|
||||||
const resized = resizeToSquare(imageBitmap, BLAZEFACE_INPUT_SIZE);
|
const resized = resizeToSquare(imageBitmap, BLAZEFACE_INPUT_SIZE);
|
||||||
const widthRatio = imageBitmap.width / resized.width;
|
|
||||||
const heightRatio = imageBitmap.height / resized.height;
|
|
||||||
const tfImage = tf.browser.fromPixels(resized.image);
|
const tfImage = tf.browser.fromPixels(resized.image);
|
||||||
const blazeFaceModel = await this.getBlazefaceModel();
|
const blazeFaceModel = await this.getBlazefaceModel();
|
||||||
const faces = await blazeFaceModel.estimateFaces(tfImage);
|
const faces = await blazeFaceModel.estimateFaces(tfImage);
|
||||||
tf.dispose(tfImage);
|
tf.dispose(tfImage);
|
||||||
|
|
||||||
const detectedFaces: Array<DetectedFace> = faces?.map(
|
const inBox = newBox(0, 0, resized.width, resized.height);
|
||||||
(normalizedFace) => {
|
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
|
||||||
const landmarks = normalizedFace.landmarks as number[][];
|
const transform = computeTransformToBox(inBox, toBox);
|
||||||
return {
|
// console.log("1st pass: ", { transform });
|
||||||
box: new Box({
|
|
||||||
left: normalizedFace.topLeft[0] * widthRatio,
|
const detectedFaces: Array<DetectedFace> = faces?.map((f) => {
|
||||||
top: normalizedFace.topLeft[1] * heightRatio,
|
const box = transformBox(normFaceBox(f), transform);
|
||||||
right: normalizedFace.bottomRight[0] * widthRatio,
|
const normLandmarks = (f.landmarks as number[][])?.map(
|
||||||
bottom: normalizedFace.bottomRight[1] * heightRatio,
|
(l) => new Point(l[0], l[1])
|
||||||
}),
|
);
|
||||||
landmarks:
|
const landmarks = transformPoints(normLandmarks, transform);
|
||||||
landmarks &&
|
return {
|
||||||
landmarks.map(
|
box,
|
||||||
(l) =>
|
landmarks,
|
||||||
new Point(l[0] * widthRatio, l[1] * heightRatio)
|
probability: f.probability as number,
|
||||||
),
|
// detectionMethod: this.method,
|
||||||
probability: normalizedFace.probability as number,
|
} as DetectedFace;
|
||||||
// detectionMethod: this.method,
|
});
|
||||||
} as DetectedFace;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
return detectedFaces;
|
return detectedFaces;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async detectFaces(
|
||||||
|
imageBitmap: ImageBitmap
|
||||||
|
): Promise<Array<DetectedFace>> {
|
||||||
|
const pass1Detections = await this.estimateFaces(imageBitmap);
|
||||||
|
|
||||||
|
// run 2nd pass for accuracy
|
||||||
|
const detections = [];
|
||||||
|
for (const face of pass1Detections) {
|
||||||
|
const imageBox = enlargeBox(face.box, 2);
|
||||||
|
const faceImage = crop(
|
||||||
|
imageBitmap,
|
||||||
|
imageBox,
|
||||||
|
BLAZEFACE_INPUT_SIZE / 2
|
||||||
|
);
|
||||||
|
const paddedImage = addPadding(faceImage, 0.5);
|
||||||
|
const paddedBox = enlargeBox(imageBox, 2);
|
||||||
|
const pass2Detections = await this.estimateFaces(paddedImage);
|
||||||
|
// TODO: select based on nearest under certain threshold
|
||||||
|
// this will matter based on our IOU threshold and faces near each other
|
||||||
|
const selected = pass2Detections[0];
|
||||||
|
// console.log("pass2: ", face.probability, selected.probability);
|
||||||
|
|
||||||
|
// TODO: we might miss 1st pass face actually having score within threshold
|
||||||
|
if (selected && selected.probability >= BLAZEFACE_SCORE_THRESHOLD) {
|
||||||
|
const inBox = newBox(0, 0, faceImage.width, faceImage.height);
|
||||||
|
imageBox.x = paddedBox.x;
|
||||||
|
imageBox.y = paddedBox.y;
|
||||||
|
const transform = computeTransformToBox(inBox, imageBox);
|
||||||
|
|
||||||
|
selected.box = transformBox(selected.box, transform);
|
||||||
|
selected.landmarks = transformPoints(
|
||||||
|
selected.landmarks,
|
||||||
|
transform
|
||||||
|
);
|
||||||
|
// console.log("pass2: ", { imageBox, paddedBox, transform, selected });
|
||||||
|
detections.push(selected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return detections;
|
||||||
|
}
|
||||||
|
|
||||||
public async dispose() {
|
public async dispose() {
|
||||||
const blazeFaceModel = await this.getBlazefaceModel();
|
const blazeFaceModel = await this.getBlazefaceModel();
|
||||||
blazeFaceModel?.dispose();
|
blazeFaceModel?.dispose();
|
||||||
|
|
|
@ -273,7 +273,8 @@ export declare type MLIndex = 'files' | 'people';
|
||||||
export const BLAZEFACE_MAX_FACES = 20;
|
export const BLAZEFACE_MAX_FACES = 20;
|
||||||
export const BLAZEFACE_INPUT_SIZE = 256;
|
export const BLAZEFACE_INPUT_SIZE = 256;
|
||||||
export const BLAZEFACE_IOU_THRESHOLD = 0.3;
|
export const BLAZEFACE_IOU_THRESHOLD = 0.3;
|
||||||
export const BLAZEFACE_SCORE_THRESHOLD = 0.65;
|
export const BLAZEFACE_SCORE_THRESHOLD = 0.7;
|
||||||
|
export const BLAZEFACE_PASS1_SCORE_THRESHOLD = 0.4;
|
||||||
export const BLAZEFACE_FACE_SIZE = 112;
|
export const BLAZEFACE_FACE_SIZE = 112;
|
||||||
|
|
||||||
export interface FaceDetectionService {
|
export interface FaceDetectionService {
|
||||||
|
|
|
@ -39,6 +39,15 @@ export function transform(
|
||||||
return offscreen.transferToImageBitmap();
|
return offscreen.transferToImageBitmap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function crop(imageBitmap: ImageBitmap, cropBox: Box, size: number) {
|
||||||
|
const dimensions: Dimensions = {
|
||||||
|
width: size,
|
||||||
|
height: size,
|
||||||
|
};
|
||||||
|
|
||||||
|
return cropWithRotation(imageBitmap, cropBox, 0, dimensions, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
export function cropWithRotation(
|
export function cropWithRotation(
|
||||||
imageBitmap: ImageBitmap,
|
imageBitmap: ImageBitmap,
|
||||||
cropBox: Box,
|
cropBox: Box,
|
||||||
|
@ -105,6 +114,24 @@ export function cropWithRotation(
|
||||||
return offscreen.transferToImageBitmap();
|
return offscreen.transferToImageBitmap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function addPadding(image: ImageBitmap, padding: number) {
|
||||||
|
const scale = 1 + padding * 2;
|
||||||
|
const width = scale * image.width;
|
||||||
|
const height = scale * image.height;
|
||||||
|
const offscreen = new OffscreenCanvas(width, height);
|
||||||
|
const ctx = offscreen.getContext('2d');
|
||||||
|
ctx.imageSmoothingEnabled = false;
|
||||||
|
ctx.drawImage(
|
||||||
|
image,
|
||||||
|
width / 2 - image.width / 2,
|
||||||
|
height / 2 - image.height / 2,
|
||||||
|
image.width,
|
||||||
|
image.height
|
||||||
|
);
|
||||||
|
|
||||||
|
return offscreen.transferToImageBitmap();
|
||||||
|
}
|
||||||
|
|
||||||
export async function imageBitmapToBlob(
|
export async function imageBitmapToBlob(
|
||||||
imageBitmap: ImageBitmap,
|
imageBitmap: ImageBitmap,
|
||||||
options?: BlobOptions
|
options?: BlobOptions
|
||||||
|
|
|
@ -40,25 +40,25 @@ export async function getStoredFaceCrop(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function ibExtractFaceImageFromCrop(
|
export function extractFaceImageFromCrop(
|
||||||
alignedFace: AlignedFace,
|
faceCrop: FaceCrop,
|
||||||
|
box: Box,
|
||||||
|
rotation: number,
|
||||||
faceSize: number
|
faceSize: number
|
||||||
): Promise<ImageBitmap> {
|
): ImageBitmap {
|
||||||
const image = alignedFace.faceCrop?.image;
|
const faceCropImage = faceCrop?.image;
|
||||||
let imageBox = alignedFace.faceCrop?.imageBox;
|
let imageBox = faceCrop?.imageBox;
|
||||||
if (!image || !imageBox) {
|
if (!faceCropImage || !imageBox) {
|
||||||
throw Error('Face crop not present');
|
throw Error('Face crop not present');
|
||||||
}
|
}
|
||||||
|
|
||||||
const box = getAlignedFaceBox(alignedFace);
|
|
||||||
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
|
|
||||||
|
|
||||||
// TODO: Have better serialization to avoid creating new object manually when calling class methods
|
// TODO: Have better serialization to avoid creating new object manually when calling class methods
|
||||||
imageBox = new Box(imageBox);
|
imageBox = new Box(imageBox);
|
||||||
const scale = faceCropImage.width / imageBox.width;
|
const scale = faceCropImage.width / imageBox.width;
|
||||||
const scaledImageBox = imageBox.rescale(scale).round();
|
const transformedBox = box
|
||||||
const scaledBox = box.rescale(scale).round();
|
.shift(-imageBox.x, -imageBox.y)
|
||||||
const shiftedBox = scaledBox.shift(-scaledImageBox.x, -scaledImageBox.y);
|
.rescale(scale)
|
||||||
|
.round();
|
||||||
// console.log({ box, imageBox, faceCropImage, scale, scaledBox, scaledImageBox, shiftedBox });
|
// console.log({ box, imageBox, faceCropImage, scale, scaledBox, scaledImageBox, shiftedBox });
|
||||||
|
|
||||||
const faceSizeDimentions: Dimensions = {
|
const faceSizeDimentions: Dimensions = {
|
||||||
|
@ -67,8 +67,8 @@ export async function ibExtractFaceImageFromCrop(
|
||||||
};
|
};
|
||||||
const faceImage = cropWithRotation(
|
const faceImage = cropWithRotation(
|
||||||
faceCropImage,
|
faceCropImage,
|
||||||
shiftedBox,
|
transformedBox,
|
||||||
alignedFace.rotation,
|
rotation,
|
||||||
faceSizeDimentions,
|
faceSizeDimentions,
|
||||||
faceSizeDimentions
|
faceSizeDimentions
|
||||||
);
|
);
|
||||||
|
@ -76,6 +76,21 @@ export async function ibExtractFaceImageFromCrop(
|
||||||
return faceImage;
|
return faceImage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function ibExtractFaceImageFromCrop(
|
||||||
|
alignedFace: AlignedFace,
|
||||||
|
faceSize: number
|
||||||
|
): Promise<ImageBitmap> {
|
||||||
|
const box = getAlignedFaceBox(alignedFace);
|
||||||
|
const faceCropImage = await createImageBitmap(alignedFace.faceCrop.image);
|
||||||
|
|
||||||
|
return extractFaceImageFromCrop(
|
||||||
|
{ image: faceCropImage, imageBox: alignedFace.faceCrop?.imageBox },
|
||||||
|
box,
|
||||||
|
alignedFace.rotation,
|
||||||
|
faceSize
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export async function ibExtractFaceImagesFromCrops(
|
export async function ibExtractFaceImagesFromCrops(
|
||||||
faces: AlignedFace[],
|
faces: AlignedFace[],
|
||||||
faceSize: number
|
faceSize: number
|
||||||
|
|
|
@ -16,6 +16,7 @@ import { mlFilesStore, mlPeopleStore } from 'utils/storage/mlStorage';
|
||||||
import { convertForPreview, needsConversionForPreview } from 'utils/file';
|
import { convertForPreview, needsConversionForPreview } from 'utils/file';
|
||||||
import { cached } from 'utils/storage/cache';
|
import { cached } from 'utils/storage/cache';
|
||||||
import { imageBitmapToBlob } from 'utils/image';
|
import { imageBitmapToBlob } from 'utils/image';
|
||||||
|
import { NormalizedFace } from '@tensorflow-models/blazeface';
|
||||||
|
|
||||||
export function f32Average(descriptors: Float32Array[]) {
|
export function f32Average(descriptors: Float32Array[]) {
|
||||||
if (descriptors.length < 1) {
|
if (descriptors.length < 1) {
|
||||||
|
@ -123,6 +124,28 @@ export function extractFaces(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function newBox(x: number, y: number, width: number, height: number) {
|
||||||
|
return new Box({ x, y, width, height });
|
||||||
|
}
|
||||||
|
|
||||||
|
export function newBoxFromPoints(
|
||||||
|
left: number,
|
||||||
|
top: number,
|
||||||
|
right: number,
|
||||||
|
bottom: number
|
||||||
|
) {
|
||||||
|
return new Box({ left, top, right, bottom });
|
||||||
|
}
|
||||||
|
|
||||||
|
export function normFaceBox(face: NormalizedFace) {
|
||||||
|
return newBoxFromPoints(
|
||||||
|
face.topLeft[0],
|
||||||
|
face.topLeft[1],
|
||||||
|
face.bottomRight[0],
|
||||||
|
face.bottomRight[1]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
|
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
|
||||||
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
|
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
|
||||||
}
|
}
|
||||||
|
@ -204,6 +227,15 @@ async function getOriginalImageFile(file: File, token: string) {
|
||||||
return new Response(fileStream).blob();
|
return new Response(fileStream).blob();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getOriginalFile(file: File, token: string) {
|
||||||
|
let fileBlob = await getOriginalImageFile(file, token);
|
||||||
|
if (needsConversionForPreview(file)) {
|
||||||
|
fileBlob = await convertForPreview(file, fileBlob);
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileBlob;
|
||||||
|
}
|
||||||
|
|
||||||
export async function getOriginalImageBitmap(
|
export async function getOriginalImageBitmap(
|
||||||
file: File,
|
file: File,
|
||||||
token: string,
|
token: string,
|
||||||
|
@ -213,17 +245,13 @@ export async function getOriginalImageBitmap(
|
||||||
|
|
||||||
if (useCache) {
|
if (useCache) {
|
||||||
fileBlob = await cached('files', '/' + file.id.toString(), () => {
|
fileBlob = await cached('files', '/' + file.id.toString(), () => {
|
||||||
return getOriginalImageFile(file, token);
|
return getOriginalFile(file, token);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
fileBlob = await getOriginalImageFile(file, token);
|
fileBlob = await getOriginalFile(file, token);
|
||||||
}
|
}
|
||||||
console.log('[MLService] Got file: ', file.id.toString());
|
console.log('[MLService] Got file: ', file.id.toString());
|
||||||
|
|
||||||
if (needsConversionForPreview(file)) {
|
|
||||||
fileBlob = await convertForPreview(file, fileBlob);
|
|
||||||
}
|
|
||||||
|
|
||||||
return getImageBitmap(fileBlob);
|
return getImageBitmap(fileBlob);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
64
src/utils/machineLearning/transform.ts
Normal file
64
src/utils/machineLearning/transform.ts
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
import { Box, Point } from '../../../thirdparty/face-api/classes';
|
||||||
|
import { Matrix } from 'ml-matrix';
|
||||||
|
import { newBoxFromPoints } from '.';
|
||||||
|
|
||||||
|
export function translation(x: number, y: number) {
|
||||||
|
return new Matrix([
|
||||||
|
[1, 0, x],
|
||||||
|
[0, 1, y],
|
||||||
|
[0, 0, 1],
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function scale(sx: number, sy: number) {
|
||||||
|
return new Matrix([
|
||||||
|
[sx, 0, 0],
|
||||||
|
[0, sy, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function rotation(angle: number) {
|
||||||
|
const cosa = Math.cos(angle);
|
||||||
|
const sina = Math.sin(angle);
|
||||||
|
return new Matrix([
|
||||||
|
[cosa, -sina, 0],
|
||||||
|
[sina, cosa, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
|
||||||
|
return translation(toBox.x, toBox.y).mmul(
|
||||||
|
scale(toBox.width / inBox.width, toBox.height / inBox.height)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function pointToArray(point: Point) {
|
||||||
|
return [point.x, point.y];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function transformPointVec(point: number[], transform: Matrix) {
|
||||||
|
point[2] = 1;
|
||||||
|
const mat = new Matrix([point]).transpose();
|
||||||
|
const mulmat = new Matrix(transform).mmul(mat).to1DArray();
|
||||||
|
// console.log({point, mat, mulmat});
|
||||||
|
|
||||||
|
return mulmat;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function transformPoint(point: Point, transform: Matrix) {
|
||||||
|
const pointVec = transformPointVec(pointToArray(point), transform);
|
||||||
|
return new Point(pointVec[0], pointVec[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function transformPoints(points: Point[], transform: Matrix) {
|
||||||
|
return points.map((p) => transformPoint(p, transform));
|
||||||
|
}
|
||||||
|
|
||||||
|
export function transformBox(box: Box, transform: Matrix) {
|
||||||
|
const topLeft = transformPoint(box.topLeft, transform);
|
||||||
|
const bottomRight = transformPoint(box.bottomRight, transform);
|
||||||
|
|
||||||
|
return newBoxFromPoints(topLeft.x, topLeft.y, bottomRight.x, bottomRight.y);
|
||||||
|
}
|
Loading…
Reference in a new issue