cache.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import asyncio
  2. from aiocache.backends.memory import SimpleMemoryCache
  3. from aiocache.lock import OptimisticLock
  4. from aiocache.plugins import BasePlugin, TimingPlugin
  5. from ..schemas import ModelType
  6. from .base import InferenceModel
  7. class ModelCache:
  8. """Fetches a model from an in-memory cache, instantiating it if it's missing."""
  9. def __init__(
  10. self,
  11. ttl: float | None = None,
  12. revalidate: bool = False,
  13. timeout: int | None = None,
  14. profiling: bool = False,
  15. ):
  16. """
  17. Args:
  18. ttl: Unloads model after this duration. Disabled if None. Defaults to None.
  19. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
  20. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
  21. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
  22. """
  23. self.ttl = ttl
  24. plugins = []
  25. if revalidate:
  26. plugins.append(RevalidationPlugin())
  27. if profiling:
  28. plugins.append(TimingPlugin())
  29. self.cache = SimpleMemoryCache(
  30. ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
  31. )
  32. async def get(
  33. self, model_name: str, model_type: ModelType, **model_kwargs
  34. ) -> InferenceModel:
  35. """
  36. Args:
  37. model_name: Name of model in the model hub used for the task.
  38. model_type: Model type or task, which determines which model zoo is used.
  39. Returns:
  40. model: The requested model.
  41. """
  42. key = self.cache.build_key(model_name, model_type.value)
  43. model = await self.cache.get(key)
  44. if model is None:
  45. async with OptimisticLock(self.cache, key) as lock:
  46. model = await asyncio.get_running_loop().run_in_executor(
  47. None,
  48. lambda: InferenceModel.from_model_type(
  49. model_type, model_name, **model_kwargs
  50. ),
  51. )
  52. await lock.cas(model, ttl=self.ttl)
  53. return model
  54. async def get_profiling(self) -> dict[str, float] | None:
  55. if not hasattr(self.cache, "profiling"):
  56. return None
  57. return self.cache.profiling # type: ignore
  58. class RevalidationPlugin(BasePlugin):
  59. """Revalidates cache item's TTL after cache hit."""
  60. async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
  61. if ret is None:
  62. return
  63. if namespace is not None:
  64. key = client.build_key(key, namespace)
  65. if key in client._handlers:
  66. await client.expire(key, client.ttl)
  67. async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
  68. if ret is None:
  69. return
  70. for key, val in zip(keys, ret):
  71. if namespace is not None:
  72. key = client.build_key(key, namespace)
  73. if val is not None and key in client._handlers:
  74. await client.expire(key, client.ttl)