File size: 4,386 Bytes
575c4e9 1b59f63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# This module provides ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
# Feel free to contact the contributors if you have any questions or issues regarding this code.
# Contributors: Shangming Cai, Zihan Wang
# Contacts: [email protected], [email protected]
from typing import Any, Callable, Dict, Hashable, Tuple
import torch
import triton
import triton.language as tl
from triton.compiler import CompiledKernel
from triton.runtime import JITFunction
try:
import triton.language.math as tlmath # Triton 2.1
except ImportError:
import triton.language.libdevice as tlmath # Triton 2.0
class TritonKernel:
def __init__(
self,
kernel_fn: JITFunction,
grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
) -> None:
self.kernel_fn_ = kernel_fn
self.grid_fn_ = grid_fn
self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}
def run(self, *args, **kwargs):
# Set current device
input_device = args[0].device
prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
if input_device.index != cur_dev_idx:
prev_dev_idx = cur_dev_idx
torch.cuda.set_device(input_device.index)
# Compute grid
grid = self.grid_fn_(args)
# Use cached kernel if possible
kernel_key = (input_device,) + tuple(kwargs.items())
if kernel_key in self.kernel_cache_:
kernel = self.kernel_cache_[kernel_key]
kernel[grid](*args)
else:
# Compile and store new kernel
kernel = self.kernel_fn_[grid](*args, **kwargs)
self.kernel_cache_[kernel_key] = kernel
# Restore previous device
torch.cuda.set_device(prev_dev_idx)
@triton.jit
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
block_idx = tl.arange(0, HEAD_DIM)
x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
x = tl.load(X + x_base_idx + block_idx)
freq_idx = tok_idx * HEAD_DIM + block_idx
cos = tl.load(Cos + freq_idx)
rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
x_rot = tl.load(X + x_base_idx + rot_idx)
x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
sin = tl.load(Sin + freq_idx)
y_idx = (
(batch_idx * seq_len + tok_idx) * num_heads + head_idx
) * HEAD_DIM + block_idx
y = x * cos + x_rot * sin
tl.store(Y + y_idx, y.to(Y.dtype.element_ty))
apply_rope_fwd_kernel = TritonKernel(
_apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
)
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
return y
@triton.jit
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
tok_idx = tl.program_id(0)
mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
for offset in range(0, hidden_dim, BLOCK_SIZE):
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
x = tl.load(
X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
).to(tl.float32)
mean_sq += x * x / hidden_dim
rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)
for offset in range(0, hidden_dim, BLOCK_SIZE):
dim_idx = offset + tl.arange(0, BLOCK_SIZE)
dim_mask = dim_idx < hidden_dim
hidden_idx = tok_idx * hidden_dim + dim_idx
x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
w = tl.load(W + dim_idx, mask=dim_mask, other=0)
y = x * rrms * w
tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)
rms_norm_fwd_kernel = TritonKernel(
_rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
)
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
y = torch.empty_like(x)
hidden_dim = x.size(-1)
rms_norm_fwd_kernel.run(
x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
)
return y
|