Explorar el Código

working clip search (kinda)

mertalev hace 1 año
padre
commit
ab7fbba5d4

+ 1 - 1
machine-learning/app/config.py

@@ -56,7 +56,7 @@ log_settings = LogSettings()
 
 class CustomRichHandler(RichHandler):
     def __init__(self) -> None:
-        console = Console(no_color=log_settings.no_color, force_interactive=True, stderr=True)
+        console = Console(no_color=log_settings.no_color)
         super().__init__(
             show_path=False,
             omit_repeated_times=False,

+ 4 - 5
machine-learning/app/main.py

@@ -23,9 +23,7 @@ from .schemas import (
     TextResponse,
 )
 
-# import rich.pretty
 
-# rich.pretty.install()
 MultiPartParser.max_file_size = 2**24  # spools to disk if payload is 16 MiB or larger
 app = FastAPI()
 
@@ -118,17 +116,18 @@ async def pipeline(
 
     outputs = await _predict(model_name, model_type, inputs, **kwargs)
     if index_name is not None:
+        expanded = np.expand_dims(outputs, 0)
         if k is not None:
             if k < 1:
                 raise HTTPException(400, f"k must be a positive integer; got {k}")
             if index_name not in vector_stores:
                 raise HTTPException(404, f"Index '{index_name}' not found")
-            outputs = await run(vector_stores[index_name].search, outputs, k)
+            outputs = await run(vector_stores[index_name].search, expanded, k)
         if embedding_id is not None:
             if index_name not in vector_stores:
-                await create(index_name, [embedding_id], outputs)
+                await create(index_name, [embedding_id], expanded)
             else:
-                await add(index_name, [embedding_id], outputs)
+                await add(index_name, [embedding_id], expanded)
     return ORJSONResponse(outputs)
 
 

+ 1 - 1
machine-learning/app/models/clip.py

@@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel):
             case _:
                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
 
-        return outputs[0]
+        return outputs[0][0]
 
     def _get_jina_model_name(self, model_name: str) -> str:
         if model_name in _MODELS:

+ 1 - 9
server/src/domain/search/search.service.ts

@@ -1,23 +1,15 @@
 import { AlbumEntity, AssetEntity, AssetFaceEntity } from '@app/infra/entities';
 import { Inject, Injectable, Logger } from '@nestjs/common';
-import { mapAlbumWithAssets } from '../album';
-import { IAlbumRepository } from '../album/album.repository';
 import { AssetResponseDto, mapAsset } from '../asset';
 import { IAssetRepository } from '../asset/asset.repository';
 import { AuthUserDto } from '../auth';
-import { usePagination } from '../domain.util';
-import { AssetFaceId, IFaceRepository } from '../facial-recognition';
-import { IAssetFaceJob, IBulkEntityJob, IJobRepository, JOBS_ASSET_PAGINATION_SIZE, JobName } from '../job';
+import { JobName } from '../job';
 import { IMachineLearningRepository } from '../smart-info';
 import { FeatureFlag, ISystemConfigRepository, SystemConfigCore } from '../system-config';
 import { SearchDto } from './dto';
 import { SearchResponseDto } from './response-dto';
 import {
-  ISearchRepository,
-  OwnedFaceEntity,
-  SearchCollection,
   SearchExploreItem,
-  SearchResult,
   SearchStrategy,
 } from './search.repository';