we have dynamic batching at home

This commit is contained in:
mertalev 2023-10-21 22:33:28 -04:00
parent d4c23c8df8
commit 57506aa1fe
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
10 changed files with 280 additions and 76 deletions

View file

@ -21,6 +21,8 @@ class Settings(BaseSettings):
request_threads: int = os.cpu_count() or 4 request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 1 model_inter_op_threads: int = 1
model_intra_op_threads: int = 2 model_intra_op_threads: int = 2
max_batch_size: int = 8
batch_timeout_s: float = 0.005
class Config: class Config:
env_prefix = "MACHINE_LEARNING_" env_prefix = "MACHINE_LEARNING_"

View file

@ -11,6 +11,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchF
from starlette.formparsers import MultiPartParser from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel from app.models.base import InferenceModel
from app.models.batcher import Batcher, ModelBatcher
from .config import log, settings from .config import log, settings
from .models.cache import ModelCache from .models.cache import ModelCache
@ -26,6 +27,7 @@ app = FastAPI()
def init_state() -> None: def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
app.state.model_batcher = ModelBatcher(max_size=settings.max_batch_size, timeout_s=settings.batch_timeout_s)
log.info( log.info(
( (
"Created in-memory cache with unloading " "Created in-memory cache with unloading "
@ -62,9 +64,9 @@ async def predict(
image: UploadFile | None = None, image: UploadFile | None = None,
) -> Any: ) -> Any:
if image is not None: if image is not None:
inputs: str | bytes = await image.read() element: str | bytes = await image.read()
elif text is not None: elif text is not None:
inputs = text element = text
else: else:
raise HTTPException(400, "Either image or text must be provided") raise HTTPException(400, "Either image or text must be provided")
try: try:
@ -74,15 +76,16 @@ async def predict(
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs)) model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs) model.configure(**kwargs)
outputs = await run(model, inputs) batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs)
outputs = await batcher.batch_process(element, run, model)
return ORJSONResponse(outputs) return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any: async def run(model: InferenceModel, elements: list[Any]) -> Any:
if app.state.thread_pool is None: if app.state.thread_pool is None:
return model.predict(inputs) return model.predict_batch(elements)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict_batch, elements)
async def load(model: InferenceModel) -> InferenceModel: async def load(model: InferenceModel) -> InferenceModel:

View file

@ -12,6 +12,10 @@ from ..config import get_cache_dir, log, settings
from ..schemas import ModelType from ..schemas import ModelType
def get_model_key(model_name: str, model_type: ModelType, mode: str = "") -> str:
return f"{model_name}{model_type.value}{mode}"
class InferenceModel(ABC): class InferenceModel(ABC):
_model_type: ModelType _model_type: ModelType
@ -65,14 +69,15 @@ class InferenceModel(ABC):
self._load() self._load()
self.loaded = True self.loaded = True
def predict(self, inputs: Any, **model_kwargs: Any) -> Any: def predict(self, element: Any) -> Any:
return self.predict_batch([element])[0]
def predict_batch(self, inputs: list[Any]) -> list[Any]:
self.load() self.load()
if model_kwargs: return self._predict_batch(inputs)
self.configure(**model_kwargs)
return self._predict(inputs)
@abstractmethod @abstractmethod
def _predict(self, inputs: Any) -> Any: def _predict_batch(self, inputs: list[Any]) -> Any:
... ...
def configure(self, **model_kwargs: Any) -> None: def configure(self, **model_kwargs: Any) -> None:

View file

