thewhole's picture
Upload 245 files
2fa4776
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import (
BaseExplicitGeometry,
BaseGeometry,
contract_to_unisphere,
)
from threestudio.models.mesh import Mesh
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import scale_tensor
from threestudio.utils.typing import *
@threestudio.register("custom-mesh")
class CustomMesh(BaseExplicitGeometry):
@dataclass
class Config(BaseExplicitGeometry.Config):
n_input_dims: int = 3
n_feature_dims: int = 3
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
shape_init: str = ""
shape_init_params: Optional[Any] = None
shape_init_mesh_up: str = "+z"
shape_init_mesh_front: str = "+x"
cfg: Config
def configure(self) -> None:
super().configure()
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
# Initialize custom mesh
if self.cfg.shape_init.startswith("mesh:"):
assert isinstance(self.cfg.shape_init_params, float)
mesh_path = self.cfg.shape_init[5:]
if not os.path.exists(mesh_path):
raise ValueError(f"Mesh file {mesh_path} does not exist.")
import trimesh
scene = trimesh.load(mesh_path)
if isinstance(scene, trimesh.Trimesh):
mesh = scene
elif isinstance(scene, trimesh.scene.Scene):
mesh = trimesh.Trimesh()
for obj in scene.geometry.values():
mesh = trimesh.util.concatenate([mesh, obj])
else:
raise ValueError(f"Unknown mesh type at {mesh_path}.")
# move to center
centroid = mesh.vertices.mean(0)
mesh.vertices = mesh.vertices - centroid
# align to up-z and front-x
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
dir2vec = {
"+x": np.array([1, 0, 0]),
"+y": np.array([0, 1, 0]),
"+z": np.array([0, 0, 1]),
"-x": np.array([-1, 0, 0]),
"-y": np.array([0, -1, 0]),
"-z": np.array([0, 0, -1]),
}
if (
self.cfg.shape_init_mesh_up not in dirs
or self.cfg.shape_init_mesh_front not in dirs
):
raise ValueError(
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
)
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
raise ValueError(
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
)
z_, x_ = (
dir2vec[self.cfg.shape_init_mesh_up],
dir2vec[self.cfg.shape_init_mesh_front],
)
y_ = np.cross(z_, x_)
std2mesh = np.stack([x_, y_, z_], axis=0).T
mesh2std = np.linalg.inv(std2mesh)
# scaling
scale = np.abs(mesh.vertices).max()
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device)
t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device)
self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
self.register_buffer(
"v_buffer",
v_pos,
)
self.register_buffer(
"t_buffer",
t_pos_idx,
)
else:
raise ValueError(
f"Unknown shape initialization type: {self.cfg.shape_init}"
)
print(self.mesh.v_pos.device)
def isosurface(self) -> Mesh:
if hasattr(self, "mesh"):
return self.mesh
elif hasattr(self, "v_buffer"):
self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer)
return self.mesh
else:
raise ValueError(f"custom mesh is not initialized")
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
assert (
output_normal == False
), f"Normal output is not supported for {self.__class__.__name__}"
points_unscaled = points # points in the original scale
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
return {"features": features}
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out