Spaces:
Running
Running
File size: 5,344 Bytes
2235ef5 fc75fbd 2235ef5 fc75fbd 8744832 54770f1 8744832 54770f1 2235ef5 54770f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import ToTensor
NUM_RESIDUAL_GROUPS = 8
NUM_RESIDUAL_BLOCKS = 16
KERNEL_SIZE = 3
REDUCTION_RATIO = 16
NUM_CHANNELS = 64
UPSCALE_FACTOR = 4
class ResidualChannelAttentionBlock(nn.Module):
def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualChannelAttentionBlock, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1),
nn.ReLU(),
# nn.BatchNorm2d(num_channels//reduction_ratio),
nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1),
nn.Sigmoid()
)
def forward(self, x):
block_input = x.clone()
residual = self.feature_extractor(x) # Feature extraction
rescale = self.channel_attention(residual) # Rescaling vector
block_output = block_input + (residual * rescale)
return block_output
class ResidualGroup(nn.Module):
def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualGroup, self).__init__()
self.residual_blocks = nn.Sequential(
*[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_blocks)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
group_input = x.clone()
residual = self.residual_blocks(x) # Residual blocks
residual = self.final_conv(residual) # Final convolution
group_output = group_input + residual
return group_output
class ResidualInResidual(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualInResidual, self).__init__()
self.residual_groups = nn.Sequential(
*[ResidualGroup(num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_groups)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = x.clone()
residual = self.residual_groups(x) # Residual groups
residual = self.final_conv(residual) # Final convolution
deep_feature = shallow_feature + residual
return deep_feature
class RCAN(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(RCAN, self).__init__()
self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR)
self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = self.shallow_conv(x) # Initial convolution
deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual
upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
return reconstructed_image.clamp(0, 1)
def inference(self, x):
"""
x is a PIL image
"""
self.eval()
with torch.no_grad():
x = ToTensor()(x).unsqueeze(0)
x = self.forward(x)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return x
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = RCAN()
model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image) |