handle dim size update outside of jobs
This commit is contained in:
parent
1fe9075d11
commit
eea1fb83ae
5 changed files with 182 additions and 28 deletions
|
@ -12,6 +12,7 @@ export interface EmbeddingSearch {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ISmartInfoRepository {
|
export interface ISmartInfoRepository {
|
||||||
|
init(): Promise<void>;
|
||||||
searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
|
searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
|
||||||
searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
|
searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
|
||||||
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
|
upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
|
||||||
|
|
107
server/src/domain/smart-info/smart-info.constant.ts
Normal file
107
server/src/domain/smart-info/smart-info.constant.ts
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
export type ModelInfo = {
|
||||||
|
dimSize: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const CLIP_MODEL_INFO: Record<string, ModelInfo> = {
|
||||||
|
RN50__openai: {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
RN50__yfcc15m: {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
RN50__cc12m: {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
RN101__openai: {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
RN101__yfcc15m: {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
RN50x4__openai: {
|
||||||
|
dimSize: 640,
|
||||||
|
},
|
||||||
|
RN50x16__openai: {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
RN50x64__openai: {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
'ViT-B-32__openai': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-32__laion2b_e16': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-32__laion400m_e31': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-32__laion400m_e32': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-32__laion2b-s34b-b79k': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-16__openai': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-16__laion400m_e31': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-16__laion400m_e32': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'ViT-B-16-plus-240__laion400m_e31': {
|
||||||
|
dimSize: 640,
|
||||||
|
},
|
||||||
|
'ViT-B-16-plus-240__laion400m_e32': {
|
||||||
|
dimSize: 640,
|
||||||
|
},
|
||||||
|
'ViT-L-14__openai': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'ViT-L-14__laion400m_e31': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'ViT-L-14__laion400m_e32': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'ViT-L-14__laion2b-s32b-b82k': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'ViT-L-14-336__openai': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'ViT-H-14__laion2b-s32b-b79k': {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
'ViT-g-14__laion2b-s12b-b42k': {
|
||||||
|
dimSize: 1024,
|
||||||
|
},
|
||||||
|
'LABSE-Vit-L-14': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
'XLM-Roberta-Large-Vit-B-32': {
|
||||||
|
dimSize: 512,
|
||||||
|
},
|
||||||
|
'XLM-Roberta-Large-Vit-B-16Plus': {
|
||||||
|
dimSize: 640,
|
||||||
|
},
|
||||||
|
'XLM-Roberta-Large-Vit-L-14': {
|
||||||
|
dimSize: 768,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
function cleanModelName(modelName: string): string {
|
||||||
|
const tokens = modelName.split('/');
|
||||||
|
return tokens[tokens.length - 1].replace(':', '_');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getCLIPModelInfo(modelName: string): ModelInfo {
|
||||||
|
const modelInfo = CLIP_MODEL_INFO[cleanModelName(modelName)];
|
||||||
|
if (!modelInfo) {
|
||||||
|
throw new Error(`Unknown CLIP model: ${modelName}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelInfo;
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
import { Inject, Injectable } from '@nestjs/common';
|
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||||
|
import { setTimeout } from 'timers/promises';
|
||||||
import { usePagination } from '../domain.util';
|
import { usePagination } from '../domain.util';
|
||||||
import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName } from '../job';
|
import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName, QueueName } from '../job';
|
||||||
import {
|
import {
|
||||||
IAssetRepository,
|
IAssetRepository,
|
||||||
IJobRepository,
|
IJobRepository,
|
||||||
|
@ -14,6 +15,7 @@ import { SystemConfigCore } from '../system-config';
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class SmartInfoService {
|
export class SmartInfoService {
|
||||||
private configCore: SystemConfigCore;
|
private configCore: SystemConfigCore;
|
||||||
|
private logger = new Logger(SmartInfoService.name);
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
|
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
|
||||||
|
@ -25,6 +27,22 @@ export class SmartInfoService {
|
||||||
this.configCore = SystemConfigCore.create(configRepository);
|
this.configCore = SystemConfigCore.create(configRepository);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async init() {
|
||||||
|
await this.jobRepository.pause(QueueName.CLIP_ENCODING);
|
||||||
|
|
||||||
|
let { isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING);
|
||||||
|
while (isActive) {
|
||||||
|
this.logger.verbose('Waiting for CLIP encoding queue to stop...');
|
||||||
|
await setTimeout(1000).then(async () => {
|
||||||
|
({ isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.repository.init();
|
||||||
|
|
||||||
|
await this.jobRepository.resume(QueueName.CLIP_ENCODING);
|
||||||
|
}
|
||||||
|
|
||||||
async handleQueueObjectTagging({ force }: IBaseJob) {
|
async handleQueueObjectTagging({ force }: IBaseJob) {
|
||||||
const { machineLearning } = await this.configCore.getConfig();
|
const { machineLearning } = await this.configCore.getConfig();
|
||||||
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
|
if (!machineLearning.enabled || !machineLearning.classification.enabled) {
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
import { Inject, Injectable } from '@nestjs/common';
|
import { Inject, Injectable } from '@nestjs/common';
|
||||||
import { JobName } from '../job';
|
import { JobName } from '../job';
|
||||||
import { CommunicationEvent, ICommunicationRepository, IJobRepository, ISystemConfigRepository } from '../repositories';
|
import {
|
||||||
|
CommunicationEvent,
|
||||||
|
ICommunicationRepository,
|
||||||
|
IJobRepository,
|
||||||
|
ISmartInfoRepository,
|
||||||
|
ISystemConfigRepository,
|
||||||
|
} from '../repositories';
|
||||||
import { SystemConfigDto, mapConfig } from './dto/system-config.dto';
|
import { SystemConfigDto, mapConfig } from './dto/system-config.dto';
|
||||||
import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto';
|
import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto';
|
||||||
import {
|
import {
|
||||||
|
@ -22,6 +28,7 @@ export class SystemConfigService {
|
||||||
@Inject(ISystemConfigRepository) private repository: ISystemConfigRepository,
|
@Inject(ISystemConfigRepository) private repository: ISystemConfigRepository,
|
||||||
@Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository,
|
@Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository,
|
||||||
@Inject(IJobRepository) private jobRepository: IJobRepository,
|
@Inject(IJobRepository) private jobRepository: IJobRepository,
|
||||||
|
@Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
|
||||||
) {
|
) {
|
||||||
this.core = SystemConfigCore.create(repository);
|
this.core = SystemConfigCore.create(repository);
|
||||||
}
|
}
|
||||||
|
@ -41,10 +48,14 @@ export class SystemConfigService {
|
||||||
}
|
}
|
||||||
|
|
||||||
async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> {
|
async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> {
|
||||||
const config = await this.core.updateConfig(dto);
|
const oldConfig = await this.core.getConfig();
|
||||||
|
const newConfig = await this.core.updateConfig(dto);
|
||||||
await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE });
|
await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE });
|
||||||
this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {});
|
this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {});
|
||||||
return mapConfig(config);
|
if (oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName) {
|
||||||
|
await this.smartInfoRepository.init();
|
||||||
|
}
|
||||||
|
return mapConfig(newConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
async refreshConfig() {
|
async refreshConfig() {
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain';
|
import {
|
||||||
|
Embedding,
|
||||||
|
EmbeddingSearch,
|
||||||
|
ISmartInfoRepository,
|
||||||
|
ISystemConfigRepository,
|
||||||
|
SystemConfigCore,
|
||||||
|
} from '@app/domain';
|
||||||
|
import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
|
||||||
import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
|
import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
|
||||||
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
|
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
|
||||||
import { Injectable, Logger } from '@nestjs/common';
|
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||||
import { InjectRepository } from '@nestjs/typeorm';
|
import { InjectRepository } from '@nestjs/typeorm';
|
||||||
import { Repository } from 'typeorm';
|
import { Repository } from 'typeorm';
|
||||||
import { asVector, isValidInteger } from '../infra.utils';
|
import { asVector, isValidInteger } from '../infra.utils';
|
||||||
|
@ -9,21 +16,40 @@ import { asVector, isValidInteger } from '../infra.utils';
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class SmartInfoRepository implements ISmartInfoRepository {
|
export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
private logger = new Logger(SmartInfoRepository.name);
|
private logger = new Logger(SmartInfoRepository.name);
|
||||||
|
private configCore: SystemConfigCore;
|
||||||
private readonly faceColumns: string[];
|
private readonly faceColumns: string[];
|
||||||
private curDimSize: number | undefined;
|
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
||||||
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
||||||
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
||||||
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
||||||
|
@Inject(ISystemConfigRepository) private configRepository: ISystemConfigRepository,
|
||||||
) {
|
) {
|
||||||
|
this.configCore = SystemConfigCore.create(configRepository);
|
||||||
this.faceColumns = this.assetFaceRepository.manager.connection
|
this.faceColumns = this.assetFaceRepository.manager.connection
|
||||||
.getMetadata(AssetFaceEntity)
|
.getMetadata(AssetFaceEntity)
|
||||||
.ownColumns.map((column) => column.propertyName)
|
.ownColumns.map((column) => column.propertyName)
|
||||||
.filter((propertyName) => propertyName !== 'embedding');
|
.filter((propertyName) => propertyName !== 'embedding');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async init(): Promise<void> {
|
||||||
|
const { machineLearning } = await this.configCore.getConfig();
|
||||||
|
const modelName = machineLearning.clip.modelName;
|
||||||
|
const { dimSize } = getCLIPModelInfo(modelName);
|
||||||
|
if (dimSize == null) {
|
||||||
|
throw new Error(`Invalid CLIP model name: ${modelName}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const curDimSize = await this.getDimSize();
|
||||||
|
this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`);
|
||||||
|
|
||||||
|
if (dimSize != curDimSize) {
|
||||||
|
this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`);
|
||||||
|
await this.updateDimSize(dimSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
||||||
if (!isValidInteger(numResults, { min: 1 })) {
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
|
@ -92,14 +118,6 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
|
await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.curDimSize == null) {
|
|
||||||
await this.getDimSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.curDimSize !== embedding.length) {
|
|
||||||
await this.updateDimSize(embedding.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.smartSearchRepository.upsert(
|
await this.smartSearchRepository.upsert(
|
||||||
{ assetId, embedding: () => asVector(embedding, true) },
|
{ assetId, embedding: () => asVector(embedding, true) },
|
||||||
{ conflictPaths: ['assetId'] },
|
{ conflictPaths: ['assetId'] },
|
||||||
|
@ -112,11 +130,12 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.curDimSize === dimSize) {
|
const curDimSize = await this.getDimSize();
|
||||||
|
if (curDimSize === dimSize) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.logger.log(`Current dimension size is ${this.curDimSize}. Updating CLIP dimension size to ${dimSize}.`);
|
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
||||||
|
|
||||||
await this.smartSearchRepository.manager.transaction(async (manager) => {
|
await this.smartSearchRepository.manager.transaction(async (manager) => {
|
||||||
await manager.query(`DROP TABLE smart_search`);
|
await manager.query(`DROP TABLE smart_search`);
|
||||||
|
@ -135,16 +154,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
$$)`);
|
$$)`);
|
||||||
});
|
});
|
||||||
|
|
||||||
this.logger.log(`Successfully updated CLIP dimension size from ${this.curDimSize} to ${dimSize}.`);
|
this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`);
|
||||||
this.curDimSize = dimSize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@RequireLock(DatabaseLock.CLIPDimSize)
|
private async getDimSize(): Promise<number> {
|
||||||
private async getDimSize(): Promise<void> {
|
|
||||||
if (this.curDimSize != null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const res = await this.smartSearchRepository.manager.query(`
|
const res = await this.smartSearchRepository.manager.query(`
|
||||||
SELECT atttypmod as dimsize
|
SELECT atttypmod as dimsize
|
||||||
FROM pg_attribute f
|
FROM pg_attribute f
|
||||||
|
@ -153,7 +166,11 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
AND f.attnum > 0
|
AND f.attnum > 0
|
||||||
AND c.relname = 'smart_search'
|
AND c.relname = 'smart_search'
|
||||||
AND f.attname = 'embedding'`);
|
AND f.attname = 'embedding'`);
|
||||||
this.curDimSize = res?.[0]?.['dimsize'] ?? 512;
|
|
||||||
this.logger.verbose(`CLIP dimension size is ${this.curDimSize}`);
|
const dimSize = res[0]['dimsize'];
|
||||||
|
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||||
|
throw new Error(`Could not retrieve CLIP dimension size`);
|
||||||
|
}
|
||||||
|
return dimSize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue