salad-demo / salad /utils /meshutil.py
DveloperY0115's picture
init repo
801501a
raw
history blame
7.24 kB
from enum import Enum
import numpy as np
import torch
import trimesh
from salad.utils import thutil
def write_obj(name: str, vertices: np.ndarray, faces: np.ndarray):
"""
name: filename
vertices: (V,3)
faces: (F,3) Assume the mesh is a triangle mesh.
"""
vertices = thutil.th2np(vertices)
faces = thutil.th2np(faces).astype(np.uint32)
fout = open(name, "w")
for ii in range(len(vertices)):
fout.write(
"v "
+ str(vertices[ii, 0])
+ " "
+ str(vertices[ii, 1])
+ " "
+ str(vertices[ii, 2])
+ "\n"
)
for ii in range(len(faces)):
fout.write(
"f "
+ str(faces[ii, 0] + 1)
+ " "
+ str(faces[ii, 1] + 1)
+ " "
+ str(faces[ii, 2] + 1)
+ "\n"
)
fout.close()
def write_obj_triangle(name: str, vertices: np.ndarray, triangles: np.ndarray):
fout = open(name, "w")
for ii in range(len(vertices)):
fout.write(
"v "
+ str(vertices[ii, 0])
+ " "
+ str(vertices[ii, 1])
+ " "
+ str(vertices[ii, 2])
+ "\n"
)
for ii in range(len(triangles)):
fout.write(
"f "
+ str(triangles[ii, 0] + 1)
+ " "
+ str(triangles[ii, 1] + 1)
+ " "
+ str(triangles[ii, 2] + 1)
+ "\n"
)
fout.close()
def write_obj_polygon(name: str, vertices: np.ndarray, polygons: np.ndarray):
fout = open(name, "w")
for ii in range(len(vertices)):
fout.write(
"v "
+ str(vertices[ii][0])
+ " "
+ str(vertices[ii][1])
+ " "
+ str(vertices[ii][2])
+ "\n"
)
for ii in range(len(polygons)):
fout.write("f")
for jj in range(len(polygons[ii])):
fout.write(" " + str(polygons[ii][jj] + 1))
fout.write("\n")
fout.close()
def read_obj(name: str):
verts = []
faces = []
with open(name, "r") as f:
lines = [line.rstrip() for line in f]
for line in lines:
if line.startswith("v "):
verts.append(np.float32(line.split()[1:4]))
elif line.startswith("f "):
faces.append(
np.int32([item.split("/")[0] for item in line.split()[1:4]])
)
v = np.vstack(verts)
f = np.vstack(faces) - 1
return v, f
def scene_as_mesh(scene_or_mesh):
if isinstance(scene_or_mesh, trimesh.Scene):
if len(scene_or_mesh.geometry) == 0:
mesh = None
else:
mesh = trimesh.util.concatenate(
tuple(
trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
for g in scene_or_mesh.geometry.values()
if g.faces.shape[1] == 3
)
)
else:
mesh = scene_or_mesh
return mesh
def get_center(verts):
max_vals = verts.max(0)
min_vals = verts.min(0)
center = (max_vals + min_vals) / 2
return center
def to_center(verts):
verts -= get_center(verts)[None, :]
return verts
def get_offset_and_scale(verts, radius=1.0):
verts = thutil.th2np(verts)
verts = verts.copy()
offset = get_center(verts)[None, :]
verts -= offset
scale = 1 / np.linalg.norm(verts, axis=1).max() * radius
return offset, scale
def normalize_mesh(mesh: trimesh.Trimesh):
# unit cube normalization
v, f = np.array(mesh.vertices), np.array(mesh.faces)
maxv, minv = np.max(v, 0), np.min(v, 0)
offset = minv
v = v - offset
scale = np.sqrt(np.sum((maxv - minv) ** 2))
v = v / scale
normed_mesh = trimesh.Trimesh(vertices=v, faces=f, process=False)
return dict(mesh=normed_mesh, offset=offset, scale=scale)
def normalize_scene(scene: trimesh.Scene):
mesh_merged = scene_as_mesh(scene)
out = normalize_mesh(mesh_merged)
offset = out["offset"]
scale = out["scale"]
submesh_normalized_list = []
for i, submesh in enumerate(list(scene.geometry.values())):
v, f = np.array(submesh.vertices), np.array(submesh.faces)
v = v - offset
v = v / scale
submesh_normalized_list.append(trimesh.Trimesh(v, f))
return trimesh.Scene(submesh_normalized_list)
class SampleBy(Enum):
AREAS = 0
FACES = 1
HYB = 2
def get_faces_normals(mesh):
if type(mesh) is not torch.Tensor:
vs, faces = mesh
vs_faces = vs[faces]
else:
vs_faces = mesh
if vs_faces.shape[-1] == 2:
vs_faces = torch.cat(
(
vs_faces,
torch.zeros(
*vs_faces.shape[:2], 1, dtype=vs_faces.dtype, device=vs_faces.device
),
),
dim=2,
)
face_normals = torch.cross(
vs_faces[:, 1, :] - vs_faces[:, 0, :], vs_faces[:, 2, :] - vs_faces[:, 1, :]
)
return face_normals
def compute_face_areas(mesh):
face_normals = get_faces_normals(mesh)
face_areas = torch.norm(face_normals, p=2, dim=1)
face_areas_ = face_areas.clone()
face_areas_[torch.eq(face_areas_, 0)] = 1
face_normals = face_normals / face_areas_[:, None]
face_areas = 0.5 * face_areas
return face_areas, face_normals
def sample_uvw(shape, device):
u, v = torch.rand(*shape, device=device), torch.rand(*shape, device=device)
mask = (u + v).gt(1)
u[mask], v[mask] = -u[mask] + 1, -v[mask] + 1
w = -u - v + 1
uvw = torch.stack([u, v, w], dim=len(shape))
return uvw
def sample_on_mesh(mesh, num_samples: int, face_areas=None, sample_s=SampleBy.HYB):
vs, faces = mesh
if faces is None: # sample from pc
uvw = None
if vs.shape[0] < num_samples:
chosen_faces_inds = torch.arange(vs.shape[0])
else:
chosen_faces_inds = torch.argsort(torch.rand(vs.shape[0]))[:num_samples]
samples = vs[chosen_faces_inds]
else:
weighted_p = []
if sample_s == SampleBy.AREAS or sample_s == SampleBy.HYB:
if face_areas is None:
face_areas, _ = compute_face_areas(mesh)
face_areas[torch.isnan(face_areas)] = 0
weighted_p.append(face_areas / face_areas.sum())
if sample_s == SampleBy.FACES or sample_s == SampleBy.HYB:
weighted_p.append(torch.ones(mesh[1].shape[0], device=mesh[0].device))
chosen_faces_inds = [
torch.multinomial(weights, num_samples // len(weighted_p), replacement=True)
for weights in weighted_p
]
if sample_s == SampleBy.HYB:
chosen_faces_inds = torch.cat(chosen_faces_inds, dim=0)
chosen_faces = faces[chosen_faces_inds]
uvw = sample_uvw([num_samples], vs.device)
samples = torch.einsum("sf,sfd->sd", uvw, vs[chosen_faces])
return samples, chosen_faces_inds, uvw
def repair_normals(v, f):
mesh = trimesh.Trimesh(v, f)
trimesh.repair.fix_normals(mesh)
v = mesh.vertices
f = np.asarray(mesh.faces)
return v, f