diff --git a/src/services/clipService.ts b/src/services/clipService.ts index ebe254119..2090082e4 100644 --- a/src/services/clipService.ts +++ b/src/services/clipService.ts @@ -1,27 +1,68 @@ import * as log from 'electron-log'; +import util from 'util'; import { logErrorSentry } from './sentry'; import { isDev } from '../utils/common'; import { app } from 'electron'; import path from 'path'; import { existsSync } from 'fs'; import fs from 'fs/promises'; +const shellescape = require('any-shell-escape'); +const execAsync = util.promisify(require('child_process').exec); import fetch from 'node-fetch'; import { writeNodeStream } from './fs'; +import { getPlatform } from '../utils/common/platform'; import { CustomErrors } from '../constants/errors'; + +const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL'; +const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH'; +const INPUT_PATH_PLACEHOLDER = 'INPUT'; + +const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [ + GGMLCLIP_PATH_PLACEHOLDER, + '-mv', + CLIP_MODEL_PATH_PLACEHOLDER, + '--image', + INPUT_PATH_PLACEHOLDER, +]; + +const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [ + GGMLCLIP_PATH_PLACEHOLDER, + '-mt', + CLIP_MODEL_PATH_PLACEHOLDER, + '--text', + INPUT_PATH_PLACEHOLDER, +]; const ort = require('onnxruntime-node'); const { encode } = require('gpt-3-encoder'); const { createCanvas, Image } = require('canvas'); -const TEXT_MODEL_DOWNLOAD_URL = - 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx'; -const IMAGE_MODEL_DOWNLOAD_URL = - 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx'; +const TEXT_MODEL_DOWNLOAD_URL = { + ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf', + onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx', +}; +const IMAGE_MODEL_DOWNLOAD_URL = { + ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf', + onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx', +}; -const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx'; -const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx'; +const TEXT_MODEL_NAME = { + ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf', + onnx: 'clip-text-vit-32-float32-int32.onnx', +}; +const IMAGE_MODEL_NAME = { + ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf', + onnx: 'clip-image-vit-32-float32.onnx', +}; + +const IMAGE_MODEL_SIZE_IN_BYTES = { + ggml: 175957504, // 167.8 MB + onnx: 351468764, // 335.2 MB +}; +const TEXT_MODEL_SIZE_IN_BYTES = { + ggml: 127853440, // 121.9 MB, + onnx: 254069585, // 242.3 MB +}; -const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB -const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB const MODEL_SAVE_FOLDER = 'models'; function getModelSavePath(modelName: string) { @@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) { let imageModelDownloadInProgress: Promise = null; -export async function getClipImageModelPath() { +export async function getClipImageModelPath(type: 'ggml' | 'onnx') { try { - const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME); + const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]); if (imageModelDownloadInProgress) { log.info('waiting for image model download to finish'); await imageModelDownloadInProgress; @@ -60,19 +101,19 @@ export async function getClipImageModelPath() { log.info('clip image model not found, downloading'); imageModelDownloadInProgress = downloadModel( modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL + IMAGE_MODEL_DOWNLOAD_URL[type] ); await imageModelDownloadInProgress; } else { const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) { + if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) { log.info( 'clip image model size mismatch, downloading again got:', localFileSize ); imageModelDownloadInProgress = downloadModel( modelSavePath, - IMAGE_MODEL_DOWNLOAD_URL + IMAGE_MODEL_DOWNLOAD_URL[type] ); await imageModelDownloadInProgress; } @@ -86,15 +127,15 @@ export async function getClipImageModelPath() { let textModelDownloadInProgress: boolean = false; -export async function getClipTextModelPath() { - const modelSavePath = getModelSavePath(TEXT_MODEL_NAME); +export async function getClipTextModelPath(type: 'ggml' | 'onnx') { + const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]); if (textModelDownloadInProgress) { throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); } else { if (!existsSync(modelSavePath)) { log.info('clip text model not found, downloading'); textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL) + downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) .catch(() => { // ignore }) @@ -104,13 +145,13 @@ export async function getClipTextModelPath() { throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); } else { const localFileSize = (await fs.stat(modelSavePath)).size; - if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) { + if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) { log.info( 'clip text model size mismatch, downloading again', localFileSize ); textModelDownloadInProgress = true; - downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL) + downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type]) .catch(() => { // ignore }) @@ -124,6 +165,12 @@ export async function getClipTextModelPath() { return modelSavePath; } +function getGGMLClipPath() { + return isDev + ? path.join('./build', `ggmlclip-${getPlatform()}`) + : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`); +} + async function createOnnxSession(modelPath: string) { return await ort.InferenceSession.create(modelPath, { intraOpNumThreads: 1, @@ -135,7 +182,7 @@ let onnxImageSession: any = null; async function getOnnxImageSession() { if (!onnxImageSession) { - const clipModelPath = await getClipImageModelPath(); + const clipModelPath = await getClipImageModelPath('onnx'); onnxImageSession = createOnnxSession(clipModelPath); } return onnxImageSession; @@ -145,7 +192,7 @@ let onnxTextSession: any = null; async function getOnnxTextSession() { if (!onnxTextSession) { - const clipModelPath = await getClipTextModelPath(); + const clipModelPath = await getClipTextModelPath('onnx'); onnxTextSession = createOnnxSession(clipModelPath); } return onnxTextSession; @@ -153,6 +200,55 @@ async function getOnnxTextSession() { export async function computeImageEmbedding( inputFilePath: string +): Promise { + const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath); + const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath); + const score = await computeClipMatchScore( + ggmlImageEmbedding, + onnxImageEmbedding + ); + console.log('imageEmbeddingScore', score); + return onnxImageEmbedding; +} + +export async function computeGGMLImageEmbedding( + inputFilePath: string +): Promise { + try { + const clipModelPath = await getClipImageModelPath('ggml'); + const ggmlclipPath = getGGMLClipPath(); + const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { + if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { + return ggmlclipPath; + } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { + return clipModelPath; + } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { + return inputFilePath; + } else { + return cmdPart; + } + }); + + const escapedCmd = shellescape(cmd); + log.info('running clip command', escapedCmd); + const startTime = Date.now(); + const { stdout } = await execAsync(escapedCmd); + log.info('clip command execution time ', Date.now() - startTime); + // parse stdout and return embedding + // get the last line of stdout + const lines = stdout.split('\n'); + const lastLine = lines[lines.length - 1]; + const embedding = JSON.parse(lastLine); + const embeddingArray = new Float32Array(embedding); + return embeddingArray; + } catch (err) { + logErrorSentry(err, 'Error in computeImageEmbedding'); + throw err; + } +} + +export async function computeONNXImageEmbedding( + inputFilePath: string ): Promise { try { const imageSession = await getOnnxImageSession(); @@ -170,6 +266,59 @@ export async function computeImageEmbedding( } export async function computeTextEmbedding( + inputFilePath: string +): Promise { + const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath); + const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath); + const score = await computeClipMatchScore( + ggmlImageEmbedding, + onnxImageEmbedding + ); + console.log('textEmbeddingScore', score); + return onnxImageEmbedding; +} + +export async function computeGGMLTextEmbedding( + text: string +): Promise { + try { + const clipModelPath = await getClipTextModelPath('ggml'); + const ggmlclipPath = getGGMLClipPath(); + const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => { + if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) { + return ggmlclipPath; + } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) { + return clipModelPath; + } else if (cmdPart === INPUT_PATH_PLACEHOLDER) { + return text; + } else { + return cmdPart; + } + }); + + const escapedCmd = shellescape(cmd); + log.info('running clip command', escapedCmd); + const startTime = Date.now(); + const { stdout } = await execAsync(escapedCmd); + log.info('clip command execution time ', Date.now() - startTime); + // parse stdout and return embedding + // get the last line of stdout + const lines = stdout.split('\n'); + const lastLine = lines[lines.length - 1]; + const embedding = JSON.parse(lastLine); + const embeddingArray = new Float32Array(embedding); + return embeddingArray; + } catch (err) { + if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) { + log.info(CustomErrors.MODEL_DOWNLOAD_PENDING); + } else { + logErrorSentry(err, 'Error in computeTextEmbedding'); + } + throw err; + } +} + +export async function computeONNXTextEmbedding( text: string ): Promise { try { @@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) { } return Float32Array.from(rgbData.flat().flat()); } + +export const computeClipMatchScore = async ( + imageEmbedding: Float32Array, + textEmbedding: Float32Array +) => { + if (imageEmbedding.length !== textEmbedding.length) { + throw Error('imageEmbedding and textEmbedding length mismatch'); + } + let score = 0; + let imageNormalization = 0; + let textNormalization = 0; + + for (let index = 0; index < imageEmbedding.length; index++) { + imageNormalization += imageEmbedding[index] * imageEmbedding[index]; + textNormalization += textEmbedding[index] * textEmbedding[index]; + } + for (let index = 0; index < imageEmbedding.length; index++) { + imageEmbedding[index] = + imageEmbedding[index] / Math.sqrt(imageNormalization); + textEmbedding[index] = + textEmbedding[index] / Math.sqrt(textNormalization); + } + for (let index = 0; index < imageEmbedding.length; index++) { + score += imageEmbedding[index] * textEmbedding[index]; + } + return score; +};