Manav Rathi 1 年之前
父節點
當前提交
926bc33c79
共有 3 個文件被更改,包括 65 次插入151 次删除
  1. 8 36
      desktop/src/main/services/ml-clip.ts
  2. 13 114
      desktop/src/main/services/ml-face.ts
  3. 44 1
      desktop/src/main/services/ml.ts

+ 8 - 36
desktop/src/main/services/ml-clip.ts

@@ -18,34 +18,17 @@ import { deleteTempFile } from "./ffmpeg";
 import {
     createInferenceSession,
     downloadModel,
-    modelPathDownloadingIfNeeded,
+    makeCachedInferenceSession,
     modelSavePath,
 } from "./ml";
 
 const textModelName = "clip-text-vit-32-uint8.onnx";
 const textModelByteSize = 64173509; // 61.2 MB
 
-const imageModelName = "clip-image-vit-32-float32.onnx";
-const imageModelByteSize = 351468764; // 335.2 MB
-
-let activeImageModelDownload: Promise<string> | undefined;
-
-const imageModelPathDownloadingIfNeeded = async () => {
-    try {
-        if (activeImageModelDownload) {
-            log.info("Waiting for CLIP image model download to finish");
-            await activeImageModelDownload;
-        } else {
-            activeImageModelDownload = modelPathDownloadingIfNeeded(
-                imageModelName,
-                imageModelByteSize,
-            );
-            return await activeImageModelDownload;
-        }
-    } finally {
-        activeImageModelDownload = undefined;
-    }
-};
+const cachedCLIPImageSession = makeCachedInferenceSession(
+    "clip-image-vit-32-float32.onnx",
+    351468764 /* 335.2 MB */,
+);
 
 let textModelDownloadInProgress = false;
 
@@ -90,18 +73,6 @@ const textModelPathDownloadingIfNeeded = async () => {
     return modelPath;
 };
 
