diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts index dc9e67f1b..cea1574e0 100644 --- a/desktop/src/main/services/ml-clip.ts +++ b/desktop/src/main/services/ml-clip.ts @@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; import log from "../log"; import { writeStream } from "../stream"; -import { ensure } from "../utils/common"; +import { ensure, wait } from "../utils/common"; import { deleteTempFile, makeTempFilePath } from "../utils/temp"; import { makeCachedInferenceSession } from "./ml"; @@ -141,20 +141,22 @@ const getTokenizer = () => { }; export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { - const sessionOrStatus = await Promise.race([ + const sessionOrSkip = await Promise.race([ cachedCLIPTextSession(), - "downloading-model", + // Wait for a tick to get the session promise to resolved the first time + // this code runs on each app start (and the model has been downloaded). + wait(0).then(() => 1), ]); - // Don't wait for the download to complete - if (typeof sessionOrStatus == "string") { + // Don't wait for the download to complete. + if (typeof sessionOrSkip == "number") { log.info( "Ignoring CLIP text embedding request because model download is pending", ); return undefined; } - const session = sessionOrStatus; + const session = sessionOrSkip; const t1 = Date.now(); const tokenizer = getTokenizer(); const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); diff --git a/desktop/src/main/utils/common.ts b/desktop/src/main/utils/common.ts index 5ed46aa8a..929281d74 100644 --- a/desktop/src/main/utils/common.ts +++ b/desktop/src/main/utils/common.ts @@ -13,3 +13,12 @@ export const ensure = (v: T | null | undefined): T => { if (v === undefined) throw new Error("Required value was not found"); return v; }; + +/** + * Wait for {@link ms} milliseconds + * + * This function is a promisified `setTimeout`. It returns a promise that + * resolves after {@link ms} milliseconds. + */ +export const wait = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/web/packages/utils/promise.ts b/web/packages/utils/promise.ts index 4cb7648fd..34f821b6d 100644 --- a/web/packages/utils/promise.ts +++ b/web/packages/utils/promise.ts @@ -10,6 +10,10 @@ export const wait = (ms: number) => /** * Await the given {@link promise} for {@link timeoutMS} milliseconds. If it * does not resolve within {@link timeoutMS}, then reject with a timeout error. + * + * Note that this does not abort {@link promise} itself - it will still get + * resolved to completion, just its result will be ignored if it gets resolved + * after we've already timed out. */ export const withTimeout = async (promise: Promise, ms: number) => { let timeoutId: ReturnType;