main.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import asyncio
  2. from functools import partial
  3. import threading
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Any, Callable
  6. from zipfile import BadZipFile
  7. import faiss
  8. import numpy as np
  9. import orjson
  10. from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
  11. from fastapi.responses import ORJSONResponse
  12. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
  13. from starlette.formparsers import MultiPartParser
  14. from app.models.base import InferenceModel
  15. from .config import log, settings
  16. from .models.cache import ModelCache
  17. from .schemas import (
  18. MessageResponse,
  19. ModelType,
  20. TextResponse,
  21. )
  22. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  23. app = FastAPI()
  24. vector_stores: dict[str, faiss.IndexIDMap2] = {}
  25. def validate_embeddings(embeddings: list[float] | np.ndarray[int, np.dtype[Any]]) -> np.ndarray[int, np.dtype[Any]]:
  26. embeddings = np.array(embeddings)
  27. if len(embeddings.shape) == 1:
  28. embeddings = np.expand_dims(embeddings, 0)
  29. elif len(embeddings.shape) != 2:
  30. raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
  31. if embeddings.shape[1] < 10:
  32. raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
  33. return embeddings
  34. def init_state() -> None:
  35. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  36. log.info(
  37. (
  38. "Created in-memory cache with unloading "
  39. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  40. )
  41. )
  42. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  43. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  44. app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
  45. app.state.index_lock = threading.Lock()
  46. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  47. @app.on_event("startup")
  48. async def startup_event() -> None:
  49. init_state()
  50. @app.get("/", response_model=MessageResponse)
  51. async def root() -> dict[str, str]:
  52. return {"message": "Immich ML"}
  53. @app.get("/ping", response_model=TextResponse)
  54. def ping() -> str:
  55. return "pong"
  56. @app.post("/pipeline", response_class=ORJSONResponse)
  57. async def pipeline(
  58. model_name: str = Form(alias="modelName"),
  59. model_type: ModelType = Form(alias="modelType"),
  60. options: str = Form(default="{}"),
  61. text: str | None = Form(default=None),
  62. image: UploadFile | None = None,
  63. index_name: str | None = Form(default=None),
  64. embedding_id: str | None = Form(default=None),
  65. k: int | None = Form(default=None),
  66. ) -> ORJSONResponse:
  67. if image is not None:
  68. inputs: str | bytes = await image.read()
  69. elif text is not None:
  70. inputs = text
  71. else:
  72. raise HTTPException(400, "Either image or text must be provided")
  73. try:
  74. kwargs = orjson.loads(options)
  75. except orjson.JSONDecodeError:
  76. raise HTTPException(400, f"Invalid options JSON: {options}")
  77. outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
  78. if index_name is not None:
  79. if k is not None:
  80. if k < 1:
  81. raise HTTPException(400, f"k must be a positive integer; got {k}")
  82. outputs = await run(_search, index_name, outputs, k)
  83. if embedding_id is not None:
  84. await run(_add, index_name, [embedding_id], outputs)
  85. return ORJSONResponse(outputs)
  86. @app.post("/predict", response_class=ORJSONResponse)
  87. async def predict(
  88. model_name: str = Form(alias="modelName"),
  89. model_type: ModelType = Form(alias="modelType"),
  90. options: str = Form(default="{}"),
  91. text: str | None = Form(default=None),
  92. image: UploadFile | None = None,
  93. ) -> ORJSONResponse:
  94. if image is not None:
  95. inputs: str | bytes = await image.read()
  96. elif text is not None:
  97. inputs = text
  98. else:
  99. raise HTTPException(400, "Either image or text must be provided")
  100. try:
  101. kwargs = orjson.loads(options)
  102. except orjson.JSONDecodeError:
  103. raise HTTPException(400, f"Invalid options JSON: {options}")
  104. outputs = await run(_predict, model_name, model_type, inputs, **kwargs)
  105. return ORJSONResponse(outputs)
  106. @app.post("/index/{index_name}/search", response_class=ORJSONResponse)
  107. async def search(
  108. index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
  109. ) -> ORJSONResponse:
  110. if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
  111. raise HTTPException(404, f"Index '{index_name}' not found")
  112. outputs: np.ndarray[int, np.dtype[Any]] = await run(_search, index_name, embeddings, k)
  113. return ORJSONResponse(outputs)
  114. @app.post("/index/{index_name}/add")
  115. async def add(
  116. index_name: str,
  117. embedding_ids: list[str],
  118. embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
  119. ) -> None:
  120. if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
  121. await create(index_name, embedding_ids, embeddings)
  122. else:
  123. await run(_add, index_name, embedding_ids, embeddings)
  124. @app.post("/index/{index_name}/create")
  125. async def create(
  126. index_name: str,
  127. embedding_ids: list[str],
  128. embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
  129. ) -> None:
  130. if embeddings.shape[0] != len(embedding_ids):
  131. raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
  132. if index_name in vector_stores:
  133. log.warn(f"Index '{index_name}' already exists. Overwriting.")
  134. vector_stores[index_name] = await run(_create)
  135. async def run(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
  136. if app.state.thread_pool is None:
  137. return func(*args, **kwargs)
  138. if kwargs:
  139. func = partial(func, **kwargs)
  140. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, func, *args)
  141. async def _load(model: InferenceModel) -> InferenceModel:
  142. if model.loaded:
  143. return model
  144. try:
  145. with app.state.model_locks[model.model_type]:
  146. if not model.loaded:
  147. model.load()
  148. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  149. log.warn(
  150. (
  151. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  152. "Clearing cache and retrying."
  153. )
  154. )
  155. model.clear_cache()
  156. model.load()
  157. return model
  158. async def _add(index_name: str, embedding_ids: list[str], embeddings: np.ndarray[int, np.dtype[np.float32]]) -> None:
  159. return await vector_stores[index_name].add_with_ids(embeddings, embedding_ids) # type: ignore
  160. async def _search(
  161. index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]], k: int
  162. ) -> np.ndarray[int, np.dtype[Any]]:
  163. return await vector_stores[index_name].assign(embeddings, k) # type: ignore
  164. async def _predict(
  165. model_name: str, model_type: ModelType, inputs: Any, **options: Any
  166. ) -> np.ndarray[int, np.dtype[np.float32]]:
  167. model = await _load(await app.state.model_cache.get(model_name, model_type, **options))
  168. model.configure(**options)
  169. return await run(model.predict, inputs)
  170. async def _create(
  171. embedding_ids: list[str],
  172. embeddings: np.ndarray[int, np.dtype[np.float32]],
  173. ) -> faiss.IndexIDMap2:
  174. hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
  175. mapped_index = faiss.IndexIDMap2(hnsw_index)
  176. with app.state.index_lock:
  177. mapped_index.add_with_ids(embeddings, embedding_ids) # type: ignore
  178. return mapped_index