@ -0,0 +1,65 @@
import asyncio
import time
from typing import Any, Awaitable, Callable, TypeVar
from app.schemas import ModelType
from .base import get_model_key, log
F = TypeVar("F")
P = TypeVar("P")
R = TypeVar("R")
class Batcher:
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
self.max_size = max_size
self.timeout_s = timeout_s
self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=max_size)
self.lock = asyncio.Lock()
self.processing: dict[int, Any] = {}
self.processed: dict[int, Any] = {}
self.element_id = 0
async def batch_process(self, element: Any, func: Callable[[list[P]], Awaitable[list[R]]], *args: Any, **kwargs: Any) -> Any:
cur_idx = self.element_id
self.element_id += 1
self.processing[cur_idx] = element
await self.queue.put(cur_idx)
try:
async with self.lock:
await self._batch(cur_idx)
if cur_idx not in self.processed:
await self._process(func, *args, **kwargs)
return self.processed.pop(cur_idx)
finally:
self.processing.pop(cur_idx, None)
self.processed.pop(cur_idx, None)
async def _batch(self, idx: int) -> list[Any]:
if idx not in self.processed:
start = time.monotonic()
while self.queue.qsize() < self.max_size and time.monotonic() - start < self.timeout_s:
await asyncio.sleep(0)
async def _process(self, func: Callable[[list[P]], Awaitable[list[R]]], *args, **kwargs) -> None:
batch_ids = [self.queue.get_nowait() for _ in range(self.queue.qsize())]
batch = [self.processing.pop(id) for id in batch_ids]
outputs = await func(*args, batch, **kwargs)
for id, output in zip(batch_ids, outputs):
self.processed[id] = output
class ModelBatcher:
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
self.batchers = {}
self.max_size = max_size
self.timeout_s = timeout_s
def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any):
key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
if key not in self.batchers:
self.batchers[key] = Batcher(max_size=self.max_size, timeout_s=self.timeout_s)
return self.batchers[key]

View file

@ -5,7 +5,7 @@ from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin from aiocache.plugins import BasePlugin, TimingPlugin
from ..schemas import ModelType from ..schemas import ModelType
from .base import InferenceModel from .base import InferenceModel, get_model_key
class ModelCache: class ModelCache:
@ -46,7 +46,7 @@ class ModelCache:
model: The requested model. model: The requested model.
""" """
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
async with OptimisticLock(self.cache, key) as lock: async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key) model = await self.cache.get(key)
if model is None: if model is None:

View file

