Spaces:
Sleeping
Sleeping
from typing import * | |
import numpy as np | |
import torch | |
import utils3d | |
import nvdiffrast.torch as dr | |
from tqdm import tqdm | |
import trimesh | |
import trimesh.visual | |
import xatlas | |
import pyvista as pv | |
from pymeshfix import _meshfix | |
import igraph | |
import cv2 | |
from PIL import Image | |
from .random_utils import sphere_hammersley_sequence | |
from .render_utils import render_multiview | |
from ..representations import Strivec, Gaussian, MeshExtractResult | |
def _fill_holes( | |
verts, | |
faces, | |
max_hole_size=0.04, | |
max_hole_nbe=32, | |
resolution=128, | |
num_views=500, | |
debug=False, | |
verbose=False | |
): | |
""" | |
Rasterize a mesh from multiple views and remove invisible faces. | |
Also includes postprocessing to: | |
1. Remove connected components that are have low visibility. | |
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. | |
Args: | |
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). | |
faces (torch.Tensor): Faces of the mesh. Shape (F, 3). | |
max_hole_size (float): Maximum area of a hole to fill. | |
resolution (int): Resolution of the rasterization. | |
num_views (int): Number of views to rasterize the mesh. | |
verbose (bool): Whether to print progress. | |
""" | |
# Construct cameras | |
yaws = [] | |
pitchs = [] | |
for i in range(num_views): | |
y, p = sphere_hammersley_sequence(i, num_views) | |
yaws.append(y) | |
pitchs.append(p) | |
yaws = torch.tensor(yaws).cuda() | |
pitchs = torch.tensor(pitchs).cuda() | |
radius = 2.0 | |
fov = torch.deg2rad(torch.tensor(40)).cuda() | |
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) | |
views = [] | |
for (yaw, pitch) in zip(yaws, pitchs): | |
orig = torch.tensor([ | |
torch.sin(yaw) * torch.cos(pitch), | |
torch.cos(yaw) * torch.cos(pitch), | |
torch.sin(pitch), | |
]).cuda().float() * radius | |
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) | |
views.append(view) | |
views = torch.stack(views, dim=0) | |
# Rasterize | |
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) | |
rastctx = utils3d.torch.RastContext(backend='cuda') | |
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): | |
view = views[i] | |
buffers = utils3d.torch.rasterize_triangle_faces( | |
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection | |
) | |
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 | |
face_id = torch.unique(face_id).long() | |
visblity[face_id] += 1 | |
visblity = visblity.float() / num_views | |
# Mincut | |
## construct outer faces | |
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) | |
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) | |
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) | |
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) | |
for i in range(len(connected_components)): | |
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) | |
outer_face_indices = outer_face_indices.nonzero().reshape(-1) | |
## construct inner faces | |
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) | |
if verbose: | |
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') | |
if inner_face_indices.shape[0] == 0: | |
return verts, faces | |
## Construct dual graph (faces as nodes, edges as edges) | |
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) | |
dual_edge2edge = edges[dual_edge2edge] | |
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) | |
if verbose: | |
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') | |
## solve mincut problem | |
### construct main graph | |
g = igraph.Graph() | |
g.add_vertices(faces.shape[0]) | |
g.add_edges(dual_edges.cpu().numpy()) | |
g.es['weight'] = dual_edges_weights.cpu().numpy() | |
### source and target | |
g.add_vertex('s') | |
g.add_vertex('t') | |
### connect invisible faces to source | |
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) | |
### connect outer faces to target | |
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) | |
### solve mincut | |
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) | |
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) | |
if verbose: | |
tqdm.write(f'Mincut solved, start checking the cut') | |
### check if the cut is valid with each connected component | |
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) | |
if debug: | |
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') | |
valid_remove_cc = [] | |
cutting_edges = [] | |
for cc in to_remove_cc: | |
#### check if the connected component has low visibility | |
visblity_median = visblity[remove_face_indices[cc]].median() | |
if debug: | |
tqdm.write(f'visblity_median: {visblity_median}') | |
if visblity_median > 0.25: | |
continue | |
#### check if the cuting loop is small enough | |
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) | |
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] | |
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] | |
if len(cc_new_boundary_edge_indices) > 0: | |
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) | |
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] | |
cc_new_boundary_edges_cc_area = [] | |
for i, edge_cc in enumerate(cc_new_boundary_edge_cc): | |
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] | |
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] | |
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) | |
if debug: | |
cutting_edges.append(cc_new_boundary_edge_indices) | |
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') | |
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): | |
continue | |
valid_remove_cc.append(cc) | |
if debug: | |
face_v = verts[faces].mean(dim=1).cpu().numpy() | |
vis_dual_edges = dual_edges.cpu().numpy() | |
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) | |
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] | |
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] | |
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] | |
if len(valid_remove_cc) > 0: | |
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] | |
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) | |
vis_verts = verts.cpu().numpy() | |
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() | |
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) | |
if len(valid_remove_cc) > 0: | |
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] | |
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) | |
mask[remove_face_indices] = 0 | |
faces = faces[mask] | |
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) | |
if verbose: | |
tqdm.write(f'Removed {(~mask).sum()} faces by mincut') | |
else: | |
if verbose: | |
tqdm.write(f'Removed 0 faces by mincut') | |
mesh = _meshfix.PyTMesh() | |
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) | |
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) | |
verts, faces = mesh.return_arrays() | |
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) | |
return verts, faces | |
def postprocess_mesh( | |
vertices: np.array, | |
faces: np.array, | |
simplify: bool = True, | |
simplify_ratio: float = 0.9, | |
fill_holes: bool = True, | |
fill_holes_max_hole_size: float = 0.04, | |
fill_holes_max_hole_nbe: int = 32, | |
fill_holes_resolution: int = 1024, | |
fill_holes_num_views: int = 1000, | |
debug: bool = False, | |
verbose: bool = False, | |
): | |
""" | |
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. | |
Args: | |
vertices (np.array): Vertices of the mesh. Shape (V, 3). | |
faces (np.array): Faces of the mesh. Shape (F, 3). | |
simplify (bool): Whether to simplify the mesh, using quadric edge collapse. | |
simplify_ratio (float): Ratio of faces to keep after simplification. | |
fill_holes (bool): Whether to fill holes in the mesh. | |
fill_holes_max_hole_size (float): Maximum area of a hole to fill. | |
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. | |
fill_holes_resolution (int): Resolution of the rasterization. | |
fill_holes_num_views (int): Number of views to rasterize the mesh. | |
verbose (bool): Whether to print progress. | |
""" | |
if verbose: | |
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') | |
# Simplify | |
if simplify and simplify_ratio > 0: | |
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) | |
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) | |
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] | |
if verbose: | |
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') | |
# Remove invisible faces | |
if fill_holes: | |
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() | |
vertices, faces = _fill_holes( | |
vertices, faces, | |
max_hole_size=fill_holes_max_hole_size, | |
max_hole_nbe=fill_holes_max_hole_nbe, | |
resolution=fill_holes_resolution, | |
num_views=fill_holes_num_views, | |
debug=debug, | |
verbose=verbose, | |
) | |
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() | |
if verbose: | |
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') | |
return vertices, faces | |
def parametrize_mesh(vertices: np.array, faces: np.array): | |
""" | |
Parametrize a mesh to a texture space, using xatlas. | |
Args: | |
vertices (np.array): Vertices of the mesh. Shape (V, 3). | |
faces (np.array): Faces of the mesh. Shape (F, 3). | |
""" | |
vmapping, indices, uvs = xatlas.parametrize(vertices, faces) | |
vertices = vertices[vmapping] | |
faces = indices | |
return vertices, faces, uvs | |
def bake_texture( | |
vertices: np.array, | |
faces: np.array, | |
uvs: np.array, | |
observations: List[np.array], | |
masks: List[np.array], | |
extrinsics: List[np.array], | |
intrinsics: List[np.array], | |
texture_size: int = 2048, | |
near: float = 0.1, | |
far: float = 10.0, | |
mode: Literal['fast', 'opt'] = 'opt', | |
lambda_tv: float = 1e-2, | |
verbose: bool = False, | |
): | |
""" | |
Bake texture to a mesh from multiple observations. | |
Args: | |
vertices (np.array): Vertices of the mesh. Shape (V, 3). | |
faces (np.array): Faces of the mesh. Shape (F, 3). | |
uvs (np.array): UV coordinates of the mesh. Shape (V, 2). | |
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). | |
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). | |
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). | |
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). | |
texture_size (int): Size of the texture. | |
near (float): Near plane of the camera. | |
far (float): Far plane of the camera. | |
mode (Literal['fast', 'opt']): Mode of texture baking. | |
lambda_tv (float): Weight of total variation loss in optimization. | |
verbose (bool): Whether to print progress. | |
""" | |
vertices = torch.tensor(vertices).cuda() | |
faces = torch.tensor(faces.astype(np.int32)).cuda() | |
uvs = torch.tensor(uvs).cuda() | |
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] | |
masks = [torch.tensor(m>0).bool().cuda() for m in masks] | |
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] | |
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] | |
if mode == 'fast': | |
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() | |
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() | |
rastctx = utils3d.torch.RastContext(backend='cuda') | |
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): | |
with torch.no_grad(): | |
rast = utils3d.torch.rasterize_triangle_faces( | |
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection | |
) | |
uv_map = rast['uv'][0].detach().flip(0) | |
mask = rast['mask'][0].detach().bool() & masks[0] | |
# nearest neighbor interpolation | |
uv_map = (uv_map * texture_size).floor().long() | |
obs = observation[mask] | |
uv_map = uv_map[mask] | |
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size | |
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) | |
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) | |
mask = texture_weights > 0 | |
texture[mask] /= texture_weights[mask][:, None] | |
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
# inpaint | |
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) | |
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) | |
elif mode == 'opt': | |
rastctx = utils3d.torch.RastContext(backend='cuda') | |
observations = [observations.flip(0) for observations in observations] | |
masks = [m.flip(0) for m in masks] | |
_uv = [] | |
_uv_dr = [] | |
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): | |
with torch.no_grad(): | |
rast = utils3d.torch.rasterize_triangle_faces( | |
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection | |
) | |
_uv.append(rast['uv'].detach()) | |
_uv_dr.append(rast['uv_dr'].detach()) | |
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) | |
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) | |
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): | |
return start_lr * (end_lr / start_lr) ** (step / total_steps) | |
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): | |
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) | |
def tv_loss(texture): | |
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ | |
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) | |
total_steps = 2500 | |
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: | |
for step in range(total_steps): | |
optimizer.zero_grad() | |
selected = np.random.randint(0, len(views)) | |
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] | |
render = dr.texture(texture, uv, uv_dr)[0] | |
loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) | |
if lambda_tv > 0: | |
loss += lambda_tv * tv_loss(texture) | |
loss.backward() | |
optimizer.step() | |
# annealing | |
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) | |
pbar.set_postfix({'loss': loss.item()}) | |
pbar.update() | |
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) | |
mask = 1 - utils3d.torch.rasterize_triangle_faces( | |
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size | |
)['mask'][0].detach().cpu().numpy().astype(np.uint8) | |
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) | |
else: | |
raise ValueError(f'Unknown mode: {mode}') | |
return texture | |
def to_glb( | |
app_rep: Union[Strivec, Gaussian], | |
mesh: MeshExtractResult, | |
simplify: float = 0.95, | |
fill_holes: bool = True, | |
fill_holes_max_size: float = 0.04, | |
texture_size: int = 1024, | |
debug: bool = False, | |
verbose: bool = True, | |
) -> trimesh.Trimesh: | |
""" | |
Convert a generated asset to a glb file. | |
Args: | |
app_rep (Union[Strivec, Gaussian]): Appearance representation. | |
mesh (MeshExtractResult): Extracted mesh. | |
simplify (float): Ratio of faces to remove in simplification. | |
fill_holes (bool): Whether to fill holes in the mesh. | |
fill_holes_max_size (float): Maximum area of a hole to fill. | |
texture_size (int): Size of the texture. | |
debug (bool): Whether to print debug information. | |
verbose (bool): Whether to print progress. | |
""" | |
vertices = mesh.vertices.cpu().numpy() | |
faces = mesh.faces.cpu().numpy() | |
# mesh postprocess | |
vertices, faces = postprocess_mesh( | |
vertices, faces, | |
simplify=simplify > 0, | |
simplify_ratio=simplify, | |
fill_holes=fill_holes, | |
fill_holes_max_hole_size=fill_holes_max_size, | |
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), | |
fill_holes_resolution=1024, | |
fill_holes_num_views=1000, | |
debug=debug, | |
verbose=verbose, | |
) | |
# parametrize mesh | |
vertices, faces, uvs = parametrize_mesh(vertices, faces) | |
# bake texture | |
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) | |
masks = [np.any(observation > 0, axis=-1) for observation in observations] | |
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] | |
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] | |
texture = bake_texture( | |
vertices, faces, uvs, | |
observations, masks, extrinsics, intrinsics, | |
texture_size=texture_size, mode='opt', | |
lambda_tv=0.01, | |
verbose=True | |
) | |
texture = Image.fromarray(texture) | |
# rotate mesh (from z-up to y-up) | |
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) | |
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture)) | |
return mesh | |