Browse Source

working pipeline endpoint

mertalev 1 year ago
parent
commit
ca5f0c7bbd
2 changed files with 37 additions and 16 deletions
  1. 36 15
      machine-learning/app/main.py
  2. 1 1
      machine-learning/app/models/clip.py

+ 36 - 15
machine-learning/app/main.py

@@ -23,26 +23,35 @@ from .schemas import (
     TextResponse,
 )
 
+# import rich.pretty
+
+# rich.pretty.install()
 MultiPartParser.max_file_size = 2**24  # spools to disk if payload is 16 MiB or larger
 app = FastAPI()
-vector_stores: dict[str, faiss.IndexIDMap2] = {}
 
 
 class VectorStore:
-    def __init__(self, dims: int, index_cls: Type[faiss.Index] = faiss.IndexHNSWFlat, **kwargs: Any) -> None:
-        self.index = index_cls(dims, **kwargs)
-        self.id_to_key: dict[int, Any] = {}
+    def __init__(self, dims: int, index: str = "HNSW") -> None:
+        self.index = faiss.index_factory(dims, index)
+        self.key_to_id: dict[int, Any] = {}
 
     def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
-        ids = self.index.assign(embeddings, k)  # type: ignore
-        return [self.id_to_key[idx] for row in ids.tolist() for idx in row if not idx == -1]
+        keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
+        return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
 
     def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
-        self.id_to_key |= {
-            id: key for id, key in zip(embedding_ids, range(self.index.ntotal, self.index.ntotal + len(embedding_ids)))
+        self.key_to_id |= {
+            key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
         }
         self.index.add(embeddings)  # type: ignore
 
+    @property
+    def dims(self) -> int:
+        return self.index.d
+
+
+vector_stores: dict[str, VectorStore] = {}
+
 
 def validate_embeddings(embeddings: list[float]) -> Any:
     embeddings = np.array(embeddings)
@@ -119,7 +128,7 @@ async def pipeline(
             if index_name not in vector_stores:
                 await create(index_name, [embedding_id], outputs)
             else:
-                await run(vector_stores[index_name].add, [embedding_id], outputs)
+                await add(index_name, [embedding_id], outputs)
     return ORJSONResponse(outputs)
 
 
@@ -148,7 +157,7 @@ async def predict(
 
 @app.post("/index/{index_name}/search", response_class=ORJSONResponse)
 async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
-    if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
+    if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
         raise HTTPException(404, f"Index '{index_name}' not found")
     outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
     return ORJSONResponse(outputs)
@@ -160,10 +169,11 @@ async def add(
     embedding_ids: list[str],
     embeddings: Any = Depends(validate_embeddings),
 ) -> None:
-    if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
+    if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
         await create(index_name, embedding_ids, embeddings)
     else:
-        await run(vector_stores[index_name].add_with_ids, embeddings, embedding_ids)
+        log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
+        await run(_add, vector_stores[index_name], embedding_ids, embeddings)
 
 
 @app.post("/index/{index_name}/create")
@@ -173,11 +183,15 @@ async def create(
     embeddings: Any = Depends(validate_embeddings),
 ) -> None:
     if embeddings.shape[0] != len(embedding_ids):
-        raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
+        raise HTTPException(
+            400,
+            f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
+        )
     if index_name in vector_stores:
         log.warn(f"Index '{index_name}' already exists. Overwriting.")
+    log.info(f"Creating new index '{index_name}'")
 
-    vector_stores[index_name] = await run(_create)
+    vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
 
 
 async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
@@ -223,7 +237,14 @@ def _create(
     embeddings: np.ndarray[int, np.dtype[np.float32]],
 ) -> VectorStore:
     index = VectorStore(embeddings.shape[1])
+    _add(index, embedding_ids, embeddings)
+    return index
 
+
+def _add(
+    index: VectorStore,
+    embedding_ids: list[str],
+    embeddings: np.ndarray[int, np.dtype[np.float32]],
+) -> None:
     with app.state.index_lock:
         index.add_with_ids(embeddings, embedding_ids)  # type: ignore
-    return index

+ 1 - 1
machine-learning/app/models/clip.py

@@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel):
             case _:
                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
 
-        return outputs[0][0]
+        return outputs[0]
 
     def _get_jina_model_name(self, model_name: str) -> str:
         if model_name in _MODELS: