config.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import logging
  2. import os
  3. from pathlib import Path
  4. import gunicorn
  5. import starlette
  6. from pydantic import BaseSettings
  7. from rich.console import Console
  8. from rich.logging import RichHandler
  9. from .schemas import ModelType
  10. class Settings(BaseSettings):
  11. cache_folder: str = "/cache"
  12. model_ttl: int = 300
  13. model_ttl_poll_s: int = 10
  14. host: str = "0.0.0.0"
  15. port: int = 3003
  16. workers: int = 1
  17. test_full: bool = False
  18. request_threads: int = os.cpu_count() or 4
  19. model_inter_op_threads: int = 1
  20. model_intra_op_threads: int = 2
  21. class Config:
  22. env_prefix = "MACHINE_LEARNING_"
  23. case_sensitive = False
  24. class LogSettings(BaseSettings):
  25. log_level: str = "info"
  26. no_color: bool = False
  27. class Config:
  28. case_sensitive = False
  29. _clean_name = str.maketrans(":\\/", "___", ".")
  30. def clean_name(model_name: str) -> str:
  31. return model_name.split("/")[-1].translate(_clean_name)
  32. def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
  33. return Path(settings.cache_folder) / model_type.value / clean_name(model_name)
  34. def get_hf_model_name(model_name: str) -> str:
  35. return f"immich-app/{clean_name(model_name)}"
  36. LOG_LEVELS: dict[str, int] = {
  37. "critical": logging.ERROR,
  38. "error": logging.ERROR,
  39. "warning": logging.WARNING,
  40. "warn": logging.WARNING,
  41. "info": logging.INFO,
  42. "log": logging.INFO,
  43. "debug": logging.DEBUG,
  44. "verbose": logging.DEBUG,
  45. }
  46. settings = Settings()
  47. log_settings = LogSettings()
  48. class CustomRichHandler(RichHandler):
  49. def __init__(self) -> None:
  50. console = Console(color_system="standard", no_color=log_settings.no_color)
  51. super().__init__(
  52. show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
  53. )
  54. log = logging.getLogger("gunicorn.access")
  55. log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))