ESRGAN-MANGA / inference_manga_v2.py
0x90e's picture
better image file handling
5e08d25
raw
history blame
1.32 kB
import sys
import cv2
import numpy as np
import torch
import ESRGAN.architecture as arch
from ESRGANer import ESRGANer
def is_cuda():
if torch.cuda.is_available():
return True
else:
return False
model_path = 'models/4x_eula_digimanga_bw_v2_nc1_307k.pth'
OUTPUT_PATH = sys.argv[1]
device = torch.device('cuda' if is_cuda() else 'cpu')
model = arch.RRDB_Net(1, 1, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
if is_cuda():
print("Using GPU πŸ₯Ά")
model.load_state_dict(torch.load(model_path), strict=True)
else:
print("Using CPU πŸ˜’")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
# Read image
img = cv2.imread(OUTPUT_PATH, cv2.IMREAD_GRAYSCALE)
img = img * 1.0 / 255
img = torch.from_numpy(img[np.newaxis, :, :]).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
upsampler = ESRGANer(model=model)
output = upsampler.enhance(img_LR)
output = output.squeeze(dim=0).float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite(OUTPUT_PATH, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 5])