thiemcun203's picture
VDSR added (#6)
5bdf4bb verified
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from PIL import Image
import os
from math import sqrt
import torch.nn.functional as F
#define class Block contain conv and relu layer
class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.conv(x))
class VDSR(nn.Module):
def __init__(self, in_channels=3, out_channels=3, num_blocks=18):
super(VDSR, self).__init__()
self.residual_layer = self.make_layer(Block, num_blocks)
self.input = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.output = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU(inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, sqrt(2. / n))
def make_layer(self, block, num_layers):
layers=[]
for _ in range(num_layers):
layers.append(block())
return nn.Sequential(*layers)
def forward(self, x):
residual = x
out = self.relu(self.input(x))
out = self.residual_layer(out)
out = self.output(out)
out = torch.add(residual, out)
return out
def inference(self, x):
"""
x is a PIL image
"""
self.eval()
with torch.no_grad():
x = ToTensor()(x).unsqueeze(0)
x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False).clamp(0, 1)
x = self.forward(x).clamp(0, 1)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return x
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = torch.load(current_dir + '/vdsr_checkpoint.pth', map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image)
print(input_image.size, output_image.size)