import math
import torch
import triton
import triton.language as tl

@triton.heuristics(
    {
        "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
    }
)
@triton.jit
def _fwd_eva_prep_kv_kernel(
    K, # [b, h, n, d]
    V, # [b, h, n, d]
    PARAM_MU, # [1, h, 1, 1, d]
    PARAM_PHI,  # [1, h, 1, 1, d]
    Mask, # [b, h, n, 1]
    Out_RFA_K, # [b, h, c, d]
    Out_RFA_V, # [b, h, c, d]
    softmax_scale,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_mu_h,
    stride_phi_h,
    stride_mb, stride_mn,
    stride_ok_b, stride_ok_h, stride_ok_c,
    stride_ov_b, stride_ov_h, stride_ov_c,
    nheads,
    seqlen,
    nchunks,
    headdim,
    CHUNKS_PER_BLOCK: tl.constexpr,
    CHUNK_SIZE: tl.constexpr,
    MASK_TYPE: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_n = tl.program_id(0)
    offs_bh = tl.program_id(1)
    offs_h = offs_bh % nheads
    offs_b = offs_bh // nheads
    # initialize offsets
    # we load BLOCK_N keys and values each time, and
    # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
    offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
    offs_m = tl.arange(0, CHUNK_SIZE)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    k_ptrs = (
        K +
        offs_b * stride_kb +
        offs_h * stride_kh +
        (
            (
                start_n * BLOCK_N + 
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_kn + 
            offs_d[None, None, :]
        )
    )
    v_ptrs = (
        V +
        offs_b * stride_vb +
        offs_h * stride_vh +
        (
            (
                start_n * BLOCK_N +
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_vn + 
            offs_d[None, None, :]
        )
    )
    param_mu_ptrs = (
        PARAM_MU +
        offs_h * stride_mu_h +
        offs_d[None, None, :]
    )
    param_phi_ptrs = (
        PARAM_PHI +
        offs_h * stride_phi_h +
        offs_d[None, None, :]
    )
    log2e = 1.4426950408889634
    if MASK_TYPE == 1:
        m_ptrs = (
            Mask +
            offs_b * stride_mb +
            (
                (
                    start_n * BLOCK_N +
                    offs_c[:, None] * CHUNK_SIZE + 
                    offs_m[None, :]
                ) * stride_mn
            )
        )
    if EVEN_N:
        if EVEN_HEADDIM:
            k = tl.load(
                k_ptrs
            )
        else:
            k = tl.load(
                k_ptrs,
                mask=offs_d[None, None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            k = tl.load(
                k_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None, None] * CHUNK_SIZE + 
                        offs_m[None, :, None]
                    ) < seqlen,
                other=0.0
            )
        else:
            k = tl.load(
                k_ptrs,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
                other=0.0
            )
    
    param_mu = tl.load(param_mu_ptrs).to(k.dtype)
    rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
    rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
    rfa_k_c_w *= log2e
    if MASK_TYPE == 1:
        if EVEN_N:
            mask = tl.load(
                m_ptrs
            )
        else:
            mask = tl.load(
                m_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None] * CHUNK_SIZE + 
                        offs_m[None, :]
                    ) < seqlen,
                other=1,
            )
        rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
        
    m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
    masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
    m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
    rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
    denom_k = tl.sum(rfa_k_c_w, axis=-1)
    denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
    rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
    rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
    # TODO: understand why rematerialize offsets to save registers?
    offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
    out_rfa_k_ptrs = (
        Out_RFA_K +
        offs_b * stride_ok_b +
        offs_h * stride_ok_h +
        (offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
    )

    if EVEN_N:
        if EVEN_HEADDIM:
            tl.store(
                out_rfa_k_ptrs, rfa_k_c
            )
        else:
            tl.store(
                out_rfa_k_ptrs, rfa_k_c,
                mask=offs_d[None, :] < headdim
            )
    else:
        if EVEN_HEADDIM:
            tl.store(
                out_rfa_k_ptrs, rfa_k_c,
                mask=offs_out_c[:, None] < nchunks
            )
        else:
            tl.store(
                out_rfa_k_ptrs, rfa_k_c,
                mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
            )


    param_phi = tl.load(param_phi_ptrs).to(k.dtype)
    rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
    rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
    rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
    rfa_v_c_w *= log2e * softmax_scale
    if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
        rfa_v_c_w += tl.where(
            (
                start_n * BLOCK_N +
                offs_c[:, None] * CHUNK_SIZE + 
                offs_m[None, :]
            ) < seqlen, 
            0, 
            float("-inf")
        )

    if MASK_TYPE == 1:
        rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)

    if EVEN_N:
        if EVEN_HEADDIM:
            v = tl.load(
                v_ptrs
            )
        else:
            v = tl.load(
                v_ptrs,
                mask=offs_d[None, None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            v = tl.load(
                v_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None, None] * CHUNK_SIZE + 
                        offs_m[None, :, None]
                    ) < seqlen,
                other=0.0
            )
        else:
            v = tl.load(
                v_ptrs,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
                other=0.0
            )
    

    m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
    masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
    m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
    rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
    denom_v = tl.sum(rfa_v_c_w, axis=-1)
    denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
    rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
    rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)

    offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
    out_rfa_v_ptrs = (
        Out_RFA_V +
        offs_b * stride_ov_b +
        offs_h * stride_ov_h +
        (offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
    )
    if EVEN_N:
        if EVEN_HEADDIM:
            tl.store(
                out_rfa_v_ptrs, rfa_v_c
            )
        else:
            tl.store(
                out_rfa_v_ptrs, rfa_v_c,
                mask=offs_d[None, :] < headdim
            )
    else:
        if EVEN_HEADDIM:
            tl.store(
                out_rfa_v_ptrs, rfa_v_c,
                mask=offs_out_c[:, None] < nchunks
            )
        else:
            tl.store(
                out_rfa_v_ptrs, rfa_v_c,
                mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
            )



@triton.heuristics(
    {
        "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
    }
)
@triton.jit
def _bwd_eva_prep_kv_kernel(
    RFA_K, # [b, h, c, d]
    RFA_V, # [b, h, c, d]
    K, # [b, h, n, d]
    V, # [b, h, n, d]
    PARAM_MU, # [1, h, 1, 1, d]
    PARAM_PHI,  # [1, h, 1, 1, d]
    Mask, # [b, h, n, 1]
    D_RFA_K, # [b, h, c, d]
    D_RFA_V, # [b, h, c, d]
    D_K, # [b, h, n, d]
    D_V, # [b, h, n, d]
    D_PARAM_MU_PARTIAL, # [b, h, g, d]
    D_PARAM_PHI_PARTIAL, # [b, h, g, d]
    softmax_scale,
    stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
    stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_mu_h,
    stride_phi_h,
    stride_mb, stride_mn,
    stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
    stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
    stride_d_k_b, stride_d_k_h, stride_d_k_n,
    stride_d_v_b, stride_d_v_h, stride_d_v_n,
    stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
    stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
    nheads,
    seqlen,
    nchunks,
    headdim,
    CHUNKS_PER_BLOCK: tl.constexpr,
    CHUNK_SIZE: tl.constexpr,
    MASK_TYPE: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_n = tl.program_id(0)
    offs_bh = tl.program_id(1)
    offs_h = offs_bh % nheads
    offs_b = offs_bh // nheads
    # initialize offsets
    # we load BLOCK_N keys and values each time, and
    # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
    offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
    offs_m = tl.arange(0, CHUNK_SIZE)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c

    k_ptrs = (
        K +
        offs_b * stride_kb +
        offs_h * stride_kh +
        (
            (
                start_n * BLOCK_N + 
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_kn + 
            offs_d[None, None, :]
        )
    )
    rfa_k_ptrs = (
        RFA_K +
        offs_b * stride_rfa_k_b +
        offs_h * stride_rfa_k_h +
        (offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
    )
    rfa_v_ptrs = (
        RFA_V +
        offs_b * stride_rfa_v_b +
        offs_h * stride_rfa_v_h +
        (offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
    )

    d_rfa_k_ptrs = (
        D_RFA_K +
        offs_b * stride_d_rfa_k_b +
        offs_h * stride_d_rfa_k_h +
        (offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
    )
    d_rfa_v_ptrs = (
        D_RFA_V +
        offs_b * stride_d_rfa_v_b +
        offs_h * stride_d_rfa_v_h +
        (offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
    )

    param_mu_ptrs = (
        PARAM_MU +
        offs_h * stride_mu_h +
        offs_d[None, None, :]
    )
    param_phi_ptrs = (
        PARAM_PHI +
        offs_h * stride_phi_h +
        offs_d[None, None, :]
    )
    
    log2e = 1.4426950408889634
    if MASK_TYPE == 1:
        m_ptrs = (
            Mask +
            offs_b * stride_mb +
            (
                (
                    start_n * BLOCK_N +
                    offs_c[:, None] * CHUNK_SIZE + 
                    offs_m[None, :]
                ) * stride_mn
            )
        )
    if EVEN_N:
        if EVEN_HEADDIM:
            k = tl.load(
                k_ptrs
            )
        else:
            k = tl.load(
                k_ptrs,
                mask=offs_d[None, None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            k = tl.load(
                k_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None, None] * CHUNK_SIZE + 
                        offs_m[None, :, None]
                    ) < seqlen,
                other=0.0
            )
        else:
            k = tl.load(
                k_ptrs,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
                other=0.0
            )

    if EVEN_N:
        if EVEN_HEADDIM:
            rfa_k = tl.load(
                rfa_k_ptrs
            )
        else:
            rfa_k = tl.load(
                rfa_k_ptrs,
                mask=offs_d[None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            rfa_k = tl.load(
                rfa_k_ptrs,
                mask=offs_rfa_c[:, None] < nchunks,
                other=0.0
            )
        else:
            rfa_k = tl.load(
                rfa_k_ptrs,
                mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
                other=0.0
            )
    
    if EVEN_N:
        if EVEN_HEADDIM:
            d_rfa_k = tl.load(
                d_rfa_k_ptrs
            )
        else:
            d_rfa_k = tl.load(
                d_rfa_k_ptrs,
                mask=offs_d[None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            d_rfa_k = tl.load(
                d_rfa_k_ptrs,
                mask=offs_rfa_c[:, None] < nchunks,
                other=0.0
            )
        else:
            d_rfa_k = tl.load(
                d_rfa_k_ptrs,
                mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
                other=0.0
            )
    
    param_mu = tl.load(param_mu_ptrs).to(k.dtype)
    mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
    mu_c_w += tl.sum(k * param_mu, axis=-1)
    mu_c_w *= log2e

    if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
        mu_c_w += tl.where(
            (
                start_n * BLOCK_N +
                offs_c[:, None] * CHUNK_SIZE + 
                offs_m[None, :]
            ) < seqlen, 
            0, 
            float("-inf")
        )

    if MASK_TYPE == 1:
        if EVEN_N:
            mask = tl.load(
                m_ptrs
            )
        else:
            mask = tl.load(
                m_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None] * CHUNK_SIZE + 
                        offs_m[None, :]
                    ) < seqlen,
                other=1,
            )
        mu_c_w = tl.where(mask, float("-inf"), mu_c_w)

    # [c, w]
    m_mu_c_w = tl.max(mu_c_w, axis=-1)
    masked_out_rows_mu = (m_mu_c_w == float("-inf"))
    m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
    mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
    denom_mu = tl.sum(mu_c_w, axis=-1)
    denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
    mu_tilde_c_w = mu_c_w / denom_mu[:, None]
    mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
    # [c, d] [c, w, d] -> [c, w]
    d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
    # [c, d] [c, d] -> [c]
    d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
    d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w

    # [c, w] [c, w, d] -> [d]
    d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
    # [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
    d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu

    d_param_mu_partial_ptrs = (
        D_PARAM_MU_PARTIAL +
        offs_b * stride_d_mu_b +
        offs_h * stride_d_mu_h +
        start_n * stride_d_mu_g +
        offs_d
    )
    if EVEN_HEADDIM:
        tl.store(
            d_param_mu_partial_ptrs, d_param_mu
        )
    else:
        tl.store(
            d_param_mu_partial_ptrs, d_param_mu,
            mask=offs_d < headdim
        )


    v_ptrs = (
        V +
        offs_b * stride_vb +
        offs_h * stride_vh +
        (
            (
                start_n * BLOCK_N +
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_vn + 
            offs_d[None, None, :]
        )
    )
    if EVEN_N:
        if EVEN_HEADDIM:
            v = tl.load(
                v_ptrs
            )
        else:
            v = tl.load(
                v_ptrs,
                mask=offs_d[None, None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            v = tl.load(
                v_ptrs,
                mask=(
                        start_n * BLOCK_N +
                        offs_c[:, None, None] * CHUNK_SIZE + 
                        offs_m[None, :, None]
                    ) < seqlen,
                other=0.0
            )
        else:
            v = tl.load(
                v_ptrs,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
                other=0.0
            )


    if EVEN_N:
        if EVEN_HEADDIM:
            rfa_v = tl.load(
                rfa_v_ptrs
            )
        else:
            rfa_v = tl.load(
                rfa_v_ptrs,
                mask=offs_d[None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            rfa_v = tl.load(
                rfa_v_ptrs,
                mask=offs_rfa_c[:, None] < nchunks,
                other=0.0
            )
        else:
            rfa_v = tl.load(
                rfa_v_ptrs,
                mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
                other=0.0
            )
    
    if EVEN_N:
        if EVEN_HEADDIM:
            d_rfa_v = tl.load(
                d_rfa_v_ptrs
            )
        else:
            d_rfa_v = tl.load(
                d_rfa_v_ptrs,
                mask=offs_d[None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            d_rfa_v = tl.load(
                d_rfa_v_ptrs,
                mask=offs_rfa_c[:, None] < nchunks,
                other=0.0
            )
        else:
            d_rfa_v = tl.load(
                d_rfa_v_ptrs,
                mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
                other=0.0
            )
    
    param_phi = tl.load(param_phi_ptrs).to(k.dtype)
    phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
    phi_c_w += tl.sum(k * param_phi, axis=-1)
    phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
    phi_c_w *= log2e * softmax_scale
    if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
        phi_c_w += tl.where(
            (
                start_n * BLOCK_N +
                offs_c[:, None] * CHUNK_SIZE + 
                offs_m[None, :]
            ) < seqlen, 
            0, 
            float("-inf")
        )

    if MASK_TYPE == 1:
        phi_c_w = tl.where(mask, float("-inf"), phi_c_w)


    m_phi_c_w = tl.max(phi_c_w, axis=-1)
    masked_out_rows_phi = (m_phi_c_w == float("-inf"))
    m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
    phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
    denom_phi = tl.sum(phi_c_w, axis=-1)
    denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
    phi_tilde_c_w = phi_c_w / denom_phi[:, None]
    # phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
    # phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
    phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
    d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
    d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
    d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w

    d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
    d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
    # [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
    d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)

    d_k_ptrs = (
        D_K +
        offs_b * stride_d_k_b +
        offs_h * stride_d_k_h +
        (
            (
                start_n * BLOCK_N +
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_d_k_n + 
            offs_d[None, None, :]
        )
    )
    d_v_ptrs = (
        D_V +
        offs_b * stride_d_v_b +
        offs_h * stride_d_v_h +
        (
            (
                start_n * BLOCK_N +
                offs_c[:, None, None] * CHUNK_SIZE + 
                offs_m[None, :, None]
            ) * stride_d_v_n + 
            offs_d[None, None, :]
        )
    )
    if EVEN_N:
        if EVEN_HEADDIM:
            tl.store(
                d_k_ptrs, d_k
            )
            tl.store(
                d_v_ptrs, d_v
            )
        else:
            tl.store(
                d_k_ptrs, d_k,
                mask=offs_d[None, None, :] < headdim
            )
            tl.store(
                d_v_ptrs, d_v,
                mask=offs_d[None, None, :] < headdim
            )
    else:
        if EVEN_HEADDIM:
            tl.store(
                d_k_ptrs, d_k,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ),
            )
            tl.store(
                d_v_ptrs, d_v,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ),
            )
        else:
            tl.store(
                d_k_ptrs, d_k,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
            )
            tl.store(
                d_v_ptrs, d_v,
                mask=(
                        (
                            start_n * BLOCK_N +
                            offs_c[:, None, None] * CHUNK_SIZE + 
                            offs_m[None, :, None]
                        ) < seqlen
                    ) & (offs_d[None, None, :] < headdim),
            )
    d_param_phi_partial_ptrs = (
        D_PARAM_PHI_PARTIAL +
        offs_b * stride_d_phi_b +
        offs_h * stride_d_phi_h +
        start_n * stride_d_phi_g +
        offs_d
    )
    if EVEN_HEADDIM:
        tl.store(
            d_param_phi_partial_ptrs, d_param_phi
        )
    else:
        tl.store(
            d_param_phi_partial_ptrs, d_param_phi,
            mask=offs_d < headdim
        )

def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
    k, v, param_mu, param_phi = [
        x if x.stride(-1) == 1 else x.contiguous() 
        for x in [k, v, param_mu, param_phi]
    ]

    # shape constraints
    batch, nheads, seqlen, head_dim = k.shape
    assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
    nchunks = seqlen // chunksize
    assert k.shape == (batch, nheads, seqlen, head_dim)
    assert v.shape == (batch, nheads, seqlen, head_dim)
    assert param_mu.shape == (1, nheads, 1, 1, head_dim)
    assert param_phi.shape == (1, nheads, 1, 1, head_dim)
    assert head_dim <= 128, "We only test head dimensions up to 128"
    assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
    assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
    assert k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)

    mask_type = 0
    if mask is not None:
        mask_type = 1
        assert mask.dtype == torch.bool
        assert mask.is_cuda
        assert mask.dim() == 4
        assert mask.shape == (batch, 1, seqlen, 1)
        if mask.stride(-1) != 1:
            mask = mask.contiguous()
    mask_strides = (
        (mask.stride(0), mask.stride(2)) 
        if mask_type == 1 else 
        (0, 0)
    )
    out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
    out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)

    BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
    BLOCK = 128
    num_warps = 4 if head_dim <= 64 else 8
    
    assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
    chunks_per_block = BLOCK // chunksize

    grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
    _fwd_eva_prep_kv_kernel[grid](
        k,
        v,
        param_mu,
        param_phi,
        mask,
        out_rfa_k,
        out_rfa_v,
        softmax_scale,
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        param_mu.stride(1),
        param_phi.stride(1),
        mask_strides[0], mask_strides[1],
        out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
        out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
        nheads,
        seqlen,
        nchunks,
        head_dim,
        chunks_per_block,
        chunksize,
        mask_type,
        BLOCK_HEADDIM,
        BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
    )
    return out_rfa_k, out_rfa_v

def triton_eva_prep_kv_bwd(
        d_rfa_k, d_rfa_v,
        k, v, param_mu, param_phi, 
        mask, 
        rfa_k, rfa_v,
        d_k, d_v, d_param_mu, d_param_phi,
        softmax_scale, 
        mask_type,
        chunksize
    ):
    d_rfa_k, d_rfa_v = [
        x if x.stride(-1) == 1 else x.contiguous() 
        for x in [d_rfa_k, d_rfa_v]
    ]

    # shape constraints
    batch, nheads, seqlen, head_dim = k.shape
    assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
    nchunks = seqlen // chunksize
    softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)

    mask_strides = (
        (mask.stride(0), mask.stride(2))
        if mask_type == 1 else 
        (0, 0)
    )

    BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
    BLOCK = 128
    num_warps = 4 if head_dim <= 64 else 8
    
    assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
    chunks_per_block = BLOCK // chunksize

    partial_groups = triton.cdiv(seqlen, BLOCK)
    d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
    d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
    grid = lambda META: (partial_groups, batch * nheads)
    _bwd_eva_prep_kv_kernel[grid](
        rfa_k, # [b, h, c, d]
        rfa_v, # [b, h, c, d]
        k, # [b, h, n, d]
        v, # [b, h, n, d]
        param_mu, # [1, h, 1, 1, d]
        param_phi,  # [1, h, 1, 1, d]
        mask, # [b, h, n, 1]
        d_rfa_k, # [b, h, c, d]
        d_rfa_v, # [b, h, c, d]
        d_k, # [b, h, n, d]
        d_v, # [b, h, n, d]
        d_param_mu_partial, # [b, h, g, d]
        d_param_phi_partial, # [b, h, g, d]
        softmax_scale,
        rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
        rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        param_mu.stride(1),
        param_phi.stride(1),
        mask_strides[0], mask_strides[1],
        d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
        d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
        d_k.stride(0), d_k.stride(1), d_k.stride(2),
        d_v.stride(0), d_v.stride(1), d_v.stride(2),
        d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
        d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
        nheads,
        seqlen,
        nchunks,
        head_dim,
        chunks_per_block,
        chunksize,
        mask_type,
        BLOCK_HEADDIM,
        BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
    )
    d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
    d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))



class EvaPrepKVFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
        if mask is not None:
            mask_type = 1
        else:
            mask_type = 0
        rfa_k, rfa_v = triton_eva_prep_kv_fwd(
            k, v, param_mu, param_phi, mask, softmax_scale, chunksize
        )
        ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
        ctx.softmax_scale = softmax_scale
        ctx.chunksize = chunksize
        ctx.mask_type = mask_type
        return rfa_k, rfa_v

    @staticmethod
    def backward(ctx, d_rfa_k, d_rfa_v):
        k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
        d_k = torch.empty_like(k)
        d_v = torch.empty_like(v)
        d_param_mu = torch.empty_like(param_mu)
        d_param_phi = torch.empty_like(param_phi)
        triton_eva_prep_kv_bwd(
            d_rfa_k, d_rfa_v,
            k, v, param_mu, param_phi, 
            mask, 
            rfa_k, rfa_v,
            d_k, d_v, d_param_mu, d_param_phi,
            ctx.softmax_scale, 
            ctx.mask_type,
            ctx.chunksize
        )
        return d_k, d_v, d_param_mu, d_param_phi, None, None, None

def eva_prep_kv_func_triton(
        k, v, 
        param_mu, param_phi,
        mask, 
        softmax_scale=None, chunksize=None
    ):
    return EvaPrepKVFunc.apply(
        k, v, 
        param_mu, param_phi,
        mask, 
        softmax_scale, chunksize
    )