Browse Source

fixed tests (#5017)

Mert 1 year ago
parent
commit
291159e7fc
1 changed files with 9 additions and 9 deletions
  1. 9 9
      machine-learning/app/test_main.py

+ 9 - 9
machine-learning/app/test_main.py

@@ -75,9 +75,9 @@ class TestCLIP:
         embedding = clip_encoder.predict(pil_image)
 
         assert clip_encoder.mode == "vision"
-        assert isinstance(embedding, list)
-        assert len(embedding) == clip_model_cfg["embed_dim"]
-        assert all([isinstance(num, float) for num in embedding])
+        assert isinstance(embedding, np.ndarray)
+        assert embedding.shape[0] == clip_model_cfg["embed_dim"]
+        assert embedding.dtype == np.float32
         clip_encoder.vision_model.run.assert_called_once()
 
     def test_basic_text(
@@ -97,9 +97,9 @@ class TestCLIP:
         embedding = clip_encoder.predict("test search query")
 
         assert clip_encoder.mode == "text"
-        assert isinstance(embedding, list)
-        assert len(embedding) == clip_model_cfg["embed_dim"]
-        assert all([isinstance(num, float) for num in embedding])
+        assert isinstance(embedding, np.ndarray)
+        assert embedding.shape[0] == clip_model_cfg["embed_dim"]
+        assert embedding.dtype == np.float32
         clip_encoder.text_model.run.assert_called_once()
 
 
@@ -133,9 +133,9 @@ class TestFaceRecognition:
         for face in faces:
             assert face["imageHeight"] == 800
             assert face["imageWidth"] == 600
-            assert isinstance(face["embedding"], list)
-            assert len(face["embedding"]) == 512
-            assert all([isinstance(num, float) for num in face["embedding"]])
+            assert isinstance(face["embedding"], np.ndarray)
+            assert face["embedding"].shape[0] == 512
+            assert face["embedding"].dtype == np.float32
 
         det_model.detect.assert_called_once()
         assert rec_model.get_feat.call_count == num_faces