main.py 1023 B

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from pydantic import BaseModel
  2. from fastapi import FastAPI
  3. from .object_detection import object_detection
  4. from .image_classifier import image_classifier
  5. from tf2_yolov4.anchors import YOLOV4_ANCHORS
  6. from tf2_yolov4.model import YOLOv4
  7. HEIGHT, WIDTH = (640, 960)
  8. # Warm up model
  9. image_classifier.warm_up()
  10. app = FastAPI()
  11. class TagImagePayload(BaseModel):
  12. thumbnail_path: str
  13. @app.post("/tagImage")
  14. async def post_root(payload: TagImagePayload):
  15. image_path = payload.thumbnail_path
  16. if image_path[0] == '.':
  17. image_path = image_path[2:]
  18. return image_classifier.classify_image(image_path=image_path)
  19. @app.get("/")
  20. async def test():
  21. object_detection.run_detection()
  22. # image = tf.io.read_file("./app/cars.jpg")
  23. # image = tf.image.decode_image(image)
  24. # image = tf.image.resize(image, (HEIGHT, WIDTH))
  25. # images = tf.expand_dims(image, axis=0) / 255.0
  26. # model = YOLOv4(
  27. # (HEIGHT, WIDTH, 3),
  28. # 80,
  29. # YOLOV4_ANCHORS,
  30. # "darknet",
  31. # )