transforms.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import numpy as np
  2. from PIL import Image
  3. from app.schemas import ndarray_f32
  4. _PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
  5. def resize(img: Image.Image, size: int) -> Image.Image:
  6. if img.width < img.height:
  7. return img.resize((size, int((img.height / img.width) * size)), resample=Image.BICUBIC)
  8. else:
  9. return img.resize((int((img.width / img.height) * size), size), resample=Image.BICUBIC)
  10. # https://stackoverflow.com/a/60883103
  11. def crop(img: Image.Image, size: int) -> Image.Image:
  12. left = int((img.size[0] / 2) - (size / 2))
  13. upper = int((img.size[1] / 2) - (size / 2))
  14. right = left + size
  15. lower = upper + size
  16. return img.crop((left, upper, right, lower))
  17. def to_numpy(img: Image.Image) -> ndarray_f32:
  18. return np.asarray(img.convert("RGB")).astype(np.float32) / 255.0
  19. def normalize(img: ndarray_f32, mean: float | ndarray_f32, std: float | ndarray_f32) -> ndarray_f32:
  20. return (img - mean) / std
  21. def get_pil_resampling(resample: str) -> Image.Resampling:
  22. return _PIL_RESAMPLING_METHODS[resample.lower()]