Ver código fonte

[desktop] Remove GGML (#1394)

Manav Rathi 1 ano atrás
pai
commit
eebb90fb40

BIN
desktop/build/ggmlclip-linux


BIN
desktop/build/ggmlclip-mac


BIN
desktop/build/ggmlclip-windows.exe


BIN
desktop/build/msvcp140d.dll


BIN
desktop/build/ucrtbased.dll


BIN
desktop/build/vcruntime140_1d.dll


BIN
desktop/build/vcruntime140d.dll


+ 5 - 5
desktop/docs/dependencies.md

@@ -111,11 +111,11 @@ watcher for the watch folders functionality.
 
 ### AI/ML
 
--   [onnxruntime-node](https://github.com/Microsoft/onnxruntime)
--   html-entities is used by the bundled clip-bpe-ts.
--   GGML binaries are bundled
--   We also use [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) for
-    conversion of all images to JPEG before processing.
+-   [onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used for
+    natural language searches based on CLIP.
+-   html-entities is used by the bundled clip-bpe-ts tokenizer.
+-   [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding
+    JPEG data into raw RGB bytes before passing it to ONNX.
 
 ## ZIP
 

+ 0 - 1
desktop/electron-builder.yml

@@ -19,7 +19,6 @@ mac:
         arch: [universal]
     category: public.app-category.photography
     hardenedRuntime: true
-    x64ArchFiles: Contents/Resources/ggmlclip-mac
 afterSign: electron-builder-notarize
 extraFiles:
     - from: build

+ 8 - 15
desktop/src/main/ipc.ts

@@ -17,9 +17,9 @@ import {
     updateAndRestart,
 } from "../services/appUpdater";
 import {
-    computeImageEmbedding,
-    computeTextEmbedding,
-} from "../services/clipService";
+    clipImageEmbedding,
+    clipTextEmbedding,
+} from "../services/clip-service";
 import { runFFmpegCmd } from "../services/ffmpeg";
 import { getDirFiles } from "../services/fs";
 import {
@@ -44,12 +44,7 @@ import {
     updateWatchMappingIgnoredFiles,
     updateWatchMappingSyncedFiles,
 } from "../services/watch";
-import type {
-    ElectronFile,
-    FILE_PATH_TYPE,
-    Model,
-    WatchMapping,
-} from "../types/ipc";
+import type { ElectronFile, FILE_PATH_TYPE, WatchMapping } from "../types/ipc";
 import {
     selectDirectory,
     showUploadDirsDialog,
@@ -148,14 +143,12 @@ export const attachIPCHandlers = () => {
 
     // - ML
 
-    ipcMain.handle(
-        "computeImageEmbedding",
-        (_, model: Model, imageData: Uint8Array) =>
-            computeImageEmbedding(model, imageData),
+    ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) =>
+        clipImageEmbedding(jpegImageData),
     );
 
-    ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) =>
-        computeTextEmbedding(model, text),
+    ipcMain.handle("clipTextEmbedding", (_, text: string) =>
+        clipTextEmbedding(text),
     );
 
     // - File selection

+ 4 - 1
desktop/src/main/log.ts

@@ -64,7 +64,10 @@ const logInfo = (...params: any[]) => {
 };
 
 const logDebug = (param: () => any) => {
-    if (isDev) console.log(`[debug] ${util.inspect(param())}`);
+    if (isDev) {
+        const p = param();
+        console.log(`[debug] ${typeof p == "string" ? p : util.inspect(p)}`);
+    }
 };
 
 /**

+ 6 - 13
desktop/src/preload.ts

@@ -45,7 +45,6 @@ import type {
     AppUpdateInfo,
     ElectronFile,
     FILE_PATH_TYPE,
-    Model,
     WatchMapping,
 } from "./types/ipc";
 
@@ -141,17 +140,11 @@ const runFFmpegCmd = (
 
 // - ML
 
-const computeImageEmbedding = (
-    model: Model,
-    imageData: Uint8Array,
-): Promise<Float32Array> =>
-    ipcRenderer.invoke("computeImageEmbedding", model, imageData);
+const clipImageEmbedding = (jpegImageData: Uint8Array): Promise<Float32Array> =>
+    ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
 
-const computeTextEmbedding = (
-    model: Model,
-    text: string,
-): Promise<Float32Array> =>
-    ipcRenderer.invoke("computeTextEmbedding", model, text);
+const clipTextEmbedding = (text: string): Promise<Float32Array> =>
+    ipcRenderer.invoke("clipTextEmbedding", text);
 
 // - File selection
 
@@ -332,8 +325,8 @@ contextBridge.exposeInMainWorld("electron", {
     runFFmpegCmd,
 
     // - ML
-    computeImageEmbedding,
-    computeTextEmbedding,
+    clipImageEmbedding,
+    clipTextEmbedding,
 
     // - File selection
     selectDirectory,

+ 288 - 0
desktop/src/services/clip-service.ts

@@ -0,0 +1,288 @@
+/**
+ * @file Compute CLIP embeddings
+ *
+ * @see `web/apps/photos/src/services/clip-service.ts` for more details. This
+ * file implements the Node.js implementation of the actual embedding
+ * computation. By doing it in the Node.js layer, we can use the binary ONNX
+ * runtimes which are 10-20x faster than the WASM based web ones.
+ *
+ * The embeddings are computed using ONNX runtime. The model itself is not
+ * shipped with the app but is downloaded on demand.
+ */
+import { app, net } from "electron/main";
+import { existsSync } from "fs";
+import fs from "node:fs/promises";
+import path from "node:path";
+import { writeStream } from "../main/fs";
+import log from "../main/log";
+import { CustomErrors } from "../types/ipc";
+import Tokenizer from "../utils/clip-bpe-ts/mod";
+import { generateTempFilePath } from "../utils/temp";
+import { deleteTempFile } from "./ffmpeg";
+const jpeg = require("jpeg-js");
+const ort = require("onnxruntime-node");
+
+const textModelName = "clip-text-vit-32-uint8.onnx";
+const textModelByteSize = 64173509; // 61.2 MB
+
+const imageModelName = "clip-image-vit-32-float32.onnx";
+const imageModelByteSize = 351468764; // 335.2 MB
+
+/** Return the path where the given {@link modelName} is meant to be saved */
+const modelSavePath = (modelName: string) =>
+    path.join(app.getPath("userData"), "models", modelName);
+
+const downloadModel = async (saveLocation: string, name: string) => {
+    // `mkdir -p` the directory where we want to save the model.
+    const saveDir = path.dirname(saveLocation);
+    await fs.mkdir(saveDir, { recursive: true });
+    // Download
+    log.info(`Downloading CLIP model from ${name}`);
+    const url = `https://models.ente.io/${name}`;
+    const res = await net.fetch(url);
+    if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
+    // Save
+    await writeStream(saveLocation, res.body);
+    log.info(`Downloaded CLIP model ${name}`);
+};
+
+let activeImageModelDownload: Promise<void> | undefined;
+
+const imageModelPathDownloadingIfNeeded = async () => {
+    try {
+        const modelPath = modelSavePath(imageModelName);
+        if (activeImageModelDownload) {
+            log.info("Waiting for CLIP image model download to finish");
+            await activeImageModelDownload;
+        } else {
+            if (!existsSync(modelPath)) {
+                log.info("CLIP image model not found, downloading");
+                activeImageModelDownload = downloadModel(
+                    modelPath,
+                    imageModelName,
+                );
+                await activeImageModelDownload;
+            } else {
+                const localFileSize = (await fs.stat(modelPath)).size;
+                if (localFileSize !== imageModelByteSize) {
+                    log.error(
+                        `CLIP image model size ${localFileSize} does not match the expected size, downloading again`,
+                    );
+                    activeImageModelDownload = downloadModel(
+                        modelPath,
+                        imageModelName,
+                    );
+                    await activeImageModelDownload;
+                }
+            }
+        }
+        return modelPath;
+    } finally {
+        activeImageModelDownload = undefined;
+    }
+};
+
+let textModelDownloadInProgress = false;
+
+const textModelPathDownloadingIfNeeded = async () => {
+    if (textModelDownloadInProgress)
+        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+
+    const modelPath = modelSavePath(textModelName);
+    if (!existsSync(modelPath)) {
+        log.info("CLIP text model not found, downloading");
+        textModelDownloadInProgress = true;
+        downloadModel(modelPath, textModelName)
+            .catch((e) => {
+                // log but otherwise ignore
+                log.error("CLIP text model download failed", e);
+            })
+            .finally(() => {
+                textModelDownloadInProgress = false;
+            });
+        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+    } else {
+        const localFileSize = (await fs.stat(modelPath)).size;
+        if (localFileSize !== textModelByteSize) {
+            log.error(
+                `CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
+            );
+            textModelDownloadInProgress = true;
+            downloadModel(modelPath, textModelName)
+                .catch((e) => {
+                    // log but otherwise ignore
+                    log.error("CLIP text model download failed", e);
+                })
+                .finally(() => {
+                    textModelDownloadInProgress = false;
+                });
+            throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
+        }
+    }
+
+    return modelPath;
+};
+
+const createInferenceSession = async (modelPath: string) => {
+    return await ort.InferenceSession.create(modelPath, {
+        intraOpNumThreads: 1,
+        enableCpuMemArena: false,
+    });
+};
+
+let imageSessionPromise: Promise<any> | undefined;
+
+const onnxImageSession = async () => {
+    if (!imageSessionPromise) {
+        imageSessionPromise = (async () => {
+            const modelPath = await imageModelPathDownloadingIfNeeded();
+            return createInferenceSession(modelPath);
+        })();
+    }
+    return imageSessionPromise;
+};
+
+let _textSession: any = null;
+
+const onnxTextSession = async () => {
+    if (!_textSession) {
+        const modelPath = await textModelPathDownloadingIfNeeded();
+        _textSession = await createInferenceSession(modelPath);
+    }
+    return _textSession;
+};
+
+export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
+    const tempFilePath = await generateTempFilePath("");
+    const imageStream = new Response(jpegImageData.buffer).body;
+    await writeStream(tempFilePath, imageStream);
+    try {
+        return await clipImageEmbedding_(tempFilePath);
+    } finally {
+        await deleteTempFile(tempFilePath);
+    }
+};
+
+const clipImageEmbedding_ = async (jpegFilePath: string) => {
+    const imageSession = await onnxImageSession();
+    const t1 = Date.now();
+    const rgbData = await getRGBData(jpegFilePath);
+    const feeds = {
+        input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
+    };
+    const t2 = Date.now();
+    const results = await imageSession.run(feeds);
+    log.debug(
+        () =>
+            `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
+    );
+    const imageEmbedding = results["output"].data; // Float32Array
+    return normalizeEmbedding(imageEmbedding);
+};
+
+const getRGBData = async (jpegFilePath: string) => {
+    const jpegData = await fs.readFile(jpegFilePath);
+    const rawImageData = jpeg.decode(jpegData, {
+        useTArray: true,
+        formatAsRGBA: false,
+    });
+
+    const nx: number = rawImageData.width;
+    const ny: number = rawImageData.height;
+    const inputImage: Uint8Array = rawImageData.data;
+
+    const nx2: number = 224;
+    const ny2: number = 224;
+    const totalSize: number = 3 * nx2 * ny2;
+
+    const result: number[] = Array(totalSize).fill(0);
+    const scale: number = Math.max(nx, ny) / 224;
+
+    const nx3: number = Math.round(nx / scale);
+    const ny3: number = Math.round(ny / scale);
+
+    const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
+    const std: number[] = [0.26862954, 0.26130258, 0.27577711];
+
+    for (let y = 0; y < ny3; y++) {
+        for (let x = 0; x < nx3; x++) {
+            for (let c = 0; c < 3; c++) {
+                // Linear interpolation
+                const sx: number = (x + 0.5) * scale - 0.5;
+                const sy: number = (y + 0.5) * scale - 0.5;
+
+                const x0: number = Math.max(0, Math.floor(sx));
+                const y0: number = Math.max(0, Math.floor(sy));
+
+                const x1: number = Math.min(x0 + 1, nx - 1);
+                const y1: number = Math.min(y0 + 1, ny - 1);
+
+                const dx: number = sx - x0;
+                const dy: number = sy - y0;
+
+                const j00: number = 3 * (y0 * nx + x0) + c;
+                const j01: number = 3 * (y0 * nx + x1) + c;
+                const j10: number = 3 * (y1 * nx + x0) + c;
+                const j11: number = 3 * (y1 * nx + x1) + c;
+
+                const v00: number = inputImage[j00];
+                const v01: number = inputImage[j01];
+                const v10: number = inputImage[j10];
+                const v11: number = inputImage[j11];
+
+                const v0: number = v00 * (1 - dx) + v01 * dx;
+                const v1: number = v10 * (1 - dx) + v11 * dx;
+
+                const v: number = v0 * (1 - dy) + v1 * dy;
+
+                const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
+
+                // createTensorWithDataList is dumb compared to reshape and
+                // hence has to be given with one channel after another
+                const i: number = y * nx3 + x + (c % 3) * 224 * 224;
+
+                result[i] = (v2 / 255 - mean[c]) / std[c];
+            }
+        }
+    }
+
+    return result;
+};
+
+const normalizeEmbedding = (embedding: Float32Array) => {
+    let normalization = 0;
+    for (let index = 0; index < embedding.length; index++) {
+        normalization += embedding[index] * embedding[index];
+    }
+    const sqrtNormalization = Math.sqrt(normalization);
+    for (let index = 0; index < embedding.length; index++) {
+        embedding[index] = embedding[index] / sqrtNormalization;
+    }
+    return embedding;
+};
+
+let _tokenizer: Tokenizer = null;
+const getTokenizer = () => {
+    if (!_tokenizer) {
+        _tokenizer = new Tokenizer();
+    }
+    return _tokenizer;
+};
+
+export const clipTextEmbedding = async (text: string) => {
+    const imageSession = await onnxTextSession();
+    const t1 = Date.now();
+    const tokenizer = getTokenizer();
+    const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
+    const feeds = {
+        input: new ort.Tensor("int32", tokenizedText, [1, 77]),
+    };
+    const t2 = Date.now();
+    const results = await imageSession.run(feeds);
+    log.debug(
+        () =>
+            `CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
+    );
+    const textEmbedding = results["output"].data;
+    return normalizeEmbedding(textEmbedding);
+};

+ 0 - 463
desktop/src/services/clipService.ts

@@ -1,463 +0,0 @@
-import { app, net } from "electron/main";
-import { existsSync } from "fs";
-import fs from "node:fs/promises";
-import path from "node:path";
-import { writeStream } from "../main/fs";
-import log from "../main/log";
-import { execAsync, isDev } from "../main/util";
-import { CustomErrors, Model, isModel } from "../types/ipc";
-import Tokenizer from "../utils/clip-bpe-ts/mod";
-import { getPlatform } from "../utils/common/platform";
-import { generateTempFilePath } from "../utils/temp";
-import { deleteTempFile } from "./ffmpeg";
-const jpeg = require("jpeg-js");
-
-const CLIP_MODEL_PATH_PLACEHOLDER = "CLIP_MODEL";
-const GGMLCLIP_PATH_PLACEHOLDER = "GGML_PATH";
-const INPUT_PATH_PLACEHOLDER = "INPUT";
-
-const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
-    GGMLCLIP_PATH_PLACEHOLDER,
-    "-mv",
-    CLIP_MODEL_PATH_PLACEHOLDER,
-    "--image",
-    INPUT_PATH_PLACEHOLDER,
-];
-
-const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
-    GGMLCLIP_PATH_PLACEHOLDER,
-    "-mt",
-    CLIP_MODEL_PATH_PLACEHOLDER,
-    "--text",
-    INPUT_PATH_PLACEHOLDER,
-];
-const ort = require("onnxruntime-node");
-
-const TEXT_MODEL_DOWNLOAD_URL = {
-    ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf",
-    onnx: "https://models.ente.io/clip-text-vit-32-uint8.onnx",
-};
-const IMAGE_MODEL_DOWNLOAD_URL = {
-    ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf",
-    onnx: "https://models.ente.io/clip-image-vit-32-float32.onnx",
-};
-
-const TEXT_MODEL_NAME = {
-    ggml: "clip-vit-base-patch32_ggml-text-model-f16.gguf",
-    onnx: "clip-text-vit-32-uint8.onnx",
-};
-const IMAGE_MODEL_NAME = {
-    ggml: "clip-vit-base-patch32_ggml-vision-model-f16.gguf",
-    onnx: "clip-image-vit-32-float32.onnx",
-};
-
-const IMAGE_MODEL_SIZE_IN_BYTES = {
-    ggml: 175957504, // 167.8 MB
-    onnx: 351468764, // 335.2 MB
-};
-const TEXT_MODEL_SIZE_IN_BYTES = {
-    ggml: 127853440, // 121.9 MB,
-    onnx: 64173509, // 61.2 MB
-};
-
-/** Return the path where the given {@link modelName} is meant to be saved */
-const getModelSavePath = (modelName: string) =>
-    path.join(app.getPath("userData"), "models", modelName);
-
-async function downloadModel(saveLocation: string, url: string) {
-    // confirm that the save location exists
-    const saveDir = path.dirname(saveLocation);
-    await fs.mkdir(saveDir, { recursive: true });
-    log.info("downloading clip model");
-    const res = await net.fetch(url);
-    if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
-    await writeStream(saveLocation, res.body);
-    log.info("clip model downloaded");
-}
-
-let imageModelDownloadInProgress: Promise<void> = null;
-
-const getClipImageModelPath = async (type: "ggml" | "onnx") => {
-    try {
-        const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
-        if (imageModelDownloadInProgress) {
-            log.info("waiting for image model download to finish");
-            await imageModelDownloadInProgress;
-        } else {
-            if (!existsSync(modelSavePath)) {
-                log.info("CLIP image model not found, downloading");
-                imageModelDownloadInProgress = downloadModel(
-                    modelSavePath,
-                    IMAGE_MODEL_DOWNLOAD_URL[type],
-                );
-                await imageModelDownloadInProgress;
-            } else {
-                const localFileSize = (await fs.stat(modelSavePath)).size;
-                if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
-                    log.info(
-                        `CLIP image model size mismatch, downloading again got: ${localFileSize}`,
-                    );
-                    imageModelDownloadInProgress = downloadModel(
-                        modelSavePath,
-                        IMAGE_MODEL_DOWNLOAD_URL[type],
-                    );
-                    await imageModelDownloadInProgress;
-                }
-            }
-        }
-        return modelSavePath;
-    } finally {
-        imageModelDownloadInProgress = null;
-    }
-};
-
-let textModelDownloadInProgress: boolean = false;
-
-const getClipTextModelPath = async (type: "ggml" | "onnx") => {
-    const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
-    if (textModelDownloadInProgress) {
-        throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-    } else {
-        if (!existsSync(modelSavePath)) {
-            log.info("CLIP text model not found, downloading");
-            textModelDownloadInProgress = true;
-            downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
-                .catch((e) => {
-                    // log but otherwise ignore
-                    log.error("CLIP text model download failed", e);
-                })
-                .finally(() => {
-                    textModelDownloadInProgress = false;
-                });
-            throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-        } else {
-            const localFileSize = (await fs.stat(modelSavePath)).size;
-            if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
-                log.info(
-                    `CLIP text model size mismatch, downloading again got: ${localFileSize}`,
-                );
-                textModelDownloadInProgress = true;
-                downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
-                    .catch((e) => {
-                        // log but otherwise ignore
-                        log.error("CLIP text model download failed", e);
-                    })
-                    .finally(() => {
-                        textModelDownloadInProgress = false;
-                    });
-                throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
-            }
-        }
-    }
-    return modelSavePath;
-};
-
-function getGGMLClipPath() {
-    return isDev
-        ? path.join("./build", `ggmlclip-${getPlatform()}`)
-        : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
-}
-
-async function createOnnxSession(modelPath: string) {
-    return await ort.InferenceSession.create(modelPath, {
-        intraOpNumThreads: 1,
-        enableCpuMemArena: false,
-    });
-}
-
-let onnxImageSessionPromise: Promise<any> = null;
-
-async function getOnnxImageSession() {
-    if (!onnxImageSessionPromise) {
-        onnxImageSessionPromise = (async () => {
-            const clipModelPath = await getClipImageModelPath("onnx");
-            return createOnnxSession(clipModelPath);
-        })();
-    }
-    return onnxImageSessionPromise;
-}
-
-let onnxTextSession: any = null;
-
-async function getOnnxTextSession() {
-    if (!onnxTextSession) {
-        const clipModelPath = await getClipTextModelPath("onnx");
-        onnxTextSession = await createOnnxSession(clipModelPath);
-    }
-    return onnxTextSession;
-}
-
-let tokenizer: Tokenizer = null;
-function getTokenizer() {
-    if (!tokenizer) {
-        tokenizer = new Tokenizer();
-    }
-    return tokenizer;
-}
-
-export const computeImageEmbedding = async (
-    model: Model,
-    imageData: Uint8Array,
-): Promise<Float32Array> => {
-    if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
-
-    let tempInputFilePath = null;
-    try {
-        tempInputFilePath = await generateTempFilePath("");
-        const imageStream = new Response(imageData.buffer).body;
-        await writeStream(tempInputFilePath, imageStream);
-        const embedding = await computeImageEmbedding_(
-            model,
-            tempInputFilePath,
-        );
-        return embedding;
-    } catch (err) {
-        if (isExecError(err)) {
-            const parsedExecError = parseExecError(err);
-            throw Error(parsedExecError);
-        } else {
-            throw err;
-        }
-    } finally {
-        if (tempInputFilePath) {
-            await deleteTempFile(tempInputFilePath);
-        }
-    }
-};
-
-const isExecError = (err: any) => {
-    return err.message.includes("Command failed:");
-};
-
-const parseExecError = (err: any) => {
-    const errMessage = err.message;
-    if (errMessage.includes("Bad CPU type in executable")) {
-        return CustomErrors.UNSUPPORTED_PLATFORM(
-            process.platform,
-            process.arch,
-        );
-    } else {
-        return errMessage;
-    }
-};
-
-async function computeImageEmbedding_(
-    model: Model,
-    inputFilePath: string,
-): Promise<Float32Array> {
-    if (!existsSync(inputFilePath)) {
-        throw new Error("Invalid file path");
-    }
-    switch (model) {
-        case "ggml-clip":
-            return await computeGGMLImageEmbedding(inputFilePath);
-        case "onnx-clip":
-            return await computeONNXImageEmbedding(inputFilePath);
-    }
-}
-
-const computeGGMLImageEmbedding = async (
-    inputFilePath: string,
-): Promise<Float32Array> => {
-    const clipModelPath = await getClipImageModelPath("ggml");
-    const ggmlclipPath = getGGMLClipPath();
-    const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
-        if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
-            return ggmlclipPath;
-        } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
-            return clipModelPath;
-        } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
-            return inputFilePath;
-        } else {
-            return cmdPart;
-        }
-    });
-
-    const { stdout } = await execAsync(cmd);
-    // parse stdout and return embedding
-    // get the last line of stdout
-    const lines = stdout.split("\n");
-    const lastLine = lines[lines.length - 1];
-    const embedding = JSON.parse(lastLine);
-    const embeddingArray = new Float32Array(embedding);
-    return embeddingArray;
-};
-
-const computeONNXImageEmbedding = async (
-    inputFilePath: string,
-): Promise<Float32Array> => {
-    const imageSession = await getOnnxImageSession();
-    const t1 = Date.now();
-    const rgbData = await getRGBData(inputFilePath);
-    const feeds = {
-        input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
-    };
-    const t2 = Date.now();
-    const results = await imageSession.run(feeds);
-    log.info(
-        `onnx image embedding time: ${Date.now() - t1} ms (prep:${
-            t2 - t1
-        } ms, extraction: ${Date.now() - t2} ms)`,
-    );
-    const imageEmbedding = results["output"].data; // Float32Array
-    return normalizeEmbedding(imageEmbedding);
-};
-
-async function getRGBData(inputFilePath: string) {
-    const jpegData = await fs.readFile(inputFilePath);
-    const rawImageData = jpeg.decode(jpegData, {
-        useTArray: true,
-        formatAsRGBA: false,
-    });
-
-    const nx: number = rawImageData.width;
-    const ny: number = rawImageData.height;
-    const inputImage: Uint8Array = rawImageData.data;
-
-    const nx2: number = 224;
-    const ny2: number = 224;
-    const totalSize: number = 3 * nx2 * ny2;
-
-    const result: number[] = Array(totalSize).fill(0);
-    const scale: number = Math.max(nx, ny) / 224;
-
-    const nx3: number = Math.round(nx / scale);
-    const ny3: number = Math.round(ny / scale);
-
-    const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
-    const std: number[] = [0.26862954, 0.26130258, 0.27577711];
-
-    for (let y = 0; y < ny3; y++) {
-        for (let x = 0; x < nx3; x++) {
-            for (let c = 0; c < 3; c++) {
-                // linear interpolation
-                const sx: number = (x + 0.5) * scale - 0.5;
-                const sy: number = (y + 0.5) * scale - 0.5;
-
-                const x0: number = Math.max(0, Math.floor(sx));
-                const y0: number = Math.max(0, Math.floor(sy));
-
-                const x1: number = Math.min(x0 + 1, nx - 1);
-                const y1: number = Math.min(y0 + 1, ny - 1);
-
-                const dx: number = sx - x0;
-                const dy: number = sy - y0;
-
-                const j00: number = 3 * (y0 * nx + x0) + c;
-                const j01: number = 3 * (y0 * nx + x1) + c;
-                const j10: number = 3 * (y1 * nx + x0) + c;
-                const j11: number = 3 * (y1 * nx + x1) + c;
-
-                const v00: number = inputImage[j00];
-                const v01: number = inputImage[j01];
-                const v10: number = inputImage[j10];
-                const v11: number = inputImage[j11];
-
-                const v0: number = v00 * (1 - dx) + v01 * dx;
-                const v1: number = v10 * (1 - dx) + v11 * dx;
-
-                const v: number = v0 * (1 - dy) + v1 * dy;
-
-                const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
-
-                // createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
-                const i: number = y * nx3 + x + (c % 3) * 224 * 224;
-
-                result[i] = (v2 / 255 - mean[c]) / std[c];
-            }
-        }
-    }
-
-    return result;
-}
-
-const normalizeEmbedding = (embedding: Float32Array) => {
-    let normalization = 0;
-    for (let index = 0; index < embedding.length; index++) {
-        normalization += embedding[index] * embedding[index];
-    }
-    const sqrtNormalization = Math.sqrt(normalization);
-    for (let index = 0; index < embedding.length; index++) {
-        embedding[index] = embedding[index] / sqrtNormalization;
-    }
-    return embedding;
-};
-
-export async function computeTextEmbedding(
-    model: Model,
-    text: string,
-): Promise<Float32Array> {
-    if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`);
-
-    try {
-        const embedding = computeTextEmbedding_(model, text);
-        return embedding;
-    } catch (err) {
-        if (isExecError(err)) {
-            const parsedExecError = parseExecError(err);
-            throw Error(parsedExecError);
-        } else {
-            throw err;
-        }
-    }
-}
-
-async function computeTextEmbedding_(
-    model: Model,
-    text: string,
-): Promise<Float32Array> {
-    switch (model) {
-        case "ggml-clip":
-            return await computeGGMLTextEmbedding(text);
-        case "onnx-clip":
-            return await computeONNXTextEmbedding(text);
-    }
-}
-
-export async function computeGGMLTextEmbedding(
-    text: string,
-): Promise<Float32Array> {
-    const clipModelPath = await getClipTextModelPath("ggml");
-    const ggmlclipPath = getGGMLClipPath();
-    const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
-        if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
-            return ggmlclipPath;
-        } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
-            return clipModelPath;
-        } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
-            return text;
-        } else {
-            return cmdPart;
-        }
-    });
-
-    const { stdout } = await execAsync(cmd);
-    // parse stdout and return embedding
-    // get the last line of stdout
-    const lines = stdout.split("\n");
-    const lastLine = lines[lines.length - 1];
-    const embedding = JSON.parse(lastLine);
-    const embeddingArray = new Float32Array(embedding);
-    return embeddingArray;
-}
-
-export async function computeONNXTextEmbedding(
-    text: string,
-): Promise<Float32Array> {
-    const imageSession = await getOnnxTextSession();
-    const t1 = Date.now();
-    const tokenizer = getTokenizer();
-    const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
-    const feeds = {
-        input: new ort.Tensor("int32", tokenizedText, [1, 77]),
-    };
-    const t2 = Date.now();
-    const results = await imageSession.run(feeds);
-    log.info(
-        `onnx text embedding time: ${Date.now() - t1} ms (prep:${
-            t2 - t1
-        } ms, extraction: ${Date.now() - t2} ms)`,
-    );
-    const textEmbedding = results["output"].data; // Float32Array
-    return normalizeEmbedding(textEmbedding);
-}

+ 0 - 4
desktop/src/types/ipc.ts

@@ -80,7 +80,3 @@ export interface AppUpdateInfo {
     autoUpdatable: boolean;
     version: string;
 }
-
-export type Model = "ggml-clip" | "onnx-clip";
-
-export const isModel = (s: unknown) => s == "ggml-clip" || s == "onnx-clip";

+ 5 - 7
web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx

@@ -14,7 +14,7 @@ import { EnteMenuItem } from "components/Menu/EnteMenuItem";
 import { MenuItemGroup } from "components/Menu/MenuItemGroup";
 import isElectron from "is-electron";
 import { AppContext } from "pages/_app";
-import { ClipExtractionStatus, ClipService } from "services/clipService";
+import { CLIPIndexingStatus, clipService } from "services/clip-service";
 import { formatNumber } from "utils/number/format";
 
 export default function AdvancedSettings({ open, onClose, onRootClose }) {
@@ -44,17 +44,15 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
             log.error("toggleFasterUpload failed", e);
         }
     };
