diff --git a/desktop/build/ggmlclip-linux b/desktop/build/ggmlclip-linux deleted file mode 100755 index 4c160b039..000000000 Binary files a/desktop/build/ggmlclip-linux and /dev/null differ diff --git a/desktop/build/ggmlclip-mac b/desktop/build/ggmlclip-mac deleted file mode 100755 index db7c4f249..000000000 Binary files a/desktop/build/ggmlclip-mac and /dev/null differ diff --git a/desktop/build/ggmlclip-windows.exe b/desktop/build/ggmlclip-windows.exe deleted file mode 100755 index 1e197dfe8..000000000 Binary files a/desktop/build/ggmlclip-windows.exe and /dev/null differ diff --git a/desktop/build/msvcp140d.dll b/desktop/build/msvcp140d.dll deleted file mode 100644 index 358e18663..000000000 Binary files a/desktop/build/msvcp140d.dll and /dev/null differ diff --git a/desktop/build/ucrtbased.dll b/desktop/build/ucrtbased.dll deleted file mode 100644 index 78cfcfeeb..000000000 Binary files a/desktop/build/ucrtbased.dll and /dev/null differ diff --git a/desktop/build/vcruntime140_1d.dll b/desktop/build/vcruntime140_1d.dll deleted file mode 100644 index 700cf5f75..000000000 Binary files a/desktop/build/vcruntime140_1d.dll and /dev/null differ diff --git a/desktop/build/vcruntime140d.dll b/desktop/build/vcruntime140d.dll deleted file mode 100644 index 8b5425e0f..000000000 Binary files a/desktop/build/vcruntime140d.dll and /dev/null differ diff --git a/desktop/docs/dependencies.md b/desktop/docs/dependencies.md index 103583a63..5c1b07744 100644 --- a/desktop/docs/dependencies.md +++ b/desktop/docs/dependencies.md @@ -111,11 +111,11 @@ watcher for the watch folders functionality. ### AI/ML -- [onnxruntime-node](https://github.com/Microsoft/onnxruntime) -- html-entities is used by the bundled clip-bpe-ts. -- GGML binaries are bundled -- We also use [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) for - conversion of all images to JPEG before processing. +- [onnxruntime-node](https://github.com/Microsoft/onnxruntime) is used for + natural language searches based on CLIP. +- html-entities is used by the bundled clip-bpe-ts tokenizer. +- [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding + JPEG data into raw RGB bytes before passing it to ONNX. ## ZIP diff --git a/desktop/electron-builder.yml b/desktop/electron-builder.yml index 9189c3435..4fdfc4f55 100644 --- a/desktop/electron-builder.yml +++ b/desktop/electron-builder.yml @@ -19,7 +19,6 @@ mac: arch: [universal] category: public.app-category.photography hardenedRuntime: true - x64ArchFiles: Contents/Resources/ggmlclip-mac afterSign: electron-builder-notarize extraFiles: - from: build diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index 0fdd10056..f4da569c5 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -17,9 +17,9 @@ import { updateAndRestart, } from "../services/appUpdater"; import { - computeImageEmbedding, - computeTextEmbedding, -} from "../services/clipService"; + clipImageEmbedding, + clipTextEmbedding, +} from "../services/clip-service"; import { runFFmpegCmd } from "../services/ffmpeg"; import { getDirFiles } from "../services/fs"; import { @@ -44,12 +44,7 @@ import { updateWatchMappingIgnoredFiles, updateWatchMappingSyncedFiles, } from "../services/watch"; -import type { - ElectronFile, - FILE_PATH_TYPE, - Model, - WatchMapping, -} from "../types/ipc"; +import type { ElectronFile, FILE_PATH_TYPE, WatchMapping } from "../types/ipc"; import { selectDirectory, showUploadDirsDialog, @@ -148,14 +143,12 @@ export const attachIPCHandlers = () => { // - ML - ipcMain.handle( - "computeImageEmbedding", - (_, model: Model, imageData: Uint8Array) => - computeImageEmbedding(model, imageData), + ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) => + clipImageEmbedding(jpegImageData), ); - ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) => - computeTextEmbedding(model, text), + ipcMain.handle("clipTextEmbedding", (_, text: string) => + clipTextEmbedding(text), ); // - File selection diff --git a/desktop/src/main/log.ts b/desktop/src/main/log.ts index 0e504115c..04ecb6ea3 100644 --- a/desktop/src/main/log.ts +++ b/desktop/src/main/log.ts @@ -64,7 +64,10 @@ const logInfo = (...params: any[]) => { }; const logDebug = (param: () => any) => { - if (isDev) console.log(`[debug] ${util.inspect(param())}`); + if (isDev) { + const p = param(); + console.log(`[debug] ${typeof p == "string" ? p : util.inspect(p)}`); + } }; /** diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index aa528b7ad..cb718f950 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -45,7 +45,6 @@ import type { AppUpdateInfo, ElectronFile, FILE_PATH_TYPE, - Model, WatchMapping, } from "./types/ipc"; @@ -141,17 +140,11 @@ const runFFmpegCmd = ( // - ML -const computeImageEmbedding = ( - model: Model, - imageData: Uint8Array, -): Promise => - ipcRenderer.invoke("computeImageEmbedding", model, imageData); +const clipImageEmbedding = (jpegImageData: Uint8Array): Promise => + ipcRenderer.invoke("clipImageEmbedding", jpegImageData); -const computeTextEmbedding = ( - model: Model, - text: string, -): Promise => - ipcRenderer.invoke("computeTextEmbedding", model, text); +const clipTextEmbedding = (text: string): Promise => + ipcRenderer.invoke("clipTextEmbedding", text); // - File selection @@ -332,8 +325,8 @@ contextBridge.exposeInMainWorld("electron", { runFFmpegCmd, // - ML - computeImageEmbedding, - computeTextEmbedding, + clipImageEmbedding, + clipTextEmbedding, // - File selection selectDirectory, diff --git a/desktop/src/services/clip-service.ts b/desktop/src/services/clip-service.ts new file mode 100644 index 000000000..5de05e601 --- /dev/null +++ b/desktop/src/services/clip-service.ts @@ -0,0 +1,288 @@ +/** + * @file Compute CLIP embeddings + * + * @see `web/apps/photos/src/services/clip-service.ts` for more details. This + * file implements the Node.js implementation of the actual embedding + * computation. By doing it in the Node.js layer, we can use the binary ONNX + * runtimes which are 10-20x faster than the WASM based web ones. + * + * The embeddings are computed using ONNX runtime. The model itself is not + * shipped with the app but is downloaded on demand. + */ +import { app, net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { writeStream } from "../main/fs"; +import log from "../main/log"; +import { CustomErrors } from "../types/ipc"; +import Tokenizer from "../utils/clip-bpe-ts/mod"; +import { generateTempFilePath } from "../utils/temp"; +import { deleteTempFile } from "./ffmpeg"; +const jpeg = require("jpeg-js"); +const ort = require("onnxruntime-node"); + +const textModelName = "clip-text-vit-32-uint8.onnx"; +const textModelByteSize = 64173509; // 61.2 MB + +const imageModelName = "clip-image-vit-32-float32.onnx"; +const imageModelByteSize = 351468764; // 335.2 MB + +/** Return the path where the given {@link modelName} is meant to be saved */ +const modelSavePath = (modelName: string) => + path.join(app.getPath("userData"), "models", modelName); + +const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download + log.info(`Downloading CLIP model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + // Save + await writeStream(saveLocation, res.body); + log.info(`Downloaded CLIP model ${name}`); +}; + +let activeImageModelDownload: Promise | undefined; + +const imageModelPathDownloadingIfNeeded = async () => { + try { + const modelPath = modelSavePath(imageModelName); + if (activeImageModelDownload) { + log.info("Waiting for CLIP image model download to finish"); + await activeImageModelDownload; + } else { + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + activeImageModelDownload = downloadModel( + modelPath, + imageModelName, + ); + await activeImageModelDownload; + } else { + const localFileSize = (await fs.stat(modelPath)).size; + if (localFileSize !== imageModelByteSize) { + log.error( + `CLIP image model size ${localFileSize} does not match the expected size, downloading again`, + ); + activeImageModelDownload = downloadModel( + modelPath, + imageModelName, + ); + await activeImageModelDownload; + } + } + } + return modelPath; + } finally { + activeImageModelDownload = undefined; + } +}; + +let textModelDownloadInProgress = false; + +const textModelPathDownloadingIfNeeded = async () => { + if (textModelDownloadInProgress) + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + + const modelPath = modelSavePath(textModelName); + if (!existsSync(modelPath)) { + log.info("CLIP text model not found, downloading"); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } else { + const localFileSize = (await fs.stat(modelPath)).size; + if (localFileSize !== textModelByteSize) { + log.error( + `CLIP text model size ${localFileSize} does not match the expected size, downloading again`, + ); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } + } + + return modelPath; +}; + +const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + intraOpNumThreads: 1, + enableCpuMemArena: false, + }); +}; + +let imageSessionPromise: Promise | undefined; + +const onnxImageSession = async () => { + if (!imageSessionPromise) { + imageSessionPromise = (async () => { + const modelPath = await imageModelPathDownloadingIfNeeded(); + return createInferenceSession(modelPath); + })(); + } + return imageSessionPromise; +}; + +let _textSession: any = null; + +const onnxTextSession = async () => { + if (!_textSession) { + const modelPath = await textModelPathDownloadingIfNeeded(); + _textSession = await createInferenceSession(modelPath); + } + return _textSession; +}; + +export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { + const tempFilePath = await generateTempFilePath(""); + const imageStream = new Response(jpegImageData.buffer).body; + await writeStream(tempFilePath, imageStream); + try { + return await clipImageEmbedding_(tempFilePath); + } finally { + await deleteTempFile(tempFilePath); + } +}; + +const clipImageEmbedding_ = async (jpegFilePath: string) => { + const imageSession = await onnxImageSession(); + const t1 = Date.now(); + const rgbData = await getRGBData(jpegFilePath); + const feeds = { + input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.debug( + () => + `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + ); + const imageEmbedding = results["output"].data; // Float32Array + return normalizeEmbedding(imageEmbedding); +}; + +const getRGBData = async (jpegFilePath: string) => { + const jpegData = await fs.readFile(jpegFilePath); + const rawImageData = jpeg.decode(jpegData, { + useTArray: true, + formatAsRGBA: false, + }); + + const nx: number = rawImageData.width; + const ny: number = rawImageData.height; + const inputImage: Uint8Array = rawImageData.data; + + const nx2: number = 224; + const ny2: number = 224; + const totalSize: number = 3 * nx2 * ny2; + + const result: number[] = Array(totalSize).fill(0); + const scale: number = Math.max(nx, ny) / 224; + + const nx3: number = Math.round(nx / scale); + const ny3: number = Math.round(ny / scale); + + const mean: number[] = [0.48145466, 0.4578275, 0.40821073]; + const std: number[] = [0.26862954, 0.26130258, 0.27577711]; + + for (let y = 0; y < ny3; y++) { + for (let x = 0; x < nx3; x++) { + for (let c = 0; c < 3; c++) { + // Linear interpolation + const sx: number = (x + 0.5) * scale - 0.5; + const sy: number = (y + 0.5) * scale - 0.5; + + const x0: number = Math.max(0, Math.floor(sx)); + const y0: number = Math.max(0, Math.floor(sy)); + + const x1: number = Math.min(x0 + 1, nx - 1); + const y1: number = Math.min(y0 + 1, ny - 1); + + const dx: number = sx - x0; + const dy: number = sy - y0; + + const j00: number = 3 * (y0 * nx + x0) + c; + const j01: number = 3 * (y0 * nx + x1) + c; + const j10: number = 3 * (y1 * nx + x0) + c; + const j11: number = 3 * (y1 * nx + x1) + c; + + const v00: number = inputImage[j00]; + const v01: number = inputImage[j01]; + const v10: number = inputImage[j10]; + const v11: number = inputImage[j11]; + + const v0: number = v00 * (1 - dx) + v01 * dx; + const v1: number = v10 * (1 - dx) + v11 * dx; + + const v: number = v0 * (1 - dy) + v1 * dy; + + const v2: number = Math.min(Math.max(Math.round(v), 0), 255); + + // createTensorWithDataList is dumb compared to reshape and + // hence has to be given with one channel after another + const i: number = y * nx3 + x + (c % 3) * 224 * 224; + + result[i] = (v2 / 255 - mean[c]) / std[c]; + } + } + } + + return result; +}; + +const normalizeEmbedding = (embedding: Float32Array) => { + let normalization = 0; + for (let index = 0; index < embedding.length; index++) { + normalization += embedding[index] * embedding[index]; + } + const sqrtNormalization = Math.sqrt(normalization); + for (let index = 0; index < embedding.length; index++) { + embedding[index] = embedding[index] / sqrtNormalization; + } + return embedding; +}; + +let _tokenizer: Tokenizer = null; +const getTokenizer = () => { + if (!_tokenizer) { + _tokenizer = new Tokenizer(); + } + return _tokenizer; +}; + +export const clipTextEmbedding = async (text: string) => { + const imageSession = await onnxTextSession(); + const t1 = Date.now(); + const tokenizer = getTokenizer(); + const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); + const feeds = { + input: new ort.Tensor("int32", tokenizedText, [1, 77]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.debug( + () => + `CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + ); + const textEmbedding = results["output"].data; + return normalizeEmbedding(textEmbedding); +}; diff --git a/desktop/src/services/clipService.ts b/desktop/src/services/clipService.ts deleted file mode 100644 index 32d404912..000000000 --- a/desktop/src/services/clipService.ts +++ /dev/null @@ -1,463 +0,0 @@ -import { app, net } from "electron/main"; -import { existsSync } from "fs"; -import fs from "node:fs/promises"; -import path from "node:path"; -import { writeStream } from "../main/fs"; -import log from "../main/log"; -import { execAsync, isDev } from "../main/util"; -import { CustomErrors, Model, isModel } from "../types/ipc"; -import Tokenizer from "../utils/clip-bpe-ts/mod"; -import { getPlatform } from "../utils/common/platform"; -import { generateTempFilePath } from "../utils/temp"; -import { deleteTempFile } from "./ffmpeg"; -const jpeg = require("jpeg-js"); - -const CLIP_MODEL_PATH_PLACEHOLDER = "CLIP_MODEL"; -const GGMLCLIP_PATH_PLACEHOLDER = "GGML_PATH"; -const INPUT_PATH_PLACEHOLDER = "INPUT"; - -const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [ - GGMLCLIP_PATH_PLACEHOLDER, - "-mv", - CLIP_MODEL_PATH_PLACEHOLDER, - "--image", - INPUT_PATH_PLACEHOLDER, -]; - -const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [ - GGMLCLIP_PATH_PLACEHOLDER, - "-mt", - CLIP_MODEL_PATH_PLACEHOLDER, - "--text", - INPUT_PATH_PLACEHOLDER, -]; -const ort = require("onnxruntime-node"); - -const TEXT_MODEL_DOWNLOAD_URL = { - ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf", - onnx: "https://models.ente.io/clip-text-vit-32-uint8.onnx", -}; -const IMAGE_MODEL_DOWNLOAD_URL = { - ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf", - onnx: "https://models.ente.io/clip-image-vit-32-float32.onnx", -}; - -const TEXT_MODEL_NAME = { - ggml: "clip-vit-base-patch32_ggml-text-model-f16.gguf", - onnx: "clip-text-vit-32-uint8.onnx", -}; -const IMAGE_MODEL_NAME = { - ggml: "clip-vit-base-patch32_ggml-vision-model-f16.gguf", - onnx: "clip-image-vit-32-float32.onnx", -}; - -const IMAGE_MODEL_SIZE_IN_BYTES = { - ggml: 175957504, // 167.8 MB - onnx: 351468764, // 335.2 MB -}; -const TEXT_MODEL_SIZE_IN_BYTES = { - ggml: 127853440, // 121.9 MB, - onnx: 64173509, // 61.2 MB -}; - -/** Return the path where the given {@link modelName} is meant to be saved */ -const getModelSavePath = (modelName: string) => - path.join(app.getPath("userData"), "models", modelName); - -async function downloadModel(saveLocation: string, url: string) { - // confirm that the save location exists - const saveDir = path.dirname(saveLocation); - await fs.mkdir(saveDir, { recursive: true }); - log.info("downloading clip model"); - const res = await net.fetch(url); - if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - await writeStream(saveLocation, res.body); - log.info("clip model downloaded"); -} - -let imageModelDownloadInProgress: Promise = null; - -const getClipImageModelPath = async (type: "ggml" | "onnx") => { - try { - const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]); - if (imageModelDownloadInProgress) { - log.info("waiting for image model download to finish"); - await imageModelDownloadInProgress; - } else { - if (!existsSync(modelSavePath)) { - log.info("CLIP image model not found, downloading"); - imageModelDownloadInProgress = downloadModel( - modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL[type], - ); - await imageModelDownloadInProgress; - } else { - const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) { - log.info( - `CLIP image model size mismatch, downloading again got: ${localFileSize}`, - ); - imageModelDownloadInProgress = downloadModel( - modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL[type], - ); - await imageModelDownloadInProgress; - } - } - } - return modelSavePath; - } finally { - imageModelDownloadInProgress = null; - } -}; - -let textModelDownloadInProgress: boolean = false; - -const getClipTextModelPath = async (type: "ggml" | "onnx") => { - const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]); - if (textModelDownloadInProgress) { - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } else { - if (!existsSync(modelSavePath)) { - log.info("CLIP text model not found, downloading"); - textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } else { - const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) { - log.info( - `CLIP text model size mismatch, downloading again got: ${localFileSize}`, - ); - textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } - } - } - return modelSavePath; -}; - -function getGGMLClipPath() { - return isDev - ? path.join("./build", `ggmlclip-${getPlatform()}`) - : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`); -} - -async function createOnnxSession(modelPath: string) { - return await ort.InferenceSession.create(modelPath, { - intraOpNumThreads: 1, - enableCpuMemArena: false, - }); -} - -let onnxImageSessionPromise: Promise = null; - -async function getOnnxImageSession() { - if (!onnxImageSessionPromise) { - onnxImageSessionPromise = (async () => { - const clipModelPath = await getClipImageModelPath("onnx"); - return createOnnxSession(clipModelPath); - })(); - } - return onnxImageSessionPromise; -} - -let onnxTextSession: any = null; - -async function getOnnxTextSession() { - if (!onnxTextSession) { - const clipModelPath = await getClipTextModelPath("onnx"); - onnxTextSession = await createOnnxSession(clipModelPath); - } - return onnxTextSession; -} - -let tokenizer: Tokenizer = null; -function getTokenizer() { - if (!tokenizer) { - tokenizer = new Tokenizer(); - } - return tokenizer; -} - -export const computeImageEmbedding = async ( - model: Model, - imageData: Uint8Array, -): Promise => { - if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`); - - let tempInputFilePath = null; - try { - tempInputFilePath = await generateTempFilePath(""); - const imageStream = new Response(imageData.buffer).body; - await writeStream(tempInputFilePath, imageStream); - const embedding = await computeImageEmbedding_( - model, - tempInputFilePath, - ); - return embedding; - } catch (err) { - if (isExecError(err)) { - const parsedExecError = parseExecError(err); - throw Error(parsedExecError); - } else { - throw err; - } - } finally { - if (tempInputFilePath) { - await deleteTempFile(tempInputFilePath); - } - } -}; - -const isExecError = (err: any) => { - return err.message.includes("Command failed:"); -}; - -const parseExecError = (err: any) => { - const errMessage = err.message; - if (errMessage.includes("Bad CPU type in executable")) { - return CustomErrors.UNSUPPORTED_PLATFORM( - process.platform, - process.arch, - ); - } else { - return errMessage; - } -}; - -async function computeImageEmbedding_( - model: Model, - inputFilePath: string, -): Promise { - if (!existsSync(inputFilePath)) { - throw new Error("Invalid file path"); - } - switch (model) { - case "ggml-clip": - return await computeGGMLImageEmbedding(inputFilePath); - case "onnx-clip": - return await computeONNXImageEmbedding(inputFilePath); - } -} - -const computeGGMLImageEmbedding = async ( - inputFilePath: string, -): Promise => { - const clipModelPath = await getClipImageModelPath("ggml"); - const ggmlclipPath = getGGMLClipPath(); - const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { - if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { - return ggmlclipPath; - } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { - return clipModelPath; - } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { - return inputFilePath; - } else { - return cmdPart; - } - }); - - const { stdout } = await execAsync(cmd); - // parse stdout and return embedding - // get the last line of stdout - const lines = stdout.split("\n"); - const lastLine = lines[lines.length - 1]; - const embedding = JSON.parse(lastLine); - const embeddingArray = new Float32Array(embedding); - return embeddingArray; -}; - -const computeONNXImageEmbedding = async ( - inputFilePath: string, -): Promise => { - const imageSession = await getOnnxImageSession(); - const t1 = Date.now(); - const rgbData = await getRGBData(inputFilePath); - const feeds = { - input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), - }; - const t2 = Date.now(); - const results = await imageSession.run(feeds); - log.info( - `onnx image embedding time: ${Date.now() - t1} ms (prep:${ - t2 - t1 - } ms, extraction: ${Date.now() - t2} ms)`, - ); - const imageEmbedding = results["output"].data; // Float32Array - return normalizeEmbedding(imageEmbedding); -}; - -async function getRGBData(inputFilePath: string) { - const jpegData = await fs.readFile(inputFilePath); - const rawImageData = jpeg.decode(jpegData, { - useTArray: true, - formatAsRGBA: false, - }); - - const nx: number = rawImageData.width; - const ny: number = rawImageData.height; - const inputImage: Uint8Array = rawImageData.data; - - const nx2: number = 224; - const ny2: number = 224; - const totalSize: number = 3 * nx2 * ny2; - - const result: number[] = Array(totalSize).fill(0); - const scale: number = Math.max(nx, ny) / 224; - - const nx3: number = Math.round(nx / scale); - const ny3: number = Math.round(ny / scale); - - const mean: number[] = [0.48145466, 0.4578275, 0.40821073]; - const std: number[] = [0.26862954, 0.26130258, 0.27577711]; - - for (let y = 0; y < ny3; y++) { - for (let x = 0; x < nx3; x++) { - for (let c = 0; c < 3; c++) { - // linear interpolation - const sx: number = (x + 0.5) * scale - 0.5; - const sy: number = (y + 0.5) * scale - 0.5; - - const x0: number = Math.max(0, Math.floor(sx)); - const y0: number = Math.max(0, Math.floor(sy)); - - const x1: number = Math.min(x0 + 1, nx - 1); - const y1: number = Math.min(y0 + 1, ny - 1); - - const dx: number = sx - x0; - const dy: number = sy - y0; - - const j00: number = 3 * (y0 * nx + x0) + c; - const j01: number = 3 * (y0 * nx + x1) + c; - const j10: number = 3 * (y1 * nx + x0) + c; - const j11: number = 3 * (y1 * nx + x1) + c; - - const v00: number = inputImage[j00]; - const v01: number = inputImage[j01]; - const v10: number = inputImage[j10]; - const v11: number = inputImage[j11]; - - const v0: number = v00 * (1 - dx) + v01 * dx; - const v1: number = v10 * (1 - dx) + v11 * dx; - - const v: number = v0 * (1 - dy) + v1 * dy; - - const v2: number = Math.min(Math.max(Math.round(v), 0), 255); - - // createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another - const i: number = y * nx3 + x + (c % 3) * 224 * 224; - - result[i] = (v2 / 255 - mean[c]) / std[c]; - } - } - } - - return result; -} - -const normalizeEmbedding = (embedding: Float32Array) => { - let normalization = 0; - for (let index = 0; index < embedding.length; index++) { - normalization += embedding[index] * embedding[index]; - } - const sqrtNormalization = Math.sqrt(normalization); - for (let index = 0; index < embedding.length; index++) { - embedding[index] = embedding[index] / sqrtNormalization; - } - return embedding; -}; - -export async function computeTextEmbedding( - model: Model, - text: string, -): Promise { - if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`); - - try { - const embedding = computeTextEmbedding_(model, text); - return embedding; - } catch (err) { - if (isExecError(err)) { - const parsedExecError = parseExecError(err); - throw Error(parsedExecError); - } else { - throw err; - } - } -} - -async function computeTextEmbedding_( - model: Model, - text: string, -): Promise { - switch (model) { - case "ggml-clip": - return await computeGGMLTextEmbedding(text); - case "onnx-clip": - return await computeONNXTextEmbedding(text); - } -} - -export async function computeGGMLTextEmbedding( - text: string, -): Promise { - const clipModelPath = await getClipTextModelPath("ggml"); - const ggmlclipPath = getGGMLClipPath(); - const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { - if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { - return ggmlclipPath; - } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { - return clipModelPath; - } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { - return text; - } else { - return cmdPart; - } - }); - - const { stdout } = await execAsync(cmd); - // parse stdout and return embedding - // get the last line of stdout - const lines = stdout.split("\n"); - const lastLine = lines[lines.length - 1]; - const embedding = JSON.parse(lastLine); - const embeddingArray = new Float32Array(embedding); - return embeddingArray; -} - -export async function computeONNXTextEmbedding( - text: string, -): Promise { - const imageSession = await getOnnxTextSession(); - const t1 = Date.now(); - const tokenizer = getTokenizer(); - const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); - const feeds = { - input: new ort.Tensor("int32", tokenizedText, [1, 77]), - }; - const t2 = Date.now(); - const results = await imageSession.run(feeds); - log.info( - `onnx text embedding time: ${Date.now() - t1} ms (prep:${ - t2 - t1 - } ms, extraction: ${Date.now() - t2} ms)`, - ); - const textEmbedding = results["output"].data; // Float32Array - return normalizeEmbedding(textEmbedding); -} diff --git a/desktop/src/types/ipc.ts b/desktop/src/types/ipc.ts index d1e99b24c..3dba231f2 100644 --- a/desktop/src/types/ipc.ts +++ b/desktop/src/types/ipc.ts @@ -80,7 +80,3 @@ export interface AppUpdateInfo { autoUpdatable: boolean; version: string; } - -export type Model = "ggml-clip" | "onnx-clip"; - -export const isModel = (s: unknown) => s == "ggml-clip" || s == "onnx-clip"; diff --git a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx index 6668fef1f..817aecb2b 100644 --- a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx +++ b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx @@ -14,7 +14,7 @@ import { EnteMenuItem } from "components/Menu/EnteMenuItem"; import { MenuItemGroup } from "components/Menu/MenuItemGroup"; import isElectron from "is-electron"; import { AppContext } from "pages/_app"; -import { ClipExtractionStatus, ClipService } from "services/clipService"; +import { CLIPIndexingStatus, clipService } from "services/clip-service"; import { formatNumber } from "utils/number/format"; export default function AdvancedSettings({ open, onClose, onRootClose }) { @@ -44,17 +44,15 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) { log.error("toggleFasterUpload failed", e); } }; - const [indexingStatus, setIndexingStatus] = useState({ + const [indexingStatus, setIndexingStatus] = useState({ indexed: 0, pending: 0, }); useEffect(() => { - const main = async () => { - setIndexingStatus(await ClipService.getIndexingStatus()); - ClipService.setOnUpdateHandler(setIndexingStatus); - }; - main(); + clipService.setOnUpdateHandler(setIndexingStatus); + clipService.getIndexingStatus().then((st) => setIndexingStatus(st)); + return () => clipService.setOnUpdateHandler(undefined); }, []); return ( diff --git a/web/apps/photos/src/pages/gallery/index.tsx b/web/apps/photos/src/pages/gallery/index.tsx index 69b833802..b772771c4 100644 --- a/web/apps/photos/src/pages/gallery/index.tsx +++ b/web/apps/photos/src/pages/gallery/index.tsx @@ -102,7 +102,7 @@ import { } from "constants/collection"; import { SYNC_INTERVAL_IN_MICROSECONDS } from "constants/gallery"; import { AppContext } from "pages/_app"; -import { ClipService } from "services/clipService"; +import { clipService } from "services/clip-service"; import { constructUserIDToEmailMap } from "services/collectionService"; import downloadManager from "services/download"; import { syncEmbeddings } from "services/embeddingService"; @@ -362,7 +362,7 @@ export default function Gallery() { syncWithRemote(false, true); }, SYNC_INTERVAL_IN_MICROSECONDS); if (electron) { - void ClipService.setupOnFileUploadListener(); + void clipService.setupOnFileUploadListener(); electron.registerForegroundEventListener(() => { syncWithRemote(false, true); }); @@ -373,7 +373,7 @@ export default function Gallery() { clearInterval(syncInterval.current); if (electron) { electron.registerForegroundEventListener(() => {}); - ClipService.removeOnFileUploadListener(); + clipService.removeOnFileUploadListener(); } }; }, []); @@ -704,8 +704,8 @@ export default function Gallery() { await syncEntities(); await syncMapEnabled(); await syncEmbeddings(); - if (ClipService.isPlatformSupported()) { - void ClipService.scheduleImageEmbeddingExtraction(); + if (clipService.isPlatformSupported()) { + void clipService.scheduleImageEmbeddingExtraction(); } } catch (e) { switch (e.message) { diff --git a/web/apps/photos/src/services/clipService.ts b/web/apps/photos/src/services/clip-service.ts similarity index 72% rename from web/apps/photos/src/services/clipService.ts rename to web/apps/photos/src/services/clip-service.ts index 53e026d4f..c6a94213f 100644 --- a/web/apps/photos/src/services/clipService.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -1,5 +1,6 @@ import { ensureElectron } from "@/next/electron"; import log from "@/next/log"; +import type { Electron } from "@/next/types/ipc"; import ComlinkCryptoWorker from "@ente/shared/crypto"; import { CustomError } from "@ente/shared/error"; import { Events, eventBus } from "@ente/shared/events"; @@ -7,29 +8,71 @@ import { LS_KEYS, getData } from "@ente/shared/storage/localStorage"; import { FILE_TYPE } from "constants/file"; import isElectron from "is-electron"; import PQueue from "p-queue"; -import { Embedding, Model } from "types/embedding"; +import { Embedding } from "types/embedding"; import { EnteFile } from "types/file"; import { getPersonalFiles } from "utils/file"; import downloadManager from "./download"; import { getLocalEmbeddings, putEmbedding } from "./embeddingService"; import { getAllLocalFiles, getLocalFiles } from "./fileService"; -const CLIP_EMBEDDING_LENGTH = 512; - -export interface ClipExtractionStatus { +/** Status of CLIP indexing on the images in the user's local library. */ +export interface CLIPIndexingStatus { + /** Number of items pending indexing. */ pending: number; + /** Number of items that have already been indexed. */ indexed: number; } -class ClipServiceImpl { +/** + * Use a CLIP based neural network for natural language search. + * + * [Note: CLIP based magic search] + * + * CLIP (Contrastive Language-Image Pretraining) is a neural network trained on + * (image, text) pairs. It can be thought of as two separate (but jointly + * trained) encoders - one for images, and one for text - that both map to the + * same embedding space. + * + * We use this for natural language search within the app (aka "magic search"): + * + * 1. Pre-compute an embedding for each image. + * + * 2. When the user searches, compute an embedding for the search term. + * + * 3. Use cosine similarity to find the find the image (embedding) closest to + * the text (embedding). + * + * More details are in our [blog + * post](https://ente.io/blog/image-search-with-clip-ggml/) that describes the + * initial launch of this feature using the GGML runtime. + * + * Since the initial launch, we've switched over to another runtime, + * [ONNX](https://onnxruntime.ai). + * + * Note that we don't train the neural network - we only use one of the publicly + * available pre-trained neural networks for inference. These neural networks + * are wholly defined by their connectivity and weights. ONNX, our ML runtimes, + * loads these weights and instantiates a running network that we can use to + * compute the embeddings. + * + * Theoretically, the same CLIP model can be loaded by different frameworks / + * runtimes, but in practice each runtime has its own preferred format, and + * there are also quantization tradeoffs. So there is a specific model (a binary + * encoding of weights) tied to our current runtime that we use. + * + * To ensure that the embeddings, for the most part, can be shared, whenever + * possible we try to ensure that all the preprocessing steps, and the model + * itself, is the same across clients - web and mobile. + */ +class CLIPService { + private electron: Electron; private embeddingExtractionInProgress: AbortController | null = null; private reRunNeeded = false; - private clipExtractionStatus: ClipExtractionStatus = { + private indexingStatus: CLIPIndexingStatus = { pending: 0, indexed: 0, }; - private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null = - null; + private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined; private liveEmbeddingExtractionQueue: PQueue; private onFileUploadedHandler: | ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void) @@ -37,6 +80,7 @@ class ClipServiceImpl { private unsupportedPlatform = false; constructor() { + this.electron = ensureElectron(); this.liveEmbeddingExtractionQueue = new PQueue({ concurrency: 1, }); @@ -96,28 +140,23 @@ class ClipServiceImpl { }; getIndexingStatus = async () => { - try { - if ( - !this.clipExtractionStatus || - (this.clipExtractionStatus.pending === 0 && - this.clipExtractionStatus.indexed === 0) - ) { - this.clipExtractionStatus = await getClipExtractionStatus(); - } - return this.clipExtractionStatus; - } catch (e) { - log.error("failed to get clip indexing status", e); + if ( + this.indexingStatus.pending === 0 && + this.indexingStatus.indexed === 0 + ) { + this.indexingStatus = await initialIndexingStatus(); } + return this.indexingStatus; }; - setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => { + /** + * Set the {@link handler} to invoke whenever our indexing status changes. + */ + setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => { this.onUpdateHandler = handler; - handler(this.clipExtractionStatus); }; - scheduleImageEmbeddingExtraction = async ( - model: Model = Model.ONNX_CLIP, - ) => { + scheduleImageEmbeddingExtraction = async () => { try { if (this.embeddingExtractionInProgress) { log.info( @@ -133,7 +172,7 @@ class ClipServiceImpl { const canceller = new AbortController(); this.embeddingExtractionInProgress = canceller; try { - await this.runClipEmbeddingExtraction(canceller, model); + await this.runClipEmbeddingExtraction(canceller); } finally { this.embeddingExtractionInProgress = null; if (!canceller.signal.aborted && this.reRunNeeded) { @@ -152,25 +191,19 @@ class ClipServiceImpl { } }; - getTextEmbedding = async ( - text: string, - model: Model = Model.ONNX_CLIP, - ): Promise => { + getTextEmbedding = async (text: string): Promise => { try { - return ensureElectron().computeTextEmbedding(model, text); + return electron.clipTextEmbedding(text); } catch (e) { if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) { this.unsupportedPlatform = true; } - log.error("failed to compute text embedding", e); + log.error("Failed to compute CLIP text embedding", e); throw e; } }; - private runClipEmbeddingExtraction = async ( - canceller: AbortController, - model: Model, - ) => { + private runClipEmbeddingExtraction = async (canceller: AbortController) => { try { if (this.unsupportedPlatform) { log.info( @@ -183,12 +216,12 @@ class ClipServiceImpl { return; } const localFiles = getPersonalFiles(await getAllLocalFiles(), user); - const existingEmbeddings = await getLocalEmbeddings(model); + const existingEmbeddings = await getLocalEmbeddings(); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, existingEmbeddings, ); - this.updateClipEmbeddingExtractionStatus({ + this.updateIndexingStatus({ indexed: existingEmbeddings.length, pending: pendingFiles.length, }); @@ -208,15 +241,11 @@ class ClipServiceImpl { throw Error(CustomError.REQUEST_CANCELLED); } const embeddingData = - await this.extractFileClipImageEmbedding(model, file); + await this.extractFileClipImageEmbedding(file); log.info( `successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`, ); - await this.encryptAndUploadEmbedding( - model, - file, - embeddingData, - ); + await this.encryptAndUploadEmbedding(file, embeddingData); this.onSuccessStatusUpdater(); log.info( `successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`, @@ -249,13 +278,10 @@ class ClipServiceImpl { } }; - private async runLocalFileClipExtraction( - arg: { - enteFile: EnteFile; - localFile: globalThis.File; - }, - model: Model = Model.ONNX_CLIP, - ) { + private async runLocalFileClipExtraction(arg: { + enteFile: EnteFile; + localFile: globalThis.File; + }) { const { enteFile, localFile } = arg; log.info( `clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, @@ -279,15 +305,9 @@ class ClipServiceImpl { ); try { await this.liveEmbeddingExtractionQueue.add(async () => { - const embedding = await this.extractLocalFileClipImageEmbedding( - model, - localFile, - ); - await this.encryptAndUploadEmbedding( - model, - enteFile, - embedding, - ); + const embedding = + await this.extractLocalFileClipImageEmbedding(localFile); + await this.encryptAndUploadEmbedding(enteFile, embedding); }); log.info( `successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, @@ -297,26 +317,18 @@ class ClipServiceImpl { } } - private extractLocalFileClipImageEmbedding = async ( - model: Model, - localFile: File, - ) => { + private extractLocalFileClipImageEmbedding = async (localFile: File) => { const file = await localFile .arrayBuffer() .then((buffer) => new Uint8Array(buffer)); - const embedding = await ensureElectron().computeImageEmbedding( - model, - file, - ); - return embedding; + return await electron.clipImageEmbedding(file); }; private encryptAndUploadEmbedding = async ( - model: Model, file: EnteFile, embeddingData: Float32Array, ) => { - if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) { + if (embeddingData?.length !== 512) { throw Error( `invalid length embedding data length: ${embeddingData?.length}`, ); @@ -331,38 +343,31 @@ class ClipServiceImpl { fileID: file.id, encryptedEmbedding: encryptedEmbeddingData.encryptedData, decryptionHeader: encryptedEmbeddingData.decryptionHeader, - model, + model: "onnx-clip", }); }; - updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => { - this.clipExtractionStatus = status; - if (this.onUpdateHandler) { - this.onUpdateHandler(status); - } + private updateIndexingStatus = (status: CLIPIndexingStatus) => { + this.indexingStatus = status; + const handler = this.onUpdateHandler; + if (handler) handler(status); }; - private extractFileClipImageEmbedding = async ( - model: Model, - file: EnteFile, - ) => { + private extractFileClipImageEmbedding = async (file: EnteFile) => { const thumb = await downloadManager.getThumbnail(file); - const embedding = await ensureElectron().computeImageEmbedding( - model, - thumb, - ); + const embedding = await ensureElectron().clipImageEmbedding(thumb); return embedding; }; private onSuccessStatusUpdater = () => { - this.updateClipEmbeddingExtractionStatus({ - pending: this.clipExtractionStatus.pending - 1, - indexed: this.clipExtractionStatus.indexed + 1, + this.updateIndexingStatus({ + pending: this.indexingStatus.pending - 1, + indexed: this.indexingStatus.indexed + 1, }); }; } -export const ClipService = new ClipServiceImpl(); +export const clipService = new CLIPService(); const getNonClipEmbeddingExtractedFiles = async ( files: EnteFile[], @@ -412,14 +417,10 @@ export const computeClipMatchScore = async ( return score; }; -const getClipExtractionStatus = async ( - model: Model = Model.ONNX_CLIP, -): Promise => { +const initialIndexingStatus = async (): Promise => { const user = getData(LS_KEYS.USER); - if (!user) { - return; - } - const allEmbeddings = await getLocalEmbeddings(model); + if (!user) throw new Error("Orphan CLIP indexing without a login"); + const allEmbeddings = await getLocalEmbeddings(); const localFiles = getPersonalFiles(await getLocalFiles(), user); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index 882cdd16c..c4c0075c6 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -5,11 +5,11 @@ import HTTPService from "@ente/shared/network/HTTPService"; import { getEndpoint } from "@ente/shared/network/api"; import localForage from "@ente/shared/storage/localForage"; import { getToken } from "@ente/shared/storage/localStorage/helpers"; -import { +import type { Embedding, + EmbeddingModel, EncryptedEmbedding, GetEmbeddingDiffResponse, - Model, PutEmbeddingRequest, } from "types/embedding"; import { EnteFile } from "types/file"; @@ -38,12 +38,12 @@ export const getAllLocalEmbeddings = async () => { return embeddings; }; -export const getLocalEmbeddings = async (model: Model) => { +export const getLocalEmbeddings = async () => { const embeddings = await getAllLocalEmbeddings(); - return embeddings.filter((embedding) => embedding.model === model); + return embeddings.filter((embedding) => embedding.model === "onnx-clip"); }; -const getModelEmbeddingSyncTime = async (model: Model) => { +const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => { return ( (await localForage.getItem( `${model}-${EMBEDDING_SYNC_TIME_TABLE}`, @@ -51,11 +51,15 @@ const getModelEmbeddingSyncTime = async (model: Model) => { ); }; -const setModelEmbeddingSyncTime = async (model: Model, time: number) => { +const setModelEmbeddingSyncTime = async ( + model: EmbeddingModel, + time: number, +) => { await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time); }; -export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => { +export const syncEmbeddings = async () => { + const models: EmbeddingModel[] = ["onnx-clip"]; try { let allEmbeddings = await getAllLocalEmbeddings(); const localFiles = await getAllLocalFiles(); @@ -138,7 +142,7 @@ export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => { export const getEmbeddingsDiff = async ( sinceTime: number, - model: Model, + model: EmbeddingModel, ): Promise => { try { const token = getToken(); diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index 692b4ac84..b85005db0 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -4,7 +4,6 @@ import * as chrono from "chrono-node"; import { FILE_TYPE } from "constants/file"; import { t } from "i18next"; import { Collection } from "types/collection"; -import { Model } from "types/embedding"; import { EntityType, LocationTag, LocationTagData } from "types/entity"; import { EnteFile } from "types/file"; import { Person, Thing } from "types/machineLearning"; @@ -22,7 +21,7 @@ import { getAllPeople } from "utils/machineLearning"; import { getMLSyncConfig } from "utils/machineLearning/config"; import { getFormattedDate } from "utils/search"; import mlIDbStorage from "utils/storage/mlIDbStorage"; -import { ClipService, computeClipMatchScore } from "./clipService"; +import { clipService, computeClipMatchScore } from "./clip-service"; import { getLocalEmbeddings } from "./embeddingService"; import { getLatestEntities } from "./entityService"; import locationSearchService, { City } from "./locationSearchService"; @@ -305,7 +304,7 @@ async function getThingSuggestion(searchPhrase: string): Promise { async function getClipSuggestion(searchPhrase: string): Promise { try { - if (!ClipService.isPlatformSupported()) { + if (!clipService.isPlatformSupported()) { return null; } @@ -396,8 +395,8 @@ async function searchThing(searchPhrase: string) { } async function searchClip(searchPhrase: string): Promise { - const imageEmbeddings = await getLocalEmbeddings(Model.ONNX_CLIP); - const textEmbedding = await ClipService.getTextEmbedding(searchPhrase); + const imageEmbeddings = await getLocalEmbeddings(); + const textEmbedding = await clipService.getTextEmbedding(searchPhrase); const clipSearchResult = new Map( ( await Promise.all( diff --git a/web/apps/photos/src/types/embedding.tsx b/web/apps/photos/src/types/embedding.tsx index 3626e0fad..c0014d01e 100644 --- a/web/apps/photos/src/types/embedding.tsx +++ b/web/apps/photos/src/types/embedding.tsx @@ -1,11 +1,16 @@ -export enum Model { - GGML_CLIP = "ggml-clip", - ONNX_CLIP = "onnx-clip", -} +/** + * The embeddings models that we support. + * + * This is an exhaustive set of values we pass when PUT-ting encrypted + * embeddings on the server. However, we should be prepared to receive an + * {@link EncryptedEmbedding} with a model value distinct from one of these. + */ +export type EmbeddingModel = "onnx-clip"; export interface EncryptedEmbedding { fileID: number; - model: Model; + /** @see {@link EmbeddingModel} */ + model: string; encryptedEmbedding: string; decryptionHeader: string; updatedAt: number; @@ -25,7 +30,7 @@ export interface GetEmbeddingDiffResponse { export interface PutEmbeddingRequest { fileID: number; - model: Model; + model: EmbeddingModel; encryptedEmbedding: string; decryptionHeader: string; } diff --git a/web/packages/next/types/ipc.ts b/web/packages/next/types/ipc.ts index d13c775f4..8451b045e 100644 --- a/web/packages/next/types/ipc.ts +++ b/web/packages/next/types/ipc.ts @@ -10,11 +10,6 @@ export interface AppUpdateInfo { version: string; } -export enum Model { - GGML_CLIP = "ggml-clip", - ONNX_CLIP = "onnx-clip", -} - export enum FILE_PATH_TYPE { FILES = "files", ZIPS = "zips", @@ -147,12 +142,27 @@ export interface Electron { // - ML - computeImageEmbedding: ( - model: Model, - imageData: Uint8Array, - ) => Promise; + /** + * Compute and return a CLIP embedding of the given image. + * + * See: [Note: CLIP based magic search] + * + * @param jpegImageData The raw bytes of the image encoded as an JPEG. + * + * @returns A CLIP embedding. + */ + clipImageEmbedding: (jpegImageData: Uint8Array) => Promise; - computeTextEmbedding: (model: Model, text: string) => Promise; + /** + * Compute and return a CLIP embedding of the given image. + * + * See: [Note: CLIP based magic search] + * + * @param text The string whose embedding we want to compute. + * + * @returns A CLIP embedding. + */ + clipTextEmbedding: (text: string) => Promise; // - File selection // TODO: Deprecated - use dialogs on the renderer process itself