import os import imageio import numpy as np os.system("bash install.sh") from omegaconf import OmegaConf import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF import rembg import gradio as gr from dva.io import load_from_config from dva.ray_marcher import RayMarcher from dva.visualize import visualize_primvolume, visualize_video_primvolume from inference import remove_background, resize_foreground, extract_texmesh from models.diffusion import create_diffusion from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt") vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt") GRADIO_PRIM_VIDEO_PATH = 'prim.mp4' GRADIO_RGB_VIDEO_PATH = 'rgb.mp4' GRADIO_MAT_VIDEO_PATH = 'mat.mp4' GRADIO_GLB_PATH = 'pbr_mesh.glb' CONFIG_PATH = "./configs/inference_dit.yml" config = OmegaConf.load(CONFIG_PATH) config.checkpoint_path = ckpt_path config.model.vae_checkpoint_path = vae_ckpt_path # model model = load_from_config(config.model.generator) state_dict = torch.load(config.checkpoint_path, map_location='cpu') model.load_state_dict(state_dict['ema']) vae = load_from_config(config.model.vae) vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') vae.load_state_dict(vae_state_dict['model_state_dict']) conditioner = load_from_config(config.model.conditioner) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vae = vae.to(device) conditioner = conditioner.to(device) model = model.to(device) model.eval() amp = True precision_dtype = torch.float16 rm = RayMarcher( config.image_height, config.image_width, **config.rm, ).to(device) perchannel_norm = False if "latent_mean" in config.model: latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device) latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device) assert latent_mean.shape[-1] == config.model.generator.in_channels perchannel_norm = True config.diffusion.pop("timestep_respacing") config.model.pop("vae") config.model.pop("vae_checkpoint_path") config.model.pop("conditioner") config.model.pop("generator") config.model.pop("latent_nf") config.model.pop("latent_mean") config.model.pop("latent_std") model_primx = load_from_config(config.model) # load rembg rembg_session = rembg.new_session() # process function def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0): # seed torch.manual_seed(input_seed) os.makedirs(config.output_dir, exist_ok=True) output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH) output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH) output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH) output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH) respacing = "ddim{}".format(input_num_steps) diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) sample_fn = diffusion.ddim_sample_loop_progressive fwd_fn = model.forward_with_cfg # text-conditioned if input_image is None: raise NotImplementedError # image-conditioned (may also input text, but no text usually works too) else: input_image = remove_background(input_image, rembg_session) input_image = resize_foreground(input_image, 0.85) raw_image = np.array(input_image) mask = (raw_image[..., -1][..., None] > 0) * 1 raw_image = raw_image[..., :3] * mask input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device) with torch.no_grad(): latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4) batch = {} inf_bs = 1 inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device) y = conditioner.encoder(input_cond) model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp) if input_cfg >= 0: model_kwargs['cfg_scale'] = input_cfg for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device): final_samples = samples recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1) if perchannel_norm: recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean recon_srt_param = recon_param[:, :, 0:4] recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64] recon_feat_param_list = [] # one-by-one to avoid oom for inf_bidx in range(inf_bs): if not perchannel_norm: decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf) else: decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:])) recon_feat_param_list.append(decoded.detach()) recon_feat_param = torch.concat(recon_feat_param_list, dim=0) # invert normalization if not perchannel_norm: recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05 recon_feat_param[:, 0:1, ...] /= 5. recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1) recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device) prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()} torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir)) # exporting GLB mesh denoise_param_path = os.path.join(config.output_dir, 'denoised.pt') primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict'] model_primx.load_state_dict(ckpt_weight) model_primx.to(device) model_primx.eval() with torch.no_grad(): model_primx.srt_param[:, 1:4] *= 0.85 extract_texmesh(config.inference, model_primx, output_glb_path, device) return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path # gradio UI _TITLE = '''3DTopia-XL''' _DESCRIPTION = '''
* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future! * If you find the output unsatisfying, try using different seeds! ''' block = gr.Blocks(title=_TITLE).queue() with block: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(variant='panel'): with gr.Column(scale=1): # input image input_image = gr.Image(label="image", type='pil') # inference steps input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25) # random seed input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6) # random seed input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42) # gen button button_gen = gr.Button("Generate") with gr.Column(scale=1): with gr.Tab("Video"): # final video results output_rgb_video = gr.Video(label="video") output_prim_video = gr.Video(label="video") output_mat_video = gr.Video(label="video") with gr.Tab("GLB"): # glb file output_glb = gr.File(label="glb") button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb]) gr.Examples( examples=[ "assets/examples/fruit_elephant.jpg", "assets/examples/mei_ling_panda.png", "assets/examples/shuai_panda_notail.png", ], inputs=[input_image], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb], fn=lambda x: process(input_image=x), cache_examples=False, label='Single Image to 3D PBR Asset' ) block.launch(server_name="0.0.0.0", share=True)