|
@@ -1,7 +1,8 @@
|
|
import asyncio
|
|
import asyncio
|
|
|
|
+from functools import partial
|
|
import threading
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
-from typing import Any
|
|
|
|
|
|
+from typing import Any, Callable
|
|
from zipfile import BadZipFile
|
|
from zipfile import BadZipFile
|
|
|
|
|
|
import faiss
|
|
import faiss
|
|
@@ -68,14 +69,47 @@ def ping() -> str:
|
|
return "pong"
|
|
return "pong"
|
|
|
|
|
|
|
|
|
|
-@app.post("/predict")
|
|
|
|
|
|
+@app.post("/pipeline", response_class=ORJSONResponse)
|
|
|
|
+async def pipeline(
|
|
|
|
+ model_name: str = Form(alias="modelName"),
|
|
|
|
+ model_type: ModelType = Form(alias="modelType"),
|
|
|
|
+ options: str = Form(default="{}"),
|
|
|
|
+ text: str | None = Form(default=None),
|
|
|
|
+ image: UploadFile | None = None,
|
|
|
|
+ index_name: str | None = Form(default=None),
|
|
|
|
+ embedding_id: str | None = Form(default=None),
|
|
|
|
+ k: int | None = Form(default=None),
|
|
|
|
+) -> ORJSONResponse:
|
|
|
|
+ if image is not None:
|
|
|
|
+ inputs: str | bytes = await image.read()
|
|
|
|
+ elif text is not None:
|
|
|
|
+ inputs = text
|
|
|
|
+ else:
|
|
|
|
+ raise HTTPException(400, "Either image or text must be provided")
|
|
|
|
+ try:
|
|
|
|
+ kwargs = orjson.loads(options)
|
|
|
|
+ except orjson.JSONDecodeError:
|
|
|
|
+ raise HTTPException(400, f"Invalid options JSON: {options}")
|
|
|
|
+
|
|
|
|
+ outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
|
|
|
|
+ if index_name is not None:
|
|
|
|
+ if k is not None:
|
|
|
|
+ if k < 1:
|
|
|
|
+ raise HTTPException(400, f"k must be a positive integer; got {k}")
|
|
|
|
+ outputs = await run(_search, index_name, outputs, k)
|
|
|
|
+ if embedding_id is not None:
|
|
|
|
+ await run(_add, index_name, [embedding_id], outputs)
|
|
|
|
+ return ORJSONResponse(outputs)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/predict", response_class=ORJSONResponse)
|
|
async def predict(
|
|
async def predict(
|
|
model_name: str = Form(alias="modelName"),
|
|
model_name: str = Form(alias="modelName"),
|
|
model_type: ModelType = Form(alias="modelType"),
|
|
model_type: ModelType = Form(alias="modelType"),
|
|
options: str = Form(default="{}"),
|
|
options: str = Form(default="{}"),
|
|
text: str | None = Form(default=None),
|
|
text: str | None = Form(default=None),
|
|
image: UploadFile | None = None,
|
|
image: UploadFile | None = None,
|
|
-) -> Any:
|
|
|
|
|
|
+) -> ORJSONResponse:
|
|
if image is not None:
|
|
if image is not None:
|
|
inputs: str | bytes = await image.read()
|
|
inputs: str | bytes = await image.read()
|
|
elif text is not None:
|
|
elif text is not None:
|
|
@@ -87,22 +121,21 @@ async def predict(
|
|
except orjson.JSONDecodeError:
|
|
except orjson.JSONDecodeError:
|
|
raise HTTPException(400, f"Invalid options JSON: {options}")
|
|
raise HTTPException(400, f"Invalid options JSON: {options}")
|
|
|
|
|
|
- model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
|
|
|
- model.configure(**kwargs)
|
|
|
|
- outputs = await run(model, inputs)
|
|
|
|
|
|
+ outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
|
|
return ORJSONResponse(outputs)
|
|
return ORJSONResponse(outputs)
|
|
|
|
|
|
|
|
|
|
-@app.post("/index/{index_name}/search")
|
|
|
|
|
|
+@app.post("/index/{index_name}/search", response_class=ORJSONResponse)
|
|
async def search(
|
|
async def search(
|
|
index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
|
|
index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
|
|
-) -> None:
|
|
|
|
|
|
+) -> ORJSONResponse:
|
|
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
raise HTTPException(404, f"Index '{index_name}' not found")
|
|
raise HTTPException(404, f"Index '{index_name}' not found")
|
|
- return vector_stores[index_name].search(embeddings, k)[1] # type: ignore
|
|
|
|
|
|
+ outputs: np.ndarray[int, np.dtype[Any]] = await run(_search, index_name, embeddings, k)
|
|
|
|
+ return ORJSONResponse(outputs)
|
|
|
|
|
|
|
|
|
|
-@app.patch("/index/{index_name}/add")
|
|
|
|
|
|
+@app.post("/index/{index_name}/add")
|
|
async def add(
|
|
async def add(
|
|
index_name: str,
|
|
index_name: str,
|
|
embedding_ids: list[str],
|
|
embedding_ids: list[str],
|
|
@@ -111,7 +144,7 @@ async def add(
|
|
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
await create(index_name, embedding_ids, embeddings)
|
|
await create(index_name, embedding_ids, embeddings)
|
|
else:
|
|
else:
|
|
- vector_stores[index_name].add_with_ids(embeddings, embedding_ids) # type: ignore
|
|
|
|
|
|
+ await run(_add, index_name, embedding_ids, embeddings)
|
|
|
|
|
|
|
|
|
|
@app.post("/index/{index_name}/create")
|
|
@app.post("/index/{index_name}/create")
|
|
@@ -125,39 +158,25 @@ async def create(
|
|
if index_name in vector_stores:
|
|
if index_name in vector_stores:
|
|
log.warn(f"Index '{index_name}' already exists. Overwriting.")
|
|
log.warn(f"Index '{index_name}' already exists. Overwriting.")
|
|
|
|
|
|
- hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
|
|
|
|
- mapped_index = faiss.IndexIDMap2(hnsw_index)
|
|
|
|
-
|
|
|
|
- def _create() -> faiss.IndexIDMap2:
|
|
|
|
- with app.state.index_lock:
|
|
|
|
- mapped_index.add_with_ids(embeddings, embedding_ids) # type: ignore
|
|
|
|
- return mapped_index
|
|
|
|
|
|
+ vector_stores[index_name] = await run(_create)
|
|
|
|
|
|
- vector_stores[index_name] = await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, _create)
|
|
|
|
|
|
|
|
-
|
|
|
|
-async def run(model: InferenceModel, inputs: Any) -> Any:
|
|
|
|
|
|
+async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
if app.state.thread_pool is None:
|
|
if app.state.thread_pool is None:
|
|
- return model.predict(inputs)
|
|
|
|
-
|
|
|
|
- return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
|
|
|
|
|
+ return func(*args, **kwargs)
|
|
|
|
+ if kwargs:
|
|
|
|
+ func = partial(func, **kwargs)
|
|
|
|
+ return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, func, *args)
|
|
|
|
|
|
|
|
|
|
-async def load(model: InferenceModel) -> InferenceModel:
|
|
|
|
|
|
+async def _load(model: InferenceModel) -> InferenceModel:
|
|
if model.loaded:
|
|
if model.loaded:
|
|
return model
|
|
return model
|
|
|
|
|
|
- def _load() -> None:
|
|
|
|
- with app.state.locks[model.model_type]:
|
|
|
|
- model.load()
|
|
|
|
-
|
|
|
|
- loop = asyncio.get_running_loop()
|
|
|
|
try:
|
|
try:
|
|
- if app.state.thread_pool is None:
|
|
|
|
- model.load()
|
|
|
|
- else:
|
|
|
|
- await loop.run_in_executor(app.state.thread_pool, _load)
|
|
|
|
- return model
|
|
|
|
|
|
+ with app.state.model_locks[model.model_type]:
|
|
|
|
+ if not model.loaded:
|
|
|
|
+ model.load()
|
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
|
log.warn(
|
|
log.warn(
|
|
(
|
|
(
|
|
@@ -166,8 +185,35 @@ async def load(model: InferenceModel) -> InferenceModel:
|
|
)
|
|
)
|
|
)
|
|
)
|
|
model.clear_cache()
|
|
model.clear_cache()
|
|
- if app.state.thread_pool is None:
|
|
|
|
- model.load()
|
|
|
|
- else:
|
|
|
|
- await loop.run_in_executor(app.state.thread_pool, _load)
|
|
|
|
- return model
|
|
|
|
|
|
+ model.load()
|
|
|
|
+ return model
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def _add(index_name: str, embedding_ids: list[str], embeddings: np.ndarray[int, np.dtype[np.float32]]) -> None:
|
|
|
|
+ return await vector_stores[index_name].add_with_ids(embeddings, embedding_ids) # type: ignore
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def _search(
|
|
|
|
+ index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]], k: int
|
|
|
|
+) -> np.ndarray[int, np.dtype[Any]]:
|
|
|
|
+ return await vector_stores[index_name].assign(embeddings, k) # type: ignore
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def _predict(
|
|
|
|
+ model_name: str, model_type: ModelType, inputs: Any, **options: Any
|
|
|
|
+) -> np.ndarray[int, np.dtype[np.float32]]:
|
|
|
|
+ model = await _load(await app.state.model_cache.get(model_name, model_type, **options))
|
|
|
|
+ model.configure(**options)
|
|
|
|
+ return await run(model.predict, inputs)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+async def _create(
|
|
|
|
+ embedding_ids: list[str],
|
|
|
|
+ embeddings: np.ndarray[int, np.dtype[np.float32]],
|
|
|
|
+) -> faiss.IndexIDMap2:
|
|
|
|
+ hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
|
|
|
|
+ mapped_index = faiss.IndexIDMap2(hnsw_index)
|
|
|
|
+
|
|
|
|
+ with app.state.index_lock:
|
|
|
|
+ mapped_index.add_with_ids(embeddings, embedding_ids) # type: ignore
|
|
|
|
+ return mapped_index
|