|
import pdb |
|
from functools import reduce, partial |
|
from packaging import version |
|
|
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
from torch.cuda.amp import autocast |
|
from typing import Callable, Literal |
|
|
|
try: |
|
from flash_attn import flash_attn_func, flash_attn_kvpacked_func |
|
except ImportError as e: |
|
print(e) |
|
print('flash_attn not installed, disabling Flash Attention') |
|
flash_attn_kvpacked_func = None |
|
flash_attn_func = None |
|
|
|
try: |
|
import natten |
|
except ImportError: |
|
natten = None |
|
|
|
|
|
def checkpoint(function, *args, **kwargs): |
|
kwargs.setdefault("use_reentrant", False) |
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def create_causal_mask(i, j, device): |
|
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) |
|
|
|
|
|
def or_reduce(masks): |
|
head, *body = masks |
|
for rest in body: |
|
head = head | rest |
|
return head |
|
|
|
|
|
|
|
|
|
class AbsolutePositionalEmbedding(nn.Module): |
|
def __init__(self, dim, max_seq_len): |
|
super().__init__() |
|
self.scale = dim ** -0.5 |
|
self.max_seq_len = max_seq_len |
|
self.emb = nn.Embedding(max_seq_len, dim) |
|
|
|
def forward(self, x, pos=None, seq_start_pos=None): |
|
seq_len, device = x.shape[1], x.device |
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' |
|
|
|
if pos is None: |
|
pos = torch.arange(seq_len, device=device) |
|
|
|
if seq_start_pos is not None: |
|
pos = (pos - seq_start_pos[..., None]).clamp(min=0) |
|
|
|
pos_emb = self.emb(pos) |
|
pos_emb = pos_emb * self.scale |
|
return pos_emb |
|
|
|
|
|
class ScaledSinusoidalEmbedding(nn.Module): |
|
def __init__(self, dim, theta=10000): |
|
super().__init__() |
|
assert (dim % 2) == 0, 'dimension must be divisible by 2' |
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) |
|
|
|
half_dim = dim // 2 |
|
freq_seq = torch.arange(half_dim).float() / half_dim |
|
inv_freq = theta ** -freq_seq |
|
self.register_buffer('inv_freq', inv_freq, persistent=False) |
|
|
|
def forward(self, x, pos=None, seq_start_pos=None): |
|
seq_len, device = x.shape[1], x.device |
|
|
|
if pos is None: |
|
pos = torch.arange(seq_len, device=device) |
|
|
|
if seq_start_pos is not None: |
|
pos = pos - seq_start_pos[..., None] |
|
|
|
emb = einsum('i, j -> i j', pos, self.inv_freq) |
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
return emb * self.scale |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
use_xpos=False, |
|
scale_base=512, |
|
interpolation_factor=1., |
|
base=10000, |
|
base_rescale_factor=1. |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
base *= base_rescale_factor ** (dim / (dim - 2)) |
|
|
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer('inv_freq', inv_freq) |
|
|
|
assert interpolation_factor >= 1. |
|
self.interpolation_factor = interpolation_factor |
|
|
|
if not use_xpos: |
|
self.register_buffer('scale', None) |
|
return |
|
|
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) |
|
|
|
self.scale_base = scale_base |
|
self.register_buffer('scale', scale) |
|
|
|
def forward_from_seq_len(self, seq_len): |
|
device = self.inv_freq.device |
|
|
|
t = torch.arange(seq_len, device=device) |
|
return self.forward(t) |
|
|
|
@autocast(enabled=False) |
|
def forward(self, t): |
|
device = self.inv_freq.device |
|
|
|
t = t.to(torch.float32) |
|
|
|
t = t / self.interpolation_factor |
|
|
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) |
|
freqs = torch.cat((freqs, freqs), dim=-1) |
|
|
|
if self.scale is None: |
|
return freqs, 1. |
|
|
|
power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base |
|
scale = self.scale ** rearrange(power, 'n -> n 1') |
|
scale = torch.cat((scale, scale), dim=-1) |
|
|
|
return freqs, scale |
|
|
|
|
|
def rotate_half(x): |
|
x = rearrange(x, '... (j d) -> ... j d', j=2) |
|
x1, x2 = x.unbind(dim=-2) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
@autocast(enabled=False) |
|
def apply_rotary_pos_emb(t, freqs, scale=1): |
|
out_dtype = t.dtype |
|
|
|
|
|
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) |
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2] |
|
freqs, t = freqs.to(dtype), t.to(dtype) |
|
freqs = freqs[-seq_len:, :] |
|
|
|
if t.ndim == 4 and freqs.ndim == 3: |
|
freqs = rearrange(freqs, 'b n d -> b 1 n d') |
|
|
|
|
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] |
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
|
|
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) |
|
|
|
return torch.cat((t, t_unrotated), dim=-1) |
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, dim, bias=False, fix_scale=False): |
|
""" |
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less |
|
""" |
|
super().__init__() |
|
|
|
if fix_scale: |
|
self.register_buffer("gamma", torch.ones(dim)) |
|
else: |
|
self.gamma = nn.Parameter(torch.ones(dim)) |
|
|
|
if bias: |
|
self.beta = nn.Parameter(torch.zeros(dim)) |
|
else: |
|
self.register_buffer("beta", torch.zeros(dim)) |
|
|
|
def forward(self, x): |
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) |
|
|
|
|
|
|
|
|
|
class GLU(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
activation: Callable, |
|
use_conv=False, |
|
conv_kernel_size=3, |
|
): |
|
super().__init__() |
|
self.act = activation |
|
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, |
|
padding=(conv_kernel_size // 2)) |
|
self.use_conv = use_conv |
|
|
|
def forward(self, x): |
|
if self.use_conv: |
|
x = rearrange(x, 'b n d -> b d n') |
|
x = self.proj(x) |
|
x = rearrange(x, 'b d n -> b n d') |
|
else: |
|
x = self.proj(x) |
|
|
|
x, gate = x.chunk(2, dim=-1) |
|
return x * self.act(gate) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out=None, |
|
mult=4, |
|
no_bias=False, |
|
glu=True, |
|
use_conv=False, |
|
conv_kernel_size=3, |
|
zero_init_output=True, |
|
): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
|
|
|
|
|
|
activation = nn.SiLU() |
|
|
|
dim_out = dim if dim_out is None else dim_out |
|
|
|
if glu: |
|
linear_in = GLU(dim, inner_dim, activation) |
|
else: |
|
linear_in = nn.Sequential( |
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
|
nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, |
|
conv_kernel_size, padding=( |
|
conv_kernel_size // 2), bias=not no_bias), |
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
|
activation |
|
) |
|
|
|
linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, |
|
conv_kernel_size, |
|
padding=( |
|
conv_kernel_size // 2), |
|
bias=not no_bias) |
|
|
|
|
|
if zero_init_output: |
|
nn.init.zeros_(linear_out.weight) |
|
if not no_bias: |
|
nn.init.zeros_(linear_out.bias) |
|
|
|
self.ff = nn.Sequential( |
|
linear_in, |
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(), |
|
linear_out, |
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
|
) |
|
|
|
def forward(self, x): |
|
return self.ff(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_heads=64, |
|
dim_context=None, |
|
causal=False, |
|
zero_init_output=True, |
|
qk_norm: Literal['l2', 'ln', 'none'] = 'none', |
|
natten_kernel_size=None |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.dim_heads = dim_heads |
|
self.causal = causal |
|
|
|
dim_kv = dim_context if dim_context is not None else dim |
|
|
|
self.num_heads = dim // dim_heads |
|
self.kv_heads = dim_kv // dim_heads |
|
|
|
if dim_context is not None: |
|
self.to_q = nn.Linear(dim, dim, bias=False) |
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) |
|
else: |
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) |
|
|
|
self.to_out = nn.Linear(dim, dim, bias=False) |
|
|
|
if zero_init_output: |
|
nn.init.zeros_(self.to_out.weight) |
|
|
|
self.qk_norm = qk_norm |
|
|
|
if self.qk_norm == "ln": |
|
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) |
|
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) |
|
|
|
|
|
self.natten_kernel_size = natten_kernel_size |
|
if natten_kernel_size is not None: |
|
return |
|
|
|
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') |
|
|
|
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None |
|
|
|
self.use_fa_flash = False |
|
|
|
self.sdp_kwargs = dict( |
|
enable_flash=True, |
|
enable_math=True, |
|
enable_mem_efficient=True |
|
) |
|
|
|
def flash_attn( |
|
self, |
|
q, |
|
k, |
|
v, |
|
mask=None, |
|
causal=None |
|
): |
|
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device |
|
kv_heads = k.shape[1] |
|
|
|
|
|
|
|
if heads != kv_heads: |
|
|
|
heads_per_kv_head = heads // kv_heads |
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) |
|
|
|
if k.ndim == 3: |
|
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) |
|
|
|
if v.ndim == 3: |
|
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) |
|
|
|
causal = self.causal if causal is None else causal |
|
|
|
if q_len == 1 and causal: |
|
causal = False |
|
|
|
if mask is not None: |
|
assert mask.ndim == 4 |
|
mask = mask.expand(batch, heads, q_len, k_len) |
|
|
|
assert causal |
|
|
|
if k_len > q_len and causal: |
|
causal_mask = create_causal_mask(q_len, k_len, device=device) |
|
if mask is None: |
|
mask = ~causal_mask |
|
else: |
|
mask = mask & ~causal_mask |
|
causal = False |
|
|
|
|
|
|
|
row_is_entirely_masked = None |
|
|
|
if mask is not None and causal: |
|
causal_mask = create_causal_mask(q_len, k_len, device=device) |
|
mask = mask & ~causal_mask |
|
|
|
|
|
|
|
row_is_entirely_masked = ~mask.any(dim=-1) |
|
mask[..., 0] = mask[..., 0] | row_is_entirely_masked |
|
|
|
causal = False |
|
|
|
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): |
|
out = F.scaled_dot_product_attention( |
|
q, k, v, |
|
attn_mask=mask, |
|
is_causal=causal |
|
) |
|
|
|
|
|
|
|
if row_is_entirely_masked is not None: |
|
out = out.masked_fill(row_is_entirely_masked[..., None], 0.) |
|
|
|
return out |
|
|
|
def forward( |
|
self, |
|
x, |
|
context=None, |
|
mask=None, |
|
context_mask=None, |
|
rotary_pos_emb=None, |
|
causal=None |
|
): |
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None |
|
|
|
kv_input = context if has_context else x |
|
|
|
if hasattr(self, 'to_q'): |
|
|
|
q = self.to_q(x) |
|
q = rearrange(q, 'b n (h d) -> b h n d', h=h) |
|
|
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
|
|
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v)) |
|
else: |
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) |
|
|
|
|
|
if self.qk_norm == "l2": |
|
q = F.normalize(q, dim=-1) |
|
k = F.normalize(k, dim=-1) |
|
elif self.qk_norm == "ln": |
|
q = self.q_norm(q) |
|
k = self.k_norm(k) |
|
|
|
if rotary_pos_emb is not None and not has_context: |
|
freqs, _ = rotary_pos_emb |
|
|
|
q_dtype = q.dtype |
|
k_dtype = k.dtype |
|
|
|
q = q.to(torch.float32) |
|
k = k.to(torch.float32) |
|
freqs = freqs.to(torch.float32) |
|
|
|
q = apply_rotary_pos_emb(q, freqs) |
|
k = apply_rotary_pos_emb(k, freqs) |
|
|
|
q = q.to(q_dtype) |
|
k = k.to(k_dtype) |
|
|
|
input_mask = context_mask |
|
|
|
if input_mask is None and not has_context: |
|
input_mask = mask |
|
|
|
|
|
masks = [] |
|
final_attn_mask = None |
|
|
|
if input_mask is not None: |
|
input_mask = rearrange(input_mask, 'b j -> b 1 1 j') |
|
masks.append(~input_mask) |
|
|
|
|
|
|
|
if len(masks) > 0: |
|
final_attn_mask = ~or_reduce(masks) |
|
|
|
n, device = q.shape[-2], q.device |
|
|
|
causal = self.causal if causal is None else causal |
|
|
|
if n == 1 and causal: |
|
causal = False |
|
|
|
if self.natten_kernel_size is not None: |
|
if natten is None: |
|
raise ImportError('natten not installed, please install natten to use neighborhood attention') |
|
|
|
dtype_in = q.dtype |
|
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) |
|
|
|
attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1) |
|
|
|
if final_attn_mask is not None: |
|
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) |
|
|
|
attn = F.softmax(attn, dim=-1, dtype=torch.float32) |
|
|
|
out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in) |
|
|
|
|
|
elif self.use_fa_flash: |
|
assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' |
|
|
|
fa_dtype_in = q.dtype |
|
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) |
|
|
|
out = flash_attn_func(q, k, v, causal=causal) |
|
|
|
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') |
|
|
|
|
|
elif self.use_pt_flash: |
|
|
|
|
|
out = self.flash_attn(q, k, v, causal=True, mask=final_attn_mask) |
|
|
|
else: |
|
|
|
|
|
if h != kv_h: |
|
|
|
heads_per_kv_head = h // kv_h |
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v)) |
|
|
|
scale = 1. / (q.shape[-1] ** 0.5) |
|
|
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' |
|
|
|
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale |
|
|
|
i, j, dtype = *dots.shape[-2:], dots.dtype |
|
|
|
mask_value = -torch.finfo(dots.dtype).max |
|
|
|
if final_attn_mask is not None: |
|
dots = dots.masked_fill(~final_attn_mask, mask_value) |
|
|
|
if causal: |
|
causal_mask = create_causal_mask(i, j, device=device) |
|
dots = dots.masked_fill(causal_mask, mask_value) |
|
|
|
attn = F.softmax(dots, dim=-1, dtype=torch.float32) |
|
attn = attn.type(dtype) |
|
|
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) |
|
|
|
|
|
out = rearrange(out, ' b h n d -> b n (h d)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.to_out(out) |
|
|
|
if mask is not None: |
|
mask = rearrange(mask, 'b n -> b n 1') |
|
out = out.masked_fill(~mask, 0.) |
|
|
|
return out |
|
|
|
|
|
class ConformerModule(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
norm_kwargs={}, |
|
): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
|
|
self.in_norm = LayerNorm(dim, **norm_kwargs) |
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) |
|
self.glu = GLU(dim, dim, nn.SiLU()) |
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) |
|
self.mid_norm = LayerNorm(dim, |
|
**norm_kwargs) |
|
self.swish = nn.SiLU() |
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) |
|
|
|
def forward(self, x): |
|
x = self.in_norm(x) |
|
x = rearrange(x, 'b n d -> b d n') |
|
x = self.pointwise_conv(x) |
|
x = rearrange(x, 'b d n -> b n d') |
|
x = self.glu(x) |
|
x = rearrange(x, 'b n d -> b d n') |
|
x = self.depthwise_conv(x) |
|
x = rearrange(x, 'b d n -> b n d') |
|
x = self.mid_norm(x) |
|
x = self.swish(x) |
|
x = rearrange(x, 'b n d -> b d n') |
|
x = self.pointwise_conv_2(x) |
|
x = rearrange(x, 'b d n -> b n d') |
|
|
|
return x |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_heads=64, |
|
cross_attend=False, |
|
dim_context=None, |
|
global_cond_dim=None, |
|
causal=False, |
|
zero_init_branch_outputs=True, |
|
conformer=False, |
|
layer_ix=-1, |
|
remove_norms=False, |
|
attn_kwargs={}, |
|
ff_kwargs={}, |
|
norm_kwargs={} |
|
): |
|
|
|
super().__init__() |
|
self.dim = dim |
|
self.dim_heads = dim_heads |
|
self.cross_attend = cross_attend |
|
self.dim_context = dim_context |
|
self.causal = causal |
|
|
|
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() |
|
|
|
self.self_attn = Attention( |
|
dim, |
|
dim_heads=dim_heads, |
|
causal=causal, |
|
zero_init_output=zero_init_branch_outputs, |
|
**attn_kwargs |
|
) |
|
|
|
if cross_attend: |
|
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() |
|
self.cross_attn = Attention( |
|
dim, |
|
dim_heads=dim_heads, |
|
dim_context=dim_context, |
|
causal=causal, |
|
zero_init_output=zero_init_branch_outputs, |
|
**attn_kwargs |
|
) |
|
|
|
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() |
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) |
|
|
|
self.layer_ix = layer_ix |
|
|
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None |
|
|
|
self.global_cond_dim = global_cond_dim |
|
|
|
if global_cond_dim is not None: |
|
self.to_scale_shift_gate = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(global_cond_dim, dim * 6, bias=False) |
|
) |
|
|
|
nn.init.zeros_(self.to_scale_shift_gate[1].weight) |
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
context=None, |
|
global_cond=None, |
|
mask=None, |
|
context_mask=None, |
|
rotary_pos_emb=None |
|
): |
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: |
|
|
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate( |
|
global_cond).unsqueeze(1).chunk(6, dim=-1) |
|
|
|
|
|
residual = x |
|
x = self.pre_norm(x) |
|
x = x * (1 + scale_self) + shift_self |
|
x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb) |
|
x = x * torch.sigmoid(1 - gate_self) |
|
x = x + residual |
|
|
|
if context is not None: |
|
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) |
|
|
|
if self.conformer is not None: |
|
x = x + self.conformer(x) |
|
|
|
|
|
residual = x |
|
x = self.ff_norm(x) |
|
x = x * (1 + scale_ff) + shift_ff |
|
x = self.ff(x) |
|
x = x * torch.sigmoid(1 - gate_ff) |
|
x = x + residual |
|
|
|
else: |
|
x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb) |
|
|
|
if context is not None: |
|
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask) |
|
|
|
if self.conformer is not None: |
|
x = x + self.conformer(x) |
|
|
|
x = x + self.ff(self.ff_norm(x)) |
|
|
|
return x |
|
|
|
|
|
class ContinuousTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
depth, |
|
*, |
|
dim_in=None, |
|
dim_out=None, |
|
dim_heads=64, |
|
cross_attend=False, |
|
cond_token_dim=None, |
|
global_cond_dim=None, |
|
causal=False, |
|
rotary_pos_emb=True, |
|
zero_init_branch_outputs=True, |
|
conformer=False, |
|
use_sinusoidal_emb=False, |
|
use_abs_pos_emb=False, |
|
abs_pos_emb_max_length=10000, |
|
**kwargs |
|
): |
|
|
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.depth = depth |
|
self.causal = causal |
|
self.layers = nn.ModuleList([]) |
|
|
|
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() |
|
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() |
|
|
|
if rotary_pos_emb: |
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) |
|
else: |
|
self.rotary_pos_emb = None |
|
|
|
self.use_sinusoidal_emb = use_sinusoidal_emb |
|
if use_sinusoidal_emb: |
|
self.pos_emb = ScaledSinusoidalEmbedding(dim) |
|
|
|
self.use_abs_pos_emb = use_abs_pos_emb |
|
if use_abs_pos_emb: |
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) |
|
|
|
for i in range(depth): |
|
self.layers.append( |
|
TransformerBlock( |
|
dim, |
|
dim_heads=dim_heads, |
|
cross_attend=cross_attend, |
|
dim_context=cond_token_dim, |
|
global_cond_dim=global_cond_dim, |
|
causal=causal, |
|
zero_init_branch_outputs=zero_init_branch_outputs, |
|
conformer=conformer, |
|
layer_ix=i, |
|
**kwargs |
|
) |
|
) |
|
|
|
def forward( |
|
self, |
|
x, |
|
mask=None, |
|
prepend_embeds=None, |
|
prepend_mask=None, |
|
global_cond=None, |
|
return_info=False, |
|
**kwargs |
|
): |
|
batch, seq, device = *x.shape[:2], x.device |
|
|
|
info = { |
|
"hidden_states": [], |
|
} |
|
|
|
x = self.project_in(x) |
|
if prepend_embeds is not None: |
|
prepend_length, prepend_dim = prepend_embeds.shape[1:] |
|
|
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' |
|
|
|
x = torch.cat((prepend_embeds, x), dim=-2) |
|
|
|
if prepend_mask is not None or mask is not None: |
|
mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool) |
|
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), |
|
device=device, dtype=torch.bool) |
|
|
|
mask = torch.cat((prepend_mask, mask), dim=-1) |
|
|
|
|
|
|
|
if self.rotary_pos_emb is not None: |
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) |
|
else: |
|
rotary_pos_emb = None |
|
|
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb: |
|
x = x + self.pos_emb(x) |
|
|
|
|
|
mask = self.refine_mask(mask) |
|
for layer in self.layers: |
|
|
|
|
|
x = checkpoint(layer, x, mask=mask.bool(), rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs) |
|
|
|
if return_info: |
|
info["hidden_states"].append(x) |
|
|
|
x = self.project_out(x) |
|
|
|
if return_info: |
|
return x, info |
|
|
|
return x |
|
|
|
def refine_mask(self, mask): |
|
return mask |
|
|
|
|
|
|
|
|