DI-PCG / scripts /sample_diffusion.py
thuzhaowang's picture
init
b6a9b6d
raw
history blame
5.26 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Sample new images from a pre-trained DiT.
"""
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
import yaml
import json
import numpy as np
from pathlib import Path
import gin
import importlib
import logging
import cv2
from huggingface_hub import hf_hub_download
logging.basicConfig(
format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s",
datefmt="%H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from core.diffusion import create_diffusion
from core.models import DiT_models
from core.utils.train_utils import load_model
from core.utils.math_utils import unnormalize_params
from scripts.prepare_data import generate
from core.utils.dinov2 import Dinov2Model
def main(cfg, generator):
# Setup PyTorch:
torch.manual_seed(cfg["seed"])
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model:
latent_size = cfg["num_params"]
model = DiT_models[cfg["model"]](input_size=latent_size).to(device)
# load a custom DiT checkpoint from train.py:
# download the checkpoint if not found:
if not os.path.exists(cfg["ckpt_path"]):
model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"])
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG",
local_dir=model_dir, filename=model_name)
print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name))
print("Loading model from {}".format(cfg["ckpt_path"]))
state_dict = load_model(cfg["ckpt_path"])
model.load_state_dict(state_dict)
model.eval() # important!
diffusion = create_diffusion(str(cfg["num_sampling_steps"]))
# feature model
feature_model = Dinov2Model()
img_names = sorted(os.listdir(cfg["condition_img_dir"]))
for name in img_names:
img_path = os.path.join(cfg["condition_img_dir"], name)
# Load condition image and extract features
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
# pre-process: resize to 256x256
img = cv2.resize(img, (256, 256))
img = np.array(img).astype(np.uint8)
img_feat = feature_model.encode_batch_imgs([img], global_feat=False)
if len(img_feat.shape) == 2:
img_feat = img_feat.unsqueeze(1)
# Create sampling noise:
z = torch.randn(1, 1, latent_size, device=device)
y = img_feat
# No classifier-free guidance:
model_kwargs = dict(y=y)
# Sample target params:
samples = diffusion.p_sample_loop(
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
samples = samples[0].squeeze(0).cpu().numpy()
# unnormalize params
params_dict = generator.params_dict
params_original = unnormalize_params(samples, params_dict)
# save params
json.dump(params_original, open("{}/{}_params.txt".format(cfg["save_dir"], name), "w"), default=str)
# generate 3D using sampled params
asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=cfg["save_dir"], save_name=name,
save_blend=True, save_img=True, save_untexture_img=True, save_gif=False, save_mesh=True,
cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"],
resolution='720x720', sample=200)
print("Generating model using sampled parameters. Saved in {}".format(cfg["save_dir"]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--remove_bg", type=bool, default=False)
args = parser.parse_args()
with open(args.config) as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg["remove_bg"] = args.remove_bg
# load the Blender procedural generator
OBJECTS_PATH = Path(cfg["generator_root"])
assert OBJECTS_PATH.exists(), OBJECTS_PATH
generator = None
for subdir in sorted(list(OBJECTS_PATH.iterdir())):
clsname = subdir.name.split(".")[0].strip()
with gin.unlock_config():
module = importlib.import_module(f"core.assets.{clsname}")
if hasattr(module, cfg["generator"]):
generator = getattr(module, cfg["generator"])
logger.info("Found {} in {}".format(cfg["generator"], subdir))
break
logger.debug("{} not found in {}".format(cfg["generator"], subdir))
if generator is None:
raise ModuleNotFoundError("{} not Found.".format(cfg["generator"]))
gen = generator(cfg["seed"])
# create visualize dir
os.makedirs(cfg["save_dir"], exist_ok=True)
main(cfg, gen)