From eea1fb83ae0c7fa66bef8673ac3a215b5a5e268b Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Mon, 4 Dec 2023 18:59:01 -0500 Subject: [PATCH] handle dim size update outside of jobs --- .../repositories/smart-info.repository.ts | 1 + .../domain/smart-info/smart-info.constant.ts | 107 ++++++++++++++++++ .../domain/smart-info/smart-info.service.ts | 22 +++- .../system-config/system-config.service.ts | 17 ++- .../repositories/smart-info.repository.ts | 63 +++++++---- 5 files changed, 182 insertions(+), 28 deletions(-) create mode 100644 server/src/domain/smart-info/smart-info.constant.ts diff --git a/server/src/domain/repositories/smart-info.repository.ts b/server/src/domain/repositories/smart-info.repository.ts index c41cdd7f6..834cde6ed 100644 --- a/server/src/domain/repositories/smart-info.repository.ts +++ b/server/src/domain/repositories/smart-info.repository.ts @@ -12,6 +12,7 @@ export interface EmbeddingSearch { } export interface ISmartInfoRepository { + init(): Promise; searchCLIP(search: EmbeddingSearch): Promise; searchFaces(search: EmbeddingSearch): Promise; upsert(smartInfo: Partial, embedding?: Embedding): Promise; diff --git a/server/src/domain/smart-info/smart-info.constant.ts b/server/src/domain/smart-info/smart-info.constant.ts new file mode 100644 index 000000000..5710c637e --- /dev/null +++ b/server/src/domain/smart-info/smart-info.constant.ts @@ -0,0 +1,107 @@ +export type ModelInfo = { + dimSize: number; +}; + +export const CLIP_MODEL_INFO: Record = { + RN50__openai: { + dimSize: 1024, + }, + RN50__yfcc15m: { + dimSize: 1024, + }, + RN50__cc12m: { + dimSize: 1024, + }, + RN101__openai: { + dimSize: 512, + }, + RN101__yfcc15m: { + dimSize: 512, + }, + RN50x4__openai: { + dimSize: 640, + }, + RN50x16__openai: { + dimSize: 768, + }, + RN50x64__openai: { + dimSize: 1024, + }, + 'ViT-B-32__openai': { + dimSize: 512, + }, + 'ViT-B-32__laion2b_e16': { + dimSize: 512, + }, + 'ViT-B-32__laion400m_e31': { + dimSize: 512, + }, + 'ViT-B-32__laion400m_e32': { + dimSize: 512, + }, + 'ViT-B-32__laion2b-s34b-b79k': { + dimSize: 512, + }, + 'ViT-B-16__openai': { + dimSize: 512, + }, + 'ViT-B-16__laion400m_e31': { + dimSize: 512, + }, + 'ViT-B-16__laion400m_e32': { + dimSize: 512, + }, + 'ViT-B-16-plus-240__laion400m_e31': { + dimSize: 640, + }, + 'ViT-B-16-plus-240__laion400m_e32': { + dimSize: 640, + }, + 'ViT-L-14__openai': { + dimSize: 768, + }, + 'ViT-L-14__laion400m_e31': { + dimSize: 768, + }, + 'ViT-L-14__laion400m_e32': { + dimSize: 768, + }, + 'ViT-L-14__laion2b-s32b-b82k': { + dimSize: 768, + }, + 'ViT-L-14-336__openai': { + dimSize: 768, + }, + 'ViT-H-14__laion2b-s32b-b79k': { + dimSize: 1024, + }, + 'ViT-g-14__laion2b-s12b-b42k': { + dimSize: 1024, + }, + 'LABSE-Vit-L-14': { + dimSize: 768, + }, + 'XLM-Roberta-Large-Vit-B-32': { + dimSize: 512, + }, + 'XLM-Roberta-Large-Vit-B-16Plus': { + dimSize: 640, + }, + 'XLM-Roberta-Large-Vit-L-14': { + dimSize: 768, + }, +}; + +function cleanModelName(modelName: string): string { + const tokens = modelName.split('/'); + return tokens[tokens.length - 1].replace(':', '_'); +} + +export function getCLIPModelInfo(modelName: string): ModelInfo { + const modelInfo = CLIP_MODEL_INFO[cleanModelName(modelName)]; + if (!modelInfo) { + throw new Error(`Unknown CLIP model: ${modelName}`); + } + + return modelInfo; +} diff --git a/server/src/domain/smart-info/smart-info.service.ts b/server/src/domain/smart-info/smart-info.service.ts index 8af3bc7b7..7962aa155 100644 --- a/server/src/domain/smart-info/smart-info.service.ts +++ b/server/src/domain/smart-info/smart-info.service.ts @@ -1,6 +1,7 @@ -import { Inject, Injectable } from '@nestjs/common'; +import { Inject, Injectable, Logger } from '@nestjs/common'; +import { setTimeout } from 'timers/promises'; import { usePagination } from '../domain.util'; -import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName } from '../job'; +import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName, QueueName } from '../job'; import { IAssetRepository, IJobRepository, @@ -14,6 +15,7 @@ import { SystemConfigCore } from '../system-config'; @Injectable() export class SmartInfoService { private configCore: SystemConfigCore; + private logger = new Logger(SmartInfoService.name); constructor( @Inject(IAssetRepository) private assetRepository: IAssetRepository, @@ -25,6 +27,22 @@ export class SmartInfoService { this.configCore = SystemConfigCore.create(configRepository); } + async init() { + await this.jobRepository.pause(QueueName.CLIP_ENCODING); + + let { isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING); + while (isActive) { + this.logger.verbose('Waiting for CLIP encoding queue to stop...'); + await setTimeout(1000).then(async () => { + ({ isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING)); + }); + } + + await this.repository.init(); + + await this.jobRepository.resume(QueueName.CLIP_ENCODING); + } + async handleQueueObjectTagging({ force }: IBaseJob) { const { machineLearning } = await this.configCore.getConfig(); if (!machineLearning.enabled || !machineLearning.classification.enabled) { diff --git a/server/src/domain/system-config/system-config.service.ts b/server/src/domain/system-config/system-config.service.ts index c81c462e8..402a3cf0f 100644 --- a/server/src/domain/system-config/system-config.service.ts +++ b/server/src/domain/system-config/system-config.service.ts @@ -1,6 +1,12 @@ import { Inject, Injectable } from '@nestjs/common'; import { JobName } from '../job'; -import { CommunicationEvent, ICommunicationRepository, IJobRepository, ISystemConfigRepository } from '../repositories'; +import { + CommunicationEvent, + ICommunicationRepository, + IJobRepository, + ISmartInfoRepository, + ISystemConfigRepository, +} from '../repositories'; import { SystemConfigDto, mapConfig } from './dto/system-config.dto'; import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto'; import { @@ -22,6 +28,7 @@ export class SystemConfigService { @Inject(ISystemConfigRepository) private repository: ISystemConfigRepository, @Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository, @Inject(IJobRepository) private jobRepository: IJobRepository, + @Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository, ) { this.core = SystemConfigCore.create(repository); } @@ -41,10 +48,14 @@ export class SystemConfigService { } async updateConfig(dto: SystemConfigDto): Promise { - const config = await this.core.updateConfig(dto); + const oldConfig = await this.core.getConfig(); + const newConfig = await this.core.updateConfig(dto); await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE }); this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {}); - return mapConfig(config); + if (oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName) { + await this.smartInfoRepository.init(); + } + return mapConfig(newConfig); } async refreshConfig() { diff --git a/server/src/infra/repositories/smart-info.repository.ts b/server/src/infra/repositories/smart-info.repository.ts index d8df4054e..df43cf5ca 100644 --- a/server/src/infra/repositories/smart-info.repository.ts +++ b/server/src/infra/repositories/smart-info.repository.ts @@ -1,7 +1,14 @@ -import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain'; +import { + Embedding, + EmbeddingSearch, + ISmartInfoRepository, + ISystemConfigRepository, + SystemConfigCore, +} from '@app/domain'; +import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant'; import { DatabaseLock, RequireLock, asyncLock } from '@app/infra'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities'; -import { Injectable, Logger } from '@nestjs/common'; +import { Inject, Injectable, Logger } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; import { asVector, isValidInteger } from '../infra.utils'; @@ -9,21 +16,40 @@ import { asVector, isValidInteger } from '../infra.utils'; @Injectable() export class SmartInfoRepository implements ISmartInfoRepository { private logger = new Logger(SmartInfoRepository.name); + private configCore: SystemConfigCore; private readonly faceColumns: string[]; - private curDimSize: number | undefined; constructor( @InjectRepository(SmartInfoEntity) private repository: Repository, @InjectRepository(AssetEntity) private assetRepository: Repository, @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository, @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository, + @Inject(ISystemConfigRepository) private configRepository: ISystemConfigRepository, ) { + this.configCore = SystemConfigCore.create(configRepository); this.faceColumns = this.assetFaceRepository.manager.connection .getMetadata(AssetFaceEntity) .ownColumns.map((column) => column.propertyName) .filter((propertyName) => propertyName !== 'embedding'); } + async init(): Promise { + const { machineLearning } = await this.configCore.getConfig(); + const modelName = machineLearning.clip.modelName; + const { dimSize } = getCLIPModelInfo(modelName); + if (dimSize == null) { + throw new Error(`Invalid CLIP model name: ${modelName}`); + } + + const curDimSize = await this.getDimSize(); + this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`); + + if (dimSize != curDimSize) { + this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`); + await this.updateDimSize(dimSize); + } + } + async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise { if (!isValidInteger(numResults, { min: 1 })) { throw new Error(`Invalid value for 'numResults': ${numResults}`); @@ -92,14 +118,6 @@ export class SmartInfoRepository implements ISmartInfoRepository { await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {}); } - if (this.curDimSize == null) { - await this.getDimSize(); - } - - if (this.curDimSize !== embedding.length) { - await this.updateDimSize(embedding.length); - } - await this.smartSearchRepository.upsert( { assetId, embedding: () => asVector(embedding, true) }, { conflictPaths: ['assetId'] }, @@ -112,11 +130,12 @@ export class SmartInfoRepository implements ISmartInfoRepository { throw new Error(`Invalid CLIP dimension size: ${dimSize}`); } - if (this.curDimSize === dimSize) { + const curDimSize = await this.getDimSize(); + if (curDimSize === dimSize) { return; } - this.logger.log(`Current dimension size is ${this.curDimSize}. Updating CLIP dimension size to ${dimSize}.`); + this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); await this.smartSearchRepository.manager.transaction(async (manager) => { await manager.query(`DROP TABLE smart_search`); @@ -135,16 +154,10 @@ export class SmartInfoRepository implements ISmartInfoRepository { $$)`); }); - this.logger.log(`Successfully updated CLIP dimension size from ${this.curDimSize} to ${dimSize}.`); - this.curDimSize = dimSize; + this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`); } - @RequireLock(DatabaseLock.CLIPDimSize) - private async getDimSize(): Promise { - if (this.curDimSize != null) { - return; - } - + private async getDimSize(): Promise { const res = await this.smartSearchRepository.manager.query(` SELECT atttypmod as dimsize FROM pg_attribute f @@ -153,7 +166,11 @@ export class SmartInfoRepository implements ISmartInfoRepository { AND f.attnum > 0 AND c.relname = 'smart_search' AND f.attname = 'embedding'`); - this.curDimSize = res?.[0]?.['dimsize'] ?? 512; - this.logger.verbose(`CLIP dimension size is ${this.curDimSize}`); + + const dimSize = res[0]['dimsize']; + if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { + throw new Error(`Could not retrieve CLIP dimension size`); + } + return dimSize; } }