This commit is contained in:
Manav Rathi 2024-04-11 11:37:16 +05:30
parent 03df858dcc
commit 006ea085fe
No known key found for this signature in database
4 changed files with 1 additions and 251 deletions

View file

@ -1,108 +0,0 @@
import log from "@/next/log";
import * as tfjsConverter from "@tensorflow/tfjs-converter";
import * as tf from "@tensorflow/tfjs-core";
import { SCENE_DETECTION_IMAGE_SIZE } from "constants/mlConfig";
import {
ObjectDetection,
SceneDetectionMethod,
SceneDetectionService,
Versioned,
} from "types/machineLearning";
import { resizeToSquare } from "utils/image";
class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
private model: tfjsConverter.GraphModel;
private sceneMap: { [key: string]: string };
private ready: Promise<void>;
private workerID: number;
public constructor() {
this.method = {
value: "ImageScene",
version: 1,
};
this.workerID = Math.round(Math.random() * 1000);
}
private async init() {
log.info(`[${this.workerID}]`, "ImageScene init called");
if (this.model) {
return;
}
this.sceneMap = await (
await fetch("/models/imagescene/sceneMap.json")
).json();
this.model = await tfjsConverter.loadGraphModel(
"/models/imagescene/model.json",
);
log.info(
`[${this.workerID}]`,
"loaded ImageScene model",
tf.getBackend(),
);
tf.tidy(() => {
const zeroTensor = tf.zeros([1, 224, 224, 3]);
// warmup the model
this.model.predict(zeroTensor) as tf.Tensor;
});
}
private async getImageSceneModel() {
log.info(`[${this.workerID}]`, "ImageScene getImageSceneModel called");
if (!this.ready) {
this.ready = this.init();
}
await this.ready;
return this.model;
}
async detectScenes(image: ImageBitmap, minScore: number) {
const resized = resizeToSquare(image, SCENE_DETECTION_IMAGE_SIZE);
const model = await this.getImageSceneModel();
const output = tf.tidy(() => {
const tfImage = tf.browser.fromPixels(resized.image);
const input = tf.expandDims(tf.cast(tfImage, "float32"));
const output = model.predict(input) as tf.Tensor;
return output;
});
const data = (await output.data()) as Float32Array;
output.dispose();
const scenes = this.parseSceneDetectionResult(
data,
minScore,
image.width,
image.height,
);
return scenes;
}
private parseSceneDetectionResult(
outputData: Float32Array,
minScore: number,
width: number,
height: number,
): ObjectDetection[] {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= minScore) {
scenes.push({
class: this.sceneMap[i.toString()],
score: outputData[i],
bbox: [0, 0, width, height],
});
}
}
return scenes;
}
}
export default new ImageScene();

View file

