fixed parameterization

This commit is contained in:
mertalev 2023-11-21 23:27:34 -05:00
parent c1175be1d0
commit 6894c89644
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
3 changed files with 65 additions and 39 deletions

View file

@ -33,5 +33,10 @@ export async function paginate<Entity extends ObjectLiteral>(
return { items, hasNextPage }; return { items, hasNextPage };
} }
export const asVector = (embedding: number[], escape = false) => export const asVector = (embedding: number[], quote = false) =>
escape ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`; quote ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;
export const isValidInteger = (value: number, options: {min?: number, max?: number}): boolean => {
const { min = Number.MIN_SAFE_INTEGER, max = Number.MAX_SAFE_INTEGER } = options;
return Number.isInteger(value) && value >= min && value <= max;
}

View file

@ -9,10 +9,10 @@ import {
} from '@app/domain'; } from '@app/domain';
import { InjectRepository } from '@nestjs/typeorm'; import { InjectRepository } from '@nestjs/typeorm';
import { In, Repository } from 'typeorm'; import { In, Repository } from 'typeorm';
import { dataSource } from '..';
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities'; import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
import { DummyValue, GenerateSql } from '../infra.util'; import { DummyValue, GenerateSql } from '../infra.util';
import { asVector } from '../infra.utils'; import { asVector, isValidInteger } from '../infra.utils';
import { dataSource } from '..';
export class PersonRepository implements IPersonRepository { export class PersonRepository implements IPersonRepository {
private readonly faceColumns: string[]; private readonly faceColumns: string[];
@ -22,8 +22,8 @@ export class PersonRepository implements IPersonRepository {
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>, @InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
) { ) {
this.faceColumns = this.assetFaceRepository.manager.connection this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity).ownColumns .getMetadata(AssetFaceEntity)
.map((column) => column.propertyName) .ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding'); .filter((propertyName) => propertyName !== 'embedding');
} }
@ -248,24 +248,39 @@ export class PersonRepository implements IPersonRepository {
return this.assetFaceRepository.findOneBy({ personId }); return this.assetFaceRepository.findOneBy({ personId });
} }
async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> { async searchByEmbedding({
const cte = this.assetFaceRepository.createQueryBuilder('faces') ownerId,
.select('1 + (faces.embedding <=> :embedding)', 'distance') embedding,
.innerJoin('faces.asset', 'asset') numResults,
.where('asset.ownerId = :ownerId') maxDistance,
.orderBy(`faces.embedding <=> :embedding`) }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
.setParameters({ownerId, embedding: asVector(embedding)}) if (!isValidInteger(numResults, { min: 1 })) {
.limit(numResults); throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`)); let results: AssetFaceEntity[] = [];
this.assetRepository.manager.transaction(async (manager) => {
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('1 + (faces.embedding <=> :embedding)', 'distance')
.innerJoin('faces.asset', 'asset')
.where('asset.ownerId = :ownerId')
.orderBy(`faces.embedding <=> :embedding`)
.setParameters({ ownerId, embedding: asVector(embedding) })
.limit(numResults);
const res = await dataSource.createQueryBuilder() this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
.select('res.*')
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.where('res.distance <= :maxDistance', { maxDistance })
.getRawMany();
return this.assetFaceRepository.create(res); results = await dataSource
.createQueryBuilder()
.select('res.*')
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.where('res.distance <= :maxDistance', { maxDistance })
.getRawMany();
});
return this.assetFaceRepository.create(results);
} }
} }

View file

@ -4,7 +4,7 @@ import { InjectRepository } from '@nestjs/typeorm';
import AsyncLock from 'async-lock'; import AsyncLock from 'async-lock';
import { Repository } from 'typeorm'; import { Repository } from 'typeorm';
import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities'; import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
import { asVector } from '../infra.utils'; import { asVector, isValidInteger } from '../infra.utils';
@Injectable() @Injectable()
export class SmartInfoRepository implements ISmartInfoRepository { export class SmartInfoRepository implements ISmartInfoRepository {
@ -21,23 +21,25 @@ export class SmartInfoRepository implements ISmartInfoRepository {
} }
async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> { async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
const query: string = this.assetRepository if (!isValidInteger(numResults, { min: 1 })) {
.createQueryBuilder('a') throw new Error(`Invalid value for 'numResults': ${numResults}`);
.innerJoin('a.smartSearch', 's') }
.where('a.ownerId = :ownerId')
.leftJoinAndSelect('a.exifInfo', 'e')
.orderBy('s.embedding <=> :embedding')
.setParameters({ embedding: asVector(embedding), ownerId })
.limit(numResults)
.getSql();
const queryWithK = ` let results: AssetEntity[] = [];
BEGIN; this.assetRepository.manager.transaction(async (manager) => {
SET LOCAL vectors.k = ${numResults}; await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
${query}; results = await manager
COMMIT; .createQueryBuilder(AssetEntity, 'a')
`; .innerJoin('a.smartSearch', 's')
return this.assetRepository.create(await this.assetRepository.manager.query(queryWithK)); .where('a.ownerId = :ownerId')
.leftJoinAndSelect('a.exifInfo', 'e')
.orderBy('s.embedding <=> :embedding')
.setParameters({ ownerId, embedding: asVector(embedding) })
.limit(numResults)
.getMany();
});
return results;
} }
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> { async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
@ -70,6 +72,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
* this does not parameterize the query because it is not possible to parameterize the column type * this does not parameterize the query because it is not possible to parameterize the column type
*/ */
private async updateDimSize(dimSize: number): Promise<void> { private async updateDimSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
await this.lock.acquire('updateDimSizeLock', async () => { await this.lock.acquire('updateDimSizeLock', async () => {
if (this.curDimSize === dimSize) return; if (this.curDimSize === dimSize) return;