from ..options import Options from . import models_utils, transformer from .. import constants from ..custom_types import * from torch import distributions import math from ..utils import files_utils def dot(x, y, dim=3): return torch.sum(x * y, dim=dim) def remove_projection(v_1, v_2): proj = (dot(v_1, v_2) / dot(v_2, v_2)) return v_1 - proj[:, :, :, None] * v_2 def get_p_direct(splitted: TS) -> T: raw_base = [] for i in range(constants.DIM): u = splitted[i] for j in range(i): u = remove_projection(u, raw_base[j]) raw_base.append(u) p = torch.stack(raw_base, dim=3) p = p / torch.norm(p, p=2, dim=4)[:, :, :, :, None] # + self.noise[None, None, :, :] return p def split_gm(splitted: TS) -> TS: p = get_p_direct(splitted) # eigenvalues eigen = splitted[-3] ** 2 + constants.EPSILON mu = splitted[-2] phi = splitted[-1].squeeze(3) return mu, p, phi, eigen class DecompositionNetwork(nn.Module): def forward_bottom(self, x): return self.l1(x).view(-1, self.bottom_width, self.embed_dim) def forward_upper(self, x): return self.to_zb(x) def forward(self, x): x = self.forward_bottom(x) x = self.forward_upper(x) return x def __init__(self, opt: Options, act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm): super(DecompositionNetwork, self).__init__() self.bottom_width = opt.num_gaussians self.embed_dim = opt.dim_h self.l1 = nn.Linear(opt.dim_z, self.bottom_width * opt.dim_h) if opt.decomposition_network == 'mlp': self.to_zb = models_utils.MLP((opt.dim_h, *([2 * opt.dim_h] * opt.decomposition_num_layers), opt.dim_h)) else: self.to_zb = transformer.Transformer(opt.dim_h, opt.num_heads, opt.num_layers, act=act, norm_layer=norm_layer) class OccupancyMlP(nn.Module): ## base on DeepSDF https://github.com/facebookresearch/DeepSDF def forward(self, x, z): x_ = x = torch.cat((x, z), dim=-1) for i, layer in enumerate(self.layers): if layer == self.latent_in: x = torch.cat([x, x_], 2) x = layer(x) if i < len(self.layers) - 2: x = self.relu(x) # x = self.dropout(self.relu(x)) # files_utils.save_pickle(x.detach().cpu(), f"/home/amirh/projects/spaghetti_private/assets/debug/out_{i}") return x def __init__(self, opt: Options): super(OccupancyMlP, self).__init__() dim_in = 2 * (opt.pos_dim + constants.DIM) dims = [dim_in] + opt.head_occ_size * [dim_in] + [1] self.latent_in = opt.head_occ_size // 2 + opt.head_occ_size % 2 dims[self.latent_in] += dims[0] self.dropout = nn.Dropout(.2) self.relu = nn.ReLU(True) layers = [] for i in range(0, len(dims) - 1): layers.append(nn.utils.weight_norm(nn.Linear(dims[i], dims[i + 1]))) self.layers = nn.ModuleList(layers) class OccupancyNetwork(nn.Module): def get_pos(self, coords: T): pos = self.pos_encoder(coords) pos = torch.cat((coords, pos), dim=2) return pos def forward_attention(self, coords: T, zh: T, mask: Optional[T] = None, alpha: TN = None) -> TS: pos = self.get_pos(coords) _, attn = self.occ_transformer.forward_with_attention(pos, zh, mask, alpha) return attn def forward(self, coords: T, zh: T, mask: TN = None, alpha: TN = None) -> T: pos = self.get_pos(coords) x = self.occ_transformer(pos, zh, mask, alpha) out = self.occ_mlp(pos, x) if out.shape[-1] == 1: out = out.squeeze(-1) return out def __init__(self, opt: Options): super(OccupancyNetwork, self).__init__() self.pos_encoder = models_utils.SineLayer(constants.DIM, opt.pos_dim, is_first=True) if hasattr(opt, 'head_occ_type') and opt.head_occ_type == 'skip': self.occ_mlp = OccupancyMlP(opt) else: self.occ_mlp = models_utils.MLP([(opt.pos_dim + constants.DIM)] + [opt.dim_h] * opt.head_occ_size + [1]) self.occ_transformer = transformer.Transformer(opt.pos_dim + constants.DIM, opt.num_heads_head, opt.num_layers_head, dim_ref=opt.dim_h) class DecompositionControl(models_utils.Model): def forward_bottom(self, x): z_bottom = self.decomposition.forward_bottom(x) return z_bottom def forward_upper(self, x): x = self.decomposition.forward_upper(x) return x def forward_split(self, x: T) -> Tuple[T, TS]: b = x.shape[0] raw_gmm = self.to_gmm(x).unsqueeze(1) gmms = split_gm(torch.split(raw_gmm, self.split_shape, dim=3)) zh = self.to_s(x) zh = zh.view(b, -1, zh.shape[-1]) return zh, gmms @staticmethod def apply_gmm_affine(gmms: TS, affine: T): mu, p, phi, eigen = gmms if affine.dim() == 2: affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape) mu_r = torch.einsum('bad, bpnd->bpna', affine, mu) p_r = torch.einsum('bad, bpncd->bpnca', affine, p) return mu_r, p_r, phi, eigen @staticmethod def concat_gmm(gmm_a: TS, gmm_b: TS): out = [] num_gaussians = gmm_a[0].shape[2] // 2 for element_a, element_b in zip(gmm_a, gmm_b): out.append(torch.cat((element_a[:, :, :num_gaussians], element_b[:, :, :num_gaussians]), dim=2)) return out def forward_mid(self, zs) -> Tuple[T, TS]: zh, gmms = self.forward_split(zs) if self.reflect is not None: gmms_r = self.apply_gmm_affine(gmms, self.reflect) gmms = self.concat_gmm(gmms, gmms_r) return zh, gmms def forward_low(self, z_init): zs = self.decomposition(z_init) return zs def forward(self, z_init) -> Tuple[T, TS]: zs = self.forward_low(z_init) zh, gmms = self.forward_mid(zs) return zh, gmms @staticmethod def get_reflection(reflect_axes: Tuple[bool, ...]): reflect = torch.eye(constants.DIM) for i in range(constants.DIM): if reflect_axes[i]: reflect[i, i] = -1 return reflect def __init__(self, opt: Options): super(DecompositionControl, self).__init__() if sum(opt.symmetric) > 0: reflect = self.get_reflection(opt.symmetric) self.register_buffer("reflect", reflect) else: self.reflect = None self.split_shape = tuple((constants.DIM + 2) * [constants.DIM] + [1]) self.decomposition = DecompositionNetwork(opt) self.to_gmm = nn.Linear(opt.dim_h, sum(self.split_shape)) self.to_s = nn.Linear(opt.dim_h, opt.dim_h) class Spaghetti(models_utils.Model): def get_z(self, item: T): return self.z(item) @staticmethod def interpolate_(z, num_between: Optional[int] = None): if num_between is None: num_between = z.shape[0] alphas = torch.linspace(0, 1, num_between, device=z.device) while alphas.dim() != z.dim(): alphas.unsqueeze_(-1) z_between = alphas * z[1:2] + (- alphas + 1) * z[:1] return z_between def interpolate_higher(self, z: T, num_between: Optional[int] = None): z_between = self.interpolate_(z, num_between) zh, gmms = self.decomposition_control.forward_split(self.decomposition_control.forward_upper(z_between)) return zh, gmms def interpolate(self, item_a: int, item_b: int, num_between: int): items = torch.tensor((item_a, item_b), dtype=torch.int64, device=self.device) z = self.get_z(items) z_between = self.interpolate_(z, num_between) zh, gmms = self.decomposition_control(z_between) return zh, gmms def get_disentanglement(self, items: T): z_a = self.get_z(items) z_b = self.decomposition_control.forward_bottom(z_a) zh, gmms = self.decomposition_control.forward_split(self.decomposition_control.forward_upper(z_b)) return z_a, z_b, zh, gmms def get_embeddings(self, item: T): z = self.get_z(item) zh, gmms = self.decomposition_control(z) return zh, z, gmms def merge_zh_step_a(self, zh, gmms): b, gp, g, _ = gmms[0].shape mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmms] p = p.reshape(*p.shape[:2], -1) z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2).detach() z_gmm = self.from_gmm(z_gmm) zh_ = zh + z_gmm return zh_ def merge_zh(self, zh, gmms, mask: Optional[T] = None) -> TNS: zh_ = self.merge_zh_step_a(zh, gmms) zh_, attn = self.mixing_network.forward_with_attention(zh_, mask=mask) return zh_, attn def forward_b(self, x, zh, gmms, mask: Optional[T] = None) -> T: zh, _ = self.merge_zh(zh, gmms, mask) return self.occupancy_network(x, zh, mask) def forward_a(self, item: T): zh, z, gmms = self.get_embeddings(item) return zh, z, gmms def get_attention(self, x, item) -> TS: zh, z, gmms = self.forward_a(item) zh, _ = self.merge_zh(zh, gmms) return self.occupancy_network.forward_attention(x, zh) def forward(self, x, item: T) -> Tuple[T, T, TS, T]: zh, z, gmms = self.forward_a(item) return self.forward_b(x, zh, gmms), z, gmms, zh def forward_mid(self, x: T, zh: T) -> Tuple[T, TS]: zh, gmms = self.decomposition_control.forward_mid(zh) return self.forward_b(x, zh, gmms), gmms def get_random_embeddings(self, num_items: int): if self.dist is None: weights = self.z.weight.clone().detach() mean = weights.mean(0) weights = weights - mean[None, :] cov = torch.einsum('nd,nc->dc', weights, weights) / (weights.shape[0] - 1) self.dist = distributions.multivariate_normal.MultivariateNormal(mean, covariance_matrix=cov) z_init = self.dist.sample((num_items,)) return z_init def random_samples(self, num_items: int): z_init = self.get_random_embeddings(num_items) zh, gmms = self.decomposition_control(z_init) return zh, gmms def __init__(self, opt: Options): super(Spaghetti, self).__init__() self.device = opt.device self.opt = opt self.z = nn.Embedding(opt.dataset_size, opt.dim_z) torch.nn.init.normal_( self.z.weight.data, 0.0, 1. / math.sqrt(opt.dim_z), ) self.decomposition_control = DecompositionControl(opt) self.occupancy_network = OccupancyNetwork(opt) self.from_gmm = nn.Linear(sum(self.decomposition_control.split_shape), opt.dim_h) if opt.use_encoder: self.mixing_network = transformer.Transformer(opt.dim_h, opt.num_heads, opt.num_layers, act=nnf.relu, norm_layer=nn.LayerNorm) else: self.mixing_network = transformer.DummyTransformer() self.dist = None