# Copyright (c) 2023, Tri Dao. # Copyright 2024 CATIE. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Modifications to the orignal file # - support for torch.compile import triton import triton.language as tl import torch import math from typing import Tuple @triton.jit def _rmsnorm_fwd_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, N, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_N: tl.constexpr, IS_EVEN_N: tl.constexpr ): row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N if IS_EVEN_N: w = tl.load(W + cols).to(tl.float32) else: w = tl.load(W + cols, mask=mask).to(tl.float32) x_hat = x * rstd y = x_hat * w # Write output if IS_EVEN_N: tl.store(Y + cols, y) else: tl.store(Y + cols, y, mask=mask) @triton.jit def _rmsnorm_bwd_kernel( X, # pointer to the input W, # pointer to the weights DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_dy_row, stride_dx_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero rows_per_program, BLOCK_N: tl.constexpr, IS_EVEN_N: tl.constexpr ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row w = tl.load(W + cols, mask=mask).to(tl.float32) dw = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM if IS_EVEN_N: x = tl.load(X + cols).to(tl.float32) dy = tl.load(DY + cols).to(tl.float32) else: x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) rstd = tl.load(Rstd + row) # Compute dx xhat = x * rstd if not IS_EVEN_N: xhat = tl.where(mask, xhat, 0.0) wdy = w * dy dw += dy * xhat c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd tl.store(DX + cols, dx, mask=mask) X += stride_x_row DY += stride_dy_row DX += stride_dx_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) @torch.library.custom_op("flasht5::rmsnorm_triton_fwd", mutates_args=(), device_types="cuda") def rmsnorm_triton_fwd( X: torch.Tensor, weight: torch.Tensor, eps: float ) -> Tuple[torch.Tensor, torch.Tensor]: M, N = X.shape assert X.stride(-1) == 1 assert weight.shape == (N,) assert weight.stride(-1) == 1 # allocate output Y = torch.empty_like(X) assert Y.stride(-1) == 1 rstd = torch.empty((M,), dtype=torch.float32, device=X.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // X.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) assert N <= BLOCK_N # heuristics for number of warps with torch.cuda.device(X.device.index): _rmsnorm_fwd_kernel[(M,)]( X, Y, weight, rstd, X.stride(0), Y.stride(0), N, eps, BLOCK_N, (N % BLOCK_N == 0) ) return Y, rstd @torch.library.register_fake("flasht5::rmsnorm_triton_fwd") def rmsnorm_triton_fwd_abstract(X, weight, eps): M, N = X.shape Y = torch.empty_like(X) rstd = torch.empty((M,), dtype=torch.float32, device=X.device) return Y, rstd @torch.library.custom_op("flasht5::rmsnorm_triton_bwd", mutates_args=(), device_types="cuda") def rmsnorm_triton_bwd( dy: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, rstd: torch.Tensor, eps: float ) -> Tuple[torch.Tensor, torch.Tensor]: M, N = x.shape assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 # allocate output dx = torch.empty_like(x) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) assert N <= BLOCK_N sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): _rmsnorm_bwd_kernel[grid]( x, weight, dy, dx, _dw, rstd, x.stride(0), dy.stride(0), dx.stride(0), M, N, eps, rows_per_program, BLOCK_N, (N % BLOCK_N == 0) ) dw = _dw.sum(0).to(weight.dtype) return dx, dw @torch.library.register_fake("flasht5::rmsnorm_triton_bwd") def rmsnorm_triton_bwd_abstract(dy, x, weight, rstd, eps): M, N = x.shape dx = torch.empty_like(x) dw = torch.empty((1, N), dtype=torch.float32, device=weight.device) return dx, dw class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod def forward(ctx, X, W, eps=1e-6): X_orig_shape = X.shape X = X.reshape(-1, X.shape[-1]) y, rstd, = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps) y = y.reshape(X_orig_shape) # We don't store y, will be recomputed in the backward pass to save memory ctx.save_for_backward(X, W, rstd) ctx.x_shape_og = X_orig_shape ctx.eps = eps return y @staticmethod def backward(ctx, dY): X, weight, rstd = ctx.saved_tensors dY = dY.reshape(-1, dY.shape[-1]) assert dY.shape == X.shape dx, dw = torch.ops.flasht5.rmsnorm_triton_bwd( dY, X, weight, rstd, ctx.eps ) return dx.reshape(ctx.x_shape_og), dw, None def fast_rms_layernorm(X, W, eps): out = Fast_RMS_Layernorm.apply(X, W, eps) return out