浏览代码

chore(ml): improved logging (#3918)

* fixed `minScore` not being set correctly

* apply to init

* don't send `enabled`

* fix eslint warning

* added logger

* added logging

* refinements

* enable access log for info level

* formatting

* merged strings

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
Mert 1 年之前
父节点
当前提交
54b2779b79

+ 34 - 0
machine-learning/app/config.py

@@ -1,7 +1,11 @@
+import logging
 import os
 from pathlib import Path
 
+import starlette
 from pydantic import BaseSettings
+from rich.console import Console
+from rich.logging import RichHandler
 
 from .schemas import ModelType
 
@@ -23,6 +27,14 @@ class Settings(BaseSettings):
         case_sensitive = False
 
 
+class LogSettings(BaseSettings):
+    log_level: str = "info"
+    no_color: bool = False
+
+    class Config:
+        case_sensitive = False
+
+
 _clean_name = str.maketrans(":\\/", "___", ".")
 
 
@@ -30,4 +42,26 @@ def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
     return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
 
 
+LOG_LEVELS: dict[str, int] = {
+    "critical": logging.ERROR,
+    "error": logging.ERROR,
+    "warning": logging.WARNING,
+    "warn": logging.WARNING,
+    "info": logging.INFO,
+    "log": logging.INFO,
+    "debug": logging.DEBUG,
+    "verbose": logging.DEBUG,
+}
+
 settings = Settings()
+log_settings = LogSettings()
+
+console = Console(color_system="standard", no_color=log_settings.no_color)
+logging.basicConfig(
+    format="%(message)s",
+    handlers=[
+        RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
+    ],
+)
+log = logging.getLogger("uvicorn")
+log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))

+ 11 - 2
machine-learning/app/main.py

@@ -1,4 +1,5 @@
 import asyncio
+import logging
 import os
 from concurrent.futures import ThreadPoolExecutor
 from typing import Any
@@ -11,7 +12,7 @@ from starlette.formparsers import MultiPartParser
 
 from app.models.base import InferenceModel
 
-from .config import settings
+from .config import log, settings
 from .models.cache import ModelCache
 from .schemas import (
     MessageResponse,
@@ -20,14 +21,20 @@ from .schemas import (
 )
 
 MultiPartParser.max_file_size = 2**24  # spools to disk if payload is 16 MiB or larger
-
 app = FastAPI()
 
 
 def init_state() -> None:
     app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
+    log.info(
+        (
+            "Created in-memory cache with unloading "
+            f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
+        )
+    )
     # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
     app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
+    log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
 
 
 @app.on_event("startup")
@@ -77,4 +84,6 @@ if __name__ == "__main__":
         port=settings.port,
         reload=is_dev,
         workers=settings.workers,
+        log_config=None,
+        access_log=log.isEnabledFor(logging.INFO),
     )

+ 32 - 4
machine-learning/app/models/base.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import os
 import pickle
 from abc import ABC, abstractmethod
 from pathlib import Path
@@ -11,7 +10,7 @@ from zipfile import BadZipFile
 import onnxruntime as ort
 from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf  # type: ignore
 
-from ..config import get_cache_dir, settings
+from ..config import get_cache_dir, log, settings
 from ..schemas import ModelType
 
 
@@ -37,22 +36,41 @@ class InferenceModel(ABC):
         self.provider_options = model_kwargs.pop(
             "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
         )
+        log.debug(
+            (
+                f"Setting '{self.model_name}' execution providers to {self.providers}"
+                "in descending order of preference"
+            ),
+        )
+        log.debug(f"Setting execution provider options to {self.provider_options}")
         self.sess_options = PicklableSessionOptions()
         # avoid thread contention between models
         if inter_op_num_threads > 1:
             self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
+
+        log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
+        log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
+        log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
         self.sess_options.inter_op_num_threads = inter_op_num_threads
         self.sess_options.intra_op_num_threads = intra_op_num_threads
 
         try:
             loader(**model_kwargs)
         except (OSError, InvalidProtobuf, BadZipFile):
+            log.warn(
+                (
+                    f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
+                    "Clearing cache and retrying."
+                )
+            )
             self.clear_cache()
             loader(**model_kwargs)
 
     def download(self, **model_kwargs: Any) -> None:
         if not self.cached:
-            print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
+            log.info(
+                (f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
+            )
             self._download(**model_kwargs)
 
     def load(self, **model_kwargs: Any) -> None:
@@ -62,7 +80,7 @@ class InferenceModel(ABC):
 
     def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
         if not self._loaded:
-            print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
+            log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
             self.load()
         if model_kwargs:
             self.configure(**model_kwargs)
@@ -109,13 +127,23 @@ class InferenceModel(ABC):
 
     def clear_cache(self) -> None:
         if not self.cache_dir.exists():
+            log.warn(
+                f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
+            )
             return
         if not rmtree.avoids_symlink_attacks:
             raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
 
         if self.cache_dir.is_dir():
+            log.info(f"Cleared cache directory for model '{self.model_name}'.")
             rmtree(self.cache_dir)
         else:
+            log.warn(
+                (
+                    f"Encountered file instead of directory at cache path "
+                    f"for '{self.model_name}'. Removing file and replacing with a directory."
+                ),
+            )
             self.cache_dir.unlink()
         self.cache_dir.mkdir(parents=True, exist_ok=True)
 

+ 6 - 3
machine-learning/app/models/clip.py

@@ -12,6 +12,7 @@ from clip_server.model.tokenization import Tokenizer
 from PIL import Image
 from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
 
+from ..config import log
 from ..schemas import ModelType
 from .base import InferenceModel
 
@@ -105,9 +106,11 @@ class CLIPEncoder(InferenceModel):
         if model_name in _MODELS:
             return model_name
         elif model_name in _ST_TO_JINA_MODEL_NAME:
-            print(
-                (f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
-                (f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
+            log.warn(
+                (
+                    f"Sentence-Transformer models like '{model_name}' are not supported."
+                    f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
+                ),
             )
             return _ST_TO_JINA_MODEL_NAME[model_name]
         else:

+ 9 - 2
machine-learning/app/models/image_classification.py

@@ -8,6 +8,7 @@ from optimum.pipelines import pipeline
 from PIL import Image
 from transformers import AutoImageProcessor
 
+from ..config import log
 from ..schemas import ModelType
 from .base import InferenceModel
 
@@ -35,19 +36,25 @@ class ImageClassifier(InferenceModel):
         )
 
     def _load(self, **model_kwargs: Any) -> None:
-        processor = AutoImageProcessor.from_pretrained(self.cache_dir)
+        processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
+        model_path = self.cache_dir / "model.onnx"
         model_kwargs |= {
             "cache_dir": self.cache_dir,
             "provider": self.providers[0],
             "provider_options": self.provider_options[0],
             "session_options": self.sess_options,
         }
-        model_path = self.cache_dir / "model.onnx"
 
         if model_path.exists():
             model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
             self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
         else:
+            log.info(
+                (
+                    f"ONNX model not found in cache directory for '{self.model_name}'."
+                    "Exporting optimized model for future use."
+                ),
+            )
             self.sess_options.optimized_model_filepath = model_path.as_posix()
             self.model = pipeline(
                 self.model_type.value,