Spaces:
Sleeping
Sleeping
from typing import Optional | |
import numpy as np | |
import torch as th | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from sklearn.neighbors import KDTree | |
import logging | |
logger = logging.getLogger(__name__) | |
# NOTE: we need pytorch3d primarily for UV rasterization things | |
from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes | |
from pytorch3d.structures import Meshes | |
from typing import Union, Optional, Tuple | |
import trimesh | |
from trimesh import Trimesh | |
from trimesh.triangles import points_to_barycentric | |
try: | |
# pyre-fixme[21]: Could not find module `igl`. | |
from igl import point_mesh_squared_distance # @manual | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def closest_point(mesh, points): | |
"""Helper function that mimics trimesh.proximity.closest_point but uses | |
IGL for faster queries.""" | |
v = mesh.vertices | |
vi = mesh.faces | |
dist, face_idxs, p = point_mesh_squared_distance(points, v, vi) | |
return p, dist, face_idxs | |
except ImportError: | |
from trimesh.proximity import closest_point | |
def closest_point_barycentrics(v, vi, points): | |
"""Given a 3D mesh and a set of query points, return closest point barycentrics | |
Args: | |
v: np.array (float) | |
[N, 3] mesh vertices | |
vi: np.array (int) | |
[N, 3] mesh triangle indices | |
points: np.array (float) | |
[M, 3] query points | |
Returns: | |
Tuple[approx, barys, interp_idxs, face_idxs] | |
approx: [M, 3] approximated (closest) points on the mesh | |
barys: [M, 3] barycentric weights that produce "approx" | |
interp_idxs: [M, 3] vertex indices for barycentric interpolation | |
face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs] | |
""" | |
mesh = Trimesh(vertices=v, faces=vi, process=False) | |
p, _, face_idxs = closest_point(mesh, points) | |
p = p.reshape((points.shape[0], 3)) | |
face_idxs = face_idxs.reshape((points.shape[0],)) | |
barys = points_to_barycentric(mesh.triangles[face_idxs], p) | |
b0, b1, b2 = np.split(barys, 3, axis=1) | |
interp_idxs = vi[face_idxs] | |
v0 = v[interp_idxs[:, 0]] | |
v1 = v[interp_idxs[:, 1]] | |
v2 = v[interp_idxs[:, 2]] | |
approx = b0 * v0 + b1 * v1 + b2 * v2 | |
return approx, barys, interp_idxs, face_idxs | |
def make_uv_face_index( | |
vt: th.Tensor, | |
vti: th.Tensor, | |
uv_shape: Union[Tuple[int, int], int], | |
flip_uv: bool = True, | |
device: Optional[Union[str, th.device]] = None, | |
): | |
"""Compute a UV-space face index map identifying which mesh face contains each | |
texel. For texels with no assigned triangle, the index will be -1.""" | |
if isinstance(uv_shape, int): | |
uv_shape = (uv_shape, uv_shape) | |
uv_max_shape_ind = uv_shape.index(max(uv_shape)) | |
uv_min_shape_ind = uv_shape.index(min(uv_shape)) | |
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] | |
if device is not None: | |
if isinstance(device, str): | |
dev = th.device(device) | |
else: | |
dev = device | |
assert dev.type == "cuda" | |
else: | |
dev = th.device("cuda") | |
vt = 1.0 - vt.clone() | |
if flip_uv: | |
vt = vt.clone() | |
vt[:, 1] = 1 - vt[:, 1] | |
vt_pix = 2.0 * vt.to(dev) - 1.0 | |
vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1) | |
vt_pix[:, uv_min_shape_ind] *= uv_ratio | |
meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev)) | |
with th.no_grad(): | |
face_index, _, _, _ = rasterize_meshes( | |
meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0 | |
) | |
face_index = face_index[0, ..., 0] | |
return face_index | |
def make_uv_vert_index( | |
vt: th.Tensor, | |
vi: th.Tensor, | |
vti: th.Tensor, | |
uv_shape: Union[Tuple[int, int], int], | |
flip_uv: bool = True, | |
): | |
"""Compute a UV-space vertex index map identifying which mesh vertices | |
comprise the triangle containing each texel. For texels with no assigned | |
triangle, all indices will be -1. | |
""" | |
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv) | |
vert_index_map = vi[face_index_map.clamp(min=0)] | |
vert_index_map[face_index_map < 0] = -1 | |
return vert_index_map.long() | |
def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6): | |
"""Computes barycentric coordinates for a set of 2D query points given | |
coordintes for the 3 vertices of the enclosing triangle for each point.""" | |
x = points[:, 0] - triangles[2, :, 0] | |
x1 = triangles[0, :, 0] - triangles[2, :, 0] | |
x2 = triangles[1, :, 0] - triangles[2, :, 0] | |
y = points[:, 1] - triangles[2, :, 1] | |
y1 = triangles[0, :, 1] - triangles[2, :, 1] | |
y2 = triangles[1, :, 1] - triangles[2, :, 1] | |
denom = y2 * x1 - y1 * x2 | |
n0 = y2 * x - x2 * y | |
n1 = x1 * y - y1 * x | |
# Small epsilon to prevent divide-by-zero error. | |
denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps)) | |
bary_0 = n0 / denom | |
bary_1 = n1 / denom | |
bary_2 = 1.0 - bary_0 - bary_1 | |
return th.stack((bary_0, bary_1, bary_2)) | |
def make_uv_barys( | |
vt: th.Tensor, | |
vti: th.Tensor, | |
uv_shape: Union[Tuple[int, int], int], | |
flip_uv: bool = True, | |
): | |
"""Compute a UV-space barycentric map where each texel contains barycentric | |
coordinates for that texel within its enclosing UV triangle. For texels | |
with no assigned triangle, all 3 barycentric coordinates will be 0. | |
""" | |
if isinstance(uv_shape, int): | |
uv_shape = (uv_shape, uv_shape) | |
if flip_uv: | |
# Flip here because texture coordinates in some of our topo files are | |
# stored in OpenGL convention with Y=0 on the bottom of the texture | |
# unlike numpy/torch arrays/tensors. | |
vt = vt.clone() | |
vt[:, 1] = 1 - vt[:, 1] | |
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False) | |
vti_map = vti.long()[face_index_map.clamp(min=0)] | |
uv_max_shape_ind = uv_shape.index(max(uv_shape)) | |
uv_min_shape_ind = uv_shape.index(min(uv_shape)) | |
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] | |
vt = vt.clone() | |
vt = vt * 2 - 1 | |
vt[:, uv_min_shape_ind] *= uv_ratio | |
uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3) | |
uv_grid = th.meshgrid( | |
th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0], | |
th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1], | |
) | |
uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs) | |
uv_grid = uv_grid * 2 - 1 | |
uv_grid[..., uv_min_shape_ind] *= uv_ratio | |
bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2)) | |
bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3) | |
bary_map[face_index_map < 0] = 0 | |
return face_index_map, bary_map | |
def index_image_impaint( | |
index_image: th.Tensor, | |
bary_image: Optional[th.Tensor] = None, | |
distance_threshold=100.0, | |
): | |
# getting the mask around the indexes? | |
if len(index_image.shape) == 3: | |
valid_index = (index_image != -1).any(dim=-1) | |
elif len(index_image.shape) == 2: | |
valid_index = index_image != -1 | |
else: | |
raise ValueError("`index_image` should be a [H,W] or [H,W,C] image") | |
invalid_index = ~valid_index | |
device = index_image.device | |
valid_ij = th.stack(th.where(valid_index), dim=-1) | |
invalid_ij = th.stack(th.where(invalid_index), dim=-1) | |
lookup_valid = KDTree(valid_ij.cpu().numpy()) | |
dists, idxs = lookup_valid.query(invalid_ij.cpu()) | |
# TODO: try average? | |
idxs = th.as_tensor(idxs, device=device)[..., 0] | |
dists = th.as_tensor(dists, device=device)[..., 0] | |
dist_mask = dists < distance_threshold | |
invalid_border = th.zeros_like(invalid_index) | |
invalid_border[invalid_index] = dist_mask | |
invalid_src_ij = valid_ij[idxs][dist_mask] | |
invalid_dst_ij = invalid_ij[dist_mask] | |
index_image_imp = index_image.clone() | |
index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[ | |
invalid_src_ij[:, 0], invalid_src_ij[:, 1] | |
] | |
if bary_image is not None: | |
bary_image_imp = bary_image.clone() | |
bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[ | |
invalid_src_ij[:, 0], invalid_src_ij[:, 1] | |
] | |
return index_image_imp, bary_image_imp | |
return index_image_imp | |
class GeometryModule(nn.Module): | |
def __init__( | |
self, | |
v, | |
vi, | |
vt, | |
vti, | |
uv_size, | |
v2uv: Optional[th.Tensor] = None, | |
flip_uv=False, | |
impaint=False, | |
impaint_threshold=100.0, | |
): | |
super().__init__() | |
self.register_buffer("v", th.as_tensor(v)) | |
self.register_buffer("vi", th.as_tensor(vi)) | |
self.register_buffer("vt", th.as_tensor(vt)) | |
self.register_buffer("vti", th.as_tensor(vti)) | |
if v2uv is not None: | |
self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64)) | |
# TODO: should we just pass topology here? | |
# self.n_verts = v2uv.shape[0] | |
self.n_verts = vi.max() + 1 | |
self.uv_size = uv_size | |
# TODO: can't we just index face_index? | |
index_image = make_uv_vert_index( | |
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv | |
).cpu() | |
face_index, bary_image = make_uv_barys( | |
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv | |
) | |
if impaint: | |
if min(uv_size) >= 1024: | |
logger.info( | |
"impainting index image might take a while for sizes >= 1024" | |
) | |
index_image, bary_image = index_image_impaint( | |
index_image, bary_image, impaint_threshold | |
) | |
# TODO: we can avoid doing this 2x | |
face_index = index_image_impaint( | |
face_index, distance_threshold=impaint_threshold | |
) | |
self.register_buffer("index_image", index_image.cpu()) | |
self.register_buffer("bary_image", bary_image.cpu()) | |
self.register_buffer("face_index_image", face_index.cpu()) | |
def render_index_images(self, uv_size, flip_uv=False, impaint=False): | |
index_image = make_uv_vert_index( | |
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv | |
) | |
face_image, bary_image = make_uv_barys( | |
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv | |
) | |
if impaint: | |
index_image, bary_image = index_image_impaint( | |
index_image, | |
bary_image, | |
) | |
return index_image, face_image, bary_image | |
def vn(self, verts): | |
return vert_normals(verts, self.vi[np.newaxis].to(th.long)) | |
def to_uv(self, values): | |
return values_to_uv(values, self.index_image, self.bary_image) | |
def from_uv(self, values_uv): | |
# TODO: we need to sample this | |
return sample_uv(values_uv, self.vt, self.v2uv.to(th.long)) | |
def rand_sample_3d_uv(self, count, uv_img): | |
""" | |
Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space. | |
Args: | |
count - num of 3D points to be sampled | |
uv_img - the image in uv space to be sampled, e.g., texture | |
""" | |
_mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False) | |
points, _ = trimesh.sample.sample_surface(_mesh, count) | |
return self.sample_uv_from_3dpts(points, uv_img) | |
def sample_uv_from_3dpts(self, points, uv_img): | |
num_pts = points.shape[0] | |
approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points) | |
interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2] | |
# do bary interp first to get interp_uv_coord in high-reso uv space | |
target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float() | |
# then directly sample from uv space | |
sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c] | |
approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2]) | |
return approx_values.numpy(), points | |
def vert_sample_uv(self, uv_img): | |
count = self.v.shape[0] | |
points = self.v.detach().cpu().numpy() | |
approx_values, _ = self.sample_uv_from_3dpts(points, uv_img) | |
return approx_values | |
def sample_uv( | |
values_uv, | |
uv_coords, | |
v2uv: Optional[th.Tensor] = None, | |
mode: str = "bilinear", | |
align_corners: bool = True, | |
flip_uvs: bool = False, | |
): | |
batch_size = values_uv.shape[0] | |
if flip_uvs: | |
uv_coords = uv_coords.clone() | |
uv_coords[:, 1] = 1.0 - uv_coords[:, 1] | |
# uv_coords_norm is [1, N, 1, 2] afterwards | |
uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand( | |
batch_size, -1, -1, -1 | |
) | |
# uv_shape = values_uv.shape[-2:] | |
# uv_max_shape_ind = uv_shape.index(max(uv_shape)) | |
# uv_min_shape_ind = uv_shape.index(min(uv_shape)) | |
# uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind] | |
# uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio | |
values = ( | |
F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode) | |
.squeeze(-1) | |
.permute((0, 2, 1)) | |
) | |
if v2uv is not None: | |
values_duplicate = values[:, v2uv] | |
values = values_duplicate.mean(2) | |
return values | |
def values_to_uv(values, index_img, bary_img): | |
uv_size = index_img.shape | |
index_mask = th.all(index_img != -1, dim=-1) | |
idxs_flat = index_img[index_mask].to(th.int64) | |
bary_flat = bary_img[index_mask].to(th.float32) | |
# NOTE: here we assume | |
values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1) | |
values_uv = th.zeros( | |
values.shape[0], | |
values.shape[-1], | |
uv_size[0], | |
uv_size[1], | |
dtype=values.dtype, | |
device=values.device, | |
) | |
values_uv[:, :, index_mask] = values_flat | |
return values_uv | |
def face_normals(v, vi, eps: float = 1e-5): | |
pts = v[:, vi] | |
v0 = pts[:, :, 1] - pts[:, :, 0] | |
v1 = pts[:, :, 2] - pts[:, :, 0] | |
n = th.cross(v0, v1, dim=-1) | |
norm = th.norm(n, dim=-1, keepdim=True) | |
norm[norm < eps] = 1 | |
n /= norm | |
return n | |
def vert_normals(v, vi, eps: float = 1.0e-5): | |
fnorms = face_normals(v, vi) | |
fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3) | |
vi_flat = vi.view(1, -1).expand(v.shape[0], -1) | |
vnorms = th.zeros_like(v) | |
for j in range(3): | |
vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j]) | |
norm = th.norm(vnorms, dim=-1, keepdim=True) | |
norm[norm < eps] = 1 | |
vnorms /= norm | |
return vnorms | |
def compute_view_cos(verts, faces, camera_pos): | |
vn = F.normalize(vert_normals(verts, faces), dim=-1) | |
v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1) | |
return th.einsum("bnd,bnd->bn", vn, v2c) | |
def compute_tbn(geom, vt, vi, vti): | |
"""Computes tangent, bitangent, and normal vectors given a mesh. | |
Args: | |
geom: [N, n_verts, 3] th.Tensor | |
Vertex positions. | |
vt: [n_uv_coords, 2] th.Tensor | |
UV coordinates. | |
vi: [..., 3] th.Tensor | |
Face vertex indices. | |
vti: [..., 3] th.Tensor | |
Face UV indices. | |
Returns: | |
[..., 3] th.Tensors for T, B, N. | |
""" | |
v0 = geom[:, vi[..., 0]] | |
v1 = geom[:, vi[..., 1]] | |
v2 = geom[:, vi[..., 2]] | |
vt0 = vt[vti[..., 0]] | |
vt1 = vt[vti[..., 1]] | |
vt2 = vt[vti[..., 2]] | |
v01 = v1 - v0 | |
v02 = v2 - v0 | |
vt01 = vt1 - vt0 | |
vt02 = vt2 - vt0 | |
f = 1.0 / ( | |
vt01[None, ..., 0] * vt02[None, ..., 1] | |
- vt01[None, ..., 1] * vt02[None, ..., 0] | |
) | |
tangent = f[..., None] * th.stack( | |
[ | |
v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1], | |
v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1], | |
v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1], | |
], | |
dim=-1, | |
) | |
tangent = F.normalize(tangent, dim=-1) | |
normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1) | |
bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1) | |
return tangent, bitangent, normal | |
def compute_v2uv(n_verts, vi, vti, n_max=4): | |
"""Computes mapping from vertex indices to texture indices. | |
Args: | |
vi: [F, 3], triangles | |
vti: [F, 3], texture triangles | |
n_max: int, max number of texture locations | |
Returns: | |
[n_verts, n_max], texture indices | |
""" | |
v2uv_dict = {} | |
for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)): | |
v2uv_dict.setdefault(i_v, set()).add(i_uv) | |
assert len(v2uv_dict) == n_verts | |
v2uv = np.zeros((n_verts, n_max), dtype=np.int32) | |
for i in range(n_verts): | |
vals = sorted(list(v2uv_dict[i])) | |
v2uv[i, :] = vals[0] | |
v2uv[i, : len(vals)] = np.array(vals) | |
return v2uv | |
def compute_neighbours(n_verts, vi, n_max_values=10): | |
"""Computes first-ring neighbours given vertices and faces.""" | |
n_vi = vi.shape[0] | |
adj = {i: set() for i in range(n_verts)} | |
for i in range(n_vi): | |
for idx in vi[i]: | |
adj[idx] |= set(vi[i]) - set([idx]) | |
nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values)) | |
nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32) | |
for idx in range(n_verts): | |
n_values = min(len(adj[idx]), n_max_values) | |
nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values] | |
nbs_weights[idx, :n_values] = -1.0 / n_values | |
return nbs_idxs, nbs_weights | |
def make_postex(v, idxim, barim): | |
return ( | |
barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]] | |
+ barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]] | |
+ barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]] | |
).permute(0, 3, 1, 2) | |
def matrix_to_axisangle(r): | |
th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None] | |
vec = ( | |
0.5 | |
* th.stack( | |
[ | |
r[..., 2, 1] - r[..., 1, 2], | |
r[..., 0, 2] - r[..., 2, 0], | |
r[..., 1, 0] - r[..., 0, 1], | |
], | |
dim=-1, | |
) | |
/ th.sin(th) | |
) | |
return th, vec | |
def axisangle_to_matrix(rvec): | |
theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1)) | |
rvec = rvec / theta[..., None] | |
costh = th.cos(theta) | |
sinth = th.sin(theta) | |
return th.stack( | |
( | |
th.stack( | |
( | |
rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh, | |
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth, | |
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth, | |
), | |
dim=-1, | |
), | |
th.stack( | |
( | |
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth, | |
rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh, | |
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth, | |
), | |
dim=-1, | |
), | |
th.stack( | |
( | |
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth, | |
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth, | |
rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh, | |
), | |
dim=-1, | |
), | |
), | |
dim=-2, | |
) | |
def rotation_interp(r0, r1, alpha): | |
r0a = r0.view(-1, 3, 3) | |
r1a = r1.view(-1, 3, 3) | |
r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) | |
th, rvec = matrix_to_axisangle(r) | |
rvec = rvec * (alpha * th) | |
r = axisangle_to_matrix(rvec) | |
return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) | |
def convert_camera_parameters(Rt, K): | |
R = Rt[:, :3, :3] | |
t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2) | |
return dict( | |
campos=t, | |
camrot=R, | |
focal=K[:, :2, :2], | |
princpt=K[:, :2, 2], | |
) | |
def project_points_multi(p, Rt, K, normalize=False, size=None): | |
"""Project a set of 3D points into multiple cameras with a pinhole model. | |
Args: | |
p: [B, N, 3], input 3D points in world coordinates | |
Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to) | |
K: [B, NC, 3, 3], intrinsics | |
normalize: bool, whether to normalize coordinates to [-1.0, 1.0] | |
Returns: | |
tuple: | |
- [B, NC, N, 2] - projected points | |
- [B, NC, N] - their | |
""" | |
B, N = p.shape[:2] | |
NC = Rt.shape[1] | |
Rt = Rt.reshape(B * NC, 3, 4) | |
K = K.reshape(B * NC, 3, 3) | |
# [B, N, 3] -> [B * NC, N, 3] | |
p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3) | |
p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis] | |
p_pix = p_cam @ K.transpose(-2, -1) | |
p_depth = p_pix[:, :, 2:] | |
p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2) | |
p_depth = p_depth.reshape(B, NC, N) | |
if normalize: | |
assert size is not None | |
h, w = size | |
p_pix = ( | |
2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0 | |
) | |
return p_pix, p_depth | |