Split
This commit is contained in:
parent
d3eb85be8d
commit
2b6047a979
3 changed files with 404 additions and 0 deletions
248
desktop/src/main/services/ml-clip.ts
Normal file
248
desktop/src/main/services/ml-clip.ts
Normal file
|
@ -0,0 +1,248 @@
|
|||
/**
|
||||
* @file Compute CLIP embeddings for images and text.
|
||||
*
|
||||
* The embeddings are computed using ONNX runtime, with CLIP as the model.
|
||||
*
|
||||
* @see `web/apps/photos/src/services/clip-service.ts` for more details.
|
||||
*/
|
||||
import { existsSync } from "fs";
|
||||
import jpeg from "jpeg-js";
|
||||
import fs from "node:fs/promises";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
||||
import { CustomErrors } from "../../types/ipc";
|
||||
import { writeStream } from "../fs";
|
||||
import log from "../log";
|
||||
import { generateTempFilePath } from "../temp";
|
||||
import { deleteTempFile } from "./ffmpeg";
|
||||
import {
|
||||
createInferenceSession,
|
||||
downloadModel,
|
||||
modelPathDownloadingIfNeeded,
|
||||
modelSavePath,
|
||||
} from "./ml";
|
||||
|
||||
const textModelName = "clip-text-vit-32-uint8.onnx";
|
||||
const textModelByteSize = 64173509; // 61.2 MB
|
||||
|
||||
const imageModelName = "clip-image-vit-32-float32.onnx";
|
||||
const imageModelByteSize = 351468764; // 335.2 MB
|
||||
|
||||
let activeImageModelDownload: Promise<string> | undefined;
|
||||
|
||||
const imageModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeImageModelDownload) {
|
||||
log.info("Waiting for CLIP image model download to finish");
|
||||
await activeImageModelDownload;
|
||||
} else {
|
||||
activeImageModelDownload = modelPathDownloadingIfNeeded(
|
||||
imageModelName,
|
||||
imageModelByteSize,
|
||||
);
|
||||
return await activeImageModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeImageModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let textModelDownloadInProgress = false;
|
||||
|
||||
/* TODO(MR): use the generic method. Then we can remove the exports for the
|
||||
internal details functions that we use here */
|
||||
const textModelPathDownloadingIfNeeded = async () => {
|
||||
if (textModelDownloadInProgress)
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
|
||||
const modelPath = modelSavePath(textModelName);
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP text model not found, downloading");
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
} else {
|
||||
const localFileSize = (await fs.stat(modelPath)).size;
|
||||
if (localFileSize !== textModelByteSize) {
|
||||
log.error(
|
||||
`CLIP text model size ${localFileSize} does not match the expected size, downloading again`,
|
||||
);
|
||||
textModelDownloadInProgress = true;
|
||||
downloadModel(modelPath, textModelName)
|
||||
.catch((e) => {
|
||||
// log but otherwise ignore
|
||||
log.error("CLIP text model download failed", e);
|
||||
})
|
||||
.finally(() => {
|
||||
textModelDownloadInProgress = false;
|
||||
});
|
||||
throw Error(CustomErrors.MODEL_DOWNLOAD_PENDING);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
let imageSessionPromise: Promise<any> | undefined;
|
||||
|
||||
const onnxImageSession = async () => {
|
||||
if (!imageSessionPromise) {
|
||||
imageSessionPromise = (async () => {
|
||||
const modelPath = await imageModelPathDownloadingIfNeeded();
|
||||
return createInferenceSession(modelPath);
|
||||
})();
|
||||
}
|
||||
return imageSessionPromise;
|
||||
};
|
||||
|
||||
let _textSession: any = null;
|
||||
|
||||
const onnxTextSession = async () => {
|
||||
if (!_textSession) {
|
||||
const modelPath = await textModelPathDownloadingIfNeeded();
|
||||
_textSession = await createInferenceSession(modelPath);
|
||||
}
|
||||
return _textSession;
|
||||
};
|
||||
|
||||
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
||||
const tempFilePath = await generateTempFilePath("");
|
||||
const imageStream = new Response(jpegImageData.buffer).body;
|
||||
await writeStream(tempFilePath, imageStream);
|
||||
try {
|
||||
return await clipImageEmbedding_(tempFilePath);
|
||||
} finally {
|
||||
await deleteTempFile(tempFilePath);
|
||||
}
|
||||
};
|
||||
|
||||
const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
const imageSession = await onnxImageSession();
|
||||
const t1 = Date.now();
|
||||
const rgbData = await getRGBData(jpegFilePath);
|
||||
const feeds = {
|
||||
input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.debug(
|
||||
() =>
|
||||
`CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const imageEmbedding = results["output"].data; // Float32Array
|
||||
return normalizeEmbedding(imageEmbedding);
|
||||
};
|
||||
|
||||
const getRGBData = async (jpegFilePath: string) => {
|
||||
const jpegData = await fs.readFile(jpegFilePath);
|
||||
const rawImageData = jpeg.decode(jpegData, {
|
||||
useTArray: true,
|
||||
formatAsRGBA: false,
|
||||
});
|
||||
|
||||
const nx: number = rawImageData.width;
|
||||
const ny: number = rawImageData.height;
|
||||
const inputImage: Uint8Array = rawImageData.data;
|
||||
|
||||
const nx2: number = 224;
|
||||
const ny2: number = 224;
|
||||
const totalSize: number = 3 * nx2 * ny2;
|
||||
|
||||
const result: number[] = Array(totalSize).fill(0);
|
||||
const scale: number = Math.max(nx, ny) / 224;
|
||||
|
||||
const nx3: number = Math.round(nx / scale);
|
||||
const ny3: number = Math.round(ny / scale);
|
||||
|
||||
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
|
||||
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
|
||||
|
||||
for (let y = 0; y < ny3; y++) {
|
||||
for (let x = 0; x < nx3; x++) {
|
||||
for (let c = 0; c < 3; c++) {
|
||||
// Linear interpolation
|
||||
const sx: number = (x + 0.5) * scale - 0.5;
|
||||
const sy: number = (y + 0.5) * scale - 0.5;
|
||||
|
||||
const x0: number = Math.max(0, Math.floor(sx));
|
||||
const y0: number = Math.max(0, Math.floor(sy));
|
||||
|
||||
const x1: number = Math.min(x0 + 1, nx - 1);
|
||||
const y1: number = Math.min(y0 + 1, ny - 1);
|
||||
|
||||
const dx: number = sx - x0;
|
||||
const dy: number = sy - y0;
|
||||
|
||||
const j00: number = 3 * (y0 * nx + x0) + c;
|
||||
const j01: number = 3 * (y0 * nx + x1) + c;
|
||||
const j10: number = 3 * (y1 * nx + x0) + c;
|
||||
const j11: number = 3 * (y1 * nx + x1) + c;
|
||||
|
||||
const v00: number = inputImage[j00];
|
||||
const v01: number = inputImage[j01];
|
||||
const v10: number = inputImage[j10];
|
||||
const v11: number = inputImage[j11];
|
||||
|
||||
const v0: number = v00 * (1 - dx) + v01 * dx;
|
||||
const v1: number = v10 * (1 - dx) + v11 * dx;
|
||||
|
||||
const v: number = v0 * (1 - dy) + v1 * dy;
|
||||
|
||||
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
|
||||
|
||||
// createTensorWithDataList is dumb compared to reshape and
|
||||
// hence has to be given with one channel after another
|
||||
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
|
||||
|
||||
result[i] = (v2 / 255 - mean[c]) / std[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
const normalizeEmbedding = (embedding: Float32Array) => {
|
||||
let normalization = 0;
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
normalization += embedding[index] * embedding[index];
|
||||
}
|
||||
const sqrtNormalization = Math.sqrt(normalization);
|
||||
for (let index = 0; index < embedding.length; index++) {
|
||||
embedding[index] = embedding[index] / sqrtNormalization;
|
||||
}
|
||||
return embedding;
|
||||
};
|
||||
|
||||
let _tokenizer: Tokenizer = null;
|
||||
const getTokenizer = () => {
|
||||
if (!_tokenizer) {
|
||||
_tokenizer = new Tokenizer();
|
||||
}
|
||||
return _tokenizer;
|
||||
};
|
||||
|
||||
export const clipTextEmbedding = async (text: string) => {
|
||||
const imageSession = await onnxTextSession();
|
||||
const t1 = Date.now();
|
||||
const tokenizer = getTokenizer();
|
||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||
const feeds = {
|
||||
input: new ort.Tensor("int32", tokenizedText, [1, 77]),
|
||||
};
|
||||
const t2 = Date.now();
|
||||
const results = await imageSession.run(feeds);
|
||||
log.debug(
|
||||
() =>
|
||||
`CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
);
|
||||
const textEmbedding = results["output"].data;
|
||||
return normalizeEmbedding(textEmbedding);
|
||||
};
|
77
desktop/src/main/services/ml-face.ts
Normal file
77
desktop/src/main/services/ml-face.ts
Normal file
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* @file Various face recognition related tasks.
|
||||
*
|
||||
* - Face detection with the YOLO model.
|
||||
* - Face embedding with the mobilefacenet model.
|
||||
*
|
||||
* The runtime used is ONNX.
|
||||
*/
|
||||
import * as ort from "onnxruntime-node";
|
||||
import log from "../log";
|
||||
import { createInferenceSession, modelPathDownloadingIfNeeded } from "./ml";
|
||||
|
||||
const faceDetectionModelName = "yolov5s_face_640_640_dynamic.onnx";
|
||||
const faceDetectionModelByteSize = 30762872; // 29.3 MB
|
||||
|
||||
const faceEmbeddingModelName = "mobilefacenet_opset15.onnx";
|
||||
const faceEmbeddingModelByteSize = 5286998; // 5 MB
|
||||
|
||||
let activeFaceDetectionModelDownload: Promise<string> | undefined;
|
||||
|
||||
const faceDetectionModelPathDownloadingIfNeeded = async () => {
|
||||
try {
|
||||
if (activeFaceDetectionModelDownload) {
|
||||
log.info("Waiting for face detection model download to finish");
|
||||
await activeFaceDetectionModelDownload;
|
||||
} else {
|
||||
activeFaceDetectionModelDownload = modelPathDownloadingIfNeeded(
|
||||
faceDetectionModelName,
|
||||
faceDetectionModelByteSize,
|
||||
);
|
||||
return await activeFaceDetectionModelDownload;
|
||||
}
|
||||
} finally {
|
||||
activeFaceDetectionModelDownload = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
let _faceDetectionSession: Promise<ort.InferenceSession> | undefined;
|
||||
|
||||
const faceDetectionSession = async () => {
|
||||
if (!_faceDetectionSession) {
|
||||
_faceDetectionSession =
|
||||
faceDetectionModelPathDownloadingIfNeeded().then((modelPath) =>
|
||||
createInferenceSession(modelPath),
|
||||
);
|
||||
}
|
||||
return _faceDetectionSession;
|
||||
};
|
||||
|
||||
|
||||
// export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
|
||||
// const tempFilePath = await generateTempFilePath("");
|
||||
// const imageStream = new Response(jpegImageData.buffer).body;
|
||||
// await writeStream(tempFilePath, imageStream);
|
||||
// try {
|
||||
// return await clipImageEmbedding_(tempFilePath);
|
||||
// } finally {
|
||||
// await deleteTempFile(tempFilePath);
|
||||
// }
|
||||
// };
|
||||
|
||||
// const clipImageEmbedding_ = async (jpegFilePath: string) => {
|
||||
// const imageSession = await onnxImageSession();
|
||||
// const t1 = Date.now();
|
||||
// const rgbData = await getRGBData(jpegFilePath);
|
||||
// const feeds = {
|
||||
// input: new ort.Tensor("float32", rgbData, [1, 3, 224, 224]),
|
||||
// };
|
||||
// const t2 = Date.now();
|
||||
// const results = await imageSession.run(feeds);
|
||||
// log.debug(
|
||||
// () =>
|
||||
// `CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
|
||||
// );
|
||||
// const imageEmbedding = results["output"].data; // Float32Array
|
||||
// return normalizeEmbedding(imageEmbedding);
|
||||
// };
|
79
desktop/src/main/services/ml.ts
Normal file
79
desktop/src/main/services/ml.ts
Normal file
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* @file AI/ML related functionality.
|
||||
*
|
||||
* @see also `ml-clip.ts`, `ml-face.ts`.
|
||||
*
|
||||
* The ML runtime we use for inference is [ONNX](https://onnxruntime.ai). Models
|
||||
* for various tasks are not shipped with the app but are downloaded on demand.
|
||||
*
|
||||
* The primary reason for doing these tasks in the Node.js layer is so that we
|
||||
* can use the binary ONNX runtime which is 10-20x faster than the WASM based
|
||||
* web one.
|
||||
*/
|
||||
import { app, net } from "electron/main";
|
||||
import { existsSync } from "fs";
|
||||
import fs from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import * as ort from "onnxruntime-node";
|
||||
import { writeStream } from "../fs";
|
||||
import log from "../log";
|
||||
|
||||
/**
|
||||
* Download the model named {@link modelName} if we don't already have it.
|
||||
*
|
||||
* Also verify that the size of the model we get matches {@expectedByteSize} (if
|
||||
* not, redownload it).
|
||||
*
|
||||
* @returns the path to the model on the local machine.
|
||||
*/
|
||||
export const modelPathDownloadingIfNeeded = async (
|
||||
modelName: string,
|
||||
expectedByteSize: number,
|
||||
) => {
|
||||
const modelPath = modelSavePath(modelName);
|
||||
|
||||
if (!existsSync(modelPath)) {
|
||||
log.info("CLIP image model not found, downloading");
|
||||
await downloadModel(modelPath, modelName);
|
||||
} else {
|
||||
const size = (await fs.stat(modelPath)).size;
|
||||
if (size !== expectedByteSize) {
|
||||
log.error(
|
||||
`The size ${size} of model ${modelName} does not match the expected size, downloading again`,
|
||||
);
|
||||
await downloadModel(modelPath, modelName);
|
||||
}
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
};
|
||||
|
||||
/** Return the path where the given {@link modelName} is meant to be saved */
|
||||
export const modelSavePath = (modelName: string) =>
|
||||
path.join(app.getPath("userData"), "models", modelName);
|
||||
|
||||
export const downloadModel = async (saveLocation: string, name: string) => {
|
||||
// `mkdir -p` the directory where we want to save the model.
|
||||
const saveDir = path.dirname(saveLocation);
|
||||
await fs.mkdir(saveDir, { recursive: true });
|
||||
// Download
|
||||
log.info(`Downloading ML model from ${name}`);
|
||||
const url = `https://models.ente.io/${name}`;
|
||||
const res = await net.fetch(url);
|
||||
if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`);
|
||||
// Save
|
||||
await writeStream(saveLocation, res.body);
|
||||
log.info(`Downloaded CLIP model ${name}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Crete an ONNX {@link InferenceSession} with some defaults.
|
||||
*/
|
||||
export const createInferenceSession = async (modelPath: string) => {
|
||||
return await ort.InferenceSession.create(modelPath, {
|
||||
// Restrict the number of threads to 1
|
||||
intraOpNumThreads: 1,
|
||||
// Be more conservative with RAM usage
|
||||
enableCpuMemArena: false,
|
||||
});
|
||||
};
|
Loading…
Add table
Reference in a new issue