working pipeline endpoint

This commit is contained in:
mertalev 2023-09-16 00:51:47 -04:00
parent 3c2265ecf4
commit ca5f0c7bbd
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
2 changed files with 37 additions and 16 deletions

View file

@ -23,26 +23,35 @@ from .schemas import (
TextResponse,
)
# import rich.pretty
# rich.pretty.install()
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
app = FastAPI()
vector_stores: dict[str, faiss.IndexIDMap2] = {}
class VectorStore:
def __init__(self, dims: int, index_cls: Type[faiss.Index] = faiss.IndexHNSWFlat, **kwargs: Any) -> None:
self.index = index_cls(dims, **kwargs)
self.id_to_key: dict[int, Any] = {}
def __init__(self, dims: int, index: str = "HNSW") -> None:
self.index = faiss.index_factory(dims, index)
self.key_to_id: dict[int, Any] = {}
def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
ids = self.index.assign(embeddings, k) # type: ignore
return [self.id_to_key[idx] for row in ids.tolist() for idx in row if not idx == -1]
keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
self.id_to_key |= {
id: key for id, key in zip(embedding_ids, range(self.index.ntotal, self.index.ntotal + len(embedding_ids)))
self.key_to_id |= {
key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
}
self.index.add(embeddings) # type: ignore
@property
def dims(self) -> int:
return self.index.d
vector_stores: dict[str, VectorStore] = {}
def validate_embeddings(embeddings: list[float]) -> Any:
embeddings = np.array(embeddings)
@ -119,7 +128,7 @@ async def pipeline(
if index_name not in vector_stores:
await create(index_name, [embedding_id], outputs)
else:
await run(vector_stores[index_name].add, [embedding_id], outputs)
await add(index_name, [embedding_id], outputs)
return ORJSONResponse(outputs)
@ -148,7 +157,7 @@ async def predict(
@app.post("/index/{index_name}/search", response_class=ORJSONResponse)
async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
raise HTTPException(404, f"Index '{index_name}' not found")
outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
return ORJSONResponse(outputs)
@ -160,10 +169,11 @@ async def add(
embedding_ids: list[str],
embeddings: Any = Depends(validate_embeddings),
) -> None:
if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
await create(index_name, embedding_ids, embeddings)
else:
await run(vector_stores[index_name].add_with_ids, embeddings, embedding_ids)
log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
await run(_add, vector_stores[index_name], embedding_ids, embeddings)
@app.post("/index/{index_name}/create")
@ -173,11 +183,15 @@ async def create(
embeddings: Any = Depends(validate_embeddings),
) -> None:
if embeddings.shape[0] != len(embedding_ids):
raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
raise HTTPException(
400,
f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
)
if index_name in vector_stores:
log.warn(f"Index '{index_name}' already exists. Overwriting.")
log.info(f"Creating new index '{index_name}'")
vector_stores[index_name] = await run(_create)
vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
@ -223,7 +237,14 @@ def _create(
embeddings: np.ndarray[int, np.dtype[np.float32]],
) -> VectorStore:
index = VectorStore(embeddings.shape[1])
_add(index, embedding_ids, embeddings)
return index
def _add(
index: VectorStore,
embedding_ids: list[str],
embeddings: np.ndarray[int, np.dtype[np.float32]],
) -> None:
with app.state.index_lock:
index.add_with_ids(embeddings, embedding_ids) # type: ignore
return index

View file

@ -102,7 +102,7 @@ class CLIPEncoder(InferenceModel):
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0]
return outputs[0]
def _get_jina_model_name(self, model_name: str) -> str:
if model_name in _MODELS: