cache.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from aiocache.plugins import TimingPlugin, BasePlugin
  2. from aiocache.backends.memory import SimpleMemoryCache
  3. from aiocache.lock import OptimisticLock
  4. from typing import Any
  5. from models import get_model
  6. class ModelCache:
  7. """Fetches a model from an in-memory cache, instantiating it if it's missing."""
  8. def __init__(
  9. self,
  10. ttl: int | None = None,
  11. revalidate: bool = False,
  12. timeout: int | None = None,
  13. profiling: bool = False,
  14. ):
  15. """
  16. Args:
  17. ttl: Unloads model after this duration. Disabled if None. Defaults to None.
  18. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
  19. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
  20. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
  21. """
  22. self.ttl = ttl
  23. plugins = []
  24. if revalidate:
  25. plugins.append(RevalidationPlugin())
  26. if profiling:
  27. plugins.append(TimingPlugin())
  28. self.cache = SimpleMemoryCache(
  29. ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
  30. )
  31. async def get_cached_model(
  32. self, model_name: str, model_type: str, **model_kwargs
  33. ) -> Any:
  34. """
  35. Args:
  36. model_name: Name of model in the model hub used for the task.
  37. model_type: Model type or task, which determines which model zoo is used.
  38. Returns:
  39. model: The requested model.
  40. """
  41. key = self.cache.build_key(model_name, model_type)
  42. model = await self.cache.get(key)
  43. if model is None:
  44. async with OptimisticLock(self.cache, key) as lock:
  45. model = get_model(model_name, model_type, **model_kwargs)
  46. await lock.cas(model, ttl=self.ttl)
  47. return model
  48. async def get_profiling(self) -> dict[str, float] | None:
  49. if not hasattr(self.cache, "profiling"):
  50. return None
  51. return self.cache.profiling # type: ignore
  52. class RevalidationPlugin(BasePlugin):
  53. """Revalidates cache item's TTL after cache hit."""
  54. async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
  55. if ret is None:
  56. return
  57. if namespace is not None:
  58. key = client.build_key(key, namespace)
  59. if key in client._handlers:
  60. await client.expire(key, client.ttl)
  61. async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
  62. if ret is None:
  63. return
  64. for key, val in zip(keys, ret):
  65. if namespace is not None:
  66. key = client.build_key(key, namespace)
  67. if val is not None and key in client._handlers:
  68. await client.expire(key, client.ttl)