cache.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import asyncio
  2. from typing import Any
  3. from aiocache.backends.memory import SimpleMemoryCache
  4. from aiocache.lock import OptimisticLock
  5. from aiocache.plugins import BasePlugin, TimingPlugin
  6. from ..schemas import ModelType
  7. from .base import InferenceModel
  8. class ModelCache:
  9. """Fetches a model from an in-memory cache, instantiating it if it's missing."""
  10. def __init__(
  11. self,
  12. ttl: float | None = None,
  13. revalidate: bool = False,
  14. timeout: int | None = None,
  15. profiling: bool = False,
  16. ):
  17. """
  18. Args:
  19. ttl: Unloads model after this duration. Disabled if None. Defaults to None.
  20. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
  21. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
  22. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
  23. """
  24. self.ttl = ttl
  25. plugins = []
  26. if revalidate:
  27. plugins.append(RevalidationPlugin())
  28. if profiling:
  29. plugins.append(TimingPlugin())
  30. self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
  31. async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
  32. """
  33. Args:
  34. model_name: Name of model in the model hub used for the task.
  35. model_type: Model type or task, which determines which model zoo is used.
  36. Returns:
  37. model: The requested model.
  38. """
  39. key = self.cache.build_key(model_name, model_type.value)
  40. model = await self.cache.get(key)
  41. if model is None:
  42. async with OptimisticLock(self.cache, key) as lock:
  43. model = await asyncio.get_running_loop().run_in_executor(
  44. None,
  45. lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
  46. )
  47. await lock.cas(model, ttl=self.ttl)
  48. return model
  49. async def get_profiling(self) -> dict[str, float] | None:
  50. if not hasattr(self.cache, "profiling"):
  51. return None
  52. return self.cache.profiling # type: ignore
  53. class RevalidationPlugin(BasePlugin):
  54. """Revalidates cache item's TTL after cache hit."""
  55. async def post_get(
  56. self,
  57. client: SimpleMemoryCache,
  58. key: str,
  59. ret: Any | None = None,
  60. namespace: str | None = None,
  61. **kwargs: Any,
  62. ) -> None:
  63. if ret is None:
  64. return
  65. if namespace is not None:
  66. key = client.build_key(key, namespace)
  67. if key in client._handlers:
  68. await client.expire(key, client.ttl)
  69. async def post_multi_get(
  70. self,
  71. client: SimpleMemoryCache,
  72. keys: list[str],
  73. ret: list[Any] | None = None,
  74. namespace: str | None = None,
  75. **kwargs: Any,
  76. ) -> None:
  77. if ret is None:
  78. return
  79. for key, val in zip(keys, ret):
  80. if namespace is not None:
  81. key = client.build_key(key, namespace)
  82. if val is not None and key in client._handlers:
  83. await client.expire(key, client.ttl)