FrozenBurning
single view to 3D init release
81ecb2b
raw
history blame
11.4 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
import open_clip
from dva.io import load_from_config
def sample_orbit_traj(radius, height, start_theta, end_theta, num_points, world_up=torch.Tensor([0, 1, 0])):
# return [num_points, 3, 4]
angles = torch.rand((num_points, )) * (end_theta - start_theta) + start_theta
return get_pose_on_orbit(radius=radius, height=height, angles=angles, world_up=world_up)
def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])):
num_points = angles.shape[0]
x = radius * torch.cos(angles)
h = torch.ones((num_points,)) * height
z = radius * torch.sin(angles)
position = torch.stack([x, h, z], dim=-1)
forward = position / torch.norm(position, p=2, dim=-1, keepdim=True)
right = -torch.cross(world_up[None, ...], forward)
right /= torch.norm(right, dim=-1, keepdim=True)
up = torch.cross(forward, right)
up /= torch.norm(up, p=2, dim=-1, keepdim=True)
rotation = torch.stack([right, up, forward], dim=1)
translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1)
return torch.concat([rotation, translation], dim=2)
class DummyImageConditioner(nn.Module):
def __init__(
self,
num_prims,
dim_feat,
prim_shape,
encoder_config,
sample_view=False,
sample_start=torch.pi*0.25,
sample_end=torch.pi*0.75,
):
super().__init__()
self.num_prims = num_prims
self.dim_feat = dim_feat
self.prim_shape = prim_shape
self.sample_view = sample_view
self.sample_start = sample_start
self.sample_end = sample_end
self.encoder = None
@torch.no_grad()
def forward(self, batch, rm, amp, precision_dtype=torch.float32):
return batch['cond']
class ImageConditioner(nn.Module):
def __init__(
self,
num_prims,
dim_feat,
prim_shape,
encoder_config,
sample_view=False,
sample_start=torch.pi*0.25,
sample_end=torch.pi*0.75,
):
super().__init__()
self.num_prims = num_prims
self.dim_feat = dim_feat
self.prim_shape = prim_shape
self.sample_view = sample_view
self.sample_start = sample_start
self.sample_end = sample_end
self.encoder = load_from_config(encoder_config)
def sdf2alpha(self, sdf):
return torch.exp(-(sdf / 0.005) ** 2)
@torch.no_grad()
def forward(self, batch, rm, amp, precision_dtype=torch.float32):
# TODO: replace with real rendering process in primsdf
assert 'input_param' in batch, "No parameters in current batch for rendering image conditions"
prim_volume = batch['input_param']
bs = prim_volume.shape[0]
preds = {}
geo_start_index = 4
geo_end_index = geo_start_index + self.prim_shape ** 3 # non-inclusive
tex_start_index = geo_end_index
tex_end_index = tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index]
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index]
prim_alpha = self.sdf2alpha(feat_geo).reshape(bs, self.num_prims, 1, self.prim_shape, self.prim_shape, self.prim_shape) * 255
prim_rgb = feat_tex.reshape(bs, self.num_prims, 3, self.prim_shape, self.prim_shape, self.prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
pos = prim_volume[:, :, 1:4]
scale = prim_volume[:, :, 0:1]
preds['prim_pos'] = pos.reshape(bs, self.num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, self.num_prims, 1, 1)
preds['prim_scale'] = (1 / scale.reshape(bs, self.num_prims, 1).repeat(1, 1, 3))
if not self.sample_view:
preds['Rt'] = torch.Tensor([
[
1.0,
0.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
-1.0,
0.0,
0.0 * rm.volradius
],
[
0.0,
0.0,
-1.0,
5 * rm.volradius
]
]).to(prim_volume)[None, ...].repeat(bs, 1, 1)
else:
preds['Rt'] = sample_orbit_traj(radius=5*rm.volradius, height=0, start_theta=self.sample_start, end_theta=self.sample_end, num_points=bs).to(prim_volume)
preds['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(prim_volume)[None, ...].repeat(bs, 1, 1)
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
preds['K'][:, 0:1, :] *= ratio_h
preds['K'][:, 1:2, :] *= ratio_w
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=preds["Rt"],
K=preds["K"],
)
rendered_image = rm_preds['rgba_image'].permute(0, 2, 3, 1)[..., :3].contiguous()
with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp):
results = self.encoder(rendered_image)
return results
class ImageMultiViewConditioner(nn.Module):
def __init__(
self,
num_prims,
dim_feat,
prim_shape,
encoder_config,
sample_view=False,
view_counts=4,
):
super().__init__()
self.num_prims = num_prims
self.dim_feat = dim_feat
self.prim_shape = prim_shape
self.view_counts = view_counts
view_angles = torch.linspace(0.5, 2.5, self.view_counts + 1) * torch.pi
self.view_angles = view_angles[:-1]
self.encoder = load_from_config(encoder_config)
def sdf2alpha(self, sdf):
return torch.exp(-(sdf / 0.005) ** 2)
@torch.no_grad()
def forward(self, batch, rm, amp, precision_dtype=torch.float32):
# TODO: replace with real rendering process in primsdf
assert 'input_param' in batch, "No parameters in current batch for rendering image conditions"
prim_volume = batch['input_param']
bs = prim_volume.shape[0]
preds = {}
geo_start_index = 4
geo_end_index = geo_start_index + self.prim_shape ** 3 # non-inclusive
tex_start_index = geo_end_index
tex_end_index = tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index]
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index]
prim_alpha = self.sdf2alpha(feat_geo).reshape(bs, self.num_prims, 1, self.prim_shape, self.prim_shape, self.prim_shape) * 255
prim_rgb = feat_tex.reshape(bs, self.num_prims, 3, self.prim_shape, self.prim_shape, self.prim_shape) * 255
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2)
pos = prim_volume[:, :, 1:4]
scale = prim_volume[:, :, 0:1]
preds['prim_pos'] = pos.reshape(bs, self.num_prims, 3) * rm.volradius
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, self.num_prims, 1, 1)
preds['prim_scale'] = (1 / scale.reshape(bs, self.num_prims, 1).repeat(1, 1, 3))
preds['K'] = torch.Tensor([
[
2084.9526697685183,
0.0,
512.0
],
[
0.0,
2084.9526697685183,
512.0
],
[
0.0,
0.0,
1.0
]]).to(prim_volume)[None, ...].repeat(bs, 1, 1)
ratio_h = rm.image_height / 1024.
ratio_w = rm.image_width / 1024.
preds['K'][:, 0:1, :] *= ratio_h
preds['K'][:, 1:2, :] *= ratio_w
# we sample view according to view_counts
cond_list = []
for view_ang in self.view_angles:
bs_view_ang = view_ang.repeat(bs,)
preds['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume)
rm_preds = rm(
prim_rgba=preds["prim_rgba"],
prim_pos=preds["prim_pos"],
prim_scale=preds["prim_scale"],
prim_rot=preds["prim_rot"],
RT=preds["Rt"],
K=preds["K"],
)
rendered_image = rm_preds['rgba_image'].permute(0, 2, 3, 1)[..., :3].contiguous()
with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp):
results = self.encoder(rendered_image)
cond_list.append(results)
final_cond = torch.concat(cond_list, dim=1)
return final_cond
class CLIPImageEncoder(nn.Module):
def __init__(
self,
pretrained_path: str,
model_spec: str = 'ViT-L-14',
):
super().__init__()
self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path)
self.model_resolution = self.model.visual.image_size
self.preprocess = Compose([
Resize(self.model_resolution, interpolation=InterpolationMode.BICUBIC),
CenterCrop(self.model_resolution),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.model.eval()
# self.tokenizer = open_clip.get_tokenizer(model_spec)
@torch.no_grad()
def forward(self, img):
assert img.shape[-1] == 3
img = img.permute(0, 3, 1, 2) / 255.
image = self.preprocess(img)
image_features = self.model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
class CLIPImageTokenEncoder(nn.Module):
def __init__(
self,
pretrained_path: str,
model_spec: str = 'ViT-L-14',
):
super().__init__()
self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path)
self.model.visual.output_tokens = True
self.model_resolution = self.model.visual.image_size
self.preprocess = Compose([
Resize(self.model_resolution, interpolation=InterpolationMode.BICUBIC),
CenterCrop(self.model_resolution),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
self.model.eval()
@torch.no_grad()
def forward(self, img):
assert img.shape[-1] == 3
img = img.permute(0, 3, 1, 2) / 255.
image = self.preprocess(img)
_, image_tokens = self.model.encode_image(image)
# [B, T, D] - [B, 256, 1024]
image_tokens /= image_tokens.norm(dim=-1, keepdim=True)
return image_tokens