vincent-doan
Added clamping to delete noise in output
8744832
raw
history blame
5.34 kB
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import ToTensor
NUM_RESIDUAL_GROUPS = 8
NUM_RESIDUAL_BLOCKS = 16
KERNEL_SIZE = 3
REDUCTION_RATIO = 16
NUM_CHANNELS = 64
UPSCALE_FACTOR = 4
class ResidualChannelAttentionBlock(nn.Module):
def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualChannelAttentionBlock, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
nn.ReLU(),
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1),
nn.ReLU(),
# nn.BatchNorm2d(num_channels//reduction_ratio),
nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1),
nn.Sigmoid()
)
def forward(self, x):
block_input = x.clone()
residual = self.feature_extractor(x) # Feature extraction
rescale = self.channel_attention(residual) # Rescaling vector
block_output = block_input + (residual * rescale)
return block_output
class ResidualGroup(nn.Module):
def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualGroup, self).__init__()
self.residual_blocks = nn.Sequential(
*[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_blocks)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
group_input = x.clone()
residual = self.residual_blocks(x) # Residual blocks
residual = self.final_conv(residual) # Final convolution
group_output = group_input + residual
return group_output
class ResidualInResidual(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(ResidualInResidual, self).__init__()
self.residual_groups = nn.Sequential(
*[ResidualGroup(num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
for _ in range(num_residual_groups)]
)
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = x.clone()
residual = self.residual_groups(x) # Residual groups
residual = self.final_conv(residual) # Final convolution
deep_feature = shallow_feature + residual
return deep_feature
class RCAN(nn.Module):
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS,
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE):
super(RCAN, self).__init__()
self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks,
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size)
self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR)
self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
def forward(self, x):
shallow_feature = self.shallow_conv(x) # Initial convolution
deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual
upscaled_image = self.upscaling_module(deep_feature) # Upscaling module
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction
return reconstructed_image.clamp(0, 1)
def inference(self, x):
"""
x is a PIL image
"""
self.eval()
with torch.no_grad():
x = ToTensor()(x).unsqueeze(0)
x = self.forward(x)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return x
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = RCAN()
model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu')))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image)