Spaces:
Runtime error
Runtime error
File size: 5,794 Bytes
6eb1be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import PIL
from PIL import Image
import torch
from diffusion_arch import ILVRUNetModel, ConditionalUNetModel
from guided_diffusion.script_util import create_gaussian_diffusion
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w, h), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0)
mask[mask > 0] = 1
return mask
class DiffusionPipeline():
def __init__(self, device):
super().__init__()
self.device = device
diffusion_model = ILVRUNetModel(
in_channels=3,
model_channels=128,
out_channels=6,
num_res_blocks=1,
attention_resolutions=[16],
channel_mult=(1, 1, 2, 2, 4, 4),
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=4,
num_head_channels=64,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=False
)
diffusion_model = diffusion_model.to(device)
diffusion_model = diffusion_model.eval()
ilvr_pretraining = torch.load('./ffhq_10m.pt', map_location='cpu')
diffusion_model.load_state_dict(ilvr_pretraining)
self.diffusion_model = diffusion_model
diffusion_restoration_model = ConditionalUNetModel(
in_channels=3,
model_channels=128,
out_channels=6,
num_res_blocks=1,
attention_resolutions=[16],
dropout=0.0,
channel_mult=(1, 1, 2, 2, 4, 4),
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=4,
num_head_channels=64,
num_heads_upsample=-1,
use_scale_shift_norm=True,
resblock_updown=True,
use_new_attention_order=False
)
diffusion_restoration_model = diffusion_restoration_model.to(device)
diffusion_restoration_model = diffusion_restoration_model.eval()
state_dict = torch.load('./net_g_250000.pth', map_location='cpu')
diffusion_restoration_model.load_state_dict(state_dict['params'])
self.diffusion_restoration_model = diffusion_restoration_model
@torch.no_grad()
def __call__(self, lq, diffusion_step, binoising_step, grid_size):
lq = lq.convert("RGB").resize((256, 256), resample=Image.LANCZOS)
eval_gaussian_diffusion = create_gaussian_diffusion(
steps=1000,
learn_sigma=True,
noise_schedule='linear',
use_kl=False,
timestep_respacing=str(int(diffusion_step)),
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
)
ow, oh = lq.size
# preprocess image
lq_img_th = preprocess_image(lq).to(self.device)
lq_img_th = lq_img_th.repeat([grid_size, 1, 1, 1])
img = torch.randn_like(lq_img_th, device=self.device)
s_img = torch.randn_like(lq_img_th, device=self.device)
indices = list(range(eval_gaussian_diffusion.num_timesteps))[::-1]
for i in indices:
t = torch.tensor([i] * lq_img_th.size(0), device=self.device)
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_restoration_model, s_img, t, model_kwargs={'lq': lq_img_th})
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1)))
) # no noise when t == 0
s_img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device)
s_img_pred = out["pred_xstart"]
if i < binoising_step:
model_output = eval_gaussian_diffusion._wrap_model(self.diffusion_restoration_model)(img, t, lq=lq_img_th)
B, C = img.shape[:2]
model_output, model_var_values = torch.split(model_output, C, dim=1)
pred_xstart = eval_gaussian_diffusion._predict_xstart_from_eps(img, t, model_output).clamp(-1, 1)
img = eval_gaussian_diffusion.q_sample(pred_xstart, t)
out = eval_gaussian_diffusion.p_mean_variance(self.diffusion_model, img, t)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1)))
) # no noise when t == 0
img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device)
img_pred = out["pred_xstart"]
if i % 2 == 0:
yield [Image.fromarray(np.uint8((make_grid(s_img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img_pred) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))]
yield [Image.fromarray(np.uint8((make_grid(s_img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.)), Image.fromarray(np.uint8((make_grid(img) / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1,2,0) * 255.))]
|