123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import asyncio
- from typing import Any
- from aiocache.backends.memory import SimpleMemoryCache
- from aiocache.lock import OptimisticLock
- from aiocache.plugins import BasePlugin, TimingPlugin
- from ..schemas import ModelType
- from .base import InferenceModel
- class ModelCache:
- """Fetches a model from an in-memory cache, instantiating it if it's missing."""
- def __init__(
- self,
- ttl: float | None = None,
- revalidate: bool = False,
- timeout: int | None = None,
- profiling: bool = False,
- ):
- """
- Args:
- ttl: Unloads model after this duration. Disabled if None. Defaults to None.
- revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
- timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
- profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
- """
- self.ttl = ttl
- plugins = []
- if revalidate:
- plugins.append(RevalidationPlugin())
- if profiling:
- plugins.append(TimingPlugin())
- self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
- async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
- """
- Args:
- model_name: Name of model in the model hub used for the task.
- model_type: Model type or task, which determines which model zoo is used.
- Returns:
- model: The requested model.
- """
- key = self.cache.build_key(model_name, model_type.value)
- model = await self.cache.get(key)
- if model is None:
- async with OptimisticLock(self.cache, key) as lock:
- model = await asyncio.get_running_loop().run_in_executor(
- None,
- lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
- )
- await lock.cas(model, ttl=self.ttl)
- return model
- async def get_profiling(self) -> dict[str, float] | None:
- if not hasattr(self.cache, "profiling"):
- return None
- return self.cache.profiling # type: ignore
- class RevalidationPlugin(BasePlugin):
- """Revalidates cache item's TTL after cache hit."""
- async def post_get(
- self,
- client: SimpleMemoryCache,
- key: str,
- ret: Any | None = None,
- namespace: str | None = None,
- **kwargs: Any,
- ) -> None:
- if ret is None:
- return
- if namespace is not None:
- key = client.build_key(key, namespace)
- if key in client._handlers:
- await client.expire(key, client.ttl)
- async def post_multi_get(
- self,
- client: SimpleMemoryCache,
- keys: list[str],
- ret: list[Any] | None = None,
- namespace: str | None = None,
- **kwargs: Any,
- ) -> None:
- if ret is None:
- return
- for key, val in zip(keys, ret):
- if namespace is not None:
- key = client.build_key(key, namespace)
- if val is not None and key in client._handlers:
- await client.expire(key, client.ttl)
|