Spaces:
Running
Running
import numpy as np | |
import torch | |
import sys | |
sys.path.append('models') | |
from SRFlow.code import imread, impad, load_model, t, rgb | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from torchvision.transforms import PILToTensor, ToPILImage | |
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'): | |
""" | |
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image. | |
Args: | |
- lr: PIL Image | |
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml. | |
- heat (float): Heat parameter for the SRFlow model. Default is 0.6. | |
Returns: | |
- sr: PIL Image | |
""" | |
model, opt = load_model(conf_path) | |
lr = PILToTensor()(lr).permute(1, 2, 0).numpy() | |
scale = opt['scale'] | |
pad_factor = 2 | |
h, w, c = lr.shape | |
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), | |
right=int(np.ceil(w / pad_factor) * pad_factor - w)) | |
lr_t = t(lr) | |
heat = opt['heat'] | |
sr_t = model.get_sr(lq=lr_t, heat=heat) | |
sr = rgb(torch.clamp(sr_t, 0, 1)) | |
sr = sr[:h * scale, :w * scale] | |
sr_img = Image.fromarray((sr).astype('uint8')) | |
return sr_img | |
def return_SRFlow_result_from_tensor(lr_tensor): | |
""" | |
Apply Super-Resolution using SRFlow model to the input batched BCHW tensor. | |
Args: | |
- lr_tensor: Batched BCHW tensor | |
Returns: | |
- sr_tensor: Processed batched BCHW tensor | |
""" | |
batch_size = lr_tensor.shape[0] | |
sr_list = [] | |
for b in range(batch_size): | |
lr_image = ToPILImage()(lr_tensor[b]) | |
sr_image = return_SRFlow_result(lr_image) | |
sr_tensor = PILToTensor()(sr_image).unsqueeze(0) | |
sr_list.append(sr_tensor) | |
sr_tensor = torch.cat(sr_list, dim=0) | |
return sr_tensor | |
if __name__ == '__main__': | |
lr = Image.open('images/demo.png') | |
lr_tensor = PILToTensor()(lr).unsqueeze(0) | |
print(lr_tensor.shape) | |
random_tensor = torch.randn(8, 3, 64, 64) | |
sr = return_SRFlow_result_from_tensor(lr_tensor) | |
print(sr) | |
# Show SR image of the first one in the batch | |
plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0))) | |
# plt.axis('off') | |
plt.show() |