util.py 581 B

123456789101112131415161718192021
  1. import json
  2. from enum import Enum
  3. from pathlib import Path
  4. from typing import Any
  5. class ModelType(Enum):
  6. ONNX = "onnx"
  7. TFLITE = "tflite"
  8. def get_model_path(output_dir: Path | str, model_type: ModelType = ModelType.ONNX) -> Path:
  9. output_dir = Path(output_dir)
  10. output_dir.mkdir(parents=True, exist_ok=True)
  11. return output_dir / f"model.{model_type.value}"
  12. def save_config(config: Any, output_path: Path | str) -> None:
  13. output_path = Path(output_path)
  14. output_path.parent.mkdir(parents=True, exist_ok=True)
  15. json.dump(config, output_path.open("w"))