File size: 2,238 Bytes
5b83793
 
 
 
 
95110bc
405a22d
 
5b83793
0564946
5b83793
 
 
 
95110bc
5b83793
 
 
 
95110bc
5b83793
 
d89aac0
95110bc
96e29c0
5b83793
 
 
 
 
 
 
d89aac0
5b83793
 
 
 
 
 
 
6e95c2f
 
95110bc
0564946
12f4dcf
405a22d
12f4dcf
 
405a22d
12f4dcf
 
405a22d
12f4dcf
405a22d
 
12f4dcf
405a22d
 
0564946
405a22d
dc49371
405a22d
12f4dcf
405a22d
dc49371
d0f6f9b
12f4dcf
95110bc
405a22d
 
0564946
82f1d3f
6e95c2f
0564946
6e95c2f
45b47cc
82f1d3f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import torch
import sys
sys.path.append('models')
from SRFlow.code import imread, impad, load_model, t, rgb
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import PILToTensor, ToPILImage

def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'):
    """
    Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.

    Args:
    - lr: PIL Image
    - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
    - heat (float): Heat parameter for the SRFlow model. Default is 0.6.

    Returns:
    - sr: PIL Image
    """
    model, opt = load_model(conf_path)

    lr = PILToTensor()(lr).permute(1, 2, 0).numpy()

    scale = opt['scale']
    pad_factor = 2

    h, w, c = lr.shape
    lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
                        right=int(np.ceil(w / pad_factor) * pad_factor - w))
    
    lr_t = t(lr)
    heat = opt['heat']
    
    sr_t = model.get_sr(lq=lr_t, heat=heat)

    sr = rgb(torch.clamp(sr_t, 0, 1))
    sr = sr[:h * scale, :w * scale]

    sr_img = Image.fromarray((sr).astype('uint8'))
    return sr_img

def return_SRFlow_result_from_tensor(lr_tensor):
    """
    Apply Super-Resolution using SRFlow model to the input batched BCHW tensor.

    Args:
    - lr_tensor: Batched BCHW tensor

    Returns:
    - sr_tensor: Processed batched BCHW tensor
    """
    batch_size = lr_tensor.shape[0]
    sr_list = []

    for b in range(batch_size):
        lr_image = ToPILImage()(lr_tensor[b])
        sr_image = return_SRFlow_result(lr_image)
        sr_tensor = PILToTensor()(sr_image).unsqueeze(0)
        
        sr_list.append(sr_tensor)

    sr_tensor = torch.cat(sr_list, dim=0)

    return sr_tensor 

if __name__ == '__main__':
    lr = Image.open('images/demo.png')
    
    lr_tensor = PILToTensor()(lr).unsqueeze(0) 
    print(lr_tensor.shape)
    random_tensor = torch.randn(8, 3, 64, 64)
    sr = return_SRFlow_result_from_tensor(lr_tensor)
    print(sr)

    # Show SR image of the first one in the batch
    plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0)))
    # plt.axis('off')
    plt.show()