Browse Source

chore(ml): use strict mypy (#5001)

* improved typing

* improved export typing

* strict mypy & check export folder

* formatting

* add formatting checks for export folder

* re-added init call
Mert 1 year ago
parent
commit
935f471ccb

+ 3 - 3
.github/workflows/test.yml

@@ -168,13 +168,13 @@ jobs:
           poetry install --with dev
       - name: Lint with ruff
         run: |
-          poetry run ruff check --format=github app
+          poetry run ruff check --format=github app export
       - name: Check black formatting
         run: |
-          poetry run black --check app
+          poetry run black --check app export
       - name: Run mypy type checking
         run: |
-          poetry run mypy --install-types --non-interactive app/
+          poetry run mypy --install-types --non-interactive --strict app/ export/
       - name: Run tests and coverage
         run: |
           poetry run pytest --cov app

+ 2 - 1
machine-learning/app/conftest.py

@@ -36,7 +36,8 @@ def deployed_app() -> TestClient:
 
 @pytest.fixture(scope="session")
 def responses() -> dict[str, Any]:
-    return json.load(open("responses.json", "r"))
+    responses: dict[str, Any] = json.load(open("responses.json", "r"))
+    return responses
 
 
 @pytest.fixture(scope="session")

+ 1 - 1
machine-learning/app/main.py

@@ -7,7 +7,7 @@ from zipfile import BadZipFile
 import orjson
 from fastapi import FastAPI, Form, HTTPException, UploadFile
 from fastapi.responses import ORJSONResponse
-from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile  # type: ignore
+from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
 from starlette.formparsers import MultiPartParser
 
 from app.models.base import InferenceModel

+ 6 - 4
machine-learning/app/models/base.py

@@ -8,6 +8,7 @@ from typing import Any
 
 import onnxruntime as ort
 from huggingface_hub import snapshot_download
+from typing_extensions import Buffer
 
 from ..config import get_cache_dir, get_hf_model_name, log, settings
 from ..schemas import ModelType
@@ -139,11 +140,12 @@ class InferenceModel(ABC):
 
 
 # HF deep copies configs, so we need to make session options picklable
-class PicklableSessionOptions(ort.SessionOptions):
+class PicklableSessionOptions(ort.SessionOptions):  # type: ignore[misc]
     def __getstate__(self) -> bytes:
         return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
 
-    def __setstate__(self, state: Any) -> None:
-        self.__init__()  # type: ignore
-        for attr, val in pickle.loads(state):
+    def __setstate__(self, state: Buffer) -> None:
+        self.__init__()  # type: ignore[misc]
+        attrs: list[tuple[str, Any]] = pickle.loads(state)
+        for attr, val in attrs:
             setattr(self, attr, val)

+ 5 - 5
machine-learning/app/models/cache.py

@@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin
 
 from app.models import from_model_type
 
-from ..schemas import ModelType
+from ..schemas import ModelType, has_profiling
 from .base import InferenceModel
 
 
@@ -50,20 +50,20 @@ class ModelCache:
 
         key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
         async with OptimisticLock(self.cache, key) as lock:
-            model = await self.cache.get(key)
+            model: InferenceModel | None = await self.cache.get(key)
             if model is None:
                 model = from_model_type(model_type, model_name, **model_kwargs)
                 await lock.cas(model, ttl=self.ttl)
         return model
 
     async def get_profiling(self) -> dict[str, float] | None:
-        if not hasattr(self.cache, "profiling"):
+        if not has_profiling(self.cache):
             return None
 
-        return self.cache.profiling  # type: ignore
+        return self.cache.profiling
 
 
-class RevalidationPlugin(BasePlugin):
+class RevalidationPlugin(BasePlugin):  # type: ignore[misc]
     """Revalidates cache item's TTL after cache hit."""
 
     async def post_get(

+ 8 - 6
machine-learning/app/models/clip.py

@@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
                 provider_options=self.provider_options,
             )
 
-    def _predict(self, image_or_text: Image.Image | str) -> list[float]:
+    def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
         if isinstance(image_or_text, bytes):
             image_or_text = Image.open(BytesIO(image_or_text))
 
@@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
                 if self.mode == "text":
                     raise TypeError("Cannot encode image as text-only model")
 
-                outputs = self.vision_model.run(None, self.transform(image_or_text))
+                outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
             case str():
                 if self.mode == "vision":
                     raise TypeError("Cannot encode text as vision-only model")
 
-                outputs = self.text_model.run(None, self.tokenize(image_or_text))
+                outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
             case _:
                 raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
 
-        return outputs[0][0].tolist()
+        return outputs
 
     @abstractmethod
     def tokenize(self, text: str) -> dict[str, ndarray_i32]:
@@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
 
     @cached_property
     def model_cfg(self) -> dict[str, Any]:
-        return json.load(self.model_cfg_path.open())
+        model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
+        return model_cfg
 
     @cached_property
     def preprocess_cfg(self) -> dict[str, Any]:
-        return json.load(self.preprocess_cfg_path.open())
+        preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
+        return preprocess_cfg
 
 
 class MCLIPEncoder(OpenCLIPEncoder):

+ 16 - 17
machine-learning/app/models/facial_recognition.py

@@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
 from insightface.utils.face_align import norm_crop
 
 from app.config import clean_name
-from app.schemas import ModelType, ndarray_f32
+from app.schemas import BoundingBox, Face, ModelType, ndarray_f32
 
 from .base import InferenceModel
 
@@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel):
         )
         self.rec_model.prepare(ctx_id=0)
 
