Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Sample new images from a pre-trained DiT. | |
""" | |
import os | |
import sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import argparse | |
import yaml | |
import json | |
import numpy as np | |
from pathlib import Path | |
import gin | |
import importlib | |
import logging | |
import cv2 | |
from huggingface_hub import hf_hub_download | |
logging.basicConfig( | |
format="[%(asctime)s.%(msecs)03d] [%(module)s] [%(levelname)s] | %(message)s", | |
datefmt="%H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
from core.diffusion import create_diffusion | |
from core.models import DiT_models | |
from core.utils.train_utils import load_model | |
from core.utils.math_utils import unnormalize_params | |
from scripts.prepare_data import generate | |
from core.utils.dinov2 import Dinov2Model | |
def main(cfg, generator): | |
# Setup PyTorch: | |
torch.manual_seed(cfg["seed"]) | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model: | |
latent_size = cfg["num_params"] | |
model = DiT_models[cfg["model"]](input_size=latent_size).to(device) | |
# load a custom DiT checkpoint from train.py: | |
# download the checkpoint if not found: | |
if not os.path.exists(cfg["ckpt_path"]): | |
model_dir, model_name = os.path.dirname(cfg["ckpt_path"]), os.path.basename(cfg["ckpt_path"]) | |
os.makedirs(model_dir, exist_ok=True) | |
checkpoint_path = hf_hub_download(repo_id="TencentARC/DI-PCG", | |
local_dir=model_dir, filename=model_name) | |
print("Downloading checkpoint {} from Hugging Face Hub...".format(model_name)) | |
print("Loading model from {}".format(cfg["ckpt_path"])) | |
state_dict = load_model(cfg["ckpt_path"]) | |
model.load_state_dict(state_dict) | |
model.eval() # important! | |
diffusion = create_diffusion(str(cfg["num_sampling_steps"])) | |
# feature model | |
feature_model = Dinov2Model() | |
img_names = sorted(os.listdir(cfg["condition_img_dir"])) | |
for name in img_names: | |
img_path = os.path.join(cfg["condition_img_dir"], name) | |
# Load condition image and extract features | |
img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) | |
# pre-process: resize to 256x256 | |
img = cv2.resize(img, (256, 256)) | |
img = np.array(img).astype(np.uint8) | |
img_feat = feature_model.encode_batch_imgs([img], global_feat=False) | |
if len(img_feat.shape) == 2: | |
img_feat = img_feat.unsqueeze(1) | |
# Create sampling noise: | |
z = torch.randn(1, 1, latent_size, device=device) | |
y = img_feat | |
# No classifier-free guidance: | |
model_kwargs = dict(y=y) | |
# Sample target params: | |
samples = diffusion.p_sample_loop( | |
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device | |
) | |
samples = samples[0].squeeze(0).cpu().numpy() | |
# unnormalize params | |
params_dict = generator.params_dict | |
params_original = unnormalize_params(samples, params_dict) | |
# save params | |
json.dump(params_original, open("{}/{}_params.txt".format(cfg["save_dir"], name), "w"), default=str) | |
# generate 3D using sampled params | |
asset, _ = generate(generator, params_original, seed=cfg["seed"], save_dir=cfg["save_dir"], save_name=name, | |
save_blend=True, save_img=True, save_untexture_img=True, save_gif=False, save_mesh=True, | |
cam_dists=cfg["r_cam_dists"], cam_elevations=cfg["r_cam_elevations"], cam_azimuths=cfg["r_cam_azimuths"], zoff=cfg["r_zoff"], | |
resolution='720x720', sample=200) | |
print("Generating model using sampled parameters. Saved in {}".format(cfg["save_dir"])) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--remove_bg", type=bool, default=False) | |
args = parser.parse_args() | |
with open(args.config) as f: | |
cfg = yaml.load(f, Loader=yaml.FullLoader) | |
cfg["remove_bg"] = args.remove_bg | |
# load the Blender procedural generator | |
OBJECTS_PATH = Path(cfg["generator_root"]) | |
assert OBJECTS_PATH.exists(), OBJECTS_PATH | |
generator = None | |
for subdir in sorted(list(OBJECTS_PATH.iterdir())): | |
clsname = subdir.name.split(".")[0].strip() | |
with gin.unlock_config(): | |
module = importlib.import_module(f"core.assets.{clsname}") | |
if hasattr(module, cfg["generator"]): | |
generator = getattr(module, cfg["generator"]) | |
logger.info("Found {} in {}".format(cfg["generator"], subdir)) | |
break | |
logger.debug("{} not found in {}".format(cfg["generator"], subdir)) | |
if generator is None: | |
raise ModuleNotFoundError("{} not Found.".format(cfg["generator"])) | |
gen = generator(cfg["seed"]) | |
# create visualize dir | |
os.makedirs(cfg["save_dir"], exist_ok=True) | |
main(cfg, gen) | |