diff --git a/desktop/src/main/services/ml.ts b/desktop/src/main/services/ml.ts index 8292596a2..6b38bc74d 100644 --- a/desktop/src/main/services/ml.ts +++ b/desktop/src/main/services/ml.ts @@ -34,6 +34,7 @@ import { writeStream } from "../stream"; * actively trigger a download until the returned function is called. * * @param modelName The name of the model to download. + * * @param modelByteSize The size in bytes that we expect the model to have. If * the size of the downloaded model does not match the expected size, then we * will redownload it. @@ -99,13 +100,15 @@ const downloadModel = async (saveLocation: string, name: string) => { // `mkdir -p` the directory where we want to save the model. const saveDir = path.dirname(saveLocation); await fs.mkdir(saveDir, { recursive: true }); - // Download + // Download. log.info(`Downloading ML model from ${name}`); const url = `https://models.ente.io/${name}`; const res = await net.fetch(url); if (!res.ok) throw new Error(`Failed to fetch ${url}: HTTP ${res.status}`); - // Save - await writeStream(saveLocation, res.body); + const body = res.body; + if (!body) throw new Error(`Received an null response for ${url}`); + // Save. + await writeStream(saveLocation, body); log.info(`Downloaded CLIP model ${name}`); }; @@ -114,9 +117,9 @@ const downloadModel = async (saveLocation: string, name: string) => { */ const createInferenceSession = async (modelPath: string) => { return await ort.InferenceSession.create(modelPath, { - // Restrict the number of threads to 1 + // Restrict the number of threads to 1. intraOpNumThreads: 1, - // Be more conservative with RAM usage + // Be more conservative with RAM usage. enableCpuMemArena: false, }); };