Anitalker / model /latentnet.py
Delik's picture
Upload 32 files
f1ea451 verified
import math
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Tuple
import torch
from choices import *
from config_base import BaseConfig
from torch import nn
from torch.nn import init
from .blocks import *
from .nn import timestep_embedding
from .unet import *
class LatentNetType(Enum):
none = 'none'
# injecting inputs into the hidden layers
skip = 'skip'
class LatentNetReturn(NamedTuple):
pred: torch.Tensor = None
@dataclass
class MLPSkipNetConfig(BaseConfig):
"""
default MLP for the latent DPM in the paper!
"""
num_channels: int
skip_layers: Tuple[int]
num_hid_channels: int
num_layers: int
num_time_emb_channels: int = 64
activation: Activation = Activation.silu
use_norm: bool = True
condition_bias: float = 1
dropout: float = 0
last_act: Activation = Activation.none
num_time_layers: int = 2
time_last_act: bool = False
def make_model(self):
return MLPSkipNet(self)
class MLPSkipNet(nn.Module):
"""
concat x to hidden layers
default MLP for the latent DPM in the paper!
"""
def __init__(self, conf: MLPSkipNetConfig):
super().__init__()
self.conf = conf
layers = []
for i in range(conf.num_time_layers):
if i == 0:
a = conf.num_time_emb_channels
b = conf.num_channels
else:
a = conf.num_channels
b = conf.num_channels
layers.append(nn.Linear(a, b))
if i < conf.num_time_layers - 1 or conf.time_last_act:
layers.append(conf.activation.get_act())
self.time_embed = nn.Sequential(*layers)
self.layers = nn.ModuleList([])
for i in range(conf.num_layers):
if i == 0:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_channels, conf.num_hid_channels
dropout = conf.dropout
elif i == conf.num_layers - 1:
act = Activation.none
norm = False
cond = False
a, b = conf.num_hid_channels, conf.num_channels
dropout = 0
else:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_hid_channels, conf.num_hid_channels
dropout = conf.dropout
if i in conf.skip_layers:
a += conf.num_channels
self.layers.append(
MLPLNAct(
a,
b,
norm=norm,
activation=act,
cond_channels=conf.num_channels,
use_cond=cond,
condition_bias=conf.condition_bias,
dropout=dropout,
))
self.last_act = conf.last_act.get_act()
def forward(self, x, t, **kwargs):
t = timestep_embedding(t, self.conf.num_time_emb_channels)
cond = self.time_embed(t)
h = x
for i in range(len(self.layers)):
if i in self.conf.skip_layers:
# injecting input into the hidden layers
h = torch.cat([h, x], dim=1)
h = self.layers[i].forward(x=h, cond=cond)
h = self.last_act(h)
return LatentNetReturn(h)
class MLPLNAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm: bool,
use_cond: bool,
activation: Activation,
cond_channels: int,
condition_bias: float = 0,
dropout: float = 0,
):
super().__init__()
self.activation = activation
self.condition_bias = condition_bias
self.use_cond = use_cond
self.linear = nn.Linear(in_channels, out_channels)
self.act = activation.get_act()
if self.use_cond:
self.linear_emb = nn.Linear(cond_channels, out_channels)
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
if norm:
self.norm = nn.LayerNorm(out_channels)
else:
self.norm = nn.Identity()
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = nn.Identity()
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
if self.activation == Activation.relu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
elif self.activation == Activation.lrelu:
init.kaiming_normal_(module.weight,
a=0.2,
nonlinearity='leaky_relu')
elif self.activation == Activation.silu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
else:
# leave it as default
pass
def forward(self, x, cond=None):
x = self.linear(x)
if self.use_cond:
# (n, c) or (n, c * 2)
cond = self.cond_layers(cond)
cond = (cond, None)
# scale shift first
x = x * (self.condition_bias + cond[0])
if cond[1] is not None:
x = x + cond[1]
# then norm
x = self.norm(x)
else:
# no condition
x = self.norm(x)
x = self.act(x)
x = self.dropout(x)
return x