we have dynamic batching at home
This commit is contained in:
parent
d4c23c8df8
commit
57506aa1fe
10 changed files with 280 additions and 76 deletions
|
@ -21,6 +21,8 @@ class Settings(BaseSettings):
|
||||||
request_threads: int = os.cpu_count() or 4
|
request_threads: int = os.cpu_count() or 4
|
||||||
model_inter_op_threads: int = 1
|
model_inter_op_threads: int = 1
|
||||||
model_intra_op_threads: int = 2
|
model_intra_op_threads: int = 2
|
||||||
|
max_batch_size: int = 8
|
||||||
|
batch_timeout_s: float = 0.005
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_prefix = "MACHINE_LEARNING_"
|
env_prefix = "MACHINE_LEARNING_"
|
||||||
|
|
|
@ -11,6 +11,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchF
|
||||||
from starlette.formparsers import MultiPartParser
|
from starlette.formparsers import MultiPartParser
|
||||||
|
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
from app.models.batcher import Batcher, ModelBatcher
|
||||||
|
|
||||||
from .config import log, settings
|
from .config import log, settings
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
|
@ -26,6 +27,7 @@ app = FastAPI()
|
||||||
|
|
||||||
def init_state() -> None:
|
def init_state() -> None:
|
||||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||||
|
app.state.model_batcher = ModelBatcher(max_size=settings.max_batch_size, timeout_s=settings.batch_timeout_s)
|
||||||
log.info(
|
log.info(
|
||||||
(
|
(
|
||||||
"Created in-memory cache with unloading "
|
"Created in-memory cache with unloading "
|
||||||
|
@ -62,9 +64,9 @@ async def predict(
|
||||||
image: UploadFile | None = None,
|
image: UploadFile | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
inputs: str | bytes = await image.read()
|
element: str | bytes = await image.read()
|
||||||
elif text is not None:
|
elif text is not None:
|
||||||
inputs = text
|
element = text
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, "Either image or text must be provided")
|
raise HTTPException(400, "Either image or text must be provided")
|
||||||
try:
|
try:
|
||||||
|
@ -74,15 +76,16 @@ async def predict(
|
||||||
|
|
||||||
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
||||||
model.configure(**kwargs)
|
model.configure(**kwargs)
|
||||||
outputs = await run(model, inputs)
|
batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs)
|
||||||
|
outputs = await batcher.batch_process(element, run, model)
|
||||||
return ORJSONResponse(outputs)
|
return ORJSONResponse(outputs)
|
||||||
|
|
||||||
|
|
||||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
async def run(model: InferenceModel, elements: list[Any]) -> Any:
|
||||||
if app.state.thread_pool is None:
|
if app.state.thread_pool is None:
|
||||||
return model.predict(inputs)
|
return model.predict_batch(elements)
|
||||||
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict_batch, elements)
|
||||||
|
|
||||||
|
|
||||||
async def load(model: InferenceModel) -> InferenceModel:
|
async def load(model: InferenceModel) -> InferenceModel:
|
||||||
|
|
|
@ -12,6 +12,10 @@ from ..config import get_cache_dir, log, settings
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_key(model_name: str, model_type: ModelType, mode: str = "") -> str:
|
||||||
|
return f"{model_name}{model_type.value}{mode}"
|
||||||
|
|
||||||
|
|
||||||
class InferenceModel(ABC):
|
class InferenceModel(ABC):
|
||||||
_model_type: ModelType
|
_model_type: ModelType
|
||||||
|
|
||||||
|
@ -65,14 +69,15 @@ class InferenceModel(ABC):
|
||||||
self._load()
|
self._load()
|
||||||
self.loaded = True
|
self.loaded = True
|
||||||
|
|
||||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
def predict(self, element: Any) -> Any:
|
||||||
|
return self.predict_batch([element])[0]
|
||||||
|
|
||||||
|
def predict_batch(self, inputs: list[Any]) -> list[Any]:
|
||||||
self.load()
|
self.load()
|
||||||
if model_kwargs:
|
return self._predict_batch(inputs)
|
||||||
self.configure(**model_kwargs)
|
|
||||||
return self._predict(inputs)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _predict(self, inputs: Any) -> Any:
|
def _predict_batch(self, inputs: list[Any]) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
def configure(self, **model_kwargs: Any) -> None:
|
||||||
|
|
65
machine-learning/app/models/batcher.py
Normal file
65
machine-learning/app/models/batcher.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, Awaitable, Callable, TypeVar
|
||||||
|
|
||||||
|
from app.schemas import ModelType
|
||||||
|
|
||||||
|
from .base import get_model_key, log
|
||||||
|
|
||||||
|
|
||||||
|
F = TypeVar("F")
|
||||||
|
P = TypeVar("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class Batcher:
|
||||||
|
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
|
||||||
|
self.max_size = max_size
|
||||||
|
self.timeout_s = timeout_s
|
||||||
|
self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=max_size)
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.processing: dict[int, Any] = {}
|
||||||
|
self.processed: dict[int, Any] = {}
|
||||||
|
self.element_id = 0
|
||||||
|
|
||||||
|
async def batch_process(self, element: Any, func: Callable[[list[P]], Awaitable[list[R]]], *args: Any, **kwargs: Any) -> Any:
|
||||||
|
cur_idx = self.element_id
|
||||||
|
self.element_id += 1
|
||||||
|
self.processing[cur_idx] = element
|
||||||
|
await self.queue.put(cur_idx)
|
||||||
|
try:
|
||||||
|
async with self.lock:
|
||||||
|
await self._batch(cur_idx)
|
||||||
|
if cur_idx not in self.processed:
|
||||||
|
await self._process(func, *args, **kwargs)
|
||||||
|
return self.processed.pop(cur_idx)
|
||||||
|
finally:
|
||||||
|
self.processing.pop(cur_idx, None)
|
||||||
|
self.processed.pop(cur_idx, None)
|
||||||
|
|
||||||
|
async def _batch(self, idx: int) -> list[Any]:
|
||||||
|
if idx not in self.processed:
|
||||||
|
start = time.monotonic()
|
||||||
|
while self.queue.qsize() < self.max_size and time.monotonic() - start < self.timeout_s:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def _process(self, func: Callable[[list[P]], Awaitable[list[R]]], *args, **kwargs) -> None:
|
||||||
|
batch_ids = [self.queue.get_nowait() for _ in range(self.queue.qsize())]
|
||||||
|
batch = [self.processing.pop(id) for id in batch_ids]
|
||||||
|
outputs = await func(*args, batch, **kwargs)
|
||||||
|
for id, output in zip(batch_ids, outputs):
|
||||||
|
self.processed[id] = output
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBatcher:
|
||||||
|
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
|
||||||
|
self.batchers = {}
|
||||||
|
self.max_size = max_size
|
||||||
|
self.timeout_s = timeout_s
|
||||||
|
|
||||||
|
def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any):
|
||||||
|
key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
|
||||||
|
if key not in self.batchers:
|
||||||
|
self.batchers[key] = Batcher(max_size=self.max_size, timeout_s=self.timeout_s)
|
||||||
|
|
||||||
|
return self.batchers[key]
|
|
@ -5,7 +5,7 @@ from aiocache.lock import OptimisticLock
|
||||||
from aiocache.plugins import BasePlugin, TimingPlugin
|
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||||
|
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel, get_model_key
|
||||||
|
|
||||||
|
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
|
@ -46,7 +46,7 @@ class ModelCache:
|
||||||
model: The requested model.
|
model: The requested model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
|
||||||
async with OptimisticLock(self.cache, key) as lock:
|
async with OptimisticLock(self.cache, key) as lock:
|
||||||
model = await self.cache.get(key)
|
model = await self.cache.get(key)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, TypeGuard
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
||||||
|
@ -17,6 +18,14 @@ from ..schemas import ModelType
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_list(images: list[Any]) -> TypeGuard[list[Image.Image | bytes]]:
|
||||||
|
return any(isinstance(image, (Image.Image, bytes)) for image in images)
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_list(texts: list[Any]) -> TypeGuard[list[str]]:
|
||||||
|
return any(isinstance(text, str) for text in texts)
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoder(InferenceModel):
|
class CLIPEncoder(InferenceModel):
|
||||||
_model_type = ModelType.CLIP
|
_model_type = ModelType.CLIP
|
||||||
|
|
||||||
|
@ -70,31 +79,42 @@ class CLIPEncoder(InferenceModel):
|
||||||
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
||||||
self.transform = _transform_pil_image(image_size)
|
self.transform = _transform_pil_image(image_size)
|
||||||
|
|
||||||
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
def _predict_batch(self, images_or_text: list[Image.Image | bytes] | list[str]) -> list[list[float]]:
|
||||||
if isinstance(image_or_text, bytes):
|
if not images_or_text:
|
||||||
image_or_text = Image.open(BytesIO(image_or_text))
|
return []
|
||||||
|
|
||||||
match image_or_text:
|
if is_image_list(images_or_text):
|
||||||
case Image.Image():
|
outputs = self._predict_images(images_or_text)
|
||||||
if self.mode == "text":
|
elif is_text_list(images_or_text):
|
||||||
raise TypeError("Cannot encode image as text-only model")
|
outputs = self._predict_text(images_or_text)
|
||||||
pixel_values = self.transform(image_or_text)
|
else:
|
||||||
assert isinstance(pixel_values, torch.Tensor)
|
raise TypeError(f"Expected list of images or text, but got: {type(images_or_text[0])}")
|
||||||
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
|
|
||||||
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
|
||||||
case str():
|
|
||||||
if self.mode == "vision":
|
|
||||||
raise TypeError("Cannot encode text as vision-only model")
|
|
||||||
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
|
|
||||||
inputs = {
|
|
||||||
"input_ids": text_inputs["input_ids"].int().numpy(),
|
|
||||||
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
|
||||||
}
|
|
||||||
outputs = self.text_model.run(self.text_outputs, inputs)
|
|
||||||
case _:
|
|
||||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
|
||||||
|
|
||||||
return outputs[0][0].tolist()
|
return outputs
|
||||||
|
|
||||||
|
def _predict_images(self, images: list[Image.Image | bytes]) -> list[list[float]]:
|
||||||
|
if not images:
|
||||||
|
return []
|
||||||
|
|
||||||
|
for i, element in enumerate(images):
|
||||||
|
if isinstance(element, bytes):
|
||||||
|
images[i] = Image.open(BytesIO(element))
|
||||||
|
|
||||||
|
pixel_values = torch.stack([self.transform(image) for image in images]).numpy()
|
||||||
|
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
||||||
|
return outputs[0].tolist()
|
||||||
|
|
||||||
|
def _predict_text(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
text_inputs: dict[str, torch.Tensor] = self.tokenizer(texts)
|
||||||
|
inputs = {
|
||||||
|
"input_ids": text_inputs["input_ids"].int().numpy(),
|
||||||
|
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
||||||
|
}
|
||||||
|
outputs = self.text_model.run(self.text_outputs, inputs)
|
||||||
|
return outputs[0].tolist()
|
||||||
|
|
||||||
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
||||||
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
||||||
|
|
|
@ -9,9 +9,12 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
||||||
from insightface.utils.face_align import norm_crop
|
from insightface.utils.face_align import norm_crop
|
||||||
from insightface.utils.storage import BASE_REPO_URL, download_file
|
from insightface.utils.storage import BASE_REPO_URL, download_file
|
||||||
|
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType, ndarray
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
from onnx.tools import update_model_dims
|
||||||
|
|
||||||
|
|
||||||
class FaceRecognizer(InferenceModel):
|
class FaceRecognizer(InferenceModel):
|
||||||
_model_type = ModelType.FACIAL_RECOGNITION
|
_model_type = ModelType.FACIAL_RECOGNITION
|
||||||
|
@ -36,6 +39,8 @@ class FaceRecognizer(InferenceModel):
|
||||||
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
||||||
zip_file.unlink()
|
zip_file.unlink()
|
||||||
|
|
||||||
|
self._add_batch_dimension(self.cache_dir / rec_file)
|
||||||
|
|
||||||
def _load(self) -> None:
|
def _load(self) -> None:
|
||||||
try:
|
try:
|
||||||
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
||||||
|
@ -43,29 +48,36 @@ class FaceRecognizer(InferenceModel):
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise FileNotFoundError("Facial recognition models not found in cache directory")
|
raise FileNotFoundError("Facial recognition models not found in cache directory")
|
||||||
|
|
||||||
self.det_model = RetinaFace(
|
det_session = ort.InferenceSession(
|
||||||
session=ort.InferenceSession(
|
det_file.as_posix(),
|
||||||
det_file.as_posix(),
|
sess_options=self.sess_options,
|
||||||
sess_options=self.sess_options,
|
providers=self.providers,
|
||||||
providers=self.providers,
|
provider_options=self.provider_options,
|
||||||
provider_options=self.provider_options,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.rec_model = ArcFaceONNX(
|
self.det_model = RetinaFace(session=det_session)
|
||||||
rec_file.as_posix(),
|
|
||||||
session=ort.InferenceSession(
|
|
||||||
rec_file.as_posix(),
|
|
||||||
sess_options=self.sess_options,
|
|
||||||
providers=self.providers,
|
|
||||||
provider_options=self.provider_options,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.det_model.prepare(
|
self.det_model.prepare(
|
||||||
ctx_id=0,
|
ctx_id=0,
|
||||||
det_thresh=self.min_score,
|
det_thresh=self.min_score,
|
||||||
input_size=(640, 640),
|
input_size=(640, 640),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rec_session = ort.InferenceSession(
|
||||||
|
rec_file.as_posix(),
|
||||||
|
sess_options=self.sess_options,
|
||||||
|
providers=self.providers,
|
||||||
|
provider_options=self.provider_options,
|
||||||
|
)
|
||||||
|
print(rec_session.get_inputs())
|
||||||
|
if rec_session.get_inputs()[0].shape[0] != "batch":
|
||||||
|
del rec_session
|
||||||
|
self._add_batch_dimension(rec_file)
|
||||||
|
rec_session = ort.InferenceSession(
|
||||||
|
rec_file.as_posix(),
|
||||||
|
sess_options=self.sess_options,
|
||||||
|
providers=self.providers,
|
||||||
|
provider_options=self.provider_options,
|
||||||
|
)
|
||||||
|
self.rec_model = ArcFaceONNX(rec_file.as_posix(), session=rec_session)
|
||||||
self.rec_model.prepare(ctx_id=0)
|
self.rec_model.prepare(ctx_id=0)
|
||||||
|
|
||||||
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
|
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
|
||||||
|
@ -100,6 +112,90 @@ class FaceRecognizer(InferenceModel):
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _predict_batch(self, images: list[cv2.Mat]) -> list[list[dict[str, Any]]]:
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
images[i] = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
|
batch_det, batch_kpss = self._detect(images)
|
||||||
|
batch_cropped_images, batch_offsets = self._preprocess(images, batch_kpss)
|
||||||
|
if batch_cropped_images:
|
||||||
|
batch_embeddings = self.rec_model.get_feat(images)
|
||||||
|
results = self._postprocess(images, batch_det, batch_embeddings, batch_offsets)
|
||||||
|
else:
|
||||||
|
results = self._postprocess(images, batch_det)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _detect(self, images: list[cv2.Mat]) -> tuple[list[ndarray], ...]:
|
||||||
|
batch_det: list[ndarray] = []
|
||||||
|
batch_kpss: list[ndarray] = []
|
||||||
|
# detection model doesn't support batching, but recognition model does
|
||||||
|
for image in images:
|
||||||
|
bboxes, kpss = self.det_model.detect(image)
|
||||||
|
batch_det.append(bboxes)
|
||||||
|
batch_kpss.append(kpss)
|
||||||
|
return batch_det, batch_kpss
|
||||||
|
|
||||||
|
def _preprocess(self, images: list[cv2.Mat], batch_kpss: list[ndarray]) -> tuple[list[cv2.Mat], list[int]]:
|
||||||
|
batch_cropped_images = []
|
||||||
|
batch_offsets = []
|
||||||
|
total_faces = 0
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
kpss = batch_kpss[i]
|
||||||
|
total_faces += kpss.shape[0]
|
||||||
|
batch_offsets.append(total_faces)
|
||||||
|
for kps in kpss:
|
||||||
|
batch_cropped_images.append(norm_crop(image, kps))
|
||||||
|
return batch_cropped_images, batch_offsets
|
||||||
|
|
||||||
|
def _postprocess(
|
||||||
|
self,
|
||||||
|
images: list[cv2.Mat],
|
||||||
|
batch_det: list[ndarray],
|
||||||
|
batch_embeddings: ndarray | None = None,
|
||||||
|
batch_offsets: list[int] | None = None,
|
||||||
|
) -> list[list[dict[str, Any]]]:
|
||||||
|
if batch_embeddings is not None and batch_offsets is not None:
|
||||||
|
image_embeddings: list[ndarray] | None = np.array_split(batch_embeddings, batch_offsets)
|
||||||
|
else:
|
||||||
|
image_embeddings = None
|
||||||
|
|
||||||
|
batch_faces: list[list[dict[str, Any]]] = []
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
faces: list[dict[str, Any]] = []
|
||||||
|
batch_faces.append(faces)
|
||||||
|
if image_embeddings is None or image_embeddings[i].shape[0] == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
height, width, _ = image.shape
|
||||||
|
|
||||||
|
embeddings = image_embeddings[i].tolist()
|
||||||
|
bboxes = batch_det[i][:, :4].round().tolist()
|
||||||
|
det_scores = batch_det[i][:, 4].tolist()
|
||||||
|
for (x1, y1, x2, y2), embedding, det_score in zip(bboxes, embeddings, det_scores):
|
||||||
|
face = {
|
||||||
|
"imageWidth": width,
|
||||||
|
"imageHeight": height,
|
||||||
|
"boundingBox": {
|
||||||
|
"x1": x1,
|
||||||
|
"y1": y1,
|
||||||
|
"x2": x2,
|
||||||
|
"y2": y2,
|
||||||
|
},
|
||||||
|
"score": det_score,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
faces.append(face)
|
||||||
|
return batch_faces
|
||||||
|
|
||||||
|
def _add_batch_dimension(self, model_path: Path) -> None:
|
||||||
|
rec_proto = onnx.load(model_path.as_posix())
|
||||||
|
inputs = {input.name: ['batch'] + [shape.dim_value for shape in input.type.tensor_type.shape.dim[1:]] for input in rec_proto.graph.input}
|
||||||
|
outputs = {output.name: ['batch'] + [shape.dim_value for shape in output.type.tensor_type.shape.dim[1:]] for output in rec_proto.graph.output}
|
||||||
|
rec_proto = update_model_dims.update_inputs_outputs_dims(rec_proto, inputs, outputs)
|
||||||
|
onnx.save(rec_proto, model_path.open("wb"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cached(self) -> bool:
|
def cached(self) -> bool:
|
||||||
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
|
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
|
||||||
|
|
|
@ -63,13 +63,18 @@ class ImageClassifier(InferenceModel):
|
||||||
feature_extractor=processor,
|
feature_extractor=processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
def _predict_batch(self, images: list[Image.Image | bytes]) -> list[list[str]]:
|
||||||
if isinstance(image, bytes):
|
for i, image in enumerate(images):
|
||||||
image = Image.open(BytesIO(image))
|
if isinstance(image, bytes):
|
||||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
images[i] = Image.open(BytesIO(image))
|
||||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
|
||||||
|
|
||||||
return tags
|
batch_predictions: list[list[dict[str, Any]]] = self.model(images)
|
||||||
|
results = [self._postprocess(predictions) for predictions in batch_predictions]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _postprocess(self, predictions: list[dict[str, Any]]) -> list[str]:
|
||||||
|
return [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||||
|
|
||||||
def configure(self, **model_kwargs: Any) -> None:
|
def configure(self, **model_kwargs: Any) -> None:
|
||||||
self.min_score = model_kwargs.pop("minScore", self.min_score)
|
self.min_score = model_kwargs.pop("minScore", self.min_score)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,3 +33,6 @@ class ModelType(StrEnum):
|
||||||
IMAGE_CLASSIFICATION = "image-classification"
|
IMAGE_CLASSIFICATION = "image-classification"
|
||||||
CLIP = "clip"
|
CLIP = "clip"
|
||||||
FACIAL_RECOGNITION = "facial-recognition"
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
|
||||||
|
|
||||||
|
ndarray: TypeAlias = np.ndarray[int, np.dtype[Any]]
|
||||||
|
|
|
@ -22,16 +22,17 @@ from .schemas import ModelType
|
||||||
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
class TestImageClassifier:
|
class TestImageClassifier:
|
||||||
classifier_preds = [
|
classifier_preds = [[
|
||||||
{"label": "that's an image alright", "score": 0.8},
|
{"label": "that's an image alright", "score": 0.8},
|
||||||
{"label": "well it ends with .jpg", "score": 0.1},
|
{"label": "well it ends with .jpg", "score": 0.1},
|
||||||
{"label": "idk, im just seeing bytes", "score": 0.05},
|
{"label": "idk, im just seeing bytes", "score": 0.05},
|
||||||
{"label": "not sure", "score": 0.04},
|
{"label": "not sure", "score": 0.04},
|
||||||
{"label": "probably a virus", "score": 0.01},
|
{"label": "probably a virus", "score": 0.01},
|
||||||
]
|
]]
|
||||||
|
|
||||||
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
async def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(ImageClassifier, "load")
|
mocker.patch.object(ImageClassifier, "load")
|
||||||
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
||||||
assert classifier.min_score == 0.0
|
assert classifier.min_score == 0.0
|
||||||
|
@ -39,9 +40,9 @@ class TestImageClassifier:
|
||||||
classifier.model = mock.Mock()
|
classifier.model = mock.Mock()
|
||||||
classifier.model.return_value = self.classifier_preds
|
classifier.model.return_value = self.classifier_preds
|
||||||
|
|
||||||
all_labels = classifier.predict(pil_image)
|
all_labels = await classifier.predict(pil_image)
|
||||||
classifier.min_score = 0.5
|
classifier.min_score = 0.5
|
||||||
filtered_labels = classifier.predict(pil_image)
|
filtered_labels = await classifier.predict(pil_image)
|
||||||
|
|
||||||
assert all_labels == [
|
assert all_labels == [
|
||||||
"that's an image alright",
|
"that's an image alright",
|
||||||
|
@ -54,29 +55,30 @@ class TestImageClassifier:
|
||||||
assert filtered_labels == ["that's an image alright"]
|
assert filtered_labels == ["that's an image alright"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
class TestCLIP:
|
class TestCLIP:
|
||||||
embedding = np.random.rand(512).astype(np.float32)
|
embedding = np.random.rand(1, 512).astype(np.float32)
|
||||||
|
|
||||||
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
async def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(CLIPEncoder, "download")
|
mocker.patch.object(CLIPEncoder, "download")
|
||||||
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
||||||
mocked.return_value.run.return_value = [[self.embedding]]
|
mocked.return_value.run.return_value = [self.embedding]
|
||||||
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
|
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
|
||||||
assert clip_encoder.mode == "vision"
|
assert clip_encoder.mode == "vision"
|
||||||
embedding = clip_encoder.predict(pil_image)
|
embedding = await clip_encoder.predict(pil_image)
|
||||||
|
|
||||||
assert isinstance(embedding, list)
|
assert isinstance(embedding, list)
|
||||||
assert len(embedding) == 512
|
assert len(embedding) == 512
|
||||||
assert all([isinstance(num, float) for num in embedding])
|
assert all([isinstance(num, float) for num in embedding])
|
||||||
clip_encoder.vision_model.run.assert_called_once()
|
clip_encoder.vision_model.run.assert_called_once()
|
||||||
|
|
||||||
def test_basic_text(self, mocker: MockerFixture) -> None:
|
async def test_basic_text(self, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(CLIPEncoder, "download")
|
mocker.patch.object(CLIPEncoder, "download")
|
||||||
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
||||||
mocked.return_value.run.return_value = [[self.embedding]]
|
mocked.return_value.run.return_value = [self.embedding]
|
||||||
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
|
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
|
||||||
assert clip_encoder.mode == "text"
|
assert clip_encoder.mode == "text"
|
||||||
embedding = clip_encoder.predict("test search query")
|
embedding = await clip_encoder.predict("test search query")
|
||||||
|
|
||||||
assert isinstance(embedding, list)
|
assert isinstance(embedding, list)
|
||||||
assert len(embedding) == 512
|
assert len(embedding) == 512
|
||||||
|
@ -91,7 +93,8 @@ class TestFaceRecognition:
|
||||||
|
|
||||||
assert face_recognizer.min_score == 0.5
|
assert face_recognizer.min_score == 0.5
|
||||||
|
|
||||||
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||||
mocker.patch.object(FaceRecognizer, "load")
|
mocker.patch.object(FaceRecognizer, "load")
|
||||||
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
|
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
|
||||||
|
|
||||||
|
@ -108,7 +111,7 @@ class TestFaceRecognition:
|
||||||
rec_model.get_feat.return_value = embedding
|
rec_model.get_feat.return_value = embedding
|
||||||
face_recognizer.rec_model = rec_model
|
face_recognizer.rec_model = rec_model
|
||||||
|
|
||||||
faces = face_recognizer.predict(cv_image)
|
faces = await face_recognizer.predict(cv_image)
|
||||||
|
|
||||||
assert len(faces) == num_faces
|
assert len(faces) == num_faces
|
||||||
for face in faces:
|
for face in faces:
|
||||||
|
@ -119,7 +122,7 @@ class TestFaceRecognition:
|
||||||
assert all([isinstance(num, float) for num in face["embedding"]])
|
assert all([isinstance(num, float) for num in face["embedding"]])
|
||||||
|
|
||||||
det_model.detect.assert_called_once()
|
det_model.detect.assert_called_once()
|
||||||
assert rec_model.get_feat.call_count == num_faces
|
rec_model.get_feat.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Reference in a new issue