Explorar o código

handle dim size update outside of jobs

mertalev hai 1 ano
pai
achega
eea1fb83ae

+ 1 - 0
server/src/domain/repositories/smart-info.repository.ts

@@ -12,6 +12,7 @@ export interface EmbeddingSearch {
 }
 }
 
 
 export interface ISmartInfoRepository {
 export interface ISmartInfoRepository {
+  init(): Promise<void>;
   searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
   searchCLIP(search: EmbeddingSearch): Promise<AssetEntity[]>;
   searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
   searchFaces(search: EmbeddingSearch): Promise<AssetFaceEntity[]>;
   upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;
   upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void>;

+ 107 - 0
server/src/domain/smart-info/smart-info.constant.ts

@@ -0,0 +1,107 @@
+export type ModelInfo = {
+  dimSize: number;
+};
+
+export const CLIP_MODEL_INFO: Record<string, ModelInfo> = {
+  RN50__openai: {
+    dimSize: 1024,
+  },
+  RN50__yfcc15m: {
+    dimSize: 1024,
+  },
+  RN50__cc12m: {
+    dimSize: 1024,
+  },
+  RN101__openai: {
+    dimSize: 512,
+  },
+  RN101__yfcc15m: {
+    dimSize: 512,
+  },
+  RN50x4__openai: {
+    dimSize: 640,
+  },
+  RN50x16__openai: {
+    dimSize: 768,
+  },
+  RN50x64__openai: {
+    dimSize: 1024,
+  },
+  'ViT-B-32__openai': {
+    dimSize: 512,
+  },
+  'ViT-B-32__laion2b_e16': {
+    dimSize: 512,
+  },
+  'ViT-B-32__laion400m_e31': {
+    dimSize: 512,
+  },
+  'ViT-B-32__laion400m_e32': {
+    dimSize: 512,
+  },
+  'ViT-B-32__laion2b-s34b-b79k': {
+    dimSize: 512,
+  },
+  'ViT-B-16__openai': {
+    dimSize: 512,
+  },
+  'ViT-B-16__laion400m_e31': {
+    dimSize: 512,
+  },
+  'ViT-B-16__laion400m_e32': {
+    dimSize: 512,
+  },
+  'ViT-B-16-plus-240__laion400m_e31': {
+    dimSize: 640,
+  },
+  'ViT-B-16-plus-240__laion400m_e32': {
+    dimSize: 640,
+  },
+  'ViT-L-14__openai': {
+    dimSize: 768,
+  },
+  'ViT-L-14__laion400m_e31': {
+    dimSize: 768,
+  },
+  'ViT-L-14__laion400m_e32': {
+    dimSize: 768,
+  },
+  'ViT-L-14__laion2b-s32b-b82k': {
+    dimSize: 768,
+  },
+  'ViT-L-14-336__openai': {
+    dimSize: 768,
+  },
+  'ViT-H-14__laion2b-s32b-b79k': {
+    dimSize: 1024,
+  },
+  'ViT-g-14__laion2b-s12b-b42k': {
+    dimSize: 1024,
+  },
+  'LABSE-Vit-L-14': {
+    dimSize: 768,
+  },
+  'XLM-Roberta-Large-Vit-B-32': {
+    dimSize: 512,
+  },
+  'XLM-Roberta-Large-Vit-B-16Plus': {
+    dimSize: 640,
+  },
+  'XLM-Roberta-Large-Vit-L-14': {
+    dimSize: 768,
+  },
+};
+
+function cleanModelName(modelName: string): string {
+  const tokens = modelName.split('/');
+  return tokens[tokens.length - 1].replace(':', '_');
+}
+
+export function getCLIPModelInfo(modelName: string): ModelInfo {
+  const modelInfo = CLIP_MODEL_INFO[cleanModelName(modelName)];
+  if (!modelInfo) {
+    throw new Error(`Unknown CLIP model: ${modelName}`);
+  }
+
+  return modelInfo;
+}

+ 20 - 2
server/src/domain/smart-info/smart-info.service.ts

@@ -1,6 +1,7 @@
-import { Inject, Injectable } from '@nestjs/common';
+import { Inject, Injectable, Logger } from '@nestjs/common';
+import { setTimeout } from 'timers/promises';
 import { usePagination } from '../domain.util';
 import { usePagination } from '../domain.util';
-import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName } from '../job';
+import { IBaseJob, IEntityJob, JOBS_ASSET_PAGINATION_SIZE, JobName, QueueName } from '../job';
 import {
 import {
   IAssetRepository,
   IAssetRepository,
   IJobRepository,
   IJobRepository,
@@ -14,6 +15,7 @@ import { SystemConfigCore } from '../system-config';
 @Injectable()
 @Injectable()
 export class SmartInfoService {
 export class SmartInfoService {
   private configCore: SystemConfigCore;
   private configCore: SystemConfigCore;
+  private logger = new Logger(SmartInfoService.name);
 
 
   constructor(
   constructor(
     @Inject(IAssetRepository) private assetRepository: IAssetRepository,
     @Inject(IAssetRepository) private assetRepository: IAssetRepository,
@@ -25,6 +27,22 @@ export class SmartInfoService {
     this.configCore = SystemConfigCore.create(configRepository);
     this.configCore = SystemConfigCore.create(configRepository);
   }
   }
 
 
+  async init() {
+    await this.jobRepository.pause(QueueName.CLIP_ENCODING);
+
+    let { isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING);
+    while (isActive) {
+      this.logger.verbose('Waiting for CLIP encoding queue to stop...');
+      await setTimeout(1000).then(async () => {
+        ({ isActive } = await this.jobRepository.getQueueStatus(QueueName.CLIP_ENCODING));
+      });
+    }
+
+    await this.repository.init();
+
+    await this.jobRepository.resume(QueueName.CLIP_ENCODING);
+  }
+
   async handleQueueObjectTagging({ force }: IBaseJob) {
   async handleQueueObjectTagging({ force }: IBaseJob) {
     const { machineLearning } = await this.configCore.getConfig();
     const { machineLearning } = await this.configCore.getConfig();
     if (!machineLearning.enabled || !machineLearning.classification.enabled) {
     if (!machineLearning.enabled || !machineLearning.classification.enabled) {

+ 14 - 3
server/src/domain/system-config/system-config.service.ts

@@ -1,6 +1,12 @@
 import { Inject, Injectable } from '@nestjs/common';
 import { Inject, Injectable } from '@nestjs/common';
 import { JobName } from '../job';
 import { JobName } from '../job';
-import { CommunicationEvent, ICommunicationRepository, IJobRepository, ISystemConfigRepository } from '../repositories';
+import {
+  CommunicationEvent,
+  ICommunicationRepository,
+  IJobRepository,
+  ISmartInfoRepository,
+  ISystemConfigRepository,
+} from '../repositories';
 import { SystemConfigDto, mapConfig } from './dto/system-config.dto';
 import { SystemConfigDto, mapConfig } from './dto/system-config.dto';
 import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto';
 import { SystemConfigTemplateStorageOptionDto } from './response-dto/system-config-template-storage-option.dto';
 import {
 import {
@@ -22,6 +28,7 @@ export class SystemConfigService {
     @Inject(ISystemConfigRepository) private repository: ISystemConfigRepository,
     @Inject(ISystemConfigRepository) private repository: ISystemConfigRepository,
     @Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository,
     @Inject(ICommunicationRepository) private communicationRepository: ICommunicationRepository,
     @Inject(IJobRepository) private jobRepository: IJobRepository,
     @Inject(IJobRepository) private jobRepository: IJobRepository,
+    @Inject(ISmartInfoRepository) private smartInfoRepository: ISmartInfoRepository,
   ) {
   ) {
     this.core = SystemConfigCore.create(repository);
     this.core = SystemConfigCore.create(repository);
   }
   }
@@ -41,10 +48,14 @@ export class SystemConfigService {
   }
   }
 
 
   async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> {
   async updateConfig(dto: SystemConfigDto): Promise<SystemConfigDto> {
-    const config = await this.core.updateConfig(dto);
+    const oldConfig = await this.core.getConfig();
+    const newConfig = await this.core.updateConfig(dto);
     await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE });
     await this.jobRepository.queue({ name: JobName.SYSTEM_CONFIG_CHANGE });
     this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {});
     this.communicationRepository.broadcast(CommunicationEvent.CONFIG_UPDATE, {});
-    return mapConfig(config);
+    if (oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName) {
+      await this.smartInfoRepository.init();
+    }
+    return mapConfig(newConfig);
   }
   }
 
 
   async refreshConfig() {
   async refreshConfig() {

+ 40 - 23
server/src/infra/repositories/smart-info.repository.ts

@@ -1,7 +1,14 @@
-import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain';
+import {
+  Embedding,
+  EmbeddingSearch,
+  ISmartInfoRepository,
+  ISystemConfigRepository,
+  SystemConfigCore,
+} from '@app/domain';
+import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
 import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
 import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
 import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
 import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
-import { Injectable, Logger } from '@nestjs/common';
+import { Inject, Injectable, Logger } from '@nestjs/common';
 import { InjectRepository } from '@nestjs/typeorm';
 import { InjectRepository } from '@nestjs/typeorm';
 import { Repository } from 'typeorm';
 import { Repository } from 'typeorm';
 import { asVector, isValidInteger } from '../infra.utils';
 import { asVector, isValidInteger } from '../infra.utils';
@@ -9,21 +16,40 @@ import { asVector, isValidInteger } from '../infra.utils';
 @Injectable()
 @Injectable()
 export class SmartInfoRepository implements ISmartInfoRepository {
 export class SmartInfoRepository implements ISmartInfoRepository {
   private logger = new Logger(SmartInfoRepository.name);
   private logger = new Logger(SmartInfoRepository.name);
+  private configCore: SystemConfigCore;
   private readonly faceColumns: string[];
   private readonly faceColumns: string[];
-  private curDimSize: number | undefined;
 
 
   constructor(
   constructor(
     @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
     @InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
     @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
     @InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
     @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
     @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
     @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
     @InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
+    @Inject(ISystemConfigRepository) private configRepository: ISystemConfigRepository,
   ) {
   ) {
+    this.configCore = SystemConfigCore.create(configRepository);
     this.faceColumns = this.assetFaceRepository.manager.connection
     this.faceColumns = this.assetFaceRepository.manager.connection
       .getMetadata(AssetFaceEntity)
       .getMetadata(AssetFaceEntity)
       .ownColumns.map((column) => column.propertyName)
       .ownColumns.map((column) => column.propertyName)
       .filter((propertyName) => propertyName !== 'embedding');
       .filter((propertyName) => propertyName !== 'embedding');
   }
   }
 
 
+  async init(): Promise<void> {
+    const { machineLearning } = await this.configCore.getConfig();
+    const modelName = machineLearning.clip.modelName;
+    const { dimSize } = getCLIPModelInfo(modelName);
+    if (dimSize == null) {
+      throw new Error(`Invalid CLIP model name: ${modelName}`);
+    }
+
+    const curDimSize = await this.getDimSize();
+    this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`);
+
+    if (dimSize != curDimSize) {
+      this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`);
+      await this.updateDimSize(dimSize);
+    }
+  }
+
   async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
   async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
     if (!isValidInteger(numResults, { min: 1 })) {
     if (!isValidInteger(numResults, { min: 1 })) {
       throw new Error(`Invalid value for 'numResults': ${numResults}`);
       throw new Error(`Invalid value for 'numResults': ${numResults}`);
@@ -92,14 +118,6 @@ export class SmartInfoRepository implements ISmartInfoRepository {
       await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
       await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
     }
     }
 
 
-    if (this.curDimSize == null) {
-      await this.getDimSize();
-    }
-
-    if (this.curDimSize !== embedding.length) {
-      await this.updateDimSize(embedding.length);
-    }
-
     await this.smartSearchRepository.upsert(
     await this.smartSearchRepository.upsert(
       { assetId, embedding: () => asVector(embedding, true) },
       { assetId, embedding: () => asVector(embedding, true) },
       { conflictPaths: ['assetId'] },
       { conflictPaths: ['assetId'] },
@@ -112,11 +130,12 @@ export class SmartInfoRepository implements ISmartInfoRepository {
       throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
       throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
     }
     }
 
 
-    if (this.curDimSize === dimSize) {
+    const curDimSize = await this.getDimSize();
+    if (curDimSize === dimSize) {
       return;
       return;
     }
     }
 
 
-    this.logger.log(`Current dimension size is ${this.curDimSize}. Updating CLIP dimension size to ${dimSize}.`);
+    this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
 
 
     await this.smartSearchRepository.manager.transaction(async (manager) => {
     await this.smartSearchRepository.manager.transaction(async (manager) => {
       await manager.query(`DROP TABLE smart_search`);
       await manager.query(`DROP TABLE smart_search`);
@@ -135,16 +154,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
           $$)`);
           $$)`);
     });
     });
 
 
-    this.logger.log(`Successfully updated CLIP dimension size from ${this.curDimSize} to ${dimSize}.`);
-    this.curDimSize = dimSize;
+    this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`);
   }
   }
 
 
-  @RequireLock(DatabaseLock.CLIPDimSize)
-  private async getDimSize(): Promise<void> {
-    if (this.curDimSize != null) {
-      return;
-    }
-
+  private async getDimSize(): Promise<number> {
     const res = await this.smartSearchRepository.manager.query(`
     const res = await this.smartSearchRepository.manager.query(`
       SELECT atttypmod as dimsize
       SELECT atttypmod as dimsize
       FROM pg_attribute f
       FROM pg_attribute f
@@ -153,7 +166,11 @@ export class SmartInfoRepository implements ISmartInfoRepository {
         AND f.attnum > 0
         AND f.attnum > 0
         AND c.relname = 'smart_search'
         AND c.relname = 'smart_search'
         AND f.attname = 'embedding'`);
         AND f.attname = 'embedding'`);
-    this.curDimSize = res?.[0]?.['dimsize'] ?? 512;
-    this.logger.verbose(`CLIP dimension size is ${this.curDimSize}`);
+
+    const dimSize = res[0]['dimsize'];
+    if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
+      throw new Error(`Could not retrieve CLIP dimension size`);
+    }
+    return dimSize;
   }
   }
 }
 }