fixed parameterization
This commit is contained in:
parent
c1175be1d0
commit
6894c89644
3 changed files with 65 additions and 39 deletions
|
@ -33,5 +33,10 @@ export async function paginate<Entity extends ObjectLiteral>(
|
||||||
return { items, hasNextPage };
|
return { items, hasNextPage };
|
||||||
}
|
}
|
||||||
|
|
||||||
export const asVector = (embedding: number[], escape = false) =>
|
export const asVector = (embedding: number[], quote = false) =>
|
||||||
escape ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;
|
quote ? `'[${embedding.join(',')}]'` : `[${embedding.join(',')}]`;
|
||||||
|
|
||||||
|
export const isValidInteger = (value: number, options: {min?: number, max?: number}): boolean => {
|
||||||
|
const { min = Number.MIN_SAFE_INTEGER, max = Number.MAX_SAFE_INTEGER } = options;
|
||||||
|
return Number.isInteger(value) && value >= min && value <= max;
|
||||||
|
}
|
||||||
|
|
|
@ -9,10 +9,10 @@ import {
|
||||||
} from '@app/domain';
|
} from '@app/domain';
|
||||||
import { InjectRepository } from '@nestjs/typeorm';
|
import { InjectRepository } from '@nestjs/typeorm';
|
||||||
import { In, Repository } from 'typeorm';
|
import { In, Repository } from 'typeorm';
|
||||||
|
import { dataSource } from '..';
|
||||||
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
|
import { AssetEntity, AssetFaceEntity, PersonEntity } from '../entities';
|
||||||
import { DummyValue, GenerateSql } from '../infra.util';
|
import { DummyValue, GenerateSql } from '../infra.util';
|
||||||
import { asVector } from '../infra.utils';
|
import { asVector, isValidInteger } from '../infra.utils';
|
||||||
import { dataSource } from '..';
|
|
||||||
|
|
||||||
export class PersonRepository implements IPersonRepository {
|
export class PersonRepository implements IPersonRepository {
|
||||||
private readonly faceColumns: string[];
|
private readonly faceColumns: string[];
|
||||||
|
@ -22,8 +22,8 @@ export class PersonRepository implements IPersonRepository {
|
||||||
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
|
||||||
) {
|
) {
|
||||||
this.faceColumns = this.assetFaceRepository.manager.connection
|
this.faceColumns = this.assetFaceRepository.manager.connection
|
||||||
.getMetadata(AssetFaceEntity).ownColumns
|
.getMetadata(AssetFaceEntity)
|
||||||
.map((column) => column.propertyName)
|
.ownColumns.map((column) => column.propertyName)
|
||||||
.filter((propertyName) => propertyName !== 'embedding');
|
.filter((propertyName) => propertyName !== 'embedding');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,24 +248,39 @@ export class PersonRepository implements IPersonRepository {
|
||||||
return this.assetFaceRepository.findOneBy({ personId });
|
return this.assetFaceRepository.findOneBy({ personId });
|
||||||
}
|
}
|
||||||
|
|
||||||
async searchByEmbedding({ ownerId, embedding, numResults, maxDistance }: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
async searchByEmbedding({
|
||||||
const cte = this.assetFaceRepository.createQueryBuilder('faces')
|
ownerId,
|
||||||
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
embedding,
|
||||||
.innerJoin('faces.asset', 'asset')
|
numResults,
|
||||||
.where('asset.ownerId = :ownerId')
|
maxDistance,
|
||||||
.orderBy(`faces.embedding <=> :embedding`)
|
}: EmbeddingSearch): Promise<AssetFaceEntity[]> {
|
||||||
.setParameters({ownerId, embedding: asVector(embedding)})
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
.limit(numResults);
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
|
}
|
||||||
|
|
||||||
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
let results: AssetFaceEntity[] = [];
|
||||||
|
this.assetRepository.manager.transaction(async (manager) => {
|
||||||
|
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
||||||
|
const cte = manager
|
||||||
|
.createQueryBuilder(AssetFaceEntity, 'faces')
|
||||||
|
.select('1 + (faces.embedding <=> :embedding)', 'distance')
|
||||||
|
.innerJoin('faces.asset', 'asset')
|
||||||
|
.where('asset.ownerId = :ownerId')
|
||||||
|
.orderBy(`faces.embedding <=> :embedding`)
|
||||||
|
.setParameters({ ownerId, embedding: asVector(embedding) })
|
||||||
|
.limit(numResults);
|
||||||
|
|
||||||
const res = await dataSource.createQueryBuilder()
|
this.faceColumns.forEach((col) => cte.addSelect(`faces.${col} AS "${col}"`));
|
||||||
.select('res.*')
|
|
||||||
.addCommonTableExpression(cte, 'cte')
|
|
||||||
.from('cte', 'res')
|
|
||||||
.where('res.distance <= :maxDistance', { maxDistance })
|
|
||||||
.getRawMany();
|
|
||||||
|
|
||||||
return this.assetFaceRepository.create(res);
|
results = await dataSource
|
||||||
|
.createQueryBuilder()
|
||||||
|
.select('res.*')
|
||||||
|
.addCommonTableExpression(cte, 'cte')
|
||||||
|
.from('cte', 'res')
|
||||||
|
.where('res.distance <= :maxDistance', { maxDistance })
|
||||||
|
.getRawMany();
|
||||||
|
});
|
||||||
|
|
||||||
|
return this.assetFaceRepository.create(results);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import { InjectRepository } from '@nestjs/typeorm';
|
||||||
import AsyncLock from 'async-lock';
|
import AsyncLock from 'async-lock';
|
||||||
import { Repository } from 'typeorm';
|
import { Repository } from 'typeorm';
|
||||||
import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
|
import { AssetEntity, SmartInfoEntity, SmartSearchEntity } from '../entities';
|
||||||
import { asVector } from '../infra.utils';
|
import { asVector, isValidInteger } from '../infra.utils';
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class SmartInfoRepository implements ISmartInfoRepository {
|
export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
|
@ -21,23 +21,25 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
}
|
}
|
||||||
|
|
||||||
async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
async searchByEmbedding({ ownerId, embedding, numResults }: EmbeddingSearch): Promise<AssetEntity[]> {
|
||||||
const query: string = this.assetRepository
|
if (!isValidInteger(numResults, { min: 1 })) {
|
||||||
.createQueryBuilder('a')
|
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||||
.innerJoin('a.smartSearch', 's')
|
}
|
||||||
.where('a.ownerId = :ownerId')
|
|
||||||
.leftJoinAndSelect('a.exifInfo', 'e')
|
|
||||||
.orderBy('s.embedding <=> :embedding')
|
|
||||||
.setParameters({ embedding: asVector(embedding), ownerId })
|
|
||||||
.limit(numResults)
|
|
||||||
.getSql();
|
|
||||||
|
|
||||||
const queryWithK = `
|
let results: AssetEntity[] = [];
|
||||||
BEGIN;
|
this.assetRepository.manager.transaction(async (manager) => {
|
||||||
SET LOCAL vectors.k = ${numResults};
|
await manager.query(`SET LOCAL vectors.k = '${numResults}'`);
|
||||||
${query};
|
results = await manager
|
||||||
COMMIT;
|
.createQueryBuilder(AssetEntity, 'a')
|
||||||
`;
|
.innerJoin('a.smartSearch', 's')
|
||||||
return this.assetRepository.create(await this.assetRepository.manager.query(queryWithK));
|
.where('a.ownerId = :ownerId')
|
||||||
|
.leftJoinAndSelect('a.exifInfo', 'e')
|
||||||
|
.orderBy('s.embedding <=> :embedding')
|
||||||
|
.setParameters({ ownerId, embedding: asVector(embedding) })
|
||||||
|
.limit(numResults)
|
||||||
|
.getMany();
|
||||||
|
});
|
||||||
|
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
|
async upsert(smartInfo: Partial<SmartInfoEntity>, embedding?: Embedding): Promise<void> {
|
||||||
|
@ -70,6 +72,10 @@ export class SmartInfoRepository implements ISmartInfoRepository {
|
||||||
* this does not parameterize the query because it is not possible to parameterize the column type
|
* this does not parameterize the query because it is not possible to parameterize the column type
|
||||||
*/
|
*/
|
||||||
private async updateDimSize(dimSize: number): Promise<void> {
|
private async updateDimSize(dimSize: number): Promise<void> {
|
||||||
|
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||||
|
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||||
|
}
|
||||||
|
|
||||||
await this.lock.acquire('updateDimSizeLock', async () => {
|
await this.lock.acquire('updateDimSizeLock', async () => {
|
||||||
if (this.curDimSize === dimSize) return;
|
if (this.curDimSize === dimSize) return;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue