main.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import asyncio
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. import numpy as np
  5. from typing import Any
  6. from zipfile import BadZipFile
  7. import faiss
  8. import orjson
  9. from fastapi import FastAPI, Form, HTTPException, UploadFile, Depends
  10. from fastapi.responses import ORJSONResponse
  11. from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
  12. from starlette.formparsers import MultiPartParser
  13. from app.models.base import InferenceModel
  14. from .config import log, settings
  15. from .models.cache import ModelCache
  16. from .schemas import (
  17. MessageResponse,
  18. ModelType,
  19. TextResponse,
  20. )
  21. MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
  22. app = FastAPI()
  23. vector_stores: dict[str, faiss.IndexIDMap2] = {}
  24. def validate_embeddings(embeddings: list[float] | np.ndarray[int, np.dtype[Any]]) -> np.ndarray[int, np.dtype[Any]]:
  25. embeddings = np.array(embeddings)
  26. if len(embeddings.shape) == 1:
  27. embeddings = np.expand_dims(embeddings, 0)
  28. elif len(embeddings.shape) != 2:
  29. raise HTTPException(400, f"Expected one or two axes for embeddings; got {len(embeddings.shape)}")
  30. if embeddings.shape[1] < 10:
  31. raise HTTPException(400, f"Dimension size must be at least 10; got {embeddings.shape[1]}")
  32. return embeddings
  33. def init_state() -> None:
  34. app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
  35. log.info(
  36. (
  37. "Created in-memory cache with unloading "
  38. f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
  39. )
  40. )
  41. # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
  42. app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
  43. app.state.model_locks = {model_type: threading.Lock() for model_type in ModelType}
  44. app.state.index_lock = threading.Lock()
  45. log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
  46. @app.on_event("startup")
  47. async def startup_event() -> None:
  48. init_state()
  49. @app.get("/", response_model=MessageResponse)
  50. async def root() -> dict[str, str]:
  51. return {"message": "Immich ML"}
  52. @app.get("/ping", response_model=TextResponse)
  53. def ping() -> str:
  54. return "pong"
  55. @app.post("/predict")
  56. async def predict(
  57. model_name: str = Form(alias="modelName"),
  58. model_type: ModelType = Form(alias="modelType"),
  59. options: str = Form(default="{}"),
  60. text: str | None = Form(default=None),
  61. image: UploadFile | None = None,
  62. ) -> Any:
  63. if image is not None:
  64. inputs: str | bytes = await image.read()
  65. elif text is not None:
  66. inputs = text
  67. else:
  68. raise HTTPException(400, "Either image or text must be provided")
  69. try:
  70. kwargs = orjson.loads(options)
  71. except orjson.JSONDecodeError:
  72. raise HTTPException(400, f"Invalid options JSON: {options}")
  73. model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
  74. model.configure(**kwargs)
  75. outputs = await run(model, inputs)
  76. return ORJSONResponse(outputs)
  77. @app.post("/index/{index_name}/search")
  78. async def search(
  79. index_name: str, embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings), k: int = 10
  80. ) -> None:
  81. if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
  82. raise HTTPException(404, f"Index '{index_name}' not found")
  83. return vector_stores[index_name].search(embeddings, k)[1] # type: ignore
  84. @app.patch("/index/{index_name}/add")
  85. async def add(
  86. index_name: str,
  87. embedding_ids: list[str],
  88. embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
  89. ) -> None:
  90. if index_name not in vector_stores or vector_stores[index_name].d != embeddings.shape[1]:
  91. await create(index_name, embedding_ids, embeddings)
  92. else:
  93. vector_stores[index_name].add_with_ids(embeddings, embedding_ids) # type: ignore
  94. @app.post("/index/{index_name}/create")
  95. async def create(
  96. index_name: str,
  97. embedding_ids: list[str],
  98. embeddings: np.ndarray[int, np.dtype[np.float32]] = Depends(validate_embeddings),
  99. ) -> None:
  100. if embeddings.shape[0] != len(embedding_ids):
  101. raise HTTPException(400, "Number of embedding IDs must match number of embeddings")
  102. if index_name in vector_stores:
  103. log.warn(f"Index '{index_name}' already exists. Overwriting.")
  104. hnsw_index = faiss.IndexHNSWFlat(embeddings.shape[1])
  105. mapped_index = faiss.IndexIDMap2(hnsw_index)
  106. def _create() -> faiss.IndexIDMap2:
  107. with app.state.index_lock:
  108. mapped_index.add_with_ids(embeddings, embedding_ids) # type: ignore
  109. return mapped_index
  110. vector_stores[index_name] = await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, _create)
  111. async def run(model: InferenceModel, inputs: Any) -> Any:
  112. if app.state.thread_pool is None:
  113. return model.predict(inputs)
  114. return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
  115. async def load(model: InferenceModel) -> InferenceModel:
  116. if model.loaded:
  117. return model
  118. def _load() -> None:
  119. with app.state.locks[model.model_type]:
  120. model.load()
  121. loop = asyncio.get_running_loop()
  122. try:
  123. if app.state.thread_pool is None:
  124. model.load()
  125. else:
  126. await loop.run_in_executor(app.state.thread_pool, _load)
  127. return model
  128. except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
  129. log.warn(
  130. (
  131. f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
  132. "Clearing cache and retrying."
  133. )
  134. )
  135. model.clear_cache()
  136. if app.state.thread_pool is None:
  137. model.load()
  138. else:
  139. await loop.run_in_executor(app.state.thread_pool, _load)
  140. return model