浏览代码

add back ggml clip logic

Abhinav 1 年之前
父节点
当前提交
aa52b769d8
共有 1 个文件被更改,包括 196 次插入20 次删除
  1. 196 20
      src/services/clipService.ts

+ 196 - 20
src/services/clipService.ts

@@ -1,27 +1,68 @@
 import * as log from 'electron-log';
+import util from 'util';
 import { logErrorSentry } from './sentry';
 import { isDev } from '../utils/common';
 import { app } from 'electron';
 import path from 'path';
 import { existsSync } from 'fs';
 import fs from 'fs/promises';
+const shellescape = require('any-shell-escape');
+const execAsync = util.promisify(require('child_process').exec);
 import fetch from 'node-fetch';
 import { writeNodeStream } from './fs';
+import { getPlatform } from '../utils/common/platform';
 import { CustomErrors } from '../constants/errors';
+
+const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
+const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
+const INPUT_PATH_PLACEHOLDER = 'INPUT';
+
+const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
+    GGMLCLIP_PATH_PLACEHOLDER,
+    '-mv',
+    CLIP_MODEL_PATH_PLACEHOLDER,
+    '--image',
+    INPUT_PATH_PLACEHOLDER,
+];
+
+const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
+    GGMLCLIP_PATH_PLACEHOLDER,
+    '-mt',
+    CLIP_MODEL_PATH_PLACEHOLDER,
+    '--text',
+    INPUT_PATH_PLACEHOLDER,
+];
 const ort = require('onnxruntime-node');
 const { encode } = require('gpt-3-encoder');
 const { createCanvas, Image } = require('canvas');
 
-const TEXT_MODEL_DOWNLOAD_URL =
-    'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx';
-const IMAGE_MODEL_DOWNLOAD_URL =
-    'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx';
+const TEXT_MODEL_DOWNLOAD_URL = {
+    ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
+    onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx',
+};
+const IMAGE_MODEL_DOWNLOAD_URL = {
+    ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
+    onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx',
+};
+
+const TEXT_MODEL_NAME = {
+    ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
+    onnx: 'clip-text-vit-32-float32-int32.onnx',
+};
+const IMAGE_MODEL_NAME = {
+    ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
+    onnx: 'clip-image-vit-32-float32.onnx',
+};
 
-const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx';
-const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx';
+const IMAGE_MODEL_SIZE_IN_BYTES = {
+    ggml: 175957504, // 167.8 MB
+    onnx: 351468764, // 335.2 MB
+};
+const TEXT_MODEL_SIZE_IN_BYTES = {
+    ggml: 127853440, // 121.9 MB,
+    onnx: 254069585, // 242.3 MB
+};
 
