|
@@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node";
|
|
|
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
|
|
import log from "../log";
|
|
|
import { writeStream } from "../stream";
|
|
|
-import { ensure } from "../utils/common";
|
|
|
+import { ensure, wait } from "../utils/common";
|
|
|
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
|
|
|
import { makeCachedInferenceSession } from "./ml";
|
|
|
|
|
@@ -20,7 +20,7 @@ const cachedCLIPImageSession = makeCachedInferenceSession(
|
|
|
351468764 /* 335.2 MB */,
|
|
|
);
|
|
|
|
|
|
-export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
|
|
+export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => {
|
|
|
const tempFilePath = await makeTempFilePath();
|
|
|
const imageStream = new Response(jpegImageData.buffer).body;
|
|
|
await writeStream(tempFilePath, ensure(imageStream));
|
|
@@ -42,7 +42,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
|
|
const results = await session.run(feeds);
|
|
|
log.debug(
|
|
|
() =>
|
|
|
- `onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
|
|
+ `ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
|
|
);
|
|
|
/* Need these model specific casts to type the result */
|
|
|
const imageEmbedding = ensure(results.output).data as Float32Array;
|
|
@@ -140,21 +140,23 @@ const getTokenizer = () => {
|
|
|
return _tokenizer;
|
|
|
};
|
|
|
|
|
|
-export const clipTextEmbeddingIfAvailable = async (text: string) => {
|
|
|
- const sessionOrStatus = await Promise.race([
|
|
|
+export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
|
|
+ const sessionOrSkip = await Promise.race([
|
|
|
cachedCLIPTextSession(),
|
|
|
- "downloading-model",
|
|
|
+ // Wait for a tick to get the session promise to resolved the first time
|
|
|
+ // this code runs on each app start (and the model has been downloaded).
|
|
|
+ wait(0).then(() => 1),
|
|
|
]);
|
|
|
|
|
|
- // Don't wait for the download to complete
|
|
|
- if (typeof sessionOrStatus == "string") {
|
|
|
+ // Don't wait for the download to complete.
|
|
|
+ if (typeof sessionOrSkip == "number") {
|
|
|
log.info(
|
|
|
"Ignoring CLIP text embedding request because model download is pending",
|
|
|
);
|
|
|
return undefined;
|
|
|
}
|
|
|
|
|
|
- const session = sessionOrStatus;
|
|
|
+ const session = sessionOrSkip;
|
|
|
const t1 = Date.now();
|
|
|
const tokenizer = getTokenizer();
|
|
|
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
|
@@ -165,7 +167,7 @@ export const clipTextEmbeddingIfAvailable = async (text: string) => {
|
|
|
const results = await session.run(feeds);
|
|
|
log.debug(
|
|
|
() =>
|
|
|
- `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
|
|
+ `ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
|
|
);
|
|
|
const textEmbedding = ensure(results.output).data as Float32Array;
|
|
|
return normalizeEmbedding(textEmbedding);
|