Browse Source

Handle first search on app start

Manav Rathi 1 year ago
parent
commit
10934b08a8

+ 8 - 6
desktop/src/main/services/ml-clip.ts

@@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node";
 import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
 import log from "../log";
 import { writeStream } from "../stream";
-import { ensure } from "../utils/common";
+import { ensure, wait } from "../utils/common";
 import { deleteTempFile, makeTempFilePath } from "../utils/temp";
 import { makeCachedInferenceSession } from "./ml";
 
@@ -141,20 +141,22 @@ const getTokenizer = () => {
 };
 
 export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
-    const sessionOrStatus = await Promise.race([
+    const sessionOrSkip = await Promise.race([
         cachedCLIPTextSession(),
-        "downloading-model",
+        // Wait for a tick to get the session promise to resolved the first time
+        // this code runs on each app start (and the model has been downloaded).
+        wait(0).then(() => 1),
     ]);
 
-    // Don't wait for the download to complete
-    if (typeof sessionOrStatus == "string") {
+    // Don't wait for the download to complete.
+    if (typeof sessionOrSkip == "number") {
         log.info(
             "Ignoring CLIP text embedding request because model download is pending",
         );
         return undefined;
     }
 
-    const session = sessionOrStatus;
+    const session = sessionOrSkip;
     const t1 = Date.now();
     const tokenizer = getTokenizer();
     const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));

+ 9 - 0
desktop/src/main/utils/common.ts

@@ -13,3 +13,12 @@ export const ensure = <T>(v: T | null | undefined): T => {
     if (v === undefined) throw new Error("Required value was not found");
     return v;
 };
+
+/**
+ * Wait for {@link ms} milliseconds
+ *
+ * This function is a promisified `setTimeout`. It returns a promise that
+ * resolves after {@link ms} milliseconds.
+ */
+export const wait = (ms: number) =>
+    new Promise((resolve) => setTimeout(resolve, ms));

+ 4 - 0
web/packages/utils/promise.ts

@@ -10,6 +10,10 @@ export const wait = (ms: number) =>
 /**
  * Await the given {@link promise} for {@link timeoutMS} milliseconds. If it
  * does not resolve within {@link timeoutMS}, then reject with a timeout error.
+ *
+ * Note that this does not abort {@link promise} itself - it will still get
+ * resolved to completion, just its result will be ignored if it gets resolved
+ * after we've already timed out.
  */
 export const withTimeout = async <T>(promise: Promise<T>, ms: number) => {
     let timeoutId: ReturnType<typeof setTimeout>;