# -*- coding: utf-8 -*- # Code adapted from # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl from fla.ops.utils import logsumexp_fwd from fla.utils import contiguous # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling # The optimal maximum block size depends on your hardware, your kernel, and your dtype MAX_FUSED_SIZE = 65536 // 2 @triton.jit def cross_entropy_kernel( logits, lse, target, loss, total, ignore_index, label_smoothing: tl.constexpr, logit_scale: tl.constexpr, reduction: tl.constexpr, V: tl.constexpr, BV: tl.constexpr ): """ This kernel computes both cross entropy loss and the gradient of the input. We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. Args: logits: Pointer to logits tensor. lse: Pointer to logsumexp tensor. target: Pointer to target tensor. loss: Pointer to tensor to store the loss. V (int): The number of columns in the input tensor. total (int): The number of non-ignored classes. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The string for the reduction to apply BV (int): The block size for vocab. """ # https://github.com/triton-lang/triton/issues/1058 # If B*T*V is too large, i_n * stride will overflow out of int32, so we convert to int64 i_n = tl.program_id(0).to(tl.int64) NV = tl.cdiv(V, BV) # 1. Load target first because if the target is ignore_index, we can return right away b_y = tl.load(target + i_n) # 2. locate the start index logits += i_n * V if b_y == ignore_index: # set all x as 0 for i in range(0, V, BV): o_v = i + tl.arange(0, BV) tl.store(logits + o_v, 0.0, mask=o_v < V) return # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 # 3. [Online softmax] first pass: compute logsumexp # we did this in anouter kernel b_l = tl.load(logits + b_y) * logit_scale b_lse = tl.load(lse + i_n) # 4. Calculate the loss # loss = lse - logits_l b_loss = b_lse - b_l # Label smoothing is a general case of normal cross entropy # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 b_z = 0.0 eps = label_smoothing / V # We need tl.debug_barrier() as mentioned in # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 tl.debug_barrier() # 5. [Online Softmax] Second pass: compute gradients # For 'mean' reduction, gradients are normalized by number of non-ignored elements # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # For label smoothing: # dx_i = (softmax(x_y) - label_smoothing / V) / N, i != y # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N # = dx_i - (1 - label_smoothing) / N for iv in range(0, NV): o_v = iv * BV + tl.arange(0, BV) b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale if label_smoothing > 0: # scale X beforehand to avoid overflow b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0)) b_p = (tl.exp(b_logits - b_lse) - eps) * logit_scale if reduction == "mean": b_p = b_p / total tl.store(logits + o_v, b_p, mask=o_v < V) tl.debug_barrier() # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: # https://arxiv.org/pdf/1512.00567 # pytorch: # https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse) # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` b_l = tl.load(logits + b_y) # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == 'mean': b_loss = b_loss / total b_l += (label_smoothing - 1) / total * logit_scale else: b_l += (label_smoothing - 1) * logit_scale tl.store(loss + i_n, b_loss) tl.store(logits + b_y, b_l) @triton.jit def elementwise_mul_kernel( x, g, N: tl.constexpr, B: tl.constexpr ): """ This function multiplies each element of the tensor pointed by x with the value pointed by g. The multiplication is performed in-place on the tensor pointed by x. Parameters: x: Pointer to the input tensor. g: Pointer to the gradient output value. N (int): The number of columns in the input tensor. B (int): The block size for Triton operations. """ # Get the program ID and convert it to int64 to avoid overflow i_x = tl.program_id(0).to(tl.int64) o_x = i_x * B + tl.arange(0, B) # Load the gradient output value b_g = tl.load(g) b_x = tl.load(x + o_x, mask=o_x < N) tl.store(x + o_x, b_x * b_g, mask=o_x < N) def fused_linear_cross_entropy_forward( x: torch.Tensor, target: torch.LongTensor, weight: torch.Tensor, bias: torch.Tensor = None, ignore_index: int = -100, label_smoothing: float = 0.0, logit_scale: float = 1.0, num_chunks: int = 8, reduction: str = "mean" ): device = x.device # inputs have shape: [N, H] # materialized activations will have shape: [N, V] # the increase in memory = [N, V] # reduction can be achieved by partitioning the number of tokens N into smaller chunks. # ideally, we would like to achieve the same memory consumption as [N, H], # so the expected chunk size should be: # NC = ceil(V / H) # C = ceil(N / NC) # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048 N, H, V = *x.shape, weight.shape[0] BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) # TODO: in real cases, we may need to limit the number of chunks NC to # ensure the precisions of accumulated gradients NC = min(num_chunks, triton.cdiv(V, H)) C = triton.next_power_of_2(triton.cdiv(N, NC)) NC = triton.cdiv(N, C) dx = torch.zeros_like(x, device=device) dw = torch.zeros_like(weight, device=device) if weight is not None else None db = torch.zeros_like(bias, device=device) if bias is not None else None # we use fp32 for loss accumulator loss = torch.zeros(N, dtype=torch.float32, device=device) total = target.ne(ignore_index).sum().item() for ic in range(NC): start, end = ic * C, min((ic + 1) * C, N) # [C, N] c_x = x[start:end] # when doing matmul, use the original precision # [C, V] c_logits = F.linear(c_x, weight, bias) c_target = target[start:end] # [C] # keep lse in fp32 to maintain precision c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float) # unreduced loss c_loss = loss[start:end] # Here we calculate the gradient of c_logits in place so we can save memory. cross_entropy_kernel[(c_logits.shape[0],)]( logits=c_logits, lse=c_lse, target=c_target, loss=c_loss, total=total, ignore_index=ignore_index, label_smoothing=label_smoothing, logit_scale=logit_scale, reduction=reduction, V=V, BV=BV, num_warps=32 ) # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V # thus dx should be of shape: C x H dx[start:end] = torch.mm(c_logits, weight) # keep dw in fp32 to maintain precision if weight is not None: dw += c_logits.t() @ c_x if bias is not None: torch.add(input=db, other=c_logits.sum(0), out=db) loss = loss.sum() if dw is not None: dw = dw.to(weight) if db is not None: db = db.to(bias) return loss, dx, dw, db def fused_linear_cross_entropy_backward( do: torch.Tensor, dx: torch.Tensor, dw: torch.Tensor, db: torch.Tensor ): # If cross entropy is the last layer, do is 1.0. Skip the mul to save time if torch.ne(do, torch.tensor(1.0, device=do.device)): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. N, H = dx.shape B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( x=dx, g=do, N=N*H, B=B, num_warps=32, ) # handle dw if dw is not None: V, H = dw.shape elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( x=dw, g=do, N=V*H, B=B, num_warps=32, ) if db is not None: V = db.shape[0] elementwise_mul_kernel[(triton.cdiv(V, B),)]( x=db, g=do, N=V, B=B, num_warps=32, ) return dx, dw, db class FusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod @contiguous def forward( ctx, x: torch.Tensor, target: torch.LongTensor, weight: torch.Tensor, bias: torch.Tensor = None, ignore_index: int = -100, label_smoothing: float = 0.0, logit_scale: float = 1.0, num_chunks: int = 8, reduction: str = "mean" ): """ Fusing the last linear layer with cross-entropy loss Reference: https://github.com/mgmalek/efficient_cross_entropy Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can compute the gradient at the forward pass. By doing so, we don't have to store the x and target for the backward pass. x (torch.Tensor): [batch_size * seq_len, hidden_size] target (torch.LongTensor): [batch_size * seq_len] where each value is in [0, vocab_size). weight (torch.Tensor): [vocab_size, hidden_size] where `vocab_size` is the number of classes. bias (Optional[torch.Tensor]): [vocab_size] where `vocab_size` is the number of classes. ignore_index: the index to ignore in the target. label_smoothing: the amount of smoothing when computing the loss, where 0.0 means no smoothing. logit_scale: float = 1.0, A scaling factor applied to the logits. Default: 1.0 num_chunks: int The number of chunks to split the input tensor into for processing. This can help optimize memory usage and computation speed. Default: 8 reduction: Specifies the reduction to apply to the output: 'mean' | 'sum'. 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Default: 'mean'. """ loss, dx, dw, db = fused_linear_cross_entropy_forward( x, target, weight, bias, ignore_index, label_smoothing, logit_scale, num_chunks, reduction ) # downcast to dtype and store for backward ctx.save_for_backward( dx.detach(), dw.detach() if weight is not None else None, db.detach() if bias is not None else None, ) return loss @staticmethod @contiguous def backward(ctx, do): dx, dw, db = ctx.saved_tensors dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db) return dx, None, dw, db, None, None, None, None, None def fused_linear_cross_entropy_loss( x: torch.Tensor, target: torch.LongTensor, weight: torch.Tensor, bias: torch.Tensor = None, ignore_index: int = -100, label_smoothing: float = 0.0, logit_scale: float = 1.0, num_chunks: int = 8, reduction: str = "mean" ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x (torch.Tensor): [batch_size * seq_len, hidden_size] target (torch.LongTensor): [batch_size * seq_len] where each value is in [0, vocab_size). weight (torch.Tensor): [vocab_size, hidden_size] where `vocab_size` is the number of classes. bias (Optional[torch.Tensor]): [vocab_size] where `vocab_size` is the number of classes. ignore_index: int. If target == ignore_index, the loss is set to 0.0. label_smoothing: float logit_scale: float A scaling factor applied to the logits. Default: 1.0 num_chunks: int The number of chunks to split the input tensor into for processing. This can help optimize memory usage and computation speed. Default: 8 reduction: Specifies the reduction to apply to the output: 'mean' | 'sum'. 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Default: 'mean'. Returns: losses: [batch,], float """ return FusedLinearCrossEntropyFunction.apply( x, target, weight, bias, ignore_index, label_smoothing, logit_scale, num_chunks, reduction ) class FusedLinearCrossEntropyLoss(nn.Module): def __init__( self, ignore_index: int = -100, label_smoothing: float = 0.0, logit_scale: float = 1.0, num_chunks: int = 8, reduction: str = "mean" ): """ Args: ignore_index: int. If target == ignore_index, the loss is set to 0.0. label_smoothing: float logit_scale: float A scaling factor applied to the logits. Default: 1.0 num_chunks: int The number of chunks to split the input tensor into for processing. This can help optimize memory usage and computation speed. Default: 8 reduction: Specifies the reduction to apply to the output: 'mean' | 'sum'. 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Default: 'mean'. """ super().__init__() assert reduction in ["none", "mean", "sum"], f"reduction: {reduction} is not supported" self.ignore_index = ignore_index self.label_smoothing = label_smoothing self.logit_scale = logit_scale self.num_chunks = num_chunks self.reduction = reduction def forward( self, x: torch.Tensor, target: torch.LongTensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None ): """ Args: x (torch.Tensor): [batch_size * seq_len, hidden_size] target (torch.LongTensor): [batch_size * seq_len] where each value is in [0, V). weight (torch.Tensor): [vocab_size, hidden_size] where `vocab_size` is the number of classes. bias (Optional[torch.Tensor]): [vocab_size] where `vocab_size` is the number of classes. Returns: loss """ loss = fused_linear_cross_entropy_loss( x, target, weight=weight, bias=bias, ignore_index=self.ignore_index, label_smoothing=self.label_smoothing, logit_scale=self.logit_scale, num_chunks=self.num_chunks, reduction=self.reduction ) return loss