|
@@ -23,9 +23,7 @@ from .schemas import (
|
|
|
TextResponse,
|
|
|
)
|
|
|
|
|
|
-# import rich.pretty
|
|
|
|
|
|
-# rich.pretty.install()
|
|
|
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
|
|
app = FastAPI()
|
|
|
|
|
@@ -118,17 +116,18 @@ async def pipeline(
|
|
|
|
|
|
outputs = await _predict(model_name, model_type, inputs, **kwargs)
|
|
|
if index_name is not None:
|
|
|
+ expanded = np.expand_dims(outputs, 0)
|
|
|
if k is not None:
|
|
|
if k < 1:
|
|
|
raise HTTPException(400, f"k must be a positive integer; got {k}")
|
|
|
if index_name not in vector_stores:
|
|
|
raise HTTPException(404, f"Index '{index_name}' not found")
|
|
|
- outputs = await run(vector_stores[index_name].search, outputs, k)
|
|
|
+ outputs = await run(vector_stores[index_name].search, expanded, k)
|
|
|
if embedding_id is not None:
|
|
|
if index_name not in vector_stores:
|
|
|
- await create(index_name, [embedding_id], outputs)
|
|
|
+ await create(index_name, [embedding_id], expanded)
|
|
|
else:
|
|
|
- await add(index_name, [embedding_id], outputs)
|
|
|
+ await add(index_name, [embedding_id], expanded)
|
|
|
return ORJSONResponse(outputs)
|
|
|
|
|
|
|