Spaces:
Runtime error
Runtime error
import numpy as np | |
import PIL | |
from PIL import Image | |
import torch | |
from diffusion_arch import ILVRUNetModel, ConditionalUNetModel | |
from guided_diffusion.script_util import create_gaussian_diffusion | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from torchvision.utils import make_grid | |
def preprocess_image(image): | |
w, h = image.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0) | |
return 2.0 * image - 1.0 | |
def preprocess_mask(mask): | |
mask = mask.convert("L") | |
w, h = mask.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
mask = mask.resize((w, h), resample=PIL.Image.NEAREST) | |
mask = np.array(mask).astype(np.float32) / 255.0 | |
mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0) | |
mask[mask > 0] = 1 | |
return mask | |
class DiffusionPipeline(): | |
def __init__(self, device): | |
super().__init__() | |
self.device = device | |
diffusion_model = ILVRUNetModel( | |
in_channels=3, | |
model_channels=128, | |
out_channels=6, | |
num_res_blocks=1, | |
attention_resolutions=[16], | |
channel_mult=(1, 1, 2, 2, 4, 4), | |
num_classes=None, | |
use_checkpoint=False, | |
use_fp16=False, | |
num_heads=4, | |
num_head_channels=64, | |
num_heads_upsample=-1, | |
use_scale_shift_norm=True, | |
resblock_updown=True, | |
use_new_attention_order=False | |
) | |
diffusion_model = diffusion_model.to(device) | |
diffusion_model = diffusion_model.eval() | |
ilvr_pretraining = torch.load('./ffhq_10m.pt', map_location='cpu') | |
diffusion_model.load_state_dict(ilvr_pretraining) | |
self.diffusion_model = diffusion_model | |
diffusion_restoration_model = ConditionalUNetModel( | |
in_channels=3, | |
model_channels=128, | |
out_channels=6, | |
num_res_blocks=1, | |
attention_resolutions=[16], | |
dropout=0.0, | |
channel_mult=(1, 1, 2, 2, 4, 4), | |
num_classes=None, | |
use_checkpoint=False, | |
use_fp16=False, | |
num_heads=4, | |
num_head_channels=64, | |
num_heads_upsample=-1, | |
use_scale_shift_norm=True, | |
resblock_updown=True, | |
use_new_attention_order=False | |
) | |
diffusion_restoration_model = diffusion_restoration_model.to(device) | |
diffusion_restoration_model = diffusion_restoration_model.eval() | |
state_dict = torch.load('./net_g_250000.pth', map_location='cpu') | |
diffusion_restoration_model.load_state_dict(state_dict['params']) | |
self.diffusion_restoration_model = diffusion_restoration_model | |
def __call__(self, lq, diffusion_step, binoising_step, grid_size): | |
lq = lq.convert("RGB").resize((256, 256), resample=Image.LANCZOS) | |
eval_gaussian_diffusion = create_gaussian_diffusion( | |
steps=1000, | |
learn_sigma=True, | |
noise_schedule='linear', | |
use_kl=False, | |
timestep_respacing=str(int(diffusion_step)), | |
predict_xstart=False, | |
rescale_timesteps=False, | |
rescale_learned_sigmas=False, | |
) | |
ow, oh = lq.size | |
# preprocess image | |
lq_img_th = preprocess_image(lq).to(self.device) | |
lq_img_th = lq_img_th.repeat([grid_size, 1, 1, 1]) | |
img = torch.randn_like(lq_img_th, device=self.device) | |
s_img = torch.randn_like(lq_img_th, device=self.device) | |
indices = list(range(eval_gaussian_diffusion.num_timesteps))[::-1] | |
for i in indices: | |
t = torch.tensor([i] * lq_img_th.size(0), device=self.device) | |
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_restoration_model, s_img, t, model_kwargs={'lq': lq_img_th}) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
) # no noise when t == 0 | |
s_img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
s_img_pred = out["pred_xstart"] | |
if i < binoising_step: | |
model_output = eval_gaussian_diffusion._wrap_model(self.diffusion_restoration_model)(img, t, lq=lq_img_th) | |
B, C = img.shape[:2] | |
model_output, model_var_values = torch.split(model_output, C, dim=1) | |
pred_xstart = eval_gaussian_diffusion._predict_xstart_from_eps(img, t, model_output).clamp(-1, 1) | |
img = eval_gaussian_diffusion.q_sample(pred_xstart, t) | |
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_model, img, t) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
) # no noise when t == 0 | |
img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
img_pred = out["pred_xstart"] | |
if i % 2 == 0: | |
yield [Image.fromarray(np.uint8((make_grid(s_img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))] | |
yield [Image.fromarray(np.uint8((make_grid(s_img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))] | |