a2f5674bbb
* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
92 lines
3 KiB
Python
92 lines
3 KiB
Python
import asyncio
|
|
|
|
from aiocache.backends.memory import SimpleMemoryCache
|
|
from aiocache.lock import OptimisticLock
|
|
from aiocache.plugins import BasePlugin, TimingPlugin
|
|
|
|
from ..schemas import ModelType
|
|
from .base import InferenceModel
|
|
|
|
|
|
class ModelCache:
|
|
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
|
|
|
|
def __init__(
|
|
self,
|
|
ttl: float | None = None,
|
|
revalidate: bool = False,
|
|
timeout: int | None = None,
|
|
profiling: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
|
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
|
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
|
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
|
"""
|
|
|
|
self.ttl = ttl
|
|
plugins = []
|
|
|
|
if revalidate:
|
|
plugins.append(RevalidationPlugin())
|
|
if profiling:
|
|
plugins.append(TimingPlugin())
|
|
|
|
self.cache = SimpleMemoryCache(
|
|
ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
|
|
)
|
|
|
|
async def get(
|
|
self, model_name: str, model_type: ModelType, **model_kwargs
|
|
) -> InferenceModel:
|
|
"""
|
|
Args:
|
|
model_name: Name of model in the model hub used for the task.
|
|
model_type: Model type or task, which determines which model zoo is used.
|
|
|
|
Returns:
|
|
model: The requested model.
|
|
"""
|
|
|
|
key = self.cache.build_key(model_name, model_type.value)
|
|
model = await self.cache.get(key)
|
|
if model is None:
|
|
async with OptimisticLock(self.cache, key) as lock:
|
|
model = await asyncio.get_running_loop().run_in_executor(
|
|
None,
|
|
lambda: InferenceModel.from_model_type(
|
|
model_type, model_name, **model_kwargs
|
|
),
|
|
)
|
|
await lock.cas(model, ttl=self.ttl)
|
|
return model
|
|
|
|
async def get_profiling(self) -> dict[str, float] | None:
|
|
if not hasattr(self.cache, "profiling"):
|
|
return None
|
|
|
|
return self.cache.profiling # type: ignore
|
|
|
|
|
|
class RevalidationPlugin(BasePlugin):
|
|
"""Revalidates cache item's TTL after cache hit."""
|
|
|
|
async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
|
|
if ret is None:
|
|
return
|
|
if namespace is not None:
|
|
key = client.build_key(key, namespace)
|
|
if key in client._handlers:
|
|
await client.expire(key, client.ttl)
|
|
|
|
async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
|
|
if ret is None:
|
|
return
|
|
|
|
for key, val in zip(keys, ret):
|
|
if namespace is not None:
|
|
key = client.build_key(key, namespace)
|
|
if val is not None and key in client._handlers:
|
|
await client.expire(key, client.ttl)
|