facial_recognition.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import zipfile
  2. from pathlib import Path
  3. from typing import Any
  4. import cv2
  5. import numpy as np
  6. import onnxruntime as ort
  7. from insightface.model_zoo import ArcFaceONNX, RetinaFace
  8. from insightface.utils.face_align import norm_crop
  9. from insightface.utils.storage import BASE_REPO_URL, download_file
  10. from ..schemas import ModelType
  11. from .base import InferenceModel
  12. class FaceRecognizer(InferenceModel):
  13. _model_type = ModelType.FACIAL_RECOGNITION
  14. def __init__(
  15. self,
  16. model_name: str,
  17. min_score: float = 0.7,
  18. cache_dir: Path | str | None = None,
  19. **model_kwargs: Any,
  20. ) -> None:
  21. self.min_score = model_kwargs.pop("minScore", min_score)
  22. super().__init__(model_name, cache_dir, **model_kwargs)
  23. def _download(self) -> None:
  24. zip_file = self.cache_dir / f"{self.model_name}.zip"
  25. download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
  26. with zipfile.ZipFile(zip_file, "r") as zip:
  27. members = zip.namelist()
  28. det_file = next(model for model in members if model.startswith("det_"))
  29. rec_file = next(model for model in members if model.startswith("w600k_"))
  30. zip.extractall(self.cache_dir, members=[det_file, rec_file])
  31. zip_file.unlink()
  32. def _load(self) -> None:
  33. try:
  34. det_file = next(self.cache_dir.glob("det_*.onnx"))
  35. rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
  36. except StopIteration:
  37. raise FileNotFoundError("Facial recognition models not found in cache directory")
  38. self.det_model = RetinaFace(
  39. session=ort.InferenceSession(
  40. det_file.as_posix(),
  41. sess_options=self.sess_options,
  42. providers=self.providers,
  43. provider_options=self.provider_options,
  44. ),
  45. )
  46. self.rec_model = ArcFaceONNX(
  47. rec_file.as_posix(),
  48. session=ort.InferenceSession(
  49. rec_file.as_posix(),
  50. sess_options=self.sess_options,
  51. providers=self.providers,
  52. provider_options=self.provider_options,
  53. ),
  54. )
  55. self.det_model.prepare(
  56. ctx_id=0,
  57. det_thresh=self.min_score,
  58. input_size=(640, 640),
  59. )
  60. self.rec_model.prepare(ctx_id=0)
  61. def _predict(self, image: np.ndarray[int, np.dtype[Any]] | bytes) -> list[dict[str, Any]]:
  62. if isinstance(image, bytes):
  63. image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
  64. bboxes, kpss = self.det_model.detect(image)
  65. if bboxes.size == 0:
  66. return []
  67. assert isinstance(image, np.ndarray) and isinstance(kpss, np.ndarray)
  68. scores = bboxes[:, 4].tolist()
  69. bboxes = bboxes[:, :4].round().tolist()
  70. results = []
  71. height, width, _ = image.shape
  72. for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
  73. cropped_img = norm_crop(image, kps)
  74. embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
  75. results.append(
  76. {
  77. "imageWidth": width,
  78. "imageHeight": height,
  79. "boundingBox": {
  80. "x1": x1,
  81. "y1": y1,
  82. "x2": x2,
  83. "y2": y2,
  84. },
  85. "score": score,
  86. "embedding": embedding,
  87. }
  88. )
  89. return results
  90. @property
  91. def cached(self) -> bool:
  92. return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
  93. def configure(self, **model_kwargs: Any) -> None:
  94. self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)