diff --git a/desktop/build/ggmlclip-linux b/desktop/build/ggmlclip-linux deleted file mode 100755 index 4c160b039..000000000 Binary files a/desktop/build/ggmlclip-linux and /dev/null differ diff --git a/desktop/build/ggmlclip-mac b/desktop/build/ggmlclip-mac deleted file mode 100755 index db7c4f249..000000000 Binary files a/desktop/build/ggmlclip-mac and /dev/null differ diff --git a/desktop/build/ggmlclip-windows.exe b/desktop/build/ggmlclip-windows.exe deleted file mode 100755 index 1e197dfe8..000000000 Binary files a/desktop/build/ggmlclip-windows.exe and /dev/null differ diff --git a/desktop/docs/dependencies.md b/desktop/docs/dependencies.md index 9c5f86e59..5c1b07744 100644 --- a/desktop/docs/dependencies.md +++ b/desktop/docs/dependencies.md @@ -116,7 +116,6 @@ watcher for the watch folders functionality. - html-entities is used by the bundled clip-bpe-ts tokenizer. - [jpeg-js](https://github.com/jpeg-js/jpeg-js#readme) is used for decoding JPEG data into raw RGB bytes before passing it to ONNX. -- GGML binaries are bundled. ## ZIP diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index 0fdd10056..f4da569c5 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -17,9 +17,9 @@ import { updateAndRestart, } from "../services/appUpdater"; import { - computeImageEmbedding, - computeTextEmbedding, -} from "../services/clipService"; + clipImageEmbedding, + clipTextEmbedding, +} from "../services/clip-service"; import { runFFmpegCmd } from "../services/ffmpeg"; import { getDirFiles } from "../services/fs"; import { @@ -44,12 +44,7 @@ import { updateWatchMappingIgnoredFiles, updateWatchMappingSyncedFiles, } from "../services/watch"; -import type { - ElectronFile, - FILE_PATH_TYPE, - Model, - WatchMapping, -} from "../types/ipc"; +import type { ElectronFile, FILE_PATH_TYPE, WatchMapping } from "../types/ipc"; import { selectDirectory, showUploadDirsDialog, @@ -148,14 +143,12 @@ export const attachIPCHandlers = () => { // - ML - ipcMain.handle( - "computeImageEmbedding", - (_, model: Model, imageData: Uint8Array) => - computeImageEmbedding(model, imageData), + ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) => + clipImageEmbedding(jpegImageData), ); - ipcMain.handle("computeTextEmbedding", (_, model: Model, text: string) => - computeTextEmbedding(model, text), + ipcMain.handle("clipTextEmbedding", (_, text: string) => + clipTextEmbedding(text), ); // - File selection diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index aa528b7ad..cb718f950 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -45,7 +45,6 @@ import type { AppUpdateInfo, ElectronFile, FILE_PATH_TYPE, - Model, WatchMapping, } from "./types/ipc"; @@ -141,17 +140,11 @@ const runFFmpegCmd = ( // - ML -const computeImageEmbedding = ( - model: Model, - imageData: Uint8Array, -): Promise => - ipcRenderer.invoke("computeImageEmbedding", model, imageData); +const clipImageEmbedding = (jpegImageData: Uint8Array): Promise => + ipcRenderer.invoke("clipImageEmbedding", jpegImageData); -const computeTextEmbedding = ( - model: Model, - text: string, -): Promise => - ipcRenderer.invoke("computeTextEmbedding", model, text); +const clipTextEmbedding = (text: string): Promise => + ipcRenderer.invoke("clipTextEmbedding", text); // - File selection @@ -332,8 +325,8 @@ contextBridge.exposeInMainWorld("electron", { runFFmpegCmd, // - ML - computeImageEmbedding, - computeTextEmbedding, + clipImageEmbedding, + clipTextEmbedding, // - File selection selectDirectory, diff --git a/desktop/src/services/clip-service.ts b/desktop/src/services/clip-service.ts new file mode 100644 index 000000000..add37656e --- /dev/null +++ b/desktop/src/services/clip-service.ts @@ -0,0 +1,290 @@ +/** + * @file Compute CLIP embeddings + * + * @see `web/apps/photos/src/services/clip-service.ts` for more details. This + * file implements the Node.js implementation of the actual embedding + * computation. By doing it in the Node.js layer, we can use the binary ONNX + * runtimes which are 10-20x faster than the WASM based web ones. + * + * The embeddings are computed using ONNX runtime. The model itself is not + * shipped with the app but is downloaded on demand. + */ +import { app, net } from "electron/main"; +import { existsSync } from "fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { writeStream } from "../main/fs"; +import log from "../main/log"; +import { CustomErrors } from "../types/ipc"; +import Tokenizer from "../utils/clip-bpe-ts/mod"; +import { generateTempFilePath } from "../utils/temp"; +import { deleteTempFile } from "./ffmpeg"; +const jpeg = require("jpeg-js"); +const ort = require("onnxruntime-node"); + +const textModelName = "clip-text-vit-32-uint8.onnx"; +const textModelByteSize = 64173509; // 61.2 MB + +const imageModelName = "clip-image-vit-32-float32.onnx"; +const imageModelByteSize = 351468764; // 335.2 MB + +/** Return the path where the given {@link modelName} is meant to be saved */ +const modelSavePath = (modelName: string) => + path.join(app.getPath("userData"), "models", modelName); + +const downloadModel = async (saveLocation: string, name: string) => { + // `mkdir -p` the directory where we want to save the model. + const saveDir = path.dirname(saveLocation); + await fs.mkdir(saveDir, { recursive: true }); + // Download + log.info(`Downloading CLIP model from ${name}`); + const url = `https://models.ente.io/${name}`; + const res = await net.fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); + // Save + await writeStream(saveLocation, res.body); + log.info(`Downloaded CLIP model ${name}`); +}; + +let activeImageModelDownload: Promise | undefined; + +const imageModelPathDownloadingIfNeeded = async () => { + try { + const modelPath = modelSavePath(imageModelName); + if (activeImageModelDownload) { + log.info("Waiting for CLIP image model download to finish"); + await activeImageModelDownload; + } else { + if (!existsSync(modelPath)) { + log.info("CLIP image model not found, downloading"); + activeImageModelDownload = downloadModel( + modelPath, + imageModelName, + ); + await activeImageModelDownload; + } else { + const localFileSize = (await fs.stat(modelPath)).size; + if (localFileSize !== imageModelByteSize) { + log.error( + `CLIP image model size ${localFileSize} does not match the expected size, downloading again`, + ); + activeImageModelDownload = downloadModel( + modelPath, + imageModelName, + ); + await activeImageModelDownload; + } + } + } + return modelPath; + } finally { + activeImageModelDownload = undefined; + } +}; + +let textModelDownloadInProgress = false; + +const textModelPathDownloadingIfNeeded = async () => { + if (textModelDownloadInProgress) + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + + const modelPath = modelSavePath(textModelName); + if (!existsSync(modelPath)) { + log.info("CLIP text model not found, downloading"); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } else { + const localFileSize = (await fs.stat(modelPath)).size; + if (localFileSize !== textModelByteSize) { + log.error( + `CLIP text model size ${localFileSize} does not match the expected size, downloading again`, + ); + textModelDownloadInProgress = true; + downloadModel(modelPath, textModelName) + .catch((e) => { + // log but otherwise ignore + log.error("CLIP text model download failed", e); + }) + .finally(() => { + textModelDownloadInProgress = false; + }); + throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); + } + } + + return modelPath; +}; + +const createInferenceSession = async (modelPath: string) => { + return await ort.InferenceSession.create(modelPath, { + intraOpNumThreads: 1, + enableCpuMemArena: false, + }); +}; + +let imageSessionPromise: Promise | undefined; + +const onnxImageSession = async () => { + if (!imageSessionPromise) { + imageSessionPromise = (async () => { + const modelPath = await imageModelPathDownloadingIfNeeded(); + return createInferenceSession(modelPath); + })(); + } + return imageSessionPromise; +}; + +let _textSession: any = null; + +const onnxTextSession = async () => { + if (!_textSession) { + const modelPath = await textModelPathDownloadingIfNeeded(); + _textSession = await createInferenceSession(modelPath); + } + return _textSession; +}; + +export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { + const tempFilePath = await generateTempFilePath(""); + const imageStream = new Response(jpegImageData.buffer).body; + await writeStream(tempFilePath, imageStream); + try { + return await clipImageEmbedding_(tempFilePath); + } finally { + await deleteTempFile(tempFilePath); + } +}; + +const clipImageEmbedding_ = async (jpegFilePath: string) => { + const imageSession = await onnxImageSession(); + const t1 = Date.now(); + const rgbData = await getRGBData(jpegFilePath); + const feeds = { + input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.info( + `onnx image embedding time: ${Date.now() - t1} ms (prep:${ + t2 - t1 + } ms, extraction: ${Date.now() - t2} ms)`, + ); + const imageEmbedding = results["output"].data; // Float32Array + return normalizeEmbedding(imageEmbedding); +}; + +const getRGBData = async (jpegFilePath: string) => { + const jpegData = await fs.readFile(jpegFilePath); + const rawImageData = jpeg.decode(jpegData, { + useTArray: true, + formatAsRGBA: false, + }); + + const nx: number = rawImageData.width; + const ny: number = rawImageData.height; + const inputImage: Uint8Array = rawImageData.data; + + const nx2: number = 224; + const ny2: number = 224; + const totalSize: number = 3 * nx2 * ny2; + + const result: number[] = Array(totalSize).fill(0); + const scale: number = Math.max(nx, ny) / 224; + + const nx3: number = Math.round(nx / scale); + const ny3: number = Math.round(ny / scale); + + const mean: number[] = [0.48145466, 0.4578275, 0.40821073]; + const std: number[] = [0.26862954, 0.26130258, 0.27577711]; + + for (let y = 0; y < ny3; y++) { + for (let x = 0; x < nx3; x++) { + for (let c = 0; c < 3; c++) { + // Linear interpolation + const sx: number = (x + 0.5) * scale - 0.5; + const sy: number = (y + 0.5) * scale - 0.5; + + const x0: number = Math.max(0, Math.floor(sx)); + const y0: number = Math.max(0, Math.floor(sy)); + + const x1: number = Math.min(x0 + 1, nx - 1); + const y1: number = Math.min(y0 + 1, ny - 1); + + const dx: number = sx - x0; + const dy: number = sy - y0; + + const j00: number = 3 * (y0 * nx + x0) + c; + const j01: number = 3 * (y0 * nx + x1) + c; + const j10: number = 3 * (y1 * nx + x0) + c; + const j11: number = 3 * (y1 * nx + x1) + c; + + const v00: number = inputImage[j00]; + const v01: number = inputImage[j01]; + const v10: number = inputImage[j10]; + const v11: number = inputImage[j11]; + + const v0: number = v00 * (1 - dx) + v01 * dx; + const v1: number = v10 * (1 - dx) + v11 * dx; + + const v: number = v0 * (1 - dy) + v1 * dy; + + const v2: number = Math.min(Math.max(Math.round(v), 0), 255); + + // createTensorWithDataList is dumb compared to reshape and + // hence has to be given with one channel after another + const i: number = y * nx3 + x + (c % 3) * 224 * 224; + + result[i] = (v2 / 255 - mean[c]) / std[c]; + } + } + } + + return result; +}; + +const normalizeEmbedding = (embedding: Float32Array) => { + let normalization = 0; + for (let index = 0; index < embedding.length; index++) { + normalization += embedding[index] * embedding[index]; + } + const sqrtNormalization = Math.sqrt(normalization); + for (let index = 0; index < embedding.length; index++) { + embedding[index] = embedding[index] / sqrtNormalization; + } + return embedding; +}; + +let _tokenizer: Tokenizer = null; +const getTokenizer = () => { + if (!_tokenizer) { + _tokenizer = new Tokenizer(); + } + return _tokenizer; +}; + +export const clipTextEmbedding = async (text: string) => { + const imageSession = await onnxTextSession(); + const t1 = Date.now(); + const tokenizer = getTokenizer(); + const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); + const feeds = { + input: new ort.Tensor("int32", tokenizedText, [1, 77]), + }; + const t2 = Date.now(); + const results = await imageSession.run(feeds); + log.info( + `onnx text embedding time: ${Date.now() - t1} ms (prep:${ + t2 - t1 + } ms, extraction: ${Date.now() - t2} ms)`, + ); + const textEmbedding = results["output"].data; + return normalizeEmbedding(textEmbedding); +}; diff --git a/desktop/src/services/clipService.ts b/desktop/src/services/clipService.ts deleted file mode 100644 index cb41ae32e..000000000 --- a/desktop/src/services/clipService.ts +++ /dev/null @@ -1,474 +0,0 @@ -/** - * @file Compute CLIP embeddings - * - * @see `web/apps/photos/services/clip-service.ts` for more details. This file - * implements the Node.js implementation of the actual embedding computation. By - * doing it in the Node.js layer, we can use the binary ONNX runtimes which are - * 10-20x faster than the WASM based web ones. - * - * The embeddings are computed using ONNX runtime. The model itself is not - * shipped with the app but is downloaded on demand. - */ -import { app, net } from "electron/main"; -import { existsSync } from "fs"; -import fs from "node:fs/promises"; -import path from "node:path"; -import { writeStream } from "../main/fs"; -import log from "../main/log"; -import { execAsync, isDev } from "../main/util"; -import { CustomErrors, Model, isModel } from "../types/ipc"; -import Tokenizer from "../utils/clip-bpe-ts/mod"; -import { getPlatform } from "../utils/common/platform"; -import { generateTempFilePath } from "../utils/temp"; -import { deleteTempFile } from "./ffmpeg"; -const jpeg = require("jpeg-js"); - -const CLIP_MODEL_PATH_PLACEHOLDER = "CLIP_MODEL"; -const GGMLCLIP_PATH_PLACEHOLDER = "GGML_PATH"; -const INPUT_PATH_PLACEHOLDER = "INPUT"; - -const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [ - GGMLCLIP_PATH_PLACEHOLDER, - "-mv", - CLIP_MODEL_PATH_PLACEHOLDER, - "--image", - INPUT_PATH_PLACEHOLDER, -]; - -const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [ - GGMLCLIP_PATH_PLACEHOLDER, - "-mt", - CLIP_MODEL_PATH_PLACEHOLDER, - "--text", - INPUT_PATH_PLACEHOLDER, -]; -const ort = require("onnxruntime-node"); - -const TEXT_MODEL_DOWNLOAD_URL = { - ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf", - onnx: "https://models.ente.io/clip-text-vit-32-uint8.onnx", -}; -const IMAGE_MODEL_DOWNLOAD_URL = { - ggml: "https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf", - onnx: "https://models.ente.io/clip-image-vit-32-float32.onnx", -}; - -const TEXT_MODEL_NAME = { - ggml: "clip-vit-base-patch32_ggml-text-model-f16.gguf", - onnx: "clip-text-vit-32-uint8.onnx", -}; -const IMAGE_MODEL_NAME = { - ggml: "clip-vit-base-patch32_ggml-vision-model-f16.gguf", - onnx: "clip-image-vit-32-float32.onnx", -}; - -const IMAGE_MODEL_SIZE_IN_BYTES = { - ggml: 175957504, // 167.8 MB - onnx: 351468764, // 335.2 MB -}; -const TEXT_MODEL_SIZE_IN_BYTES = { - ggml: 127853440, // 121.9 MB, - onnx: 64173509, // 61.2 MB -}; - -/** Return the path where the given {@link modelName} is meant to be saved */ -const getModelSavePath = (modelName: string) => - path.join(app.getPath("userData"), "models", modelName); - -async function downloadModel(saveLocation: string, url: string) { - // confirm that the save location exists - const saveDir = path.dirname(saveLocation); - await fs.mkdir(saveDir, { recursive: true }); - log.info("downloading clip model"); - const res = await net.fetch(url); - if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - await writeStream(saveLocation, res.body); - log.info("clip model downloaded"); -} - -let imageModelDownloadInProgress: Promise = null; - -const getClipImageModelPath = async (type: "ggml" | "onnx") => { - try { - const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]); - if (imageModelDownloadInProgress) { - log.info("waiting for image model download to finish"); - await imageModelDownloadInProgress; - } else { - if (!existsSync(modelSavePath)) { - log.info("CLIP image model not found, downloading"); - imageModelDownloadInProgress = downloadModel( - modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL[type], - ); - await imageModelDownloadInProgress; - } else { - const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) { - log.info( - `CLIP image model size mismatch, downloading again got: ${localFileSize}`, - ); - imageModelDownloadInProgress = downloadModel( - modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL[type], - ); - await imageModelDownloadInProgress; - } - } - } - return modelSavePath; - } finally { - imageModelDownloadInProgress = null; - } -}; - -let textModelDownloadInProgress: boolean = false; - -const getClipTextModelPath = async (type: "ggml" | "onnx") => { - const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]); - if (textModelDownloadInProgress) { - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } else { - if (!existsSync(modelSavePath)) { - log.info("CLIP text model not found, downloading"); - textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } else { - const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) { - log.info( - `CLIP text model size mismatch, downloading again got: ${localFileSize}`, - ); - textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) - .catch((e) => { - // log but otherwise ignore - log.error("CLIP text model download failed", e); - }) - .finally(() => { - textModelDownloadInProgress = false; - }); - throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); - } - } - } - return modelSavePath; -}; - -function getGGMLClipPath() { - return isDev - ? path.join("./build", `ggmlclip-${getPlatform()}`) - : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`); -} - -async function createOnnxSession(modelPath: string) { - return await ort.InferenceSession.create(modelPath, { - intraOpNumThreads: 1, - enableCpuMemArena: false, - }); -} - -let onnxImageSessionPromise: Promise = null; - -async function getOnnxImageSession() { - if (!onnxImageSessionPromise) { - onnxImageSessionPromise = (async () => { - const clipModelPath = await getClipImageModelPath("onnx"); - return createOnnxSession(clipModelPath); - })(); - } - return onnxImageSessionPromise; -} - -let onnxTextSession: any = null; - -async function getOnnxTextSession() { - if (!onnxTextSession) { - const clipModelPath = await getClipTextModelPath("onnx"); - onnxTextSession = await createOnnxSession(clipModelPath); - } - return onnxTextSession; -} - -let tokenizer: Tokenizer = null; -function getTokenizer() { - if (!tokenizer) { - tokenizer = new Tokenizer(); - } - return tokenizer; -} - -export const computeImageEmbedding = async ( - model: Model, - imageData: Uint8Array, -): Promise => { - if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`); - - let tempInputFilePath = null; - try { - tempInputFilePath = await generateTempFilePath(""); - const imageStream = new Response(imageData.buffer).body; - await writeStream(tempInputFilePath, imageStream); - const embedding = await computeImageEmbedding_( - model, - tempInputFilePath, - ); - return embedding; - } catch (err) { - if (isExecError(err)) { - const parsedExecError = parseExecError(err); - throw Error(parsedExecError); - } else { - throw err; - } - } finally { - if (tempInputFilePath) { - await deleteTempFile(tempInputFilePath); - } - } -}; - -const isExecError = (err: any) => { - return err.message.includes("Command failed:"); -}; - -const parseExecError = (err: any) => { - const errMessage = err.message; - if (errMessage.includes("Bad CPU type in executable")) { - return CustomErrors.UNSUPPORTED_PLATFORM( - process.platform, - process.arch, - ); - } else { - return errMessage; - } -}; - -async function computeImageEmbedding_( - model: Model, - inputFilePath: string, -): Promise { - if (!existsSync(inputFilePath)) { - throw new Error("Invalid file path"); - } - switch (model) { - case "ggml-clip": - return await computeGGMLImageEmbedding(inputFilePath); - case "onnx-clip": - return await computeONNXImageEmbedding(inputFilePath); - } -} - -const computeGGMLImageEmbedding = async ( - inputFilePath: string, -): Promise => { - const clipModelPath = await getClipImageModelPath("ggml"); - const ggmlclipPath = getGGMLClipPath(); - const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { - if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { - return ggmlclipPath; - } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { - return clipModelPath; - } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { - return inputFilePath; - } else { - return cmdPart; - } - }); - - const { stdout } = await execAsync(cmd); - // parse stdout and return embedding - // get the last line of stdout - const lines = stdout.split("\n"); - const lastLine = lines[lines.length - 1]; - const embedding = JSON.parse(lastLine); - const embeddingArray = new Float32Array(embedding); - return embeddingArray; -}; - -const computeONNXImageEmbedding = async ( - inputFilePath: string, -): Promise => { - const imageSession = await getOnnxImageSession(); - const t1 = Date.now(); - const rgbData = await getRGBData(inputFilePath); - const feeds = { - input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]), - }; - const t2 = Date.now(); - const results = await imageSession.run(feeds); - log.info( - `onnx image embedding time: ${Date.now() - t1} ms (prep:${ - t2 - t1 - } ms, extraction: ${Date.now() - t2} ms)`, - ); - const imageEmbedding = results["output"].data; // Float32Array - return normalizeEmbedding(imageEmbedding); -}; - -async function getRGBData(inputFilePath: string) { - const jpegData = await fs.readFile(inputFilePath); - const rawImageData = jpeg.decode(jpegData, { - useTArray: true, - formatAsRGBA: false, - }); - - const nx: number = rawImageData.width; - const ny: number = rawImageData.height; - const inputImage: Uint8Array = rawImageData.data; - - const nx2: number = 224; - const ny2: number = 224; - const totalSize: number = 3 * nx2 * ny2; - - const result: number[] = Array(totalSize).fill(0); - const scale: number = Math.max(nx, ny) / 224; - - const nx3: number = Math.round(nx / scale); - const ny3: number = Math.round(ny / scale); - - const mean: number[] = [0.48145466, 0.4578275, 0.40821073]; - const std: number[] = [0.26862954, 0.26130258, 0.27577711]; - - for (let y = 0; y < ny3; y++) { - for (let x = 0; x < nx3; x++) { - for (let c = 0; c < 3; c++) { - // linear interpolation - const sx: number = (x + 0.5) * scale - 0.5; - const sy: number = (y + 0.5) * scale - 0.5; - - const x0: number = Math.max(0, Math.floor(sx)); - const y0: number = Math.max(0, Math.floor(sy)); - - const x1: number = Math.min(x0 + 1, nx - 1); - const y1: number = Math.min(y0 + 1, ny - 1); - - const dx: number = sx - x0; - const dy: number = sy - y0; - - const j00: number = 3 * (y0 * nx + x0) + c; - const j01: number = 3 * (y0 * nx + x1) + c; - const j10: number = 3 * (y1 * nx + x0) + c; - const j11: number = 3 * (y1 * nx + x1) + c; - - const v00: number = inputImage[j00]; - const v01: number = inputImage[j01]; - const v10: number = inputImage[j10]; - const v11: number = inputImage[j11]; - - const v0: number = v00 * (1 - dx) + v01 * dx; - const v1: number = v10 * (1 - dx) + v11 * dx; - - const v: number = v0 * (1 - dy) + v1 * dy; - - const v2: number = Math.min(Math.max(Math.round(v), 0), 255); - - // createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another - const i: number = y * nx3 + x + (c % 3) * 224 * 224; - - result[i] = (v2 / 255 - mean[c]) / std[c]; - } - } - } - - return result; -} - -const normalizeEmbedding = (embedding: Float32Array) => { - let normalization = 0; - for (let index = 0; index < embedding.length; index++) { - normalization += embedding[index] * embedding[index]; - } - const sqrtNormalization = Math.sqrt(normalization); - for (let index = 0; index < embedding.length; index++) { - embedding[index] = embedding[index] / sqrtNormalization; - } - return embedding; -}; - -export async function computeTextEmbedding( - model: Model, - text: string, -): Promise { - if (!isModel(model)) throw new Error(`Invalid CLIP model ${model}`); - - try { - const embedding = computeTextEmbedding_(model, text); - return embedding; - } catch (err) { - if (isExecError(err)) { - const parsedExecError = parseExecError(err); - throw Error(parsedExecError); - } else { - throw err; - } - } -} - -async function computeTextEmbedding_( - model: Model, - text: string, -): Promise { - switch (model) { - case "ggml-clip": - return await computeGGMLTextEmbedding(text); - case "onnx-clip": - return await computeONNXTextEmbedding(text); - } -} - -export async function computeGGMLTextEmbedding( - text: string, -): Promise { - const clipModelPath = await getClipTextModelPath("ggml"); - const ggmlclipPath = getGGMLClipPath(); - const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { - if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { - return ggmlclipPath; - } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { - return clipModelPath; - } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { - return text; - } else { - return cmdPart; - } - }); - - const { stdout } = await execAsync(cmd); - // parse stdout and return embedding - // get the last line of stdout - const lines = stdout.split("\n"); - const lastLine = lines[lines.length - 1]; - const embedding = JSON.parse(lastLine); - const embeddingArray = new Float32Array(embedding); - return embeddingArray; -} - -export async function computeONNXTextEmbedding( - text: string, -): Promise { - const imageSession = await getOnnxTextSession(); - const t1 = Date.now(); - const tokenizer = getTokenizer(); - const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); - const feeds = { - input: new ort.Tensor("int32", tokenizedText, [1, 77]), - }; - const t2 = Date.now(); - const results = await imageSession.run(feeds); - log.info( - `onnx text embedding time: ${Date.now() - t1} ms (prep:${ - t2 - t1 - } ms, extraction: ${Date.now() - t2} ms)`, - ); - const textEmbedding = results["output"].data; // Float32Array - return normalizeEmbedding(textEmbedding); -} diff --git a/desktop/src/types/ipc.ts b/desktop/src/types/ipc.ts index d1e99b24c..3dba231f2 100644 --- a/desktop/src/types/ipc.ts +++ b/desktop/src/types/ipc.ts @@ -80,7 +80,3 @@ export interface AppUpdateInfo { autoUpdatable: boolean; version: string; } - -export type Model = "ggml-clip" | "onnx-clip"; - -export const isModel = (s: unknown) => s == "ggml-clip" || s == "onnx-clip";