Manav Rathi il y a 1 an
Parent
commit
2b6047a979

+ 248 - 0
desktop/src/main/services/ml-clip.ts

@@ -0,0 +1,248 @@
+/**
+ * @file Compute CLIP embeddings for images and text.
+ *
+ * The embeddings are computed using ONNX runtime, with CLIP as the model.
+ *
+ * @see `web/apps/photos/src/services/clip-service.ts` for more details.
+ */
+import { existsSync } from "fs";
+import jpeg from "jpeg-js";
+import fs from "node:fs/promises";
+import * as ort from "onnxruntime-node";
+import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
+import { CustomErrors } from "../../types/ipc";
+import { writeStream } from "../fs";
+import log from "../log";
+import { generateTempFilePath } from "../temp";
+import { deleteTempFile } from "./ffmpeg";
+import {
+    createInferenceSession,
+    downloadModel,
+    modelPathDownloadingIfNeeded,
+    modelSavePath,
+} from "./ml";
+
+const textModelName = "clip-text-vit-32-uint8.onnx";
+const textModelByteSize = 64173509; // 61.2 MB
+
+const imageModelName = "clip-image-vit-32-float32.onnx";
+const imageModelByteSize = 351468764; // 335.2 MB
+
+let activeImageModelDownload: Promise<string> | undefined;
+
+const imageModelPathDownloadingIfNeeded = async () => {
+    try {
+        if (activeImageModelDownload) {
+            log.info("Waiting for CLIP image model download to finish");
+            await activeImageModelDownload;
+        } else {
+            activeImageModelDownload = modelPathDownloadingIfNeeded(
+                imageModelName,
+                imageModelByteSize,
+            );
+            return await activeImageModelDownload;
+        }
+    } finally {
+        activeImageModelDownload = undefined;
+    }
+};
+
+let textModelDownloadInProgress = false;
+
+/* TODO(MR): use the generic method. Then we can remove the exports for the
+   internal details functions that we use here */
+const textModelPathDownloadingIfNeeded = async () => {
+    if (textModelDownloadInProgress)
+        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+
+    const modelPath = modelSavePath(textModelName);
+    if (!existsSync(modelPath)) {
+        log.info("CLIP text model not found, downloading");
+        textModelDownloadInProgress = true;
+        downloadModel(modelPath, textModelName)
+            .catch((e) => {
+                // log but otherwise ignore
+                log.error("CLIP text model download failed", e);
+            })
+            .finally(() => {
+                textModelDownloadInProgress = false;
+            });
+        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+    } else {
+        const localFileSize = (await fs.stat(modelPath)).size;
+        if (localFileSize !== textModelByteSize) {
+            log.error(
+                `CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
+            );
+            textModelDownloadInProgress = true;
+            downloadModel(modelPath, textModelName)
+                .catch((e) => {
+                    // log but otherwise ignore
+                    log.error("CLIP text model download failed", e);
+                })
+                .finally(() => {
+                    textModelDownloadInProgress = false;
+                });
+            throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+        }
+    }
+
+    return modelPath;
+};
+
+let imageSessionPromise: Promise<any> | undefined;
+
+const onnxImageSession = async () => {
+    if (!imageSessionPromise) {
+        imageSessionPromise = (async () => {
+            const modelPath = await imageModelPathDownloadingIfNeeded();
+            return createInferenceSession(modelPath);
+        })();
+    }
+    return imageSessionPromise;
+};
+
+let _textSession: any = null;
+
+const onnxTextSession = async () => {
+    if (!_textSession) {
+        const modelPath = await textModelPathDownloadingIfNeeded();
+        _textSession = await createInferenceSession(modelPath);
+    }
+    return _textSession;
+};
+
+export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
+    const tempFilePath = await generateTempFilePath("");
+    const imageStream = new Response(jpegImageData.buffer).body;
+    await writeStream(tempFilePath, imageStream);
+    try {
+        return await clipImageEmbedding_(tempFilePath);
+    } finally {
+        await deleteTempFile(tempFilePath);
+    }
+};
+
+const clipImageEmbedding_ = async (jpegFilePath: string) => {
+    const imageSession = await onnxImageSession();
+    const t1 = Date.now();
+    const rgbData = await getRGBData(jpegFilePath);
+    const feeds = {
+        input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
+    };
+    const t2 = Date.now();
+    const results = await imageSession.run(feeds);
+    log.debug(
+        () =>
+            `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
+    );
+    const imageEmbedding = results["output"].data; // Float32Array
+    return normalizeEmbedding(imageEmbedding);
+};
+
+const getRGBData = async (jpegFilePath: string) => {
+    const jpegData = await fs.readFile(jpegFilePath);
+    const rawImageData = jpeg.decode(jpegData, {
+        useTArray: true,
+        formatAsRGBA: false,
+    });
+
+    const nx: number = rawImageData.width;
+    const ny: number = rawImageData.height;
+    const inputImage: Uint8Array = rawImageData.data;
+
+    const nx2: number = 224;
+    const ny2: number = 224;
+    const totalSize: number = 3 * nx2 * ny2;
+
+    const result: number[] = Array(totalSize).fill(0);
+    const scale: number = Math.max(nx, ny) / 224;
+
+    const nx3: number = Math.round(nx / scale);
+    const ny3: number = Math.round(ny / scale);
+
+    const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
+    const std: number[] = [0.26862954, 0.26130258, 0.27577711];
+
+    for (let y = 0; y < ny3; y++) {
+        for (let x = 0; x < nx3; x++) {
+            for (let c = 0; c < 3; c++) {
+                // Linear interpolation
+                const sx: number = (x + 0.5) * scale - 0.5;
+                const sy: number = (y + 0.5) * scale - 0.5;
+
+                const x0: number = Math.max(0, Math.floor(sx));
+                const y0: number = Math.max(0, Math.floor(sy));
+
+                const x1: number = Math.min(x0 + 1, nx - 1);
+                const y1: number = Math.min(y0 + 1, ny - 1);
+
+                const dx: number = sx - x0;
+                const dy: number = sy - y0;
+
+                const j00: number = 3 * (y0 * nx + x0) + c;
+                const j01: number = 3 * (y0 * nx + x1) + c;
+                const j10: number = 3 * (y1 * nx + x0) + c;
+                const j11: number = 3 * (y1 * nx + x1) + c;
+
+                const v00: number = inputImage[j00];
+                const v01: number = inputImage[j01];
+                const v10: number = inputImage[j10];
+                const v11: number = inputImage[j11];
+
+                const v0: number = v00 * (1 - dx) + v01 * dx;
+                const v1: number = v10 * (1 - dx) + v11 * dx;
+
+                const v: number = v0 * (1 - dy) + v1 * dy;
+
+                const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
+
+                // createTensorWithDataList is dumb compared to reshape and
+                // hence has to be given with one channel after another
+                const i: number = y * nx3 + x + (c % 3) * 224 * 224;
+
+                result[i] = (v2 / 255 - mean[c]) / std[c];
+            }
+        }
+    }
+
+    return result;
+};
+
+const normalizeEmbedding = (embedding: Float32Array) => {
+    let normalization = 0;
+    for (let index = 0; index < embedding.length; index++) {
+        normalization += embedding[index] * embedding[index];
+    }
+    const sqrtNormalization = Math.sqrt(normalization);
+    for (let index = 0; index < embedding.length; index++) {
+        embedding[index] = embedding[index] / sqrtNormalization;
+    }
+    return embedding;
+};
+
+let _tokenizer: Tokenizer = null;
+const getTokenizer = () => {
+    if (!_tokenizer) {
+        _tokenizer = new Tokenizer();
+    }
+    return _tokenizer;
+};
+
+export const clipTextEmbedding = async (text: string) => {
+    const imageSession = await onnxTextSession();
+    const t1 = Date.now();
+    const tokenizer = getTokenizer();
+    const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
+    const feeds = {
+        input: new ort.Tensor("int32", tokenizedText, [1, 77]),
+    };
+    const t2 = Date.now();
+    const results = await imageSession.run(feeds);
+    log.debug(
+        () =>
+            `CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
+    );
+    const textEmbedding = results["output"].data;
+    return normalizeEmbedding(textEmbedding);
+};

+ 77 - 0
desktop/src/main/services/ml-face.ts

@@ -0,0 +1,77 @@
+/**
+ * @file Various face recognition related tasks.
+ *
+ * - Face detection with the YOLO model.
+ * - Face embedding with the mobilefacenet model.
+ *
+ * The runtime used is ONNX.
+ */
+import * as ort from "onnxruntime-node";
+import log from "../log";
+import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
+
+const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
+const faceDetectionModelByteSize = 30762872; // 29.3 MB
+
+const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
+const faceEmbeddingModelByteSize = 5286998; // 5 MB
+
+let activeFaceDetectionModelDownload: Promise<string> | undefined;
+
+const faceDetectionModelPathDownloadingIfNeeded = async () => {
+    try {
+        if (activeFaceDetectionModelDownload) {
+            log.info("Waiting for face detection model download to finish");
+            await activeFaceDetectionModelDownload;
+        } else {
+            activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
+                faceDetectionModelName,
+                faceDetectionModelByteSize,
+            );
+            return await activeFaceDetectionModelDownload;
+        }
+    } finally {
+        activeFaceDetectionModelDownload = undefined;
+    }
+};
+
+let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
+
+const faceDetectionSession = async () => {
+    if (!_faceDetectionSession) {
+        _faceDetectionSession =
+            faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
+                createInferenceSession(modelPath),
+            );
+    }
+    return _faceDetectionSession;
+};
+
+
+// export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
+//     const tempFilePath = await generateTempFilePath("");
+//     const imageStream = new Response(jpegImageData.buffer).body;
+//     await writeStream(tempFilePath, imageStream);
+//     try {
+//         return await clipImageEmbedding_(tempFilePath);
+//     } finally {
+//         await deleteTempFile(tempFilePath);
+//     }
+// };
+
+// const clipImageEmbedding_ = async (jpegFilePath: string) => {
+//     const imageSession = await onnxImageSession();
+//     const t1 = Date.now();
+//     const rgbData = await getRGBData(jpegFilePath);
+//     const feeds = {
+//         input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
+//     };
+//     const t2 = Date.now();
+//     const results = await imageSession.run(feeds);
+//     log.debug(
+//         () =>
+//             `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
+//     );
+//     const imageEmbedding = results["output"].data; // Float32Array
+//     return normalizeEmbedding(imageEmbedding);
+// };

+ 79 - 0
desktop/src/main/services/ml.ts

@@ -0,0 +1,79 @@
+/**
+ * @file AI/ML related functionality.
+ *
+ * @see also `ml-clip.ts`, `ml-face.ts`.
+ *
+ * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
+ * for various tasks are not shipped with the app but are downloaded on demand.
+ *
+ * The primary reason for doing these tasks in the Node.js layer is so that we
+ * can use the binary ONNX runtime which is 10-20x faster than the WASM based
+ * web one.
+ */
+import { app, net } from "electron/main";
+import { existsSync } from "fs";
+import fs from "node:fs/promises";
+import path from "node:path";
+import * as ort from "onnxruntime-node";
+import { writeStream } from "../fs";
+import log from "../log";
+
+/**
+ * Download the model named {@link modelName} if we don't already have it.
+ *
+ * Also verify that the size of the model we get matches {@expectedByteSize} (if
+ * not, redownload it).
+ *
+ * @returns the path to the model on the local machine.
+ */
+export const modelPathDownloadingIfNeeded = async (
+    modelName: string,
+    expectedByteSize: number,
+) => {
+    const modelPath = modelSavePath(modelName);
+
+    if (!existsSync(modelPath)) {
+        log.info("CLIP image model not found, downloading");
+        await downloadModel(modelPath, modelName);
+    } else {
+        const size = (await fs.stat(modelPath)).size;
+        if (size !== expectedByteSize) {
+            log.error(
+                `The size ${size} of model ${modelName} does not match the expected size, downloading again`,
+            );
+            await downloadModel(modelPath, modelName);
+        }
+    }
+
+    return modelPath;
+};
+
+/** Return the path where the given {@link modelName} is meant to be saved */
+export const modelSavePath = (modelName: string) =>
+    path.join(app.getPath("userData"), "models", modelName);
+
+export const downloadModel = async (saveLocation: string, name: string) => {
+    // `mkdir -p` the directory where we want to save the model.
+    const saveDir = path.dirname(saveLocation);
+    await fs.mkdir(saveDir, { recursive: true });
+    // Download
+    log.info(`Downloading ML model from ${name}`);
+    const url = `https://models.ente.io/${name}`;
+    const res = await net.fetch(url);
+    if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
+    // Save
+    await writeStream(saveLocation, res.body);
+    log.info(`Downloaded CLIP model ${name}`);
+};
+
+/**
+ * Crete an ONNX {@link InferenceSession} with some defaults.
+ */
+export const createInferenceSession = async (modelPath: string) => {
+    return await ort.InferenceSession.create(modelPath, {
+        // Restrict the number of threads to 1
+        intraOpNumThreads: 1,
+        // Be more conservative with RAM usage
+        enableCpuMemArena: false,
+    });
+};