optimize.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from pathlib import Path
  2. import onnx
  3. import onnxruntime as ort
  4. import onnxsim
  5. def optimize_onnxsim(model_path: Path | str, output_path: Path | str) -> None:
  6. model_path = Path(model_path)
  7. output_path = Path(output_path)
  8. model = onnx.load(model_path.as_posix())
  9. model, check = onnxsim.simplify(model, skip_shape_inference=True)
  10. assert check, "Simplified ONNX model could not be validated"
  11. onnx.save(model, output_path.as_posix())
  12. def optimize_ort(
  13. model_path: Path | str,
  14. output_path: Path | str,
  15. level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
  16. ) -> None:
  17. model_path = Path(model_path)
  18. output_path = Path(output_path)
  19. sess_options = ort.SessionOptions()
  20. sess_options.graph_optimization_level = level
  21. sess_options.optimized_model_filepath = output_path.as_posix()
  22. ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"], sess_options=sess_options)
  23. def optimize(model_path: Path | str) -> None:
  24. model_path = Path(model_path)
  25. optimize_ort(model_path, model_path)
  26. # onnxsim serializes large models as a blob, which uses much more memory when loading the model at runtime
  27. if not any(file.name.startswith("Constant") for file in model_path.parent.iterdir()):
  28. optimize_onnxsim(model_path, model_path)