File size: 7,203 Bytes
74a8b24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import math
from typing import Optional, Tuple, TypeVar
import torch.nn as nn
import torch
import triton
from functools import lru_cache
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
Layout = Tuple[torch.LongTensor, torch.LongTensor]
def create_sparse_attn_mask(
n_heads: int,
max_seq_len: int,
max_seq_len_k: int,
dtype: torch.dtype,
device: torch.device,
BLOCK: int,
local_blocks: int,
vert_stride: int,
homo_head: bool,
return_dense: bool
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
n_heads=n_heads,
q_len=max_seq_len,
N_CTX=max_seq_len_k,
dtype=dtype,
device=device,
BLOCK=BLOCK,
local_blocks=local_blocks,
vert_stride=vert_stride,
homo_head=homo_head,
return_dense=return_dense
)
return layout, block_sparse_pattern
class BlockSparseAttentionLayer(nn.Module):
def __init__(
self,
n_heads: int,
max_seq_len: int,
sparse_block_size: int,
local_blocks: int,
vert_stride: int,
kernel_block_size: Optional[int] = None,
homo_head: bool = False,
active_head_range: Optional[Tuple[int]] = None
) -> None:
super().__init__()
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.sparse_block_size = sparse_block_size
self.kernel_block_size = kernel_block_size or sparse_block_size
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.homo_head = homo_head
self.active_head_range = active_head_range
# Internal Parameters used by the layer
self._sparse_block_mask = None
self._sparse_layout = None
self._dtype = None
self._device = None
# TODO(bapatra): Ideally, I'd want to keep all the code for
# forward to be handled here, and not branch for training and inference.
# However, that refactor would need a lot of testing. For now, using the
# training op as is, and will refactor again later.
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
def _initialize_internals(
self,
dtype: torch.dtype,
device: torch.device
) -> None:
self._dtype, self._device = dtype, device
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
max_seq_len_k=self.max_seq_len,
dtype=dtype,
device=device,
BLOCK=self.sparse_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
return_dense=False,
)
if (not self.homo_head) and (self.active_head_range is not None):
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
h_start, h_end = self.active_head_range
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
if self.sparse_block_size // self.kernel_block_size > 1:
_mul = self.sparse_block_size // self.kernel_block_size
# need to consider if block_m and block_n are different
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
num_sparse_blocks = self._sparse_block_mask.size(-1)
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sm_scale: float,
*,
# Arguments Related to Block Attention Inference
left_paddings: Optional[torch.LongTensor] = None,
seqlens: Optional[torch.LongTensor] = None,
# Arguements Related to Variable Length Inference
cu_seqlens_k: Optional[torch.LongTensor] = None,
cu_seqlens_q: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
blocksparse_op = get_local_strided_sparse_attention_op(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
sparse_block_size=self.sparse_block_size,
kernel_block_size=self.kernel_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
device=q.device,
inference=not self.training
)
return blocksparse_op(q, k, v, sm_scale)
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
# First set internals if they have not been set
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
self._initialize_internals(dtype=q.dtype, device=q.device)
if k.dim() == 3:
assert cu_seqlens_k is not None
return blocksparse_flash_attn_varlen_fwd(
q=q,
k=k,
v=v,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
if k.dim() == 4:
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
return blocksparse_flash_attn_padded_fwd(
q=q,
k=k,
v=v,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
left_paddings=left_paddings,
seqlens=seqlens,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')
|