Prune
This commit is contained in:
parent
03df858dcc
commit
006ea085fe
4 changed files with 1 additions and 251 deletions
|
@ -1,108 +0,0 @@
|
|||
import log from "@/next/log";
|
||||
import * as tfjsConverter from "@tensorflow/tfjs-converter";
|
||||
import * as tf from "@tensorflow/tfjs-core";
|
||||
import { SCENE_DETECTION_IMAGE_SIZE } from "constants/mlConfig";
|
||||
import {
|
||||
ObjectDetection,
|
||||
SceneDetectionMethod,
|
||||
SceneDetectionService,
|
||||
Versioned,
|
||||
} from "types/machineLearning";
|
||||
import { resizeToSquare } from "utils/image";
|
||||
|
||||
class ImageScene implements SceneDetectionService {
|
||||
method: Versioned<SceneDetectionMethod>;
|
||||
private model: tfjsConverter.GraphModel;
|
||||
private sceneMap: { [key: string]: string };
|
||||
private ready: Promise<void>;
|
||||
private workerID: number;
|
||||
|
||||
public constructor() {
|
||||
this.method = {
|
||||
value: "ImageScene",
|
||||
version: 1,
|
||||
};
|
||||
this.workerID = Math.round(Math.random() * 1000);
|
||||
}
|
||||
|
||||
private async init() {
|
||||
log.info(`[${this.workerID}]`, "ImageScene init called");
|
||||
if (this.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.sceneMap = await (
|
||||
await fetch("/models/imagescene/sceneMap.json")
|
||||
).json();
|
||||
|
||||
this.model = await tfjsConverter.loadGraphModel(
|
||||
"/models/imagescene/model.json",
|
||||
);
|
||||
log.info(
|
||||
`[${this.workerID}]`,
|
||||
"loaded ImageScene model",
|
||||
tf.getBackend(),
|
||||
);
|
||||
|
||||
tf.tidy(() => {
|
||||
const zeroTensor = tf.zeros([1, 224, 224, 3]);
|
||||
// warmup the model
|
||||
this.model.predict(zeroTensor) as tf.Tensor;
|
||||
});
|
||||
}
|
||||
|
||||
private async getImageSceneModel() {
|
||||
log.info(`[${this.workerID}]`, "ImageScene getImageSceneModel called");
|
||||
if (!this.ready) {
|
||||
this.ready = this.init();
|
||||
}
|
||||
await this.ready;
|
||||
return this.model;
|
||||
}
|
||||
|
||||
async detectScenes(image: ImageBitmap, minScore: number) {
|
||||
const resized = resizeToSquare(image, SCENE_DETECTION_IMAGE_SIZE);
|
||||
|
||||
const model = await this.getImageSceneModel();
|
||||
|
||||
const output = tf.tidy(() => {
|
||||
const tfImage = tf.browser.fromPixels(resized.image);
|
||||
const input = tf.expandDims(tf.cast(tfImage, "float32"));
|
||||
const output = model.predict(input) as tf.Tensor;
|
||||
return output;
|
||||
});
|
||||
|
||||
const data = (await output.data()) as Float32Array;
|
||||
output.dispose();
|
||||
|
||||
const scenes = this.parseSceneDetectionResult(
|
||||
data,
|
||||
minScore,
|
||||
image.width,
|
||||
image.height,
|
||||
);
|
||||
|
||||
return scenes;
|
||||
}
|
||||
|
||||
private parseSceneDetectionResult(
|
||||
outputData: Float32Array,
|
||||
minScore: number,
|
||||
width: number,
|
||||
height: number,
|
||||
): ObjectDetection[] {
|
||||
const scenes = [];
|
||||
for (let i = 0; i < outputData.length; i++) {
|
||||
if (outputData[i] >= minScore) {
|
||||
scenes.push({
|
||||
class: this.sceneMap[i.toString()],
|
||||
score: outputData[i],
|
||||
bbox: [0, 0, width, height],
|
||||
});
|
||||
}
|
||||
}
|
||||
return scenes;
|
||||
}
|
||||
}
|
||||
|
||||
export default new ImageScene();
|
|
@ -22,15 +22,12 @@ import {
|
|||
MLLibraryData,
|
||||
MLSyncConfig,
|
||||
MLSyncContext,
|
||||
SceneDetectionMethod,
|
||||
SceneDetectionService,
|
||||
} from "types/machineLearning";
|
||||
import { logQueueStats } from "utils/machineLearning";
|
||||
import arcfaceAlignmentService from "./arcfaceAlignmentService";
|
||||
import arcfaceCropService from "./arcfaceCropService";
|
||||
import dbscanClusteringService from "./dbscanClusteringService";
|
||||
import hdbscanClusteringService from "./hdbscanClusteringService";
|
||||
import imageSceneService from "./imageSceneService";
|
||||
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
|
||||
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
|
||||
import yoloFaceDetectionService from "./yoloFaceDetectionService";
|
||||
|
@ -46,16 +43,6 @@ export class MLFactory {
|
|||
throw Error("Unknon face detection method: " + method);
|
||||
}
|
||||
|
||||
public static getSceneDetectionService(
|
||||
method: SceneDetectionMethod,
|
||||
): SceneDetectionService {
|
||||
if (method === "ImageScene") {
|
||||
return imageSceneService;
|
||||
}
|
||||
|
||||
throw Error("Unknown scene detection method: " + method);
|
||||
}
|
||||
|
||||
public static getFaceCropService(method: FaceCropMethod) {
|
||||
if (method === "ArcFace") {
|
||||
return arcfaceCropService;
|
||||
|
|
|
@ -2,9 +2,6 @@ import log from "@/next/log";
|
|||
import { APPS } from "@ente/shared/apps/constants";
|
||||
import ComlinkCryptoWorker from "@ente/shared/crypto";
|
||||
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
|
||||
import "@tensorflow/tfjs-backend-cpu";
|
||||
import "@tensorflow/tfjs-backend-webgl";
|
||||
import * as tf from "@tensorflow/tfjs-core";
|
||||
import { MAX_ML_SYNC_ERROR_COUNT } from "constants/mlConfig";
|
||||
import downloadManager from "services/download";
|
||||
import { putEmbedding } from "services/embeddingService";
|
||||
|
@ -21,7 +18,6 @@ import { LocalFileMlDataToServerFileMl } from "utils/machineLearning/mldataMappe
|
|||
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
||||
import FaceService from "./faceService";
|
||||
import { MLFactory } from "./machineLearningFactory";
|
||||
import ObjectService from "./objectService";
|
||||
import PeopleService from "./peopleService";
|
||||
import ReaderService from "./readerService";
|
||||
|
||||
|
@ -58,12 +54,6 @@ class MachineLearningService {
|
|||
await this.syncIndex(syncContext);
|
||||
}
|
||||
|
||||
// tf.engine().endScope();
|
||||
|
||||
// if (syncContext.config.tsne) {
|
||||
// await this.runTSNE(syncContext);
|
||||
// }
|
||||
|
||||
const mlSyncResult: MLSyncResult = {
|
||||
nOutOfSyncFiles: syncContext.outOfSyncFiles.length,
|
||||
nSyncedFiles: syncContext.nSyncedFiles,
|
||||
|
@ -78,9 +68,6 @@ class MachineLearningService {
|
|||
};
|
||||
// log.info('[MLService] sync results: ', mlSyncResult);
|
||||
|
||||
// await syncContext.dispose();
|
||||
log.info("Final TF Memory stats: ", JSON.stringify(tf.memory()));
|
||||
|
||||
return mlSyncResult;
|
||||
}
|
||||
|
||||
|
@ -183,50 +170,6 @@ class MachineLearningService {
|
|||
log.info("getOutOfSyncFiles", Date.now() - startTime, "ms");
|
||||
}
|
||||
|
||||
// TODO: optimize, use indexdb indexes, move facecrops to cache to reduce io
|
||||
// remove, already done
|
||||
private async getUniqueOutOfSyncFilesNoIdx(
|
||||
syncContext: MLSyncContext,
|
||||
files: EnteFile[],
|
||||
) {
|
||||
const limit = syncContext.config.batchSize;
|
||||
const mlVersion = syncContext.config.mlVersion;
|
||||
const uniqueFiles: Map<number, EnteFile> = new Map<number, EnteFile>();
|
||||
for (let i = 0; uniqueFiles.size < limit && i < files.length; i++) {
|
||||
const mlFileData = await this.getMLFileData(files[i].id);
|
||||
const mlFileVersion = mlFileData?.mlVersion || 0;
|
||||
if (
|
||||
!uniqueFiles.has(files[i].id) &&
|
||||
(!mlFileData?.errorCount || mlFileData.errorCount < 2) &&
|
||||
(mlFileVersion < mlVersion ||
|
||||
syncContext.config.imageSource !== mlFileData.imageSource)
|
||||
) {
|
||||
uniqueFiles.set(files[i].id, files[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return [...uniqueFiles.values()];
|
||||
}
|
||||
|
||||
private async getOutOfSyncFilesNoIdx(syncContext: MLSyncContext) {
|
||||
const existingFilesMap = await this.getLocalFilesMap(syncContext);
|
||||
// existingFiles.sort(
|
||||
// (a, b) => b.metadata.creationTime - a.metadata.creationTime
|
||||
// );
|
||||
console.time("getUniqueOutOfSyncFiles");
|
||||
syncContext.outOfSyncFiles = await this.getUniqueOutOfSyncFilesNoIdx(
|
||||
syncContext,
|
||||
[...existingFilesMap.values()],
|
||||
);
|
||||
log.info("getUniqueOutOfSyncFiles");
|
||||
log.info(
|
||||
"Got unique outOfSyncFiles: ",
|
||||
syncContext.outOfSyncFiles.length,
|
||||
"for batchSize: ",
|
||||
syncContext.config.batchSize,
|
||||
);
|
||||
}
|
||||
|
||||
private async syncFiles(syncContext: MLSyncContext) {
|
||||
try {
|
||||
const functions = syncContext.outOfSyncFiles.map(
|
||||
|
@ -295,7 +238,6 @@ class MachineLearningService {
|
|||
userID: number,
|
||||
enteFile: EnteFile,
|
||||
localFile?: globalThis.File,
|
||||
textDetectionTimeoutIndex?: number,
|
||||
): Promise<MlFileData | Error> {
|
||||
const syncContext = await this.getLocalSyncContext(token, userID);
|
||||
|
||||
|
@ -304,7 +246,6 @@ class MachineLearningService {
|
|||
syncContext,
|
||||
enteFile,
|
||||
localFile,
|
||||
textDetectionTimeoutIndex,
|
||||
);
|
||||
|
||||
if (syncContext.nSyncedFiles >= syncContext.config.batchSize) {
|
||||
|
@ -322,19 +263,15 @@ class MachineLearningService {
|
|||
syncContext: MLSyncContext,
|
||||
enteFile: EnteFile,
|
||||
localFile?: globalThis.File,
|
||||
textDetectionTimeoutIndex?: number,
|
||||
): Promise<MlFileData> {
|
||||
try {
|
||||
console.log(
|
||||
"Start index for ",
|
||||
enteFile.title ?? "no title",
|
||||
enteFile.id,
|
||||
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
|
||||
);
|
||||
const mlFileData = await this.syncFile(
|
||||
syncContext,
|
||||
enteFile,
|
||||
localFile,
|
||||
textDetectionTimeoutIndex,
|
||||
);
|
||||
syncContext.nSyncedFaces += mlFileData.faces?.length || 0;
|
||||
syncContext.nSyncedFiles += 1;
|
||||
|
@ -365,14 +302,6 @@ class MachineLearningService {
|
|||
|
||||
await this.persistMLFileSyncError(syncContext, enteFile, error);
|
||||
syncContext.nSyncedFiles += 1;
|
||||
} finally {
|
||||
console.log(
|
||||
"done index for ",
|
||||
enteFile.title ?? "no title",
|
||||
enteFile.id,
|
||||
);
|
||||
// addLogLine('TF Memory stats: ', JSON.stringify(tf.memory()));
|
||||
log.info("TF Memory stats: ", JSON.stringify(tf.memory()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -380,8 +309,6 @@ class MachineLearningService {
|
|||
syncContext: MLSyncContext,
|
||||
enteFile: EnteFile,
|
||||
localFile?: globalThis.File,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
textDetectionTimeoutIndex?: number,
|
||||
) {
|
||||
console.log("Syncing for file" + enteFile.title);
|
||||
const fileContext: MLSyncFileContext = { enteFile, localFile };
|
||||
|
@ -406,15 +333,6 @@ class MachineLearningService {
|
|||
await ReaderService.getImageBitmap(syncContext, fileContext);
|
||||
await Promise.all([
|
||||
this.syncFileAnalyzeFaces(syncContext, fileContext),
|
||||
// ObjectService.syncFileObjectDetections(
|
||||
// syncContext,
|
||||
// fileContext
|
||||
// ),
|
||||
// TextService.syncFileTextDetections(
|
||||
// syncContext,
|
||||
// fileContext,
|
||||
// textDetectionTimeoutIndex
|
||||
// ),
|
||||
]);
|
||||
newMlFile.errorCount = 0;
|
||||
newMlFile.lastErrorMessage = undefined;
|
||||
|
@ -425,15 +343,7 @@ class MachineLearningService {
|
|||
newMlFile.mlVersion = oldMlFile.mlVersion;
|
||||
throw e;
|
||||
} finally {
|
||||
fileContext.tfImage && fileContext.tfImage.dispose();
|
||||
fileContext.imageBitmap && fileContext.imageBitmap.close();
|
||||
// log.info('8 TF Memory stats: ',JSON.stringify(tf.memory()));
|
||||
|
||||
// TODO: enable once faceId changes go in
|
||||
// await removeOldFaceCrops(
|
||||
// fileContext.oldMlFile,
|
||||
// fileContext.newMlFile
|
||||
// );
|
||||
}
|
||||
|
||||
return newMlFile;
|
||||
|
|
|
@ -61,10 +61,6 @@ export declare type ImageType = "Original" | "Preview";
|
|||
|
||||
export declare type FaceDetectionMethod = "FaceApiSSD" | "YoloFace";
|
||||
|
||||
export declare type ObjectDetectionMethod = "SSDMobileNetV2";
|
||||
|
||||
export declare type SceneDetectionMethod = "ImageScene";
|
||||
|
||||
export declare type FaceCropMethod = "ArcFace";
|
||||
|
||||
export declare type FaceAlignmentMethod =
|
||||
|
@ -172,8 +168,6 @@ export interface MlFileData {
|
|||
faceCropMethod?: Versioned<FaceCropMethod>;
|
||||
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
|
||||
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
|
||||
objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
|
||||
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
|
||||
mlVersion: number;
|
||||
errorCount: number;
|
||||
lastErrorMessage?: string;
|
||||
|
@ -183,17 +177,6 @@ export interface FaceDetectionConfig {
|
|||
method: FaceDetectionMethod;
|
||||
}
|
||||
|
||||
export interface ObjectDetectionConfig {
|
||||
method: ObjectDetectionMethod;
|
||||
maxNumBoxes: number;
|
||||
minScore: number;
|
||||
}
|
||||
|
||||
export interface SceneDetectionConfig {
|
||||
method: SceneDetectionMethod;
|
||||
minScore: number;
|
||||
}
|
||||
|
||||
export interface FaceCropConfig {
|
||||
enabled: boolean;
|
||||
method: FaceCropMethod;
|
||||
|
@ -265,7 +248,6 @@ export interface MLSyncContext {
|
|||
faceEmbeddingService: FaceEmbeddingService;
|
||||
blurDetectionService: BlurDetectionService;
|
||||
faceClusteringService: ClusteringService;
|
||||
sceneDetectionService: SceneDetectionService;
|
||||
|
||||
localFilesMap: Map<number, EnteFile>;
|
||||
outOfSyncFiles: EnteFile[];
|
||||
|
@ -292,7 +274,6 @@ export interface MLSyncFileContext {
|
|||
oldMlFile?: MlFileData;
|
||||
newMlFile?: MlFileData;
|
||||
|
||||
tfImage?: tf.Tensor3D;
|
||||
imageBitmap?: ImageBitmap;
|
||||
|
||||
newDetection?: boolean;
|
||||
|
@ -318,26 +299,6 @@ export interface FaceDetectionService {
|
|||
dispose(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface ObjectDetectionService {
|
||||
method: Versioned<ObjectDetectionMethod>;
|
||||
// init(): Promise<void>;
|
||||
detectObjects(
|
||||
image: ImageBitmap,
|
||||
maxNumBoxes: number,
|
||||
minScore: number,
|
||||
): Promise<ObjectDetection[]>;
|
||||
dispose(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface SceneDetectionService {
|
||||
method: Versioned<SceneDetectionMethod>;
|
||||
// init(): Promise<void>;
|
||||
detectScenes(
|
||||
image: ImageBitmap,
|
||||
minScore: number,
|
||||
): Promise<ObjectDetection[]>;
|
||||
}
|
||||
|
||||
export interface FaceCropService {
|
||||
method: Versioned<FaceCropMethod>;
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue