File size: 4,386 Bytes
575c4e9
 
 
 
 
 
 
 
 
 
1b59f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# This module provides ApplyRoPE and RMSNorm kernels written in OpenAI Triton.
# Feel free to contact the contributors if you have any questions or issues regarding this code.
# Contributors: Shangming Cai, Zihan Wang
# Contacts: [email protected], [email protected]

from typing import Any, Callable, Dict, Hashable, Tuple

import torch
import triton
import triton.language as tl
from triton.compiler import CompiledKernel
from triton.runtime import JITFunction

try:
    import triton.language.math as tlmath  # Triton 2.1
except ImportError:
    import triton.language.libdevice as tlmath  # Triton 2.0


class TritonKernel:
    def __init__(
        self,
        kernel_fn: JITFunction,
        grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]],
    ) -> None:
        self.kernel_fn_ = kernel_fn
        self.grid_fn_ = grid_fn
        self.kernel_cache_: Dict[Hashable, CompiledKernel] = {}

    def run(self, *args, **kwargs):
        # Set current device
        input_device = args[0].device
        prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device()
        if input_device.index != cur_dev_idx:
            prev_dev_idx = cur_dev_idx
            torch.cuda.set_device(input_device.index)

        # Compute grid
        grid = self.grid_fn_(args)

        # Use cached kernel if possible
        kernel_key = (input_device,) + tuple(kwargs.items())
        if kernel_key in self.kernel_cache_:
            kernel = self.kernel_cache_[kernel_key]
            kernel[grid](*args)
        else:
            # Compile and store new kernel
            kernel = self.kernel_fn_[grid](*args, **kwargs)
            self.kernel_cache_[kernel_key] = kernel

        # Restore previous device
        torch.cuda.set_device(prev_dev_idx)


@triton.jit
def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr):
    batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    seq_len, num_heads = tl.num_programs(1), tl.num_programs(2)
    block_idx = tl.arange(0, HEAD_DIM)
    x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM
    x = tl.load(X + x_base_idx + block_idx)
    freq_idx = tok_idx * HEAD_DIM + block_idx
    cos = tl.load(Cos + freq_idx)
    rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM
    x_rot = tl.load(X + x_base_idx + rot_idx)
    x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot)
    sin = tl.load(Sin + freq_idx)
    y_idx = (
        (batch_idx * seq_len + tok_idx) * num_heads + head_idx
    ) * HEAD_DIM + block_idx
    y = x * cos + x_rot * sin
    tl.store(Y + y_idx, y.to(Y.dtype.element_ty))


apply_rope_fwd_kernel = TritonKernel(
    _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3])
)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    y = torch.empty(x.shape, dtype=x.dtype, device=x.device)
    apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1))
    return y


@triton.jit
def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr):
    tok_idx = tl.program_id(0)

    mean_sq = tl.zeros([BLOCK_SIZE], tl.float32)
    for offset in range(0, hidden_dim, BLOCK_SIZE):
        dim_idx = offset + tl.arange(0, BLOCK_SIZE)
        x = tl.load(
            X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0
        ).to(tl.float32)
        mean_sq += x * x / hidden_dim
    rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps)

    for offset in range(0, hidden_dim, BLOCK_SIZE):
        dim_idx = offset + tl.arange(0, BLOCK_SIZE)
        dim_mask = dim_idx < hidden_dim
        hidden_idx = tok_idx * hidden_dim + dim_idx
        x = tl.load(X + hidden_idx, mask=dim_mask, other=0)
        w = tl.load(W + dim_idx, mask=dim_mask, other=0)
        y = x * rrms * w
        tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask)


rms_norm_fwd_kernel = TritonKernel(
    _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1)
)


def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float):
    y = torch.empty_like(x)
    hidden_dim = x.size(-1)
    rms_norm_fwd_kernel.run(
        x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim)
    )
    return y