|
@@ -27,13 +27,10 @@ app = FastAPI()
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event() -> None:
|
|
|
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
|
|
- same_clip = settings.clip_image_model == settings.clip_text_model
|
|
|
- app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
|
|
|
- app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
|
|
|
models = [
|
|
|
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
|
|
|
- (settings.clip_image_model, app.state.clip_vision_type),
|
|
|
- (settings.clip_text_model, app.state.clip_text_type),
|
|
|
+ (settings.clip_image_model, ModelType.CLIP),
|
|
|
+ (settings.clip_text_model, ModelType.CLIP),
|
|
|
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
|
|
|
]
|
|
|
|
|
@@ -87,9 +84,7 @@ async def image_classification(
|
|
|
async def clip_encode_image(
|
|
|
image: Image.Image = Depends(dep_pil_image),
|
|
|
) -> list[float]:
|
|
|
- model = await app.state.model_cache.get(
|
|
|
- settings.clip_image_model, app.state.clip_vision_type
|
|
|
- )
|
|
|
+ model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
|
|
|
embedding = model.predict(image)
|
|
|
return embedding
|
|
|
|
|
@@ -100,9 +95,7 @@ async def clip_encode_image(
|
|
|
status_code=200,
|
|
|
)
|
|
|
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
|
|
- model = await app.state.model_cache.get(
|
|
|
- settings.clip_text_model, app.state.clip_text_type
|
|
|
- )
|
|
|
+ model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
|
|
|
embedding = model.predict(payload.text)
|
|
|
return embedding
|
|
|
|