vincent-doan
Minor change
2ee69f4
raw
history blame
2.75 kB
import os
import torch
from torch import nn
from PIL import Image
from torchvision.transforms import ToTensor, functional as TF
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
nn.PReLU(),
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, 0.8),
)
def forward(self, x):
return x + self.conv_block(x)
class GeneratorResnet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResnet, self).__init__()
#first layer
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
#Residual blocks
res_blocks=[]
for _ in range(n_residual_blocks):
res_blocks.append(ResidualBlock(64))
self.res_blocks = nn.Sequential(*res_blocks)
#second conv layer after res blocks
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
upsampling=[]
for _ in range(2):
upsampling+=[
nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2),
nn.PReLU(),
]
self.upsampling = nn.Sequential(*upsampling)
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
return out.clamp(0, 1)
def inference(self, x):
"""
x is a PIL image
"""
self.eval()
with torch.no_grad():
x = ToTensor()(x).unsqueeze(0)
x = self.forward(x)
x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
return TF.adjust_brightness(x, 1.1)
if __name__ == '__main__':
current_dir = os.path.dirname(os.path.realpath(__file__))
model = GeneratorResnet()
model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
input_image = Image.open('images/demo.png')
output_image = model.inference(input_image)