diff --git a/web/apps/photos/src/utils/machineLearning/faceAlign.ts b/web/apps/photos/src/utils/machineLearning/faceAlign.ts index bedef835f..3b8918dac 100644 --- a/web/apps/photos/src/utils/machineLearning/faceAlign.ts +++ b/web/apps/photos/src/utils/machineLearning/faceAlign.ts @@ -131,25 +131,6 @@ export function extractFaceImage( }); } -export function tfExtractFaceImages( - image: tf.Tensor3D | tf.Tensor4D, - alignments: Array, - faceSize: number, -): tf.Tensor4D { - return tf.tidy(() => { - const tf4dFloat32Image = toTensor4D(image, "float32"); - const faceImages = new Array(alignments.length); - for (let i = 0; i < alignments.length; i++) { - faceImages[i] = tf.squeeze( - extractFaceImage(tf4dFloat32Image, alignments[i], faceSize), - [0], - ); - } - - return tf.stack(faceImages) as tf.Tensor4D; - }); -} - export function getAlignedFaceBox(alignment: FaceAlignment) { return new Box({ x: alignment.center.x - alignment.size / 2, @@ -200,59 +181,3 @@ export function ibExtractFaceImages( ibExtractFaceImage(image, alignment, faceSize), ); } - -const BLAZEFACE_LEFT_EYE_INDEX = 0; -const BLAZEFACE_RIGHT_EYE_INDEX = 1; -// const BLAZEFACE_NOSE_INDEX = 2; -const BLAZEFACE_MOUTH_INDEX = 3; - -export function getRotatedFaceImage( - image: tf.Tensor3D | tf.Tensor4D, - faceDetection: FaceDetection, - padding: number = 1.5, -): tf.Tensor4D { - const paddedBox = enlargeBox(faceDetection.box, padding); - // log.info("paddedBox", paddedBox); - const landmarkPoints = faceDetection.landmarks; - - return tf.tidy(() => { - const tf4dFloat32Image = toTensor4D(image, "float32"); - let angle = 0; - const leftEye = landmarkPoints[BLAZEFACE_LEFT_EYE_INDEX]; - const rightEye = landmarkPoints[BLAZEFACE_RIGHT_EYE_INDEX]; - const foreheadCenter = getBoxCenterPt(leftEye, rightEye); - - angle = computeRotation( - landmarkPoints[BLAZEFACE_MOUTH_INDEX], - foreheadCenter, - ); // landmarkPoints[BLAZEFACE_NOSE_INDEX] - // angle = computeRotation(leftEye, rightEye); - // log.info('angle: ', angle); - - const faceCenter = getBoxCenter(faceDetection.box); - // log.info('faceCenter: ', faceCenter); - const faceCenterNormalized: [number, number] = [ - faceCenter.x / tf4dFloat32Image.shape[2], - faceCenter.y / tf4dFloat32Image.shape[1], - ]; - // log.info('faceCenterNormalized: ', faceCenterNormalized); - - let rotatedImage = tf4dFloat32Image; - if (angle !== 0) { - rotatedImage = tf.image.rotateWithOffset( - tf4dFloat32Image, - angle, - 0, - faceCenterNormalized, - ); - } - - const faceImageTensor = extractFaces( - rotatedImage, - [paddedBox], - paddedBox.width > 224 ? 448 : 224, - ); - return faceImageTensor; - // return tf.gather(faceImageTensor, 0); - }); -} diff --git a/web/apps/photos/src/utils/machineLearning/index.ts b/web/apps/photos/src/utils/machineLearning/index.ts index 6603dde69..9b47f1a7b 100644 --- a/web/apps/photos/src/utils/machineLearning/index.ts +++ b/web/apps/photos/src/utils/machineLearning/index.ts @@ -1,7 +1,6 @@ import log from "@/next/log"; import { CACHES } from "@ente/shared/storage/cacheStorage/constants"; import { cached } from "@ente/shared/storage/cacheStorage/helpers"; -import * as tf from "@tensorflow/tfjs-core"; import { NormalizedFace } from "blazeface-back"; import { FILE_TYPE } from "constants/file"; import { BLAZEFACE_FACE_SIZE } from "constants/mlConfig"; @@ -52,89 +51,6 @@ export function f32Average(descriptors: Float32Array[]) { return avg; } -export function isTensor(tensor: any, dim: number) { - return tensor instanceof tf.Tensor && tensor.shape.length === dim; -} - -export function isTensor1D(tensor: any): tensor is tf.Tensor1D { - return isTensor(tensor, 1); -} - -export function isTensor2D(tensor: any): tensor is tf.Tensor2D { - return isTensor(tensor, 2); -} - -export function isTensor3D(tensor: any): tensor is tf.Tensor3D { - return isTensor(tensor, 3); -} - -export function isTensor4D(tensor: any): tensor is tf.Tensor4D { - return isTensor(tensor, 4); -} - -export function toTensor4D( - image: tf.Tensor3D | tf.Tensor4D, - dtype?: tf.DataType, -) { - return tf.tidy(() => { - let reshapedImage: tf.Tensor4D; - if (isTensor3D(image)) { - reshapedImage = tf.expandDims(image, 0); - } else if (isTensor4D(image)) { - reshapedImage = image; - } else { - throw Error("toTensor4D only supports Tensor3D and Tensor4D input"); - } - if (dtype) { - reshapedImage = tf.cast(reshapedImage, dtype); - } - - return reshapedImage; - }); -} - -export function imageBitmapsToTensor4D(imageBitmaps: Array) { - return tf.tidy(() => { - const tfImages = imageBitmaps.map((ib) => tf.browser.fromPixels(ib)); - return tf.stack(tfImages) as tf.Tensor4D; - }); -} - -export function extractFaces( - image: tf.Tensor3D | tf.Tensor4D, - facebBoxes: Array, - faceSize: number, -) { - return tf.tidy(() => { - const reshapedImage = toTensor4D(image, "float32"); - - const boxes = facebBoxes.map((box) => { - const normalized = box.rescale({ - width: 1 / reshapedImage.shape[2], - height: 1 / reshapedImage.shape[1], - }); - - return [ - normalized.top, - normalized.left, - normalized.bottom, - normalized.right, - ]; - }); - - // log.info('boxes: ', boxes[0]); - - const faceImagesTensor = tf.image.cropAndResize( - reshapedImage, - boxes, - tf.fill([boxes.length], 0, "int32"), - [faceSize, faceSize], - ); - - return faceImagesTensor; - }); -} - export function newBox(x: number, y: number, width: number, height: number) { return new Box({ x, y, width, height }); } @@ -304,25 +220,10 @@ export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) { return faceID; } -export async function getTFImage(blob): Promise { - const imageBitmap = await createImageBitmap(blob); - const tfImage = tf.browser.fromPixels(imageBitmap); - imageBitmap.close(); - - return tfImage; -} - export async function getImageBlobBitmap(blob: Blob): Promise { return await createImageBitmap(blob); } -// export async function getTFImageUsingJpegJS(blob: Blob): Promise { -// const imageData = jpegjs.decode(await blob.arrayBuffer()); -// const tfImage = tf.browser.fromPixels(imageData); - -// return new TFImageBitmap(undefined, tfImage); -// } - async function getOriginalFile(file: EnteFile, queue?: PQueue) { let fileStream; if (queue) {