[web] ML cleanup - Part 7/x (#1771)

This commit is contained in:
Manav Rathi 2024-05-19 20:19:59 +05:30 committed by GitHub
commit 69f06f753c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 566 additions and 855 deletions

View file

@ -9,7 +9,7 @@ import { useCallback, useContext, useEffect, useRef, useState } from "react";
import { components } from "react-select"; import { components } from "react-select";
import AsyncSelect from "react-select/async"; import AsyncSelect from "react-select/async";
import { InputActionMeta } from "react-select/src/types"; import { InputActionMeta } from "react-select/src/types";
import { Person } from "services/face/types"; import type { Person } from "services/face/people";
import { City } from "services/locationSearchService"; import { City } from "services/locationSearchService";
import { import {
getAutoCompleteSuggestions, getAutoCompleteSuggestions,

View file

@ -1,10 +1,11 @@
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log"; import log from "@/next/log";
import { Skeleton, styled } from "@mui/material"; import { Skeleton, styled } from "@mui/material";
import { Legend } from "components/PhotoViewer/styledComponents/Legend"; import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next"; import { t } from "i18next";
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import mlIDbStorage from "services/face/db"; import mlIDbStorage from "services/face/db";
import { Face, Person, type MlFileData } from "services/face/types"; import type { Person } from "services/face/people";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
const FaceChipContainer = styled("div")` const FaceChipContainer = styled("div")`
@ -57,10 +58,7 @@ export const PeopleList = React.memo((props: PeopleListProps) => {
props.onSelect && props.onSelect(person, index) props.onSelect && props.onSelect(person, index)
} }
> >
<FaceCropImageView <FaceCropImageView faceID={person.displayFaceId} />
faceID={person.displayFaceId}
cacheKey={person.faceCropCacheKey}
/>
</FaceChip> </FaceChip>
))} ))}
</FaceChipContainer> </FaceChipContainer>
@ -108,7 +106,7 @@ export function UnidentifiedFaces(props: {
file: EnteFile; file: EnteFile;
updateMLDataIndex: number; updateMLDataIndex: number;
}) { }) {
const [faces, setFaces] = useState<Array<Face>>([]); const [faces, setFaces] = useState<{ id: string }[]>([]);
useEffect(() => { useEffect(() => {
let didCancel = false; let didCancel = false;
@ -136,10 +134,7 @@ export function UnidentifiedFaces(props: {
{faces && {faces &&
faces.map((face, index) => ( faces.map((face, index) => (
<FaceChip key={index}> <FaceChip key={index}>
<FaceCropImageView <FaceCropImageView faceID={face.id} />
faceID={face.id}
cacheKey={face.crop?.cacheKey}
/>
</FaceChip> </FaceChip>
))} ))}
</FaceChipContainer> </FaceChipContainer>
@ -149,13 +144,9 @@ export function UnidentifiedFaces(props: {
interface FaceCropImageViewProps { interface FaceCropImageViewProps {
faceID: string; faceID: string;
cacheKey?: string;
} }
const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({ faceID }) => {
faceID,
cacheKey,
}) => {
const [objectURL, setObjectURL] = useState<string | undefined>(); const [objectURL, setObjectURL] = useState<string | undefined>();
useEffect(() => { useEffect(() => {
@ -165,12 +156,16 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
if (faceID && electron) { if (faceID && electron) {
electron electron
.legacyFaceCrop(faceID) .legacyFaceCrop(faceID)
.then(async (data) => {
if (data) return data;
/* /*
TODO(MR): regen if needed and get this to work on web too.
cachedOrNew("face-crops", cacheKey, async () => { cachedOrNew("face-crops", cacheKey, async () => {
return machineLearningService.regenerateFaceCrop( return regenerateFaceCrop(faceId);
faceId,
);
})*/ })*/
const cache = await blobCache("face-crops");
return await cache.get(faceID);
})
.then((data) => { .then((data) => {
if (data) { if (data) {
const blob = new Blob([data]); const blob = new Blob([data]);
@ -183,7 +178,7 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
didCancel = true; didCancel = true;
if (objectURL) URL.revokeObjectURL(objectURL); if (objectURL) URL.revokeObjectURL(objectURL);
}; };
}, [faceID, cacheKey]); }, [faceID]);
return objectURL ? ( return objectURL ? (
<img src={objectURL} /> <img src={objectURL} />
@ -192,9 +187,9 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
); );
}; };
async function getPeopleList(file: EnteFile): Promise<Array<Person>> { async function getPeopleList(file: EnteFile): Promise<Person[]> {
let startTime = Date.now(); let startTime = Date.now();
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); const mlFileData = await mlIDbStorage.getFile(file.id);
log.info( log.info(
"getPeopleList:mlFilesStore:getItem", "getPeopleList:mlFilesStore:getItem",
Date.now() - startTime, Date.now() - startTime,
@ -226,8 +221,8 @@ async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
return peopleList; return peopleList;
} }
async function getUnidentifiedFaces(file: EnteFile): Promise<Array<Face>> { async function getUnidentifiedFaces(file: EnteFile): Promise<{ id: string }[]> {
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); const mlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter( return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined, (f) => f.personId === null || f.personId === undefined,

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type"; import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo"; import { decodeLivePhoto } from "@/media/live-photo";
import { openCache, type BlobCache } from "@/next/blob-cache"; import { blobCache, type BlobCache } from "@/next/blob-cache";
import log from "@/next/log"; import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants"; import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto"; import ComlinkCryptoWorker from "@ente/shared/crypto";
@ -91,7 +91,7 @@ class DownloadManagerImpl {
} }
this.downloadClient = createDownloadClient(app, tokens); this.downloadClient = createDownloadClient(app, tokens);
try { try {
this.thumbnailCache = await openCache("thumbs"); this.thumbnailCache = await blobCache("thumbs");
} catch (e) { } catch (e) {
log.error( log.error(
"Failed to open thumbnail cache, will continue without it", "Failed to open thumbnail cache, will continue without it",
@ -100,7 +100,7 @@ class DownloadManagerImpl {
} }
// TODO (MR): Revisit full file caching cf disk space usage // TODO (MR): Revisit full file caching cf disk space usage
// try { // try {
// if (isElectron()) this.fileCache = await openCache("files"); // if (isElectron()) this.fileCache = await cache("files");
// } catch (e) { // } catch (e) {
// log.error("Failed to open file cache, will continue without it", e); // log.error("Failed to open file cache, will continue without it", e);
// } // }

View file

@ -7,7 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api"; import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage"; import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers"; import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { FileML } from "services/machineLearning/machineLearningService"; import { FileML } from "services/face/remote";
import type { import type {
Embedding, Embedding,
EmbeddingModel, EmbeddingModel,

View file

@ -1,8 +1,9 @@
import { Hdbscan, type DebugInfo } from "hdbscan"; import { Hdbscan, type DebugInfo } from "hdbscan";
import { type Cluster } from "services/face/types";
export type Cluster = number[];
export interface ClusterFacesResult { export interface ClusterFacesResult {
clusters: Array<Cluster>; clusters: Cluster[];
noise: Cluster; noise: Cluster;
debugInfo?: DebugInfo; debugInfo?: DebugInfo;
} }

View file

@ -9,7 +9,8 @@ import {
openDB, openDB,
} from "idb"; } from "idb";
import isElectron from "is-electron"; import isElectron from "is-electron";
import { Face, MlFileData, Person } from "services/face/types"; import type { Person } from "services/face/people";
import type { MlFileData } from "services/face/types";
import { import {
DEFAULT_ML_SEARCH_CONFIG, DEFAULT_ML_SEARCH_CONFIG,
MAX_ML_SYNC_ERROR_COUNT, MAX_ML_SYNC_ERROR_COUNT,
@ -23,6 +24,18 @@ export interface IndexStatus {
peopleIndexSynced: boolean; peopleIndexSynced: boolean;
} }
/**
* TODO(MR): Transient type with an intersection of values that both existing
* and new types during the migration will have. Eventually we'll store the the
* server ML data shape here exactly.
*/
export interface MinimalPersistedFileData {
fileId: number;
mlVersion: number;
errorCount: number;
faces?: { personId?: number; id: string }[];
}
interface Config {} interface Config {}
export const ML_SEARCH_CONFIG_NAME = "ml-search"; export const ML_SEARCH_CONFIG_NAME = "ml-search";
@ -31,7 +44,7 @@ const MLDATA_DB_NAME = "mldata";
interface MLDb extends DBSchema { interface MLDb extends DBSchema {
files: { files: {
key: number; key: number;
value: MlFileData; value: MinimalPersistedFileData;
indexes: { mlVersion: [number, number] }; indexes: { mlVersion: [number, number] };
}; };
people: { people: {
@ -211,38 +224,6 @@ class MLIDbStorage {
await this.db; await this.db;
} }
public async getAllFileIds() {
const db = await this.db;
return db.getAllKeys("files");
}
public async putAllFilesInTx(mlFiles: Array<MlFileData>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
await tx.done;
}
public async removeAllFilesInTx(fileIds: Array<number>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
await tx.done;
}
public async newTransaction<
Name extends StoreNames<MLDb>,
Mode extends IDBTransactionMode = "readonly",
>(storeNames: Name, mode?: Mode) {
const db = await this.db;
return db.transaction(storeNames, mode);
}
public async commit(tx: IDBPTransaction<MLDb>) {
return tx.done;
}
public async getAllFileIdsForUpdate( public async getAllFileIdsForUpdate(
tx: IDBPTransaction<MLDb, ["files"], "readwrite">, tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) { ) {
@ -276,16 +257,11 @@ class MLIDbStorage {
return fileIds; return fileIds;
} }
public async getFile(fileId: number) { public async getFile(fileId: number): Promise<MinimalPersistedFileData> {
const db = await this.db; const db = await this.db;
return db.get("files", fileId); return db.get("files", fileId);
} }
public async getAllFiles() {
const db = await this.db;
return db.getAll("files");
}
public async putFile(mlFile: MlFileData) { public async putFile(mlFile: MlFileData) {
const db = await this.db; const db = await this.db;
return db.put("files", mlFile); return db.put("files", mlFile);
@ -293,7 +269,7 @@ class MLIDbStorage {
public async upsertFileInTx( public async upsertFileInTx(
fileId: number, fileId: number,
upsert: (mlFile: MlFileData) => MlFileData, upsert: (mlFile: MinimalPersistedFileData) => MinimalPersistedFileData,
) { ) {
const db = await this.db; const db = await this.db;
const tx = db.transaction("files", "readwrite"); const tx = db.transaction("files", "readwrite");
@ -306,7 +282,7 @@ class MLIDbStorage {
} }
public async putAllFiles( public async putAllFiles(
mlFiles: Array<MlFileData>, mlFiles: MinimalPersistedFileData[],
tx: IDBPTransaction<MLDb, ["files"], "readwrite">, tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) { ) {
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile))); await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
@ -319,44 +295,6 @@ class MLIDbStorage {
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId))); await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
} }
public async getFace(fileID: number, faceId: string) {
const file = await this.getFile(fileID);
const face = file.faces.filter((f) => f.id === faceId);
return face[0];
}
public async getAllFacesMap() {
const startTime = Date.now();
const db = await this.db;
const allFiles = await db.getAll("files");
const allFacesMap = new Map<number, Array<Face>>();
allFiles.forEach(
(mlFileData) =>
mlFileData.faces &&
allFacesMap.set(mlFileData.fileId, mlFileData.faces),
);
log.info("getAllFacesMap", Date.now() - startTime, "ms");
return allFacesMap;
}
public async updateFaces(allFacesMap: Map<number, Face[]>) {
const startTime = Date.now();
const db = await this.db;
const tx = db.transaction("files", "readwrite");
let cursor = await tx.store.openCursor();
while (cursor) {
if (allFacesMap.has(cursor.key)) {
const mlFileData = { ...cursor.value };
mlFileData.faces = allFacesMap.get(cursor.key);
cursor.update(mlFileData);
}
cursor = await cursor.continue();
}
await tx.done;
log.info("updateFaces", Date.now() - startTime, "ms");
}
public async getPerson(id: number) { public async getPerson(id: number) {
const db = await this.db; const db = await this.db;
return db.get("people", id); return db.get("people", id);
@ -367,21 +305,6 @@ class MLIDbStorage {
return db.getAll("people"); return db.getAll("people");
} }
public async putPerson(person: Person) {
const db = await this.db;
return db.put("people", person);
}
public async clearAllPeople() {
const db = await this.db;
return db.clear("people");
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get("versions", index);
}
public async incrementIndexVersion(index: StoreNames<MLDb>) { public async incrementIndexVersion(index: StoreNames<MLDb>) {
if (index === "versions") { if (index === "versions") {
throw new Error("versions store can not be versioned"); throw new Error("versions store can not be versioned");
@ -396,11 +319,6 @@ class MLIDbStorage {
return version; return version;
} }
public async setIndexVersion(index: string, version: number) {
const db = await this.db;
return db.put("versions", version, index);
}
public async getConfig<T extends Config>(name: string, def: T) { public async getConfig<T extends Config>(name: string, def: T) {
const db = await this.db; const db = await this.db;
const tx = db.transaction("configs", "readwrite"); const tx = db.transaction("configs", "readwrite");
@ -464,66 +382,6 @@ class MLIDbStorage {
peopleIndexVersion === filesIndexVersion, peopleIndexVersion === filesIndexVersion,
}; };
} }
// for debug purpose
public async getAllMLData() {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readonly");
const allMLData: any = {};
for (const store of tx.objectStoreNames) {
const keys = await tx.objectStore(store).getAllKeys();
const data = await tx.objectStore(store).getAll();
allMLData[store] = {};
for (let i = 0; i < keys.length; i++) {
allMLData[store][keys[i]] = data[i];
}
}
await tx.done;
const files = allMLData["files"];
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Array.from(f.embedding)),
);
}
return allMLData;
}
// for debug purpose, this will overwrite all data
public async putAllMLData(allMLData: Map<string, any>) {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readwrite");
for (const store of tx.objectStoreNames) {
const records = allMLData[store];
if (!records) {
continue;
}
const txStore = tx.objectStore(store);
if (store === "files") {
const files = records;
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Float32Array.from(f.embedding)),
);
}
}
await txStore.clear();
for (const key of Object.keys(records)) {
if (txStore.keyPath) {
txStore.put(records[key]);
} else {
txStore.put(records[key], key);
}
}
}
await tx.done;
}
} }
export default new MLIDbStorage(); export default new MLIDbStorage();

View file

@ -1,26 +1,29 @@
import { FILE_TYPE } from "@/media/file-type"; import { FILE_TYPE } from "@/media/file-type";
import { openCache } from "@/next/blob-cache"; import { blobCache } from "@/next/blob-cache";
import log from "@/next/log"; import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge"; import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan"; import { euclidean } from "hdbscan";
import { Matrix } from "ml-matrix"; import { Matrix } from "ml-matrix";
import { Box, Dimensions, Point, enlargeBox } from "services/face/geom";
import { import {
Box,
Dimensions,
Point,
enlargeBox,
roundBox,
} from "services/face/geom";
import type {
Face, Face,
FaceAlignment, FaceAlignment,
FaceCrop,
FaceDetection, FaceDetection,
FaceEmbedding, MlFileData,
type MlFileData,
} from "services/face/types"; } from "services/face/types";
import { defaultMLVersion } from "services/machineLearning/machineLearningService"; import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { getSimilarityTransformation } from "similarity-transformation"; import { getSimilarityTransformation } from "similarity-transformation";
import type { EnteFile } from "types/file"; import type { EnteFile } from "types/file";
import { fetchImageBitmap, getLocalFileImageBitmap } from "./file";
import { import {
clamp, clamp,
createGrayscaleIntMatrixFromNormalized2List, grayscaleIntMatrixFromNormalized2List,
fetchImageBitmap,
getLocalFileImageBitmap,
pixelRGBBilinear, pixelRGBBilinear,
warpAffineFloat32List, warpAffineFloat32List,
} from "./image"; } from "./image";
@ -85,47 +88,49 @@ const fetchOrCreateImageBitmap = async (
const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => { const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
const fileID = enteFile.id; const fileID = enteFile.id;
const { width, height } = imageBitmap; const imageDimensions: Dimensions = imageBitmap;
const mlFile: MlFileData = { const mlFile: MlFileData = {
fileId: fileID, fileId: fileID,
mlVersion: defaultMLVersion, mlVersion: defaultMLVersion,
imageDimensions: { width, height }, imageDimensions,
errorCount: 0, errorCount: 0,
}; };
const faceDetections = await detectFaces(imageBitmap); const faceDetections = await detectFaces(imageBitmap);
const detectedFaces = faceDetections.map((detection) => ({ const detectedFaces = faceDetections.map((detection) => ({
id: makeFaceID(fileID, detection, mlFile.imageDimensions), id: makeFaceID(fileID, detection, imageDimensions),
fileId: fileID, fileId: fileID,
detection, detection,
})); }));
mlFile.faces = detectedFaces; mlFile.faces = detectedFaces;
if (detectedFaces.length > 0) { if (detectedFaces.length > 0) {
await Promise.all( const alignments: FaceAlignment[] = [];
detectedFaces.map((face) => saveFaceCrop(imageBitmap, face)),
);
// Execute the face alignment calculations
for (const face of mlFile.faces) { for (const face of mlFile.faces) {
face.alignment = faceAlignment(face.detection); const alignment = faceAlignment(face.detection);
face.alignment = alignment;
alignments.push(alignment);
await saveFaceCrop(imageBitmap, face);
} }
// Extract face images and convert to Float32Array const alignedFacesData = convertToMobileFaceNetInput(
const faceAlignments = mlFile.faces.map((f) => f.alignment);
const alignedFacesData = await extractFaceImagesToFloat32(
faceAlignments,
mobileFaceNetFaceSize,
imageBitmap, imageBitmap,
alignments,
); );
const blurValues = detectBlur(alignedFacesData, mlFile.faces); const blurValues = detectBlur(alignedFacesData, mlFile.faces);
mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i])); mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
const embeddings = await faceEmbeddings(alignedFacesData); const embeddings = await computeEmbeddings(alignedFacesData);
mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
convertFaceDetectionsToRelative(mlFile); // TODO-ML: Skip if somehow already relative. But why would it be?
// if (face.detection.box.x + face.detection.box.width < 2) continue;
mlFile.faces.forEach((face) => {
face.detection = relativeDetection(face.detection, imageDimensions);
});
} }
return mlFile; return mlFile;
@ -170,8 +175,7 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
const requiredWidth = 640; const requiredWidth = 640;
const requiredHeight = 640; const requiredHeight = 640;
const width = imageBitmap.width; const { width, height } = imageBitmap;
const height = imageBitmap.height;
// Create an OffscreenCanvas and set its size. // Create an OffscreenCanvas and set its size.
const offscreenCanvas = new OffscreenCanvas(width, height); const offscreenCanvas = new OffscreenCanvas(width, height);
@ -405,37 +409,39 @@ const faceAlignmentUsingSimilarityTransform = (
return { affineMatrix, center, size, rotation }; return { affineMatrix, center, size, rotation };
}; };
async function extractFaceImagesToFloat32( const convertToMobileFaceNetInput = (
faceAlignments: Array<FaceAlignment>, imageBitmap: ImageBitmap,
faceSize: number, faceAlignments: FaceAlignment[],
image: ImageBitmap, ): Float32Array => {
): Promise<Float32Array> { const faceSize = mobileFaceNetFaceSize;
const faceData = new Float32Array( const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3, faceAlignments.length * faceSize * faceSize * 3,
); );
for (let i = 0; i < faceAlignments.length; i++) { for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i]; const { affineMatrix } = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3; const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List( warpAffineFloat32List(
image, imageBitmap,
alignedFace, affineMatrix,
faceSize, faceSize,
faceData, faceData,
faceDataOffset, faceDataOffset,
); );
} }
return faceData; return faceData;
} };
/** /**
* Laplacian blur detection. * Laplacian blur detection.
* *
* Return an array of detected blur values, one for each face in {@link faces}. * Return an array of detected blur values, one for each face in {@link faces}.
* The face data is taken from the slice of {@link alignedFacesData}
* corresponding to each face of {@link faces}.
*/ */
const detectBlur = (alignedFaces: Float32Array, faces: Face[]): number[] => const detectBlur = (alignedFacesData: Float32Array, faces: Face[]): number[] =>
faces.map((face, i) => { faces.map((face, i) => {
const faceImage = createGrayscaleIntMatrixFromNormalized2List( const faceImage = grayscaleIntMatrixFromNormalized2List(
alignedFaces, alignedFacesData,
i, i,
mobileFaceNetFaceSize, mobileFaceNetFaceSize,
mobileFaceNetFaceSize, mobileFaceNetFaceSize,
@ -609,19 +615,20 @@ const matrixVariance = (matrix: number[][]): number => {
}; };
const mobileFaceNetFaceSize = 112; const mobileFaceNetFaceSize = 112;
const mobileFaceNetEmbeddingSize = 192;
/** /**
* Compute embeddings for the given {@link faceData}. * Compute embeddings for the given {@link faceData}.
* *
* The model used is MobileFaceNet, running in an ONNX runtime. * The model used is MobileFaceNet, running in an ONNX runtime.
*/ */
const faceEmbeddings = async ( const computeEmbeddings = async (
faceData: Float32Array, faceData: Float32Array,
): Promise<Array<FaceEmbedding>> => { ): Promise<Float32Array[]> => {
const outputData = await workerBridge.faceEmbeddings(faceData); const outputData = await workerBridge.faceEmbeddings(faceData);
const embeddingSize = 192; const embeddingSize = mobileFaceNetEmbeddingSize;
const embeddings = new Array<FaceEmbedding>( const embeddings = new Array<Float32Array>(
outputData.length / embeddingSize, outputData.length / embeddingSize,
); );
for (let i = 0; i < embeddings.length; i++) { for (let i = 0; i < embeddings.length; i++) {
@ -632,18 +639,9 @@ const faceEmbeddings = async (
return embeddings; return embeddings;
}; };
const convertFaceDetectionsToRelative = (mlFile: MlFileData) => { /**
for (let i = 0; i < mlFile.faces.length; i++) { * Convert the coordinates to between 0-1, normalized by the image's dimensions.
const face = mlFile.faces[i]; */
// Skip if somehow already relative.
if (face.detection.box.x + face.detection.box.width < 2) continue;
face.detection = relativeDetection(
face.detection,
mlFile.imageDimensions,
);
}
};
const relativeDetection = ( const relativeDetection = (
faceDetection: FaceDetection, faceDetection: FaceDetection,
{ width, height }: Dimensions, { width, height }: Dimensions,
@ -663,15 +661,13 @@ const relativeDetection = (
}; };
export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => { export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => {
const faceCrop = getFaceCrop(imageBitmap, face.detection); const faceCrop = extractFaceCrop(imageBitmap, face.alignment);
const blob = await imageBitmapToBlob(faceCrop);
faceCrop.close();
const blob = await imageBitmapToBlob(faceCrop.image); const cache = await blobCache("face-crops");
const cache = await openCache("face-crops");
await cache.put(face.id, blob); await cache.put(face.id, blob);
faceCrop.image.close();
return blob; return blob;
}; };
@ -681,67 +677,43 @@ const imageBitmapToBlob = (imageBitmap: ImageBitmap) => {
return canvas.convertToBlob({ type: "image/jpeg", quality: 0.8 }); return canvas.convertToBlob({ type: "image/jpeg", quality: 0.8 });
}; };
const getFaceCrop = ( const extractFaceCrop = (
imageBitmap: ImageBitmap, imageBitmap: ImageBitmap,
faceDetection: FaceDetection, alignment: FaceAlignment,
): FaceCrop => { ): ImageBitmap => {
const alignment = faceAlignment(faceDetection); // TODO-ML: Do we need to round twice?
const alignmentBox = roundBox(
const padding = 0.25; new Box({
const maxSize = 256;
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2, x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2, y: alignment.center.y - alignment.size / 2,
width: alignment.size, width: alignment.size,
height: alignment.size, height: alignment.size,
}).round(); }),
);
const padding = 0.25;
const scaleForPadding = 1 + padding * 2; const scaleForPadding = 1 + padding * 2;
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding));
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
width: maxSize,
height: maxSize,
});
return { // TODO-ML: The rotation doesn't seem to be used? it's set to 0.
image: faceImageBitmap, return cropWithRotation(imageBitmap, paddedBox, 0, 256);
imageBox: paddedBox,
};
}; };
export function cropWithRotation( const cropWithRotation = (
imageBitmap: ImageBitmap, imageBitmap: ImageBitmap,
cropBox: Box, cropBox: Box,
rotation?: number, rotation: number,
maxSize?: Dimensions, maxDimension: number,
minSize?: Dimensions, ) => {
) { const box = roundBox(cropBox);
const box = cropBox.round();
const outputSize = { width: box.width, height: box.height }; const outputSize = { width: box.width, height: box.height };
if (maxSize) {
const minScale = Math.min(
maxSize.width / box.width,
maxSize.height / box.height,
);
if (minScale < 1) {
outputSize.width = Math.round(minScale * box.width);
outputSize.height = Math.round(minScale * box.height);
}
}
if (minSize) { const scale = Math.min(maxDimension / box.width, maxDimension / box.height);
const maxScale = Math.max( if (scale < 1) {
minSize.width / box.width, outputSize.width = Math.round(scale * box.width);
minSize.height / box.height, outputSize.height = Math.round(scale * box.height);
);
if (maxScale > 1) {
outputSize.width = Math.round(maxScale * box.width);
outputSize.height = Math.round(maxScale * box.height);
} }
}
// log.info({ imageBitmap, box, outputSize });
const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height); const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height);
const offscreenCtx = offscreen.getContext("2d"); const offscreenCtx = offscreen.getContext("2d");
@ -773,4 +745,4 @@ export function cropWithRotation(
); );
return offscreen.transferToImageBitmap(); return offscreen.transferToImageBitmap();
} };

View file

@ -22,11 +22,6 @@ export class DedicatedMLWorker {
await downloadManager.init(APPS.PHOTOS, { token }); await downloadManager.init(APPS.PHOTOS, { token });
return mlService.sync(token, userID); return mlService.sync(token, userID);
} }
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
} }
expose(DedicatedMLWorker, self); expose(DedicatedMLWorker, self);

View file

@ -0,0 +1,37 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}

View file

@ -13,13 +13,6 @@ export interface Dimensions {
height: number; height: number;
} }
export interface IBoundingBox {
left: number;
top: number;
right: number;
bottom: number;
}
export interface IRect { export interface IRect {
x: number; x: number;
y: number; y: number;
@ -27,20 +20,6 @@ export interface IRect {
height: number; height: number;
} }
export const boxFromBoundingBox = ({
left,
top,
right,
bottom,
}: IBoundingBox) => {
return new Box({
x: left,
y: top,
width: right - left,
height: bottom - top,
});
};
export class Box implements IRect { export class Box implements IRect {
public x: number; public x: number;
public y: number; public y: number;
@ -53,36 +32,26 @@ export class Box implements IRect {
this.width = width; this.width = width;
this.height = height; this.height = height;
} }
public get topLeft(): Point {
return new Point(this.x, this.y);
} }
public get bottomRight(): Point { /** Round all the components of the box. */
return new Point(this.x + this.width, this.y + this.height); export const roundBox = (box: Box): Box => {
} const [x, y, width, height] = [box.x, box.y, box.width, box.height].map(
(val) => Math.round(val),
public round(): Box { );
const [x, y, width, height] = [
this.x,
this.y,
this.width,
this.height,
].map((val) => Math.round(val));
return new Box({ x, y, width, height }); return new Box({ x, y, width, height });
} };
}
export function enlargeBox(box: Box, factor: number = 1.5) { /** Increase the size of the given {@link box} by {@link factor}. */
export const enlargeBox = (box: Box, factor: number) => {
const center = new Point(box.x + box.width / 2, box.y + box.height / 2); const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
const newWidth = factor * box.width;
const newHeight = factor * box.height;
const size = new Point(box.width, box.height); return new Box({
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2); x: center.x - newWidth / 2,
y: center.y - newHeight / 2,
return boxFromBoundingBox({ width: newWidth,
left: center.x - newHalfSize.x, height: newHeight,
top: center.y - newHalfSize.y,
right: center.x + newHalfSize.x,
bottom: center.y + newHalfSize.y,
}); });
} };

View file

@ -1,11 +1,4 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import { Matrix, inverse } from "ml-matrix"; import { Matrix, inverse } from "ml-matrix";
import DownloadManager from "services/download";
import { FaceAlignment } from "services/face/types";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
/** /**
* Clamp {@link value} to between {@link min} and {@link max}, inclusive. * Clamp {@link value} to between {@link min} and {@link max}, inclusive.
@ -13,42 +6,11 @@ import { getRenderableImage } from "utils/file";
export const clamp = (value: number, min: number, max: number) => export const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value)); Math.min(max, Math.max(min, value));
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}
/** /**
* Returns the pixel value (RGB) at the given coordinates ({@link fx}, * Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation. * {@link fy}) using bilinear interpolation.
*/ */
export function pixelRGBBicubic( export function pixelRGBBilinear(
fx: number, fx: number,
fy: number, fy: number,
imageData: Uint8ClampedArray, imageData: Uint8ClampedArray,
@ -59,6 +21,72 @@ export function pixelRGBBicubic(
fx = clamp(fx, 0, imageWidth - 1); fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1); fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights.
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels.
const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1);
const bilinear = (val1: number, val2: number, val3: number, val4: number) =>
Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
// Return interpolated pixel colors.
return {
r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r),
g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g),
b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b),
};
}
const pixelRGBA = (
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) => {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
};
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation.
*/
const pixelRGBBicubic = (
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) => {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1); const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1; const px = x - 1;
const nx = x + 1; const nx = x + 1;
@ -171,97 +199,28 @@ export function pixelRGBBicubic(
// const c3 = cubic(dy, ip3, ic3, in3, ia3); // const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 }; return { r: c0, g: c1, b: c2 };
}
const pixelRGBA = (
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) => {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
}; };
/** /**
* Returns the pixel value (RGB) at the given coordinates ({@link fx}, * Transform {@link inputData} starting at {@link inputStartIndex}.
* {@link fy}) using bilinear interpolation.
*/ */
export function pixelRGBBilinear( export const warpAffineFloat32List = (
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights.
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels.
const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1);
const bilinear = (val1: number, val2: number, val3: number, val4: number) =>
Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
// Return interpolated pixel colors.
return {
r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r),
g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g),
b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b),
};
}
export function warpAffineFloat32List(
imageBitmap: ImageBitmap, imageBitmap: ImageBitmap,
faceAlignment: FaceAlignment, faceAlignmentAffineMatrix: number[][],
faceSize: number, faceSize: number,
inputData: Float32Array, inputData: Float32Array,
inputStartIndex: number, inputStartIndex: number,
): void { ): void => {
const { width, height } = imageBitmap;
// Get the pixel data. // Get the pixel data.
const offscreenCanvas = new OffscreenCanvas( const offscreenCanvas = new OffscreenCanvas(width, height);
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d"); const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height); ctx.drawImage(imageBitmap, 0, 0, width, height);
const imageData = ctx.getImageData( const imageData = ctx.getImageData(0, 0, width, height);
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data; const pixelData = imageData.data;
const transformationMatrix = faceAlignment.affineMatrix.map((row) => const transformationMatrix = faceAlignmentAffineMatrix.map((row) =>
row.map((val) => (val != 1.0 ? val * faceSize : 1.0)), row.map((val) => (val != 1.0 ? val * faceSize : 1.0)),
); // 3x3 ); // 3x3
@ -280,7 +239,7 @@ export function warpAffineFloat32List(
for (let yTrans = 0; yTrans < faceSize; ++yTrans) { for (let yTrans = 0; yTrans < faceSize; ++yTrans) {
for (let xTrans = 0; xTrans < faceSize; ++xTrans) { for (let xTrans = 0; xTrans < faceSize; ++xTrans) {
// Perform inverse affine transformation // Perform inverse affine transformation.
const xOrigin = const xOrigin =
a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10); a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10);
const yOrigin = const yOrigin =
@ -291,34 +250,32 @@ export function warpAffineFloat32List(
xOrigin, xOrigin,
yOrigin, yOrigin,
pixelData, pixelData,
imageBitmap.width, width,
imageBitmap.height, height,
); );
// Set the pixel in the input data // Set the pixel in the input data.
const index = (yTrans * faceSize + xTrans) * 3; const index = (yTrans * faceSize + xTrans) * 3;
inputData[inputStartIndex + index] = inputData[inputStartIndex + index] = rgbToBipolarFloat(r);
normalizePixelBetweenMinus1And1(r); inputData[inputStartIndex + index + 1] = rgbToBipolarFloat(g);
inputData[inputStartIndex + index + 1] = inputData[inputStartIndex + index + 2] = rgbToBipolarFloat(b);
normalizePixelBetweenMinus1And1(g);
inputData[inputStartIndex + index + 2] =
normalizePixelBetweenMinus1And1(b);
}
} }
} }
};
const normalizePixelBetweenMinus1And1 = (pixelValue: number) => /** Convert a RGB component 0-255 to a floating point value between -1 and 1. */
pixelValue / 127.5 - 1.0; const rgbToBipolarFloat = (pixelValue: number) => pixelValue / 127.5 - 1.0;
const unnormalizePixelFromBetweenMinus1And1 = (pixelValue: number) => /** Convert a floating point value between -1 and 1 to a RGB component 0-255. */
const bipolarFloatToRGB = (pixelValue: number) =>
clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255); clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255);
export function createGrayscaleIntMatrixFromNormalized2List( export const grayscaleIntMatrixFromNormalized2List = (
imageList: Float32Array, imageList: Float32Array,
faceNumber: number, faceNumber: number,
width: number, width: number,
height: number, height: number,
): number[][] { ): number[][] => {
const startIndex = faceNumber * width * height * 3; const startIndex = faceNumber * width * height * 3;
return Array.from({ length: height }, (_, y) => return Array.from({ length: height }, (_, y) =>
Array.from({ length: width }, (_, x) => { Array.from({ length: width }, (_, x) => {
@ -326,22 +283,13 @@ export function createGrayscaleIntMatrixFromNormalized2List(
const pixelIndex = startIndex + 3 * (y * width + x); const pixelIndex = startIndex + 3 * (y * width + x);
return clamp( return clamp(
Math.round( Math.round(
0.299 * 0.299 * bipolarFloatToRGB(imageList[pixelIndex]) +
unnormalizePixelFromBetweenMinus1And1( 0.587 * bipolarFloatToRGB(imageList[pixelIndex + 1]) +
imageList[pixelIndex], 0.114 * bipolarFloatToRGB(imageList[pixelIndex + 2]),
) +
0.587 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 1],
) +
0.114 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 2],
),
), ),
0, 0,
255, 255,
); );
}), }),
); );
} };

