浏览代码

added pipeline endpoint

mertalev 1 年之前
父节点
当前提交
b26b4042cf
共有 2 个文件被更改,包括 87 次插入41 次删除
  1. 86 40
      machine-learning/app/main.py
  2. 1 1
      machine-learning/app/models/clip.py

+ 86 - 40
machine-learning/app/main.py

@@ -1,7 +1,8 @@
 import asyncio
 import asyncio
+from functools import partial
 import threading
 import threading
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import Any
+from typing import Any, Callable
 from zipfile import BadZipFile
 from zipfile import BadZipFile
 
 
 import faiss
 import faiss
@@ -68,14 +69,47 @@ def ping() -> str:
     return "pong"
     return "pong"
 
 
 
 
-@app.post("/predict")
+@app.post("/pipeline", response_class=ORJSONResponse)
+async def pipeline(
+    model_name: str = Form(alias="modelName"),
+    model_type: ModelType = Form(alias="modelType"),
+    options: str = Form(default="{}"),
+    text: str | None = Form(default=None),
+    image: UploadFile | None = None,
+    index_name: str | None = Form(default=None),
+    embedding_id: str | None = Form(default=None),
+    k: int | None = Form(default=None),
+) -> ORJSONResponse:
+    if image is not None:
+        inputs: str | bytes = await image.read()
+    elif text is not None:
+        inputs = text
+    else:
+        raise HTTPException(400, "Either image or text must be provided")
+    try:
+        kwargs = orjson.loads(options)
+    except orjson.JSONDecodeError:
+        raise HTTPException(400, f"Invalid options JSON: {options}")
+
+    outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
+    if index_name is not None:
+        if k is not None:
+            if k < 1:
+                raise HTTPException(400, f"k must be a positive integer; got {k}")
+            outputs = await run(_search, index_name, outputs, k)
+        if embedding_id is not None:
+            await run(_add, index_name, [embedding_id], outputs)
+    return ORJSONResponse(outputs)
+
+
+@app.post("/predict", response_class=ORJSONResponse)
 async def predict(
 async def predict(
     model_name: str = Form(alias="modelName"),
     model_name: str = Form(alias="modelName"),
     model_type: ModelType = Form(alias="modelType"),
     model_type: ModelType = Form(alias="modelType"),
     options: str = Form(default="{}"),
     options: str = Form(default="{}"),
     text: str | None = Form(default=None),
     text: str | None = Form(default=None),
     image: UploadFile | None = None,
     image: UploadFile | None = None,
-) -> Any:
+) -> ORJSONResponse:
     if image is not None:
     if image is not None:
         inputs: str | bytes = await image.read()
         inputs: str | bytes = await image.read()
     elif text is not None:
     elif text is not None:
