import os from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import threestudio from threestudio.models.geometry.base import ( BaseExplicitGeometry, BaseGeometry, contract_to_unisphere, ) from threestudio.models.geometry.implicit_sdf import ImplicitSDF from threestudio.models.geometry.implicit_volume import ImplicitVolume from threestudio.models.isosurface import MarchingTetrahedraHelper from threestudio.models.mesh import Mesh from threestudio.models.networks import get_encoding, get_mlp from threestudio.utils.misc import broadcast from threestudio.utils.ops import scale_tensor from threestudio.utils.typing import * @threestudio.register("tetrahedra-sdf-grid") class TetrahedraSDFGrid(BaseExplicitGeometry): @dataclass class Config(BaseExplicitGeometry.Config): isosurface_resolution: int = 128 isosurface_deformable_grid: bool = True isosurface_remove_outliers: bool = False isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 n_input_dims: int = 3 n_feature_dims: int = 3 pos_encoding_config: dict = field( default_factory=lambda: { "otype": "HashGrid", "n_levels": 16, "n_features_per_level": 2, "log2_hashmap_size": 19, "base_resolution": 16, "per_level_scale": 1.447269237440378, } ) mlp_network_config: dict = field( default_factory=lambda: { "otype": "VanillaMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 1, } ) shape_init: Optional[str] = None shape_init_params: Optional[Any] = None shape_init_mesh_up: str = "+z" shape_init_mesh_front: str = "+x" force_shape_init: bool = False geometry_only: bool = False fix_geometry: bool = False cfg: Config def configure(self) -> None: super().configure() # this should be saved to state_dict, register as buffer self.isosurface_bbox: Float[Tensor, "2 3"] self.register_buffer("isosurface_bbox", self.bbox.clone()) self.isosurface_helper = MarchingTetrahedraHelper( self.cfg.isosurface_resolution, f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", ) self.sdf: Float[Tensor, "Nv 1"] self.deformation: Optional[Float[Tensor, "Nv 3"]] if not self.cfg.fix_geometry: self.register_parameter( "sdf", nn.Parameter( torch.zeros( (self.isosurface_helper.grid_vertices.shape[0], 1), dtype=torch.float32, ) ), ) if self.cfg.isosurface_deformable_grid: self.register_parameter( "deformation", nn.Parameter( torch.zeros_like(self.isosurface_helper.grid_vertices) ), ) else: self.deformation = None else: self.register_buffer( "sdf", torch.zeros( (self.isosurface_helper.grid_vertices.shape[0], 1), dtype=torch.float32, ), ) if self.cfg.isosurface_deformable_grid: self.register_buffer( "deformation", torch.zeros_like(self.isosurface_helper.grid_vertices), ) else: self.deformation = None if not self.cfg.geometry_only: self.encoding = get_encoding( self.cfg.n_input_dims, self.cfg.pos_encoding_config ) self.feature_network = get_mlp( self.encoding.n_output_dims, self.cfg.n_feature_dims, self.cfg.mlp_network_config, ) self.mesh: Optional[Mesh] = None def initialize_shape(self) -> None: if self.cfg.shape_init is None and not self.cfg.force_shape_init: return # do not initialize shape if weights are provided if self.cfg.weights is not None and not self.cfg.force_shape_init: return get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]] assert isinstance(self.cfg.shape_init, str) if self.cfg.shape_init == "ellipsoid": assert ( isinstance(self.cfg.shape_init_params, Sized) and len(self.cfg.shape_init_params) == 3 ) size = torch.as_tensor(self.cfg.shape_init_params).to(self.device) def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: return ((points_rand / size) ** 2).sum( dim=-1, keepdim=True ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid get_gt_sdf = func elif self.cfg.shape_init == "sphere": assert isinstance(self.cfg.shape_init_params, float) radius = self.cfg.shape_init_params def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius get_gt_sdf = func elif self.cfg.shape_init.startswith("mesh:"): assert isinstance(self.cfg.shape_init_params, float) mesh_path = self.cfg.shape_init[5:] if not os.path.exists(mesh_path): raise ValueError(f"Mesh file {mesh_path} does not exist.") import trimesh mesh = trimesh.load(mesh_path) # move to center centroid = mesh.vertices.mean(0) mesh.vertices = mesh.vertices - centroid # align to up-z and front-x dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] dir2vec = { "+x": np.array([1, 0, 0]), "+y": np.array([0, 1, 0]), "+z": np.array([0, 0, 1]), "-x": np.array([-1, 0, 0]), "-y": np.array([0, -1, 0]), "-z": np.array([0, 0, -1]), } if ( self.cfg.shape_init_mesh_up not in dirs or self.cfg.shape_init_mesh_front not in dirs ): raise ValueError( f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." ) if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: raise ValueError( "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." ) z_, x_ = ( dir2vec[self.cfg.shape_init_mesh_up], dir2vec[self.cfg.shape_init_mesh_front], ) y_ = np.cross(z_, x_) std2mesh = np.stack([x_, y_, z_], axis=0).T mesh2std = np.linalg.inv(std2mesh) # scaling scale = np.abs(mesh.vertices).max() mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T from pysdf import SDF sdf = SDF(mesh.vertices, mesh.faces) def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: # add a negative signed here # as in pysdf the inside of the shape has positive signed distance return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to( points_rand )[..., None] get_gt_sdf = func else: raise ValueError( f"Unknown shape initialization type: {self.cfg.shape_init}" ) sdf_gt = get_gt_sdf( scale_tensor( self.isosurface_helper.grid_vertices, self.isosurface_helper.points_range, self.isosurface_bbox, ) ) self.sdf.data = sdf_gt # explicit broadcast to ensure param consistency across ranks for param in self.parameters(): broadcast(param, src=0) def isosurface(self) -> Mesh: # return cached mesh if fix_geometry is True to save computation if self.cfg.fix_geometry and self.mesh is not None: return self.mesh mesh = self.isosurface_helper(self.sdf, self.deformation) mesh.v_pos = scale_tensor( mesh.v_pos, self.isosurface_helper.points_range, self.isosurface_bbox ) if self.cfg.isosurface_remove_outliers: mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) self.mesh = mesh return mesh def forward( self, points: Float[Tensor, "*N Di"], output_normal: bool = False ) -> Dict[str, Float[Tensor, "..."]]: if self.cfg.geometry_only: return {} assert ( output_normal == False ), f"Normal output is not supported for {self.__class__.__name__}" points_unscaled = points # points in the original scale points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1) enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) return {"features": features} @staticmethod @torch.no_grad() def create_from( other: BaseGeometry, cfg: Optional[Union[dict, DictConfig]] = None, copy_net: bool = True, **kwargs, ) -> "TetrahedraSDFGrid": if isinstance(other, TetrahedraSDFGrid): instance = TetrahedraSDFGrid(cfg, **kwargs) assert instance.cfg.isosurface_resolution == other.cfg.isosurface_resolution instance.isosurface_bbox = other.isosurface_bbox.clone() instance.sdf.data = other.sdf.data.clone() if ( instance.cfg.isosurface_deformable_grid and other.cfg.isosurface_deformable_grid ): assert ( instance.deformation is not None and other.deformation is not None ) instance.deformation.data = other.deformation.data.clone() if ( not instance.cfg.geometry_only and not other.cfg.geometry_only and copy_net ): instance.encoding.load_state_dict(other.encoding.state_dict()) instance.feature_network.load_state_dict( other.feature_network.state_dict() ) return instance elif isinstance(other, ImplicitVolume): instance = TetrahedraSDFGrid(cfg, **kwargs) if other.cfg.isosurface_method != "mt": other.cfg.isosurface_method = "mt" threestudio.warn( f"Override isosurface_method of the source geometry to 'mt'" ) if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution: other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution threestudio.warn( f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}" ) mesh = other.isosurface() instance.isosurface_bbox = mesh.extras["bbox"] instance.sdf.data = ( mesh.extras["grid_level"].to(instance.sdf.data).clamp(-1, 1) ) if not instance.cfg.geometry_only and copy_net: instance.encoding.load_state_dict(other.encoding.state_dict()) instance.feature_network.load_state_dict( other.feature_network.state_dict() ) return instance elif isinstance(other, ImplicitSDF): instance = TetrahedraSDFGrid(cfg, **kwargs) if other.cfg.isosurface_method != "mt": other.cfg.isosurface_method = "mt" threestudio.warn( f"Override isosurface_method of the source geometry to 'mt'" ) if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution: other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution threestudio.warn( f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}" ) mesh = other.isosurface() instance.isosurface_bbox = mesh.extras["bbox"] instance.sdf.data = mesh.extras["grid_level"].to(instance.sdf.data) if ( instance.cfg.isosurface_deformable_grid and other.cfg.isosurface_deformable_grid ): assert instance.deformation is not None instance.deformation.data = mesh.extras["grid_deformation"].to( instance.deformation.data ) if not instance.cfg.geometry_only and copy_net: instance.encoding.load_state_dict(other.encoding.state_dict()) instance.feature_network.load_state_dict( other.feature_network.state_dict() ) return instance else: raise TypeError( f"Cannot create {TetrahedraSDFGrid.__name__} from {other.__class__.__name__}" ) def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: out: Dict[str, Any] = {} if self.cfg.geometry_only or self.cfg.n_feature_dims == 0: return out points_unscaled = points points = contract_to_unisphere(points_unscaled, self.bbox) enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) out.update( { "features": features, } ) return out