File size: 5,344 Bytes
2235ef5
 
fc75fbd
2235ef5
 
fc75fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8744832
54770f1
 
 
 
 
8744832
 
 
 
 
54770f1
2235ef5
 
 
 
 
 
 
 
 
54770f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import ToTensor

NUM_RESIDUAL_GROUPS = 8
NUM_RESIDUAL_BLOCKS = 16
KERNEL_SIZE = 3
REDUCTION_RATIO = 16
NUM_CHANNELS = 64
UPSCALE_FACTOR = 4

class ResidualChannelAttentionBlock(nn.Module):
    def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
        
        super(ResidualChannelAttentionBlock, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
            nn.ReLU(),
            nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
        )

        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1),
            nn.ReLU(),
            # nn.BatchNorm2d(num_channels//reduction_ratio),
            nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        block_input = x.clone()

        residual = self.feature_extractor(x) # Feature extraction
        rescale = self.channel_attention(residual) # Rescaling vector
        
        block_output = block_input + (residual * rescale)
        
        return block_output

class ResidualGroup(nn.Module):
    def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
                 num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
        
        super(ResidualGroup, self).__init__()

        self.residual_blocks = nn.Sequential(
            *[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size) 
              for _ in range(num_residual_blocks)]
        )

        self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
    
    def forward(self, x):
        group_input = x.clone()

        residual = self.residual_blocks(x) # Residual blocks
        residual = self.final_conv(residual) # Final convolution

        group_output = group_input + residual

        return group_output

class ResidualInResidual(nn.Module):
    def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
                 num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
        
        super(ResidualInResidual, self).__init__()

        self.residual_groups = nn.Sequential(
            *[ResidualGroup(num_residual_blocks=num_residual_blocks,
                            num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size) 
              for _ in range(num_residual_groups)]
        )

        self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
    
    def forward(self, x):
        shallow_feature = x.clone()
        
        residual = self.residual_groups(x) # Residual groups
        residual = self.final_conv(residual) # Final convolution

        deep_feature = shallow_feature + residual

        return deep_feature

class RCAN(nn.Module):
    def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
                 num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
        
        super(RCAN, self).__init__()

        self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
        self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks,
                                                       num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
        self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR)
        self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
    
    def forward(self, x):
        shallow_feature = self.shallow_conv(x) # Initial convolution
        deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual
        upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
        reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction

        return reconstructed_image.clamp(0, 1)
    
    def inference(self, x):
        """
        x is a PIL image
        """
        self.eval()
        with torch.no_grad():
            x = ToTensor()(x).unsqueeze(0)
            x = self.forward(x)
            x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
        return x

if __name__ == '__main__':
    current_dir = os.path.dirname(os.path.realpath(__file__))

    model = RCAN()
    model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
    model.eval()
    with torch.no_grad():
        input_image = Image.open('images/demo.png')
        output_image = model.inference(input_image)