working pipeline endpoint
This commit is contained in:
parent
3c2265ecf4
commit
ca5f0c7bbd
2 changed files with 37 additions and 16 deletions
|
@ -23,26 +23,35 @@ from .schemas import (
|
|||
TextResponse,
|
||||
)
|
||||
|
||||
# import rich.pretty
|
||||
|
||||
# rich.pretty.install()
|
||||
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
||||
app = FastAPI()
|
||||
vector_stores: dict[str, faiss.IndexIDMap2] = {}
|
||||
|
||||
|
||||
class VectorStore:
|
||||
def __init__(self, dims: int, index_cls: Type[faiss.Index] = faiss.IndexHNSWFlat, **kwargs: Any) -> None:
|
||||
self.index = index_cls(dims, **kwargs)
|
||||
self.id_to_key: dict[int, Any] = {}
|
||||
def __init__(self, dims: int, index: str = "HNSW") -> None:
|
||||
self.index = faiss.index_factory(dims, index)
|
||||
self.key_to_id: dict[int, Any] = {}
|
||||
|
||||
def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
|
||||
ids = self.index.assign(embeddings, k) # type: ignore
|
||||
return [self.id_to_key[idx] for row in ids.tolist() for idx in row if not idx == -1]
|
||||
keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
|
||||
return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
|
||||
|
||||
def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
|
||||
self.id_to_key |= {
|
||||
id: key for id, key in zip(embedding_ids, range(self.index.ntotal, self.index.ntotal + len(embedding_ids)))
|
||||
self.key_to_id |= {
|
||||
key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
|
||||
}
|
||||
self.index.add(embeddings) # type: ignore
|
||||
|
||||
@property
|
||||
def dims(self) -> int:
|
||||
return self.index.d
|
||||
|
||||
|
||||
vector_stores: dict[str, VectorStore] = {}
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: list[float]) -> Any:
|
||||
embeddings = np.array(embeddings)
|
||||
|
@ -119,7 +128,7 @@ async def pipeline(
|
|||
if index_name not in vector_stores:
|
||||
await create(index_name, [embedding_id], outputs)
|
||||
else:
|
||||
await run(vector_stores[index_name].add, [embedding_id], outputs)
|
||||
await add(index_name, [embedding_id], outputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
||||
|
||||
|
@ -148,7 +157,7 @@ async def predict(
|
|||
|
||||
@app.post("/index/{index_name}/search", response_class=ORJSONResponse)
|
||||
async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
|
||||
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
||||
if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
|
||||
raise HTTPException(404, f"Index '{index_name}' not found")
|
||||
outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
|
||||
return ORJSONResponse(outputs)
|
||||
|
@ -160,10 +169,11 @@ async def add(
|
|||
embedding_ids: list[str],
|
||||
embeddings: Any = Depends(validate_embeddings),
|
||||
) -> None:
|
||||
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
|
||||
if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
|
||||
await create(index_name, embedding_ids, embeddings)
|
||||
else:
|
||||
await run(vector_stores[index_name].add_with_ids, embeddings, embedding_ids)
|
||||
log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
|
||||
await run(_add, vector_stores[index_name], embedding_ids, embeddings)
|
||||
|
||||
|
||||
@app.post("/index/{index_name}/create")
|
||||
|
@ -173,11 +183,15 @@ async def create(
|
|||
embeddings: Any = Depends(validate_embeddings),
|
||||
) -> None:
|
||||
if embeddings.shape[0] != len(embedding_ids):
|
||||
raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
|
||||
)
|
||||
if index_name in vector_stores:
|
||||
log.warn(f"Index '{index_name}' already exists. Overwriting.")
|
||||
log.info(f"Creating new index '{index_name}'")
|
||||
|
||||
vector_stores[index_name] = await run(_create)
|
||||
vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
|
||||
|
||||
|
||||
async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -223,7 +237,14 @@ def _create(
|
|||
embeddings: np.ndarray[int, np.dtype[np.float32]],
|
||||
) -> VectorStore:
|
||||
index = VectorStore(embeddings.shape[1])
|
||||
_add(index, embedding_ids, embeddings)
|
||||
return index
|
||||
|
||||
|
||||
def _add(
|
||||
index: VectorStore,
|
||||
embedding_ids: list[str],
|
||||
embeddings: np.ndarray[int, np.dtype[np.float32]],
|
||||
) -> None:
|
||||
with app.state.index_lock:
|
||||
index.add_with_ids(embeddings, embedding_ids) # type: ignore
|
||||
return index
|
||||
|
|
|
@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel):
|
|||
case _:
|
||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
||||
|
||||
return outputs[0][0]
|
||||
return outputs[0]
|
||||
|
||||
def _get_jina_model_name(self, model_name: str) -> str:
|
||||
if model_name in _MODELS:
|
||||
|
|
Loading…
Add table
Reference in a new issue