View file

@ -1,17 +1,19 @@
import log from "@/next/log"; export interface Person {
import mlIDbStorage from "services/face/db"; id: number;
import { Person } from "services/face/types"; name?: string;
import { clusterFaces } from "./cluster"; files: Array<number>;
import { saveFaceCrop } from "./f-index"; displayFaceId?: string;
import { fetchImageBitmap, getLocalFile } from "./image"; }
export const syncPeopleIndex = async () => {
// TODO-ML(MR): Forced disable clustering. It doesn't currently work, // TODO-ML(MR): Forced disable clustering. It doesn't currently work,
// need to finalize it before we move out of beta. // need to finalize it before we move out of beta.
// //
// > Error: Failed to execute 'transferToImageBitmap' on // > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed // > 'OffscreenCanvas': ImageBitmap construction failed
/* /*
export const syncPeopleIndex = async () => {
if ( if (
syncContext.outOfSyncFiles.length <= 0 || syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === batchSize && Math.random() < 0) (syncContext.nSyncedFiles === batchSize && Math.random() < 0)
@ -32,16 +34,16 @@ export const syncPeopleIndex = async () => {
if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) { if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) {
return; return;
} }
*/
// TODO: have faces addresable through fileId + faceId // TODO: have faces addresable through fileId + faceId
// to avoid index based addressing, which is prone to wrong results // to avoid index based addressing, which is prone to wrong results
// one way could be to match nearest face within threshold in the file // one way could be to match nearest face within threshold in the file
/*
const allFacesMap = const allFacesMap =
syncContext.allSyncedFacesMap ?? syncContext.allSyncedFacesMap ??
(syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap()); (syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap());
*/
// await this.init(); // await this.init();
@ -83,17 +85,18 @@ export const syncPeopleIndex = async () => {
: best, : best,
); );
if (personFace && !personFace.crop?.cacheKey) { if (personFace && !personFace.crop?.cacheKey) {
const file = await getLocalFile(personFace.fileId); const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file); const imageBitmap = await fetchImageBitmap(file);
await saveFaceCrop(imageBitmap, personFace); await saveFaceCrop(imageBitmap, personFace);
} }
const person: Person = { const person: Person = {
id: index, id: index,
files: faces.map((f) => f.fileId), files: faces.map((f) => f.fileId),
displayFaceId: personFace?.id, displayFaceId: personFace?.id,
faceCropCacheKey: personFace?.crop?.cacheKey,
}; };
await mlIDbStorage.putPerson(person); await mlIDbStorage.putPerson(person);
@ -108,3 +111,21 @@ export const syncPeopleIndex = async () => {
// await mlIDbStorage.setIndexVersion("people", filesVersion); // await mlIDbStorage.setIndexVersion("people", filesVersion);
}; };
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};
*/

