ESRGAN-MANGA / inference.py
0x90e's picture
Use 4x_NMKD-Siax_200k for general upscale 4x
227e514
raw
history blame
1.78 kB
import sys
import cv2
import numpy as np
import torch
import ESRGAN.architecture as esrgan
import ESRGAN_plus.architecture as esrgan_plus
from run_cmd import run_cmd
from ESRGANer import ESRGANer
def is_cuda():
if torch.cuda.is_available():
return True
else:
return False
model_type = sys.argv[2]
if model_type == "Anime":
model_path = "models/4x-AnimeSharp.pth"
if model_type == "Photo":
model_path = "models/4x_Valar_v1.pth"
else:
model_path = "models/4x_NMKD-Siax_200k.pth"
OUTPUT_PATH = sys.argv[1]
device = torch.device('cuda' if is_cuda() else 'cpu')
if model_type != "Photo":
model = esrgan.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
else:
model = esrgan_plus.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
if is_cuda():
print("Using GPU πŸ₯Ά")
model.load_state_dict(torch.load(model_path), strict=True)
else:
print("Using CPU πŸ˜’")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
# Read image
img = cv2.imread(OUTPUT_PATH, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
upsampler = ESRGANer(model=model)
output = upsampler.enhance(img_LR)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(OUTPUT_PATH, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 5])