test_main.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from io import BytesIO
  2. from pathlib import Path
  3. from unittest import mock
  4. import cv2
  5. import pytest
  6. from fastapi.testclient import TestClient
  7. from PIL import Image
  8. from .config import settings
  9. from .models.cache import ModelCache
  10. from .models.clip import CLIPSTEncoder
  11. from .models.facial_recognition import FaceRecognizer
  12. from .models.image_classification import ImageClassifier
  13. from .schemas import ModelType
  14. class TestImageClassifier:
  15. def test_init(self, mock_classifier_pipeline: mock.Mock) -> None:
  16. cache_dir = Path("test_cache")
  17. classifier = ImageClassifier("test_model_name", 0.5, cache_dir=cache_dir)
  18. assert classifier.min_score == 0.5
  19. mock_classifier_pipeline.assert_called_once_with(
  20. "image-classification",
  21. "test_model_name",
  22. model_kwargs={"cache_dir": cache_dir},
  23. )
  24. def test_min_score(self, pil_image: Image.Image, mock_classifier_pipeline: mock.Mock) -> None:
  25. classifier = ImageClassifier("test_model_name", min_score=0.0)
  26. classifier.min_score = 0.0
  27. all_labels = classifier.predict(pil_image)
  28. classifier.min_score = 0.5
  29. filtered_labels = classifier.predict(pil_image)
  30. assert all_labels == [
  31. "that's an image alright",
  32. "well it ends with .jpg",
  33. "idk",
  34. "im just seeing bytes",
  35. "not sure",
  36. "probably a virus",
  37. ]
  38. assert filtered_labels == ["that's an image alright"]
  39. class TestCLIP:
  40. def test_init(self, mock_st: mock.Mock) -> None:
  41. CLIPSTEncoder("test_model_name", cache_dir="test_cache")
  42. mock_st.assert_called_once_with("test_model_name", cache_folder="test_cache")
  43. def test_basic_image(self, pil_image: Image.Image, mock_st: mock.Mock) -> None:
  44. clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
  45. embedding = clip_encoder.predict(pil_image)
  46. assert isinstance(embedding, list)
  47. assert len(embedding) == 512
  48. assert all([isinstance(num, float) for num in embedding])
  49. mock_st.assert_called_once()
  50. def test_basic_text(self, mock_st: mock.Mock) -> None:
  51. clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
  52. embedding = clip_encoder.predict("test search query")
  53. assert isinstance(embedding, list)
  54. assert len(embedding) == 512
  55. assert all([isinstance(num, float) for num in embedding])
  56. mock_st.assert_called_once()
  57. class TestFaceRecognition:
  58. def test_init(self, mock_faceanalysis: mock.Mock) -> None:
  59. FaceRecognizer("test_model_name", cache_dir="test_cache")
  60. mock_faceanalysis.assert_called_once_with(
  61. name="test_model_name",
  62. root="test_cache",
  63. allowed_modules=["detection", "recognition"],
  64. )
  65. def test_basic(self, cv_image: cv2.Mat, mock_faceanalysis: mock.Mock) -> None:
  66. face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
  67. faces = face_recognizer.predict(cv_image)
  68. assert len(faces) == 2
  69. for face in faces:
  70. assert face["imageHeight"] == 800
  71. assert face["imageWidth"] == 600
  72. assert isinstance(face["embedding"], list)
  73. assert len(face["embedding"]) == 512
  74. assert all([isinstance(num, float) for num in face["embedding"]])
  75. mock_faceanalysis.assert_called_once()
  76. @pytest.mark.asyncio
  77. class TestCache:
  78. async def test_caches(self, mock_get_model: mock.Mock) -> None:
  79. model_cache = ModelCache()
  80. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  81. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  82. assert len(model_cache.cache._cache) == 1
  83. mock_get_model.assert_called_once()
  84. async def test_kwargs_used(self, mock_get_model: mock.Mock) -> None:
  85. model_cache = ModelCache()
  86. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION, cache_dir="test_cache")
  87. mock_get_model.assert_called_once_with(
  88. ModelType.IMAGE_CLASSIFICATION, "test_model_name", cache_dir="test_cache"
  89. )
  90. async def test_different_clip(self, mock_get_model: mock.Mock) -> None:
  91. model_cache = ModelCache()
  92. await model_cache.get("test_image_model_name", ModelType.CLIP)
  93. await model_cache.get("test_text_model_name", ModelType.CLIP)
  94. mock_get_model.assert_has_calls(
  95. [
  96. mock.call(ModelType.CLIP, "test_image_model_name"),
  97. mock.call(ModelType.CLIP, "test_text_model_name"),
  98. ]
  99. )
  100. assert len(model_cache.cache._cache) == 2
  101. @mock.patch("app.models.cache.OptimisticLock", autospec=True)
  102. async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
  103. model_cache = ModelCache(ttl=100)
  104. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  105. mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
  106. @mock.patch("app.models.cache.SimpleMemoryCache.expire")
  107. async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
  108. model_cache = ModelCache(ttl=100, revalidate=True)
  109. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  110. await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
  111. mock_cache_expire.assert_called_once_with(mock.ANY, 100)
  112. @pytest.mark.skipif(
  113. not settings.test_full,
  114. reason="More time-consuming since it deploys the app and loads models.",
  115. )
  116. class TestEndpoints:
  117. def test_tagging_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  118. byte_image = BytesIO()
  119. pil_image.save(byte_image, format="jpeg")
  120. headers = {"Content-Type": "image/jpg"}
  121. response = deployed_app.post(
  122. "http://localhost:3003/image-classifier/tag-image",
  123. content=byte_image.getvalue(),
  124. headers=headers,
  125. )
  126. assert response.status_code == 200
  127. def test_clip_image_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  128. byte_image = BytesIO()
  129. pil_image.save(byte_image, format="jpeg")
  130. headers = {"Content-Type": "image/jpg"}
  131. response = deployed_app.post(
  132. "http://localhost:3003/sentence-transformer/encode-image",
  133. content=byte_image.getvalue(),
  134. headers=headers,
  135. )
  136. assert response.status_code == 200
  137. def test_clip_text_endpoint(self, deployed_app: TestClient) -> None:
  138. response = deployed_app.post(
  139. "http://localhost:3003/sentence-transformer/encode-text",
  140. json={"text": "test search query"},
  141. )
  142. assert response.status_code == 200
  143. def test_face_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
  144. byte_image = BytesIO()
  145. pil_image.save(byte_image, format="jpeg")
  146. headers = {"Content-Type": "image/jpg"}
  147. response = deployed_app.post(
  148. "http://localhost:3003/facial-recognition/detect-faces",
  149. content=byte_image.getvalue(),
  150. headers=headers,
  151. )
  152. assert response.status_code == 200