Triton-InternViT-6B-448px-V1-5 / flash_attention.py
radna's picture
Update flash_attention.py
6e45b59 verified
raw
history blame
3.64 kB
import torch
import torch.nn as nn
from einops import rearrange
try:
from .triton_flash_atn import _attention
from .triton_bert_pading import pad_input, unpad_input
except:
print("FlashAttention is not installed.")
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(
self,
qkv,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=qkv.device,
)
output = _attention.apply(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal
)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = _attention.apply(
x_unpad,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
batch_size,
seqlen,
),
"b s (h d) -> b s h d",
h=nheads,
)
else:
assert max_s is not None
output = _attention.apply(
qkv,
cu_seqlens,
max_s,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal
)
return output, None