handle dim size update outside of jobs

This commit is contained in:
mertalev 2023-12-04 18:59:01 -05:00
parent 1fe9075d11
commit eea1fb83ae
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
5 changed files with 182 additions and 28 deletions

View file

@ -12,6 +12,7 @@ export interface EmbeddingSearch {
} }
export interface ISmartInfoRepository { export interface ISmartInfoRepository {
init(): Promise<void>;
searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>; searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>; searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>; upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;

View file

@ -0,0 +1,107 @@
export type ModelInfo = {
dimSize: number;
};
export const CLIP_MODEL_INFO: Record<string, ModelInfo> = {
RN50__openai: {
dimSize: 1024,
},
RN50__yfcc15m: {
dimSize: 1024,
},
RN50__cc12m: {
dimSize: 1024,
},
RN101__openai: {
dimSize: 512,
},
RN101__yfcc15m: {
dimSize: 512,
},
RN50x4__openai: {
dimSize: 640,
},
RN50x16__openai: {
dimSize: 768,
},
RN50x64__openai: {
dimSize: 1024,
},
'ViT-B-32__openai': {
dimSize: 512,
},
'ViT-B-32__laion2b_e16': {
dimSize: 512,
},
'ViT-B-32__laion400m_e31': {
dimSize: 512,
},
'ViT-B-32__laion400m_e32': {
dimSize: 512,
},
'ViT-B-32__laion2b-s34b-b79k': {
dimSize: 512,
},
'ViT-B-16__openai': {
dimSize: 512,
},
'ViT-B-16__laion400m_e31': {
dimSize: 512,
},
'ViT-B-16__laion400m_e32': {
dimSize: 512,
},
'ViT-B-16-plus-240__laion400m_e31': {
dimSize: 640,
},
'ViT-B-16-plus-240__laion400m_e32': {
dimSize: 640,
},
'ViT-L-14__openai': {
dimSize: 768,
},
'ViT-L-14__laion400m_e31': {
dimSize: 768,
},
'ViT-L-14__laion400m_e32': {
dimSize: 768,
},
'ViT-L-14__laion2b-s32b-b82k': {
dimSize: 768,
},
'ViT-L-14-336__openai': {
dimSize: 768,
},
'ViT-H-14__laion2b-s32b-b79k': {
dimSize: 1024,
},
'ViT-g-14__laion2b-s12b-b42k': {
dimSize: 1024,
},
'LABSE-Vit-L-14': {
dimSize: 768,
},
'XLM-Roberta-Large-Vit-B-32': {
dimSize: 512,
},
'XLM-Roberta-Large-Vit-B-16Plus': {
dimSize: 640,
},
'XLM-Roberta-Large-Vit-L-14': {
dimSize: 768,
},
};
function cleanModelName(modelName: string): string {
const tokens = modelName.split('/');
return tokens[tokens.length - 1].replace(':', '_');
}
export function getCLIPModelInfo(modelName: string): ModelInfo {
const modelInfo = CLIP_MODEL_INFO[cleanModelName(modelName)];
if (!modelInfo) {
throw new Error(`Unknown CLIP model: ${modelName}`);
}
return modelInfo;
}

View file

@ -1,6 +1,7 @@
import { Inject, Injectable } from '@nestjs/common'; import { Inject, Injectable, Logger } from '@nestjs/common';
import { setTimeout } from 'timers/promises';
import { usePagination } from '../domain.util'; import { usePagination } from '../domain.util';
import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName } from '../job'; import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName, QueueName } from '../job';
import { import {
IAssetRepository, IAssetRepository,
IJobRepository, IJobRepository,
@ -14,6 +15,7 @@ import { SystemConfigCore } from '../system-config';
@Injectable() @Injectable()
export class SmartInfoService { export class SmartInfoService {
private configCore: SystemConfigCore; private configCore: SystemConfigCore;
private logger = new Logger(SmartInfoService.name);
constructor( constructor(
@Inject(IAssetRepository) private assetRepository: IAssetRepository, @Inject(IAssetRepository) private assetRepository: IAssetRepository,
@ -25,6 +27,22 @@ export class SmartInfoService {
this.configCore = SystemConfigCore.create(configRepository); this.configCore = SystemConfigCore.create(configRepository);
} }
async init() {
await this.jobRepository.pause(QueueName.CLIP_ENCODING);
let { isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING);
while (isActive) {
this.logger.verbose('Waiting for CLIP encoding queue to stop...');
await setTimeout(1000).then(async () => {
({ isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING));
});
}
await this.repository.init();
await this.jobRepository.resume(QueueName.CLIP_ENCODING);
}
async handleQueueObjectTagging({ force }: IBaseJob) { async handleQueueObjectTagging({ force }: IBaseJob) {
const { machineLearning } = await this.configCore.getConfig(); const { machineLearning } = await this.configCore.getConfig();
if (!machineLearning.enabled || !machineLearning.classification.enabled) { if (!machineLearning.enabled || !machineLearning.classification.enabled) {

View file

@ -1,6 +1,12 @@
import { Inject, Injectable } from '@nestjs/common'; import { Inject, Injectable } from '@nestjs/common';
import { JobName } from '../job'; import { JobName } from '../job';
import { CommunicationEvent, ICommunicationRepository, IJobRepository, ISystemConfigRepository } from '../repositories'; import {
CommunicationEvent,
ICommunicationRepository,
IJobRepository,
ISmartInfoRepository,
ISystemConfigRepository,
} from '../repositories';
import { SystemConfigDto, mapConfig } from './dto/system-config.dto'; import { SystemConfigDto, mapConfig } from './dto/system-config.dto';
import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto'; import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto';
import { import {
@ -22,6 +28,7 @@ export class SystemConfigService {
@Inject(ISystemConfigRepository) private repository: ISystemConfigRepository, @Inject(ISystemConfigRepository) private repository: ISystemConfigRepository,
@Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository, @Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository, @Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
) { ) {
this.core = SystemConfigCore.create(repository); this.core = SystemConfigCore.create(repository);
} }
@ -41,10 +48,14 @@ export class SystemConfigService {
} }
async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> { async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> {
const config = await this.core.updateConfig(dto); const oldConfig = await this.core.getConfig();
const newConfig = await this.core.updateConfig(dto);
await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE }); await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE });
this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {}); this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {});
return mapConfig(config); if (oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName) {
await this.smartInfoRepository.init();
}
return mapConfig(newConfig);
} }
async refreshConfig() { async refreshConfig() {

View file

@ -1,7 +1,14 @@
import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain'; import {
Embedding,
EmbeddingSearch,
ISmartInfoRepository,
ISystemConfigRepository,
SystemConfigCore,
} from '@app/domain';
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
import { DatabaseLock, RequireLock, asyncLock } from '@app/infra'; import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities'; import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
import { Injectable, Logger } from '@nestjs/common'; import { Inject, Injectable, Logger } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm'; import { Repository } from 'typeorm';
import { asVector, isValidInteger } from '../infra.utils'; import { asVector, isValidInteger } from '../infra.utils';
@ -9,21 +16,40 @@ import { asVector, isValidInteger } from '../infra.utils';
@Injectable() @Injectable()
export class SmartInfoRepository implements ISmartInfoRepository { export class SmartInfoRepository implements ISmartInfoRepository {
private logger = new Logger(SmartInfoRepository.name); private logger = new Logger(SmartInfoRepository.name);
private configCore: SystemConfigCore;
private readonly faceColumns: string[]; private readonly faceColumns: string[];
private curDimSize: number | undefined;
constructor( constructor(
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>, @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>, @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>, @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>, @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
@Inject(ISystemConfigRepository) private configRepository: ISystemConfigRepository,
) { ) {
this.configCore = SystemConfigCore.create(configRepository);
this.faceColumns = this.assetFaceRepository.manager.connection this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity) .getMetadata(AssetFaceEntity)
.ownColumns.map((column) => column.propertyName) .ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding'); .filter((propertyName) => propertyName !== 'embedding');
} }
async init(): Promise<void> {
const { machineLearning } = await this.configCore.getConfig();
const modelName = machineLearning.clip.modelName;
const { dimSize } = getCLIPModelInfo(modelName);
if (dimSize == null) {
throw new Error(`Invalid CLIP model name: ${modelName}`);
}
const curDimSize = await this.getDimSize();
this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`);
if (dimSize != curDimSize) {
this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`);
await this.updateDimSize(dimSize);
}
}
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> { async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
if (!isValidInteger(numResults, { min: 1 })) { if (!isValidInteger(numResults, { min: 1 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`); throw new Error(`Invalid value for 'numResults': ${numResults}`);
@ -92,14 +118,6 @@ export class SmartInfoRepository implements ISmartInfoRepository {
await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {}); await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
} }
if (this.curDimSize == null) {
await this.getDimSize();
}
if (this.curDimSize !== embedding.length) {
await this.updateDimSize(embedding.length);
}
await this.smartSearchRepository.upsert( await this.smartSearchRepository.upsert(
{ assetId, embedding: () => asVector(embedding, true) }, { assetId, embedding: () => asVector(embedding, true) },
{ conflictPaths: ['assetId'] }, { conflictPaths: ['assetId'] },
@ -112,11 +130,12 @@ export class SmartInfoRepository implements ISmartInfoRepository {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`); throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
} }
if (this.curDimSize === dimSize) { const curDimSize = await this.getDimSize();
if (curDimSize === dimSize) {
return; return;
} }
this.logger.log(`Current dimension size is ${this.curDimSize}. Updating CLIP dimension size to ${dimSize}.`); this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
await this.smartSearchRepository.manager.transaction(async (manager) => { await this.smartSearchRepository.manager.transaction(async (manager) => {
await manager.query(`DROP TABLE smart_search`); await manager.query(`DROP TABLE smart_search`);
@ -135,16 +154,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
$$)`); $$)`);
}); });
this.logger.log(`Successfully updated CLIP dimension size from ${this.curDimSize} to ${dimSize}.`); this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`);
this.curDimSize = dimSize;
} }
@RequireLock(DatabaseLock.CLIPDimSize) private async getDimSize(): Promise<number> {
private async getDimSize(): Promise<void> {
if (this.curDimSize != null) {
return;
}
const res = await this.smartSearchRepository.manager.query(` const res = await this.smartSearchRepository.manager.query(`
SELECT atttypmod as dimsize SELECT atttypmod as dimsize
FROM pg_attribute f FROM pg_attribute f
@ -153,7 +166,11 @@ export class SmartInfoRepository implements ISmartInfoRepository {
AND f.attnum > 0 AND f.attnum > 0
AND c.relname = 'smart_search' AND c.relname = 'smart_search'
AND f.attname = 'embedding'`); AND f.attname = 'embedding'`);
this.curDimSize = res?.[0]?.['dimsize'] ?? 512;
this.logger.verbose(`CLIP dimension size is ${this.curDimSize}`); const dimSize = res[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve CLIP dimension size`);
}
return dimSize;
} }
} }