zaydzuhri's picture
Training in progress, step 2500
0094a2a verified
# -*- coding: utf-8 -*-
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from fla.utils import contiguous
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
@triton.jit
def kl_div_kernel(
logits,
target_logits,
loss,
s_logits,
s_loss,
reduction: tl.constexpr,
N: tl.constexpr,
V: tl.constexpr,
BV: tl.constexpr
):
# https://github.com/triton-lang/triton/issues/1058
# If N*V is too large, i_n * stride will overflow out of int32, so we convert to int64
i_n = tl.program_id(0).to(tl.int64)
logits += i_n * s_logits
target_logits += i_n * s_logits
# m is the max value. use the notation from the paper
sm, tm = float('-inf'), float('-inf')
# d is the sum. use the notation from the paper
sd, td = 0.0, 0.0
NV = tl.cdiv(V, BV)
for iv in range(0, NV):
o_x = iv * BV + tl.arange(0, BV)
# for student
b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
b_sm = tl.max(b_sl)
m_new = tl.maximum(sm, b_sm)
sd = sd * tl.exp(sm - m_new) + tl.sum(tl.exp(b_sl - m_new))
sm = m_new
# for teacher
b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
b_tm = tl.max(b_tl)
m_new = tl.maximum(tm, b_tm)
td = td * tl.exp(tm - m_new) + tl.sum(tl.exp(b_tl - m_new))
tm = m_new
b_loss = 0.
# KL(y_true || y) = exp(y_true) * (log(y_true) - log(y))
for iv in range(0, NV):
o_x = iv * BV + tl.arange(0, BV)
b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
b_sp_log = b_sl - sm - tl.log(sd)
b_tp_log = b_tl - tm - tl.log(td)
b_sp = tl.exp(b_sp_log)
b_tp = tl.exp(b_tp_log)
b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0)
b_dl = -b_tp + b_sp
b_loss += tl.sum(b_kl)
if reduction == 'batchmean':
b_dl = b_dl / N
tl.store(logits + o_x, b_dl, mask=o_x < V)
# Normalize the loss by the number of elements if reduction is 'batchmean'
if reduction == 'batchmean':
b_loss = b_loss / N
tl.store(loss + i_n * s_loss, b_loss)
@triton.jit
def elementwise_mul_kernel(
x,
g,
N: tl.constexpr,
B: tl.constexpr
):
"""
This function multiplies each element of the tensor pointed by x with the value pointed by g.
The multiplication is performed in-place on the tensor pointed by x.
Parameters:
x:
Pointer to the input tensor.
g:
Pointer to the gradient output value.
N (int):
The number of columns in the input tensor.
B (int):
The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
i_x = tl.program_id(0).to(tl.int64)
o_x = i_x * B + tl.arange(0, B)
# Load the gradient output value
b_g = tl.load(g)
b_x = tl.load(x + o_x, mask=o_x < N)
tl.store(x + o_x, b_x * b_g, mask=o_x < N)
def fused_kl_div_forward(
x: torch.Tensor,
target_x: torch.Tensor,
weight: torch.Tensor,
target_weight: torch.Tensor,
reduction: str = 'batchmean'
):
device = x.device
# ideally, we would like to achieve the same memory consumption as [N, H],
# so the expected chunk size should be:
# NC = ceil(V / H)
# C = ceil(N / NC)
# for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
N, H, V = *x.shape, weight.shape[0]
BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# TODO: in real cases, we may need to limit the number of chunks NC to
# ensure the precisions of accumulated gradients
NC = min(8, triton.cdiv(V, H))
C = triton.next_power_of_2(triton.cdiv(N, NC))
NC = triton.cdiv(N, C)
dx = torch.zeros_like(x, device=device)
dw = torch.zeros_like(weight, device=device) if weight is not None else None
# we use fp32 for loss accumulator
loss = torch.zeros(N, dtype=torch.float32, device=device)
for ic in range(NC):
start, end = ic * C, min((ic + 1) * C, N)
# [C, N]
c_sx = x[start:end]
c_tx = target_x[start:end]
# when doing matmul, use the original precision
# [C, V]
c_sl = F.linear(c_sx, weight)
c_tl = F.linear(c_tx, target_weight)
# unreduced loss
c_loss = loss[start:end]
# Here we calculate the gradient of c_sx in place so we can save memory.
kl_div_kernel[(c_sx.shape[0],)](
logits=c_sl,
target_logits=c_tl,
loss=c_loss,
s_logits=c_sl.stride(-2),
s_loss=c_loss.stride(-1),
reduction=reduction,
N=N,
V=V,
BV=BV,
num_warps=32
)
# gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
# thus dx[start: end] should be of shape: C x H
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
# Thus, we need an additional scaling factor of (n_non_ignore/total) to scale the gradients.
# [C, H]
dx[start:end] = torch.mm(c_sl, weight)
if weight is not None:
torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw)
loss = loss.sum()
return loss, dx, dw
def fused_kl_div_backward(
do: torch.Tensor,
dx: torch.Tensor,
dw: torch.Tensor
):
# If cross entropy is the last layer, do is 1.0. Skip the mul to save time
if torch.ne(do, torch.tensor(1.0, device=do.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
N, H = dx.shape
B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
x=dx,
g=do,
N=N*H,
B=B,
num_warps=32,
)
# handle dw
if dw is not None:
V, H = dw.shape
elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
x=dw,
g=do,
N=V*H,
B=B,
num_warps=32,
)
return dx, dw
class FusedKLDivLossFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x: torch.Tensor,
target_x: torch.Tensor,
weight: torch.Tensor,
target_weight: torch.Tensor,
reduction: str
):
loss, dx, dw = fused_kl_div_forward(
x=x,
target_x=target_x,
weight=weight,
target_weight=target_weight,
reduction=reduction
)
ctx.save_for_backward(dx, dw)
return loss
@staticmethod
@contiguous
def backward(ctx, do):
dx, dw = ctx.saved_tensors
dx, dw = fused_kl_div_backward(do, dx, dw)
return dx, None, dw, None, None
def fused_kl_div_loss(
x: torch.Tensor,
target_x: torch.Tensor,
weight: torch.Tensor,
target_weight: torch.Tensor,
reduction: str = 'batchmean'
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x (torch.Tensor): [batch_size * seq_len, hidden_size]
target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
weight (torch.Tensor): [vocab_size, hidden_size]
where `vocab_size` is the number of classes.
target_weight (torch.Tensor): [vocab_size, hidden_size]
where `vocab_size` is the number of classes.
reduction:
Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
Returns:
loss
"""
return FusedKLDivLossFunction.apply(
x,
target_x,
weight,
target_weight,
reduction
)
class FusedKLDivLoss(nn.Module):
def __init__(
self,
reduction: str = 'batchmean'
):
"""
Args:
reduction:
Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
"""
super().__init__()
assert reduction in ['batchmean'], f"reduction: {reduction} is not supported"
self.reduction = reduction
def forward(
self,
x: torch.Tensor,
target_x: torch.Tensor,
weight: torch.Tensor,
target_weight: torch.Tensor
):
"""
Args:
x (torch.Tensor): [batch_size * seq_len, hidden_size]
target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
weight (torch.Tensor): [vocab_size, hidden_size]
where `vocab_size` is the number of classes.
target_weight (torch.Tensor): [vocab_size, hidden_size]
where `vocab_size` is the number of classes.
Returns:
loss
"""
loss = fused_kl_div_loss(
x=x,
target_x=target_x,
weight=weight,
target_weight=target_weight,
reduction=self.reduction
)
return loss