hail75's picture
7
0564946
import numpy as np
import torch
import sys
sys.path.append('models')
from SRFlow.code import imread, impad, load_model, t, rgb
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import PILToTensor, ToPILImage
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
"""
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
Args:
- lr: PIL Image
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
Returns:
- sr: PIL Image
"""
model, opt = load_model(conf_path)
lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
scale = opt['scale']
pad_factor = 2
h, w, c = lr.shape
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
right=int(np.ceil(w / pad_factor) * pad_factor - w))
lr_t = t(lr)
heat = opt['heat']
sr_t = model.get_sr(lq=lr_t, heat=heat)
sr = rgb(torch.clamp(sr_t, 0, 1))
sr = sr[:h * scale, :w * scale]
sr_img = Image.fromarray((sr).astype('uint8'))
return sr_img
def return_SRFlow_result_from_tensor(lr_tensor):
"""
Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
Args:
- lr_tensor: Batched BCHW tensor
Returns:
- sr_tensor: Processed batched BCHW tensor
"""
batch_size = lr_tensor.shape[0]
sr_list = []
for b in range(batch_size):
lr_image = ToPILImage()(lr_tensor[b])
sr_image = return_SRFlow_result(lr_image)
sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
sr_list.append(sr_tensor)
sr_tensor = torch.cat(sr_list, dim=0)
return sr_tensor
if __name__ == '__main__':
lr = Image.open('images/demo.png')
lr_tensor = PILToTensor()(lr).unsqueeze(0)
print(lr_tensor.shape)
random_tensor = torch.randn(8, 3, 64, 64)
sr = return_SRFlow_result_from_tensor(lr_tensor)
print(sr)
# Show SR image of the first one in the batch
plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
# plt.axis('off')
plt.show()