@@ -87,22 +121,21 @@ async def predict(
     except orjson.JSONDecodeError:
     except orjson.JSONDecodeError:
         raise HTTPException(400, f"Invalid options JSON: {options}")
         raise HTTPException(400, f"Invalid options JSON: {options}")
 
 
-    model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
-    model.configure(**kwargs)
-    outputs = await run(model, inputs)
+    outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
     return ORJSONResponse(outputs)
     return ORJSONResponse(outputs)
 
 
 
 
-@app.post("/index/{index_name}/search")
+@app.post("/index/{index_name}/search", response_class=ORJSONResponse)
 async def search(
 async def search(
     index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
     index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
-) -> None:
+) -> ORJSONResponse:
     if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
     if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
         raise HTTPException(404, f"Index '{index_name}' not found")
         raise HTTPException(404, f"Index '{index_name}' not found")
-    return vector_stores[index_name].search(embeddings, k)[1]  # type: ignore
+    outputs: np.ndarray[int, np.dtype[Any]] = await run(_search, index_name, embeddings, k)
+    return ORJSONResponse(outputs)
 
 
 
 
-@app.patch("/index/{index_name}/add")
+@app.post("/index/{index_name}/add")
 async def add(
 async def add(
     index_name: str,
     index_name: str,
     embedding_ids: list[str],
     embedding_ids: list[str],
@@ -111,7 +144,7 @@ async def add(
     if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
     if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
         await create(index_name, embedding_ids, embeddings)
         await create(index_name, embedding_ids, embeddings)
     else:
     else:
-        vector_stores[index_name].add_with_ids(embeddings, embedding_ids)  # type: ignore
+        await run(_add, index_name, embedding_ids, embeddings)
 
 
 
 
 @app.post("/index/{index_name}/create")
 @app.post("/index/{index_name}/create")
@@ -125,39 +158,25 @@ async def create(
     if index_name in vector_stores:
     if index_name in vector_stores:
         log.warn(f"Index '{index_name}' already exists. Overwriting.")
         log.warn(f"Index '{index_name}' already exists. Overwriting.")
 
 
-    hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
-    mapped_index = faiss.IndexIDMap2(hnsw_index)
-
-    def _create() -> faiss.IndexIDMap2:
-        with app.state.index_lock:
-            mapped_index.add_with_ids(embeddings, embedding_ids)  # type: ignore
-        return mapped_index
+    vector_stores[index_name] = await run(_create)
 
 
-    vector_stores[index_name] = await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, _create)
 
 
-
-async def run(model: InferenceModel, inputs: Any) -> Any:
+async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
     if app.state.thread_pool is None:
     if app.state.thread_pool is None:
-        return model.predict(inputs)
-
-    return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
+        return func(*args, **kwargs)
+    if kwargs:
+        func = partial(func, **kwargs)
+    return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, func, *args)
 
 
 
 
-async def load(model: InferenceModel) -> InferenceModel:
+async def _load(model: InferenceModel) -> InferenceModel:
     if model.loaded:
     if model.loaded:
         return model
         return model
 
 
-    def _load() -> None:
-        with app.state.locks[model.model_type]:
-            model.load()
-
-    loop = asyncio.get_running_loop()
     try:
     try:
-        if app.state.thread_pool is None:
-            model.load()
-        else:
-            await loop.run_in_executor(app.state.thread_pool, _load)
-        return model
+        with app.state.model_locks[model.model_type]:
+            if not model.loaded:
+                model.load()
     except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
     except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
         log.warn(
         log.warn(
             (
             (
@@ -166,8 +185,35 @@ async def load(model: InferenceModel) -> InferenceModel:
             )
             )
         )
         )
         model.clear_cache()
         model.clear_cache()
-        if app.state.thread_pool is None:
-            model.load()
-        else:
-            await loop.run_in_executor(app.state.thread_pool, _load)
-        return model
+        model.load()
+    return model
+
+
+async def _add(index_name: str, embedding_ids: list[str], embeddings: np.ndarray[int, np.dtype[np.float32]]) -> None:
+    return await vector_stores[index_name].add_with_ids(embeddings, embedding_ids)  # type: ignore
+
+
+async def _search(
+    index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]], k: int
+) -> np.ndarray[int, np.dtype[Any]]:
+    return await vector_stores[index_name].assign(embeddings, k)  # type: ignore
+
+
+async def _predict(
+    model_name: str, model_type: ModelType, inputs: Any, **options: Any
+) -> np.ndarray[int, np.dtype[np.float32]]:
+    model = await _load(await app.state.model_cache.get(model_name, model_type, **options))
+    model.configure(**options)
+    return await run(model.predict, inputs)
+
+
+async def _create(
+    embedding_ids: list[str],
+    embeddings: np.ndarray[int, np.dtype[np.float32]],
+) -> faiss.IndexIDMap2:
+    hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
+    mapped_index = faiss.IndexIDMap2(hnsw_index)
+
+    with app.state.index_lock:
+        mapped_index.add_with_ids(embeddings, embedding_ids)  # type: ignore
+    return mapped_index

+ 1 - 1
machine-learning/app/models/clip.py

@@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel):
             case _:
             case _:
                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
 
 
-        return outputs[0][0].tolist()
+        return outputs[0][0]
 
 
     def _get_jina_model_name(self, model_name: str) -> str:
     def _get_jina_model_name(self, model_name: str) -> str:
         if model_name in _MODELS:
         if model_name in _MODELS: