import typing as tp
from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from xformers import ops


_efficient_attention_backend: str = 'torch'





def _get_attention_time_dimension(memory_efficient: bool) -> int:
    if _efficient_attention_backend == 'torch' and memory_efficient:
        return 2
    else:
        return 1



def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
                         dtype: torch.dtype = torch.float32) -> torch.Tensor:
    """Create sinusoidal positional embedding, with shape `[B, T, C]`.

    Args:
        positions (torch.Tensor): LongTensor of positions.
        dim (int): Dimension of the embedding.
        max_period (float): Maximum period of the cosine/sine functions.
        dtype (torch.dtype or str): dtype to use to generate the embedding.
    Returns:
        torch.Tensor: Sinusoidal positional embedding.
    """
    # We aim for BTC format
    assert dim % 2 == 0
    half_dim = dim // 2
    positions = positions.to(dtype)
    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)  # avoid sync point
    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)


def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
    if n_rep == 1:
        return x
    if _efficient_attention_backend == 'torch' and memory_efficient:
        bs, n_kv_heads, slen, head_dim = x.shape
        return (
            x[:, :, None, :, :]
            .expand(bs, n_kv_heads, n_rep, slen, head_dim)
            .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
        )
    else:
        bs, slen, n_kv_heads, head_dim = x.shape
        return (
            x[:, :, :, None, :]
            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
        )





class StreamingMultiheadAttention(nn.Module):

    def __init__(self, 
                 embed_dim, 
                 num_heads, dropout: float = 0.0, bias: bool = True,
                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
                 memory_efficient: bool = False, attention_as_float32: bool = False,
                 cross_attention: bool = False,
                 kv_repeat: int = 1,
                 device=None, dtype=None):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        if past_context is not None:
            assert causal

        self.embed_dim = embed_dim
        
        self.k_history = None  # previous k from the previous tokens seen in the current generation - only for selt.attn
        self.v_history = None  # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
        
        self.memory_efficient = memory_efficient
        
        
        self.cross_attention = cross_attention
        
        self.num_heads = num_heads
        self.dropout = dropout
        self.kv_repeat = kv_repeat
        



        self.custom = True #_is_custom(custom, memory_efficient)
        if not self.custom:
            print(f'{self.custom}')
        if self.custom:
            out_dim = embed_dim
            assert num_heads % kv_repeat == 0
            assert not cross_attention or kv_repeat == 1
            num_kv = num_heads // kv_repeat
            kv_dim = (embed_dim // num_heads) * num_kv
            out_dim += 2 * kv_dim
            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
            # We try to follow the default PyTorch MHA convention, to easily compare results.
            self.in_proj_weight = in_proj.weight
            self.in_proj_bias = in_proj.bias
            if bias:
                self.in_proj_bias.data.zero_()  # Following Pytorch convention
            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
            if bias:
                self.out_proj.bias.data.zero_()
        else:
            assert kv_repeat == 1
            self.mha = nn.MultiheadAttention(
                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
                **factory_kwargs)
        

    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        if not self.custom:
            # Support compat with regular MHA
            keys = [n for n, _ in self.mha.named_parameters()]
            for key in keys:
                if prefix + key in state_dict:
                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

    



    

    def forward(self, 
                query, 
                key=None,   # ignores those 2 args if not self.cross_attn 
                value=None):
        

        # time_dim = _get_attention_time_dimension(self.memory_efficient)
        # if time_dim == 2:
        layout = "b h t d"
        # else:
        #     layout = "b t h d"
        # dtype = query.dtype
        

        

        

        if self.custom:

            if self.cross_attention:
                # Different queries, keys, values, we have to spit manually the weights
                # before applying the linear.
                dim = self.in_proj_weight.shape[0] // 3
                if self.in_proj_bias is None:
                    bias_q, bias_k, bias_v = None, None, None
                else:
                    bias_q = self.in_proj_bias[:dim]
                    bias_k = self.in_proj_bias[dim: 2 * dim]
                    bias_v = self.in_proj_bias[2 * dim:]
                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
                # todo: when streaming, we could actually save k, v and check the shape actually match.
                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
                
                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
                print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
            else:
                # 1st projected makes k,v (instantaneous)
                # 2nd cat
                
                
                # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
        
                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
                if self.kv_repeat == 1:
                    # if time_dim == 2:
                    bound_layout = "b h p t d"
                    # else:
                    #     bound_layout = "b t p h d"
                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
                    q, k, v = ops.unbind(packed, dim=2)

                    
                if self.k_history is not None:
                    # 
                    # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
                    # has to be 4D with batch 1 due to single condition 3=seqlen
                    # 24 heads 64 dimofh
                    self.k_history = torch.cat([self.k_history, k], 2)
                    self.v_history = torch.cat([self.v_history, v], 2)

                else:
                    # init on 1st token (for all 47 transf layers)
                    print(f'else skip')
                    self.k_history = k
                    self.v_history = v    
                
                k = self.k_history
                v = self.v_history


                
                # KV COMPLETION ONLY ON SELF ATTENTION
                print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape)
                
            
            if self.memory_efficient:
                # print('EVER IN MEMORY EFFICIENT A')
                

                p = self.dropout if self.training else 0
                if _efficient_attention_backend == 'torch':
                    # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
                    x = torch.nn.functional.scaled_dot_product_attention(
                        q, k, v, is_causal=False, dropout_p=p
                    )
            
            x = x.to(q.dtype)
            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
            x = self.out_proj(x)
        return x


class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
    # INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    def __init__(self, 
                 d_model: int, 
                 num_heads: int, 
                 dim_feedforward: int = 2048, 
                 dropout: float = 0.1,
                 bias_ff: bool = True, 
                 bias_attn: bool = True, 
                 custom: bool = False,
                 memory_efficient: bool = False, 
                 attention_as_float32: bool = False,
                 cross_attention: bool = False, 
                 attention_dropout: tp.Optional[float] = None,
                 kv_repeat: int = 1,
                 norm: str = 'layer_norm', 
                 device=None,
                 dtype=None, 
                 **kwargs):
        
        
        super().__init__() #d_model, num_heads, dim_feedforward, dropout,
                         #device=device, dtype=dtype, batch_first=True, **kwargs)
        # print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n')                         
        # -- EN Layer
        # DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function
        
        # -- EN layer
        
        factory_kwargs = {'device': device, 'dtype': dtype}
        # Redefine self_attn to our streaming multi-head attention
        attn_kwargs: tp.Dict[str, tp.Any] = {
            'embed_dim': d_model,
            'num_heads': num_heads,
            'dropout': dropout if attention_dropout is None else attention_dropout,
            'bias': bias_attn,
            'custom': custom,
            'memory_efficient': memory_efficient,
            'attention_as_float32': attention_as_float32,
        }
        self.self_attn = StreamingMultiheadAttention(
            kv_repeat=kv_repeat, 
            **attn_kwargs, 
            **factory_kwargs)  # type: ignore
        # Redefine feedforward layers to expose bias parameter
        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
        # print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n')   # always


        self.cross_attention= None
        if cross_attention:
            self.cross_attention = StreamingMultiheadAttention(
                cross_attention=True,
                **attn_kwargs, 
                **factory_kwargs)
            
            self.dropout_cross = nn.Dropout(dropout)
            
            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)        
        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)


    def forward(self,
                src,
                cross_attention_src=None):  # txtcond
        '''T is saved float16 weights - should we cast src to float16'''
        
        x = src
        
        x = x + self.self_attn(self.norm1(x))
        
        if cross_attention_src is not None:
            x = x + self.cross_attention(
                                    query = self.norm_cross(x), 
                                    key   = cross_attention_src, 
                                    value = cross_attention_src)  # txtcondition
        
        x = x + self.linear2(F.gelu(self.linear1(   self.norm2(x)    )))
        return x


class StreamingTransformer(nn.Module):

    def __init__(self, d_model: int, 
                 num_heads: int, 
                 num_layers: int, 
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1, 
                 bias_ff: bool = True, 
                 bias_attn: bool = True,
                 custom: bool = False, 
                 memory_efficient: bool = False, 
                 attention_as_float32: bool = False,
                 cross_attention: bool = False,
                 positional_embedding: str = 'sin', 
                 max_period: float = 10_000,
                 layer_class=StreamingTransformerLayer,
                 checkpointing: str = 'none', 
                 device=None, 
                 dtype=None, 
                 **kwargs):
        super().__init__()
        assert d_model % num_heads == 0

        self.positional_embedding = positional_embedding
        self.max_period = max_period
        

        
        # self._stream_off = 0  # the llm should reinitialize this at ery generate()

        self.checkpointing = checkpointing

        
        

        self.layers = nn.ModuleList()
        for idx in range(num_layers):
            self.layers.append(
                layer_class(
                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
                    custom=custom,
                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
                    cross_attention=cross_attention,
                    device=device, dtype=dtype, **kwargs))

        if self.checkpointing != 'none':
            for layer in self.layers:
                # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
                # backward hook inside of FSDP...
                layer._magma_checkpointed = True  # type: ignore

    

    def forward(self, x: torch.Tensor, *args, **kwargs):
        
        B, T, C = x.shape
        

        if self.positional_embedding in ['sin', 'sin_rope']:
            
            positions = torch.arange(T, device=x.device).view(1, -1, 1)
            positions = positions + kwargs['token_count']  #offsets.view(-1, 1, 1)
            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
            x = x + pos_emb
            
            

        for j, lay in enumerate(self.layers):
            print(f'5_________________________{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
            x = lay(x, cross_attention_src=kwargs["cross_attention_src"])  # txt cond
            # each layer (mha) keeps history of its own k,v for all tokens
        return x