Spaces:
Running
Running
File size: 2,238 Bytes
5b83793 95110bc 405a22d 5b83793 0564946 5b83793 95110bc 5b83793 95110bc 5b83793 d89aac0 95110bc 96e29c0 5b83793 d89aac0 5b83793 6e95c2f 95110bc 0564946 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 12f4dcf 405a22d 0564946 405a22d dc49371 405a22d 12f4dcf 405a22d dc49371 d0f6f9b 12f4dcf 95110bc 405a22d 0564946 82f1d3f 6e95c2f 0564946 6e95c2f 45b47cc 82f1d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import numpy as np
import torch
import sys
sys.path.append('models')
from SRFlow.code import imread, impad, load_model, t, rgb
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import PILToTensor, ToPILImage
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
"""
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
Args:
- lr: PIL Image
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
Returns:
- sr: PIL Image
"""
model, opt = load_model(conf_path)
lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
scale = opt['scale']
pad_factor = 2
h, w, c = lr.shape
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
right=int(np.ceil(w / pad_factor) * pad_factor - w))
lr_t = t(lr)
heat = opt['heat']
sr_t = model.get_sr(lq=lr_t, heat=heat)
sr = rgb(torch.clamp(sr_t, 0, 1))
sr = sr[:h * scale, :w * scale]
sr_img = Image.fromarray((sr).astype('uint8'))
return sr_img
def return_SRFlow_result_from_tensor(lr_tensor):
"""
Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.
Args:
- lr_tensor: Batched BCHW tensor
Returns:
- sr_tensor: Processed batched BCHW tensor
"""
batch_size = lr_tensor.shape[0]
sr_list = []
for b in range(batch_size):
lr_image = ToPILImage()(lr_tensor[b])
sr_image = return_SRFlow_result(lr_image)
sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
sr_list.append(sr_tensor)
sr_tensor = torch.cat(sr_list, dim=0)
return sr_tensor
if __name__ == '__main__':
lr = Image.open('images/demo.png')
lr_tensor = PILToTensor()(lr).unsqueeze(0)
print(lr_tensor.shape)
random_tensor = torch.randn(8, 3, 64, 64)
sr = return_SRFlow_result_from_tensor(lr_tensor)
print(sr)
# Show SR image of the first one in the batch
plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
# plt.axis('off')
plt.show() |