View file

@ -0,0 +1,160 @@
import log from "@/next/log";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { putEmbedding } from "services/embeddingService";
import type { EnteFile } from "types/file";
import type { Point } from "./geom";
import type { Face, FaceDetection, MlFileData } from "./types";
export const putFaceEmbedding = async (
enteFile: EnteFile,
mlFileData: MlFileData,
) => {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
};
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
/* TODO
public client?: string;
public error?: boolean;
*/
public constructor(faces: ServerFace[], version: number) {
this.faces = faces;
this.version = version;
}
}
class ServerFace {
public faceID: string;
// TODO-ML: singular?
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Point[];
public constructor(box: ServerFaceBox, landmarks: Point[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
// TODO-ML: Add client UA and version
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, landmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}

View file

@ -1,6 +1,6 @@
import { Box, Point, boxFromBoundingBox } from "services/face/geom"; import { Box, Point } from "services/face/geom";
import { FaceDetection } from "services/face/types"; import type { FaceDetection } from "services/face/types";
// TODO-ML(MR): Do we need two separate Matrix libraries? // TODO-ML: Do we need two separate Matrix libraries?
// //
// Keeping this in a separate file so that we can audit this. If these can be // Keeping this in a separate file so that we can audit this. If these can be
// expressed using ml-matrix, then we can move the code to f-index. // expressed using ml-matrix, then we can move the code to f-index.
@ -22,43 +22,36 @@ export const transformFaceDetections = (
inBox: Box, inBox: Box,
toBox: Box, toBox: Box,
): FaceDetection[] => { ): FaceDetection[] => {
const transform = computeTransformToBox(inBox, toBox); const transform = boxTransformationMatrix(inBox, toBox);
return faceDetections.map((f) => { return faceDetections.map((f) => ({
const box = transformBox(f.box, transform); box: transformBox(f.box, transform),
const normLandmarks = f.landmarks; landmarks: f.landmarks.map((p) => transformPoint(p, transform)),
const landmarks = transformPoints(normLandmarks, transform); probability: f.probability,
return { }));
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
}; };
function computeTransformToBox(inBox: Box, toBox: Box): Matrix { const boxTransformationMatrix = (inBox: Box, toBox: Box): Matrix =>
return compose( compose(
translate(toBox.x, toBox.y), translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height), scale(toBox.width / inBox.width, toBox.height / inBox.height),
); );
}
function transformPoint(point: Point, transform: Matrix) { const transformPoint = (point: Point, transform: Matrix) => {
const txdPoint = applyToPoint(transform, point); const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y); return new Point(txdPoint.x, txdPoint.y);
} };
function transformPoints(points: Point[], transform: Matrix) { const transformBox = (box: Box, transform: Matrix) => {
return points?.map((p) => transformPoint(p, transform)); const topLeft = transformPoint(new Point(box.x, box.y), transform);
} const bottomRight = transformPoint(
new Point(box.x + box.width, box.y + box.height),
transform,
);
function transformBox(box: Box, transform: Matrix) { return new Box({
const topLeft = transformPoint(box.topLeft, transform); x: topLeft.x,
const bottomRight = transformPoint(box.bottomRight, transform); y: topLeft.y,
width: bottomRight.x - topLeft.x,
return boxFromBoundingBox({ height: bottomRight.y - topLeft.y,
left: topLeft.x,
top: topLeft.y,
right: bottomRight.x,
bottom: bottomRight.y,
}); });
} };

View file

@ -1,73 +1,35 @@
import { Box, Dimensions, Point } from "services/face/geom"; import { Box, Dimensions, Point } from "services/face/geom";
export declare type Cluster = Array<number>;
export declare type Landmark = Point;
export interface FaceDetection { export interface FaceDetection {
// box and landmarks is relative to image dimentions stored at mlFileData // box and landmarks is relative to image dimentions stored at mlFileData
box: Box; box: Box;
landmarks?: Array<Landmark>; landmarks?: Point[];
probability?: number; probability?: number;
} }
export interface DetectedFace {
fileId: number;
detection: FaceDetection;
}
export interface DetectedFaceWithId extends DetectedFace {
id: string;
}
export interface FaceCrop {
image: ImageBitmap;
// imageBox is relative to image dimentions stored at mlFileData
imageBox: Box;
}
export interface StoredFaceCrop {
cacheKey: string;
imageBox: Box;
}
export interface CroppedFace extends DetectedFaceWithId {
crop?: StoredFaceCrop;
}
export interface FaceAlignment { export interface FaceAlignment {
// TODO: remove affine matrix as rotation, size and center // TODO-ML: remove affine matrix as rotation, size and center
// are simple to store and use, affine matrix adds complexity while getting crop // are simple to store and use, affine matrix adds complexity while getting crop
affineMatrix: Array<Array<number>>; affineMatrix: number[][];
rotation: number; rotation: number;
// size and center is relative to image dimentions stored at mlFileData // size and center is relative to image dimentions stored at mlFileData
size: number; size: number;
center: Point; center: Point;
} }
export interface AlignedFace extends CroppedFace { export interface Face {
fileId: number;
detection: FaceDetection;
id: string;
alignment?: FaceAlignment; alignment?: FaceAlignment;
blurValue?: number; blurValue?: number;
}
export declare type FaceEmbedding = Float32Array; embedding?: Float32Array;
export interface FaceWithEmbedding extends AlignedFace {
embedding?: FaceEmbedding;
}
export interface Face extends FaceWithEmbedding {
personId?: number; personId?: number;
} }
export interface Person {
id: number;
name?: string;
files: Array<number>;
displayFaceId?: string;
faceCropCacheKey?: string;
}
export interface MlFileData { export interface MlFileData {
fileId: number; fileId: number;
faces?: Face[]; faces?: Face[];

View file

@ -1,20 +1,15 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log"; import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import ComlinkCryptoWorker, {
getDedicatedCryptoWorker,
} from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error"; import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import PQueue from "p-queue"; import PQueue from "p-queue";
import { putEmbedding } from "services/embeddingService"; import mlIDbStorage, {
import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db"; ML_SEARCH_CONFIG_NAME,
import { fetchImageBitmap, getLocalFile } from "services/face/image"; type MinimalPersistedFileData,
import { Face, FaceDetection, Landmark, MlFileData } from "services/face/types"; } from "services/face/db";
import { putFaceEmbedding } from "services/face/remote";
import { getLocalFiles } from "services/fileService"; import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user"; import { isInternalUserForML } from "utils/user";
import { indexFaces, saveFaceCrop } from "../face/f-index"; import { indexFaces } from "../face/f-index";
/** /**
* TODO-ML(MR): What and why. * TODO-ML(MR): What and why.
@ -57,63 +52,30 @@ class MLSyncContext {
public localFilesMap: Map<number, EnteFile>; public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[]; public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number; public nSyncedFiles: number;
public error?: Error; public error?: Error;
public syncQueue: PQueue; public syncQueue: PQueue;
// TODO: wheather to limit concurrent downloads
// private downloadQueue: PQueue;
private concurrency: number; constructor(token: string, userID: number) {
private comlinkCryptoWorker: Array<
ComlinkWorker<typeof DedicatedCryptoWorker>
>;
private enteWorkers: Array<any>;
constructor(token: string, userID: number, concurrency?: number) {
this.token = token; this.token = token;
this.userID = userID; this.userID = userID;
this.outOfSyncFiles = []; this.outOfSyncFiles = [];
this.nSyncedFiles = 0; this.nSyncedFiles = 0;
this.concurrency = concurrency ?? getConcurrency(); const concurrency = getConcurrency();
this.syncQueue = new PQueue({ concurrency });
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
// timeout on queue will keep the operation open till worker is terminated
this.syncQueue = new PQueue({ concurrency: this.concurrency });
logQueueStats(this.syncQueue, "sync");
// this.downloadQueue = new PQueue({ concurrency: 1 });
// logQueueStats(this.downloadQueue, 'download');
this.comlinkCryptoWorker = new Array(this.concurrency);
this.enteWorkers = new Array(this.concurrency);
}
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
}
return this.enteWorkers[wid];
} }
public async dispose() { public async dispose() {
this.localFilesMap = undefined; this.localFilesMap = undefined;
await this.syncQueue.onIdle(); await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners(); this.syncQueue.removeAllListeners();
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
enteComlinkWorker?.terminate();
}
} }
} }
export const getConcurrency = () => const getConcurrency = () =>
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
class MachineLearningService { class MachineLearningService {
private localSyncContext: Promise<MLSyncContext>; private localSyncContext: Promise<MLSyncContext>;
@ -139,16 +101,12 @@ class MachineLearningService {
return !error && nOutOfSyncFiles > 0; return !error && nOutOfSyncFiles > 0;
} }
public async regenerateFaceCrop(faceID: string) {
return regenerateFaceCrop(faceID);
}
private newMlData(fileId: number) { private newMlData(fileId: number) {
return { return {
fileId, fileId,
mlVersion: 0, mlVersion: 0,
errorCount: 0, errorCount: 0,
} as MlFileData; } as MinimalPersistedFileData;
} }
private async getLocalFilesMap(syncContext: MLSyncContext) { private async getLocalFilesMap(syncContext: MLSyncContext) {
@ -327,9 +285,6 @@ class MachineLearningService {
localFile?: globalThis.File, localFile?: globalThis.File,
) { ) {
try { try {
console.log(
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
);
const mlFileData = await this.syncFile(enteFile, localFile); const mlFileData = await this.syncFile(enteFile, localFile);
syncContext.nSyncedFiles += 1; syncContext.nSyncedFiles += 1;
return mlFileData; return mlFileData;
@ -363,44 +318,17 @@ class MachineLearningService {
} }
private async syncFile(enteFile: EnteFile, localFile?: globalThis.File) { private async syncFile(enteFile: EnteFile, localFile?: globalThis.File) {
const oldMlFile = await this.getMLFileData(enteFile.id); const oldMlFile = await mlIDbStorage.getFile(enteFile.id);
if (oldMlFile && oldMlFile.mlVersion) { if (oldMlFile && oldMlFile.mlVersion) {
return oldMlFile; return oldMlFile;
} }
const newMlFile = await indexFaces(enteFile, localFile); const newMlFile = await indexFaces(enteFile, localFile);
await this.persistOnServer(newMlFile, enteFile); await putFaceEmbedding(enteFile, newMlFile);
await mlIDbStorage.putFile(newMlFile); await mlIDbStorage.putFile(newMlFile);
return newMlFile; return newMlFile;
} }
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
}
private async getMLFileData(fileId: number) {
return mlIDbStorage.getFile(fileId);
}
private async persistMLFileSyncError(enteFile: EnteFile, e: Error) { private async persistMLFileSyncError(enteFile: EnteFile, e: Error) {
try { try {
await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => { await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => {
@ -420,161 +348,3 @@ class MachineLearningService {
} }
export default new MachineLearningService(); export default new MachineLearningService();
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
/* TODO
public client?: string;
public error?: boolean;
*/
public constructor(faces: ServerFace[], version: number) {
this.faces = faces;
this.version = version;
}
}
class ServerFace {
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
// TODO: Add client UA and version
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};

View file

@ -9,7 +9,6 @@ import { createFaceComlinkWorker } from "services/face";
import mlIDbStorage from "services/face/db"; import mlIDbStorage from "services/face/db";
import type { DedicatedMLWorker } from "services/face/face.worker"; import type { DedicatedMLWorker } from "services/face/face.worker";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";
import { logQueueStats } from "./machineLearningService";
export type JobState = "Scheduled" | "Running" | "NotScheduled"; export type JobState = "Scheduled" | "Running" | "NotScheduled";
@ -309,3 +308,15 @@ class MLWorkManager {
} }
export default new MLWorkManager(); export default new MLWorkManager();
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}