-    const [indexingStatus, setIndexingStatus] = useState<ClipExtractionStatus>({
+    const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
         indexed: 0,
         pending: 0,
     });
 
     useEffect(() => {
-        const main = async () => {
-            setIndexingStatus(await ClipService.getIndexingStatus());
-            ClipService.setOnUpdateHandler(setIndexingStatus);
-        };
-        main();
+        clipService.setOnUpdateHandler(setIndexingStatus);
+        clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
+        return () => clipService.setOnUpdateHandler(undefined);
     }, []);
 
     return (

+ 5 - 5
web/apps/photos/src/pages/gallery/index.tsx

@@ -102,7 +102,7 @@ import {
 } from "constants/collection";
 import { SYNC_INTERVAL_IN_MICROSECONDS } from "constants/gallery";
 import { AppContext } from "pages/_app";
-import { ClipService } from "services/clipService";
+import { clipService } from "services/clip-service";
 import { constructUserIDToEmailMap } from "services/collectionService";
 import downloadManager from "services/download";
 import { syncEmbeddings } from "services/embeddingService";
@@ -362,7 +362,7 @@ export default function Gallery() {
                 syncWithRemote(false, true);
             }, SYNC_INTERVAL_IN_MICROSECONDS);
             if (electron) {
-                void ClipService.setupOnFileUploadListener();
+                void clipService.setupOnFileUploadListener();
                 electron.registerForegroundEventListener(() => {
                     syncWithRemote(false, true);
                 });
@@ -373,7 +373,7 @@ export default function Gallery() {
             clearInterval(syncInterval.current);
             if (electron) {
                 electron.registerForegroundEventListener(() => {});
-                ClipService.removeOnFileUploadListener();
+                clipService.removeOnFileUploadListener();
             }
         };
     }, []);
