[web] ML cleanup - Part 7/x (#1771)

This commit is contained in:
Manav Rathi 2024-05-19 20:19:59 +05:30 committed by GitHub
commit 69f06f753c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 566 additions and 855 deletions

View file

@ -9,7 +9,7 @@ import { useCallback, useContext, useEffect, useRef, useState } from "react";
import { components } from "react-select";
import AsyncSelect from "react-select/async";
import { InputActionMeta } from "react-select/src/types";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { City } from "services/locationSearchService";
import {
getAutoCompleteSuggestions,

View file

@ -1,10 +1,11 @@
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { Skeleton, styled } from "@mui/material";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import React, { useEffect, useState } from "react";
import mlIDbStorage from "services/face/db";
import { Face, Person, type MlFileData } from "services/face/types";
import type { Person } from "services/face/people";
import { EnteFile } from "types/file";
const FaceChipContainer = styled("div")`
@ -57,10 +58,7 @@ export const PeopleList = React.memo((props: PeopleListProps) => {
props.onSelect && props.onSelect(person, index)
}
>
<FaceCropImageView
faceID={person.displayFaceId}
cacheKey={person.faceCropCacheKey}
/>
<FaceCropImageView faceID={person.displayFaceId} />
</FaceChip>
))}
</FaceChipContainer>
@ -108,7 +106,7 @@ export function UnidentifiedFaces(props: {
file: EnteFile;
updateMLDataIndex: number;
}) {
const [faces, setFaces] = useState<Array<Face>>([]);
const [faces, setFaces] = useState<{ id: string }[]>([]);
useEffect(() => {
let didCancel = false;
@ -136,10 +134,7 @@ export function UnidentifiedFaces(props: {
{faces &&
faces.map((face, index) => (
<FaceChip key={index}>
<FaceCropImageView
faceID={face.id}
cacheKey={face.crop?.cacheKey}
/>
<FaceCropImageView faceID={face.id} />
</FaceChip>
))}
</FaceChipContainer>
@ -149,13 +144,9 @@ export function UnidentifiedFaces(props: {
interface FaceCropImageViewProps {
faceID: string;
cacheKey?: string;
}
const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
faceID,
cacheKey,
}) => {
const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({ faceID }) => {
const [objectURL, setObjectURL] = useState<string | undefined>();
useEffect(() => {
@ -165,12 +156,16 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
if (faceID && electron) {
electron
.legacyFaceCrop(faceID)
/*
cachedOrNew("face-crops", cacheKey, async () => {
return machineLearningService.regenerateFaceCrop(
faceId,
);
})*/
.then(async (data) => {
if (data) return data;
/*
TODO(MR): regen if needed and get this to work on web too.
cachedOrNew("face-crops", cacheKey, async () => {
return regenerateFaceCrop(faceId);
})*/
const cache = await blobCache("face-crops");
return await cache.get(faceID);
})
.then((data) => {
if (data) {
const blob = new Blob([data]);
@ -183,7 +178,7 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
didCancel = true;
if (objectURL) URL.revokeObjectURL(objectURL);
};
}, [faceID, cacheKey]);
}, [faceID]);
return objectURL ? (
<img src={objectURL} />
@ -192,9 +187,9 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
);
};
async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
async function getPeopleList(file: EnteFile): Promise<Person[]> {
let startTime = Date.now();
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
const mlFileData = await mlIDbStorage.getFile(file.id);
log.info(
"getPeopleList:mlFilesStore:getItem",
Date.now() - startTime,
@ -226,8 +221,8 @@ async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
return peopleList;
}
async function getUnidentifiedFaces(file: EnteFile): Promise<Array<Face>> {
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
async function getUnidentifiedFaces(file: EnteFile): Promise<{ id: string }[]> {
const mlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined,

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import { openCache, type BlobCache } from "@/next/blob-cache";
import { blobCache, type BlobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
@ -91,7 +91,7 @@ class DownloadManagerImpl {
}
this.downloadClient = createDownloadClient(app, tokens);
try {
this.thumbnailCache = await openCache("thumbs");
this.thumbnailCache = await blobCache("thumbs");
} catch (e) {
log.error(
"Failed to open thumbnail cache, will continue without it",
@ -100,7 +100,7 @@ class DownloadManagerImpl {
}
// TODO (MR): Revisit full file caching cf disk space usage
// try {
// if (isElectron()) this.fileCache = await openCache("files");
// if (isElectron()) this.fileCache = await cache("files");
// } catch (e) {
// log.error("Failed to open file cache, will continue without it", e);
// }

View file

@ -7,7 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { FileML } from "services/machineLearning/machineLearningService";
import { FileML } from "services/face/remote";
import type {
Embedding,
EmbeddingModel,

View file

@ -1,8 +1,9 @@
import { Hdbscan, type DebugInfo } from "hdbscan";
import { type Cluster } from "services/face/types";
export type Cluster = number[];
export interface ClusterFacesResult {
clusters: Array<Cluster>;
clusters: Cluster[];
noise: Cluster;
debugInfo?: DebugInfo;
}

View file

@ -9,7 +9,8 @@ import {
openDB,
} from "idb";
import isElectron from "is-electron";
import { Face, MlFileData, Person } from "services/face/types";
import type { Person } from "services/face/people";
import type { MlFileData } from "services/face/types";
import {
DEFAULT_ML_SEARCH_CONFIG,
MAX_ML_SYNC_ERROR_COUNT,
@ -23,6 +24,18 @@ export interface IndexStatus {
peopleIndexSynced: boolean;
}
/**
* TODO(MR): Transient type with an intersection of values that both existing
* and new types during the migration will have. Eventually we'll store the the
* server ML data shape here exactly.
*/
export interface MinimalPersistedFileData {
fileId: number;
mlVersion: number;
errorCount: number;
faces?: { personId?: number; id: string }[];
}
interface Config {}
export const ML_SEARCH_CONFIG_NAME = "ml-search";
@ -31,7 +44,7 @@ const MLDATA_DB_NAME = "mldata";
interface MLDb extends DBSchema {
files: {
key: number;
value: MlFileData;
value: MinimalPersistedFileData;
indexes: { mlVersion: [number, number] };
};
people: {
@ -211,38 +224,6 @@ class MLIDbStorage {
await this.db;
}
public async getAllFileIds() {
const db = await this.db;
return db.getAllKeys("files");
}
public async putAllFilesInTx(mlFiles: Array<MlFileData>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
await tx.done;
}
public async removeAllFilesInTx(fileIds: Array<number>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
await tx.done;
}
public async newTransaction<
Name extends StoreNames<MLDb>,
Mode extends IDBTransactionMode = "readonly",
>(storeNames: Name, mode?: Mode) {
const db = await this.db;
return db.transaction(storeNames, mode);
}
public async commit(tx: IDBPTransaction<MLDb>) {
return tx.done;
}
public async getAllFileIdsForUpdate(
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
@ -276,16 +257,11 @@ class MLIDbStorage {
return fileIds;
}
public async getFile(fileId: number) {
public async getFile(fileId: number): Promise<MinimalPersistedFileData> {
const db = await this.db;
return db.get("files", fileId);
}
public async getAllFiles() {
const db = await this.db;
return db.getAll("files");
}
public async putFile(mlFile: MlFileData) {
const db = await this.db;
return db.put("files", mlFile);
@ -293,7 +269,7 @@ class MLIDbStorage {
public async upsertFileInTx(
fileId: number,
upsert: (mlFile: MlFileData) => MlFileData,
upsert: (mlFile: MinimalPersistedFileData) => MinimalPersistedFileData,
) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
@ -306,7 +282,7 @@ class MLIDbStorage {
}
public async putAllFiles(
mlFiles: Array<MlFileData>,
mlFiles: MinimalPersistedFileData[],
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
@ -319,44 +295,6 @@ class MLIDbStorage {
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
}
public async getFace(fileID: number, faceId: string) {
const file = await this.getFile(fileID);
const face = file.faces.filter((f) => f.id === faceId);
return face[0];
}
public async getAllFacesMap() {
const startTime = Date.now();
const db = await this.db;
const allFiles = await db.getAll("files");
const allFacesMap = new Map<number, Array<Face>>();
allFiles.forEach(
(mlFileData) =>
mlFileData.faces &&
allFacesMap.set(mlFileData.fileId, mlFileData.faces),
);
log.info("getAllFacesMap", Date.now() - startTime, "ms");
return allFacesMap;
}
public async updateFaces(allFacesMap: Map<number, Face[]>) {
const startTime = Date.now();
const db = await this.db;
const tx = db.transaction("files", "readwrite");
let cursor = await tx.store.openCursor();
while (cursor) {
if (allFacesMap.has(cursor.key)) {
const mlFileData = { ...cursor.value };
mlFileData.faces = allFacesMap.get(cursor.key);
cursor.update(mlFileData);
}
cursor = await cursor.continue();
}
await tx.done;
log.info("updateFaces", Date.now() - startTime, "ms");
}
public async getPerson(id: number) {
const db = await this.db;
return db.get("people", id);
@ -367,21 +305,6 @@ class MLIDbStorage {
return db.getAll("people");
}
public async putPerson(person: Person) {
const db = await this.db;
return db.put("people", person);
}
public async clearAllPeople() {
const db = await this.db;
return db.clear("people");
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get("versions", index);
}
public async incrementIndexVersion(index: StoreNames<MLDb>) {
if (index === "versions") {
throw new Error("versions store can not be versioned");
@ -396,11 +319,6 @@ class MLIDbStorage {
return version;
}
public async setIndexVersion(index: string, version: number) {
const db = await this.db;
return db.put("versions", version, index);
}
public async getConfig<T extends Config>(name: string, def: T) {
const db = await this.db;
const tx = db.transaction("configs", "readwrite");
@ -464,66 +382,6 @@ class MLIDbStorage {
peopleIndexVersion === filesIndexVersion,
};
}
// for debug purpose
public async getAllMLData() {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readonly");
const allMLData: any = {};
for (const store of tx.objectStoreNames) {
const keys = await tx.objectStore(store).getAllKeys();
const data = await tx.objectStore(store).getAll();
allMLData[store] = {};
for (let i = 0; i < keys.length; i++) {
allMLData[store][keys[i]] = data[i];
}
}
await tx.done;
const files = allMLData["files"];
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Array.from(f.embedding)),
);
}
return allMLData;
}
// for debug purpose, this will overwrite all data
public async putAllMLData(allMLData: Map<string, any>) {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readwrite");
for (const store of tx.objectStoreNames) {
const records = allMLData[store];
if (!records) {
continue;
}
const txStore = tx.objectStore(store);
if (store === "files") {
const files = records;
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Float32Array.from(f.embedding)),
);
}
}
await txStore.clear();
for (const key of Object.keys(records)) {
if (txStore.keyPath) {
txStore.put(records[key]);
} else {
txStore.put(records[key], key);
}
}
}
await tx.done;
}
}
export default new MLIDbStorage();

View file

@ -1,26 +1,29 @@
import { FILE_TYPE } from "@/media/file-type";
import { openCache } from "@/next/blob-cache";
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import { Matrix } from "ml-matrix";
import { Box, Dimensions, Point, enlargeBox } from "services/face/geom";
import {
Box,
Dimensions,
Point,
enlargeBox,
roundBox,
} from "services/face/geom";
import type {
Face,
FaceAlignment,
FaceCrop,
FaceDetection,
FaceEmbedding,
type MlFileData,
MlFileData,
} from "services/face/types";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { getSimilarityTransformation } from "similarity-transformation";
import type { EnteFile } from "types/file";
import { fetchImageBitmap, getLocalFileImageBitmap } from "./file";
import {
clamp,
createGrayscaleIntMatrixFromNormalized2List,
fetchImageBitmap,
getLocalFileImageBitmap,
grayscaleIntMatrixFromNormalized2List,
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
@ -85,47 +88,49 @@ const fetchOrCreateImageBitmap = async (
const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
const fileID = enteFile.id;
const { width, height } = imageBitmap;
const imageDimensions: Dimensions = imageBitmap;
const mlFile: MlFileData = {
fileId: fileID,
mlVersion: defaultMLVersion,
imageDimensions: { width, height },
imageDimensions,
errorCount: 0,
};
const faceDetections = await detectFaces(imageBitmap);
const detectedFaces = faceDetections.map((detection) => ({
id: makeFaceID(fileID, detection, mlFile.imageDimensions),
id: makeFaceID(fileID, detection, imageDimensions),
fileId: fileID,
detection,
}));
mlFile.faces = detectedFaces;
if (detectedFaces.length > 0) {
await Promise.all(
detectedFaces.map((face) => saveFaceCrop(imageBitmap, face)),
);
const alignments: FaceAlignment[] = [];
// Execute the face alignment calculations
for (const face of mlFile.faces) {
face.alignment = faceAlignment(face.detection);
const alignment = faceAlignment(face.detection);
face.alignment = alignment;
alignments.push(alignment);
await saveFaceCrop(imageBitmap, face);
}
// Extract face images and convert to Float32Array
const faceAlignments = mlFile.faces.map((f) => f.alignment);
const alignedFacesData = await extractFaceImagesToFloat32(
faceAlignments,
mobileFaceNetFaceSize,
const alignedFacesData = convertToMobileFaceNetInput(
imageBitmap,
alignments,
);
const blurValues = detectBlur(alignedFacesData, mlFile.faces);
mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
const embeddings = await faceEmbeddings(alignedFacesData);
const embeddings = await computeEmbeddings(alignedFacesData);
mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
convertFaceDetectionsToRelative(mlFile);
// TODO-ML: Skip if somehow already relative. But why would it be?
// if (face.detection.box.x + face.detection.box.width < 2) continue;
mlFile.faces.forEach((face) => {
face.detection = relativeDetection(face.detection, imageDimensions);
});
}
return mlFile;
@ -170,8 +175,7 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
const requiredWidth = 640;
const requiredHeight = 640;
const width = imageBitmap.width;
const height = imageBitmap.height;
const { width, height } = imageBitmap;
// Create an OffscreenCanvas and set its size.
const offscreenCanvas = new OffscreenCanvas(width, height);
@ -405,37 +409,39 @@ const faceAlignmentUsingSimilarityTransform = (
return { affineMatrix, center, size, rotation };
};
async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
faceSize: number,
image: ImageBitmap,
): Promise<Float32Array> {
const convertToMobileFaceNetInput = (
imageBitmap: ImageBitmap,
faceAlignments: FaceAlignment[],
): Float32Array => {
const faceSize = mobileFaceNetFaceSize;
const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3,
);
for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i];
const { affineMatrix } = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List(
image,
alignedFace,
imageBitmap,
affineMatrix,
faceSize,
faceData,
faceDataOffset,
);
}
return faceData;
}
};
/**
* Laplacian blur detection.
*
* Return an array of detected blur values, one for each face in {@link faces}.
* The face data is taken from the slice of {@link alignedFacesData}
* corresponding to each face of {@link faces}.
*/
const detectBlur = (alignedFaces: Float32Array, faces: Face[]): number[] =>
const detectBlur = (alignedFacesData: Float32Array, faces: Face[]): number[] =>
faces.map((face, i) => {
const faceImage = createGrayscaleIntMatrixFromNormalized2List(
alignedFaces,
const faceImage = grayscaleIntMatrixFromNormalized2List(
alignedFacesData,
i,
mobileFaceNetFaceSize,
mobileFaceNetFaceSize,
@ -609,19 +615,20 @@ const matrixVariance = (matrix: number[][]): number => {
};
const mobileFaceNetFaceSize = 112;
const mobileFaceNetEmbeddingSize = 192;
/**
* Compute embeddings for the given {@link faceData}.
*
* The model used is MobileFaceNet, running in an ONNX runtime.
*/
const faceEmbeddings = async (
const computeEmbeddings = async (
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> => {
): Promise<Float32Array[]> => {
const outputData = await workerBridge.faceEmbeddings(faceData);
const embeddingSize = 192;
const embeddings = new Array<FaceEmbedding>(
const embeddingSize = mobileFaceNetEmbeddingSize;
const embeddings = new Array<Float32Array>(
outputData.length / embeddingSize,
);
for (let i = 0; i < embeddings.length; i++) {
@ -632,18 +639,9 @@ const faceEmbeddings = async (
return embeddings;
};
const convertFaceDetectionsToRelative = (mlFile: MlFileData) => {
for (let i = 0; i < mlFile.faces.length; i++) {
const face = mlFile.faces[i];
// Skip if somehow already relative.
if (face.detection.box.x + face.detection.box.width < 2) continue;
face.detection = relativeDetection(
face.detection,
mlFile.imageDimensions,
);
}
};
/**
* Convert the coordinates to between 0-1, normalized by the image's dimensions.
*/
const relativeDetection = (
faceDetection: FaceDetection,
{ width, height }: Dimensions,
@ -663,15 +661,13 @@ const relativeDetection = (
};
export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => {
const faceCrop = getFaceCrop(imageBitmap, face.detection);
const faceCrop = extractFaceCrop(imageBitmap, face.alignment);
const blob = await imageBitmapToBlob(faceCrop);
faceCrop.close();
const blob = await imageBitmapToBlob(faceCrop.image);
const cache = await openCache("face-crops");
const cache = await blobCache("face-crops");
await cache.put(face.id, blob);
faceCrop.image.close();
return blob;
};
@ -681,68 +677,44 @@ const imageBitmapToBlob = (imageBitmap: ImageBitmap) => {
return canvas.convertToBlob({ type: "image/jpeg", quality: 0.8 });
};
const getFaceCrop = (
const extractFaceCrop = (
imageBitmap: ImageBitmap,
faceDetection: FaceDetection,
): FaceCrop => {
const alignment = faceAlignment(faceDetection);
alignment: FaceAlignment,
): ImageBitmap => {
// TODO-ML: Do we need to round twice?
const alignmentBox = roundBox(
new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}),
);
const padding = 0.25;
const maxSize = 256;
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
const scaleForPadding = 1 + padding * 2;
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
width: maxSize,
height: maxSize,
});
const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding));
return {
image: faceImageBitmap,
imageBox: paddedBox,
};
// TODO-ML: The rotation doesn't seem to be used? it's set to 0.
return cropWithRotation(imageBitmap, paddedBox, 0, 256);
};
export function cropWithRotation(
const cropWithRotation = (
imageBitmap: ImageBitmap,
cropBox: Box,
rotation?: number,
maxSize?: Dimensions,
minSize?: Dimensions,
) {
const box = cropBox.round();
rotation: number,
maxDimension: number,
) => {
const box = roundBox(cropBox);
const outputSize = { width: box.width, height: box.height };
if (maxSize) {
const minScale = Math.min(
maxSize.width / box.width,
maxSize.height / box.height,
);
if (minScale < 1) {
outputSize.width = Math.round(minScale * box.width);
outputSize.height = Math.round(minScale * box.height);
}
}
if (minSize) {
const maxScale = Math.max(
minSize.width / box.width,
minSize.height / box.height,
);
if (maxScale > 1) {
outputSize.width = Math.round(maxScale * box.width);
outputSize.height = Math.round(maxScale * box.height);
}
const scale = Math.min(maxDimension / box.width, maxDimension / box.height);
if (scale < 1) {
outputSize.width = Math.round(scale * box.width);
outputSize.height = Math.round(scale * box.height);
}
// log.info({ imageBitmap, box, outputSize });
const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height);
const offscreenCtx = offscreen.getContext("2d");
offscreenCtx.imageSmoothingQuality = "high";
@ -773,4 +745,4 @@ export function cropWithRotation(
);
return offscreen.transferToImageBitmap();
}
};

View file

@ -22,11 +22,6 @@ export class DedicatedMLWorker {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.sync(token, userID);
}
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
}
expose(DedicatedMLWorker, self);

View file

@ -0,0 +1,37 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}

View file

@ -13,13 +13,6 @@ export interface Dimensions {
height: number;
}
export interface IBoundingBox {
left: number;
top: number;
right: number;
bottom: number;
}
export interface IRect {
x: number;
y: number;
@ -27,20 +20,6 @@ export interface IRect {
height: number;
}
export const boxFromBoundingBox = ({
left,
top,
right,
bottom,
}: IBoundingBox) => {
return new Box({
x: left,
y: top,
width: right - left,
height: bottom - top,
});
};
export class Box implements IRect {
public x: number;
public y: number;
@ -53,36 +32,26 @@ export class Box implements IRect {
this.width = width;
this.height = height;
}
public get topLeft(): Point {
return new Point(this.x, this.y);
}
public get bottomRight(): Point {
return new Point(this.x + this.width, this.y + this.height);
}
public round(): Box {
const [x, y, width, height] = [
this.x,
this.y,
this.width,
this.height,
].map((val) => Math.round(val));
return new Box({ x, y, width, height });
}
}
export function enlargeBox(box: Box, factor: number = 1.5) {
/** Round all the components of the box. */
export const roundBox = (box: Box): Box => {
const [x, y, width, height] = [box.x, box.y, box.width, box.height].map(
(val) => Math.round(val),
);
return new Box({ x, y, width, height });
};
/** Increase the size of the given {@link box} by {@link factor}. */
export const enlargeBox = (box: Box, factor: number) => {
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
const newWidth = factor * box.width;
const newHeight = factor * box.height;
const size = new Point(box.width, box.height);
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
return boxFromBoundingBox({
left: center.x - newHalfSize.x,
top: center.y - newHalfSize.y,
right: center.x + newHalfSize.x,
bottom: center.y + newHalfSize.y,
return new Box({
x: center.x - newWidth / 2,
y: center.y - newHeight / 2,
width: newWidth,
height: newHeight,
});
}
};

View file

@ -1,11 +1,4 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import { Matrix, inverse } from "ml-matrix";
import DownloadManager from "services/download";
import { FaceAlignment } from "services/face/types";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
/**
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
@ -13,42 +6,11 @@ import { getRenderableImage } from "utils/file";
export const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value));
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation.
* {@link fy}) using bilinear interpolation.
*/
export function pixelRGBBicubic(
export function pixelRGBBilinear(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
@ -59,6 +21,72 @@ export function pixelRGBBicubic(
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights.
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels.
const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1);
const bilinear = (val1: number, val2: number, val3: number, val4: number) =>
Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
// Return interpolated pixel colors.
return {
r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r),
g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g),
b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b),
};
}
const pixelRGBA = (
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) => {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
};
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation.
*/
const pixelRGBBicubic = (
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) => {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1;
const nx = x + 1;
@ -171,97 +199,28 @@ export function pixelRGBBicubic(
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 };
}
const pixelRGBA = (
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) => {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
};
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bilinear interpolation.
* Transform {@link inputData} starting at {@link inputStartIndex}.
*/
export function pixelRGBBilinear(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights.
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels.
const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1);
const bilinear = (val1: number, val2: number, val3: number, val4: number) =>
Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
// Return interpolated pixel colors.
return {
r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r),
g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g),
b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b),
};
}
export function warpAffineFloat32List(
export const warpAffineFloat32List = (
imageBitmap: ImageBitmap,
faceAlignment: FaceAlignment,
faceAlignmentAffineMatrix: number[][],
faceSize: number,
inputData: Float32Array,
inputStartIndex: number,
): void {
): void => {
const { width, height } = imageBitmap;
// Get the pixel data.
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const offscreenCanvas = new OffscreenCanvas(width, height);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
ctx.drawImage(imageBitmap, 0, 0, width, height);
const imageData = ctx.getImageData(0, 0, width, height);
const pixelData = imageData.data;
const transformationMatrix = faceAlignment.affineMatrix.map((row) =>
const transformationMatrix = faceAlignmentAffineMatrix.map((row) =>
row.map((val) => (val != 1.0 ? val * faceSize : 1.0)),
); // 3x3
@ -280,7 +239,7 @@ export function warpAffineFloat32List(
for (let yTrans = 0; yTrans < faceSize; ++yTrans) {
for (let xTrans = 0; xTrans < faceSize; ++xTrans) {
// Perform inverse affine transformation
// Perform inverse affine transformation.
const xOrigin =
a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10);
const yOrigin =
@ -291,34 +250,32 @@ export function warpAffineFloat32List(
xOrigin,
yOrigin,
pixelData,
imageBitmap.width,
imageBitmap.height,
width,
height,
);
// Set the pixel in the input data
// Set the pixel in the input data.
const index = (yTrans * faceSize + xTrans) * 3;
inputData[inputStartIndex + index] =
normalizePixelBetweenMinus1And1(r);
inputData[inputStartIndex + index + 1] =
normalizePixelBetweenMinus1And1(g);
inputData[inputStartIndex + index + 2] =
normalizePixelBetweenMinus1And1(b);
inputData[inputStartIndex + index] = rgbToBipolarFloat(r);
inputData[inputStartIndex + index + 1] = rgbToBipolarFloat(g);
inputData[inputStartIndex + index + 2] = rgbToBipolarFloat(b);
}
}
}
};
const normalizePixelBetweenMinus1And1 = (pixelValue: number) =>
pixelValue / 127.5 - 1.0;
/** Convert a RGB component 0-255 to a floating point value between -1 and 1. */
const rgbToBipolarFloat = (pixelValue: number) => pixelValue / 127.5 - 1.0;
const unnormalizePixelFromBetweenMinus1And1 = (pixelValue: number) =>
/** Convert a floating point value between -1 and 1 to a RGB component 0-255. */
const bipolarFloatToRGB = (pixelValue: number) =>
clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255);
export function createGrayscaleIntMatrixFromNormalized2List(
export const grayscaleIntMatrixFromNormalized2List = (
imageList: Float32Array,
faceNumber: number,
width: number,
height: number,
): number[][] {
): number[][] => {
const startIndex = faceNumber * width * height * 3;
return Array.from({ length: height }, (_, y) =>
Array.from({ length: width }, (_, x) => {
@ -326,22 +283,13 @@ export function createGrayscaleIntMatrixFromNormalized2List(
const pixelIndex = startIndex + 3 * (y * width + x);
return clamp(
Math.round(
0.299 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex],
) +
0.587 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 1],
) +
0.114 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 2],
),
0.299 * bipolarFloatToRGB(imageList[pixelIndex]) +
0.587 * bipolarFloatToRGB(imageList[pixelIndex + 1]) +
0.114 * bipolarFloatToRGB(imageList[pixelIndex + 2]),
),
0,
255,
);
}),
);
}
};

View file

@ -1,17 +1,19 @@
import log from "@/next/log";
import mlIDbStorage from "services/face/db";
import { Person } from "services/face/types";
import { clusterFaces } from "./cluster";
import { saveFaceCrop } from "./f-index";
import { fetchImageBitmap, getLocalFile } from "./image";
export interface Person {
id: number;
name?: string;
files: Array<number>;
displayFaceId?: string;
}
// TODO-ML(MR): Forced disable clustering. It doesn't currently work,
// need to finalize it before we move out of beta.
//
// > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed
/*
export const syncPeopleIndex = async () => {
// TODO-ML(MR): Forced disable clustering. It doesn't currently work,
// need to finalize it before we move out of beta.
//
// > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed
/*
if (
syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === batchSize && Math.random() < 0)
@ -32,16 +34,16 @@ export const syncPeopleIndex = async () => {
if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) {
return;
}
*/
// TODO: have faces addresable through fileId + faceId
// to avoid index based addressing, which is prone to wrong results
// one way could be to match nearest face within threshold in the file
/*
const allFacesMap =
syncContext.allSyncedFacesMap ??
(syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap());
*/
// await this.init();
@ -83,17 +85,18 @@ export const syncPeopleIndex = async () => {
: best,
);
if (personFace && !personFace.crop?.cacheKey) {
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
await saveFaceCrop(imageBitmap, personFace);
}
const person: Person = {
id: index,
files: faces.map((f) => f.fileId),
displayFaceId: personFace?.id,
faceCropCacheKey: personFace?.crop?.cacheKey,
};
await mlIDbStorage.putPerson(person);
@ -108,3 +111,21 @@ export const syncPeopleIndex = async () => {
// await mlIDbStorage.setIndexVersion("people", filesVersion);
};
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};
*/

View file

@ -0,0 +1,160 @@
import log from "@/next/log";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { putEmbedding } from "services/embeddingService";
import type { EnteFile } from "types/file";
import type { Point } from "./geom";
import type { Face, FaceDetection, MlFileData } from "./types";
export const putFaceEmbedding = async (
enteFile: EnteFile,
mlFileData: MlFileData,
) => {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
};
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
/* TODO
public client?: string;
public error?: boolean;
*/
public constructor(faces: ServerFace[], version: number) {
this.faces = faces;
this.version = version;
}
}
class ServerFace {
public faceID: string;
// TODO-ML: singular?
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Point[];
public constructor(box: ServerFaceBox, landmarks: Point[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
// TODO-ML: Add client UA and version
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, landmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}

View file

@ -1,6 +1,6 @@
import { Box, Point, boxFromBoundingBox } from "services/face/geom";
import { FaceDetection } from "services/face/types";
// TODO-ML(MR): Do we need two separate Matrix libraries?
import { Box, Point } from "services/face/geom";
import type { FaceDetection } from "services/face/types";
// TODO-ML: Do we need two separate Matrix libraries?
//
// Keeping this in a separate file so that we can audit this. If these can be
// expressed using ml-matrix, then we can move the code to f-index.
@ -22,43 +22,36 @@ export const transformFaceDetections = (
inBox: Box,
toBox: Box,
): FaceDetection[] => {
const transform = computeTransformToBox(inBox, toBox);
return faceDetections.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
const transform = boxTransformationMatrix(inBox, toBox);
return faceDetections.map((f) => ({
box: transformBox(f.box, transform),
landmarks: f.landmarks.map((p) => transformPoint(p, transform)),
probability: f.probability,
}));
};
function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
return compose(
const boxTransformationMatrix = (inBox: Box, toBox: Box): Matrix =>
compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
}
function transformPoint(point: Point, transform: Matrix) {
const transformPoint = (point: Point, transform: Matrix) => {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
}
};
function transformPoints(points: Point[], transform: Matrix) {
return points?.map((p) => transformPoint(p, transform));
}
const transformBox = (box: Box, transform: Matrix) => {
const topLeft = transformPoint(new Point(box.x, box.y), transform);
const bottomRight = transformPoint(
new Point(box.x + box.width, box.y + box.height),
transform,
);
function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return boxFromBoundingBox({
left: topLeft.x,
top: topLeft.y,
right: bottomRight.x,
bottom: bottomRight.y,
return new Box({
x: topLeft.x,
y: topLeft.y,
width: bottomRight.x - topLeft.x,
height: bottomRight.y - topLeft.y,
});
}
};

View file

@ -1,73 +1,35 @@
import { Box, Dimensions, Point } from "services/face/geom";
export declare type Cluster = Array<number>;
export declare type Landmark = Point;
export interface FaceDetection {
// box and landmarks is relative to image dimentions stored at mlFileData
box: Box;
landmarks?: Array<Landmark>;
landmarks?: Point[];
probability?: number;
}
export interface DetectedFace {
fileId: number;
detection: FaceDetection;
}
export interface DetectedFaceWithId extends DetectedFace {
id: string;
}
export interface FaceCrop {
image: ImageBitmap;
// imageBox is relative to image dimentions stored at mlFileData
imageBox: Box;
}
export interface StoredFaceCrop {
cacheKey: string;
imageBox: Box;
}
export interface CroppedFace extends DetectedFaceWithId {
crop?: StoredFaceCrop;
}
export interface FaceAlignment {
// TODO: remove affine matrix as rotation, size and center
// TODO-ML: remove affine matrix as rotation, size and center
// are simple to store and use, affine matrix adds complexity while getting crop
affineMatrix: Array<Array<number>>;
affineMatrix: number[][];
rotation: number;
// size and center is relative to image dimentions stored at mlFileData
size: number;
center: Point;
}
export interface AlignedFace extends CroppedFace {
export interface Face {
fileId: number;
detection: FaceDetection;
id: string;
alignment?: FaceAlignment;
blurValue?: number;
}
export declare type FaceEmbedding = Float32Array;
embedding?: Float32Array;
export interface FaceWithEmbedding extends AlignedFace {
embedding?: FaceEmbedding;
}
export interface Face extends FaceWithEmbedding {
personId?: number;
}
export interface Person {
id: number;
name?: string;
files: Array<number>;
displayFaceId?: string;
faceCropCacheKey?: string;
}
export interface MlFileData {
fileId: number;
faces?: Face[];

View file

@ -1,20 +1,15 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import ComlinkCryptoWorker, {
getDedicatedCryptoWorker,
} from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import PQueue from "p-queue";
import { putEmbedding } from "services/embeddingService";
import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db";
import { fetchImageBitmap, getLocalFile } from "services/face/image";
import { Face, FaceDetection, Landmark, MlFileData } from "services/face/types";
import mlIDbStorage, {
ML_SEARCH_CONFIG_NAME,
type MinimalPersistedFileData,
} from "services/face/db";
import { putFaceEmbedding } from "services/face/remote";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import { indexFaces, saveFaceCrop } from "../face/f-index";
import { indexFaces } from "../face/f-index";
/**
* TODO-ML(MR): What and why.
@ -57,63 +52,30 @@ class MLSyncContext {
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number;
public error?: Error;
public syncQueue: PQueue;
// TODO: wheather to limit concurrent downloads
// private downloadQueue: PQueue;
private concurrency: number;
private comlinkCryptoWorker: Array<
ComlinkWorker<typeof DedicatedCryptoWorker>
>;
private enteWorkers: Array<any>;
constructor(token: string, userID: number, concurrency?: number) {
constructor(token: string, userID: number) {
this.token = token;
this.userID = userID;
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.concurrency = concurrency ?? getConcurrency();
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
// timeout on queue will keep the operation open till worker is terminated
this.syncQueue = new PQueue({ concurrency: this.concurrency });
logQueueStats(this.syncQueue, "sync");
// this.downloadQueue = new PQueue({ concurrency: 1 });
// logQueueStats(this.downloadQueue, 'download');
this.comlinkCryptoWorker = new Array(this.concurrency);
this.enteWorkers = new Array(this.concurrency);
}
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
}
return this.enteWorkers[wid];
const concurrency = getConcurrency();
this.syncQueue = new PQueue({ concurrency });
}
public async dispose() {
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
enteComlinkWorker?.terminate();
}
}
}
export const getConcurrency = () =>
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
const getConcurrency = () =>
Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
class MachineLearningService {
private localSyncContext: Promise<MLSyncContext>;
@ -139,16 +101,12 @@ class MachineLearningService {
return !error && nOutOfSyncFiles > 0;
}
public async regenerateFaceCrop(faceID: string) {
return regenerateFaceCrop(faceID);
}
private newMlData(fileId: number) {
return {
fileId,
mlVersion: 0,
errorCount: 0,
} as MlFileData;
} as MinimalPersistedFileData;
}
private async getLocalFilesMap(syncContext: MLSyncContext) {
@ -327,9 +285,6 @@ class MachineLearningService {
localFile?: globalThis.File,
) {
try {
console.log(
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
);
const mlFileData = await this.syncFile(enteFile, localFile);
syncContext.nSyncedFiles += 1;
return mlFileData;
@ -363,44 +318,17 @@ class MachineLearningService {
}
private async syncFile(enteFile: EnteFile, localFile?: globalThis.File) {
const oldMlFile = await this.getMLFileData(enteFile.id);
const oldMlFile = await mlIDbStorage.getFile(enteFile.id);
if (oldMlFile && oldMlFile.mlVersion) {
return oldMlFile;
}
const newMlFile = await indexFaces(enteFile, localFile);
await this.persistOnServer(newMlFile, enteFile);
await putFaceEmbedding(enteFile, newMlFile);
await mlIDbStorage.putFile(newMlFile);
return newMlFile;
}
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
}
private async getMLFileData(fileId: number) {
return mlIDbStorage.getFile(fileId);
}
private async persistMLFileSyncError(enteFile: EnteFile, e: Error) {
try {
await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => {
@ -420,161 +348,3 @@ class MachineLearningService {
}
export default new MachineLearningService();
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
/* TODO
public client?: string;
public error?: boolean;
*/
public constructor(faces: ServerFace[], version: number) {
this.faces = faces;
this.version = version;
}
}
class ServerFace {
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
// TODO: Add client UA and version
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};

View file

@ -9,7 +9,6 @@ import { createFaceComlinkWorker } from "services/face";
import mlIDbStorage from "services/face/db";
import type { DedicatedMLWorker } from "services/face/face.worker";
import { EnteFile } from "types/file";
import { logQueueStats } from "./machineLearningService";
export type JobState = "Scheduled" | "Running" | "NotScheduled";
@ -309,3 +308,15 @@ class MLWorkManager {
}
export default new MLWorkManager();
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}

View file

@ -3,7 +3,7 @@ import log from "@/next/log";
import * as chrono from "chrono-node";
import { t } from "i18next";
import mlIDbStorage from "services/face/db";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity";

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type";
import { IndexStatus } from "services/face/db";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file";

View file

@ -1,4 +1,4 @@
import { clearCaches } from "@/next/blob-cache";
import { clearBlobCaches } from "@/next/blob-cache";
import log from "@/next/log";
import InMemoryStore from "@ente/shared/storage/InMemoryStore";
import localForage from "@ente/shared/storage/localForage";
@ -43,7 +43,7 @@ export const accountLogout = async () => {
log.error("Ignoring error during logout (local forage)", e);
}
try {
await clearCaches();
await clearBlobCaches();
} catch (e) {
log.error("Ignoring error during logout (cache)", e);
}

View file

@ -20,8 +20,8 @@ export type BlobCacheNamespace = (typeof blobCacheNames)[number];
*
* This cache is suitable for storing large amounts of data (entire files).
*
* To obtain a cache for a given namespace, use {@link openCache}. To clear all
* cached data (e.g. during logout), use {@link clearCaches}.
* To obtain a cache for a given namespace, use {@link openBlobCache}. To clear all
* cached data (e.g. during logout), use {@link clearBlobCaches}.
*
* [Note: Caching files]
*
@ -69,14 +69,31 @@ export interface BlobCache {
delete: (key: string) => Promise<boolean>;
}
const cachedCaches = new Map<BlobCacheNamespace, BlobCache>();
/**
* Return the {@link BlobCache} corresponding to the given {@link name}.
*
* This is a wrapper over {@link openBlobCache} that caches (pun intended) the
* cache and returns the same one each time it is called with the same name.
* It'll open the cache lazily the first time it is invoked.
*/
export const blobCache = async (
name: BlobCacheNamespace,
): Promise<BlobCache> => {
let c = cachedCaches.get(name);
if (!c) cachedCaches.set(name, (c = await openBlobCache(name)));
return c;
};
/**
* Create a new {@link BlobCache} corresponding to the given {@link name}.
*
* @param name One of the arbitrary but predefined namespaces of type
* {@link BlobCacheNamespace} which group related data and allow us to use the
* same key across namespaces.
*/
export const openCache = async (
export const openBlobCache = async (
name: BlobCacheNamespace,
): Promise<BlobCache> =>
isElectron() ? openOPFSCacheWeb(name) : openWebCache(name);
@ -194,7 +211,7 @@ export const cachedOrNew = async (
key: string,
get: () => Promise<Blob>,
): Promise<Blob> => {
const cache = await openCache(cacheName);
const cache = await openBlobCache(cacheName);
const cachedBlob = await cache.get(key);
if (cachedBlob) return cachedBlob;
@ -204,15 +221,17 @@ export const cachedOrNew = async (
};
/**
* Delete all cached data.
* Delete all cached data, including cached caches.
*
* Meant for use during logout, to reset the state of the user's account.
*/
export const clearCaches = async () =>
isElectron() ? clearOPFSCaches() : clearWebCaches();
export const clearBlobCaches = async () => {
cachedCaches.clear();
return isElectron() ? clearOPFSCaches() : clearWebCaches();
};
const clearWebCaches = async () => {
await Promise.all(blobCacheNames.map((name) => caches.delete(name)));
await Promise.allSettled(blobCacheNames.map((name) => caches.delete(name)));
};
const clearOPFSCaches = async () => {