|
@@ -1,7 +1,14 @@
|
|
|
-import { Embedding, EmbeddingSearch, ISmartInfoRepository } from '@app/domain';
|
|
|
+import {
|
|
|
+ Embedding,
|
|
|
+ EmbeddingSearch,
|
|
|
+ ISmartInfoRepository,
|
|
|
+ ISystemConfigRepository,
|
|
|
+ SystemConfigCore,
|
|
|
+} from '@app/domain';
|
|
|
+import { getCLIPModelInfo } from '@app/domain/smart-info/smart-info.constant';
|
|
|
import { DatabaseLock, RequireLock, asyncLock } from '@app/infra';
|
|
|
import { AssetEntity, AssetFaceEntity, SmartInfoEntity, SmartSearchEntity } from '@app/infra/entities';
|
|
|
-import { Injectable, Logger } from '@nestjs/common';
|
|
|
+import { Inject, Injectable, Logger } from '@nestjs/common';
|
|
|
import { InjectRepository } from '@nestjs/typeorm';
|
|
|
import { Repository } from 'typeorm';
|
|
|
import { asVector, isValidInteger } from '../infra.utils';
|
|
@@ -9,21 +16,40 @@ import { asVector, isValidInteger } from '../infra.utils';
|
|
|
@Injectable()
|
|
|
export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
private logger = new Logger(SmartInfoRepository.name);
|
|
|
+ private configCore: SystemConfigCore;
|
|
|
private readonly faceColumns: string[];
|
|
|
- private curDimSize: number | undefined;
|
|
|
|
|
|
constructor(
|
|
|
@InjectRepository(SmartInfoEntity) private repository: Repository<SmartInfoEntity>,
|
|
|
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
|
|
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
|
|
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
|
|
|
+ @Inject(ISystemConfigRepository) private configRepository: ISystemConfigRepository,
|
|
|
) {
|
|
|
+ this.configCore = SystemConfigCore.create(configRepository);
|
|
|
this.faceColumns = this.assetFaceRepository.manager.connection
|
|
|
.getMetadata(AssetFaceEntity)
|
|
|
.ownColumns.map((column) => column.propertyName)
|
|
|
.filter((propertyName) => propertyName !== 'embedding');
|
|
|
}
|
|
|
|
|
|
+ async init(): Promise<void> {
|
|
|
+ const { machineLearning } = await this.configCore.getConfig();
|
|
|
+ const modelName = machineLearning.clip.modelName;
|
|
|
+ const { dimSize } = getCLIPModelInfo(modelName);
|
|
|
+ if (dimSize == null) {
|
|
|
+ throw new Error(`Invalid CLIP model name: ${modelName}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ const curDimSize = await this.getDimSize();
|
|
|
+ this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`);
|
|
|
+
|
|
|
+ if (dimSize != curDimSize) {
|
|
|
+ this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`);
|
|
|
+ await this.updateDimSize(dimSize);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
async searchCLIP({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
|
|
if (!isValidInteger(numResults, { min: 1 })) {
|
|
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
|
@@ -92,14 +118,6 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
await asyncLock.acquire(DatabaseLock[DatabaseLock.CLIPDimSize], () => {});
|
|
|
}
|
|
|
|
|
|
- if (this.curDimSize == null) {
|
|
|
- await this.getDimSize();
|
|
|
- }
|
|
|
-
|
|
|
- if (this.curDimSize !== embedding.length) {
|
|
|
- await this.updateDimSize(embedding.length);
|
|
|
- }
|
|
|
-
|
|
|
await this.smartSearchRepository.upsert(
|
|
|
{ assetId, embedding: () => asVector(embedding, true) },
|
|
|
{ conflictPaths: ['assetId'] },
|
|
@@ -112,11 +130,12 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
|
|
}
|
|
|
|
|
|
- if (this.curDimSize === dimSize) {
|
|
|
+ const curDimSize = await this.getDimSize();
|
|
|
+ if (curDimSize === dimSize) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- this.logger.log(`Current dimension size is ${this.curDimSize}. Updating CLIP dimension size to ${dimSize}.`);
|
|
|
+ this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
|
|
|
|
|
await this.smartSearchRepository.manager.transaction(async (manager) => {
|
|
|
await manager.query(`DROP TABLE smart_search`);
|
|
@@ -135,16 +154,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
$$)`);
|
|
|
});
|
|
|
|
|
|
- this.logger.log(`Successfully updated CLIP dimension size from ${this.curDimSize} to ${dimSize}.`);
|
|
|
- this.curDimSize = dimSize;
|
|
|
+ this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`);
|
|
|
}
|
|
|
|
|
|
- @RequireLock(DatabaseLock.CLIPDimSize)
|
|
|
- private async getDimSize(): Promise<void> {
|
|
|
- if (this.curDimSize != null) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
+ private async getDimSize(): Promise<number> {
|
|
|
const res = await this.smartSearchRepository.manager.query(`
|
|
|
SELECT atttypmod as dimsize
|
|
|
FROM pg_attribute f
|
|
@@ -153,7 +166,11 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
|
|
AND f.attnum > 0
|
|
|
AND c.relname = 'smart_search'
|
|
|
AND f.attname = 'embedding'`);
|
|
|
- this.curDimSize = res?.[0]?.['dimsize'] ?? 512;
|
|
|
- this.logger.verbose(`CLIP dimension size is ${this.curDimSize}`);
|
|
|
+
|
|
|
+ const dimSize = res[0]['dimsize'];
|
|
|
+ if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
|
|
+ throw new Error(`Could not retrieve CLIP dimension size`);
|
|
|
+ }
|
|
|
+ return dimSize;
|
|
|
}
|
|
|
}
|