from enum import Enum import torch from torch import Tensor from torch.nn.functional import silu from .latentnet import * from .unet import * from choices import * @dataclass class BeatGANsAutoencConfig(BeatGANsUNetConfig): # number of style channels enc_out_channels: int = 512 enc_attn_resolutions: Tuple[int] = None enc_pool: str = 'depthconv' enc_num_res_block: int = 2 enc_channel_mult: Tuple[int] = None enc_grad_checkpoint: bool = False latent_net_conf: MLPSkipNetConfig = None def make_model(self): return BeatGANsAutoencModel(self) class BeatGANsAutoencModel(BeatGANsUNetModel): def __init__(self, conf: BeatGANsAutoencConfig): super().__init__(conf) self.conf = conf # having only time, cond self.time_embed = TimeStyleSeperateEmbed( time_channels=conf.model_channels, time_out_channels=conf.embed_channels, ) self.encoder = BeatGANsEncoderConfig( image_size=conf.image_size, in_channels=conf.in_channels, model_channels=conf.model_channels, out_hid_channels=conf.enc_out_channels, out_channels=conf.enc_out_channels, num_res_blocks=conf.enc_num_res_block, attention_resolutions=(conf.enc_attn_resolutions or conf.attention_resolutions), dropout=conf.dropout, channel_mult=conf.enc_channel_mult or conf.channel_mult, use_time_condition=False, conv_resample=conf.conv_resample, dims=conf.dims, use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, resblock_updown=conf.resblock_updown, use_new_attention_order=conf.use_new_attention_order, pool=conf.enc_pool, ).make_model() if conf.latent_net_conf is not None: self.latent_net = conf.latent_net_conf.make_model() def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ assert self.conf.is_stochastic std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def sample_z(self, n: int, device): assert self.conf.is_stochastic return torch.randn(n, self.conf.enc_out_channels, device=device) def noise_to_cond(self, noise: Tensor): raise NotImplementedError() assert self.conf.noise_net_conf is not None return self.noise_net.forward(noise) def encode(self, x): cond = self.encoder.forward(x) return {'cond': cond} @property def stylespace_sizes(self): modules = list(self.input_blocks.modules()) + list( self.middle_block.modules()) + list(self.output_blocks.modules()) sizes = [] for module in modules: if isinstance(module, ResBlock): linear = module.cond_emb_layers[-1] sizes.append(linear.weight.shape[0]) return sizes def encode_stylespace(self, x, return_vector: bool = True): """ encode to style space """ modules = list(self.input_blocks.modules()) + list( self.middle_block.modules()) + list(self.output_blocks.modules()) # (n, c) cond = self.encoder.forward(x) S = [] for module in modules: if isinstance(module, ResBlock): # (n, c') s = module.cond_emb_layers.forward(cond) S.append(s) if return_vector: # (n, sum_c) return torch.cat(S, dim=1) else: return S def forward(self, x, t, y=None, x_start=None, cond=None, style=None, noise=None, t_cond=None, **kwargs): """ Apply the model to an input batch. Args: x_start: the original image to encode cond: output of the encoder noise: random noise (to predict the cond) """ if t_cond is None: t_cond = t if noise is not None: # if the noise is given, we predict the cond from noise cond = self.noise_to_cond(noise) if cond is None: if x is not None: assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' tmp = self.encode(x_start) cond = tmp['cond'] if t is not None: _t_emb = timestep_embedding(t, self.conf.model_channels) _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) else: # this happens when training only autoenc _t_emb = None _t_cond_emb = None if self.conf.resnet_two_cond: res = self.time_embed.forward( time_emb=_t_emb, cond=cond, time_cond_emb=_t_cond_emb, ) else: raise NotImplementedError() if self.conf.resnet_two_cond: # two cond: first = time emb, second = cond_emb emb = res.time_emb cond_emb = res.emb else: # one cond = combined of both time and cond emb = res.emb cond_emb = None # override the style if given style = style or res.style assert (y is not None) == ( self.conf.num_classes is not None ), "must specify y if and only if the model is class-conditional" if self.conf.num_classes is not None: raise NotImplementedError() # assert y.shape == (x.shape[0], ) # emb = emb + self.label_emb(y) # where in the model to supply time conditions enc_time_emb = emb mid_time_emb = emb dec_time_emb = emb # where in the model to supply style conditions enc_cond_emb = cond_emb mid_cond_emb = cond_emb dec_cond_emb = cond_emb # hs = [] hs = [[] for _ in range(len(self.conf.channel_mult))] if x is not None: h = x.type(self.dtype) # input blocks k = 0 for i in range(len(self.input_num_blocks)): for j in range(self.input_num_blocks[i]): h = self.input_blocks[k](h, emb=enc_time_emb, cond=enc_cond_emb) # print(i, j, h.shape) hs[i].append(h) k += 1 assert k == len(self.input_blocks) # middle blocks h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) else: # no lateral connections # happens when training only the autonecoder h = None hs = [[] for _ in range(len(self.conf.channel_mult))] # output blocks k = 0 for i in range(len(self.output_num_blocks)): for j in range(self.output_num_blocks[i]): # take the lateral connection from the same layer (in reserve) # until there is no more, use None try: lateral = hs[-i - 1].pop() # print(i, j, lateral.shape) except IndexError: lateral = None # print(i, j, lateral) h = self.output_blocks[k](h, emb=dec_time_emb, cond=dec_cond_emb, lateral=lateral) k += 1 pred = self.out(h) return AutoencReturn(pred=pred, cond=cond) class AutoencReturn(NamedTuple): pred: Tensor cond: Tensor = None class EmbedReturn(NamedTuple): # style and time emb: Tensor = None # time only time_emb: Tensor = None # style only (but could depend on time) style: Tensor = None class TimeStyleSeperateEmbed(nn.Module): # embed only style def __init__(self, time_channels, time_out_channels): super().__init__() self.time_embed = nn.Sequential( linear(time_channels, time_out_channels), nn.SiLU(), linear(time_out_channels, time_out_channels), ) self.style = nn.Identity() def forward(self, time_emb=None, cond=None, **kwargs): if time_emb is None: # happens with autoenc training mode time_emb = None else: time_emb = self.time_embed(time_emb) style = self.style(cond) return EmbedReturn(emb=style, time_emb=time_emb, style=style)