conftest.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from types import SimpleNamespace
  2. from typing import Any, Iterator, TypeAlias
  3. from unittest import mock
  4. import numpy as np
  5. import pytest
  6. from fastapi.testclient import TestClient
  7. from PIL import Image
  8. from .main import app, init_state
  9. ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
  10. @pytest.fixture
  11. def pil_image() -> Image.Image:
  12. return Image.new("RGB", (600, 800))
  13. @pytest.fixture
  14. def cv_image(pil_image: Image.Image) -> ndarray:
  15. return np.asarray(pil_image)[:, :, ::-1] # PIL uses RGB while cv2 uses BGR
  16. @pytest.fixture
  17. def mock_classifier_pipeline() -> Iterator[mock.Mock]:
  18. with mock.patch("app.models.image_classification.pipeline") as model:
  19. classifier_preds = [
  20. {"label": "that's an image alright", "score": 0.8},
  21. {"label": "well it ends with .jpg", "score": 0.1},
  22. {"label": "idk, im just seeing bytes", "score": 0.05},
  23. {"label": "not sure", "score": 0.04},
  24. {"label": "probably a virus", "score": 0.01},
  25. ]
  26. def forward(
  27. inputs: Image.Image | list[Image.Image], **kwargs: Any
  28. ) -> list[dict[str, Any]] | list[list[dict[str, Any]]]:
  29. if isinstance(inputs, list) and not all([isinstance(img, Image.Image) for img in inputs]):
  30. raise TypeError
  31. elif not isinstance(inputs, Image.Image):
  32. raise TypeError
  33. if isinstance(inputs, list):
  34. return [classifier_preds] * len(inputs)
  35. return classifier_preds
  36. model.return_value = forward
  37. yield model
  38. @pytest.fixture
  39. def mock_st() -> Iterator[mock.Mock]:
  40. with mock.patch("app.models.clip.SentenceTransformer") as model:
  41. embedding = np.random.rand(512).astype(np.float32)
  42. def encode(inputs: Image.Image | list[Image.Image], **kwargs: Any) -> ndarray | list[ndarray]:
  43. # mypy complains unless isinstance(inputs, list) is used explicitly
  44. img_batch = isinstance(inputs, list) and all([isinstance(inst, Image.Image) for inst in inputs])
  45. text_batch = isinstance(inputs, list) and all([isinstance(inst, str) for inst in inputs])
  46. if isinstance(inputs, list) and not any([img_batch, text_batch]):
  47. raise TypeError
  48. if isinstance(inputs, list):
  49. return np.stack([embedding] * len(inputs))
  50. return embedding
  51. mocked = mock.Mock()
  52. mocked.encode = encode
  53. model.return_value = mocked
  54. yield model
  55. @pytest.fixture
  56. def mock_faceanalysis() -> Iterator[mock.Mock]:
  57. with mock.patch("app.models.facial_recognition.FaceAnalysis") as model:
  58. face_preds = [
  59. SimpleNamespace( # this is so these fields can be accessed through dot notation
  60. **{
  61. "bbox": np.random.rand(4).astype(np.float32),
  62. "kps": np.random.rand(5, 2).astype(np.float32),
  63. "det_score": np.array([0.67]).astype(np.float32),
  64. "normed_embedding": np.random.rand(512).astype(np.float32),
  65. }
  66. ),
  67. SimpleNamespace(
  68. **{
  69. "bbox": np.random.rand(4).astype(np.float32),
  70. "kps": np.random.rand(5, 2).astype(np.float32),
  71. "det_score": np.array([0.4]).astype(np.float32),
  72. "normed_embedding": np.random.rand(512).astype(np.float32),
  73. }
  74. ),
  75. ]
  76. def get(image: np.ndarray[int, np.dtype[np.float32]], **kwargs: Any) -> list[SimpleNamespace]:
  77. if not isinstance(image, np.ndarray):
  78. raise TypeError
  79. return face_preds
  80. mocked = mock.Mock()
  81. mocked.get = get
  82. model.return_value = mocked
  83. yield model
  84. @pytest.fixture
  85. def mock_get_model() -> Iterator[mock.Mock]:
  86. with mock.patch("app.models.cache.InferenceModel.from_model_type", autospec=True) as mocked:
  87. yield mocked
  88. @pytest.fixture(scope="session")
  89. def deployed_app() -> TestClient:
  90. init_state()
  91. return TestClient(app)