@@ -704,8 +704,8 @@ export default function Gallery() {
             await syncEntities();
             await syncMapEnabled();
             await syncEmbeddings();
-            if (ClipService.isPlatformSupported()) {
-                void ClipService.scheduleImageEmbeddingExtraction();
+            if (clipService.isPlatformSupported()) {
+                void clipService.scheduleImageEmbeddingExtraction();
             }
         } catch (e) {
             switch (e.message) {

+ 96 - 95
web/apps/photos/src/services/clipService.ts → web/apps/photos/src/services/clip-service.ts

@@ -1,5 +1,6 @@
 import { ensureElectron } from "@/next/electron";
 import log from "@/next/log";
+import type { Electron } from "@/next/types/ipc";
 import ComlinkCryptoWorker from "@ente/shared/crypto";
 import { CustomError } from "@ente/shared/error";
 import { Events, eventBus } from "@ente/shared/events";
@@ -7,29 +8,71 @@ import { LS_KEYS, getData } from "@ente/shared/storage/localStorage";
 import { FILE_TYPE } from "constants/file";
 import isElectron from "is-electron";
 import PQueue from "p-queue";
-import { Embedding, Model } from "types/embedding";
+import { Embedding } from "types/embedding";
 import { EnteFile } from "types/file";
 import { getPersonalFiles } from "utils/file";
 import downloadManager from "./download";
 import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
 import { getAllLocalFiles, getLocalFiles } from "./fileService";
 
-const CLIP_EMBEDDING_LENGTH = 512;
-
-export interface ClipExtractionStatus {
+/** Status of CLIP indexing on the images in the user's local library. */
+export interface CLIPIndexingStatus {
+    /** Number of items pending indexing. */
     pending: number;
+    /** Number of items that have already been indexed. */
     indexed: number;
 }
 
-class ClipServiceImpl {
+/**
+ * Use a CLIP based neural network for natural language search.
+ *
+ * [Note: CLIP based magic search]
+ *
+ * CLIP (Contrastive Language-Image Pretraining) is a neural network trained on
+ * (image, text) pairs. It can be thought of as two separate (but jointly
+ * trained) encoders - one for images, and one for text - that both map to the
+ * same embedding space.
+ *
+ * We use this for natural language search within the app (aka "magic search"):
+ *
+ * 1. Pre-compute an embedding for each image.
+ *
+ * 2. When the user searches, compute an embedding for the search term.
+ *
+ * 3. Use cosine similarity to find the find the image (embedding) closest to
+ *    the text (embedding).
+ *
+ * More details are in our [blog
+ * post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the
+ * initial launch of this feature using the GGML runtime.
+ *
+ * Since the initial launch, we've switched over to another runtime,
+ * [ONNX](https://onnxruntime.ai).
+ *
+ * Note that we don't train the neural network - we only use one of the publicly
+ * available pre-trained neural networks for inference. These neural networks
+ * are wholly defined by their connectivity and weights. ONNX, our ML runtimes,
+ * loads these weights and instantiates a running network that we can use to
+ * compute the embeddings.
+ *
+ * Theoretically, the same CLIP model can be loaded by different frameworks /
+ * runtimes, but in practice each runtime has its own preferred format, and
+ * there are also quantization tradeoffs. So there is a specific model (a binary
+ * encoding of weights) tied to our current runtime that we use.
+ *
+ * To ensure that the embeddings, for the most part, can be shared, whenever
+ * possible we try to ensure that all the preprocessing steps, and the model
+ * itself, is the same across clients - web and mobile.
+ */
+class CLIPService {
+    private electron: Electron;
     private embeddingExtractionInProgress: AbortController | null = null;
     private reRunNeeded = false;
-    private clipExtractionStatus: ClipExtractionStatus = {
+    private indexingStatus: CLIPIndexingStatus = {
         pending: 0,
         indexed: 0,
     };
-    private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null =
-        null;
+    private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
     private liveEmbeddingExtractionQueue: PQueue;
     private onFileUploadedHandler:
         | ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
@@ -37,6 +80,7 @@ class ClipServiceImpl {
     private unsupportedPlatform = false;
 
     constructor() {
+        this.electron = ensureElectron();
         this.liveEmbeddingExtractionQueue = new PQueue({
             concurrency: 1,
         });
@@ -96,28 +140,23 @@ class ClipServiceImpl {
     };
 
     getIndexingStatus = async () => {
-        try {
-            if (
-                !this.clipExtractionStatus ||
-                (this.clipExtractionStatus.pending === 0 &&
-                    this.clipExtractionStatus.indexed === 0)
-            ) {
-                this.clipExtractionStatus = await getClipExtractionStatus();
-            }
-            return this.clipExtractionStatus;
-        } catch (e) {
-            log.error("failed to get clip indexing status", e);
+        if (
+            this.indexingStatus.pending === 0 &&
+            this.indexingStatus.indexed === 0
+        ) {
+            this.indexingStatus = await initialIndexingStatus();
         }
+        return this.indexingStatus;
     };
 
-    setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => {
+    /**
+     * Set the {@link handler} to invoke whenever our indexing status changes.
+     */
+    setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
         this.onUpdateHandler = handler;
-        handler(this.clipExtractionStatus);
     };
 
-    scheduleImageEmbeddingExtraction = async (
-        model: Model = Model.ONNX_CLIP,
-    ) => {
+    scheduleImageEmbeddingExtraction = async () => {
         try {
             if (this.embeddingExtractionInProgress) {
                 log.info(
@@ -133,7 +172,7 @@ class ClipServiceImpl {
             const canceller = new AbortController();
             this.embeddingExtractionInProgress = canceller;
             try {
-                await this.runClipEmbeddingExtraction(canceller, model);
+                await this.runClipEmbeddingExtraction(canceller);
             } finally {
                 this.embeddingExtractionInProgress = null;
                 if (!canceller.signal.aborted && this.reRunNeeded) {
@@ -152,25 +191,19 @@ class ClipServiceImpl {
         }
     };
 
-    getTextEmbedding = async (
-        text: string,
-        model: Model = Model.ONNX_CLIP,
-    ): Promise<Float32Array> => {
+    getTextEmbedding = async (text: string): Promise<Float32Array> => {
         try {
-            return ensureElectron().computeTextEmbedding(model, text);
+            return electron.clipTextEmbedding(text);
         } catch (e) {
             if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
                 this.unsupportedPlatform = true;
             }
-            log.error("failed to compute text embedding", e);
+            log.error("Failed to compute CLIP text embedding", e);
             throw e;
         }
     };
 
-    private runClipEmbeddingExtraction = async (
-        canceller: AbortController,
-        model: Model,
-    ) => {
+    private runClipEmbeddingExtraction = async (canceller: AbortController) => {
         try {
             if (this.unsupportedPlatform) {
                 log.info(
@@ -183,12 +216,12 @@ class ClipServiceImpl {
                 return;
             }
             const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
-            const existingEmbeddings = await getLocalEmbeddings(model);
+            const existingEmbeddings = await getLocalEmbeddings();
             const pendingFiles = await getNonClipEmbeddingExtractedFiles(
                 localFiles,
                 existingEmbeddings,
             );
-            this.updateClipEmbeddingExtractionStatus({
+            this.updateIndexingStatus({
                 indexed: existingEmbeddings.length,
                 pending: pendingFiles.length,
             });
@@ -208,15 +241,11 @@ class ClipServiceImpl {
                         throw Error(CustomError.REQUEST_CANCELLED);
                     }
                     const embeddingData =
-                        await this.extractFileClipImageEmbedding(model, file);
+                        await this.extractFileClipImageEmbedding(file);
                     log.info(
                         `successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
                     );
-                    await this.encryptAndUploadEmbedding(
-                        model,
-                        file,
-                        embeddingData,
-                    );
+                    await this.encryptAndUploadEmbedding(file, embeddingData);
                     this.onSuccessStatusUpdater();
                     log.info(
                         `successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
@@ -249,13 +278,10 @@ class ClipServiceImpl {
         }
     };
 
-    private async runLocalFileClipExtraction(
-        arg: {
-            enteFile: EnteFile;
-            localFile: globalThis.File;
-        },
-        model: Model = Model.ONNX_CLIP,
-    ) {
+    private async runLocalFileClipExtraction(arg: {
+        enteFile: EnteFile;
+        localFile: globalThis.File;
+    }) {
         const { enteFile, localFile } = arg;
         log.info(
             `clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@@ -279,15 +305,9 @@ class ClipServiceImpl {
         );
         try {
             await this.liveEmbeddingExtractionQueue.add(async () => {
-                const embedding = await this.extractLocalFileClipImageEmbedding(
-                    model,
-                    localFile,
-                );
-                await this.encryptAndUploadEmbedding(
-                    model,
-                    enteFile,
-                    embedding,
-                );
+                const embedding =
+                    await this.extractLocalFileClipImageEmbedding(localFile);
+                await this.encryptAndUploadEmbedding(enteFile, embedding);
             });
             log.info(
                 `successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@@ -297,26 +317,18 @@ class ClipServiceImpl {
         }
     }
 
-    private extractLocalFileClipImageEmbedding = async (
-        model: Model,
-        localFile: File,
-    ) => {
+    private extractLocalFileClipImageEmbedding = async (localFile: File) => {
         const file = await localFile
             .arrayBuffer()
             .then((buffer) => new Uint8Array(buffer));
-        const embedding = await ensureElectron().computeImageEmbedding(
-            model,
-            file,
-        );
-        return embedding;
+        return await electron.clipImageEmbedding(file);
     };
 
     private encryptAndUploadEmbedding = async (
-        model: Model,
         file: EnteFile,
         embeddingData: Float32Array,
     ) => {
-        if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) {
+        if (embeddingData?.length !== 512) {
             throw Error(
                 `invalid length embedding data length: ${embeddingData?.length}`,
             );
@@ -331,38 +343,31 @@ class ClipServiceImpl {
             fileID: file.id,
             encryptedEmbedding: encryptedEmbeddingData.encryptedData,
             decryptionHeader: encryptedEmbeddingData.decryptionHeader,
-            model,
+            model: "onnx-clip",
         });
     };
 
-    updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => {
-        this.clipExtractionStatus = status;
-        if (this.onUpdateHandler) {
-            this.onUpdateHandler(status);
-        }
+    private updateIndexingStatus = (status: CLIPIndexingStatus) => {
+        this.indexingStatus = status;
+        const handler = this.onUpdateHandler;
+        if (handler) handler(status);
     };
 
-    private extractFileClipImageEmbedding = async (
-        model: Model,
-        file: EnteFile,
-    ) => {
+    private extractFileClipImageEmbedding = async (file: EnteFile) => {
         const thumb = await downloadManager.getThumbnail(file);
-        const embedding = await ensureElectron().computeImageEmbedding(
-            model,
-            thumb,
-        );
+        const embedding = await ensureElectron().clipImageEmbedding(thumb);
         return embedding;
     };
 
     private onSuccessStatusUpdater = () => {
-        this.updateClipEmbeddingExtractionStatus({
-            pending: this.clipExtractionStatus.pending - 1,
-            indexed: this.clipExtractionStatus.indexed + 1,
+        this.updateIndexingStatus({
+            pending: this.indexingStatus.pending - 1,
+            indexed: this.indexingStatus.indexed + 1,
         });
     };
 }
 
-export const ClipService = new ClipServiceImpl();
+export const clipService = new CLIPService();
 
 const getNonClipEmbeddingExtractedFiles = async (
     files: EnteFile[],
@@ -412,14 +417,10 @@ export const computeClipMatchScore = async (
     return score;
 };
 
-const getClipExtractionStatus = async (
-    model: Model = Model.ONNX_CLIP,
-): Promise<ClipExtractionStatus> => {
+const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
     const user = getData(LS_KEYS.USER);
-    if (!user) {
-        return;
-    }
-    const allEmbeddings = await getLocalEmbeddings(model);
+    if (!user) throw new Error("Orphan CLIP indexing without a login");
+    const allEmbeddings = await getLocalEmbeddings();
     const localFiles = getPersonalFiles(await getLocalFiles(), user);
     const pendingFiles = await getNonClipEmbeddingExtractedFiles(
         localFiles,

+ 12 - 8
web/apps/photos/src/services/embeddingService.ts

@@ -5,11 +5,11 @@ import HTTPService from "@ente/shared/network/HTTPService";
 import { getEndpoint } from "@ente/shared/network/api";
 import localForage from "@ente/shared/storage/localForage";
 import { getToken } from "@ente/shared/storage/localStorage/helpers";
-import {
+import type {
     Embedding,
+    EmbeddingModel,
     EncryptedEmbedding,
     GetEmbeddingDiffResponse,
-    Model,
     PutEmbeddingRequest,
 } from "types/embedding";
 import { EnteFile } from "types/file";
@@ -38,12 +38,12 @@ export const getAllLocalEmbeddings = async () => {
     return embeddings;
 };
 
-export const getLocalEmbeddings = async (model: Model) => {
+export const getLocalEmbeddings = async () => {
     const embeddings = await getAllLocalEmbeddings();
-    return embeddings.filter((embedding) => embedding.model === model);
+    return embeddings.filter((embedding) => embedding.model === "onnx-clip");
 };
 
-const getModelEmbeddingSyncTime = async (model: Model) => {
+const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
     return (
         (await localForage.getItem<number>(
             `${model}-${EMBEDDING_SYNC_TIME_TABLE}`,
@@ -51,11 +51,15 @@ const getModelEmbeddingSyncTime = async (model: Model) => {
     );
 };
 
-const setModelEmbeddingSyncTime = async (model: Model, time: number) => {
+const setModelEmbeddingSyncTime = async (
+    model: EmbeddingModel,
+    time: number,
+) => {
     await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
 };
 
-export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
+export const syncEmbeddings = async () => {
+    const models: EmbeddingModel[] = ["onnx-clip"];
     try {
         let allEmbeddings = await getAllLocalEmbeddings();
         const localFiles = await getAllLocalFiles();
@@ -138,7 +142,7 @@ export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => {
 
 export const getEmbeddingsDiff = async (
     sinceTime: number,
-    model: Model,
+    model: EmbeddingModel,
 ): Promise<GetEmbeddingDiffResponse> => {
     try {
         const token = getToken();

+ 4 - 5
web/apps/photos/src/services/searchService.ts

@@ -4,7 +4,6 @@ import * as chrono from "chrono-node";
 import { FILE_TYPE } from "constants/file";
 import { t } from "i18next";
 import { Collection } from "types/collection";
-import { Model } from "types/embedding";
 import { EntityType, LocationTag, LocationTagData } from "types/entity";
 import { EnteFile } from "types/file";
 import { Person, Thing } from "types/machineLearning";
@@ -22,7 +21,7 @@ import { getAllPeople } from "utils/machineLearning";
 import { getMLSyncConfig } from "utils/machineLearning/config";
 import { getFormattedDate } from "utils/search";
 import mlIDbStorage from "utils/storage/mlIDbStorage";
-import { ClipService, computeClipMatchScore } from "./clipService";
+import { clipService, computeClipMatchScore } from "./clip-service";
 import { getLocalEmbeddings } from "./embeddingService";
 import { getLatestEntities } from "./entityService";
 import locationSearchService, { City } from "./locationSearchService";
@@ -305,7 +304,7 @@ async function getThingSuggestion(searchPhrase: string): Promise<Suggestion[]> {
 
 async function getClipSuggestion(searchPhrase: string): Promise<Suggestion> {
     try {
-        if (!ClipService.isPlatformSupported()) {
+        if (!clipService.isPlatformSupported()) {
             return null;
         }
 
@@ -396,8 +395,8 @@ async function searchThing(searchPhrase: string) {
 }
 
 async function searchClip(searchPhrase: string): Promise<ClipSearchScores> {
-    const imageEmbeddings = await getLocalEmbeddings(Model.ONNX_CLIP);
-    const textEmbedding = await ClipService.getTextEmbedding(searchPhrase);
+    const imageEmbeddings = await getLocalEmbeddings();
+    const textEmbedding = await clipService.getTextEmbedding(searchPhrase);
     const clipSearchResult = new Map<number, number>(
         (
             await Promise.all(

+ 11 - 6
web/apps/photos/src/types/embedding.tsx

@@ -1,11 +1,16 @@
-export enum Model {
-    GGML_CLIP = "ggml-clip",
-    ONNX_CLIP = "onnx-clip",
-}
+/**
+ * The embeddings models that we support.
+ *
+ * This is an exhaustive set of values we pass when PUT-ting encrypted
+ * embeddings on the server. However, we should be prepared to receive an
+ * {@link EncryptedEmbedding} with a model value distinct from one of these.
+ */
+export type EmbeddingModel = "onnx-clip";
 
 export interface EncryptedEmbedding {
     fileID: number;
-    model: Model;
+    /** @see {@link EmbeddingModel} */
+    model: string;
     encryptedEmbedding: string;
     decryptionHeader: string;
     updatedAt: number;
@@ -25,7 +30,7 @@ export interface GetEmbeddingDiffResponse {
 
 export interface PutEmbeddingRequest {
     fileID: number;
-    model: Model;
+    model: EmbeddingModel;
     encryptedEmbedding: string;
     decryptionHeader: string;
 }

+ 20 - 10
web/packages/next/types/ipc.ts

@@ -10,11 +10,6 @@ export interface AppUpdateInfo {
     version: string;
 }
 
-export enum Model {
-    GGML_CLIP = "ggml-clip",
-    ONNX_CLIP = "onnx-clip",
-}
-
 export enum FILE_PATH_TYPE {
     FILES = "files",
     ZIPS = "zips",
@@ -147,12 +142,27 @@ export interface Electron {
 
     // - ML
 
-    computeImageEmbedding: (
-        model: Model,
-        imageData: Uint8Array,
-    ) => Promise<Float32Array>;
+    /**
+     * Compute and return a CLIP embedding of the given image.
+     *
+     * See: [Note: CLIP based magic search]
+     *
+     * @param jpegImageData The raw bytes of the image encoded as an JPEG.
+     *
+     * @returns A CLIP embedding.
+     */
+    clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
 
-    computeTextEmbedding: (model: Model, text: string) => Promise<Float32Array>;
+    /**
+     * Compute and return a CLIP embedding of the given image.
+     *
+     * See: [Note: CLIP based magic search]
+     *
+     * @param text The string whose embedding we want to compute.
+     *
+     * @returns A CLIP embedding.
+     */
+    clipTextEmbedding: (text: string) => Promise<Float32Array>;
 
     // - File selection
     // TODO: Deprecated - use dialogs on the renderer process itself