import os import numpy as np import torch as th from imageio import imread from skimage.transform import resize as imresize from PIL import Image from decomp_diffusion.model_and_diffusion_util import * from decomp_diffusion.diffusion.respace import SpacedDiffusion from decomp_diffusion.gen_image import * from download import download_model from upsampling import get_pipeline, upscale_image import gradio as gr from huggingface_hub import HfApi HF_TOKEN = os.getenv('HF_TOKEN') hf_api = HfApi(token=HF_TOKEN) # fix randomness th.manual_seed(0) np.random.seed(0) def get_pil_im(im, resolution=64): im = imresize(im, (resolution, resolution))[:, :, :3] im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous() return im # generate image components and reconstruction def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1): """Generate row of orig image, individual components, and reconstructed image""" orig_img = get_pil_im(im, resolution=image_size).to(device) latent = model.encode_latent(orig_img) model_kwargs = {'latent': latent} assert sample_method in ('ddpm', 'ddim') sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop if sample_method == 'ddim': model = gd._wrap_model(model) # generate imgs for i in range(num_images): all_samples = [orig_img] # individual components for j in range(num_components): model_kwargs['latent_index'] = j sample = sample_loop_func( model, (batch_size, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] # save indiv comp all_samples.append(sample) # reconstruction model_kwargs['latent_index'] = None sample = sample_loop_func( model, (batch_size, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:batch_size] # save indiv reconstruction all_samples.append(sample) samples = th.cat(all_samples, dim=0).cpu() grid = make_grid(samples, nrow=samples.shape[0], padding=0) return grid # def decompose_image(im): # sample_method = 'ddim' # result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) # return result.permute(1, 2, 0).numpy() # load diffusion GD = {} # diffusion objects for ddim and ddpm diffusion_kwargs = diffusion_defaults() gd = create_gaussian_diffusion(**diffusion_kwargs) GD['ddpm'] = gd # set up ddim sampling desired_timesteps = 50 num_timesteps = diffusion_kwargs['steps'] spacing = num_timesteps // desired_timesteps spaced_ts = list(range(0, num_timesteps + 1, spacing)) betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps) diffusion_kwargs['betas'] = betas del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule'] gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs) GD['ddim'] = gd def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64): """Combine by adding components together """ assert sample_method in ('ddpm', 'ddim') im1 = get_pil_im(im1, resolution=image_size).to(device) im2 = get_pil_im(im2, resolution=image_size).to(device) latent1 = model.encode_latent(im1) latent2 = model.encode_latent(im2) num_comps = model.num_components # get latent slices if indices == None: half = num_comps // 2 indices = [1] * half + [0] * half # first half 1, second half 0 indices = th.Tensor(indices) == 1 indices = indices.reshape(num_comps, 1) elif type(indices) == str: indices = indices.split(',') indices = [int(ind) for ind in indices] indices = th.Tensor(indices).reshape(-1, 1) == 1 assert len(indices) == num_comps indices = indices.to(device) latent1 = latent1.reshape(num_comps, -1).to(device) latent2 = latent2.reshape(num_comps, -1).to(device) combined_latent = th.where(indices, latent1, latent2) combined_latent = combined_latent.reshape(1, -1) model_kwargs['latent'] = combined_latent sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop if sample_method == 'ddim': model = gd._wrap_model(model) # sampling loop sample = sample_loop_func( model, (1, 3, image_size, image_size), device=device, clip_denoised=True, progress=True, model_kwargs=model_kwargs, cond_fn=None, )[:1] return sample[0].cpu() def decompose_image_demo(im, model): sample_method = 'ddim' result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) # result = Image.fromarray(result.permute(1, 2, 0).numpy()) return result.permute(1, 2, 0).numpy() def combine_images_demo(im1, im2, model): sample_method = 'ddim' result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device) result = Image.fromarray(result.permute(1, 2, 0).numpy()) if model == 'CelebA-HQ': return upscale_image(result, pipe) return result def load_model(dataset, extra_kwargs={}, device='cuda'): ckpt_path = download_model(dataset) model_kwargs = unet_model_defaults() # model parameters model_kwargs.update(extra_kwargs) model = create_diffusion_model(**model_kwargs) model.eval() model.to(device) print(f'loading from {ckpt_path}') checkpoint = th.load(ckpt_path, map_location='cpu') model.load_state_dict(checkpoint) return model device = 'cuda' if th.cuda.is_available() else 'cpu' clevr_model = load_model('clevr', extra_kwargs=dict(emb_dim=64, enc_channels=128), device=device) celeb_model = load_model('celebahq', extra_kwargs=dict(enc_channels=128), device=device) MODELS = { 'CLEVR': clevr_model, 'CelebA-HQ': celeb_model } pipe = get_pipeline() with gr.Blocks() as demo: gr.Markdown( """

Unsupervised Compositional Image Decomposition with Diffusion Models - Project Page

""") gr.Markdown( """

We introduce Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models.

""") gr.Markdown( """

Decomposition and reconstruction of images

""") with gr.Row(): with gr.Column(): with gr.Row(): decomp_input = gr.Image(type='numpy', label='Input') with gr.Row(): decomp_model = gr.Radio( ['CLEVR', 'CelebA-HQ'], type="value", label='Model', value='CLEVR') with gr.Row(): # image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR'] decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'], ['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']] decomp_img_examples = gr.Examples( examples=decomp_examples, inputs=[decomp_input, decomp_model] ) with gr.Column(): decomp_output = gr.Image(type='numpy') decomp_button = gr.Button("Generate") gr.Markdown( """

Combination of images

""") with gr.Row().style(equal_height=True): with gr.Column(scale=2): with gr.Row(): with gr.Column(): comb_input1 = gr.Image(type='numpy', label='Input 1') with gr.Column(): comb_input2 = gr.Image(type='numpy', label='Input 2') with gr.Row(): comb_model = gr.Radio( ['CLEVR', 'CelebA-HQ'], type="value", label='Model', value='CLEVR') with gr.Row(): comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'], ['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']] comb_img_examples = gr.Examples( examples=comb_examples, inputs=[comb_input1, comb_input2, comb_model] ) with gr.Column(scale=1): comb_output = gr.Image(type='pil') comb_button = gr.Button("Generate") decomp_button.click(decompose_image_demo, inputs=[decomp_input, decomp_model], outputs=decomp_output) comb_button.click(combine_images_demo, inputs=[comb_input1, comb_input2, comb_model], outputs=comb_output) demo.launch(debug=True)