cache.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Any
  2. from aiocache.backends.memory import SimpleMemoryCache
  3. from aiocache.lock import OptimisticLock
  4. from aiocache.plugins import BasePlugin, TimingPlugin
  5. from ..schemas import ModelType
  6. from .base import InferenceModel
  7. class ModelCache:
  8. """Fetches a model from an in-memory cache, instantiating it if it's missing."""
  9. def __init__(
  10. self,
  11. ttl: float | None = None,
  12. revalidate: bool = False,
  13. timeout: int | None = None,
  14. profiling: bool = False,
  15. ):
  16. """
  17. Args:
  18. ttl: Unloads model after this duration. Disabled if None. Defaults to None.
  19. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
  20. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
  21. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
  22. """
  23. self.ttl = ttl
  24. plugins = []
  25. if revalidate:
  26. plugins.append(RevalidationPlugin())
  27. if profiling:
  28. plugins.append(TimingPlugin())
  29. self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
  30. async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
  31. """
  32. Args:
  33. model_name: Name of model in the model hub used for the task.
  34. model_type: Model type or task, which determines which model zoo is used.
  35. Returns:
  36. model: The requested model.
  37. """
  38. key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
  39. async with OptimisticLock(self.cache, key) as lock:
  40. model = await self.cache.get(key)
  41. if model is None:
  42. model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
  43. await lock.cas(model, ttl=self.ttl)
  44. return model
  45. async def get_profiling(self) -> dict[str, float] | None:
  46. if not hasattr(self.cache, "profiling"):
  47. return None
  48. return self.cache.profiling # type: ignore
  49. class RevalidationPlugin(BasePlugin):
  50. """Revalidates cache item's TTL after cache hit."""
  51. async def post_get(
  52. self,
  53. client: SimpleMemoryCache,
  54. key: str,
  55. ret: Any | None = None,
  56. namespace: str | None = None,
  57. **kwargs: Any,
  58. ) -> None:
  59. if ret is None:
  60. return
  61. if namespace is not None:
  62. key = client.build_key(key, namespace)
  63. if key in client._handlers:
  64. await client.expire(key, client.ttl)
  65. async def post_multi_get(
  66. self,
  67. client: SimpleMemoryCache,
  68. keys: list[str],
  69. ret: list[Any] | None = None,
  70. namespace: str | None = None,
  71. **kwargs: Any,
  72. ) -> None:
  73. if ret is None:
  74. return
  75. for key, val in zip(keys, ret):
  76. if namespace is not None:
  77. key = client.build_key(key, namespace)
  78. if val is not None and key in client._handlers:
  79. await client.expire(key, client.ttl)