main.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import asyncio
  2. from functools import partial
  3. import threading
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Any, Callable, Type
  6. from zipfile import BadZipFile
  7. import faiss
  8. import numpy as np
  9. import orjson
  10. from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
  11. from fastapi.responses import ORJSONResponse
  12. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
  13. from starlette.formparsers import MultiPartParser
  14. from app.models.base import InferenceModel
  15. from .config import log, settings
  16. from .models.cache import ModelCache
  17. from .schemas import (
  18. MessageResponse,
  19. ModelType,
  20. TextResponse,
  21. )
  22. # import rich.pretty
  23. # rich.pretty.install()
  24. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  25. app = FastAPI()
  26. class VectorStore:
  27. def __init__(self, dims: int, index: str = "HNSW") -> None:
  28. self.index = faiss.index_factory(dims, index)
  29. self.key_to_id: dict[int, Any] = {}
  30. def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
  31. keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
  32. return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
  33. def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
  34. self.key_to_id |= {
  35. key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
  36. }
  37. self.index.add(embeddings) # type: ignore
  38. @property
  39. def dims(self) -> int:
  40. return self.index.d
  41. vector_stores: dict[str, VectorStore] = {}
  42. def validate_embeddings(embeddings: list[float]) -> Any:
  43. embeddings = np.array(embeddings)
  44. if len(embeddings.shape) == 1:
  45. embeddings = np.expand_dims(embeddings, 0)
  46. elif len(embeddings.shape) != 2:
  47. raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
  48. if embeddings.shape[1] < 10:
  49. raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
  50. return embeddings
  51. def init_state() -> None:
  52. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  53. log.info(
  54. (
  55. "Created in-memory cache with unloading "
  56. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  57. )
  58. )
  59. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  60. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  61. app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
  62. app.state.index_lock = threading.Lock()
  63. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  64. @app.on_event("startup")
  65. async def startup_event() -> None:
  66. init_state()
  67. @app.get("/", response_model=MessageResponse)
  68. async def root() -> dict[str, str]:
  69. return {"message": "Immich ML"}
  70. @app.get("/ping", response_model=TextResponse)
  71. def ping() -> str:
  72. return "pong"
  73. @app.post("/pipeline", response_class=ORJSONResponse)
  74. async def pipeline(
  75. model_name: str = Form(alias="modelName"),
  76. model_type: ModelType = Form(alias="modelType"),
  77. options: str = Form(default="{}"),
  78. text: str | None = Form(default=None),
  79. image: UploadFile | None = None,
  80. index_name: str | None = Form(default=None),
  81. embedding_id: str | None = Form(default=None),
  82. k: int | None = Form(default=None),
  83. ) -> ORJSONResponse:
  84. if image is not None:
  85. inputs: str | bytes = await image.read()
  86. elif text is not None:
  87. inputs = text
  88. else:
  89. raise HTTPException(400, "Either image or text must be provided")
  90. try:
  91. kwargs = orjson.loads(options)
  92. except orjson.JSONDecodeError:
  93. raise HTTPException(400, f"Invalid options JSON: {options}")
  94. outputs = await _predict(model_name, model_type, inputs, **kwargs)
  95. if index_name is not None:
  96. if k is not None:
  97. if k < 1:
  98. raise HTTPException(400, f"k must be a positive integer; got {k}")
  99. if index_name not in vector_stores:
  100. raise HTTPException(404, f"Index '{index_name}' not found")
  101. outputs = await run(vector_stores[index_name].search, outputs, k)
  102. if embedding_id is not None:
  103. if index_name not in vector_stores:
  104. await create(index_name, [embedding_id], outputs)
  105. else:
  106. await add(index_name, [embedding_id], outputs)
  107. return ORJSONResponse(outputs)
  108. @app.post("/predict", response_class=ORJSONResponse)
  109. async def predict(
  110. model_name: str = Form(alias="modelName"),
  111. model_type: ModelType = Form(alias="modelType"),
  112. options: str = Form(default="{}"),
  113. text: str | None = Form(default=None),
  114. image: UploadFile | None = None,
  115. ) -> ORJSONResponse:
  116. if image is not None:
  117. inputs: str | bytes = await image.read()
  118. elif text is not None:
  119. inputs = text
  120. else:
  121. raise HTTPException(400, "Either image or text must be provided")
  122. try:
  123. kwargs = orjson.loads(options)
  124. except orjson.JSONDecodeError:
  125. raise HTTPException(400, f"Invalid options JSON: {options}")
  126. outputs = await _predict(model_name, model_type, inputs, **kwargs)
  127. return ORJSONResponse(outputs)
  128. @app.post("/index/{index_name}/search", response_class=ORJSONResponse)
  129. async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
  130. if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
  131. raise HTTPException(404, f"Index '{index_name}' not found")
  132. outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
  133. return ORJSONResponse(outputs)
  134. @app.post("/index/{index_name}/add")
  135. async def add(
  136. index_name: str,
  137. embedding_ids: list[str],
  138. embeddings: Any = Depends(validate_embeddings),
  139. ) -> None:
  140. if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
  141. await create(index_name, embedding_ids, embeddings)
  142. else:
  143. log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
  144. await run(_add, vector_stores[index_name], embedding_ids, embeddings)
  145. @app.post("/index/{index_name}/create")
  146. async def create(
  147. index_name: str,
  148. embedding_ids: list[str],
  149. embeddings: Any = Depends(validate_embeddings),
  150. ) -> None:
  151. if embeddings.shape[0] != len(embedding_ids):
  152. raise HTTPException(
  153. 400,
  154. f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
  155. )
  156. if index_name in vector_stores:
  157. log.warn(f"Index '{index_name}' already exists. Overwriting.")
  158. log.info(f"Creating new index '{index_name}'")
  159. vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
  160. async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
  161. if app.state.thread_pool is None:
  162. return func(*args, **kwargs)
  163. if kwargs:
  164. func = partial(func, **kwargs)
  165. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, func, *args)
  166. def _load(model: InferenceModel) -> InferenceModel:
  167. if model.loaded:
  168. return model
  169. try:
  170. with app.state.model_locks[model.model_type]:
  171. if not model.loaded:
  172. model.load()
  173. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  174. log.warn(
  175. (
  176. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  177. "Clearing cache and retrying."
  178. )
  179. )
  180. model.clear_cache()
  181. model.load()
  182. return model
  183. async def _predict(
  184. model_name: str, model_type: ModelType, inputs: Any, **options: Any
  185. ) -> np.ndarray[int, np.dtype[np.float32]]:
  186. model = await app.state.model_cache.get(model_name, model_type, **options)
  187. if not model.loaded:
  188. await run(_load, model)
  189. model.configure(**options)
  190. return await run(model.predict, inputs)
  191. def _create(
  192. embedding_ids: list[str],
  193. embeddings: np.ndarray[int, np.dtype[np.float32]],
  194. ) -> VectorStore:
  195. index = VectorStore(embeddings.shape[1])
  196. _add(index, embedding_ids, embeddings)
  197. return index
  198. def _add(
  199. index: VectorStore,
  200. embedding_ids: list[str],
  201. embeddings: np.ndarray[int, np.dtype[np.float32]],
  202. ) -> None:
  203. with app.state.index_lock:
  204. index.add_with_ids(embeddings, embedding_ids) # type: ignore