add back ggml clip logic

This commit is contained in:
Abhinav 2024-01-04 16:17:15 +05:30
parent 6d7d513702
commit aa52b769d8

View file

@ -1,27 +1,68 @@
import * as log from 'electron-log'; import * as log from 'electron-log';
import util from 'util';
import { logErrorSentry } from './sentry'; import { logErrorSentry } from './sentry';
import { isDev } from '../utils/common'; import { isDev } from '../utils/common';
import { app } from 'electron'; import { app } from 'electron';
import path from 'path'; import path from 'path';
import { existsSync } from 'fs'; import { existsSync } from 'fs';
import fs from 'fs/promises'; import fs from 'fs/promises';
const shellescape = require('any-shell-escape');
const execAsync = util.promisify(require('child_process').exec);
import fetch from 'node-fetch'; import fetch from 'node-fetch';
import { writeNodeStream } from './fs'; import { writeNodeStream } from './fs';
import { getPlatform } from '../utils/common/platform';
import { CustomErrors } from '../constants/errors'; import { CustomErrors } from '../constants/errors';
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
const INPUT_PATH_PLACEHOLDER = 'INPUT';
const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
'-mv',
CLIP_MODEL_PATH_PLACEHOLDER,
'--image',
INPUT_PATH_PLACEHOLDER,
];
const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
GGMLCLIP_PATH_PLACEHOLDER,
'-mt',
CLIP_MODEL_PATH_PLACEHOLDER,
'--text',
INPUT_PATH_PLACEHOLDER,
];
const ort = require('onnxruntime-node'); const ort = require('onnxruntime-node');
const { encode } = require('gpt-3-encoder'); const { encode } = require('gpt-3-encoder');
const { createCanvas, Image } = require('canvas'); const { createCanvas, Image } = require('canvas');
const TEXT_MODEL_DOWNLOAD_URL = const TEXT_MODEL_DOWNLOAD_URL = {
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx'; ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
const IMAGE_MODEL_DOWNLOAD_URL = onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx',
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx'; };
const IMAGE_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx',
};
const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx'; const TEXT_MODEL_NAME = {
const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx'; ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
onnx: 'clip-text-vit-32-float32-int32.onnx',
};
const IMAGE_MODEL_NAME = {
ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
onnx: 'clip-image-vit-32-float32.onnx',
};
const IMAGE_MODEL_SIZE_IN_BYTES = {
ggml: 175957504, // 167.8 MB
onnx: 351468764, // 335.2 MB
};
const TEXT_MODEL_SIZE_IN_BYTES = {
ggml: 127853440, // 121.9 MB,
onnx: 254069585, // 242.3 MB
};
const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB
const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB
const MODEL_SAVE_FOLDER = 'models'; const MODEL_SAVE_FOLDER = 'models';
function getModelSavePath(modelName: string) { function getModelSavePath(modelName: string) {
@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) {
let imageModelDownloadInProgress: Promise<void> = null; let imageModelDownloadInProgress: Promise<void> = null;
export async function getClipImageModelPath() { export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
try { try {
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME); const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
if (imageModelDownloadInProgress) { if (imageModelDownloadInProgress) {
log.info('waiting for image model download to finish'); log.info('waiting for image model download to finish');
await imageModelDownloadInProgress; await imageModelDownloadInProgress;
@ -60,19 +101,19 @@ export async function getClipImageModelPath() {
log.info('clip image model not found, downloading'); log.info('clip image model not found, downloading');
imageModelDownloadInProgress = downloadModel( imageModelDownloadInProgress = downloadModel(
modelSavePath, modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL IMAGE_MODEL_DOWNLOAD_URL[type]
); );
await imageModelDownloadInProgress; await imageModelDownloadInProgress;
} else { } else {
const localFileSize = (await fs.stat(modelSavePath)).size; const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) { if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
log.info( log.info(
'clip image model size mismatch, downloading again got:', 'clip image model size mismatch, downloading again got:',
localFileSize localFileSize
); );
imageModelDownloadInProgress = downloadModel( imageModelDownloadInProgress = downloadModel(
modelSavePath, modelSavePath,
IMAGE_MODEL_DOWNLOAD_URL IMAGE_MODEL_DOWNLOAD_URL[type]
); );
await imageModelDownloadInProgress; await imageModelDownloadInProgress;
} }
@ -86,15 +127,15 @@ export async function getClipImageModelPath() {
let textModelDownloadInProgress: boolean = false; let textModelDownloadInProgress: boolean = false;
export async function getClipTextModelPath() { export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME); const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
if (textModelDownloadInProgress) { if (textModelDownloadInProgress) {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else { } else {
if (!existsSync(modelSavePath)) { if (!existsSync(modelSavePath)) {
log.info('clip text model not found, downloading'); log.info('clip text model not found, downloading');
textModelDownloadInProgress = true; textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL) downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => { .catch(() => {
// ignore // ignore
}) })
@ -104,13 +145,13 @@ export async function getClipTextModelPath() {
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING); throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else { } else {
const localFileSize = (await fs.stat(modelSavePath)).size; const localFileSize = (await fs.stat(modelSavePath)).size;
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) { if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
log.info( log.info(
'clip text model size mismatch, downloading again', 'clip text model size mismatch, downloading again',
localFileSize localFileSize
); );
textModelDownloadInProgress = true; textModelDownloadInProgress = true;
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL) downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
.catch(() => { .catch(() => {
// ignore // ignore
}) })
@ -124,6 +165,12 @@ export async function getClipTextModelPath() {
return modelSavePath; return modelSavePath;
} }
function getGGMLClipPath() {
return isDev
? path.join('./build', `ggmlclip-${getPlatform()}`)
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
}
async function createOnnxSession(modelPath: string) { async function createOnnxSession(modelPath: string) {
return await ort.InferenceSession.create(modelPath, { return await ort.InferenceSession.create(modelPath, {
intraOpNumThreads: 1, intraOpNumThreads: 1,
@ -135,7 +182,7 @@ let onnxImageSession: any = null;
async function getOnnxImageSession() { async function getOnnxImageSession() {
if (!onnxImageSession) { if (!onnxImageSession) {
const clipModelPath = await getClipImageModelPath(); const clipModelPath = await getClipImageModelPath('onnx');
onnxImageSession = createOnnxSession(clipModelPath); onnxImageSession = createOnnxSession(clipModelPath);
} }
return onnxImageSession; return onnxImageSession;
@ -145,7 +192,7 @@ let onnxTextSession: any = null;
async function getOnnxTextSession() { async function getOnnxTextSession() {
if (!onnxTextSession) { if (!onnxTextSession) {
const clipModelPath = await getClipTextModelPath(); const clipModelPath = await getClipTextModelPath('onnx');
onnxTextSession = createOnnxSession(clipModelPath); onnxTextSession = createOnnxSession(clipModelPath);
} }
return onnxTextSession; return onnxTextSession;
@ -153,6 +200,55 @@ async function getOnnxTextSession() {
export async function computeImageEmbedding( export async function computeImageEmbedding(
inputFilePath: string inputFilePath: string
): Promise<Float32Array> {
const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath);
const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath);
const score = await computeClipMatchScore(
ggmlImageEmbedding,
onnxImageEmbedding
);
console.log('imageEmbeddingScore', score);
return onnxImageEmbedding;
}
export async function computeGGMLImageEmbedding(
inputFilePath: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipImageModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return inputFilePath;
} else {
return cmdPart;
}
});
const escapedCmd = shellescape(cmd);
log.info('running clip command', escapedCmd);
const startTime = Date.now();
const { stdout } = await execAsync(escapedCmd);
log.info('clip command execution time ', Date.now() - startTime);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split('\n');
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
} catch (err) {
logErrorSentry(err, 'Error in computeImageEmbedding');
throw err;
}
}
export async function computeONNXImageEmbedding(
inputFilePath: string
): Promise<Float32Array> { ): Promise<Float32Array> {
try { try {
const imageSession = await getOnnxImageSession(); const imageSession = await getOnnxImageSession();
@ -170,6 +266,59 @@ export async function computeImageEmbedding(
} }
export async function computeTextEmbedding( export async function computeTextEmbedding(
inputFilePath: string
): Promise<Float32Array> {
const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath);
const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath);
const score = await computeClipMatchScore(
ggmlImageEmbedding,
onnxImageEmbedding
);
console.log('textEmbeddingScore', score);
return onnxImageEmbedding;
}
export async function computeGGMLTextEmbedding(
text: string
): Promise<Float32Array> {
try {
const clipModelPath = await getClipTextModelPath('ggml');
const ggmlclipPath = getGGMLClipPath();
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
return ggmlclipPath;
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
return clipModelPath;
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
return text;
} else {
return cmdPart;
}
});
const escapedCmd = shellescape(cmd);
log.info('running clip command', escapedCmd);
const startTime = Date.now();
const { stdout } = await execAsync(escapedCmd);
log.info('clip command execution time ', Date.now() - startTime);
// parse stdout and return embedding
// get the last line of stdout
const lines = stdout.split('\n');
const lastLine = lines[lines.length - 1];
const embedding = JSON.parse(lastLine);
const embeddingArray = new Float32Array(embedding);
return embeddingArray;
} catch (err) {
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
} else {
logErrorSentry(err, 'Error in computeTextEmbedding');
}
throw err;
}
}
export async function computeONNXTextEmbedding(
text: string text: string
): Promise<Float32Array> { ): Promise<Float32Array> {
try { try {
@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) {
} }
return Float32Array.from(rgbData.flat().flat()); return Float32Array.from(rgbData.flat().flat());
} }
export const computeClipMatchScore = async (
imageEmbedding: Float32Array,
textEmbedding: Float32Array
) => {
if (imageEmbedding.length !== textEmbedding.length) {
throw Error('imageEmbedding and textEmbedding length mismatch');
}
let score = 0;
let imageNormalization = 0;
let textNormalization = 0;
for (let index = 0; index < imageEmbedding.length; index++) {
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
textNormalization += textEmbedding[index] * textEmbedding[index];
}
for (let index = 0; index < imageEmbedding.length; index++) {
imageEmbedding[index] =
imageEmbedding[index] / Math.sqrt(imageNormalization);
textEmbedding[index] =
textEmbedding[index] / Math.sqrt(textNormalization);
}
for (let index = 0; index < imageEmbedding.length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;
};