test_main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from io import BytesIO
  2. from typing import TypeAlias
  3. from unittest import mock
  4. import cv2
  5. import numpy as np
  6. import pytest
  7. from fastapi.testclient import TestClient
  8. from PIL import Image
  9. from pytest_mock import MockerFixture
  10. from .config import settings
  11. from .models.cache import ModelCache
  12. from .models.clip import CLIPSTEncoder
  13. from .models.facial_recognition import FaceRecognizer
  14. from .models.image_classification import ImageClassifier
  15. from .schemas import ModelType
  16. ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
  17. class TestImageClassifier:
  18. classifier_preds = [
  19. {"label": "that's an image alright", "score": 0.8},
  20. {"label": "well it ends with .jpg", "score": 0.1},
  21. {"label": "idk, im just seeing bytes", "score": 0.05},
  22. {"label": "not sure", "score": 0.04},
  23. {"label": "probably a virus", "score": 0.01},
  24. ]
  25. def test_eager_init(self, mocker: MockerFixture) -> None:
  26. mocker.patch.object(ImageClassifier, "download")
  27. mock_load = mocker.patch.object(ImageClassifier, "load")
  28. classifier = ImageClassifier("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
  29. assert classifier.model_name == "test_model_name"
  30. mock_load.assert_called_once_with(test_arg="test_arg")
  31. def test_lazy_init(self, mocker: MockerFixture) -> None:
  32. mock_download = mocker.patch.object(ImageClassifier, "download")
  33. mock_load = mocker.patch.object(ImageClassifier, "load")
  34. face_model = ImageClassifier("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
  35. assert face_model.model_name == "test_model_name"
  36. mock_download.assert_called_once_with(test_arg="test_arg")
  37. mock_load.assert_not_called()
  38. def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
  39. mocker.patch.object(ImageClassifier, "load")
  40. classifier = ImageClassifier("test_model_name", min_score=0.0)
  41. assert classifier.min_score == 0.0
  42. classifier.model = mock.Mock()
  43. classifier.model.return_value = self.classifier_preds
  44. all_labels = classifier.predict(pil_image)
  45. classifier.min_score = 0.5
  46. filtered_labels = classifier.predict(pil_image)
  47. assert all_labels == [
  48. "that's an image alright",
  49. "well it ends with .jpg",
  50. "idk",
  51. "im just seeing bytes",
  52. "not sure",
  53. "probably a virus",
  54. ]
  55. assert filtered_labels == ["that's an image alright"]
  56. class TestCLIP:
  57. embedding = np.random.rand(512).astype(np.float32)
  58. def test_eager_init(self, mocker: MockerFixture) -> None:
  59. mocker.patch.object(CLIPSTEncoder, "download")
  60. mock_load = mocker.patch.object(CLIPSTEncoder, "load")
  61. clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
  62. assert clip_model.model_name == "test_model_name"
  63. mock_load.assert_called_once_with(test_arg="test_arg")
  64. def test_lazy_init(self, mocker: MockerFixture) -> None:
  65. mock_download = mocker.patch.object(CLIPSTEncoder, "download")
  66. mock_load = mocker.patch.object(CLIPSTEncoder, "load")
  67. clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
  68. assert clip_model.model_name == "test_model_name"
  69. mock_download.assert_called_once_with(test_arg="test_arg")
  70. mock_load.assert_not_called()
  71. def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
  72. mocker.patch.object(CLIPSTEncoder, "load")
  73. clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
  74. clip_encoder.model = mock.Mock()
  75. clip_encoder.model.encode.return_value = self.embedding
  76. embedding = clip_encoder.predict(pil_image)
  77. assert isinstance(embedding, list)
  78. assert len(embedding) == 512
  79. assert all([isinstance(num, float) for num in embedding])
  80. clip_encoder.model.encode.assert_called_once()
  81. def test_basic_text(self, mocker: MockerFixture) -> None:
  82. mocker.patch.object(CLIPSTEncoder, "load")
  83. clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
  84. clip_encoder.model = mock.Mock()
  85. clip_encoder.model.encode.return_value = self.embedding
  86. embedding = clip_encoder.predict("test search query")
  87. assert isinstance(embedding, list)
  88. assert len(embedding) == 512
  89. assert all([isinstance(num, float) for num in embedding])
  90. clip_encoder.model.encode.assert_called_once()
  91. class TestFaceRecognition:
  92. def test_eager_init(self, mocker: MockerFixture) -> None:
  93. mocker.patch.object(FaceRecognizer, "download")
  94. mock_load = mocker.patch.object(FaceRecognizer, "load")
  95. face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
  96. assert face_model.model_name == "test_model_name"
  97. mock_load.assert_called_once_with(test_arg="test_arg")
  98. def test_lazy_init(self, mocker: MockerFixture) -> None:
  99. mock_download = mocker.patch.object(FaceRecognizer, "download")
  100. mock_load = mocker.patch.object(FaceRecognizer, "load")
  101. face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
  102. assert face_model.model_name == "test_model_name"
  103. mock_download.assert_called_once_with(test_arg="test_arg")
  104. mock_load.assert_not_called()
  105. def test_set_min_score(self, mocker: MockerFixture) -> None:
  106. mocker.patch.object(FaceRecognizer, "load")
  107. face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
  108. assert face_recognizer.min_score == 0.5
  109. def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
  110. mocker.patch.object(FaceRecognizer, "load")
  111. face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
  112. det_model = mock.Mock()
  113. num_faces = 2
  114. bbox = np.random.rand(num_faces, 4).astype(np.float32)
  115. score = np.array([[0.67]] * num_faces).astype(np.float32)
  116. kpss = np.random.rand(num_faces, 5, 2).astype(np.float32)
  117. det_model.detect.return_value = (np.concatenate([bbox, score], axis=-1), kpss)
  118. face_recognizer.det_model = det_model
  119. rec_model = mock.Mock()
  120. embedding = np.random.rand(num_faces, 512).astype(np.float32)
  121. rec_model.get_feat.return_value = embedding
  122. face_recognizer.rec_model = rec_model
  123. faces = face_recognizer.predict(cv_image)
  124. assert len(faces) == num_faces
  125. for face in faces:
  126. assert face["imageHeight"] == 800
  127. assert face["imageWidth"] == 600
  128. assert isinstance(face["embedding"], list)
  129. assert len(face["embedding"]) == 512
  130. assert all([isinstance(num, float) for num in face["embedding"]])
  131. det_model.detect.assert_called_once()
  132. assert rec_model.get_feat.call_count == num_faces
  133. @pytest.mark.asyncio
  134. class TestCache:
  135. async def test_caches(self, mock_get_model: mock.Mock) -> None:
  136. model_cache = ModelCache()
  137. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  138. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  139. assert len(model_cache.cache._cache) == 1
  140. mock_get_model.assert_called_once()
  141. async def test_kwargs_used(self, mock_get_model: mock.Mock) -> None:
  142. model_cache = ModelCache()
  143. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION, cache_dir="test_cache")
  144. mock_get_model.assert_called_once_with(
  145. ModelType.IMAGE_CLASSIFICATION, "test_model_name", cache_dir="test_cache"
  146. )
  147. async def test_different_clip(self, mock_get_model: mock.Mock) -> None:
  148. model_cache = ModelCache()
  149. await model_cache.get("test_image_model_name", ModelType.CLIP)
  150. await model_cache.get("test_text_model_name", ModelType.CLIP)
  151. mock_get_model.assert_has_calls(
  152. [
  153. mock.call(ModelType.CLIP, "test_image_model_name"),
  154. mock.call(ModelType.CLIP, "test_text_model_name"),
  155. ]
  156. )
  157. assert len(model_cache.cache._cache) == 2
  158. @mock.patch("app.models.cache.OptimisticLock", autospec=True)
  159. async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
  160. model_cache = ModelCache(ttl=100)
  161. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  162. mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
  163. @mock.patch("app.models.cache.SimpleMemoryCache.expire")
  164. async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
  165. model_cache = ModelCache(ttl=100, revalidate=True)
  166. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  167. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  168. mock_cache_expire.assert_called_once_with(mock.ANY, 100)
  169. @pytest.mark.skipif(
  170. not settings.test_full,
  171. reason="More time-consuming since it deploys the app and loads models.",
  172. )
  173. class TestEndpoints:
  174. def test_tagging_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  175. byte_image = BytesIO()
  176. pil_image.save(byte_image, format="jpeg")
  177. headers = {"Content-Type": "image/jpg"}
  178. response = deployed_app.post(
  179. "http://localhost:3003/image-classifier/tag-image",
  180. content=byte_image.getvalue(),
  181. headers=headers,
  182. )
  183. assert response.status_code == 200
  184. def test_clip_image_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  185. byte_image = BytesIO()
  186. pil_image.save(byte_image, format="jpeg")
  187. headers = {"Content-Type": "image/jpg"}
  188. response = deployed_app.post(
  189. "http://localhost:3003/sentence-transformer/encode-image",
  190. content=byte_image.getvalue(),
  191. headers=headers,
  192. )
  193. assert response.status_code == 200
  194. def test_clip_text_endpoint(self, deployed_app: TestClient) -> None:
  195. response = deployed_app.post(
  196. "http://localhost:3003/sentence-transformer/encode-text",
  197. json={"text": "test search query"},
  198. )
  199. assert response.status_code == 200
  200. def test_face_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  201. byte_image = BytesIO()
  202. pil_image.save(byte_image, format="jpeg")
  203. headers = {"Content-Type": "image/jpg"}
  204. response = deployed_app.post(
  205. "http://localhost:3003/facial-recognition/detect-faces",
  206. content=byte_image.getvalue(),
  207. headers=headers,
  208. )
  209. assert response.status_code == 200