Spaces:
Running
Running
import os | |
import torch | |
from torch import nn | |
from PIL import Image | |
from torchvision.transforms import ToTensor, functional as TF | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_features): | |
super(ResidualBlock, self).__init__() | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(in_features, 0.8), | |
nn.PReLU(), | |
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(in_features, 0.8), | |
) | |
def forward(self, x): | |
return x + self.conv_block(x) | |
class GeneratorResnet(nn.Module): | |
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16): | |
super(GeneratorResnet, self).__init__() | |
#first layer | |
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU()) | |
#Residual blocks | |
res_blocks=[] | |
for _ in range(n_residual_blocks): | |
res_blocks.append(ResidualBlock(64)) | |
self.res_blocks = nn.Sequential(*res_blocks) | |
#second conv layer after res blocks | |
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8)) | |
upsampling=[] | |
for _ in range(2): | |
upsampling+=[ | |
nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(256), | |
nn.PixelShuffle(upscale_factor=2), | |
nn.PReLU(), | |
] | |
self.upsampling = nn.Sequential(*upsampling) | |
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh()) | |
def forward(self, x): | |
out1 = self.conv1(x) | |
out = self.res_blocks(out1) | |
out2 = self.conv2(out) | |
out = torch.add(out1, out2) | |
out = self.upsampling(out) | |
out = self.conv3(out) | |
return out.clamp(0, 1) | |
def inference(self, x): | |
""" | |
x is a PIL image | |
""" | |
self.eval() | |
with torch.no_grad(): | |
x = ToTensor()(x).unsqueeze(0) | |
x = self.forward(x) | |
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8')) | |
return TF.adjust_brightness(x, 1.1) | |
if __name__ == '__main__': | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
model = GeneratorResnet() | |
model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu')) | |
model.eval() | |
with torch.no_grad(): | |
input_image = Image.open('images/demo.png') | |
output_image = model.inference(input_image) |