@ -1,8 +1,9 @@
import os import os
import zipfile import zipfile
from io import BytesIO from io import BytesIO
from typing import Any, Literal from typing import Any, Literal, TypeGuard
import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch import torch
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
@ -17,6 +18,14 @@ from ..schemas import ModelType
from .base import InferenceModel from .base import InferenceModel
def is_image_list(images: list[Any]) -> TypeGuard[list[Image.Image | bytes]]:
return any(isinstance(image, (Image.Image, bytes)) for image in images)
def is_text_list(texts: list[Any]) -> TypeGuard[list[str]]:
return any(isinstance(text, str) for text in texts)
class CLIPEncoder(InferenceModel): class CLIPEncoder(InferenceModel):
_model_type = ModelType.CLIP _model_type = ModelType.CLIP
@ -70,31 +79,42 @@ class CLIPEncoder(InferenceModel):
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)] image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
self.transform = _transform_pil_image(image_size) self.transform = _transform_pil_image(image_size)
def _predict(self, image_or_text: Image.Image | str) -> list[float]: def _predict_batch(self, images_or_text: list[Image.Image | bytes] | list[str]) -> list[list[float]]:
if isinstance(image_or_text, bytes): if not images_or_text:
image_or_text = Image.open(BytesIO(image_or_text)) return []
match image_or_text: if is_image_list(images_or_text):
case Image.Image(): outputs = self._predict_images(images_or_text)
if self.mode == "text": elif is_text_list(images_or_text):
raise TypeError("Cannot encode image as text-only model") outputs = self._predict_text(images_or_text)
pixel_values = self.transform(image_or_text) else:
assert isinstance(pixel_values, torch.Tensor) raise TypeError(f"Expected list of images or text, but got: {type(images_or_text[0])}")
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
inputs = {
"input_ids": text_inputs["input_ids"].int().numpy(),
"attention_mask": text_inputs["attention_mask"].int().numpy(),
}
outputs = self.text_model.run(self.text_outputs, inputs)
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist() return outputs
def _predict_images(self, images: list[Image.Image | bytes]) -> list[list[float]]:
if not images:
return []
for i, element in enumerate(images):
if isinstance(element, bytes):
images[i] = Image.open(BytesIO(element))
pixel_values = torch.stack([self.transform(image) for image in images]).numpy()
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
return outputs[0].tolist()
def _predict_text(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
text_inputs: dict[str, torch.Tensor] = self.tokenizer(texts)
inputs = {
"input_ids": text_inputs["input_ids"].int().numpy(),
"attention_mask": text_inputs["attention_mask"].int().numpy(),
}
outputs = self.text_model.run(self.text_outputs, inputs)
return outputs[0].tolist()
def _download_model(self, model_name: str, model_md5: str) -> bool: def _download_model(self, model_name: str, model_md5: str) -> bool:
# downloading logic is adapted from clip-server's CLIPOnnxModel class # downloading logic is adapted from clip-server's CLIPOnnxModel class

View file

@ -9,9 +9,12 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop from insightface.utils.face_align import norm_crop
from insightface.utils.storage import BASE_REPO_URL, download_file from insightface.utils.storage import BASE_REPO_URL, download_file
from ..schemas import ModelType from ..schemas import ModelType, ndarray
from .base import InferenceModel from .base import InferenceModel
import onnx
from onnx.tools import update_model_dims
class FaceRecognizer(InferenceModel): class FaceRecognizer(InferenceModel):
_model_type = ModelType.FACIAL_RECOGNITION _model_type = ModelType.FACIAL_RECOGNITION
@ -36,6 +39,8 @@ class FaceRecognizer(InferenceModel):
zip.extractall(self.cache_dir, members=[det_file, rec_file]) zip.extractall(self.cache_dir, members=[det_file, rec_file])
zip_file.unlink() zip_file.unlink()
self._add_batch_dimension(self.cache_dir / rec_file)
def _load(self) -> None: def _load(self) -> None:
try: try:
det_file = next(self.cache_dir.glob("det_*.onnx")) det_file = next(self.cache_dir.glob("det_*.onnx"))
@ -43,29 +48,36 @@ class FaceRecognizer(InferenceModel):
except StopIteration: except StopIteration:
raise FileNotFoundError("Facial recognition models not found in cache directory") raise FileNotFoundError("Facial recognition models not found in cache directory")
self.det_model = RetinaFace( det_session = ort.InferenceSession(
session=ort.InferenceSession( det_file.as_posix(),
det_file.as_posix(), sess_options=self.sess_options,
sess_options=self.sess_options, providers=self.providers,
providers=self.providers, provider_options=self.provider_options,
provider_options=self.provider_options,
),
) )
self.rec_model = ArcFaceONNX( self.det_model = RetinaFace(session=det_session)
rec_file.as_posix(),
session=ort.InferenceSession(
rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.det_model.prepare( self.det_model.prepare(
ctx_id=0, ctx_id=0,
det_thresh=self.min_score, det_thresh=self.min_score,
input_size=(640, 640), input_size=(640, 640),
) )
rec_session = ort.InferenceSession(
rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
print(rec_session.get_inputs())
if rec_session.get_inputs()[0].shape[0] != "batch":
del rec_session
self._add_batch_dimension(rec_file)
rec_session = ort.InferenceSession(
rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.rec_model = ArcFaceONNX(rec_file.as_posix(), session=rec_session)
self.rec_model.prepare(ctx_id=0) self.rec_model.prepare(ctx_id=0)
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]: def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
@ -100,6 +112,90 @@ class FaceRecognizer(InferenceModel):
) )
return results return results
def _predict_batch(self, images: list[cv2.Mat]) -> list[list[dict[str, Any]]]:
for i, image in enumerate(images):
if isinstance(image, bytes):
images[i] = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
batch_det, batch_kpss = self._detect(images)
batch_cropped_images, batch_offsets = self._preprocess(images, batch_kpss)
if batch_cropped_images:
batch_embeddings = self.rec_model.get_feat(images)
results = self._postprocess(images, batch_det, batch_embeddings, batch_offsets)
else:
results = self._postprocess(images, batch_det)
return results
def _detect(self, images: list[cv2.Mat]) -> tuple[list[ndarray], ...]:
batch_det: list[ndarray] = []
batch_kpss: list[ndarray] = []
# detection model doesn't support batching, but recognition model does
for image in images:
bboxes, kpss = self.det_model.detect(image)
batch_det.append(bboxes)
batch_kpss.append(kpss)
return batch_det, batch_kpss
def _preprocess(self, images: list[cv2.Mat], batch_kpss: list[ndarray]) -> tuple[list[cv2.Mat], list[int]]:
batch_cropped_images = []
batch_offsets = []
total_faces = 0
for i, image in enumerate(images):
kpss = batch_kpss[i]
total_faces += kpss.shape[0]
batch_offsets.append(total_faces)
for kps in kpss:
batch_cropped_images.append(norm_crop(image, kps))
return batch_cropped_images, batch_offsets
def _postprocess(
self,
images: list[cv2.Mat],
batch_det: list[ndarray],
batch_embeddings: ndarray | None = None,
batch_offsets: list[int] | None = None,
) -> list[list[dict[str, Any]]]:
if batch_embeddings is not None and batch_offsets is not None:
image_embeddings: list[ndarray] | None = np.array_split(batch_embeddings, batch_offsets)
else:
image_embeddings = None
batch_faces: list[list[dict[str, Any]]] = []
for i, image in enumerate(images):
faces: list[dict[str, Any]] = []
batch_faces.append(faces)
if image_embeddings is None or image_embeddings[i].shape[0] == 0:
continue
height, width, _ = image.shape
embeddings = image_embeddings[i].tolist()
bboxes = batch_det[i][:, :4].round().tolist()
det_scores = batch_det[i][:, 4].tolist()
for (x1, y1, x2, y2), embedding, det_score in zip(bboxes, embeddings, det_scores):
face = {
"imageWidth": width,
"imageHeight": height,
"boundingBox": {
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
},
"score": det_score,
"embedding": embedding,
}
faces.append(face)
return batch_faces
def _add_batch_dimension(self, model_path: Path) -> None:
rec_proto = onnx.load(model_path.as_posix())
inputs = {input.name: ['batch'] + [shape.dim_value for shape in input.type.tensor_type.shape.dim[1:]] for input in rec_proto.graph.input}
outputs = {output.name: ['batch'] + [shape.dim_value for shape in output.type.tensor_type.shape.dim[1:]] for output in rec_proto.graph.output}
rec_proto = update_model_dims.update_inputs_outputs_dims(rec_proto, inputs, outputs)
onnx.save(rec_proto, model_path.open("wb"))
@property @property
def cached(self) -> bool: def cached(self) -> bool:
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx")) return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))

View file

@ -63,13 +63,18 @@ class ImageClassifier(InferenceModel):
feature_extractor=processor, feature_extractor=processor,
) )
def _predict(self, image: Image.Image | bytes) -> list[str]: def _predict_batch(self, images: list[Image.Image | bytes]) -> list[list[str]]:
if isinstance(image, bytes): for i, image in enumerate(images):
image = Image.open(BytesIO(image)) if isinstance(image, bytes):
predictions: list[dict[str, Any]] = self.model(image) # type: ignore images[i] = Image.open(BytesIO(image))
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags batch_predictions: list[list[dict[str, Any]]] = self.model(images)
results = [self._postprocess(predictions) for predictions in batch_predictions]
return results
def _postprocess(self, predictions: list[dict[str, Any]]) -> list[str]:
return [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
def configure(self, **model_kwargs: Any) -> None: def configure(self, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", self.min_score) self.min_score = model_kwargs.pop("minScore", self.min_score)

View file

@ -1,5 +1,7 @@
from enum import StrEnum from enum import StrEnum
from typing import Any, TypeAlias
import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
@ -31,3 +33,6 @@ class ModelType(StrEnum):
IMAGE_CLASSIFICATION = "image-classification" IMAGE_CLASSIFICATION = "image-classification"
CLIP = "clip" CLIP = "clip"
FACIAL_RECOGNITION = "facial-recognition" FACIAL_RECOGNITION = "facial-recognition"
ndarray: TypeAlias = np.ndarray[int, np.dtype[Any]]

View file

@ -22,16 +22,17 @@ from .schemas import ModelType
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]] ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
@pytest.mark.asyncio
class TestImageClassifier: class TestImageClassifier:
classifier_preds = [ classifier_preds = [[
{"label": "that's an image alright", "score": 0.8}, {"label": "that's an image alright", "score": 0.8},
{"label": "well it ends with .jpg", "score": 0.1}, {"label": "well it ends with .jpg", "score": 0.1},
{"label": "idk, im just seeing bytes", "score": 0.05}, {"label": "idk, im just seeing bytes", "score": 0.05},
{"label": "not sure", "score": 0.04}, {"label": "not sure", "score": 0.04},
{"label": "probably a virus", "score": 0.01}, {"label": "probably a virus", "score": 0.01},
] ]]
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None: async def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
mocker.patch.object(ImageClassifier, "load") mocker.patch.object(ImageClassifier, "load")
classifier = ImageClassifier("test_model_name", min_score=0.0) classifier = ImageClassifier("test_model_name", min_score=0.0)
assert classifier.min_score == 0.0 assert classifier.min_score == 0.0
@ -39,9 +40,9 @@ class TestImageClassifier:
classifier.model = mock.Mock() classifier.model = mock.Mock()
classifier.model.return_value = self.classifier_preds classifier.model.return_value = self.classifier_preds
all_labels = classifier.predict(pil_image) all_labels = await classifier.predict(pil_image)
classifier.min_score = 0.5 classifier.min_score = 0.5
filtered_labels = classifier.predict(pil_image) filtered_labels = await classifier.predict(pil_image)
assert all_labels == [ assert all_labels == [
"that's an image alright", "that's an image alright",
@ -54,29 +55,30 @@ class TestImageClassifier:
assert filtered_labels == ["that's an image alright"] assert filtered_labels == ["that's an image alright"]
@pytest.mark.asyncio
class TestCLIP: class TestCLIP:
embedding = np.random.rand(512).astype(np.float32) embedding = np.random.rand(1, 512).astype(np.float32)
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None: async def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download") mocker.patch.object(CLIPEncoder, "download")
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]] mocked.return_value.run.return_value = [self.embedding]
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision") clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
assert clip_encoder.mode == "vision" assert clip_encoder.mode == "vision"
embedding = clip_encoder.predict(pil_image) embedding = await clip_encoder.predict(pil_image)
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) == 512 assert len(embedding) == 512
assert all([isinstance(num, float) for num in embedding]) assert all([isinstance(num, float) for num in embedding])
clip_encoder.vision_model.run.assert_called_once() clip_encoder.vision_model.run.assert_called_once()
def test_basic_text(self, mocker: MockerFixture) -> None: async def test_basic_text(self, mocker: MockerFixture) -> None:
mocker.patch.object(CLIPEncoder, "download") mocker.patch.object(CLIPEncoder, "download")
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True) mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]] mocked.return_value.run.return_value = [self.embedding]
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text") clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
assert clip_encoder.mode == "text" assert clip_encoder.mode == "text"
embedding = clip_encoder.predict("test search query") embedding = await clip_encoder.predict("test search query")
assert isinstance(embedding, list) assert isinstance(embedding, list)
assert len(embedding) == 512 assert len(embedding) == 512
@ -91,7 +93,8 @@ class TestFaceRecognition:
assert face_recognizer.min_score == 0.5 assert face_recognizer.min_score == 0.5
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None: @pytest.mark.asyncio
async def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
mocker.patch.object(FaceRecognizer, "load") mocker.patch.object(FaceRecognizer, "load")
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache") face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
@ -108,7 +111,7 @@ class TestFaceRecognition:
rec_model.get_feat.return_value = embedding rec_model.get_feat.return_value = embedding
face_recognizer.rec_model = rec_model face_recognizer.rec_model = rec_model
faces = face_recognizer.predict(cv_image) faces = await face_recognizer.predict(cv_image)
assert len(faces) == num_faces assert len(faces) == num_faces
for face in faces: for face in faces:
@ -119,7 +122,7 @@ class TestFaceRecognition:
assert all([isinstance(num, float) for num in face["embedding"]]) assert all([isinstance(num, float) for num in face["embedding"]])
det_model.detect.assert_called_once() det_model.detect.assert_called_once()
assert rec_model.get_feat.call_count == num_faces rec_model.get_feat.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio