config.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import logging
  2. import os
  3. from pathlib import Path
  4. import gunicorn
  5. import starlette
  6. from pydantic import BaseSettings
  7. from rich.console import Console
  8. from rich.logging import RichHandler
  9. from .schemas import ModelType
  10. class Settings(BaseSettings):
  11. cache_folder: str = "/cache"
  12. model_ttl: int = 0
  13. host: str = "0.0.0.0"
  14. port: int = 3003
  15. workers: int = 1
  16. test_full: bool = False
  17. request_threads: int = os.cpu_count() or 4
  18. model_inter_op_threads: int = 1
  19. model_intra_op_threads: int = 2
  20. class Config:
  21. env_prefix = "MACHINE_LEARNING_"
  22. case_sensitive = False
  23. class LogSettings(BaseSettings):
  24. log_level: str = "info"
  25. no_color: bool = False
  26. class Config:
  27. case_sensitive = False
  28. _clean_name = str.maketrans(":\\/", "___", ".")
  29. def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
  30. return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
  31. LOG_LEVELS: dict[str, int] = {
  32. "critical": logging.ERROR,
  33. "error": logging.ERROR,
  34. "warning": logging.WARNING,
  35. "warn": logging.WARNING,
  36. "info": logging.INFO,
  37. "log": logging.INFO,
  38. "debug": logging.DEBUG,
  39. "verbose": logging.DEBUG,
  40. }
  41. settings = Settings()
  42. log_settings = LogSettings()
  43. class CustomRichHandler(RichHandler):
  44. def __init__(self) -> None:
  45. console = Console(color_system="standard", no_color=log_settings.no_color)
  46. super().__init__(
  47. show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
  48. )
  49. log = logging.getLogger("gunicorn.access")
  50. log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))