Manav Rathi 1 rok temu
rodzic
commit
46a53d5fdf
1 zmienionych plików z 43 dodań i 0 usunięć
  1. 43 0
      desktop/src/main/services/ml-face.ts

+ 43 - 0
desktop/src/main/services/ml-face.ts

@@ -16,6 +16,49 @@ const faceDetectionModelByteSize = 30762872; // 29.3 MB
 const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
 const faceEmbeddingModelByteSize = 5286998; // 5 MB
 
+/**
+ * Return a function that can be used to trigger a download of the specified
+ * model, and the creating of an ONNX inference session initialized using it.
+ *
+ * Multiple parallel calls to the returned function are fine, it ensures that
+ * the the model will be downloaded and the session created using it only once.
+ * All pending calls to it meanwhile will just await on the same promise.
+ *
+ * And once the promise is resolved, the create ONNX inference session will be
+ * cached, so subsequent calls to the returned function will just reuse the same
+ * session.
+ *
+ * {@link makeCachedInferenceSession} can itself be called anytime, it doesn't
+ * actively trigger a download until the returned function is called.
+ *
+ * @param modelName The name of the model to download.
+ * @param modelByteSize The size in bytes that we expect the model to have. If
+ * the size of the downloaded model does not match the expected size, then we
+ * will redownload it.
+ *
+ * @returns a function. calling that function returns a promise to an ONNX
+ * session.
+ */
+const makeCachedInferenceSession = (
+    modelName: string,
+    modelByteSize: number,
+) => {
+    let session: Promise<ort.InferenceSession> | undefined;
+
+    const download = () =>
+        modelPathDownloadingIfNeeded(modelName, modelByteSize);
+
+    const createSession = (modelPath: string) =>
+        createInferenceSession(modelPath);
+
+    const cachedInferenceSession = () => {
+        if (!session) session = download().then(createSession);
+        return session;
+    };
+
+    return cachedInferenceSession;
+};
+
 let activeFaceDetectionModelDownload: Promise<string> | undefined;
 
 const faceDetectionModelPathDownloadingIfNeeded = async () => {