config.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import logging
  2. import os
  3. from pathlib import Path
  4. from pydantic import BaseSettings
  5. from rich.console import Console
  6. from rich.logging import RichHandler
  7. from .schemas import ModelType
  8. class Settings(BaseSettings):
  9. cache_folder: str = "/cache"
  10. model_ttl: int = 0
  11. host: str = "0.0.0.0"
  12. port: int = 3003
  13. workers: int = 1
  14. test_full: bool = False
  15. request_threads: int = os.cpu_count() or 4
  16. model_inter_op_threads: int = 1
  17. model_intra_op_threads: int = 2
  18. class Config:
  19. env_prefix = "MACHINE_LEARNING_"
  20. case_sensitive = False
  21. class LogSettings(BaseSettings):
  22. log_level: str = "info"
  23. no_color: bool = False
  24. class Config:
  25. case_sensitive = False
  26. _clean_name = str.maketrans(":\\/", "___", ".")
  27. def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
  28. return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
  29. LOG_LEVELS: dict[str, int] = {
  30. "critical": logging.ERROR,
  31. "error": logging.ERROR,
  32. "warning": logging.WARNING,
  33. "warn": logging.WARNING,
  34. "info": logging.INFO,
  35. "log": logging.INFO,
  36. "debug": logging.DEBUG,
  37. "verbose": logging.DEBUG,
  38. }
  39. settings = Settings()
  40. log_settings = LogSettings()
  41. class CustomRichHandler(RichHandler):
  42. def __init__(self) -> None:
  43. console = Console(no_color=log_settings.no_color)
  44. super().__init__(
  45. show_path=False,
  46. omit_repeated_times=False,
  47. console=console,
  48. tracebacks_width=100,
  49. )
  50. log = logging.getLogger("gunicorn.access")
  51. log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))