|
@@ -5,86 +5,21 @@
|
|
|
*
|
|
|
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
|
|
|
*/
|
|
|
-import { existsSync } from "fs";
|
|
|
import jpeg from "jpeg-js";
|
|
|
import fs from "node:fs/promises";
|
|
|
import * as ort from "onnxruntime-node";
|
|
|
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
|
|
-import { CustomErrors } from "../../types/ipc";
|
|
|
import log from "../log";
|
|
|
import { writeStream } from "../stream";
|
|
|
import { generateTempFilePath } from "../temp";
|
|
|
import { deleteTempFile } from "./ffmpeg";
|
|
|
-import {
|
|
|
- createInferenceSession,
|
|
|
- downloadModel,
|
|
|
- makeCachedInferenceSession,
|
|
|
- modelSavePath,
|
|
|
-} from "./ml";
|
|
|
+import { makeCachedInferenceSession } from "./ml";
|
|
|
|
|
|
const cachedCLIPImageSession = makeCachedInferenceSession(
|
|
|
"clip-image-vit-32-float32.onnx",
|
|
|
351468764 /* 335.2 MB */,
|
|
|
);
|
|
|
|
|
|
-const cachedCLIPTextSession = makeCachedInferenceSession(
|
|
|
- "clip-text-vit-32-uint8.onnx",
|
|
|
- 64173509 /* 61.2 MB */,
|
|
|
-);
|
|
|
-
|
|
|
-let textModelDownloadInProgress = false;
|
|
|
-
|
|
|
-/* TODO(MR): use the generic method. Then we can remove the exports for the
|
|
|
- internal details functions that we use here */
|
|
|
-const textModelPathDownloadingIfNeeded = async () => {
|
|
|
- if (textModelDownloadInProgress)
|
|
|
- throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
-
|
|
|
- const modelPath = modelSavePath(textModelName);
|
|
|
- if (!existsSync(modelPath)) {
|
|
|
- log.info("CLIP text model not found, downloading");
|
|
|
- textModelDownloadInProgress = true;
|
|
|
- downloadModel(modelPath, textModelName)
|
|
|
- .catch((e) => {
|
|
|
- // log but otherwise ignore
|
|
|
- log.error("CLIP text model download failed", e);
|
|
|
- })
|
|
|
- .finally(() => {
|
|
|
- textModelDownloadInProgress = false;
|
|
|
- });
|
|
|
- throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
- } else {
|
|
|
- const localFileSize = (await fs.stat(modelPath)).size;
|
|
|
- if (localFileSize !== textModelByteSize) {
|
|
|
- log.error(
|
|
|
- `CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
|
|
|
- );
|
|
|
- textModelDownloadInProgress = true;
|
|
|
- downloadModel(modelPath, textModelName)
|
|
|
- .catch((e) => {
|
|
|
- // log but otherwise ignore
|
|
|
- log.error("CLIP text model download failed", e);
|
|
|
- })
|
|
|
- .finally(() => {
|
|
|
- textModelDownloadInProgress = false;
|
|
|
- });
|
|
|
- throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return modelPath;
|
|
|
-};
|
|
|
-
|
|
|
-let _textSession: any = null;
|
|
|
-
|
|
|
-const onnxTextSession = async () => {
|
|
|
- if (!_textSession) {
|
|
|
- const modelPath = await textModelPathDownloadingIfNeeded();
|
|
|
- _textSession = await createInferenceSession(modelPath);
|
|
|
- }
|
|
|
- return _textSession;
|
|
|
-};
|
|
|
-
|
|
|
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
|
|
const tempFilePath = await generateTempFilePath("");
|
|
|
const imageStream = new Response(jpegImageData.buffer).body;
|
|
@@ -195,6 +130,11 @@ const normalizeEmbedding = (embedding: Float32Array) => {
|
|
|
return embedding;
|
|
|
};
|
|
|
|
|
|
+const cachedCLIPTextSession = makeCachedInferenceSession(
|
|
|
+ "clip-text-vit-32-uint8.onnx",
|
|
|
+ 64173509 /* 61.2 MB */,
|
|
|
+);
|
|
|
+
|
|
|
let _tokenizer: Tokenizer = null;
|
|
|
const getTokenizer = () => {
|
|
|
if (!_tokenizer) {
|
|
@@ -203,14 +143,21 @@ const getTokenizer = () => {
|
|
|
return _tokenizer;
|
|
|
};
|
|
|
|
|
|
-export const clipTextEmbedding = async (text: string) => {
|
|
|
- const session = await Promise.race([
|
|
|
+export const clipTextEmbeddingIfAvailable = async (text: string) => {
|
|
|
+ const sessionOrStatus = await Promise.race([
|
|
|
cachedCLIPTextSession(),
|
|
|
- new Promise<"downloading-model">((resolve) =>
|
|
|
- setTimeout(() => resolve("downloading-model"), 100),
|
|
|
- ),
|
|
|
+ "downloading-model",
|
|
|
]);
|
|
|
- await onnxTextSession();
|
|
|
+
|
|
|
+ // Don't wait for the download to complete
|
|
|
+ if (typeof sessionOrStatus == "string") {
|
|
|
+ console.log(
|
|
|
+ "Ignoring CLIP text embedding request because model download is pending",
|
|
|
+ );
|
|
|
+ return undefined;
|
|
|
+ }
|
|
|
+
|
|
|
+ const session = sessionOrStatus;
|
|
|
const t1 = Date.now();
|
|
|
const tokenizer = getTokenizer();
|
|
|
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
|
@@ -223,6 +170,6 @@ export const clipTextEmbedding = async (text: string) => {
|
|
|
() =>
|
|
|
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
|
|
);
|
|
|
- const textEmbedding = results["output"].data;
|
|
|
+ const textEmbedding = results["output"].data as Float32Array;
|
|
|
return normalizeEmbedding(textEmbedding);
|
|
|
};
|