added faiss

This commit is contained in:
mertalev 2023-09-10 17:43:41 -04:00
parent 0a9b632e48
commit 07c4e039b5
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -1,11 +1,13 @@
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from typing import Any
from zipfile import BadZipFile
import faiss
import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi import FastAPI, Form, HTTPException, UploadFile, Depends
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
from starlette.formparsers import MultiPartParser
@ -22,6 +24,18 @@ from .schemas import (
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
app = FastAPI()
vector_stores: dict[str, faiss.IndexIDMap2] = {}
def validate_embeddings(embeddings: list[float] | np.ndarray[int, np.dtype[Any]]) -> np.ndarray[int, np.dtype[Any]]:
embeddings = np.array(embeddings)
if len(embeddings.shape) == 1:
embeddings = np.expand_dims(embeddings, 0)
elif len(embeddings.shape) != 2:
raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
if embeddings.shape[1] < 10:
raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
return embeddings
def init_state() -> None:
@ -34,7 +48,8 @@ def init_state() -> None:
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
app.state.index_lock = threading.Lock()
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@ -78,6 +93,49 @@ async def predict(
return ORJSONResponse(outputs)
@app.post("/index/{index_name}/search")
async def search(
index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
) -> None:
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
raise HTTPException(404, f"Index '{index_name}' not found")
return vector_stores[index_name].search(embeddings, k)[1] # type: ignore
@app.patch("/index/{index_name}/add")
async def add(
index_name: str,
embedding_ids: list[str],
embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
) -> None:
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
await create(index_name, embedding_ids, embeddings)
else:
vector_stores[index_name].add_with_ids(embeddings, embedding_ids) # type: ignore
@app.post("/index/{index_name}/create")
async def create(
index_name: str,
embedding_ids: list[str],
embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
) -> None:
if embeddings.shape[0] != len(embedding_ids):
raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
if index_name in vector_stores:
log.warn(f"Index '{index_name}' already exists. Overwriting.")
hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
mapped_index = faiss.IndexIDMap2(hnsw_index)
def _create() -> faiss.IndexIDMap2:
with app.state.index_lock:
mapped_index.add_with_ids(embeddings, embedding_ids) # type: ignore
return mapped_index
vector_stores[index_name] = await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, _create)
async def run(model: InferenceModel, inputs: Any) -> Any:
if app.state.thread_pool is None:
return model.predict(inputs)