diff --git a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx index 8acd58eca..817aecb2b 100644 --- a/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx +++ b/web/apps/photos/src/components/Sidebar/AdvancedSettings.tsx @@ -14,7 +14,7 @@ import { EnteMenuItem } from "components/Menu/EnteMenuItem"; import { MenuItemGroup } from "components/Menu/MenuItemGroup"; import isElectron from "is-electron"; import { AppContext } from "pages/_app"; -import { ClipExtractionStatus, clipService } from "services/clip-service"; +import { CLIPIndexingStatus, clipService } from "services/clip-service"; import { formatNumber } from "utils/number/format"; export default function AdvancedSettings({ open, onClose, onRootClose }) { @@ -44,17 +44,15 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) { log.error("toggleFasterUpload failed", e); } }; - const [indexingStatus, setIndexingStatus] = useState({ + const [indexingStatus, setIndexingStatus] = useState({ indexed: 0, pending: 0, }); useEffect(() => { - const main = async () => { - setIndexingStatus(await clipService.getIndexingStatus()); - clipService.setOnUpdateHandler(setIndexingStatus); - }; - main(); + clipService.setOnUpdateHandler(setIndexingStatus); + clipService.getIndexingStatus().then((st) => setIndexingStatus(st)); + return () => clipService.setOnUpdateHandler(undefined); }, []); return ( diff --git a/web/apps/photos/src/services/clip-service.ts b/web/apps/photos/src/services/clip-service.ts index a2f2300d4..a1f3cd1a6 100644 --- a/web/apps/photos/src/services/clip-service.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -14,10 +14,11 @@ import downloadManager from "./download"; import { getLocalEmbeddings, putEmbedding } from "./embeddingService"; import { getAllLocalFiles, getLocalFiles } from "./fileService"; -const CLIP_EMBEDDING_LENGTH = 512; - -export interface ClipExtractionStatus { +/** Status of CLIP indexing on the images in the user's local library. */ +export interface CLIPIndexingStatus { + /** Number of items pending indexing. */ pending: number; + /** Number of items that have already been indexed. */ indexed: number; } @@ -62,15 +63,14 @@ export interface ClipExtractionStatus { * * Both these currently have one (and only one) associated model. */ -class ClipServiceImpl { +class ClipService { private embeddingExtractionInProgress: AbortController | null = null; private reRunNeeded = false; - private clipExtractionStatus: ClipExtractionStatus = { + private indexingStatus: CLIPIndexingStatus = { pending: 0, indexed: 0, }; - private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null = - null; + private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined; private liveEmbeddingExtractionQueue: PQueue; private onFileUploadedHandler: | ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void) @@ -137,28 +137,23 @@ class ClipServiceImpl { }; getIndexingStatus = async () => { - try { - if ( - !this.clipExtractionStatus || - (this.clipExtractionStatus.pending === 0 && - this.clipExtractionStatus.indexed === 0) - ) { - this.clipExtractionStatus = await getClipExtractionStatus(); - } - return this.clipExtractionStatus; - } catch (e) { - log.error("failed to get clip indexing status", e); + if ( + this.indexingStatus.pending === 0 && + this.indexingStatus.indexed === 0 + ) { + this.indexingStatus = await initialIndexingStatus(); } + return this.indexingStatus; }; - setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => { + /** + * Set the {@link handler} to invoke whenever our indexing status changes. + */ + setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => { this.onUpdateHandler = handler; - handler(this.clipExtractionStatus); }; - scheduleImageEmbeddingExtraction = async ( - model: Model = Model.ONNX_CLIP, - ) => { + scheduleImageEmbeddingExtraction = async () => { try { if (this.embeddingExtractionInProgress) { log.info( @@ -174,7 +169,7 @@ class ClipServiceImpl { const canceller = new AbortController(); this.embeddingExtractionInProgress = canceller; try { - await this.runClipEmbeddingExtraction(canceller, model); + await this.runClipEmbeddingExtraction(canceller); } finally { this.embeddingExtractionInProgress = null; if (!canceller.signal.aborted && this.reRunNeeded) { @@ -193,12 +188,9 @@ class ClipServiceImpl { } }; - getTextEmbedding = async ( - text: string, - model: Model = Model.ONNX_CLIP, - ): Promise => { + getTextEmbedding = async (text: string): Promise => { try { - return ensureElectron().computeTextEmbedding(model, text); + return ensureElectron().computeTextEmbedding(Model.ONNX_CLIP, text); } catch (e) { if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) { this.unsupportedPlatform = true; @@ -208,10 +200,7 @@ class ClipServiceImpl { } }; - private runClipEmbeddingExtraction = async ( - canceller: AbortController, - model: Model, - ) => { + private runClipEmbeddingExtraction = async (canceller: AbortController) => { try { if (this.unsupportedPlatform) { log.info( @@ -224,12 +213,12 @@ class ClipServiceImpl { return; } const localFiles = getPersonalFiles(await getAllLocalFiles(), user); - const existingEmbeddings = await getLocalEmbeddings(model); + const existingEmbeddings = await getLocalEmbeddings(); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, existingEmbeddings, ); - this.updateClipEmbeddingExtractionStatus({ + this.updateIndexingStatus({ indexed: existingEmbeddings.length, pending: pendingFiles.length, }); @@ -249,15 +238,11 @@ class ClipServiceImpl { throw Error(CustomError.REQUEST_CANCELLED); } const embeddingData = - await this.extractFileClipImageEmbedding(model, file); + await this.extractFileClipImageEmbedding(file); log.info( `successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`, ); - await this.encryptAndUploadEmbedding( - model, - file, - embeddingData, - ); + await this.encryptAndUploadEmbedding(file, embeddingData); this.onSuccessStatusUpdater(); log.info( `successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`, @@ -290,13 +275,10 @@ class ClipServiceImpl { } }; - private async runLocalFileClipExtraction( - arg: { - enteFile: EnteFile; - localFile: globalThis.File; - }, - model: Model = Model.ONNX_CLIP, - ) { + private async runLocalFileClipExtraction(arg: { + enteFile: EnteFile; + localFile: globalThis.File; + }) { const { enteFile, localFile } = arg; log.info( `clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, @@ -320,15 +302,9 @@ class ClipServiceImpl { ); try { await this.liveEmbeddingExtractionQueue.add(async () => { - const embedding = await this.extractLocalFileClipImageEmbedding( - model, - localFile, - ); - await this.encryptAndUploadEmbedding( - model, - enteFile, - embedding, - ); + const embedding = + await this.extractLocalFileClipImageEmbedding(localFile); + await this.encryptAndUploadEmbedding(enteFile, embedding); }); log.info( `successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, @@ -338,26 +314,22 @@ class ClipServiceImpl { } } - private extractLocalFileClipImageEmbedding = async ( - model: Model, - localFile: File, - ) => { + private extractLocalFileClipImageEmbedding = async (localFile: File) => { const file = await localFile .arrayBuffer() .then((buffer) => new Uint8Array(buffer)); const embedding = await ensureElectron().computeImageEmbedding( - model, + Model.ONNX_CLIP, file, ); return embedding; }; private encryptAndUploadEmbedding = async ( - model: Model, file: EnteFile, embeddingData: Float32Array, ) => { - if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) { + if (embeddingData?.length !== 512) { throw Error( `invalid length embedding data length: ${embeddingData?.length}`, ); @@ -368,6 +340,7 @@ class ClipServiceImpl { log.info( `putting clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`, ); + const model = Model.ONNX_CLIP; await putEmbedding({ fileID: file.id, encryptedEmbedding: encryptedEmbeddingData.encryptedData, @@ -376,34 +349,30 @@ class ClipServiceImpl { }); }; - updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => { - this.clipExtractionStatus = status; - if (this.onUpdateHandler) { - this.onUpdateHandler(status); - } + private updateIndexingStatus = (status: CLIPIndexingStatus) => { + this.indexingStatus = status; + const handler = this.onUpdateHandler; + if (handler) handler(status); }; - private extractFileClipImageEmbedding = async ( - model: Model, - file: EnteFile, - ) => { + private extractFileClipImageEmbedding = async (file: EnteFile) => { const thumb = await downloadManager.getThumbnail(file); const embedding = await ensureElectron().computeImageEmbedding( - model, + Model.ONNX_CLIP, thumb, ); return embedding; }; private onSuccessStatusUpdater = () => { - this.updateClipEmbeddingExtractionStatus({ - pending: this.clipExtractionStatus.pending - 1, - indexed: this.clipExtractionStatus.indexed + 1, + this.updateIndexingStatus({ + pending: this.indexingStatus.pending - 1, + indexed: this.indexingStatus.indexed + 1, }); }; } -export const clipService = new ClipServiceImpl(); +export const clipService = new ClipService(); const getNonClipEmbeddingExtractedFiles = async ( files: EnteFile[], @@ -453,14 +422,10 @@ export const computeClipMatchScore = async ( return score; }; -const getClipExtractionStatus = async ( - model: Model = Model.ONNX_CLIP, -): Promise => { +const initialIndexingStatus = async (): Promise => { const user = getData(LS_KEYS.USER); - if (!user) { - return; - } - const allEmbeddings = await getLocalEmbeddings(model); + if (!user) throw new Error("Orphan CLIP indexing without a login"); + const allEmbeddings = await getLocalEmbeddings(); const localFiles = getPersonalFiles(await getLocalFiles(), user); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index 882cdd16c..79fa5ef7e 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -38,9 +38,9 @@ export const getAllLocalEmbeddings = async () => { return embeddings; }; -export const getLocalEmbeddings = async (model: Model) => { +export const getLocalEmbeddings = async () => { const embeddings = await getAllLocalEmbeddings(); - return embeddings.filter((embedding) => embedding.model === model); + return embeddings.filter((embedding) => embedding.model === Model.ONNX_CLIP); }; const getModelEmbeddingSyncTime = async (model: Model) => {