View file

@ -3,7 +3,7 @@ import log from "@/next/log";
import * as chrono from "chrono-node"; import * as chrono from "chrono-node";
import { t } from "i18next"; import { t } from "i18next";
import mlIDbStorage from "services/face/db"; import mlIDbStorage from "services/face/db";
import { Person } from "services/face/types"; import type { Person } from "services/face/people";
import { defaultMLVersion } from "services/machineLearning/machineLearningService"; import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { Collection } from "types/collection"; import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity"; import { EntityType, LocationTag, LocationTagData } from "types/entity";

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type"; import { FILE_TYPE } from "@/media/file-type";
import { IndexStatus } from "services/face/db"; import { IndexStatus } from "services/face/db";
import { Person } from "services/face/types"; import type { Person } from "services/face/people";
import { City } from "services/locationSearchService"; import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity"; import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file"; import { EnteFile } from "types/file";

View file

@ -1,4 +1,4 @@
import { clearCaches } from "@/next/blob-cache"; import { clearBlobCaches } from "@/next/blob-cache";
import log from "@/next/log"; import log from "@/next/log";
import InMemoryStore from "@ente/shared/storage/InMemoryStore"; import InMemoryStore from "@ente/shared/storage/InMemoryStore";
import localForage from "@ente/shared/storage/localForage"; import localForage from "@ente/shared/storage/localForage";
@ -43,7 +43,7 @@ export const accountLogout = async () => {
log.error("Ignoring error during logout (local forage)", e); log.error("Ignoring error during logout (local forage)", e);
} }
try { try {
await clearCaches(); await clearBlobCaches();
} catch (e) { } catch (e) {
log.error("Ignoring error during logout (cache)", e); log.error("Ignoring error during logout (cache)", e);
} }

