Roundtrip

This commit is contained in:
Manav Rathi 2024-04-11 14:37:44 +05:30
parent 41f7b30ca0
commit a1d6ef43b4
No known key found for this signature in database
5 changed files with 33 additions and 32 deletions

View file

@ -147,8 +147,8 @@ export const attachIPCHandlers = () => {
clipTextEmbedding(text),
);
ipcMain.handle("detectFaces", (_, imageData: Uint8Array) =>
detectFaces(imageData),
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
detectFaces(input),
);
ipcMain.handle("faceEmbedding", (_, input: Float32Array) =>

View file

@ -78,8 +78,29 @@ const faceEmbeddingSession = async () => {
return _faceEmbeddingSession;
};
export const detectFaces = async (inputImage: Uint8Array) => {
throw new Error("test");
export const detectFaces = async (input: Float32Array) => {
// console.log("start ort");
// this.onnxInferenceSession = await ort.InferenceSession.create(
// "/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
// );
// const data = new Float32Array(1 * 3 * 640 * 640);
// const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// // TODO(MR): onnx-yolo
// // const feeds: Record<string, ort.Tensor> = {};
// const feeds: Record<string, any> = {};
// const name = this.onnxInferenceSession.inputNames[0];
// feeds[name] = inputTensor;
// await this.onnxInferenceSession.run(feeds);
// console.log("start end");
const session = await faceDetectionSession();
const t = Date.now();
const feeds = {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const results = await session.run(feeds);
log.debug(() => `onnx/yolo inference took ${Date.now() - t} ms`);
return results["output"].data;
};
export const faceEmbedding = async (input: Float32Array) => {
@ -89,19 +110,6 @@ export const faceEmbedding = async (input: Float32Array) => {
/*
private async initOnnx() {
console.log("start ort");
this.onnxInferenceSession = await ort.InferenceSession.create(
"/models/yoloface/yolov5s_face_640_640_dynamic.onnx",
);
const data = new Float32Array(1 * 3 * 640 * 640);
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
const name = this.onnxInferenceSession.inputNames[0];
feeds[name] = inputTensor;
await this.onnxInferenceSession.run(feeds);
console.log("start end");
}
private async getOnnxInferenceSession() {

View file

@ -143,8 +143,8 @@ const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
const clipTextEmbedding = (text: string): Promise<Float32Array> =>
ipcRenderer.invoke("clipTextEmbedding", text);
const detectFaces = (imageData: Uint8Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", imageData);
const detectFaces = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("detectFaces", input);
const faceEmbedding = (input: Float32Array): Promise<Float32Array> =>
ipcRenderer.invoke("faceEmbedding", input);

View file

@ -1,3 +1,4 @@
import { ensureElectron } from "@/next/electron";
import { MAX_FACE_DISTANCE_PERCENT } from "constants/mlConfig";
import { euclidean } from "hdbscan";
import {
@ -44,14 +45,7 @@ class YoloFaceDetectionService implements FaceDetectionService {
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const inputTensor = new ort.Tensor("float32", data, [1, 3, 640, 640]);
// TODO(MR): onnx-yolo
// const feeds: Record<string, ort.Tensor> = {};
const feeds: Record<string, any> = {};
feeds["input"] = inputTensor;
const inferenceSession = await this.getOnnxInferenceSession();
const runout = await inferenceSession.run(feeds);
const outputData = runout.output.data;
const outputData = await ensureElectron().detectFaces(data);
const faces = this.getFacesFromYoloOutput(
outputData as Float32Array,
0.7,

View file

@ -221,16 +221,15 @@ export interface Electron {
* Detect faces in the given image using YOLO.
*
* Both the input and output are opaque binary data whose internal structure
* is model (YOLO) and our implementation specific. That said, specifically
* the {@link inputImage} a particular bitmap encoding of an image.
* is specific to our implementation and the model (YOLO) we use.
*/
detectFaces: (inputImage: Uint8Array) => Promise<Float32Array>;
detectFaces: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a mobilefacenet embedding for the given face data.
* Return a MobileFaceNet embedding for the given face data.
*
* Both the input and output are opaque binary data whose internal structure
* is model (mobilefacenet) and our implementation specific.
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
faceEmbedding: (input: Float32Array) => Promise<Float32Array>;