From 57506aa1fe1d0654f4a3d674adf1ba2d9bc96135 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Sat, 21 Oct 2023 22:33:28 -0400 Subject: [PATCH] we have dynamic batching at home --- machine-learning/app/config.py | 2 + machine-learning/app/main.py | 15 +- machine-learning/app/models/base.py | 15 +- machine-learning/app/models/batcher.py | 65 +++++++++ machine-learning/app/models/cache.py | 4 +- machine-learning/app/models/clip.py | 68 +++++---- .../app/models/facial_recognition.py | 132 +++++++++++++++--- .../app/models/image_classification.py | 17 ++- machine-learning/app/schemas.py | 5 + machine-learning/app/test_main.py | 33 +++-- 10 files changed, 280 insertions(+), 76 deletions(-) create mode 100644 machine-learning/app/models/batcher.py diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index f3b41d22d..74fbaeea6 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -21,6 +21,8 @@ class Settings(BaseSettings): request_threads: int = os.cpu_count() or 4 model_inter_op_threads: int = 1 model_intra_op_threads: int = 2 + max_batch_size: int = 8 + batch_timeout_s: float = 0.005 class Config: env_prefix = "MACHINE_LEARNING_" diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index 375c14a9e..e9bf0ada3 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -11,6 +11,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchF from starlette.formparsers import MultiPartParser from app.models.base import InferenceModel +from app.models.batcher import Batcher, ModelBatcher from .config import log, settings from .models.cache import ModelCache @@ -26,6 +27,7 @@ app = FastAPI() def init_state() -> None: app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) + app.state.model_batcher = ModelBatcher(max_size=settings.max_batch_size, timeout_s=settings.batch_timeout_s) log.info( ( "Created in-memory cache with unloading " @@ -62,9 +64,9 @@ async def predict( image: UploadFile | None = None, ) -> Any: if image is not None: - inputs: str | bytes = await image.read() + element: str | bytes = await image.read() elif text is not None: - inputs = text + element = text else: raise HTTPException(400, "Either image or text must be provided") try: @@ -74,15 +76,16 @@ async def predict( model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs)) model.configure(**kwargs) - outputs = await run(model, inputs) + batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs) + outputs = await batcher.batch_process(element, run, model) return ORJSONResponse(outputs) -async def run(model: InferenceModel, inputs: Any) -> Any: +async def run(model: InferenceModel, elements: list[Any]) -> Any: if app.state.thread_pool is None: - return model.predict(inputs) + return model.predict_batch(elements) - return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) + return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict_batch, elements) async def load(model: InferenceModel) -> InferenceModel: diff --git a/machine-learning/app/models/base.py b/machine-learning/app/models/base.py index 5f7f7c6e0..f963caa97 100644 --- a/machine-learning/app/models/base.py +++ b/machine-learning/app/models/base.py @@ -12,6 +12,10 @@ from ..config import get_cache_dir, log, settings from ..schemas import ModelType +def get_model_key(model_name: str, model_type: ModelType, mode: str = "") -> str: + return f"{model_name}{model_type.value}{mode}" + + class InferenceModel(ABC): _model_type: ModelType @@ -65,14 +69,15 @@ class InferenceModel(ABC): self._load() self.loaded = True - def predict(self, inputs: Any, **model_kwargs: Any) -> Any: + def predict(self, element: Any) -> Any: + return self.predict_batch([element])[0] + + def predict_batch(self, inputs: list[Any]) -> list[Any]: self.load() - if model_kwargs: - self.configure(**model_kwargs) - return self._predict(inputs) + return self._predict_batch(inputs) @abstractmethod - def _predict(self, inputs: Any) -> Any: + def _predict_batch(self, inputs: list[Any]) -> Any: ... def configure(self, **model_kwargs: Any) -> None: diff --git a/machine-learning/app/models/batcher.py b/machine-learning/app/models/batcher.py new file mode 100644 index 000000000..34965b161 --- /dev/null +++ b/machine-learning/app/models/batcher.py @@ -0,0 +1,65 @@ +import asyncio +import time +from typing import Any, Awaitable, Callable, TypeVar + +from app.schemas import ModelType + +from .base import get_model_key, log + + +F = TypeVar("F") +P = TypeVar("P") +R = TypeVar("R") + + +class Batcher: + def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None: + self.max_size = max_size + self.timeout_s = timeout_s + self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=max_size) + self.lock = asyncio.Lock() + self.processing: dict[int, Any] = {} + self.processed: dict[int, Any] = {} + self.element_id = 0 + + async def batch_process(self, element: Any, func: Callable[[list[P]], Awaitable[list[R]]], *args: Any, **kwargs: Any) -> Any: + cur_idx = self.element_id + self.element_id += 1 + self.processing[cur_idx] = element + await self.queue.put(cur_idx) + try: + async with self.lock: + await self._batch(cur_idx) + if cur_idx not in self.processed: + await self._process(func, *args, **kwargs) + return self.processed.pop(cur_idx) + finally: + self.processing.pop(cur_idx, None) + self.processed.pop(cur_idx, None) + + async def _batch(self, idx: int) -> list[Any]: + if idx not in self.processed: + start = time.monotonic() + while self.queue.qsize() < self.max_size and time.monotonic() - start < self.timeout_s: + await asyncio.sleep(0) + + async def _process(self, func: Callable[[list[P]], Awaitable[list[R]]], *args, **kwargs) -> None: + batch_ids = [self.queue.get_nowait() for _ in range(self.queue.qsize())] + batch = [self.processing.pop(id) for id in batch_ids] + outputs = await func(*args, batch, **kwargs) + for id, output in zip(batch_ids, outputs): + self.processed[id] = output + + +class ModelBatcher: + def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None: + self.batchers = {} + self.max_size = max_size + self.timeout_s = timeout_s + + def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any): + key = get_model_key(model_name, model_type, model_kwargs.get("mode", "")) + if key not in self.batchers: + self.batchers[key] = Batcher(max_size=self.max_size, timeout_s=self.timeout_s) + + return self.batchers[key] diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py index 0e37c4aa6..61c3ea2d1 100644 --- a/machine-learning/app/models/cache.py +++ b/machine-learning/app/models/cache.py @@ -5,7 +5,7 @@ from aiocache.lock import OptimisticLock from aiocache.plugins import BasePlugin, TimingPlugin from ..schemas import ModelType -from .base import InferenceModel +from .base import InferenceModel, get_model_key class ModelCache: @@ -46,7 +46,7 @@ class ModelCache: model: The requested model. """ - key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" + key = get_model_key(model_name, model_type, model_kwargs.get("mode", "")) async with OptimisticLock(self.cache, key) as lock: model = await self.cache.get(key) if model is None: diff --git a/machine-learning/app/models/clip.py b/machine-learning/app/models/clip.py index c1ce1801b..606e95743 100644 --- a/machine-learning/app/models/clip.py +++ b/machine-learning/app/models/clip.py @@ -1,8 +1,9 @@ import os import zipfile from io import BytesIO -from typing import Any, Literal +from typing import Any, Literal, TypeGuard +import numpy as np import onnxruntime as ort import torch from clip_server.model.clip import BICUBIC, _convert_image_to_rgb @@ -17,6 +18,14 @@ from ..schemas import ModelType from .base import InferenceModel +def is_image_list(images: list[Any]) -> TypeGuard[list[Image.Image | bytes]]: + return any(isinstance(image, (Image.Image, bytes)) for image in images) + + +def is_text_list(texts: list[Any]) -> TypeGuard[list[str]]: + return any(isinstance(text, str) for text in texts) + + class CLIPEncoder(InferenceModel): _model_type = ModelType.CLIP @@ -70,31 +79,42 @@ class CLIPEncoder(InferenceModel): image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] self.transform = _transform_pil_image(image_size) - def _predict(self, image_or_text: Image.Image | str) -> list[float]: - if isinstance(image_or_text, bytes): - image_or_text = Image.open(BytesIO(image_or_text)) + def _predict_batch(self, images_or_text: list[Image.Image | bytes] | list[str]) -> list[list[float]]: + if not images_or_text: + return [] - match image_or_text: - case Image.Image(): - if self.mode == "text": - raise TypeError("Cannot encode image as text-only model") - pixel_values = self.transform(image_or_text) - assert isinstance(pixel_values, torch.Tensor) - pixel_values = torch.unsqueeze(pixel_values, 0).numpy() - outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) - case str(): - if self.mode == "vision": - raise TypeError("Cannot encode text as vision-only model") - text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text) - inputs = { - "input_ids": text_inputs["input_ids"].int().numpy(), - "attention_mask": text_inputs["attention_mask"].int().numpy(), - } - outputs = self.text_model.run(self.text_outputs, inputs) - case _: - raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") + if is_image_list(images_or_text): + outputs = self._predict_images(images_or_text) + elif is_text_list(images_or_text): + outputs = self._predict_text(images_or_text) + else: + raise TypeError(f"Expected list of images or text, but got: {type(images_or_text[0])}") - return outputs[0][0].tolist() + return outputs + + def _predict_images(self, images: list[Image.Image | bytes]) -> list[list[float]]: + if not images: + return [] + + for i, element in enumerate(images): + if isinstance(element, bytes): + images[i] = Image.open(BytesIO(element)) + + pixel_values = torch.stack([self.transform(image) for image in images]).numpy() + outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values}) + return outputs[0].tolist() + + def _predict_text(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + + text_inputs: dict[str, torch.Tensor] = self.tokenizer(texts) + inputs = { + "input_ids": text_inputs["input_ids"].int().numpy(), + "attention_mask": text_inputs["attention_mask"].int().numpy(), + } + outputs = self.text_model.run(self.text_outputs, inputs) + return outputs[0].tolist() def _download_model(self, model_name: str, model_md5: str) -> bool: # downloading logic is adapted from clip-server's CLIPOnnxModel class diff --git a/machine-learning/app/models/facial_recognition.py b/machine-learning/app/models/facial_recognition.py index cd4be2308..7ed023a59 100644 --- a/machine-learning/app/models/facial_recognition.py +++ b/machine-learning/app/models/facial_recognition.py @@ -9,9 +9,12 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace from insightface.utils.face_align import norm_crop from insightface.utils.storage import BASE_REPO_URL, download_file -from ..schemas import ModelType +from ..schemas import ModelType, ndarray from .base import InferenceModel +import onnx +from onnx.tools import update_model_dims + class FaceRecognizer(InferenceModel): _model_type = ModelType.FACIAL_RECOGNITION @@ -36,6 +39,8 @@ class FaceRecognizer(InferenceModel): zip.extractall(self.cache_dir, members=[det_file, rec_file]) zip_file.unlink() + self._add_batch_dimension(self.cache_dir / rec_file) + def _load(self) -> None: try: det_file = next(self.cache_dir.glob("det_*.onnx")) @@ -43,29 +48,36 @@ class FaceRecognizer(InferenceModel): except StopIteration: raise FileNotFoundError("Facial recognition models not found in cache directory") - self.det_model = RetinaFace( - session=ort.InferenceSession( - det_file.as_posix(), - sess_options=self.sess_options, - providers=self.providers, - provider_options=self.provider_options, - ), + det_session = ort.InferenceSession( + det_file.as_posix(), + sess_options=self.sess_options, + providers=self.providers, + provider_options=self.provider_options, ) - self.rec_model = ArcFaceONNX( - rec_file.as_posix(), - session=ort.InferenceSession( - rec_file.as_posix(), - sess_options=self.sess_options, - providers=self.providers, - provider_options=self.provider_options, - ), - ) - + self.det_model = RetinaFace(session=det_session) self.det_model.prepare( ctx_id=0, det_thresh=self.min_score, input_size=(640, 640), ) + + rec_session = ort.InferenceSession( + rec_file.as_posix(), + sess_options=self.sess_options, + providers=self.providers, + provider_options=self.provider_options, + ) + print(rec_session.get_inputs()) + if rec_session.get_inputs()[0].shape[0] != "batch": + del rec_session + self._add_batch_dimension(rec_file) + rec_session = ort.InferenceSession( + rec_file.as_posix(), + sess_options=self.sess_options, + providers=self.providers, + provider_options=self.provider_options, + ) + self.rec_model = ArcFaceONNX(rec_file.as_posix(), session=rec_session) self.rec_model.prepare(ctx_id=0) def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]: @@ -100,6 +112,90 @@ class FaceRecognizer(InferenceModel): ) return results + def _predict_batch(self, images: list[cv2.Mat]) -> list[list[dict[str, Any]]]: + for i, image in enumerate(images): + if isinstance(image, bytes): + images[i] = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR) + + batch_det, batch_kpss = self._detect(images) + batch_cropped_images, batch_offsets = self._preprocess(images, batch_kpss) + if batch_cropped_images: + batch_embeddings = self.rec_model.get_feat(images) + results = self._postprocess(images, batch_det, batch_embeddings, batch_offsets) + else: + results = self._postprocess(images, batch_det) + return results + + def _detect(self, images: list[cv2.Mat]) -> tuple[list[ndarray], ...]: + batch_det: list[ndarray] = [] + batch_kpss: list[ndarray] = [] + # detection model doesn't support batching, but recognition model does + for image in images: + bboxes, kpss = self.det_model.detect(image) + batch_det.append(bboxes) + batch_kpss.append(kpss) + return batch_det, batch_kpss + + def _preprocess(self, images: list[cv2.Mat], batch_kpss: list[ndarray]) -> tuple[list[cv2.Mat], list[int]]: + batch_cropped_images = [] + batch_offsets = [] + total_faces = 0 + for i, image in enumerate(images): + kpss = batch_kpss[i] + total_faces += kpss.shape[0] + batch_offsets.append(total_faces) + for kps in kpss: + batch_cropped_images.append(norm_crop(image, kps)) + return batch_cropped_images, batch_offsets + + def _postprocess( + self, + images: list[cv2.Mat], + batch_det: list[ndarray], + batch_embeddings: ndarray | None = None, + batch_offsets: list[int] | None = None, + ) -> list[list[dict[str, Any]]]: + if batch_embeddings is not None and batch_offsets is not None: + image_embeddings: list[ndarray] | None = np.array_split(batch_embeddings, batch_offsets) + else: + image_embeddings = None + + batch_faces: list[list[dict[str, Any]]] = [] + for i, image in enumerate(images): + faces: list[dict[str, Any]] = [] + batch_faces.append(faces) + if image_embeddings is None or image_embeddings[i].shape[0] == 0: + continue + + height, width, _ = image.shape + + embeddings = image_embeddings[i].tolist() + bboxes = batch_det[i][:, :4].round().tolist() + det_scores = batch_det[i][:, 4].tolist() + for (x1, y1, x2, y2), embedding, det_score in zip(bboxes, embeddings, det_scores): + face = { + "imageWidth": width, + "imageHeight": height, + "boundingBox": { + "x1": x1, + "y1": y1, + "x2": x2, + "y2": y2, + }, + "score": det_score, + "embedding": embedding, + } + + faces.append(face) + return batch_faces + + def _add_batch_dimension(self, model_path: Path) -> None: + rec_proto = onnx.load(model_path.as_posix()) + inputs = {input.name: ['batch'] + [shape.dim_value for shape in input.type.tensor_type.shape.dim[1:]] for input in rec_proto.graph.input} + outputs = {output.name: ['batch'] + [shape.dim_value for shape in output.type.tensor_type.shape.dim[1:]] for output in rec_proto.graph.output} + rec_proto = update_model_dims.update_inputs_outputs_dims(rec_proto, inputs, outputs) + onnx.save(rec_proto, model_path.open("wb")) + @property def cached(self) -> bool: return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx")) diff --git a/machine-learning/app/models/image_classification.py b/machine-learning/app/models/image_classification.py index cbf784e5a..387056e94 100644 --- a/machine-learning/app/models/image_classification.py +++ b/machine-learning/app/models/image_classification.py @@ -63,13 +63,18 @@ class ImageClassifier(InferenceModel): feature_extractor=processor, ) - def _predict(self, image: Image.Image | bytes) -> list[str]: - if isinstance(image, bytes): - image = Image.open(BytesIO(image)) - predictions: list[dict[str, Any]] = self.model(image) # type: ignore - tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score] + def _predict_batch(self, images: list[Image.Image | bytes]) -> list[list[str]]: + for i, image in enumerate(images): + if isinstance(image, bytes): + images[i] = Image.open(BytesIO(image)) - return tags + batch_predictions: list[list[dict[str, Any]]] = self.model(images) + results = [self._postprocess(predictions) for predictions in batch_predictions] + + return results + + def _postprocess(self, predictions: list[dict[str, Any]]) -> list[str]: + return [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score] def configure(self, **model_kwargs: Any) -> None: self.min_score = model_kwargs.pop("minScore", self.min_score) diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index cad279687..a54ae7be6 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -1,5 +1,7 @@ from enum import StrEnum +from typing import Any, TypeAlias +import numpy as np from pydantic import BaseModel @@ -31,3 +33,6 @@ class ModelType(StrEnum): IMAGE_CLASSIFICATION = "image-classification" CLIP = "clip" FACIAL_RECOGNITION = "facial-recognition" + + +ndarray: TypeAlias = np.ndarray[int, np.dtype[Any]] diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index 3e21389da..e46509304 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -22,16 +22,17 @@ from .schemas import ModelType ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]] +@pytest.mark.asyncio class TestImageClassifier: - classifier_preds = [ + classifier_preds = [[ {"label": "that's an image alright", "score": 0.8}, {"label": "well it ends with .jpg", "score": 0.1}, {"label": "idk, im just seeing bytes", "score": 0.05}, {"label": "not sure", "score": 0.04}, {"label": "probably a virus", "score": 0.01}, - ] + ]] - def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None: + async def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None: mocker.patch.object(ImageClassifier, "load") classifier = ImageClassifier("test_model_name", min_score=0.0) assert classifier.min_score == 0.0 @@ -39,9 +40,9 @@ class TestImageClassifier: classifier.model = mock.Mock() classifier.model.return_value = self.classifier_preds - all_labels = classifier.predict(pil_image) + all_labels = await classifier.predict(pil_image) classifier.min_score = 0.5 - filtered_labels = classifier.predict(pil_image) + filtered_labels = await classifier.predict(pil_image) assert all_labels == [ "that's an image alright", @@ -54,29 +55,30 @@ class TestImageClassifier: assert filtered_labels == ["that's an image alright"] +@pytest.mark.asyncio class TestCLIP: - embedding = np.random.rand(512).astype(np.float32) + embedding = np.random.rand(1, 512).astype(np.float32) - def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: + async def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: mocker.patch.object(CLIPEncoder, "download") mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) - mocked.return_value.run.return_value = [[self.embedding]] + mocked.return_value.run.return_value = [self.embedding] clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") assert clip_encoder.mode == "vision" - embedding = clip_encoder.predict(pil_image) + embedding = await clip_encoder.predict(pil_image) assert isinstance(embedding, list) assert len(embedding) == 512 assert all([isinstance(num, float) for num in embedding]) clip_encoder.vision_model.run.assert_called_once() - def test_basic_text(self, mocker: MockerFixture) -> None: + async def test_basic_text(self, mocker: MockerFixture) -> None: mocker.patch.object(CLIPEncoder, "download") mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) - mocked.return_value.run.return_value = [[self.embedding]] + mocked.return_value.run.return_value = [self.embedding] clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") assert clip_encoder.mode == "text" - embedding = clip_encoder.predict("test search query") + embedding = await clip_encoder.predict("test search query") assert isinstance(embedding, list) assert len(embedding) == 512 @@ -91,7 +93,8 @@ class TestFaceRecognition: assert face_recognizer.min_score == 0.5 - def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None: + @pytest.mark.asyncio + async def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None: mocker.patch.object(FaceRecognizer, "load") face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache") @@ -108,7 +111,7 @@ class TestFaceRecognition: rec_model.get_feat.return_value = embedding face_recognizer.rec_model = rec_model - faces = face_recognizer.predict(cv_image) + faces = await face_recognizer.predict(cv_image) assert len(faces) == num_faces for face in faces: @@ -119,7 +122,7 @@ class TestFaceRecognition: assert all([isinstance(num, float) for num in face["embedding"]]) det_model.detect.assert_called_once() - assert rec_model.get_feat.call_count == num_faces + rec_model.get_feat.assert_called_once() @pytest.mark.asyncio