from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torchvision def denormalize(images: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: """ Denormalize an image array to [0,1]. """ return (images / 2 + 0.5).clamp(0, 1) def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: """ Convert a NumPy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [ PIL.Image.fromarray(image.squeeze(), mode="L") for image in images ] else: pil_images = [PIL.Image.fromarray(image) for image in images] return pil_images def postprocess_image( image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ) -> Union[torch.Tensor, np.ndarray, PIL.Image.Image]: if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type == "latent": return image do_normalize_flg = True if do_denormalize is None: do_denormalize = [do_normalize_flg] * image.shape[0] image = torch.stack( [ denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0]) ] ) if output_type == "pt": return image image = pt_to_numpy(image) if output_type == "np": return image if output_type == "pil": return numpy_to_pil(image) def process_image( image_pil: PIL.Image.Image, range: Tuple[int, int] = (-1, 1) ) -> Tuple[torch.Tensor, PIL.Image.Image]: image = torchvision.transforms.ToTensor()(image_pil) r_min, r_max = range[0], range[1] image = image * (r_max - r_min) + r_min return image[None, ...], image_pil def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: height = image_pil.height width = image_pil.width imgs = [] img, _ = process_image(image_pil) imgs.append(img) imgs = torch.vstack(imgs) images = torch.nn.functional.interpolate( imgs, size=(height, width), mode="bilinear" ) image_tensors = images.to(torch.float16) return image_tensors