main.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import asyncio
  2. from functools import partial
  3. import threading
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Any, Callable, Type
  6. from zipfile import BadZipFile
  7. import faiss
  8. import numpy as np
  9. import orjson
  10. from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
  11. from fastapi.responses import ORJSONResponse
  12. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
  13. from starlette.formparsers import MultiPartParser
  14. from app.models.base import InferenceModel
  15. from .config import log, settings
  16. from .models.cache import ModelCache
  17. from .schemas import (
  18. MessageResponse,
  19. ModelType,
  20. TextResponse,
  21. )
  22. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  23. app = FastAPI()
  24. class VectorStore:
  25. def __init__(self, dims: int, index: str = "HNSW") -> None:
  26. self.index = faiss.index_factory(dims, index)
  27. self.key_to_id: dict[int, Any] = {}
  28. def search(self, embeddings: np.ndarray[int, np.dtype[Any]], k: int) -> list[Any]:
  29. keys: np.ndarray[int, np.dtype[np.int64]] = self.index.assign(embeddings, k)
  30. return [self.key_to_id[idx] for row in keys.tolist() for idx in row if not idx == -1]
  31. def add_with_ids(self, embeddings: np.ndarray[int, np.dtype[Any]], embedding_ids: list[Any]) -> None:
  32. self.key_to_id |= {
  33. key: id for key, id in zip(range(self.index.ntotal, self.index.ntotal + len(embedding_ids)), embedding_ids)
  34. }
  35. self.index.add(embeddings) # type: ignore
  36. @property
  37. def dims(self) -> int:
  38. return self.index.d
  39. vector_stores: dict[str, VectorStore] = {}
  40. def validate_embeddings(embeddings: list[float]) -> Any:
  41. embeddings = np.array(embeddings)
  42. if len(embeddings.shape) == 1:
  43. embeddings = np.expand_dims(embeddings, 0)
  44. elif len(embeddings.shape) != 2:
  45. raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
  46. if embeddings.shape[1] < 10:
  47. raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
  48. return embeddings
  49. def init_state() -> None:
  50. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  51. log.info(
  52. (
  53. "Created in-memory cache with unloading "
  54. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  55. )
  56. )
  57. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  58. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  59. app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
  60. app.state.index_lock = threading.Lock()
  61. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  62. @app.on_event("startup")
  63. async def startup_event() -> None:
  64. init_state()
  65. @app.get("/", response_model=MessageResponse)
  66. async def root() -> dict[str, str]:
  67. return {"message": "Immich ML"}
  68. @app.get("/ping", response_model=TextResponse)
  69. def ping() -> str:
  70. return "pong"
  71. @app.post("/pipeline", response_class=ORJSONResponse)
  72. async def pipeline(
  73. model_name: str = Form(alias="modelName"),
  74. model_type: ModelType = Form(alias="modelType"),
  75. options: str = Form(default="{}"),
  76. text: str | None = Form(default=None),
  77. image: UploadFile | None = None,
  78. index_name: str | None = Form(default=None),
  79. embedding_id: str | None = Form(default=None),
  80. k: int | None = Form(default=None),
  81. ) -> ORJSONResponse:
  82. if image is not None:
  83. inputs: str | bytes = await image.read()
  84. elif text is not None:
  85. inputs = text
  86. else:
  87. raise HTTPException(400, "Either image or text must be provided")
  88. try:
  89. kwargs = orjson.loads(options)
  90. except orjson.JSONDecodeError:
  91. raise HTTPException(400, f"Invalid options JSON: {options}")
  92. outputs = await _predict(model_name, model_type, inputs, **kwargs)
  93. if index_name is not None:
  94. expanded = np.expand_dims(outputs, 0)
  95. if k is not None:
  96. if k < 1:
  97. raise HTTPException(400, f"k must be a positive integer; got {k}")
  98. if index_name not in vector_stores:
  99. raise HTTPException(404, f"Index '{index_name}' not found")
  100. outputs = await run(vector_stores[index_name].search, expanded, k)
  101. if embedding_id is not None:
  102. if index_name not in vector_stores:
  103. await create(index_name, [embedding_id], expanded)
  104. else:
  105. await add(index_name, [embedding_id], expanded)
  106. return ORJSONResponse(outputs)
  107. @app.post("/predict", response_class=ORJSONResponse)
  108. async def predict(
  109. model_name: str = Form(alias="modelName"),
  110. model_type: ModelType = Form(alias="modelType"),
  111. options: str = Form(default="{}"),
  112. text: str | None = Form(default=None),
  113. image: UploadFile | None = None,
  114. ) -> ORJSONResponse:
  115. if image is not None:
  116. inputs: str | bytes = await image.read()
  117. elif text is not None:
  118. inputs = text
  119. else:
  120. raise HTTPException(400, "Either image or text must be provided")
  121. try:
  122. kwargs = orjson.loads(options)
  123. except orjson.JSONDecodeError:
  124. raise HTTPException(400, f"Invalid options JSON: {options}")
  125. outputs = await _predict(model_name, model_type, inputs, **kwargs)
  126. return ORJSONResponse(outputs)
  127. @app.post("/index/{index_name}/search", response_class=ORJSONResponse)
  128. async def search(index_name: str, embeddings: Any = Depends(validate_embeddings), k: int = 10) -> ORJSONResponse:
  129. if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
  130. raise HTTPException(404, f"Index '{index_name}' not found")
  131. outputs: np.ndarray[int, np.dtype[Any]] = await run(vector_stores[index_name].search, embeddings, k)
  132. return ORJSONResponse(outputs)
  133. @app.post("/index/{index_name}/add")
  134. async def add(
  135. index_name: str,
  136. embedding_ids: list[str],
  137. embeddings: Any = Depends(validate_embeddings),
  138. ) -> None:
  139. if index_name not in vector_stores or vector_stores[index_name].dims != embeddings.shape[1]:
  140. await create(index_name, embedding_ids, embeddings)
  141. else:
  142. log.info(f"Adding {len(embedding_ids)} embeddings to index '{index_name}'")
  143. await run(_add, vector_stores[index_name], embedding_ids, embeddings)
  144. @app.post("/index/{index_name}/create")
  145. async def create(
  146. index_name: str,
  147. embedding_ids: list[str],
  148. embeddings: Any = Depends(validate_embeddings),
  149. ) -> None:
  150. if embeddings.shape[0] != len(embedding_ids):
  151. raise HTTPException(
  152. 400,
  153. f"Number of embedding IDs must match number of embeddings; got {len(embedding_ids)} ID(s) and {embeddings.shape[0]} embedding(s)",
  154. )
  155. if index_name in vector_stores:
  156. log.warn(f"Index '{index_name}' already exists. Overwriting.")
  157. log.info(f"Creating new index '{index_name}'")
  158. vector_stores[index_name] = await run(_create, embedding_ids, embeddings)
  159. async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
  160. if app.state.thread_pool is None:
  161. return func(*args, **kwargs)
  162. if kwargs:
  163. func = partial(func, **kwargs)
  164. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, func, *args)
  165. def _load(model: InferenceModel) -> InferenceModel:
  166. if model.loaded:
  167. return model
  168. try:
  169. with app.state.model_locks[model.model_type]:
  170. if not model.loaded:
  171. model.load()
  172. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  173. log.warn(
  174. (
  175. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  176. "Clearing cache and retrying."
  177. )
  178. )
  179. model.clear_cache()
  180. model.load()
  181. return model
  182. async def _predict(
  183. model_name: str, model_type: ModelType, inputs: Any, **options: Any
  184. ) -> np.ndarray[int, np.dtype[np.float32]]:
  185. model = await app.state.model_cache.get(model_name, model_type, **options)
  186. if not model.loaded:
  187. await run(_load, model)
  188. model.configure(**options)
  189. return await run(model.predict, inputs)
  190. def _create(
  191. embedding_ids: list[str],
  192. embeddings: np.ndarray[int, np.dtype[np.float32]],
  193. ) -> VectorStore:
  194. index = VectorStore(embeddings.shape[1])
  195. _add(index, embedding_ids, embeddings)
  196. return index
  197. def _add(
  198. index: VectorStore,
  199. embedding_ids: list[str],
  200. embeddings: np.ndarray[int, np.dtype[np.float32]],
  201. ) -> None:
  202. with app.state.index_lock:
  203. index.add_with_ids(embeddings, embedding_ids) # type: ignore