-    def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]:
+    def _predict(self, image: ndarray_f32 | bytes) -> list[Face]:
         if isinstance(image, bytes):
             image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
         bboxes, kpss = self.det_model.detect(image)
@@ -67,21 +67,20 @@ class FaceRecognizer(InferenceModel):
         height, width, _ = image.shape
         for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
             cropped_img = norm_crop(image, kps)
-            embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
-            results.append(
-                {
-                    "imageWidth": width,
-                    "imageHeight": height,
-                    "boundingBox": {
-                        "x1": x1,
-                        "y1": y1,
-                        "x2": x2,
-                        "y2": y2,
-                    },
-                    "score": score,
-                    "embedding": embedding,
-                }
-            )
+            embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0]
+            face: Face = {
+                "imageWidth": width,
+                "imageHeight": height,
+                "boundingBox": {
+                    "x1": x1,
+                    "y1": y1,
+                    "x2": x2,
+                    "y2": y2,
+                },
+                "score": score,
+                "embedding": embedding,
+            }
+            results.append(face)
         return results
 
     @property

+ 1 - 1
machine-learning/app/models/image_classification.py

@@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel):
     def _predict(self, image: Image.Image | bytes) -> list[str]:
         if isinstance(image, bytes):
             image = Image.open(BytesIO(image))
-        predictions: list[dict[str, Any]] = self.model(image)  # type: ignore
+        predictions: list[dict[str, Any]] = self.model(image)
         tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
 
         return tags

+ 19 - 13
machine-learning/app/schemas.py

@@ -1,17 +1,12 @@
 from enum import StrEnum
-from typing import TypeAlias
+from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard
 
 import numpy as np
 from pydantic import BaseModel
 
-
-def to_lower_camel(string: str) -> str:
-    tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))]
-    return "".join(tokens)
-
-
-class TextModelRequest(BaseModel):
-    text: str
+ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
+ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
+ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
 
 
 class TextResponse(BaseModel):
@@ -22,7 +17,7 @@ class MessageResponse(BaseModel):
     message: str
 
 
-class BoundingBox(BaseModel):
+class BoundingBox(TypedDict):
     x1: int
     y1: int
     x2: int
@@ -35,6 +30,17 @@ class ModelType(StrEnum):
     FACIAL_RECOGNITION = "facial-recognition"
 
 
-ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
-ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
-ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
+class HasProfiling(Protocol):
+    profiling: dict[str, float]
+
+
+class Face(TypedDict):
+    boundingBox: BoundingBox
+    embedding: ndarray_f32
+    imageWidth: int
+    imageHeight: int
+    score: float
+
+
+def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
+    return hasattr(obj, "profiling") and type(obj.profiling) == dict

+ 9 - 4
machine-learning/export/models/openclip.py

@@ -1,6 +1,7 @@
 import tempfile
 import warnings
 from dataclasses import dataclass, field
+from math import e
 from pathlib import Path
 
 import open_clip
@@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
     output_path = Path(output_path)
 
     def encode_image(image: torch.Tensor) -> torch.Tensor:
-        return model.encode_image(image, normalize=True)
+        output = model.encode_image(image, normalize=True)
+        assert isinstance(output, torch.Tensor)
+        return output
 
     args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
-    traced = torch.jit.trace(encode_image, args)
+    traced = torch.jit.trace(encode_image, args)  # type: ignore[no-untyped-call]
 
     with warnings.catch_warnings():
         warnings.simplefilter("ignore", UserWarning)
@@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
     output_path = Path(output_path)
 
     def encode_text(text: torch.Tensor) -> torch.Tensor:
-        return model.encode_text(text, normalize=True)
+        output = model.encode_text(text, normalize=True)
+        assert isinstance(output, torch.Tensor)
+        return output
 
     args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
-    traced = torch.jit.trace(encode_text, args)
+    traced = torch.jit.trace(encode_text, args)  # type: ignore[no-untyped-call]
 
     with warnings.catch_warnings():
         warnings.simplefilter("ignore", UserWarning)