@ -22,15 +22,12 @@ import {
MLLibraryData,
MLSyncConfig,
MLSyncContext,
SceneDetectionMethod,
SceneDetectionService,
} from "types/machineLearning";
import { logQueueStats } from "utils/machineLearning";
import arcfaceAlignmentService from "./arcfaceAlignmentService";
import arcfaceCropService from "./arcfaceCropService";
import dbscanClusteringService from "./dbscanClusteringService";
import hdbscanClusteringService from "./hdbscanClusteringService";
import imageSceneService from "./imageSceneService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
@ -46,16 +43,6 @@ export class MLFactory {
throw Error("Unknon face detection method: " + method);
}
public static getSceneDetectionService(
method: SceneDetectionMethod,
): SceneDetectionService {
if (method === "ImageScene") {
return imageSceneService;
}
throw Error("Unknown scene detection method: " + method);
}
public static getFaceCropService(method: FaceCropMethod) {
if (method === "ArcFace") {
return arcfaceCropService;

View file

@ -2,9 +2,6 @@ import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import "@tensorflow/tfjs-backend-cpu";
import "@tensorflow/tfjs-backend-webgl";
import * as tf from "@tensorflow/tfjs-core";
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
import downloadManager from "services/download";
import { putEmbedding } from "services/embeddingService";
@ -21,7 +18,6 @@ import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappe
import mlIDbStorage from "utils/storage/mlIDbStorage";
import FaceService from "./faceService";
import { MLFactory } from "./machineLearningFactory";
import ObjectService from "./objectService";
import PeopleService from "./peopleService";
import ReaderService from "./readerService";
@ -58,12 +54,6 @@ class MachineLearningService {
await this.syncIndex(syncContext);
}
// tf.engine().endScope();
// if (syncContext.config.tsne) {
// await this.runTSNE(syncContext);
// }
const mlSyncResult: MLSyncResult = {
nOutOfSyncFiles: syncContext.outOfSyncFiles.length,
nSyncedFiles: syncContext.nSyncedFiles,
@ -78,9 +68,6 @@ class MachineLearningService {
};
// log.info('[MLService] sync results: ', mlSyncResult);
// await syncContext.dispose();
log.info("Final TF Memory stats: ", JSON.stringify(tf.memory()));
return mlSyncResult;
}
@ -183,50 +170,6 @@ class MachineLearningService {
log.info("getOutOfSyncFiles", Date.now() - startTime, "ms");
}
// TODO: optimize, use indexdb indexes, move facecrops to cache to reduce io
// remove, already done
private async getUniqueOutOfSyncFilesNoIdx(
syncContext: MLSyncContext,
files: EnteFile[],
) {
const limit = syncContext.config.batchSize;
const mlVersion = syncContext.config.mlVersion;
const uniqueFiles: Map<number, EnteFile> = new Map<number, EnteFile>();
for (let i = 0; uniqueFiles.size < limit && i < files.length; i++) {
const mlFileData = await this.getMLFileData(files[i].id);
const mlFileVersion = mlFileData?.mlVersion || 0;
if (
!uniqueFiles.has(files[i].id) &&
(!mlFileData?.errorCount || mlFileData.errorCount < 2) &&
(mlFileVersion < mlVersion ||
syncContext.config.imageSource !== mlFileData.imageSource)
) {
uniqueFiles.set(files[i].id, files[i]);
}
}
return [...uniqueFiles.values()];
}
private async getOutOfSyncFilesNoIdx(syncContext: MLSyncContext) {
const existingFilesMap = await this.getLocalFilesMap(syncContext);
// existingFiles.sort(
// (a, b) => b.metadata.creationTime - a.metadata.creationTime
// );
console.time("getUniqueOutOfSyncFiles");
syncContext.outOfSyncFiles = await this.getUniqueOutOfSyncFilesNoIdx(
syncContext,
[...existingFilesMap.values()],
);
log.info("getUniqueOutOfSyncFiles");
log.info(
"Got unique outOfSyncFiles: ",
syncContext.outOfSyncFiles.length,
"for batchSize: ",
syncContext.config.batchSize,
);
}
private async syncFiles(syncContext: MLSyncContext) {
try {
const functions = syncContext.outOfSyncFiles.map(
@ -295,7 +238,6 @@ class MachineLearningService {
userID: number,
enteFile: EnteFile,
localFile?: globalThis.File,
textDetectionTimeoutIndex?: number,
): Promise<MlFileData | Error> {
const syncContext = await this.getLocalSyncContext(token, userID);
@ -304,7 +246,6 @@ class MachineLearningService {
syncContext,
enteFile,
localFile,
textDetectionTimeoutIndex,
);
if (syncContext.nSyncedFiles >= syncContext.config.batchSize) {
@ -322,19 +263,15 @@ class MachineLearningService {
syncContext: MLSyncContext,
enteFile: EnteFile,
localFile?: globalThis.File,
textDetectionTimeoutIndex?: number,
): Promise<MlFileData> {
try {
console.log(
"Start index for ",
enteFile.title ?? "no title",
enteFile.id,
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
);
const mlFileData = await this.syncFile(
syncContext,
enteFile,
localFile,
textDetectionTimeoutIndex,
);
syncContext.nSyncedFaces += mlFileData.faces?.length || 0;
syncContext.nSyncedFiles += 1;
@ -365,14 +302,6 @@ class MachineLearningService {
await this.persistMLFileSyncError(syncContext, enteFile, error);
syncContext.nSyncedFiles += 1;
} finally {
console.log(
"done index for ",
enteFile.title ?? "no title",
enteFile.id,
);
// addLogLine('TF Memory stats: ', JSON.stringify(tf.memory()));
log.info("TF Memory stats: ", JSON.stringify(tf.memory()));
}
}
@ -380,8 +309,6 @@ class MachineLearningService {
syncContext: MLSyncContext,
enteFile: EnteFile,
localFile?: globalThis.File,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
textDetectionTimeoutIndex?: number,
) {
console.log("Syncing for file" + enteFile.title);
const fileContext: MLSyncFileContext = { enteFile, localFile };
@ -406,15 +333,6 @@ class MachineLearningService {
await ReaderService.getImageBitmap(syncContext, fileContext);
await Promise.all([
this.syncFileAnalyzeFaces(syncContext, fileContext),
// ObjectService.syncFileObjectDetections(
// syncContext,
// fileContext
// ),
// TextService.syncFileTextDetections(
// syncContext,
// fileContext,
// textDetectionTimeoutIndex
// ),
]);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
@ -425,15 +343,7 @@ class MachineLearningService {
newMlFile.mlVersion = oldMlFile.mlVersion;
throw e;
} finally {
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
// log.info('8 TF Memory stats: ',JSON.stringify(tf.memory()));
// TODO: enable once faceId changes go in
// await removeOldFaceCrops(
// fileContext.oldMlFile,
// fileContext.newMlFile
// );
}
return newMlFile;

View file

@ -61,10 +61,6 @@ export declare type ImageType = "Original" | "Preview";
export declare type FaceDetectionMethod = "FaceApiSSD" | "YoloFace";
export declare type ObjectDetectionMethod = "SSDMobileNetV2";
export declare type SceneDetectionMethod = "ImageScene";
export declare type FaceCropMethod = "ArcFace";
export declare type FaceAlignmentMethod =
@ -172,8 +168,6 @@ export interface MlFileData {
faceCropMethod?: Versioned<FaceCropMethod>;
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
mlVersion: number;
errorCount: number;
lastErrorMessage?: string;
@ -183,17 +177,6 @@ export interface FaceDetectionConfig {
method: FaceDetectionMethod;
}
export interface ObjectDetectionConfig {
method: ObjectDetectionMethod;
maxNumBoxes: number;
minScore: number;
}
export interface SceneDetectionConfig {
method: SceneDetectionMethod;
minScore: number;
}
export interface FaceCropConfig {
enabled: boolean;
method: FaceCropMethod;
@ -265,7 +248,6 @@ export interface MLSyncContext {
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
sceneDetectionService: SceneDetectionService;
localFilesMap: Map<number, EnteFile>;
outOfSyncFiles: EnteFile[];
@ -292,7 +274,6 @@ export interface MLSyncFileContext {
oldMlFile?: MlFileData;
newMlFile?: MlFileData;
tfImage?: tf.Tensor3D;
imageBitmap?: ImageBitmap;
newDetection?: boolean;
@ -318,26 +299,6 @@ export interface FaceDetectionService {
dispose(): Promise<void>;
}
export interface ObjectDetectionService {
method: Versioned<ObjectDetectionMethod>;
// init(): Promise<void>;
detectObjects(
image: ImageBitmap,
maxNumBoxes: number,
minScore: number,
): Promise<ObjectDetection[]>;
dispose(): Promise<void>;
}
export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>;
detectScenes(
image: ImageBitmap,
minScore: number,
): Promise<ObjectDetection[]>;
}
export interface FaceCropService {
method: Versioned<FaceCropMethod>;