diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts new file mode 100644 index 000000000..3fe6da2eb --- /dev/null +++ b/desktop/src/main/services/ml-clip.ts @@ -0,0 +1,248 @@ +/** + * @file Compute CLIP embeddings for images and text. + * + * The embeddings are computed using ONNX runtime, with CLIP as the model. + * + * @see `web/apps/photos/src/services/clip-service.ts` for more details. + */ +import { existsSync } from "fs"; +import jpeg from "jpeg-js"; +import fs from "node:fs/promises"; +import * as ort from "onnxruntime-node"; +import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; +import { CustomErrors } from "../../types/ipc"; +import { writeStream } from "../fs"; +import log from "../log"; +import { generateTempFilePath } from "../temp"; +import { deleteTempFile } from "./ffmpeg"; +import { + createInferenceSession, + downloadModel, + modelPathDownloadingIfNeeded, + modelSavePath, +} from "./ml"; + +const textModelName = "clip-text-vit-32-uint8.onnx"; +const textModelByteSize = 64173509; // 61.2 MB + +const imageModelName = "clip-image-vit-32-float32.onnx"; +const imageModelByteSize = 351468764; // 335.2 MB + +let activeImageModelDownload: Promise | undefined; + +const imageModelPathDownloadingIfNeeded = async () => { + try { + if (activeImageModelDownload) { + log.info("Waiting for CLIP image model download to finish"); + await activeImageModelDownload; + } else { + activeImageModelDownload = modelPathDownloadingIfNeeded( + imageModelName, + imageModelByteSize, + ); + return await activeImageModelDownload; + } + } finally { + activeImageModelDownload = undefined; + } +}; + +let textModelDownloadInProgress = false; + +/* TODO(MR): use the generic method. Then we can remove the exports for the + internal details functions that we use here */ +const textModelPathDownloadingIfNeeded = async () => { + if (textModelDownloadInProgress) + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + + const modelPath = modelSavePath(textModelName); + if (!existsSync(modelPath)) { + log.info("CLIP text model not found, downloading"); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } else { + const localFileSize = (await fs.stat(modelPath)).size; + if (localFileSize !== textModelByteSize) { + log.error( + `CLIP text model size ${localFileSize} does not match the expected size, downloading again`, + ); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } + } + + return modelPath; +}; + +let imageSessionPromise: Promise | undefined; + +const onnxImageSession = async () => { + if (!imageSessionPromise) { + imageSessionPromise = (async () => { + const modelPath = await imageModelPathDownloadingIfNeeded(); + return createInferenceSession(modelPath); + })(); + } + return imageSessionPromise; +}; + +let _textSession: any = null; + +const onnxTextSession = async () => { + if (!_textSession) { + const modelPath = await textModelPathDownloadingIfNeeded(); + _textSession = await createInferenceSession(modelPath); + } + return _textSession; +}; + +export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { + const tempFilePath = await generateTempFilePath(""); + const imageStream = new Response(jpegImageData.buffer).body; + await writeStream(tempFilePath, imageStream); + try { + return await clipImageEmbedding_(tempFilePath); + } finally { + await deleteTempFile(tempFilePath); + } +}; + +const clipImageEmbedding_ = async (jpegFilePath: string) => { + const imageSession = await onnxImageSession(); + const t1 = Date.now(); + const rgbData = await getRGBData(jpegFilePath); + const feeds = { + input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.debug( + () => + `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + ); + const imageEmbedding = results["output"].data; // Float32Array + return normalizeEmbedding(imageEmbedding); +}; + +const getRGBData = async (jpegFilePath: string) => { + const jpegData = await fs.readFile(jpegFilePath); + const rawImageData = jpeg.decode(jpegData, { + useTArray: true, + formatAsRGBA: false, + }); + + const nx: number = rawImageData.width; + const ny: number = rawImageData.height; + const inputImage: Uint8Array = rawImageData.data; + + const nx2: number = 224; + const ny2: number = 224; + const totalSize: number = 3 * nx2 * ny2; + + const result: number[] = Array(totalSize).fill(0); + const scale: number = Math.max(nx, ny) / 224; + + const nx3: number = Math.round(nx / scale); + const ny3: number = Math.round(ny / scale); + + const mean: number[] = [0.48145466, 0.4578275, 0.40821073]; + const std: number[] = [0.26862954, 0.26130258, 0.27577711]; + + for (let y = 0; y < ny3; y++) { + for (let x = 0; x < nx3; x++) { + for (let c = 0; c < 3; c++) { + // Linear interpolation + const sx: number = (x + 0.5) * scale - 0.5; + const sy: number = (y + 0.5) * scale - 0.5; + + const x0: number = Math.max(0, Math.floor(sx)); + const y0: number = Math.max(0, Math.floor(sy)); + + const x1: number = Math.min(x0 + 1, nx - 1); + const y1: number = Math.min(y0 + 1, ny - 1); + + const dx: number = sx - x0; + const dy: number = sy - y0; + + const j00: number = 3 * (y0 * nx + x0) + c; + const j01: number = 3 * (y0 * nx + x1) + c; + const j10: number = 3 * (y1 * nx + x0) + c; + const j11: number = 3 * (y1 * nx + x1) + c; + + const v00: number = inputImage[j00]; + const v01: number = inputImage[j01]; + const v10: number = inputImage[j10]; + const v11: number = inputImage[j11]; + + const v0: number = v00 * (1 - dx) + v01 * dx; + const v1: number = v10 * (1 - dx) + v11 * dx; + + const v: number = v0 * (1 - dy) + v1 * dy; + + const v2: number = Math.min(Math.max(Math.round(v), 0), 255); + + // createTensorWithDataList is dumb compared to reshape and + // hence has to be given with one channel after another + const i: number = y * nx3 + x + (c % 3) * 224 * 224; + + result[i] = (v2 / 255 - mean[c]) / std[c]; + } + } + } + + return result; +}; + +const normalizeEmbedding = (embedding: Float32Array) => { + let normalization = 0; + for (let index = 0; index < embedding.length; index++) { + normalization += embedding[index] * embedding[index]; + } + const sqrtNormalization = Math.sqrt(normalization); + for (let index = 0; index < embedding.length; index++) { + embedding[index] = embedding[index] / sqrtNormalization; + } + return embedding; +}; + +let _tokenizer: Tokenizer = null; +const getTokenizer = () => { + if (!_tokenizer) { + _tokenizer = new Tokenizer(); + } + return _tokenizer; +}; + +export const clipTextEmbedding = async (text: string) => { + const imageSession = await onnxTextSession(); + const t1 = Date.now(); + const tokenizer = getTokenizer(); + const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); + const feeds = { + input: new ort.Tensor("int32", tokenizedText, [1, 77]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.debug( + () => + `CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + ); + const textEmbedding = results["output"].data; + return normalizeEmbedding(textEmbedding); +}; diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts new file mode 100644 index 000000000..c547885bb --- /dev/null +++ b/desktop/src/main/services/ml-face.ts @@ -0,0 +1,77 @@ +/** + * @file Various face recognition related tasks. + * + * - Face detection with the YOLO model. + * - Face embedding with the mobilefacenet model. + * + * The runtime used is ONNX. + */ +import * as ort from "onnxruntime-node"; +import log from "../log"; +import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml"; + +const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx"; +const faceDetectionModelByteSize = 30762872; // 29.3 MB + +const faceEmbeddingModelName = "mobilefacenet_opset15.onnx"; +const faceEmbeddingModelByteSize = 5286998; // 5 MB + +let activeFaceDetectionModelDownload: Promise | undefined; + +const faceDetectionModelPathDownloadingIfNeeded = async () => { + try { + if (activeFaceDetectionModelDownload) { + log.info("Waiting for face detection model download to finish"); + await activeFaceDetectionModelDownload; + } else { + activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded( + faceDetectionModelName, + faceDetectionModelByteSize, + ); + return await activeFaceDetectionModelDownload; + } + } finally { + activeFaceDetectionModelDownload = undefined; + } +}; + +let _faceDetectionSession: Promise | undefined; + +const faceDetectionSession = async () => { + if (!_faceDetectionSession) { + _faceDetectionSession = + faceDetectionModelPathDownloadingIfNeeded().then((modelPath) => + createInferenceSession(modelPath), + ); + } + return _faceDetectionSession; +}; + + +// export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { +// const tempFilePath = await generateTempFilePath(""); +// const imageStream = new Response(jpegImageData.buffer).body; +// await writeStream(tempFilePath, imageStream); +// try { +// return await clipImageEmbedding_(tempFilePath); +// } finally { +// await deleteTempFile(tempFilePath); +// } +// }; + +// const clipImageEmbedding_ = async (jpegFilePath: string) => { +// const imageSession = await onnxImageSession(); +// const t1 = Date.now(); +// const rgbData = await getRGBData(jpegFilePath); +// const feeds = { +// input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), +// }; +// const t2 = Date.now(); +// const results = await imageSession.run(feeds); +// log.debug( +// () => +// `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, +// ); +// const imageEmbedding = results["output"].data; // Float32Array +// return normalizeEmbedding(imageEmbedding); +// }; diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts new file mode 100644 index 000000000..10402db21 --- /dev/null +++ b/desktop/src/main/services/ml.ts @@ -0,0 +1,79 @@ +/** + * @file AI/ML related functionality. + * + * @see also `ml-clip.ts`, `ml-face.ts`. + * + * The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models + * for various tasks are not shipped with the app but are downloaded on demand. + * + * The primary reason for doing these tasks in the Node.js layer is so that we + * can use the binary ONNX runtime which is 10-20x faster than the WASM based + * web one. + */ +import { app, net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import * as ort from "onnxruntime-node"; +import { writeStream } from "../fs"; +import log from "../log"; + +/** + * Download the model named {@link modelName} if we don't already have it. + * + * Also verify that the size of the model we get matches {@expectedByteSize} (if + * not, redownload it). + * + * @returns the path to the model on the local machine. + */ +export const modelPathDownloadingIfNeeded = async ( + modelName: string, + expectedByteSize: number, +) => { + const modelPath = modelSavePath(modelName); + + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + await downloadModel(modelPath, modelName); + } else { + const size = (await fs.stat(modelPath)).size; + if (size !== expectedByteSize) { + log.error( + `The size ${size} of model ${modelName} does not match the expected size, downloading again`, + ); + await downloadModel(modelPath, modelName); + } + } + + return modelPath; +}; + +/** Return the path where the given {@link modelName} is meant to be saved */ +export const modelSavePath = (modelName: string) => + path.join(app.getPath("userData"), "models", modelName); + +export const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download + log.info(`Downloading ML model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + // Save + await writeStream(saveLocation, res.body); + log.info(`Downloaded CLIP model ${name}`); +}; + +/** + * Crete an ONNX {@link InferenceSession} with some defaults. + */ +export const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + // Restrict the number of threads to 1 + intraOpNumThreads: 1, + // Be more conservative with RAM usage + enableCpuMemArena: false, + }); +};