diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index bb917f02f..dda4183f8 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -23,26 +23,35 @@ from .schemas import ( TextResponse, ) +# import rich.pretty + +# rich.pretty.install() MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger app = FastAPI() -vector_stores: dict[str, faiss.IndexIDMap2] = {} class VectorStore: - def __init__(self, dims: int, index_cls: Type[faiss.Index] = faiss.IndexHNSWFlat, **kwargs: Any) -> None: - self.index = index_cls(dims, **kwargs) - self.id_to_key: dict[int, Any] = {} + def __init__(self, dims: int, index: str = "HNSW") -> None: + self.index = faiss.index_factory(dims, index) + self.key_to_id: dict[int, Any] = {} def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]: - ids = self.index.assign(embeddings, k) # type: ignore - return [self.id_to_key[idx] for row in ids.tolist() for idx in row if not idx == -1] + keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k) + return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1] def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None: - self.id_to_key |= { - id: key for id, key in zip(embedding_ids, range(self.index.ntotal, self.index.ntotal + len(embedding_ids))) + self.key_to_id |= { + key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids) } self.index.add(embeddings) # type: ignore + @property + def dims(self) -> int: + return self.index.d + + +vector_stores: dict[str, VectorStore] = {} + def validate_embeddings(embeddings: list[float]) -> Any: embeddings = np.array(embeddings) @@ -119,7 +128,7 @@ async def pipeline( if index_name not in vector_stores: await create(index_name, [embedding_id], outputs) else: - await run(vector_stores[index_name].add, [embedding_id], outputs) + await add(index_name, [embedding_id], outputs) return ORJSONResponse(outputs) @@ -148,7 +157,7 @@ async def predict( @app.post("/index/{index_name}/search", response_class=ORJSONResponse) async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse: - if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]: + if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]: raise HTTPException(404, f"Index '{index_name}' not found") outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k) return ORJSONResponse(outputs) @@ -160,10 +169,11 @@ async def add( embedding_ids: list[str], embeddings: Any = Depends(validate_embeddings), ) -> None: - if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]: + if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]: await create(index_name, embedding_ids, embeddings) else: - await run(vector_stores[index_name].add_with_ids, embeddings, embedding_ids) + log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'") + await run(_add, vector_stores[index_name], embedding_ids, embeddings) @app.post("/index/{index_name}/create") @@ -173,11 +183,15 @@ async def create( embeddings: Any = Depends(validate_embeddings), ) -> None: if embeddings.shape[0] != len(embedding_ids): - raise HTTPException(400, "Number of embedding IDs must match number of embeddings") + raise HTTPException( + 400, + f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)", + ) if index_name in vector_stores: log.warn(f"Index '{index_name}' already exists. Overwriting.") + log.info(f"Creating new index '{index_name}'") - vector_stores[index_name] = await run(_create) + vector_stores[index_name] = await run(_create, embedding_ids, embeddings) async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: @@ -223,7 +237,14 @@ def _create( embeddings: np.ndarray[int, np.dtype[np.float32]], ) -> VectorStore: index = VectorStore(embeddings.shape[1]) + _add(index, embedding_ids, embeddings) + return index + +def _add( + index: VectorStore, + embedding_ids: list[str], + embeddings: np.ndarray[int, np.dtype[np.float32]], +) -> None: with app.state.index_lock: index.add_with_ids(embeddings, embedding_ids) # type: ignore - return index diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index 5247cdcf0..b05fba8af 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel): case _: raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") - return outputs[0][0] + return outputs[0] def _get_jina_model_name(self, model_name: str) -> str: if model_name in _MODELS: