|
@@ -23,26 +23,35 @@ from .schemas import (
|
|
|
TextResponse,
|
|
|
)
|
|
|
|
|
|
+# import rich.pretty
|
|
|
+
|
|
|
+# rich.pretty.install()
|
|
|
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
|
|
app = FastAPI()
|
|
|
-vector_stores: dict[str, faiss.IndexIDMap2] = {}
|
|
|
|
|
|
|
|
|
class VectorStore:
|
|
|
- def __init__(self, dims: int, index_cls: Type[faiss.Index] = faiss.IndexHNSWFlat, **kwargs: Any) -> None:
|
|
|
- self.index = index_cls(dims, **kwargs)
|
|
|
- self.id_to_key: dict[int, Any] = {}
|
|
|
+ def __init__(self, dims: int, index: str = "HNSW") -> None:
|
|
|
+ self.index = faiss.index_factory(dims, index)
|
|
|
+ self.key_to_id: dict[int, Any] = {}
|
|
|
|
|
|
def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
|
|
|
- ids = self.index.assign(embeddings, k) # type: ignore
|
|
|
- return [self.id_to_key[idx] for row in ids.tolist() for idx in row if not idx == -1]
|
|
|
+ keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
|
|
|
+ return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
|
|
|
|
|
|
def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
|
|
|
- self.id_to_key |= {
|
|
|
- id: key for id, key in zip(embedding_ids, range(self.index.ntotal, self.index.ntotal + len(embedding_ids)))
|
|
|
+ self.key_to_id |= {
|
|
|
+ key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
|
|
|
}
|
|
|
self.index.add(embeddings) # type: ignore
|
|
|
|
|
|
+ @property
|
|
|
+ def dims(self) -> int:
|
|
|
+ return self.index.d
|
|
|
+
|
|
|
+
|
|
|
+vector_stores: dict[str, VectorStore] = {}
|
|
|
+
|
|
|
|
|
|
def validate_embeddings(embeddings: list[float]) -> Any:
|
|
|
embeddings = np.array(embeddings)
|
|
@@ -119,7 +128,7 @@ async def pipeline(
|
|
|
if index_name not in vector_stores:
|
|
|
await create(index_name, [embedding_id], outputs)
|
|
|
else:
|
|
|
- await run(vector_stores[index_name].add, [embedding_id], outputs)
|
|
|
+ await add(index_name, [embedding_id], outputs)
|
|
|
return ORJSONResponse(outputs)
|
|
|
|
|
|
|
|
@@ -148,7 +157,7 @@ async def predict(
|
|
|
|
|
|
@app.post("/index/{index_name}/search", response_class=ORJSONResponse)
|
|
|
async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
|
|
|
- if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
|
+ if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
|
|
|
raise HTTPException(404, f"Index '{index_name}' not found")
|
|
|
outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
|
|
|
return ORJSONResponse(outputs)
|
|
@@ -160,10 +169,11 @@ async def add(
|
|
|
embedding_ids: list[str],
|
|
|
embeddings: Any = Depends(validate_embeddings),
|
|
|
) -> None:
|
|
|
- if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
|
|
+ if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
|
|
|
await create(index_name, embedding_ids, embeddings)
|
|
|
else:
|
|
|
- await run(vector_stores[index_name].add_with_ids, embeddings, embedding_ids)
|
|
|
+ log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
|
|
|
+ await run(_add, vector_stores[index_name], embedding_ids, embeddings)
|
|
|
|
|
|
|
|
|
@app.post("/index/{index_name}/create")
|
|
@@ -173,11 +183,15 @@ async def create(
|
|
|
embeddings: Any = Depends(validate_embeddings),
|
|
|
) -> None:
|
|
|
if embeddings.shape[0] != len(embedding_ids):
|
|
|
- raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
|
|
|
+ raise HTTPException(
|
|
|
+ 400,
|
|
|
+ f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
|
|
|
+ )
|
|
|
if index_name in vector_stores:
|
|
|
log.warn(f"Index '{index_name}' already exists. Overwriting.")
|
|
|
+ log.info(f"Creating new index '{index_name}'")
|
|
|
|
|
|
- vector_stores[index_name] = await run(_create)
|
|
|
+ vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
|
|
|
|
|
|
|
|
|
async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
@@ -223,7 +237,14 @@ def _create(
|
|
|
embeddings: np.ndarray[int, np.dtype[np.float32]],
|
|
|
) -> VectorStore:
|
|
|
index = VectorStore(embeddings.shape[1])
|
|
|
+ _add(index, embedding_ids, embeddings)
|
|
|
+ return index
|
|
|
|
|
|
+
|
|
|
+def _add(
|
|
|
+ index: VectorStore,
|
|
|
+ embedding_ids: list[str],
|
|
|
+ embeddings: np.ndarray[int, np.dtype[np.float32]],
|
|
|
+) -> None:
|
|
|
with app.state.index_lock:
|
|
|
index.add_with_ids(embeddings, embedding_ids) # type: ignore
|
|
|
- return index
|