Compare commits
2 commits
main
...
feat/dynam
Author | SHA1 | Date | |
---|---|---|---|
|
bdf8c9f1a9 | ||
|
57506aa1fe |
10 changed files with 285 additions and 76 deletions
|
@ -21,6 +21,8 @@ class Settings(BaseSettings):
|
|||
request_threads: int = os.cpu_count() or 4
|
||||
model_inter_op_threads: int = 1
|
||||
model_intra_op_threads: int = 2
|
||||
max_batch_size: int = 1
|
||||
batch_timeout_s: float = 0.005
|
||||
|
||||
class Config:
|
||||
env_prefix = "MACHINE_LEARNING_"
|
||||
|
|
|
@ -11,6 +11,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchF
|
|||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.batcher import Batcher, ModelBatcher
|
||||
|
||||
from .config import log, settings
|
||||
from .models.cache import ModelCache
|
||||
|
@ -26,6 +27,7 @@ app = FastAPI()
|
|||
|
||||
def init_state() -> None:
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
app.state.model_batcher = ModelBatcher(max_size=settings.max_batch_size, timeout_s=settings.batch_timeout_s)
|
||||
log.info(
|
||||
(
|
||||
"Created in-memory cache with unloading "
|
||||
|
@ -62,9 +64,9 @@ async def predict(
|
|||
image: UploadFile | None = None,
|
||||
) -> Any:
|
||||
if image is not None:
|
||||
inputs: str | bytes = await image.read()
|
||||
element: str | bytes = await image.read()
|
||||
elif text is not None:
|
||||
inputs = text
|
||||
element = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
try:
|
||||
|
@ -74,15 +76,21 @@ async def predict(
|
|||
|
||||
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model, inputs)
|
||||
|
||||
if settings.max_batch_size > 1:
|
||||
batcher: Batcher = app.state.model_batcher.get(model_name, model_type, **kwargs)
|
||||
outputs = await batcher.batch_process(element, run, model)
|
||||
else:
|
||||
outputs = await run(model, [element])
|
||||
|
||||
return ORJSONResponse(outputs)
|
||||
|
||||
|
||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||
async def run(model: InferenceModel, elements: list[Any]) -> Any:
|
||||
if app.state.thread_pool is None:
|
||||
return model.predict(inputs)
|
||||
return model.predict_batch(elements)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict_batch, elements)
|
||||
|
||||
|
||||
async def load(model: InferenceModel) -> InferenceModel:
|
||||
|
|
|
@ -12,6 +12,10 @@ from ..config import get_cache_dir, log, settings
|
|||
from ..schemas import ModelType
|
||||
|
||||
|
||||
def get_model_key(model_name: str, model_type: ModelType, mode: str = "") -> str:
|
||||
return f"{model_name}{model_type.value}{mode}"
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
|
||||
|
@ -65,14 +69,15 @@ class InferenceModel(ABC):
|
|||
self._load()
|
||||
self.loaded = True
|
||||
|
||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||
def predict(self, element: Any) -> Any:
|
||||
return self.predict_batch([element])[0]
|
||||
|
||||
def predict_batch(self, inputs: list[Any]) -> list[Any]:
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(inputs)
|
||||
return self._predict_batch(inputs)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, inputs: Any) -> Any:
|
||||
def _predict_batch(self, inputs: list[Any]) -> Any:
|
||||
...
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
|
|
65
machine-learning/app/models/batcher.py
Normal file
65
machine-learning/app/models/batcher.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import asyncio
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
from app.schemas import ModelType
|
||||
|
||||
from .base import get_model_key, log
|
||||
|
||||
|
||||
F = TypeVar("F")
|
||||
P = TypeVar("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class Batcher:
|
||||
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
|
||||
self.max_size = max_size
|
||||
self.timeout_s = timeout_s
|
||||
self.queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=max_size)
|
||||
self.lock = asyncio.Lock()
|
||||
self.processing: dict[int, Any] = {}
|
||||
self.processed: dict[int, Any] = {}
|
||||
self.element_id = 0
|
||||
|
||||
async def batch_process(self, element: Any, func: Callable[[list[P]], Awaitable[list[R]]], *args: Any, **kwargs: Any) -> Any:
|
||||
cur_idx = self.element_id
|
||||
self.element_id += 1
|
||||
self.processing[cur_idx] = element
|
||||
await self.queue.put(cur_idx)
|
||||
try:
|
||||
async with self.lock:
|
||||
await self._batch(cur_idx)
|
||||
if cur_idx not in self.processed:
|
||||
await self._process(func, *args, **kwargs)
|
||||
return self.processed.pop(cur_idx)
|
||||
finally:
|
||||
self.processing.pop(cur_idx, None)
|
||||
self.processed.pop(cur_idx, None)
|
||||
|
||||
async def _batch(self, idx: int) -> list[Any]:
|
||||
if idx not in self.processed:
|
||||
start = time.monotonic()
|
||||
while self.queue.qsize() < self.max_size and time.monotonic() - start < self.timeout_s:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def _process(self, func: Callable[[list[P]], Awaitable[list[R]]], *args, **kwargs) -> None:
|
||||
batch_ids = [self.queue.get_nowait() for _ in range(self.queue.qsize())]
|
||||
batch = [self.processing.pop(id) for id in batch_ids]
|
||||
outputs = await func(*args, batch, **kwargs)
|
||||
for id, output in zip(batch_ids, outputs):
|
||||
self.processed[id] = output
|
||||
|
||||
|
||||
class ModelBatcher:
|
||||
def __init__(self, max_size: int = 16, timeout_s: float = 0.005) -> None:
|
||||
self.batchers = {}
|
||||
self.max_size = max_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any):
|
||||
key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
|
||||
if key not in self.batchers:
|
||||
self.batchers[key] = Batcher(max_size=self.max_size, timeout_s=self.timeout_s)
|
||||
|
||||
return self.batchers[key]
|
|
@ -5,7 +5,7 @@ from aiocache.lock import OptimisticLock
|
|||
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
from .base import InferenceModel, get_model_key
|
||||
|
||||
|
||||
class ModelCache:
|
||||
|
@ -46,7 +46,7 @@ class ModelCache:
|
|||
model: The requested model.
|
||||
"""
|
||||
|
||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||
key = get_model_key(model_name, model_type, model_kwargs.get("mode", ""))
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await self.cache.get(key)
|
||||
if model is None:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, TypeGuard
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from clip_server.model.clip import BICUBIC, _convert_image_to_rgb
|
||||
|
@ -17,6 +18,14 @@ from ..schemas import ModelType
|
|||
from .base import InferenceModel
|
||||
|
||||
|
||||
def is_image_list(images: list[Any]) -> TypeGuard[list[Image.Image | bytes]]:
|
||||
return any(isinstance(image, (Image.Image, bytes)) for image in images)
|
||||
|
||||
|
||||
def is_text_list(texts: list[Any]) -> TypeGuard[list[str]]:
|
||||
return any(isinstance(text, str) for text in texts)
|
||||
|
||||
|
||||
class CLIPEncoder(InferenceModel):
|
||||
_model_type = ModelType.CLIP
|
||||
|
||||
|
@ -70,31 +79,42 @@ class CLIPEncoder(InferenceModel):
|
|||
image_size = _VISUAL_MODEL_IMAGE_SIZE[CLIPOnnxModel.get_model_name(self.model_name)]
|
||||
self.transform = _transform_pil_image(image_size)
|
||||
|
||||
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
||||
if isinstance(image_or_text, bytes):
|
||||
image_or_text = Image.open(BytesIO(image_or_text))
|
||||
def _predict_batch(self, images_or_text: list[Image.Image | bytes] | list[str]) -> list[list[float]]:
|
||||
if not images_or_text:
|
||||
return []
|
||||
|
||||
match image_or_text:
|
||||
case Image.Image():
|
||||
if self.mode == "text":
|
||||
raise TypeError("Cannot encode image as text-only model")
|
||||
pixel_values = self.transform(image_or_text)
|
||||
assert isinstance(pixel_values, torch.Tensor)
|
||||
pixel_values = torch.unsqueeze(pixel_values, 0).numpy()
|
||||
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
||||
case str():
|
||||
if self.mode == "vision":
|
||||
raise TypeError("Cannot encode text as vision-only model")
|
||||
text_inputs: dict[str, torch.Tensor] = self.tokenizer(image_or_text)
|
||||
inputs = {
|
||||
"input_ids": text_inputs["input_ids"].int().numpy(),
|
||||
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
||||
}
|
||||
outputs = self.text_model.run(self.text_outputs, inputs)
|
||||
case _:
|
||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
||||
if is_image_list(images_or_text):
|
||||
outputs = self._predict_images(images_or_text)
|
||||
elif is_text_list(images_or_text):
|
||||
outputs = self._predict_text(images_or_text)
|
||||
else:
|
||||
raise TypeError(f"Expected list of images or text, but got: {type(images_or_text[0])}")
|
||||
|
||||
return outputs[0][0].tolist()
|
||||
return outputs
|
||||
|
||||
def _predict_images(self, images: list[Image.Image | bytes]) -> list[list[float]]:
|
||||
if not images:
|
||||
return []
|
||||
|
||||
for i, element in enumerate(images):
|
||||
if isinstance(element, bytes):
|
||||
images[i] = Image.open(BytesIO(element))
|
||||
|
||||
pixel_values = torch.stack([self.transform(image) for image in images]).numpy()
|
||||
outputs = self.vision_model.run(self.vision_outputs, {"pixel_values": pixel_values})
|
||||
return outputs[0].tolist()
|
||||
|
||||
def _predict_text(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
text_inputs: dict[str, torch.Tensor] = self.tokenizer(texts)
|
||||
inputs = {
|
||||
"input_ids": text_inputs["input_ids"].int().numpy(),
|
||||
"attention_mask": text_inputs["attention_mask"].int().numpy(),
|
||||
}
|
||||
outputs = self.text_model.run(self.text_outputs, inputs)
|
||||
return outputs[0].tolist()
|
||||
|
||||
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
||||
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
||||
|
|
|
@ -9,9 +9,12 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
|||
from insightface.utils.face_align import norm_crop
|
||||
from insightface.utils.storage import BASE_REPO_URL, download_file
|
||||
|
||||
from ..schemas import ModelType
|
||||
from ..schemas import ModelType, ndarray
|
||||
from .base import InferenceModel
|
||||
|
||||
import onnx
|
||||
from onnx.tools import update_model_dims
|
||||
|
||||
|
||||
class FaceRecognizer(InferenceModel):
|
||||
_model_type = ModelType.FACIAL_RECOGNITION
|
||||
|
@ -36,6 +39,8 @@ class FaceRecognizer(InferenceModel):
|
|||
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
||||
zip_file.unlink()
|
||||
|
||||
self._add_batch_dimension(self.cache_dir / rec_file)
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
||||
|
@ -43,29 +48,36 @@ class FaceRecognizer(InferenceModel):
|
|||
except StopIteration:
|
||||
raise FileNotFoundError("Facial recognition models not found in cache directory")
|
||||
|
||||
self.det_model = RetinaFace(
|
||||
session=ort.InferenceSession(
|
||||
det_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
),
|
||||
det_session = ort.InferenceSession(
|
||||
det_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
self.rec_model = ArcFaceONNX(
|
||||
rec_file.as_posix(),
|
||||
session=ort.InferenceSession(
|
||||
rec_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
),
|
||||
)
|
||||
|
||||
self.det_model = RetinaFace(session=det_session)
|
||||
self.det_model.prepare(
|
||||
ctx_id=0,
|
||||
det_thresh=self.min_score,
|
||||
input_size=(640, 640),
|
||||
)
|
||||
|
||||
rec_session = ort.InferenceSession(
|
||||
rec_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
print(rec_session.get_inputs())
|
||||
if rec_session.get_inputs()[0].shape[0] != "batch":
|
||||
del rec_session
|
||||
self._add_batch_dimension(rec_file)
|
||||
rec_session = ort.InferenceSession(
|
||||
rec_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
self.rec_model = ArcFaceONNX(rec_file.as_posix(), session=rec_session)
|
||||
self.rec_model.prepare(ctx_id=0)
|
||||
|
||||
def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
|
||||
|
@ -100,6 +112,90 @@ class FaceRecognizer(InferenceModel):
|
|||
)
|
||||
return results
|
||||
|
||||
def _predict_batch(self, images: list[cv2.Mat]) -> list[list[dict[str, Any]]]:
|
||||
for i, image in enumerate(images):
|
||||
if isinstance(image, bytes):
|
||||
images[i] = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
batch_det, batch_kpss = self._detect(images)
|
||||
batch_cropped_images, batch_offsets = self._preprocess(images, batch_kpss)
|
||||
if batch_cropped_images:
|
||||
batch_embeddings = self.rec_model.get_feat(images)
|
||||
results = self._postprocess(images, batch_det, batch_embeddings, batch_offsets)
|
||||
else:
|
||||
results = self._postprocess(images, batch_det)
|
||||
return results
|
||||
|
||||
def _detect(self, images: list[cv2.Mat]) -> tuple[list[ndarray], ...]:
|
||||
batch_det: list[ndarray] = []
|
||||
batch_kpss: list[ndarray] = []
|
||||
# detection model doesn't support batching, but recognition model does
|
||||
for image in images:
|
||||
bboxes, kpss = self.det_model.detect(image)
|
||||
batch_det.append(bboxes)
|
||||
batch_kpss.append(kpss)
|
||||
return batch_det, batch_kpss
|
||||
|
||||
def _preprocess(self, images: list[cv2.Mat], batch_kpss: list[ndarray]) -> tuple[list[cv2.Mat], list[int]]:
|
||||
batch_cropped_images = []
|
||||
batch_offsets = []
|
||||
total_faces = 0
|
||||
for i, image in enumerate(images):
|
||||
kpss = batch_kpss[i]
|
||||
total_faces += kpss.shape[0]
|
||||
batch_offsets.append(total_faces)
|
||||
for kps in kpss:
|
||||
batch_cropped_images.append(norm_crop(image, kps))
|
||||
return batch_cropped_images, batch_offsets
|
||||
|
||||
def _postprocess(
|
||||
self,
|
||||
images: list[cv2.Mat],
|
||||
batch_det: list[ndarray],
|
||||
batch_embeddings: ndarray | None = None,
|
||||
batch_offsets: list[int] | None = None,
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
if batch_embeddings is not None and batch_offsets is not None:
|
||||
image_embeddings: list[ndarray] | None = np.array_split(batch_embeddings, batch_offsets)
|
||||
else:
|
||||
image_embeddings = None
|
||||
|
||||
batch_faces: list[list[dict[str, Any]]] = []
|
||||
for i, image in enumerate(images):
|
||||
faces: list[dict[str, Any]] = []
|
||||
batch_faces.append(faces)
|
||||
if image_embeddings is None or image_embeddings[i].shape[0] == 0:
|
||||
continue
|
||||
|
||||
height, width, _ = image.shape
|
||||
|
||||
embeddings = image_embeddings[i].tolist()
|
||||
bboxes = batch_det[i][:, :4].round().tolist()
|
||||
det_scores = batch_det[i][:, 4].tolist()
|
||||
for (x1, y1, x2, y2), embedding, det_score in zip(bboxes, embeddings, det_scores):
|
||||
face = {
|
||||
"imageWidth": width,
|
||||
"imageHeight": height,
|
||||
"boundingBox": {
|
||||
"x1": x1,
|
||||
"y1": y1,
|
||||
"x2": x2,
|
||||
"y2": y2,
|
||||
},
|
||||
"score": det_score,
|
||||
"embedding": embedding,
|
||||
}
|
||||
|
||||
faces.append(face)
|
||||
return batch_faces
|
||||
|
||||
def _add_batch_dimension(self, model_path: Path) -> None:
|
||||
rec_proto = onnx.load(model_path.as_posix())
|
||||
inputs = {input.name: ['batch'] + [shape.dim_value for shape in input.type.tensor_type.shape.dim[1:]] for input in rec_proto.graph.input}
|
||||
outputs = {output.name: ['batch'] + [shape.dim_value for shape in output.type.tensor_type.shape.dim[1:]] for output in rec_proto.graph.output}
|
||||
rec_proto = update_model_dims.update_inputs_outputs_dims(rec_proto, inputs, outputs)
|
||||
onnx.save(rec_proto, model_path.open("wb"))
|
||||
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
|
||||
|
|
|
@ -63,13 +63,18 @@ class ImageClassifier(InferenceModel):
|
|||
feature_extractor=processor,
|
||||
)
|
||||
|
||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
||||
if isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
def _predict_batch(self, images: list[Image.Image | bytes]) -> list[list[str]]:
|
||||
for i, image in enumerate(images):
|
||||
if isinstance(image, bytes):
|
||||
images[i] = Image.open(BytesIO(image))
|
||||
|
||||
return tags
|
||||
batch_predictions: list[list[dict[str, Any]]] = self.model(images)
|
||||
results = [self._postprocess(predictions) for predictions in batch_predictions]
|
||||
|
||||
return results
|
||||
|
||||
def _postprocess(self, predictions: list[dict[str, Any]]) -> list[str]:
|
||||
return [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", self.min_score)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from enum import StrEnum
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
@ -31,3 +33,6 @@ class ModelType(StrEnum):
|
|||
IMAGE_CLASSIFICATION = "image-classification"
|
||||
CLIP = "clip"
|
||||
FACIAL_RECOGNITION = "facial-recognition"
|
||||
|
||||
|
||||
ndarray: TypeAlias = np.ndarray[int, np.dtype[Any]]
|
||||
|
|
|
@ -22,16 +22,17 @@ from .schemas import ModelType
|
|||
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestImageClassifier:
|
||||
classifier_preds = [
|
||||
classifier_preds = [[
|
||||
{"label": "that's an image alright", "score": 0.8},
|
||||
{"label": "well it ends with .jpg", "score": 0.1},
|
||||
{"label": "idk, im just seeing bytes", "score": 0.05},
|
||||
{"label": "not sure", "score": 0.04},
|
||||
{"label": "probably a virus", "score": 0.01},
|
||||
]
|
||||
]]
|
||||
|
||||
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||
async def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(ImageClassifier, "load")
|
||||
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
||||
assert classifier.min_score == 0.0
|
||||
|
@ -39,9 +40,9 @@ class TestImageClassifier:
|
|||
classifier.model = mock.Mock()
|
||||
classifier.model.return_value = self.classifier_preds
|
||||
|
||||
all_labels = classifier.predict(pil_image)
|
||||
all_labels = await classifier.predict(pil_image)
|
||||
classifier.min_score = 0.5
|
||||
filtered_labels = classifier.predict(pil_image)
|
||||
filtered_labels = await classifier.predict(pil_image)
|
||||
|
||||
assert all_labels == [
|
||||
"that's an image alright",
|
||||
|
@ -54,29 +55,30 @@ class TestImageClassifier:
|
|||
assert filtered_labels == ["that's an image alright"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCLIP:
|
||||
embedding = np.random.rand(512).astype(np.float32)
|
||||
embedding = np.random.rand(1, 512).astype(np.float32)
|
||||
|
||||
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||
async def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(CLIPEncoder, "download")
|
||||
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
||||
mocked.return_value.run.return_value = [[self.embedding]]
|
||||
mocked.return_value.run.return_value = [self.embedding]
|
||||
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
|
||||
assert clip_encoder.mode == "vision"
|
||||
embedding = clip_encoder.predict(pil_image)
|
||||
embedding = await clip_encoder.predict(pil_image)
|
||||
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == 512
|
||||
assert all([isinstance(num, float) for num in embedding])
|
||||
clip_encoder.vision_model.run.assert_called_once()
|
||||
|
||||
def test_basic_text(self, mocker: MockerFixture) -> None:
|
||||
async def test_basic_text(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(CLIPEncoder, "download")
|
||||
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
|
||||
mocked.return_value.run.return_value = [[self.embedding]]
|
||||
mocked.return_value.run.return_value = [self.embedding]
|
||||
clip_encoder = CLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
|
||||
assert clip_encoder.mode == "text"
|
||||
embedding = clip_encoder.predict("test search query")
|
||||
embedding = await clip_encoder.predict("test search query")
|
||||
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == 512
|
||||
|
@ -91,7 +93,8 @@ class TestFaceRecognition:
|
|||
|
||||
assert face_recognizer.min_score == 0.5
|
||||
|
||||
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(FaceRecognizer, "load")
|
||||
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
|
||||
|
||||
|
@ -108,7 +111,7 @@ class TestFaceRecognition:
|
|||
rec_model.get_feat.return_value = embedding
|
||||
face_recognizer.rec_model = rec_model
|
||||
|
||||
faces = face_recognizer.predict(cv_image)
|
||||
faces = await face_recognizer.predict(cv_image)
|
||||
|
||||
assert len(faces) == num_faces
|
||||
for face in faces:
|
||||
|
@ -119,7 +122,7 @@ class TestFaceRecognition:
|
|||
assert all([isinstance(num, float) for num in face["embedding"]])
|
||||
|
||||
det_model.detect.assert_called_once()
|
||||
assert rec_model.get_feat.call_count == num_faces
|
||||
rec_model.get_feat.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Reference in a new issue