Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchvision.transforms import ToTensor | |
from PIL import Image | |
import os | |
from math import sqrt | |
import torch.nn.functional as F | |
#define class Block contain conv and relu layer | |
class Block(nn.Module): | |
def __init__(self): | |
super(Block, self).__init__() | |
self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.relu = nn.ReLU(inplace=True) | |
def forward(self, x): | |
return self.relu(self.conv(x)) | |
class VDSR(nn.Module): | |
def __init__(self, in_channels=3, out_channels=3, num_blocks=18): | |
super(VDSR, self).__init__() | |
self.residual_layer = self.make_layer(Block, num_blocks) | |
self.input = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.output = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) | |
self.relu = nn.ReLU(inplace=True) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, sqrt(2. / n)) | |
def make_layer(self, block, num_layers): | |
layers=[] | |
for _ in range(num_layers): | |
layers.append(block()) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
residual = x | |
out = self.relu(self.input(x)) | |
out = self.residual_layer(out) | |
out = self.output(out) | |
out = torch.add(residual, out) | |
return out | |
def inference(self, x): | |
""" | |
x is a PIL image | |
""" | |
self.eval() | |
with torch.no_grad(): | |
x = ToTensor()(x).unsqueeze(0) | |
x = F.interpolate(x, scale_factor=4, mode='bicubic', align_corners=False).clamp(0, 1) | |
x = self.forward(x).clamp(0, 1) | |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8')) | |
return x | |
if __name__ == '__main__': | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
model = torch.load(current_dir + '/vdsr_checkpoint.pth', map_location=torch.device('cpu')) | |
model.eval() | |
with torch.no_grad(): | |
input_image = Image.open('images/demo.png') | |
output_image = model.inference(input_image) | |
print(input_image.size, output_image.size) | |