__init__.py 1004 B

12345678910111213141516171819202122232425
  1. from typing import Any
  2. from app.schemas import ModelType
  3. from .base import InferenceModel
  4. from .clip import MCLIPEncoder, OpenCLIPEncoder, is_mclip, is_openclip
  5. from .facial_recognition import FaceRecognizer
  6. from .image_classification import ImageClassifier
  7. def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
  8. match model_type:
  9. case ModelType.CLIP:
  10. if is_openclip(model_name):
  11. return OpenCLIPEncoder(model_name, **model_kwargs)
  12. elif is_mclip(model_name):
  13. return MCLIPEncoder(model_name, **model_kwargs)
  14. else:
  15. raise ValueError(f"Unknown CLIP model {model_name}")
  16. case ModelType.FACIAL_RECOGNITION:
  17. return FaceRecognizer(model_name, **model_kwargs)
  18. case ModelType.IMAGE_CLASSIFICATION:
  19. return ImageClassifier(model_name, **model_kwargs)
  20. case _:
  21. raise ValueError(f"Unknown model type {model_type}")