View file

@ -20,8 +20,8 @@ export type BlobCacheNamespace = (typeof blobCacheNames)[number];
* *
* This cache is suitable for storing large amounts of data (entire files). * This cache is suitable for storing large amounts of data (entire files).
* *
* To obtain a cache for a given namespace, use {@link openCache}. To clear all * To obtain a cache for a given namespace, use {@link openBlobCache}. To clear all
* cached data (e.g. during logout), use {@link clearCaches}. * cached data (e.g. during logout), use {@link clearBlobCaches}.
* *
* [Note: Caching files] * [Note: Caching files]
* *
@ -69,14 +69,31 @@ export interface BlobCache {
delete: (key: string) => Promise<boolean>; delete: (key: string) => Promise<boolean>;
} }
const cachedCaches = new Map<BlobCacheNamespace, BlobCache>();
/** /**
* Return the {@link BlobCache} corresponding to the given {@link name}. * Return the {@link BlobCache} corresponding to the given {@link name}.
* *
* This is a wrapper over {@link openBlobCache} that caches (pun intended) the
* cache and returns the same one each time it is called with the same name.
* It'll open the cache lazily the first time it is invoked.
*/
export const blobCache = async (
name: BlobCacheNamespace,
): Promise<BlobCache> => {
let c = cachedCaches.get(name);
if (!c) cachedCaches.set(name, (c = await openBlobCache(name)));
return c;
};
/**
* Create a new {@link BlobCache} corresponding to the given {@link name}.
*
* @param name One of the arbitrary but predefined namespaces of type * @param name One of the arbitrary but predefined namespaces of type
* {@link BlobCacheNamespace} which group related data and allow us to use the * {@link BlobCacheNamespace} which group related data and allow us to use the
* same key across namespaces. * same key across namespaces.
*/ */
export const openCache = async ( export const openBlobCache = async (
name: BlobCacheNamespace, name: BlobCacheNamespace,
): Promise<BlobCache> => ): Promise<BlobCache> =>
isElectron() ? openOPFSCacheWeb(name) : openWebCache(name); isElectron() ? openOPFSCacheWeb(name) : openWebCache(name);
@ -194,7 +211,7 @@ export const cachedOrNew = async (
key: string, key: string,
get: () => Promise<Blob>, get: () => Promise<Blob>,
): Promise<Blob> => { ): Promise<Blob> => {
const cache = await openCache(cacheName); const cache = await openBlobCache(cacheName);
const cachedBlob = await cache.get(key); const cachedBlob = await cache.get(key);
if (cachedBlob) return cachedBlob; if (cachedBlob) return cachedBlob;
@ -204,15 +221,17 @@ export const cachedOrNew = async (
}; };
/** /**
* Delete all cached data. * Delete all cached data, including cached caches.
* *
* Meant for use during logout, to reset the state of the user's account. * Meant for use during logout, to reset the state of the user's account.
*/ */
export const clearCaches = async () => export const clearBlobCaches = async () => {
isElectron() ? clearOPFSCaches() : clearWebCaches(); cachedCaches.clear();
return isElectron() ? clearOPFSCaches() : clearWebCaches();
};
const clearWebCaches = async () => { const clearWebCaches = async () => {
await Promise.all(blobCacheNames.map((name) => caches.delete(name))); await Promise.allSettled(blobCacheNames.map((name) => caches.delete(name)));
}; };
const clearOPFSCaches = async () => { const clearOPFSCaches = async () => {