mertalev 1 gadu atpakaļ
vecāks
revīzija
07c4e039b5
1 mainītis faili ar 60 papildinājumiem un 2 dzēšanām
  1. 60 2
      machine-learning/app/main.py

+ 60 - 2
machine-learning/app/main.py

@@ -1,11 +1,13 @@
 import asyncio
 import threading
 from concurrent.futures import ThreadPoolExecutor
+import numpy as np
 from typing import Any
 from zipfile import BadZipFile
 
+import faiss
 import orjson
-from fastapi import FastAPI, Form, HTTPException, UploadFile
+from fastapi import FastAPI, Form, HTTPException, UploadFile, Depends
 from fastapi.responses import ORJSONResponse
 from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile  # type: ignore
 from starlette.formparsers import MultiPartParser
@@ -22,6 +24,18 @@ from .schemas import (
 
 MultiPartParser.max_file_size = 2**24  # spools to disk if payload is 16 MiB or larger
 app = FastAPI()
+vector_stores: dict[str, faiss.IndexIDMap2] = {}
+
+
+def validate_embeddings(embeddings: list[float] | np.ndarray[int, np.dtype[Any]]) -> np.ndarray[int, np.dtype[Any]]:
+    embeddings = np.array(embeddings)
+    if len(embeddings.shape) == 1:
+        embeddings = np.expand_dims(embeddings, 0)
+    elif len(embeddings.shape) != 2:
+        raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
+    if embeddings.shape[1] < 10:
+        raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
+    return embeddings
 
 
 def init_state() -> None:
@@ -34,7 +48,8 @@ def init_state() -> None:
     )
     # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
     app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
-    app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
+    app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
+    app.state.index_lock = threading.Lock()
     log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
 
 
@@ -78,6 +93,49 @@ async def predict(
     return ORJSONResponse(outputs)
 
 
+@app.post("/index/{index_name}/search")
+async def search(
+    index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
+) -> None:
+    if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
+        raise HTTPException(404, f"Index '{index_name}' not found")
+    return vector_stores[index_name].search(embeddings, k)[1]  # type: ignore
+
+
+@app.patch("/index/{index_name}/add")
+async def add(
+    index_name: str,
+    embedding_ids: list[str],
+    embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
+) -> None:
+    if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
+        await create(index_name, embedding_ids, embeddings)
+    else:
+        vector_stores[index_name].add_with_ids(embeddings, embedding_ids)  # type: ignore
+
+
+@app.post("/index/{index_name}/create")
+async def create(
+    index_name: str,
+    embedding_ids: list[str],
+    embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
+) -> None:
+    if embeddings.shape[0] != len(embedding_ids):
+        raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
+    if index_name in vector_stores:
+        log.warn(f"Index '{index_name}' already exists. Overwriting.")
+
+    hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
+    mapped_index = faiss.IndexIDMap2(hnsw_index)
+
+    def _create() -> faiss.IndexIDMap2:
+        with app.state.index_lock:
+            mapped_index.add_with_ids(embeddings, embedding_ids)  # type: ignore
+        return mapped_index
+
+    vector_stores[index_name] = await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, _create)
+
+
 async def run(model: InferenceModel, inputs: Any) -> Any:
     if app.state.thread_pool is None:
         return model.predict(inputs)