add back ggml clip logic
This commit is contained in:
parent
6d7d513702
commit
aa52b769d8
1 changed files with 196 additions and 20 deletions
|
@ -1,27 +1,68 @@
|
|||
import * as log from 'electron-log';
|
||||
import util from 'util';
|
||||
import { logErrorSentry } from './sentry';
|
||||
import { isDev } from '../utils/common';
|
||||
import { app } from 'electron';
|
||||
import path from 'path';
|
||||
import { existsSync } from 'fs';
|
||||
import fs from 'fs/promises';
|
||||
const shellescape = require('any-shell-escape');
|
||||
const execAsync = util.promisify(require('child_process').exec);
|
||||
import fetch from 'node-fetch';
|
||||
import { writeNodeStream } from './fs';
|
||||
import { getPlatform } from '../utils/common/platform';
|
||||
import { CustomErrors } from '../constants/errors';
|
||||
|
||||
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
|
||||
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
|
||||
const INPUT_PATH_PLACEHOLDER = 'INPUT';
|
||||
|
||||
const IMAGE_EMBEDDING_EXTRACT_CMD: string[] = [
|
||||
GGMLCLIP_PATH_PLACEHOLDER,
|
||||
'-mv',
|
||||
CLIP_MODEL_PATH_PLACEHOLDER,
|
||||
'--image',
|
||||
INPUT_PATH_PLACEHOLDER,
|
||||
];
|
||||
|
||||
const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
|
||||
GGMLCLIP_PATH_PLACEHOLDER,
|
||||
'-mt',
|
||||
CLIP_MODEL_PATH_PLACEHOLDER,
|
||||
'--text',
|
||||
INPUT_PATH_PLACEHOLDER,
|
||||
];
|
||||
const ort = require('onnxruntime-node');
|
||||
const { encode } = require('gpt-3-encoder');
|
||||
const { createCanvas, Image } = require('canvas');
|
||||
|
||||
const TEXT_MODEL_DOWNLOAD_URL =
|
||||
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx';
|
||||
const IMAGE_MODEL_DOWNLOAD_URL =
|
||||
'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx';
|
||||
const TEXT_MODEL_DOWNLOAD_URL = {
|
||||
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
||||
onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-text-vit-32-float32-int32.onnx',
|
||||
};
|
||||
const IMAGE_MODEL_DOWNLOAD_URL = {
|
||||
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-vision-model-f16.gguf',
|
||||
onnx: 'https://huggingface.co/rocca/openai-clip-js/resolve/main/clip-image-vit-32-float32.onnx',
|
||||
};
|
||||
|
||||
const TEXT_MODEL_NAME = 'clip-text-vit-32-float32-int32.onnx';
|
||||
const IMAGE_MODEL_NAME = 'clip-image-vit-32-float32.onnx';
|
||||
const TEXT_MODEL_NAME = {
|
||||
ggml: 'clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
||||
onnx: 'clip-text-vit-32-float32-int32.onnx',
|
||||
};
|
||||
const IMAGE_MODEL_NAME = {
|
||||
ggml: 'clip-vit-base-patch32_ggml-vision-model-f16.gguf',
|
||||
onnx: 'clip-image-vit-32-float32.onnx',
|
||||
};
|
||||
|
||||
const IMAGE_MODEL_SIZE_IN_BYTES = {
|
||||
ggml: 175957504, // 167.8 MB
|
||||
onnx: 351468764, // 335.2 MB
|
||||
};
|
||||
const TEXT_MODEL_SIZE_IN_BYTES = {
|
||||
ggml: 127853440, // 121.9 MB,
|
||||
onnx: 254069585, // 242.3 MB
|
||||
};
|
||||
|
||||
const IMAGE_MODEL_SIZE_IN_BYTES = 351468764; // 335.2 MB
|
||||
const TEXT_MODEL_SIZE_IN_BYTES = 254069585; // 242.3 MB
|
||||
const MODEL_SAVE_FOLDER = 'models';
|
||||
|
||||
function getModelSavePath(modelName: string) {
|
||||
|
@ -49,9 +90,9 @@ async function downloadModel(saveLocation: string, url: string) {
|
|||
|
||||
let imageModelDownloadInProgress: Promise<void> = null;
|
||||
|
||||
export async function getClipImageModelPath() {
|
||||
export async function getClipImageModelPath(type: 'ggml' | 'onnx') {
|
||||
try {
|
||||
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME);
|
||||
const modelSavePath = getModelSavePath(IMAGE_MODEL_NAME[type]);
|
||||
if (imageModelDownloadInProgress) {
|
||||
log.info('waiting for image model download to finish');
|
||||
await imageModelDownloadInProgress;
|
||||
|
@ -60,19 +101,19 @@ export async function getClipImageModelPath() {
|
|||
log.info('clip image model not found, downloading');
|
||||
imageModelDownloadInProgress = downloadModel(
|
||||
modelSavePath,
|
||||
IMAGE_MODEL_DOWNLOAD_URL
|
||||
IMAGE_MODEL_DOWNLOAD_URL[type]
|
||||
);
|
||||
await imageModelDownloadInProgress;
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelSavePath)).size;
|
||||
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES) {
|
||||
if (localFileSize !== IMAGE_MODEL_SIZE_IN_BYTES[type]) {
|
||||
log.info(
|
||||
'clip image model size mismatch, downloading again got:',
|
||||
localFileSize
|
||||
);
|
||||
imageModelDownloadInProgress = downloadModel(
|
||||
modelSavePath,
|
||||
IMAGE_MODEL_DOWNLOAD_URL
|
||||
IMAGE_MODEL_DOWNLOAD_URL[type]
|
||||
);
|
||||
await imageModelDownloadInProgress;
|
||||
}
|
||||
|
@ -86,15 +127,15 @@ export async function getClipImageModelPath() {
|
|||
|
||||
let textModelDownloadInProgress: boolean = false;
|
||||
|
||||
export async function getClipTextModelPath() {
|
||||
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME);
|
||||
export async function getClipTextModelPath(type: 'ggml' | 'onnx') {
|
||||
const modelSavePath = getModelSavePath(TEXT_MODEL_NAME[type]);
|
||||
if (textModelDownloadInProgress) {
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
if (!existsSync(modelSavePath)) {
|
||||
log.info('clip text model not found, downloading');
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
||||
.catch(() => {
|
||||
// ignore
|
||||
})
|
||||
|
@ -104,13 +145,13 @@ export async function getClipTextModelPath() {
|
|||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelSavePath)).size;
|
||||
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES) {
|
||||
if (localFileSize !== TEXT_MODEL_SIZE_IN_BYTES[type]) {
|
||||
log.info(
|
||||
'clip text model size mismatch, downloading again',
|
||||
localFileSize
|
||||
);
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL)
|
||||
downloadModel(modelSavePath, TEXT_MODEL_DOWNLOAD_URL[type])
|
||||
.catch(() => {
|
||||
// ignore
|
||||
})
|
||||
|
@ -124,6 +165,12 @@ export async function getClipTextModelPath() {
|
|||
return modelSavePath;
|
||||
}
|
||||
|
||||
function getGGMLClipPath() {
|
||||
return isDev
|
||||
? path.join('./build', `ggmlclip-${getPlatform()}`)
|
||||
: path.join(process.resourcesPath, `ggmlclip-${getPlatform()}`);
|
||||
}
|
||||
|
||||
async function createOnnxSession(modelPath: string) {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
intraOpNumThreads: 1,
|
||||
|
@ -135,7 +182,7 @@ let onnxImageSession: any = null;
|
|||
|
||||
async function getOnnxImageSession() {
|
||||
if (!onnxImageSession) {
|
||||
const clipModelPath = await getClipImageModelPath();
|
||||
const clipModelPath = await getClipImageModelPath('onnx');
|
||||
onnxImageSession = createOnnxSession(clipModelPath);
|
||||
}
|
||||
return onnxImageSession;
|
||||
|
@ -145,7 +192,7 @@ let onnxTextSession: any = null;
|
|||
|
||||
async function getOnnxTextSession() {
|
||||
if (!onnxTextSession) {
|
||||
const clipModelPath = await getClipTextModelPath();
|
||||
const clipModelPath = await getClipTextModelPath('onnx');
|
||||
onnxTextSession = createOnnxSession(clipModelPath);
|
||||
}
|
||||
return onnxTextSession;
|
||||
|
@ -153,6 +200,55 @@ async function getOnnxTextSession() {
|
|||
|
||||
export async function computeImageEmbedding(
|
||||
inputFilePath: string
|
||||
): Promise<Float32Array> {
|
||||
const ggmlImageEmbedding = await computeGGMLImageEmbedding(inputFilePath);
|
||||
const onnxImageEmbedding = await computeONNXImageEmbedding(inputFilePath);
|
||||
const score = await computeClipMatchScore(
|
||||
ggmlImageEmbedding,
|
||||
onnxImageEmbedding
|
||||
);
|
||||
console.log('imageEmbeddingScore', score);
|
||||
return onnxImageEmbedding;
|
||||
}
|
||||
|
||||
export async function computeGGMLImageEmbedding(
|
||||
inputFilePath: string
|
||||
): Promise<Float32Array> {
|
||||
try {
|
||||
const clipModelPath = await getClipImageModelPath('ggml');
|
||||
const ggmlclipPath = getGGMLClipPath();
|
||||
const cmd = IMAGE_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
||||
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
||||
return ggmlclipPath;
|
||||
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
||||
return clipModelPath;
|
||||
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
||||
return inputFilePath;
|
||||
} else {
|
||||
return cmdPart;
|
||||
}
|
||||
});
|
||||
|
||||
const escapedCmd = shellescape(cmd);
|
||||
log.info('running clip command', escapedCmd);
|
||||
const startTime = Date.now();
|
||||
const { stdout } = await execAsync(escapedCmd);
|
||||
log.info('clip command execution time ', Date.now() - startTime);
|
||||
// parse stdout and return embedding
|
||||
// get the last line of stdout
|
||||
const lines = stdout.split('\n');
|
||||
const lastLine = lines[lines.length - 1];
|
||||
const embedding = JSON.parse(lastLine);
|
||||
const embeddingArray = new Float32Array(embedding);
|
||||
return embeddingArray;
|
||||
} catch (err) {
|
||||
logErrorSentry(err, 'Error in computeImageEmbedding');
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
export async function computeONNXImageEmbedding(
|
||||
inputFilePath: string
|
||||
): Promise<Float32Array> {
|
||||
try {
|
||||
const imageSession = await getOnnxImageSession();
|
||||
|
@ -170,6 +266,59 @@ export async function computeImageEmbedding(
|
|||
}
|
||||
|
||||
export async function computeTextEmbedding(
|
||||
inputFilePath: string
|
||||
): Promise<Float32Array> {
|
||||
const ggmlImageEmbedding = await computeGGMLTextEmbedding(inputFilePath);
|
||||
const onnxImageEmbedding = await computeONNXTextEmbedding(inputFilePath);
|
||||
const score = await computeClipMatchScore(
|
||||
ggmlImageEmbedding,
|
||||
onnxImageEmbedding
|
||||
);
|
||||
console.log('textEmbeddingScore', score);
|
||||
return onnxImageEmbedding;
|
||||
}
|
||||
|
||||
export async function computeGGMLTextEmbedding(
|
||||
text: string
|
||||
): Promise<Float32Array> {
|
||||
try {
|
||||
const clipModelPath = await getClipTextModelPath('ggml');
|
||||
const ggmlclipPath = getGGMLClipPath();
|
||||
const cmd = TEXT_EMBEDDING_EXTRACT_CMD.map((cmdPart) => {
|
||||
if (cmdPart === GGMLCLIP_PATH_PLACEHOLDER) {
|
||||
return ggmlclipPath;
|
||||
} else if (cmdPart === CLIP_MODEL_PATH_PLACEHOLDER) {
|
||||
return clipModelPath;
|
||||
} else if (cmdPart === INPUT_PATH_PLACEHOLDER) {
|
||||
return text;
|
||||
} else {
|
||||
return cmdPart;
|
||||
}
|
||||
});
|
||||
|
||||
const escapedCmd = shellescape(cmd);
|
||||
log.info('running clip command', escapedCmd);
|
||||
const startTime = Date.now();
|
||||
const { stdout } = await execAsync(escapedCmd);
|
||||
log.info('clip command execution time ', Date.now() - startTime);
|
||||
// parse stdout and return embedding
|
||||
// get the last line of stdout
|
||||
const lines = stdout.split('\n');
|
||||
const lastLine = lines[lines.length - 1];
|
||||
const embedding = JSON.parse(lastLine);
|
||||
const embeddingArray = new Float32Array(embedding);
|
||||
return embeddingArray;
|
||||
} catch (err) {
|
||||
if (err.message === CustomErrors.MODEL_DOWNLOAD_PENDING) {
|
||||
log.info(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
logErrorSentry(err, 'Error in computeTextEmbedding');
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
export async function computeONNXTextEmbedding(
|
||||
text: string
|
||||
): Promise<Float32Array> {
|
||||
try {
|
||||
|
@ -239,3 +388,30 @@ async function getRgbData(inputFilePath: string) {
|
|||
}
|
||||
return Float32Array.from(rgbData.flat().flat());
|
||||
}
|
||||
|
||||
export const computeClipMatchScore = async (
|
||||
imageEmbedding: Float32Array,
|
||||
textEmbedding: Float32Array
|
||||
) => {
|
||||
if (imageEmbedding.length !== textEmbedding.length) {
|
||||
throw Error('imageEmbedding and textEmbedding length mismatch');
|
||||
}
|
||||
let score = 0;
|
||||
let imageNormalization = 0;
|
||||
let textNormalization = 0;
|
||||
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
imageNormalization += imageEmbedding[index] * imageEmbedding[index];
|
||||
textNormalization += textEmbedding[index] * textEmbedding[index];
|
||||
}
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
imageEmbedding[index] =
|
||||
imageEmbedding[index] / Math.sqrt(imageNormalization);
|
||||
textEmbedding[index] =
|
||||
textEmbedding[index] / Math.sqrt(textNormalization);
|
||||
}
|
||||
for (let index = 0; index < imageEmbedding.length; index++) {
|
||||
score += imageEmbedding[index] * textEmbedding[index];
|
||||
}
|
||||
return score;
|
||||
};
|
||||
|
|
Loading…
Add table
Reference in a new issue