dreamgaussian2 / main2.py
jiawei011's picture
init
12b7f59
raw
history blame
23.7 kB
import os
import cv2
import time
import tqdm
import numpy as np
import dearpygui.dearpygui as dpg
import torch
import torch.nn.functional as F
import trimesh
import rembg
from cam_utils import orbit_camera, OrbitCamera
from mesh_renderer import Renderer
# from kiui.lpips import LPIPS
class GUI:
def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.gui = opt.gui # enable gui
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.mode = "image"
self.seed = "random"
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # update buffer_image
# models
self.device = torch.device("cuda")
self.bg_remover = None
self.guidance_sd = None
self.guidance_zero123 = None
self.enable_sd = False
self.enable_zero123 = False
# renderer
self.renderer = Renderer(opt).to(self.device)
# input image
self.input_img = None
self.input_mask = None
self.input_img_torch = None
self.input_mask_torch = None
self.overlay_input_img = False
self.overlay_input_img_ratio = 0.5
# input text
self.prompt = ""
self.negative_prompt = ""
# training stuff
self.training = False
self.optimizer = None
self.step = 0
self.train_steps = 1 # steps per rendering loop
# self.lpips_loss = LPIPS(net='vgg').to(self.device)
# load input data from cmdline
if self.opt.input is not None:
self.load_input(self.opt.input)
# override prompt from cmdline
if self.opt.prompt is not None:
self.prompt = self.opt.prompt
if self.gui:
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
if self.gui:
dpg.destroy_context()
def seed_everything(self):
try:
seed = int(self.seed)
except:
seed = np.random.randint(0, 1000000)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
self.last_seed = seed
def prepare_train(self):
self.step = 0
# setup training
self.optimizer = torch.optim.Adam(self.renderer.get_params())
# default camera
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
self.fixed_cam = (pose, self.cam.perspective)
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
# lazy load guidance model
if self.guidance_sd is None and self.enable_sd:
print(f"[INFO] loading SD...")
from guidance.sd_utils import StableDiffusion
self.guidance_sd = StableDiffusion(self.device)
print(f"[INFO] loaded SD!")
if self.guidance_zero123 is None and self.enable_zero123:
print(f"[INFO] loading zero123...")
from guidance.zero123_utils import Zero123
self.guidance_zero123 = Zero123(self.device)
print(f"[INFO] loaded zero123!")
# input image
if self.input_img is not None:
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_img_torch = F.interpolate(
self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False
)
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_mask_torch = F.interpolate(
self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False
)
self.input_img_torch_channel_last = self.input_img_torch[0].permute(1,2,0).contiguous()
# prepare embeddings
with torch.no_grad():
if self.enable_sd:
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
if self.enable_zero123:
self.guidance_zero123.get_img_embeds(self.input_img_torch)
def train_step(self):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
for _ in range(self.train_steps):
self.step += 1
step_ratio = min(1, self.step / self.opt.iters_refine)
loss = 0
### known view
if self.input_img_torch is not None:
ssaa = min(2.0, max(0.125, 2 * np.random.random()))
out = self.renderer.render(*self.fixed_cam, self.opt.ref_size, self.opt.ref_size, ssaa=ssaa)
# rgb loss
image = out["image"] # [H, W, 3] in [0, 1]
valid_mask = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach()
loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last * valid_mask)
### novel view (manual batch)
render_resolution = 512
images = []
vers, hors, radii = [], [], []
# avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
for _ in range(self.opt.batch_size):
# render random view
ver = np.random.randint(min_ver, max_ver)
hor = np.random.randint(-180, 180)
radius = 0
vers.append(ver)
hors.append(hor)
radii.append(radius)
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
# random render resolution
ssaa = min(2.0, max(0.125, 2 * np.random.random()))
out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa)
image = out["image"] # [H, W, 3] in [0, 1]
image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1]
images.append(image)
images = torch.cat(images, dim=0)
# import kiui
# kiui.lo(hor, ver)
# kiui.vis.plot_image(image)
# guidance loss
if self.enable_sd:
# loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
refined_images = self.guidance_sd.refine(images, strength=0.6).float()
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
loss = loss + self.opt.lambda_sd * F.mse_loss(images, refined_images)
if self.enable_zero123:
# loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio)
refined_images = self.guidance_zero123.refine(images, vers, hors, radii, strength=0.6).float()
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images)
# loss = loss + self.opt.lambda_zero123 * self.lpips_loss(images, refined_images)
# optimize step
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.need_update = True
if self.gui:
dpg.set_value("_log_train_time", f"{t:.4f}ms")
dpg.set_value(
"_log_train_log",
f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}",
)
# dynamic train steps (no need for now)
# max allowed train time per-frame is 500 ms
# full_t = t / self.train_steps * 16
# train_steps = min(16, max(4, int(16 * 500 / full_t)))
# if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
# self.train_steps = train_steps
@torch.no_grad()
def test_step(self):
# ignore if no need to update
if not self.need_update:
return
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
# should update image
if self.need_update:
# render image
out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W)
buffer_image = out[self.mode] # [H, W, 3]
if self.mode in ['depth', 'alpha']:
buffer_image = buffer_image.repeat(1, 1, 3)
if self.mode == 'depth':
buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy()
# display input_image
if self.overlay_input_img and self.input_img is not None:
self.buffer_image = (
self.buffer_image * (1 - self.overlay_input_img_ratio)
+ self.input_img * self.overlay_input_img_ratio
)
self.need_update = False
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
if self.gui:
dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
dpg.set_value(
"_texture", self.buffer_image
) # buffer must be contiguous, else seg fault!
def load_input(self, file):
# load image
print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
if img.shape[-1] == 3:
if self.bg_remover is None:
self.bg_remover = rembg.new_session()
img = rembg.remove(img, session=self.bg_remover)
img = cv2.resize(
img, (self.W, self.H), interpolation=cv2.INTER_AREA
)
img = img.astype(np.float32) / 255.0
self.input_mask = img[..., 3:]
# white bg
self.input_img = img[..., :3] * self.input_mask + (
1 - self.input_mask
)
# bgr to rgb
self.input_img = self.input_img[..., ::-1].copy()
# load prompt
file_prompt = file.replace("_rgba.png", "_caption.txt")
if os.path.exists(file_prompt):
print(f'[INFO] load prompt from {file_prompt}...')
with open(file_prompt, "r") as f:
self.prompt = f.read().strip()
def save_model(self):
os.makedirs(self.opt.outdir, exist_ok=True)
path = os.path.join(self.opt.outdir, self.opt.save_path + '.' + self.opt.mesh_format)
self.renderer.export_mesh(path)
print(f"[INFO] save model to {path}.")
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(
self.W,
self.H,
self.buffer_image,
format=dpg.mvFormat_Float_rgb,
tag="_texture",
)
### register window
# the rendered image, as the primary window
with dpg.window(
tag="_primary_window",
width=self.W,
height=self.H,
pos=[0, 0],
no_move=True,
no_title_bar=True,
no_scrollbar=True,
):
# add the texture
dpg.add_image("_texture")
# dpg.set_primary_window("_primary_window", True)
# control window
with dpg.window(
label="Control",
tag="_control_window",
width=600,
height=self.H,
pos=[self.W, 0],
no_move=True,
no_title_bar=True,
):
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# timer stuff
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
def callback_setattr(sender, app_data, user_data):
setattr(self, user_data, app_data)
# init stuff
with dpg.collapsing_header(label="Initialize", default_open=True):
# seed stuff
def callback_set_seed(sender, app_data):
self.seed = app_data
self.seed_everything()
dpg.add_input_text(
label="seed",
default_value=self.seed,
on_enter=True,
callback=callback_set_seed,
)
# input stuff
def callback_select_input(sender, app_data):
# only one item
for k, v in app_data["selections"].items():
dpg.set_value("_log_input", k)
self.load_input(v)
self.need_update = True
with dpg.file_dialog(
directory_selector=False,
show=False,
callback=callback_select_input,
file_count=1,
tag="file_dialog_tag",
width=700,
height=400,
):
dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
with dpg.group(horizontal=True):
dpg.add_button(
label="input",
callback=lambda: dpg.show_item("file_dialog_tag"),
)
dpg.add_text("", tag="_log_input")
# overlay stuff
with dpg.group(horizontal=True):
def callback_toggle_overlay_input_img(sender, app_data):
self.overlay_input_img = not self.overlay_input_img
self.need_update = True
dpg.add_checkbox(
label="overlay image",
default_value=self.overlay_input_img,
callback=callback_toggle_overlay_input_img,
)
def callback_set_overlay_input_img_ratio(sender, app_data):
self.overlay_input_img_ratio = app_data
self.need_update = True
dpg.add_slider_float(
label="ratio",
min_value=0,
max_value=1,
format="%.1f",
default_value=self.overlay_input_img_ratio,
callback=callback_set_overlay_input_img_ratio,
)
# prompt stuff
dpg.add_input_text(
label="prompt",
default_value=self.prompt,
callback=callback_setattr,
user_data="prompt",
)
dpg.add_input_text(
label="negative",
default_value=self.negative_prompt,
callback=callback_setattr,
user_data="negative_prompt",
)
# save current model
with dpg.group(horizontal=True):
dpg.add_text("Save: ")
dpg.add_button(
label="model",
tag="_button_save_model",
callback=self.save_model,
)
dpg.bind_item_theme("_button_save_model", theme_button)
dpg.add_input_text(
label="",
default_value=self.opt.save_path,
callback=callback_setattr,
user_data="save_path",
)
# training stuff
with dpg.collapsing_header(label="Train", default_open=True):
# lr and train button
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.prepare_train()
self.training = True
dpg.configure_item("_button_train", label="stop")
# dpg.add_button(
# label="init", tag="_button_init", callback=self.prepare_train
# )
# dpg.bind_item_theme("_button_init", theme_button)
dpg.add_button(
label="start", tag="_button_train", callback=callback_train
)
dpg.bind_item_theme("_button_train", theme_button)
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_time")
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Rendering", default_open=True):
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(
("image", "depth", "alpha", "normal"),
label="mode",
default_value=self.mode,
callback=callback_change_mode,
)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = np.deg2rad(app_data)
self.need_update = True
dpg.add_slider_int(
label="FoV (vertical)",
min_value=1,
max_value=120,
format="%d deg",
default_value=np.rad2deg(self.cam.fovy),
callback=callback_set_fovy,
)
### register camera handler
def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
def callback_set_mouse_loc(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
# just the pixel coordinate in image
self.mouse_loc = np.array(app_data)
with dpg.handler_registry():
# for camera moving
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Left,
callback=callback_camera_drag_rotate_or_draw_mask,
)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
)
dpg.create_viewport(
title="Gaussian3D",
width=self.W + 600,
height=self.H + (45 if os.name == "nt" else 0),
resizable=False,
)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
### register a larger font
# get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
if os.path.exists("LXGWWenKai-Regular.ttf"):
with dpg.font_registry():
with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
dpg.bind_font(default_font)
# dpg.show_metrics()
dpg.show_viewport()
def render(self):
assert self.gui
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
self.test_step()
dpg.render_dearpygui_frame()
# no gui mode
def train(self, iters=500):
if iters > 0:
self.prepare_train()
for i in tqdm.trange(iters):
self.train_step()
# save
self.save_model()
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to the yaml config file")
args, extras = parser.parse_known_args()
# override default config from cli
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
# auto find mesh from stage 1
if opt.mesh is None:
default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.' + opt.mesh_format)
if os.path.exists(default_path):
opt.mesh = default_path
else:
raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!")
gui = GUI(opt)
if opt.gui:
gui.render()
else:
gui.train(opt.iters_refine)