Spaces:
Running
Running
import os | |
import torch | |
from torch import nn | |
from PIL import Image | |
from torchvision.transforms import ToTensor | |
NUM_RESIDUAL_GROUPS = 8 | |
NUM_RESIDUAL_BLOCKS = 16 | |
KERNEL_SIZE = 3 | |
REDUCTION_RATIO = 16 | |
NUM_CHANNELS = 64 | |
UPSCALE_FACTOR = 4 | |
class ResidualChannelAttentionBlock(nn.Module): | |
def __init__(self, num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE): | |
super(ResidualChannelAttentionBlock, self).__init__() | |
self.feature_extractor = nn.Sequential( | |
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2), | |
nn.ReLU(), | |
nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2) | |
) | |
self.channel_attention = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(num_channels, num_channels//reduction_ratio, kernel_size=1, stride=1), | |
nn.ReLU(), | |
# nn.BatchNorm2d(num_channels//reduction_ratio), | |
nn.Conv2d(num_channels//reduction_ratio, num_channels, kernel_size=1, stride=1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
block_input = x.clone() | |
residual = self.feature_extractor(x) # Feature extraction | |
rescale = self.channel_attention(residual) # Rescaling vector | |
block_output = block_input + (residual * rescale) | |
return block_output | |
class ResidualGroup(nn.Module): | |
def __init__(self, num_residual_blocks=NUM_RESIDUAL_BLOCKS, | |
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE): | |
super(ResidualGroup, self).__init__() | |
self.residual_blocks = nn.Sequential( | |
*[ResidualChannelAttentionBlock(num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size) | |
for _ in range(num_residual_blocks)] | |
) | |
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2) | |
def forward(self, x): | |
group_input = x.clone() | |
residual = self.residual_blocks(x) # Residual blocks | |
residual = self.final_conv(residual) # Final convolution | |
group_output = group_input + residual | |
return group_output | |
class ResidualInResidual(nn.Module): | |
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS, | |
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE): | |
super(ResidualInResidual, self).__init__() | |
self.residual_groups = nn.Sequential( | |
*[ResidualGroup(num_residual_blocks=num_residual_blocks, | |
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size) | |
for _ in range(num_residual_groups)] | |
) | |
self.final_conv = nn.Conv2d(num_channels, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2) | |
def forward(self, x): | |
shallow_feature = x.clone() | |
residual = self.residual_groups(x) # Residual groups | |
residual = self.final_conv(residual) # Final convolution | |
deep_feature = shallow_feature + residual | |
return deep_feature | |
class RCAN(nn.Module): | |
def __init__(self, num_residual_groups=NUM_RESIDUAL_GROUPS, num_residual_blocks=NUM_RESIDUAL_BLOCKS, | |
num_channels=NUM_CHANNELS, reduction_ratio=REDUCTION_RATIO, kernel_size=KERNEL_SIZE): | |
super(RCAN, self).__init__() | |
self.shallow_conv = nn.Conv2d(3, num_channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2) | |
self.residual_in_residual = ResidualInResidual(num_residual_groups=num_residual_groups, num_residual_blocks=num_residual_blocks, | |
num_channels=num_channels, reduction_ratio=reduction_ratio, kernel_size=kernel_size) | |
self.upscaling_module = nn.PixelShuffle(upscale_factor=UPSCALE_FACTOR) | |
self.reconstruction_conv = nn.Conv2d(num_channels // (UPSCALE_FACTOR ** 2), 3, kernel_size=kernel_size, stride=1, padding=kernel_size//2) | |
def forward(self, x): | |
shallow_feature = self.shallow_conv(x) # Initial convolution | |
deep_feature = self.residual_in_residual(shallow_feature) # Residual in Residual | |
upscaled_image = self.upscaling_module(deep_feature) # Upscaling module | |
reconstructed_image = self.reconstruction_conv(upscaled_image) # Reconstruction | |
return reconstructed_image.clamp(0, 1) | |
def inference(self, x): | |
""" | |
x is a PIL image | |
""" | |
self.eval() | |
with torch.no_grad(): | |
x = ToTensor()(x).unsqueeze(0) | |
x = self.forward(x) | |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8')) | |
return x | |
if __name__ == '__main__': | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
model = RCAN() | |
model.load_state_dict(torch.load(current_dir + '/rcan_checkpoint.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
with torch.no_grad(): | |
input_image = Image.open('images/demo.png') | |
output_image = model.inference(input_image) |