-let imageSessionPromise: Promise<any> | undefined;
-
-const onnxImageSession = async () => {
-    if (!imageSessionPromise) {
-        imageSessionPromise = (async () => {
-            const modelPath = await imageModelPathDownloadingIfNeeded();
-            return createInferenceSession(modelPath);
-        })();
-    }
-    return imageSessionPromise;
-};
-
 let _textSession: any = null;
 
 const onnxTextSession = async () => {
@@ -124,7 +95,7 @@ export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
 };
 
 const clipImageEmbedding_ = async (jpegFilePath: string) => {
-    const session = await onnxImageSession();
+    const session = await cachedCLIPImageSession();
     const t1 = Date.now();
     const rgbData = await getRGBData(jpegFilePath);
     const feeds = {
@@ -136,7 +107,8 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
         () =>
             `onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
     );
-    const imageEmbedding = results["output"].data; // Float32Array
+    /* Need these model specific casts to type the result */
+    const imageEmbedding = results["output"].data as Float32Array;
     return normalizeEmbedding(imageEmbedding);
 };
 

+ 13 - 114
desktop/src/main/services/ml-face.ts

@@ -8,121 +8,20 @@
  */
 import * as ort from "onnxruntime-node";
 import log from "../log";
-import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
+import { makeCachedInferenceSession } from "./ml";
 
-const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
-const faceDetectionModelByteSize = 30762872; // 29.3 MB
+const cachedFaceDetectionSession = makeCachedInferenceSession(
+    "yolov5s_face_640_640_dynamic.onnx",
+    30762872 /* 29.3 MB */,
+);
 
-const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
-const faceEmbeddingModelByteSize = 5286998; // 5 MB
-
-/**
- * Return a function that can be used to trigger a download of the specified
- * model, and the creating of an ONNX inference session initialized using it.
- *
- * Multiple parallel calls to the returned function are fine, it ensures that
- * the the model will be downloaded and the session created using it only once.
- * All pending calls to it meanwhile will just await on the same promise.
- *
- * And once the promise is resolved, the create ONNX inference session will be
- * cached, so subsequent calls to the returned function will just reuse the same
- * session.
- *
- * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
- * actively trigger a download until the returned function is called.
- *
- * @param modelName The name of the model to download.
- * @param modelByteSize The size in bytes that we expect the model to have. If
- * the size of the downloaded model does not match the expected size, then we
- * will redownload it.
- *
- * @returns a function. calling that function returns a promise to an ONNX
- * session.
- */
-const makeCachedInferenceSession = (
-    modelName: string,
-    modelByteSize: number,
-) => {
-    let session: Promise<ort.InferenceSession> | undefined;
-
-    const download = () =>
-        modelPathDownloadingIfNeeded(modelName, modelByteSize);
-
-    const createSession = (modelPath: string) =>
-        createInferenceSession(modelPath);
-
-    const cachedInferenceSession = () => {
-        if (!session) session = download().then(createSession);
-        return session;
-    };
-
-    return cachedInferenceSession;
-};
-
-let activeFaceDetectionModelDownload: Promise<string> | undefined;
-
-const faceDetectionModelPathDownloadingIfNeeded = async () => {
-    try {
-        if (activeFaceDetectionModelDownload) {
-            log.info("Waiting for face detection model download to finish");
-            await activeFaceDetectionModelDownload;
-        } else {
-            activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
-                faceDetectionModelName,
-                faceDetectionModelByteSize,
-            );
-            return await activeFaceDetectionModelDownload;
-        }
-    } finally {
-        activeFaceDetectionModelDownload = undefined;
-    }
-};
-
-let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
-
-const faceDetectionSession = async () => {
-    if (!_faceDetectionSession) {
-        _faceDetectionSession =
-            faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
-                createInferenceSession(modelPath),
-            );
-    }
-    return _faceDetectionSession;
-};
-
-let activeFaceEmbeddingModelDownload: Promise<string> | undefined;
-
-const faceEmbeddingModelPathDownloadingIfNeeded = async () => {
-    try {
-        if (activeFaceEmbeddingModelDownload) {
-            log.info("Waiting for face embedding model download to finish");
-            await activeFaceEmbeddingModelDownload;
-        } else {
-            activeFaceEmbeddingModelDownload = modelPathDownloadingIfNeeded(
-                faceEmbeddingModelName,
-                faceEmbeddingModelByteSize,
-            );
-            return await activeFaceEmbeddingModelDownload;
-        }
-    } finally {
-        activeFaceEmbeddingModelDownload = undefined;
-    }
-};
-
-let _faceEmbeddingSession: Promise<ort.InferenceSession> | undefined;
-
-const faceEmbeddingSession = async () => {
-    if (!_faceEmbeddingSession) {
-        _faceEmbeddingSession =
-            faceEmbeddingModelPathDownloadingIfNeeded().then((modelPath) =>
-                createInferenceSession(modelPath),
-            );
-    }
-    return _faceEmbeddingSession;
-};
+const cachedFaceEmbeddingSession = makeCachedInferenceSession(
+    "mobilefacenet_opset15.onnx",
+    5286998 /* 5 MB */,
+);
 
 export const detectFaces = async (input: Float32Array) => {
-    const session = await faceDetectionSession();
+    const session = await cachedFaceDetectionSession();
     const t = Date.now();
     const feeds = {
         input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
@@ -141,11 +40,11 @@ export const faceEmbedding = async (input: Float32Array) => {
     const n = Math.round(input.length / (z * z * 3));
     const inputTensor = new ort.Tensor("float32", input, [n, z, z, 3]);
 
-    const session = await faceEmbeddingSession();
+    const session = await cachedFaceEmbeddingSession();
     const t = Date.now();
     const feeds = { img_inputs: inputTensor };
     const results = await session.run(feeds);
     log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
-    // TODO: What's with this type? It works in practice, but double check.
-    return (results.embeddings as unknown as any)["cpuData"]; // as Float32Array;
+    /* Need these model specific casts to extract and type the result */
+    return (results.embeddings as unknown as any)["cpuData"] as Float32Array;
 };

+ 44 - 1
desktop/src/main/services/ml.ts

@@ -18,6 +18,49 @@ import * as ort from "onnxruntime-node";
 import log from "../log";
 import { writeStream } from "../stream";
 
+/**
+ * Return a function that can be used to trigger a download of the specified
+ * model, and the creating of an ONNX inference session initialized using it.
+ *
+ * Multiple parallel calls to the returned function are fine, it ensures that
+ * the the model will be downloaded and the session created using it only once.
+ * All pending calls to it meanwhile will just await on the same promise.
+ *
+ * And once the promise is resolved, the create ONNX inference session will be
+ * cached, so subsequent calls to the returned function will just reuse the same
+ * session.
+ *
+ * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
+ * actively trigger a download until the returned function is called.
+ *
+ * @param modelName The name of the model to download.
+ * @param modelByteSize The size in bytes that we expect the model to have. If
+ * the size of the downloaded model does not match the expected size, then we
+ * will redownload it.
+ *
+ * @returns a function. calling that function returns a promise to an ONNX
+ * session.
+ */
+export const makeCachedInferenceSession = (
+    modelName: string,
+    modelByteSize: number,
+) => {
+    let session: Promise<ort.InferenceSession> | undefined;
+
+    const download = () =>
+        modelPathDownloadingIfNeeded(modelName, modelByteSize);
+
+    const createSession = (modelPath: string) =>
+        createInferenceSession(modelPath);
+
+    const cachedInferenceSession = () => {
+        if (!session) session = download().then(createSession);
+        return session;
+    };
+
+    return cachedInferenceSession;
+};
+
 /**
  * Download the model named {@link modelName} if we don't already have it.
  *
@@ -26,7 +69,7 @@ import { writeStream } from "../stream";
  *
  * @returns the path to the model on the local machine.
  */
-export const modelPathDownloadingIfNeeded = async (
+const modelPathDownloadingIfNeeded = async (
     modelName: string,
     expectedByteSize: number,
 ) => {