Explorar o código

[web][photos] solve TODOs

laurenspriem hai 1 ano
pai
achega
ae5496f306

+ 110 - 58
web/apps/photos/src/services/face/f-index.ts

@@ -27,8 +27,13 @@ import {
     pixelRGBBilinear,
     warpAffineFloat32List,
 } from "./image";
-import { transformFaceDetections } from "./transform-box";
-
+import {
+    Matrix as transformMatrix,
+    applyToPoint,
+    compose,
+    scale,
+    translate,
+} from "transformation-matrix";
 /**
  * Index faces in the given file.
  *
@@ -138,7 +143,7 @@ const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
 /**
  * Detect faces in the given {@link imageBitmap}.
  *
- * The model used is YOLO, running in an ONNX runtime.
+ * The model used is YOLOv5Face, running in an ONNX runtime.
  */
 const detectFaces = async (
     imageBitmap: ImageBitmap,
@@ -149,16 +154,14 @@ const detectFaces = async (
     const { yoloInput, yoloSize } =
         convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
     const yoloOutput = await workerBridge.detectFaces(yoloInput);
-    const faces = faceDetectionsFromYOLOOutput(yoloOutput);
+    const faces = filterExtractDetectionsFromYOLOOutput(yoloOutput);
     const faceDetections = transformFaceDetections(
         faces,
         rect(yoloSize),
         rect(imageBitmap),
     );
 
-    const maxFaceDistancePercent = Math.sqrt(2) / 100;
-    const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
-    return removeDuplicateDetections(faceDetections, maxFaceDistance);
+    return naiveNonMaxSuppression(faceDetections, 0.4);
 };
 
 /**
@@ -214,14 +217,16 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
 };
 
 /**
- * Extract detected faces from the YOLO's output.
+ * Extract detected faces from the YOLOv5Face's output.
  *
  * Only detections that exceed a minimum score are returned.
  *
  * @param rows A Float32Array of shape [25200, 16], where each row
  * represents a bounding box.
  */
-const faceDetectionsFromYOLOOutput = (rows: Float32Array): FaceDetection[] => {
+const filterExtractDetectionsFromYOLOOutput = (
+    rows: Float32Array,
+): FaceDetection[] => {
     const faces: FaceDetection[] = [];
     // Iterate over each row.
     for (let i = 0; i < rows.length; i += 16) {
@@ -266,61 +271,111 @@ const faceDetectionsFromYOLOOutput = (rows: Float32Array): FaceDetection[] => {
 };
 
 /**
- * Removes duplicate face detections from an array of detections.
- *
- * This function sorts the detections by their probability in descending order,
- * then iterates over them.
- *
- * For each detection, it calculates the Euclidean distance to all other
- * detections.
- *
- * If the distance is less than or equal to the specified threshold
- * (`withinDistance`), the other detection is considered a duplicate and is
- * removed.
- *
- * @param detections - An array of face detections to remove duplicates from.
- *
- * @param withinDistance - The maximum Euclidean distance between two detections
- * for them to be considered duplicates.
- *
- * @returns An array of face detections with duplicates removed.
+ * Transform the given {@link faceDetections} from their coordinate system in
+ * which they were detected ({@link inBox}) back to the coordinate system of the
+ * original image ({@link toBox}).
  */
-const removeDuplicateDetections = (
-    detections: FaceDetection[],
-    withinDistance: number,
-) => {
-    detections.sort((a, b) => b.probability - a.probability);
+const transformFaceDetections = (
+    faceDetections: FaceDetection[],
+    inBox: Box,
+    toBox: Box,
+): FaceDetection[] => {
+    const transform = boxTransformationMatrix(inBox, toBox);
+    return faceDetections.map((f) => ({
+        box: transformBox(f.box, transform),
+        landmarks: f.landmarks.map((p) => transformPoint(p, transform)),
+        probability: f.probability,
+    }));
+};
 
-    const dupIndices = new Set<number>();
-    for (let i = 0; i < detections.length; i++) {
-        if (dupIndices.has(i)) continue;
+const boxTransformationMatrix = (inBox: Box, toBox: Box): transformMatrix =>
+    compose(
+        translate(toBox.x, toBox.y),
+        scale(toBox.width / inBox.width, toBox.height / inBox.height),
+    );
 
-        for (let j = i + 1; j < detections.length; j++) {
-            if (dupIndices.has(j)) continue;
+const transformPoint = (point: Point, transform: transformMatrix) => {
+    const txdPoint = applyToPoint(transform, point);
+    return new Point(txdPoint.x, txdPoint.y);
+};
+
+const transformBox = (box: Box, transform: transformMatrix) => {
+    const topLeft = transformPoint(new Point(box.x, box.y), transform);
+    const bottomRight = transformPoint(
+        new Point(box.x + box.width, box.y + box.height),
+        transform,
+    );
 
-            const centeri = faceDetectionCenter(detections[i]);
-            const centerj = faceDetectionCenter(detections[j]);
-            const dist = euclidean(
-                [centeri.x, centeri.y],
-                [centerj.x, centerj.y],
-            );
+    return new Box({
+        x: topLeft.x,
+        y: topLeft.y,
+        width: bottomRight.x - topLeft.x,
+        height: bottomRight.y - topLeft.y,
+    });
+};
 
-            if (dist <= withinDistance) dupIndices.add(j);
+/**
+ * Remove overlapping faces from an array of face detections through non-maximum suppression algorithm.
+ * 
+ * This function sorts the detections by their probability in descending order, then iterates over them.
+ * 
+ * For each detection, it calculates the Intersection over Union (IoU) with all other detections.
+ * 
+ * If the IoU is greater than or equal to the specified threshold (`iouThreshold`), the other detection is considered overlapping and is removed.
+ * 
+ * @param detections - An array of face detections to remove overlapping faces from.
+ * 
+ * @param iouThreshold - The minimum IoU between two detections for them to be considered overlapping.
+ * 
+ * @returns An array of face detections with overlapping faces removed
+ */
+const naiveNonMaxSuppression = (
+    detections: FaceDetection[],
+    iouThreshold: number,
+): FaceDetection[] => {
+    // Sort the detections by score, the highest first
+    detections.sort((a, b) => b.probability - a.probability);
+
+    // Loop through the detections and calculate the IOU
+    for (let i = 0; i < detections.length - 1; i++) {
+        for (let j = i + 1; j < detections.length; j++) {
+            const iou = calculateIOU(detections[i], detections[j]);
+            if (iou >= iouThreshold) {
+                detections.splice(j, 1);
+                j--;
+            }
         }
     }
 
-    return detections.filter((_, i) => !dupIndices.has(i));
+    return detections;
 };
 
-const faceDetectionCenter = (detection: FaceDetection) => {
-    const center = new Point(0, 0);
-    // TODO-ML(LAURENS): first 4 landmarks is applicable to blazeface only this
-    // needs to consider eyes, nose and mouth landmarks to take center
-    detection.landmarks?.slice(0, 4).forEach((p) => {
-        center.x += p.x;
-        center.y += p.y;
-    });
-    return new Point(center.x / 4, center.y / 4);
+const calculateIOU = (a: FaceDetection, b: FaceDetection): number => {
+    const intersectionMinX = Math.max(a.box.x, b.box.x);
+    const intersectionMinY = Math.max(a.box.y, b.box.y);
+    const intersectionMaxX = Math.min(
+        a.box.x + a.box.width,
+        b.box.x + b.box.width,
+    );
+    const intersectionMaxY = Math.min(
+        a.box.y + a.box.height,
+        b.box.y + b.box.height,
+    );
+
+    const intersectionWidth = intersectionMaxX - intersectionMinX;
+    const intersectionHeight = intersectionMaxY - intersectionMinY;
+
+    if (intersectionWidth < 0 || intersectionHeight < 0) {
+        return 0.0; // If boxes do not overlap, IoU is 0
+    }
+
+    const areaA = a.box.width * a.box.height;
+    const areaB = b.box.width * b.box.height;
+
+    const intersectionArea = intersectionWidth * intersectionHeight;
+    const unionArea = areaA + areaB - intersectionArea;
+
+    return intersectionArea / unionArea;
 };
 
 const makeFaceID = (
@@ -689,14 +744,12 @@ const extractFaceCrop = (
     const scaleForPadding = 1 + padding * 2;
     const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding));
 
-    // TODO-ML(LAURENS): The rotation doesn't seem to be used? it's set to 0.
-    return cropWithRotation(imageBitmap, paddedBox, 0, 256);
+    return cropWithRotation(imageBitmap, paddedBox, 256);
 };
 
 const cropWithRotation = (
     imageBitmap: ImageBitmap,
     cropBox: Box,
-    rotation: number,
     maxDimension: number,
 ) => {
     const box = roundBox(cropBox);
@@ -714,7 +767,6 @@ const cropWithRotation = (
     offscreenCtx.imageSmoothingQuality = "high";
 
     offscreenCtx.translate(outputSize.width / 2, outputSize.height / 2);
-    rotation && offscreenCtx.rotate(rotation);
 
     const outputBox = new Box({
         x: -outputSize.width / 2,

+ 0 - 57
web/apps/photos/src/services/face/transform-box.ts

@@ -1,57 +0,0 @@
-import { Box, Point } from "services/face/geom";
-import type { FaceDetection } from "services/face/types";
-// TODO-ML(LAURENS): Do we need two separate Matrix libraries?
-//
-// Keeping this in a separate file so that we can audit this. If these can be
-// expressed using ml-matrix, then we can move this code to f-index.ts
-import {
-    Matrix,
-    applyToPoint,
-    compose,
-    scale,
-    translate,
-} from "transformation-matrix";
-
-/**
- * Transform the given {@link faceDetections} from their coordinate system in
- * which they were detected ({@link inBox}) back to the coordinate system of the
- * original image ({@link toBox}).
- */
-export const transformFaceDetections = (
-    faceDetections: FaceDetection[],
-    inBox: Box,
-    toBox: Box,
-): FaceDetection[] => {
-    const transform = boxTransformationMatrix(inBox, toBox);
-    return faceDetections.map((f) => ({
-        box: transformBox(f.box, transform),
-        landmarks: f.landmarks.map((p) => transformPoint(p, transform)),
-        probability: f.probability,
-    }));
-};
-
-const boxTransformationMatrix = (inBox: Box, toBox: Box): Matrix =>
-    compose(
-        translate(toBox.x, toBox.y),
-        scale(toBox.width / inBox.width, toBox.height / inBox.height),
-    );
-
-const transformPoint = (point: Point, transform: Matrix) => {
-    const txdPoint = applyToPoint(transform, point);
-    return new Point(txdPoint.x, txdPoint.y);
-};
-
-const transformBox = (box: Box, transform: Matrix) => {
-    const topLeft = transformPoint(new Point(box.x, box.y), transform);
-    const bottomRight = transformPoint(
-        new Point(box.x + box.width, box.y + box.height),
-        transform,
-    );
-
-    return new Box({
-        x: topLeft.x,
-        y: topLeft.y,
-        width: bottomRight.x - topLeft.x,
-        height: bottomRight.y - topLeft.y,
-    });
-};