Remove tf utils

This commit is contained in:
Manav Rathi 2024-04-11 11:13:21 +05:30
parent 1ad5cb83f9
commit 3182d67ca1
No known key found for this signature in database
2 changed files with 0 additions and 174 deletions

View file

@ -131,25 +131,6 @@ export function extractFaceImage(
});
}
export function tfExtractFaceImages(
image: tf.Tensor3D | tf.Tensor4D,
alignments: Array<FaceAlignment>,
faceSize: number,
): tf.Tensor4D {
return tf.tidy(() => {
const tf4dFloat32Image = toTensor4D(image, "float32");
const faceImages = new Array<tf.Tensor3D>(alignments.length);
for (let i = 0; i < alignments.length; i++) {
faceImages[i] = tf.squeeze(
extractFaceImage(tf4dFloat32Image, alignments[i], faceSize),
[0],
);
}
return tf.stack(faceImages) as tf.Tensor4D;
});
}
export function getAlignedFaceBox(alignment: FaceAlignment) {
return new Box({
x: alignment.center.x - alignment.size / 2,
@ -200,59 +181,3 @@ export function ibExtractFaceImages(
ibExtractFaceImage(image, alignment, faceSize),
);
}
const BLAZEFACE_LEFT_EYE_INDEX = 0;
const BLAZEFACE_RIGHT_EYE_INDEX = 1;
// const BLAZEFACE_NOSE_INDEX = 2;
const BLAZEFACE_MOUTH_INDEX = 3;
export function getRotatedFaceImage(
image: tf.Tensor3D | tf.Tensor4D,
faceDetection: FaceDetection,
padding: number = 1.5,
): tf.Tensor4D {
const paddedBox = enlargeBox(faceDetection.box, padding);
// log.info("paddedBox", paddedBox);
const landmarkPoints = faceDetection.landmarks;
return tf.tidy(() => {
const tf4dFloat32Image = toTensor4D(image, "float32");
let angle = 0;
const leftEye = landmarkPoints[BLAZEFACE_LEFT_EYE_INDEX];
const rightEye = landmarkPoints[BLAZEFACE_RIGHT_EYE_INDEX];
const foreheadCenter = getBoxCenterPt(leftEye, rightEye);
angle = computeRotation(
landmarkPoints[BLAZEFACE_MOUTH_INDEX],
foreheadCenter,
); // landmarkPoints[BLAZEFACE_NOSE_INDEX]
// angle = computeRotation(leftEye, rightEye);
// log.info('angle: ', angle);
const faceCenter = getBoxCenter(faceDetection.box);
// log.info('faceCenter: ', faceCenter);
const faceCenterNormalized: [number, number] = [
faceCenter.x / tf4dFloat32Image.shape[2],
faceCenter.y / tf4dFloat32Image.shape[1],
];
// log.info('faceCenterNormalized: ', faceCenterNormalized);
let rotatedImage = tf4dFloat32Image;
if (angle !== 0) {
rotatedImage = tf.image.rotateWithOffset(
tf4dFloat32Image,
angle,
0,
faceCenterNormalized,
);
}
const faceImageTensor = extractFaces(
rotatedImage,
[paddedBox],
paddedBox.width > 224 ? 448 : 224,
);
return faceImageTensor;
// return tf.gather(faceImageTensor, 0);
});
}

View file

@ -1,7 +1,6 @@
import log from "@/next/log";
import { CACHES } from "@ente/shared/storage/cacheStorage/constants";
import { cached } from "@ente/shared/storage/cacheStorage/helpers";
import * as tf from "@tensorflow/tfjs-core";
import { NormalizedFace } from "blazeface-back";
import { FILE_TYPE } from "constants/file";
import { BLAZEFACE_FACE_SIZE } from "constants/mlConfig";
@ -52,89 +51,6 @@ export function f32Average(descriptors: Float32Array[]) {
return avg;
}
export function isTensor(tensor: any, dim: number) {
return tensor instanceof tf.Tensor && tensor.shape.length === dim;
}
export function isTensor1D(tensor: any): tensor is tf.Tensor1D {
return isTensor(tensor, 1);
}
export function isTensor2D(tensor: any): tensor is tf.Tensor2D {
return isTensor(tensor, 2);
}
export function isTensor3D(tensor: any): tensor is tf.Tensor3D {
return isTensor(tensor, 3);
}
export function isTensor4D(tensor: any): tensor is tf.Tensor4D {
return isTensor(tensor, 4);
}
export function toTensor4D(
image: tf.Tensor3D | tf.Tensor4D,
dtype?: tf.DataType,
) {
return tf.tidy(() => {
let reshapedImage: tf.Tensor4D;
if (isTensor3D(image)) {
reshapedImage = tf.expandDims(image, 0);
} else if (isTensor4D(image)) {
reshapedImage = image;
} else {
throw Error("toTensor4D only supports Tensor3D and Tensor4D input");
}
if (dtype) {
reshapedImage = tf.cast(reshapedImage, dtype);
}
return reshapedImage;
});
}
export function imageBitmapsToTensor4D(imageBitmaps: Array<ImageBitmap>) {
return tf.tidy(() => {
const tfImages = imageBitmaps.map((ib) => tf.browser.fromPixels(ib));
return tf.stack(tfImages) as tf.Tensor4D;
});
}
export function extractFaces(
image: tf.Tensor3D | tf.Tensor4D,
facebBoxes: Array<Box>,
faceSize: number,
) {
return tf.tidy(() => {
const reshapedImage = toTensor4D(image, "float32");
const boxes = facebBoxes.map((box) => {
const normalized = box.rescale({
width: 1 / reshapedImage.shape[2],
height: 1 / reshapedImage.shape[1],
});
return [
normalized.top,
normalized.left,
normalized.bottom,
normalized.right,
];
});
// log.info('boxes: ', boxes[0]);
const faceImagesTensor = tf.image.cropAndResize(
reshapedImage,
boxes,
tf.fill([boxes.length], 0, "int32"),
[faceSize, faceSize],
);
return faceImagesTensor;
});
}
export function newBox(x: number, y: number, width: number, height: number) {
return new Box({ x, y, width, height });
}
@ -304,25 +220,10 @@ export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
return faceID;
}
export async function getTFImage(blob): Promise<tf.Tensor3D> {
const imageBitmap = await createImageBitmap(blob);
const tfImage = tf.browser.fromPixels(imageBitmap);
imageBitmap.close();
return tfImage;
}
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
return await createImageBitmap(blob);
}
// export async function getTFImageUsingJpegJS(blob: Blob): Promise<TFImageBitmap> {
// const imageData = jpegjs.decode(await blob.arrayBuffer());
// const tfImage = tf.browser.fromPixels(imageData);
// return new TFImageBitmap(undefined, tfImage);
// }
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
let fileStream;
if (queue) {