facial_recognition.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import zipfile
  2. from pathlib import Path
  3. from typing import Any
  4. import cv2
  5. import numpy as np
  6. from insightface.model_zoo import ArcFaceONNX, RetinaFace
  7. from insightface.utils.face_align import norm_crop
  8. from insightface.utils.storage import BASE_REPO_URL, download_file
  9. from ..config import settings
  10. from ..schemas import ModelType
  11. from .base import InferenceModel
  12. class FaceRecognizer(InferenceModel):
  13. _model_type = ModelType.FACIAL_RECOGNITION
  14. def __init__(
  15. self,
  16. model_name: str,
  17. min_score: float = settings.min_face_score,
  18. cache_dir: Path | str | None = None,
  19. **model_kwargs: Any,
  20. ) -> None:
  21. self.min_score = min_score
  22. super().__init__(model_name, cache_dir, **model_kwargs)
  23. def _download(self, **model_kwargs: Any) -> None:
  24. zip_file = self.cache_dir / f"{self.model_name}.zip"
  25. download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
  26. with zipfile.ZipFile(zip_file, "r") as zip:
  27. members = zip.namelist()
  28. det_file = next(model for model in members if model.startswith("det_"))
  29. rec_file = next(model for model in members if model.startswith("w600k_"))
  30. zip.extractall(self.cache_dir, members=[det_file, rec_file])
  31. zip_file.unlink()
  32. def _load(self, **model_kwargs: Any) -> None:
  33. try:
  34. det_file = next(self.cache_dir.glob("det_*.onnx"))
  35. rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
  36. except StopIteration:
  37. raise FileNotFoundError("Facial recognition models not found in cache directory")
  38. self.det_model = RetinaFace(det_file.as_posix())
  39. self.rec_model = ArcFaceONNX(rec_file.as_posix())
  40. self.det_model.prepare(
  41. ctx_id=-1,
  42. det_thresh=self.min_score,
  43. input_size=(640, 640),
  44. )
  45. self.rec_model.prepare(ctx_id=-1)
  46. def _predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
  47. bboxes, kpss = self.det_model.detect(image)
  48. if bboxes.size == 0:
  49. return []
  50. assert isinstance(kpss, np.ndarray)
  51. scores = bboxes[:, 4].tolist()
  52. bboxes = bboxes[:, :4].round().tolist()
  53. results = []
  54. height, width, _ = image.shape
  55. for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
  56. cropped_img = norm_crop(image, kps)
  57. embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
  58. results.append(
  59. {
  60. "imageWidth": width,
  61. "imageHeight": height,
  62. "boundingBox": {
  63. "x1": x1,
  64. "y1": y1,
  65. "x2": x2,
  66. "y2": y2,
  67. },
  68. "score": score,
  69. "embedding": embedding,
  70. }
  71. )
  72. return results
  73. @property
  74. def cached(self) -> bool:
  75. return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))