add back ggml clip logic

This commit is contained in:
Abhinav 2024-01-04 16:17:15 +05:30
parent 6d7d513702
commit aa52b769d8

View file

@ -1,27 +1,68 @@
import * as log from 'electron-log';
import util from 'util';
import { logErrorSentry } from './sentry';
import { isDev } from '../utils/common';
import { app } from 'electron';
import path from 'path';
import { existsSync } from 'fs';
import fs from 'fs/promises';
const shellescape = require('any-shell-escape');
const execAsync = util.promisify(require('child_process').exec);
import fetch from 'node-fetch';
import { writeNodeStream } from './fs';
import { getPlatform } from '../utils/common/platform';
import { CustomErrors } from '../constants/errors';
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
const INPUT_PATH_PLACEHOLDER = 'INPUT';
const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
'-mv',
CLIP_MODEL_PATH_PLACEHOLDER,
'--image',
INPUT_PATH_PLACEHOLDER,
];
const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
'-mt',
CLIP_MODEL_PATH_PLACEHOLDER,
'--text',
INPUT_PATH_PLACEHOLDER,
];
const ort = require('onnxruntime-node');
const { encode } = require('gpt-3-encoder');
const { createCanvas, Image } = require('canvas');
const TEXT_MODEL_DOWNLOAD_URL =
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx';
const IMAGE_MODEL_DOWNLOAD_URL =
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx';
const TEXT_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx',
};
const IMAGE_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx',
};
const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx';
const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx';
const TEXT_MODEL_NAME = {
ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
onnx: 'clip-text-vit-32-float32-int32.onnx',
};
const IMAGE_MODEL_NAME = {
ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'clip-image-vit-32-float32.onnx',
};
const IMAGE_MODEL_SIZE_IN_BYTES = {
ggml: 175957504, // 167.8 MB
onnx: 351468764, // 335.2 MB
};
const TEXT_MODEL_SIZE_IN_BYTES = {
ggml: 127853440, // 121.9 MB,
onnx: 254069585, // 242.3 MB
};
const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB
const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB
const MODEL_SAVE_FOLDER = 'models';
function getModelSavePath(modelName: string) {
@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) {
let imageModelDownloadInProgress: Promise<void> = null;
export async function getClipImageModelPath() {
export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
try {
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME);
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
if (imageModelDownloadInProgress) {
log.info('waiting for image model download to finish');
await imageModelDownloadInProgress;
@ -60,19 +101,19 @@ export async function getClipImageModelPath() {
log.info('clip image model not found, downloading');
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL
IMAGE_MODEL_DOWNLOAD_URL[type]
);
await imageModelDownloadInProgress;
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) {
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
log.info(
'clip image model size mismatch, downloading again got:',
localFileSize
);
imageModelDownloadInProgress = downloadModel(
modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL
IMAGE_MODEL_DOWNLOAD_URL[type]
);
await imageModelDownloadInProgress;
}
@ -86,15 +127,15 @@ export async function getClipImageModelPath() {
let textModelDownloadInProgress: boolean = false;
export async function getClipTextModelPath() {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME);
export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
if (textModelDownloadInProgress) {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
if (!existsSync(modelSavePath)) {
log.info('clip text model not found, downloading');
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => {
// ignore
})
@ -104,13 +145,13 @@ export async function getClipTextModelPath() {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) {
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
log.info(
'clip text model size mismatch, downloading again',
localFileSize
);
textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => {
// ignore
})
@ -124,6 +165,12 @@ export async function getClipTextModelPath() {
return modelSavePath;
}
function getGGMLClipPath() {
return isDev
? path.join('./build', `ggmlclip-${getPlatform()}`)
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
}
async function createOnnxSession(modelPath: string) {
return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1,
@ -135,7 +182,7 @@ let onnxImageSession: any = null;
async function getOnnxImageSession() {
if (!onnxImageSession) {
const clipModelPath = await getClipImageModelPath();
const clipModelPath = await getClipImageModelPath('onnx');
onnxImageSession = createOnnxSession(clipModelPath);
}
return onnxImageSession;
@ -145,7 +192,7 @@ let onnxTextSession: any = null;
async function getOnnxTextSession() {
if (!onnxTextSession) {
const clipModelPath = await getClipTextModelPath();
const clipModelPath = await getClipTextModelPath('onnx');
onnxTextSession = createOnnxSession(clipModelPath);
}
return onnxTextSession;
@ -153,6 +200,55 @@ async function getOnnxTextSession() {
export async function computeImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath);
const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath);
const score = await computeClipMatchScore(
ggmlImageEmbedding,
onnxImageEmbedding
);
console.log('imageEmbeddingScore', score);
return onnxImageEmbedding;
}
export async function computeGGMLImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipImageModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return inputFilePath;
} else {
return cmdPart;
}
});
const escapedCmd = shellescape(cmd);
log.info('running clip command', escapedCmd);
const startTime = Date.now();
const { stdout } = await execAsync(escapedCmd);
log.info('clip command execution time ', Date.now() - startTime);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split('\n');
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
} catch (err) {
logErrorSentry(err, 'Error in computeImageEmbedding');
throw err;
}
}
export async function computeONNXImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
try {
const imageSession = await getOnnxImageSession();
@ -170,6 +266,59 @@ export async function computeImageEmbedding(
}
export async function computeTextEmbedding(
inputFilePath: string
): Promise<Float32Array> {
const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath);
const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath);
const score = await computeClipMatchScore(
ggmlImageEmbedding,
onnxImageEmbedding
);
console.log('textEmbeddingScore', score);
return onnxImageEmbedding;
}
export async function computeGGMLTextEmbedding(
text: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipTextModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return text;
} else {
return cmdPart;
}
});
const escapedCmd = shellescape(cmd);
log.info('running clip command', escapedCmd);
const startTime = Date.now();
const { stdout } = await execAsync(escapedCmd);
log.info('clip command execution time ', Date.now() - startTime);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split('\n');
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
} catch (err) {
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
logErrorSentry(err, 'Error in computeTextEmbedding');
}
throw err;
}
}
export async function computeONNXTextEmbedding(
text: string
): Promise<Float32Array> {
try {
@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) {
}
return Float32Array.from(rgbData.flat().flat());
}
export const computeClipMatchScore = async (
imageEmbedding: Float32Array,
textEmbedding: Float32Array
) => {
if (imageEmbedding.length !== textEmbedding.length) {
throw Error('imageEmbedding and textEmbedding length mismatch');
}
let score = 0;
let imageNormalization = 0;
let textNormalization = 0;
for (let index = 0; index < imageEmbedding.length; index++) {
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
textNormalization += textEmbedding[index] * textEmbedding[index];
}
for (let index = 0; index < imageEmbedding.length; index++) {
imageEmbedding[index] =
imageEmbedding[index] / Math.sqrt(imageNormalization);
textEmbedding[index] =
textEmbedding[index] / Math.sqrt(textNormalization);
}
for (let index = 0; index < imageEmbedding.length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;
};