cache.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import Any
  2. from aiocache.backends.memory import SimpleMemoryCache
  3. from aiocache.lock import OptimisticLock
  4. from aiocache.plugins import BasePlugin, TimingPlugin
  5. from app.models import from_model_type
  6. from ..schemas import ModelType
  7. from .base import InferenceModel
  8. class ModelCache:
  9. """Fetches a model from an in-memory cache, instantiating it if it's missing."""
  10. def __init__(
  11. self,
  12. ttl: float | None = None,
  13. revalidate: bool = False,
  14. timeout: int | None = None,
  15. profiling: bool = False,
  16. ) -> None:
  17. """
  18. Args:
  19. ttl: Unloads model after this duration. Disabled if None. Defaults to None.
  20. revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
  21. timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
  22. profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
  23. """
  24. self.ttl = ttl
  25. plugins = []
  26. if revalidate:
  27. plugins.append(RevalidationPlugin())
  28. if profiling:
  29. plugins.append(TimingPlugin())
  30. self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
  31. async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
  32. """
  33. Args:
  34. model_name: Name of model in the model hub used for the task.
  35. model_type: Model type or task, which determines which model zoo is used.
  36. Returns:
  37. model: The requested model.
  38. """
  39. key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
  40. async with OptimisticLock(self.cache, key) as lock:
  41. model = await self.cache.get(key)
  42. if model is None:
  43. model = from_model_type(model_type, model_name, **model_kwargs)
  44. await lock.cas(model, ttl=self.ttl)
  45. return model
  46. async def get_profiling(self) -> dict[str, float] | None:
  47. if not hasattr(self.cache, "profiling"):
  48. return None
  49. return self.cache.profiling # type: ignore
  50. class RevalidationPlugin(BasePlugin):
  51. """Revalidates cache item's TTL after cache hit."""
  52. async def post_get(
  53. self,
  54. client: SimpleMemoryCache,
  55. key: str,
  56. ret: Any | None = None,
  57. namespace: str | None = None,
  58. **kwargs: Any,
  59. ) -> None:
  60. if ret is None:
  61. return
  62. if namespace is not None:
  63. key = client.build_key(key, namespace)
  64. if key in client._handlers:
  65. await client.expire(key, client.ttl)
  66. async def post_multi_get(
  67. self,
  68. client: SimpleMemoryCache,
  69. keys: list[str],
  70. ret: list[Any] | None = None,
  71. namespace: str | None = None,
  72. **kwargs: Any,
  73. ) -> None:
  74. if ret is None:
  75. return
  76. for key, val in zip(keys, ret):
  77. if namespace is not None:
  78. key = client.build_key(key, namespace)
  79. if val is not None and key in client._handlers:
  80. await client.expire(key, client.ttl)