config.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import logging
  2. import os
  3. from pathlib import Path
  4. import starlette
  5. from pydantic import BaseSettings
  6. from rich.console import Console
  7. from rich.logging import RichHandler
  8. from .schemas import ModelType
  9. class Settings(BaseSettings):
  10. cache_folder: str = "/cache"
  11. eager_startup: bool = False
  12. model_ttl: int = 0
  13. host: str = "0.0.0.0"
  14. port: int = 3003
  15. workers: int = 1
  16. test_full: bool = False
  17. request_threads: int = os.cpu_count() or 4
  18. model_inter_op_threads: int = 1
  19. model_intra_op_threads: int = 2
  20. class Config:
  21. env_prefix = "MACHINE_LEARNING_"
  22. case_sensitive = False
  23. class LogSettings(BaseSettings):
  24. log_level: str = "info"
  25. no_color: bool = False
  26. class Config:
  27. case_sensitive = False
  28. _clean_name = str.maketrans(":\\/", "___", ".")
  29. def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
  30. return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
  31. LOG_LEVELS: dict[str, int] = {
  32. "critical": logging.ERROR,
  33. "error": logging.ERROR,
  34. "warning": logging.WARNING,
  35. "warn": logging.WARNING,
  36. "info": logging.INFO,
  37. "log": logging.INFO,
  38. "debug": logging.DEBUG,
  39. "verbose": logging.DEBUG,
  40. }
  41. settings = Settings()
  42. log_settings = LogSettings()
  43. console = Console(color_system="standard", no_color=log_settings.no_color)
  44. logging.basicConfig(
  45. format="%(message)s",
  46. handlers=[
  47. RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
  48. ],
  49. )
  50. log = logging.getLogger("uvicorn")
  51. log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))