Manav Rathi 1 rok temu
rodzic
commit
a1d6ef43b4

+ 2 - 2
desktop/src/main/ipc.ts

@@ -147,8 +147,8 @@ export const attachIPCHandlers = () => {
         clipTextEmbedding(text),
     );
 
-    ipcMain.handle("detectFaces", (_, imageData: Uint8Array) =>
-        detectFaces(imageData),
+    ipcMain.handle("detectFaces", (_, input: Float32Array) =>
+        detectFaces(input),
     );
 
     ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>

+ 23 - 15
desktop/src/main/services/ml-face.ts

@@ -78,8 +78,29 @@ const faceEmbeddingSession = async () => {
     return _faceEmbeddingSession;
 };
 
-export const detectFaces = async (inputImage: Uint8Array) => {
-    throw new Error("test");
+export const detectFaces = async (input: Float32Array) => {
+    // console.log("start ort");
+    // this.onnxInferenceSession = await ort.InferenceSession.create(
+    //     "/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
+    // );
+    // const data = new Float32Array(1 * 3 * 640 * 640);
+    // const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
+    // // TODO(MR): onnx-yolo
+    // // const feeds: Record<string, ort.Tensor> = {};
+    // const feeds: Record<string, any> = {};
+    // const name = this.onnxInferenceSession.inputNames[0];
+    // feeds[name] = inputTensor;
+    // await this.onnxInferenceSession.run(feeds);
+    // console.log("start end");
+
+    const session = await faceDetectionSession();
+    const t = Date.now();
+    const feeds = {
+        input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
+    };
+    const results = await session.run(feeds);
+    log.debug(() => `onnx/yolo inference took ${Date.now() - t} ms`);
+    return results["output"].data;
 };
 
 export const faceEmbedding = async (input: Float32Array) => {
@@ -89,19 +110,6 @@ export const faceEmbedding = async (input: Float32Array) => {
 /*
 
 private async initOnnx() {
-    console.log("start ort");
-    this.onnxInferenceSession = await ort.InferenceSession.create(
-        "/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
-    );
-    const data = new Float32Array(1 * 3 * 640 * 640);
-    const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
-    // TODO(MR): onnx-yolo
-    // const feeds: Record<string, ort.Tensor> = {};
-    const feeds: Record<string, any> = {};
-    const name = this.onnxInferenceSession.inputNames[0];
-    feeds[name] = inputTensor;
-    await this.onnxInferenceSession.run(feeds);
-    console.log("start end");
 }
 
 private async getOnnxInferenceSession() {

+ 2 - 2
desktop/src/preload.ts

@@ -143,8 +143,8 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
 const clipTextEmbedding = (text: string): Promise<Float32Array> =>
     ipcRenderer.invoke("clipTextEmbedding", text);
 
-const detectFaces = (imageData: Uint8Array): Promise<Float32Array> =>
-    ipcRenderer.invoke("detectFaces", imageData);
+const detectFaces = (input: Float32Array): Promise<Float32Array> =>
+    ipcRenderer.invoke("detectFaces", input);
 
 const faceEmbedding = (input: Float32Array): Promise<Float32Array> =>
     ipcRenderer.invoke("faceEmbedding", input);

+ 2 - 8
web/apps/photos/src/services/machineLearning/yoloFaceDetectionService.ts

@@ -1,3 +1,4 @@
+import { ensureElectron } from "@/next/electron";
 import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig";
 import { euclidean } from "hdbscan";
 import {
@@ -44,14 +45,7 @@ class YoloFaceDetectionService implements FaceDetectionService {
             );
         const data = preprocessResult.data;
         const resized = preprocessResult.newSize;
-        const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
-        // TODO(MR): onnx-yolo
-        // const feeds: Record<string, ort.Tensor> = {};
-        const feeds: Record<string, any> = {};
-        feeds["input"] = inputTensor;
-        const inferenceSession = await this.getOnnxInferenceSession();
-        const runout = await inferenceSession.run(feeds);
-        const outputData = runout.output.data;
+        const outputData = await ensureElectron().detectFaces(data);
         const faces = this.getFacesFromYoloOutput(
             outputData as Float32Array,
             0.7,

+ 4 - 5
web/packages/next/types/ipc.ts

@@ -221,16 +221,15 @@ export interface Electron {
      * Detect faces in the given image using YOLO.
      *
      * Both the input and output are opaque binary data whose internal structure
-     * is model (YOLO) and our implementation specific. That said, specifically
-     * the {@link inputImage} a particular bitmap encoding of an image.
+     * is specific to our implementation and the model (YOLO) we use.
      */
-    detectFaces: (inputImage: Uint8Array) => Promise<Float32Array>;
+    detectFaces: (input: Float32Array) => Promise<Float32Array>;
 
     /**
-     * Return a mobilefacenet embedding for the given face data.
+     * Return a MobileFaceNet embedding for the given face data.
      *
      * Both the input and output are opaque binary data whose internal structure
-     * is model (mobilefacenet) and our implementation specific.
+     * is specific to our implementation and the model (MobileFaceNet) we use.
      */
     faceEmbedding: (input: Float32Array) => Promise<Float32Array>;