-const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB
-const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB
 const MODEL_SAVE_FOLDER = 'models';
 
 function getModelSavePath(modelName: string) {
@@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) {
 
 let imageModelDownloadInProgress: Promise<void> = null;
 
-export async function getClipImageModelPath() {
+export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
     try {
-        const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME);
+        const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
         if (imageModelDownloadInProgress) {
             log.info('waiting for image model download to finish');
             await imageModelDownloadInProgress;
@@ -60,19 +101,19 @@ export async function getClipImageModelPath() {
                 log.info('clip image model not found, downloading');
                 imageModelDownloadInProgress = downloadModel(
                     modelSavePath,
-                    IMAGE_MODEL_DOWNLOAD_URL
+                    IMAGE_MODEL_DOWNLOAD_URL[type]
                 );
                 await imageModelDownloadInProgress;
             } else {
                 const localFileSize = (await fs.stat(modelSavePath)).size;
-                if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) {
+                if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
                     log.info(
                         'clip image model size mismatch, downloading again got:',
                         localFileSize
                     );
                     imageModelDownloadInProgress = downloadModel(
                         modelSavePath,
-                        IMAGE_MODEL_DOWNLOAD_URL
+                        IMAGE_MODEL_DOWNLOAD_URL[type]
                     );
                     await imageModelDownloadInProgress;
                 }
@@ -86,15 +127,15 @@ export async function getClipImageModelPath() {
 
 let textModelDownloadInProgress: boolean = false;
 
-export async function getClipTextModelPath() {
-    const modelSavePath = getModelSavePath(TEXT_MODEL_NAME);
+export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
+    const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
     if (textModelDownloadInProgress) {
         throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
     } else {
         if (!existsSync(modelSavePath)) {
             log.info('clip text model not found, downloading');
             textModelDownloadInProgress = true;
-            downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
+            downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
                 .catch(() => {
                     // ignore
                 })
@@ -104,13 +145,13 @@ export async function getClipTextModelPath() {
             throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
         } else {
             const localFileSize = (await fs.stat(modelSavePath)).size;
-            if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) {
+            if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
                 log.info(
                     'clip text model size mismatch, downloading again',
                     localFileSize
                 );
                 textModelDownloadInProgress = true;
-                downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
+                downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
                     .catch(() => {
                         // ignore
                     })
@@ -124,6 +165,12 @@ export async function getClipTextModelPath() {
     return modelSavePath;
 }
 
+function getGGMLClipPath() {
+    return isDev
+        ? path.join('./build', `ggmlclip-${getPlatform()}`)
+        : path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
+}
+
 async function createOnnxSession(modelPath: string) {
     return await ort.InferenceSession.create(modelPath, {
         intraOpNumThreads: 1,
@@ -135,7 +182,7 @@ let onnxImageSession: any = null;
 
 async function getOnnxImageSession() {
     if (!onnxImageSession) {
-        const clipModelPath = await getClipImageModelPath();
+        const clipModelPath = await getClipImageModelPath('onnx');
         onnxImageSession = createOnnxSession(clipModelPath);
     }
     return onnxImageSession;
@@ -145,7 +192,7 @@ let onnxTextSession: any = null;
 
 async function getOnnxTextSession() {
     if (!onnxTextSession) {
-        const clipModelPath = await getClipTextModelPath();
+        const clipModelPath = await getClipTextModelPath('onnx');
         onnxTextSession = createOnnxSession(clipModelPath);
     }
     return onnxTextSession;
@@ -153,6 +200,55 @@ async function getOnnxTextSession() {
 
 export async function computeImageEmbedding(
     inputFilePath: string
+): Promise<Float32Array> {
+    const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath);
+    const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath);
+    const score = await computeClipMatchScore(
+        ggmlImageEmbedding,
+        onnxImageEmbedding
+    );
+    console.log('imageEmbeddingScore', score);
+    return onnxImageEmbedding;
+}
+
+export async function computeGGMLImageEmbedding(
+    inputFilePath: string
+): Promise<Float32Array> {
+    try {
+        const clipModelPath = await getClipImageModelPath('ggml');
+        const ggmlclipPath = getGGMLClipPath();
+        const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
+            if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
+                return ggmlclipPath;
+            } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
+                return clipModelPath;
+            } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
+                return inputFilePath;
+            } else {
+                return cmdPart;
+            }
+        });
+
+        const escapedCmd = shellescape(cmd);
+        log.info('running clip command', escapedCmd);
+        const startTime = Date.now();
+        const { stdout } = await execAsync(escapedCmd);
+        log.info('clip command execution time ', Date.now() - startTime);
+        // parse stdout and return embedding
+        // get the last line of stdout
+        const lines = stdout.split('\n');
+        const lastLine = lines[lines.length - 1];
+        const embedding = JSON.parse(lastLine);
+        const embeddingArray = new Float32Array(embedding);
+        return embeddingArray;
+    } catch (err) {
+        logErrorSentry(err, 'Error in computeImageEmbedding');
+        throw err;
+    }
+}
+
+export async function computeONNXImageEmbedding(
+    inputFilePath: string
 ): Promise<Float32Array> {
     try {
         const imageSession = await getOnnxImageSession();
@@ -170,6 +266,59 @@ export async function computeImageEmbedding(
 }
 
 export async function computeTextEmbedding(
+    inputFilePath: string
+): Promise<Float32Array> {
+    const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath);
+    const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath);
+    const score = await computeClipMatchScore(
+        ggmlImageEmbedding,
+        onnxImageEmbedding
+    );
+    console.log('textEmbeddingScore', score);
+    return onnxImageEmbedding;
+}
+
+export async function computeGGMLTextEmbedding(
+    text: string
+): Promise<Float32Array> {
+    try {
+        const clipModelPath = await getClipTextModelPath('ggml');
+        const ggmlclipPath = getGGMLClipPath();
+        const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
+            if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
+                return ggmlclipPath;
+            } else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
+                return clipModelPath;
+            } else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
+                return text;
+            } else {
+                return cmdPart;
+            }
+        });
+
+        const escapedCmd = shellescape(cmd);
+        log.info('running clip command', escapedCmd);
+        const startTime = Date.now();
+        const { stdout } = await execAsync(escapedCmd);
+        log.info('clip command execution time ', Date.now() - startTime);
+        // parse stdout and return embedding
+        // get the last line of stdout
+        const lines = stdout.split('\n');
+        const lastLine = lines[lines.length - 1];
+        const embedding = JSON.parse(lastLine);
+        const embeddingArray = new Float32Array(embedding);
+        return embeddingArray;
+    } catch (err) {
+        if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
+            log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
+        } else {
+            logErrorSentry(err, 'Error in computeTextEmbedding');
+        }
+        throw err;
+    }
+}
+
+export async function computeONNXTextEmbedding(
     text: string
 ): Promise<Float32Array> {
     try {
@@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) {
     }
     return Float32Array.from(rgbData.flat().flat());
 }
+
+export const computeClipMatchScore = async (
+    imageEmbedding: Float32Array,
+    textEmbedding: Float32Array
+) => {
+    if (imageEmbedding.length !== textEmbedding.length) {
+        throw Error('imageEmbedding and textEmbedding length mismatch');
+    }
+    let score = 0;
+    let imageNormalization = 0;
+    let textNormalization = 0;
+
+    for (let index = 0; index < imageEmbedding.length; index++) {
+        imageNormalization += imageEmbedding[index] * imageEmbedding[index];
+        textNormalization += textEmbedding[index] * textEmbedding[index];
+    }
+    for (let index = 0; index < imageEmbedding.length; index++) {
+        imageEmbedding[index] =
+            imageEmbedding[index] / Math.sqrt(imageNormalization);
+        textEmbedding[index] =
+            textEmbedding[index] / Math.sqrt(textNormalization);
+    }
+    for (let index = 0; index < imageEmbedding.length; index++) {
+        score += imageEmbedding[index] * textEmbedding[index];
+    }
+    return score;
+};