import typing as tp from einops import rearrange import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint from xformers import ops _efficient_attention_backend: str = 'torch' def _get_attention_time_dimension(memory_efficient: bool) -> int: if _efficient_attention_backend == 'torch' and memory_efficient: return 2 else: return 1 def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Create sinusoidal positional embedding, with shape `[B, T, C]`. Args: positions (torch.Tensor): LongTensor of positions. dim (int): Dimension of the embedding. max_period (float): Maximum period of the cosine/sine functions. dtype (torch.dtype or str): dtype to use to generate the embedding. Returns: torch.Tensor: Sinusoidal positional embedding. """ # We aim for BTC format assert dim % 2 == 0 half_dim = dim // 2 positions = positions.to(dtype) adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" if n_rep == 1: return x if _efficient_attention_backend == 'torch' and memory_efficient: bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) else: bs, slen, n_kv_heads, head_dim = x.shape return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class StreamingMultiheadAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, kv_repeat: int = 1, device=None, dtype=None): super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} if past_context is not None: assert causal self.embed_dim = embed_dim self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history self.memory_efficient = memory_efficient self.cross_attention = cross_attention self.num_heads = num_heads self.dropout = dropout self.kv_repeat = kv_repeat self.custom = True #_is_custom(custom, memory_efficient) if not self.custom: print(f'{self.custom}') if self.custom: out_dim = embed_dim assert num_heads % kv_repeat == 0 assert not cross_attention or kv_repeat == 1 num_kv = num_heads // kv_repeat kv_dim = (embed_dim // num_heads) * num_kv out_dim += 2 * kv_dim in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) # We try to follow the default PyTorch MHA convention, to easily compare results. self.in_proj_weight = in_proj.weight self.in_proj_bias = in_proj.bias if bias: self.in_proj_bias.data.zero_() # Following Pytorch convention self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) if bias: self.out_proj.bias.data.zero_() else: assert kv_repeat == 1 self.mha = nn.MultiheadAttention( embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, **factory_kwargs) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): if not self.custom: # Support compat with regular MHA keys = [n for n, _ in self.mha.named_parameters()] for key in keys: if prefix + key in state_dict: state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def forward(self, query, key=None, # ignores those 2 args if not self.cross_attn value=None): # time_dim = _get_attention_time_dimension(self.memory_efficient) # if time_dim == 2: layout = "b h t d" # else: # layout = "b t h d" # dtype = query.dtype if self.custom: if self.cross_attention: # Different queries, keys, values, we have to spit manually the weights # before applying the linear. dim = self.in_proj_weight.shape[0] // 3 if self.in_proj_bias is None: bias_q, bias_k, bias_v = None, None, None else: bias_q = self.in_proj_bias[:dim] bias_k = self.in_proj_bias[dim: 2 * dim] bias_v = self.in_proj_bias[2 * dim:] q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) # todo: when streaming, we could actually save k, v and check the shape actually match. k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5') else: # 1st projected makes k,v (instantaneous) # 2nd cat # HISTORY - DIFFERENT FOR EACH TRANSF LAYER projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) if self.kv_repeat == 1: # if time_dim == 2: bound_layout = "b h p t d" # else: # bound_layout = "b t p h d" packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) q, k, v = ops.unbind(packed, dim=2) if self.k_history is not None: # # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT # has to be 4D with batch 1 due to single condition 3=seqlen # 24 heads 64 dimofh self.k_history = torch.cat([self.k_history, k], 2) self.v_history = torch.cat([self.v_history, v], 2) else: # init on 1st token (for all 47 transf layers) print(f'else skip') self.k_history = k self.v_history = v k = self.k_history v = self.v_history # KV COMPLETION ONLY ON SELF ATTENTION print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape) if self.memory_efficient: # print('EVER IN MEMORY EFFICIENT A') p = self.dropout if self.training else 0 if _efficient_attention_backend == 'torch': # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen') x = torch.nn.functional.scaled_dot_product_attention( q, k, v, is_causal=False, dropout_p=p ) x = x.to(q.dtype) x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) x = self.out_proj(x) return x class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer): # INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, attention_dropout: tp.Optional[float] = None, kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): super().__init__() #d_model, num_heads, dim_feedforward, dropout, #device=device, dtype=dtype, batch_first=True, **kwargs) # print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n') # -- EN Layer # DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function # -- EN layer factory_kwargs = {'device': device, 'dtype': dtype} # Redefine self_attn to our streaming multi-head attention attn_kwargs: tp.Dict[str, tp.Any] = { 'embed_dim': d_model, 'num_heads': num_heads, 'dropout': dropout if attention_dropout is None else attention_dropout, 'bias': bias_attn, 'custom': custom, 'memory_efficient': memory_efficient, 'attention_as_float32': attention_as_float32, } self.self_attn = StreamingMultiheadAttention( kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore # Redefine feedforward layers to expose bias parameter self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) # print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n') # always self.cross_attention= None if cross_attention: self.cross_attention = StreamingMultiheadAttention( cross_attention=True, **attn_kwargs, **factory_kwargs) self.dropout_cross = nn.Dropout(dropout) self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) self.norm1 = nn.LayerNorm(d_model, eps=1e-5) self.norm2 = nn.LayerNorm(d_model, eps=1e-5) def forward(self, src, cross_attention_src=None): # txtcond '''T is saved float16 weights - should we cast src to float16''' x = src x = x + self.self_attn(self.norm1(x)) if cross_attention_src is not None: x = x + self.cross_attention( query = self.norm_cross(x), key = cross_attention_src, value = cross_attention_src) # txtcondition x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) ))) return x class StreamingTransformer(nn.Module): def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, positional_embedding: str = 'sin', max_period: float = 10_000, layer_class=StreamingTransformerLayer, checkpointing: str = 'none', device=None, dtype=None, **kwargs): super().__init__() assert d_model % num_heads == 0 self.positional_embedding = positional_embedding self.max_period = max_period # self._stream_off = 0 # the llm should reinitialize this at ery generate() self.checkpointing = checkpointing self.layers = nn.ModuleList() for idx in range(num_layers): self.layers.append( layer_class( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, custom=custom, memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, cross_attention=cross_attention, device=device, dtype=dtype, **kwargs)) if self.checkpointing != 'none': for layer in self.layers: # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the # backward hook inside of FSDP... layer._magma_checkpointed = True # type: ignore def forward(self, x: torch.Tensor, *args, **kwargs): B, T, C = x.shape if self.positional_embedding in ['sin', 'sin_rope']: positions = torch.arange(T, device=x.device).view(1, -1, 1) positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1) pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) x = x + pos_emb for j, lay in enumerate(self.layers): print(f'5_________________________{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________') x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # txt cond # each layer (mha) keeps history of its own k,v for all tokens return x