jsu27's picture
revert
9e2d369
import os
import numpy as np
import torch as th
from imageio import imread
from skimage.transform import resize as imresize
from PIL import Image
from decomp_diffusion.model_and_diffusion_util import *
from decomp_diffusion.diffusion.respace import SpacedDiffusion
from decomp_diffusion.gen_image import *
from download import download_model
from upsampling import get_pipeline, upscale_image
import gradio as gr
# from huggingface_hub import login
# fix randomness
th.manual_seed(0)
np.random.seed(0)
def get_pil_im(im, resolution=64):
im = imresize(im, (resolution, resolution))[:, :, :3]
im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous()
return im
# generate image components and reconstruction
def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1):
"""Generate row of orig image, individual components, and reconstructed image"""
orig_img = get_pil_im(im, resolution=image_size).to(device)
latent = model.encode_latent(orig_img)
model_kwargs = {'latent': latent}
assert sample_method in ('ddpm', 'ddim')
sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop
if sample_method == 'ddim':
model = gd._wrap_model(model)
# generate imgs
for i in range(num_images):
all_samples = [orig_img]
# individual components
for j in range(num_components):
model_kwargs['latent_index'] = j
sample = sample_loop_func(
model,
(batch_size, 3, image_size, image_size),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
# save indiv comp
all_samples.append(sample)
# reconstruction
model_kwargs['latent_index'] = None
sample = sample_loop_func(
model,
(batch_size, 3, image_size, image_size),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
# save indiv reconstruction
all_samples.append(sample)
samples = th.cat(all_samples, dim=0).cpu()
grid = make_grid(samples, nrow=samples.shape[0], padding=0)
return grid
# def decompose_image(im):
# sample_method = 'ddim'
# result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device)
# return result.permute(1, 2, 0).numpy()
# load diffusion
GD = {} # diffusion objects for ddim and ddpm
diffusion_kwargs = diffusion_defaults()
gd = create_gaussian_diffusion(**diffusion_kwargs)
GD['ddpm'] = gd
# set up ddim sampling
desired_timesteps = 50
num_timesteps = diffusion_kwargs['steps']
spacing = num_timesteps // desired_timesteps
spaced_ts = list(range(0, num_timesteps + 1, spacing))
betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps)
diffusion_kwargs['betas'] = betas
del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule']
gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs)
GD['ddim'] = gd
def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64):
"""Combine by adding components together
"""
assert sample_method in ('ddpm', 'ddim')
im1 = get_pil_im(im1, resolution=image_size).to(device)
im2 = get_pil_im(im2, resolution=image_size).to(device)
latent1 = model.encode_latent(im1)
latent2 = model.encode_latent(im2)
num_comps = model.num_components
# get latent slices
if indices == None:
half = num_comps // 2
indices = [1] * half + [0] * half # first half 1, second half 0
indices = th.Tensor(indices) == 1
indices = indices.reshape(num_comps, 1)
elif type(indices) == str:
indices = indices.split(',')
indices = [int(ind) for ind in indices]
indices = th.Tensor(indices).reshape(-1, 1) == 1
assert len(indices) == num_comps
indices = indices.to(device)
latent1 = latent1.reshape(num_comps, -1).to(device)
latent2 = latent2.reshape(num_comps, -1).to(device)
combined_latent = th.where(indices, latent1, latent2)
combined_latent = combined_latent.reshape(1, -1)
model_kwargs['latent'] = combined_latent
sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop
if sample_method == 'ddim':
model = gd._wrap_model(model)
# sampling loop
sample = sample_loop_func(
model,
(1, 3, image_size, image_size),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:1]
return sample[0].cpu()
def decompose_image_demo(im, model):
sample_method = 'ddim'
result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1, device=device)
# result = Image.fromarray(result.permute(1, 2, 0).numpy())
return result.permute(1, 2, 0).numpy()
def combine_images_demo(im1, im2, model):
sample_method = 'ddim'
result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device)
result = result.permute(1, 2, 0).numpy()
# result = Image.fromarray(result.permute(1, 2, 0).numpy())
# if model == 'CelebA-HQ':
# return upscale_image(result, pipe)
return result
def load_model(dataset, extra_kwargs={}, device='cuda'):
ckpt_path = download_model(dataset)
model_kwargs = unet_model_defaults()
# model parameters
model_kwargs.update(extra_kwargs)
model = create_diffusion_model(**model_kwargs)
model.eval()
model.to(device)
print(f'loading from {ckpt_path}')
checkpoint = th.load(ckpt_path, map_location='cpu')
model.load_state_dict(checkpoint)
return model
device = 'cuda' if th.cuda.is_available() else 'cpu'
clevr_model = load_model('clevr', extra_kwargs=dict(emb_dim=64, enc_channels=128), device=device)
celeb_model = load_model('celebahq', extra_kwargs=dict(enc_channels=128), device=device)
MODELS = {
'CLEVR': clevr_model,
'CelebA-HQ': celeb_model
}
# pipe = get_pipeline()
with gr.Blocks() as demo:
gr.Markdown(
"""<h1 style="text-align: center;"><b>Unsupervised Compositional Image Decomposition with Diffusion Models
</b> - <a href="https://jsu27.github.io/decomp-diffusion-web/">Project Page</a></h1>""")
gr.Markdown(
"""<p style="font-size: 18px;">We introduce Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models.
</p>""")
gr.Markdown(
"""<br> <h4>Decomposition and reconstruction of images</h4>""")
with gr.Row():
with gr.Column():
with gr.Row():
decomp_input = gr.Image(type='numpy', label='Input')
with gr.Row():
decomp_model = gr.Radio(
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
value='CLEVR')
with gr.Row():
# image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR']
decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'],
['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']]
decomp_img_examples = gr.Examples(
examples=decomp_examples,
inputs=[decomp_input, decomp_model]
)
with gr.Column():
decomp_output = gr.Image(type='numpy')
decomp_button = gr.Button("Generate")
gr.Markdown(
"""<br> <h4>Combination of images</h4>""")
with gr.Row().style(equal_height=True):
with gr.Column(scale=2):
with gr.Row():
with gr.Column():
comb_input1 = gr.Image(type='numpy', label='Input 1')
with gr.Column():
comb_input2 = gr.Image(type='numpy', label='Input 2')
with gr.Row():
comb_model = gr.Radio(
['CLEVR', 'CelebA-HQ'], type="value", label='Model',
value='CLEVR')
with gr.Row():
comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'],
['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']]
comb_img_examples = gr.Examples(
examples=comb_examples,
inputs=[comb_input1, comb_input2, comb_model]
)
with gr.Column(scale=1):
comb_output = gr.Image(type='numpy')
comb_button = gr.Button("Generate")
decomp_button.click(decompose_image_demo,
inputs=[decomp_input, decomp_model],
outputs=decomp_output)
comb_button.click(combine_images_demo,
inputs=[comb_input1, comb_input2, comb_model],
outputs=comb_output)
demo.launch(debug=True)