|
|
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import contiguous |
|
|
|
|
|
|
|
|
|
|
|
MAX_FUSED_SIZE = 65536 // 2 |
|
|
|
|
|
@triton.jit |
|
def kl_div_kernel( |
|
logits, |
|
target_logits, |
|
loss, |
|
s_logits, |
|
s_loss, |
|
reduction: tl.constexpr, |
|
N: tl.constexpr, |
|
V: tl.constexpr, |
|
BV: tl.constexpr |
|
): |
|
|
|
|
|
i_n = tl.program_id(0).to(tl.int64) |
|
|
|
logits += i_n * s_logits |
|
target_logits += i_n * s_logits |
|
|
|
|
|
sm, tm = float('-inf'), float('-inf') |
|
|
|
sd, td = 0.0, 0.0 |
|
|
|
NV = tl.cdiv(V, BV) |
|
for iv in range(0, NV): |
|
o_x = iv * BV + tl.arange(0, BV) |
|
|
|
b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf')) |
|
b_sm = tl.max(b_sl) |
|
m_new = tl.maximum(sm, b_sm) |
|
sd = sd * tl.exp(sm - m_new) + tl.sum(tl.exp(b_sl - m_new)) |
|
sm = m_new |
|
|
|
b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf')) |
|
b_tm = tl.max(b_tl) |
|
m_new = tl.maximum(tm, b_tm) |
|
td = td * tl.exp(tm - m_new) + tl.sum(tl.exp(b_tl - m_new)) |
|
tm = m_new |
|
|
|
b_loss = 0. |
|
|
|
for iv in range(0, NV): |
|
o_x = iv * BV + tl.arange(0, BV) |
|
b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf')) |
|
b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf')) |
|
b_sp_log = b_sl - sm - tl.log(sd) |
|
b_tp_log = b_tl - tm - tl.log(td) |
|
b_sp = tl.exp(b_sp_log) |
|
b_tp = tl.exp(b_tp_log) |
|
b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0) |
|
b_dl = -b_tp + b_sp |
|
b_loss += tl.sum(b_kl) |
|
if reduction == 'batchmean': |
|
b_dl = b_dl / N |
|
tl.store(logits + o_x, b_dl, mask=o_x < V) |
|
|
|
|
|
if reduction == 'batchmean': |
|
b_loss = b_loss / N |
|
|
|
tl.store(loss + i_n * s_loss, b_loss) |
|
|
|
|
|
@triton.jit |
|
def elementwise_mul_kernel( |
|
x, |
|
g, |
|
N: tl.constexpr, |
|
B: tl.constexpr |
|
): |
|
""" |
|
This function multiplies each element of the tensor pointed by x with the value pointed by g. |
|
The multiplication is performed in-place on the tensor pointed by x. |
|
|
|
Parameters: |
|
x: |
|
Pointer to the input tensor. |
|
g: |
|
Pointer to the gradient output value. |
|
N (int): |
|
The number of columns in the input tensor. |
|
B (int): |
|
The block size for Triton operations. |
|
""" |
|
|
|
|
|
i_x = tl.program_id(0).to(tl.int64) |
|
o_x = i_x * B + tl.arange(0, B) |
|
|
|
|
|
b_g = tl.load(g) |
|
b_x = tl.load(x + o_x, mask=o_x < N) |
|
tl.store(x + o_x, b_x * b_g, mask=o_x < N) |
|
|
|
|
|
def fused_kl_div_forward( |
|
x: torch.Tensor, |
|
target_x: torch.Tensor, |
|
weight: torch.Tensor, |
|
target_weight: torch.Tensor, |
|
reduction: str = 'batchmean' |
|
): |
|
device = x.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
N, H, V = *x.shape, weight.shape[0] |
|
BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) |
|
|
|
|
|
NC = min(8, triton.cdiv(V, H)) |
|
C = triton.next_power_of_2(triton.cdiv(N, NC)) |
|
NC = triton.cdiv(N, C) |
|
|
|
dx = torch.zeros_like(x, device=device) |
|
dw = torch.zeros_like(weight, device=device) if weight is not None else None |
|
|
|
loss = torch.zeros(N, dtype=torch.float32, device=device) |
|
|
|
for ic in range(NC): |
|
start, end = ic * C, min((ic + 1) * C, N) |
|
|
|
c_sx = x[start:end] |
|
c_tx = target_x[start:end] |
|
|
|
|
|
c_sl = F.linear(c_sx, weight) |
|
c_tl = F.linear(c_tx, target_weight) |
|
|
|
|
|
c_loss = loss[start:end] |
|
|
|
|
|
kl_div_kernel[(c_sx.shape[0],)]( |
|
logits=c_sl, |
|
target_logits=c_tl, |
|
loss=c_loss, |
|
s_logits=c_sl.stride(-2), |
|
s_loss=c_loss.stride(-1), |
|
reduction=reduction, |
|
N=N, |
|
V=V, |
|
BV=BV, |
|
num_warps=32 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dx[start:end] = torch.mm(c_sl, weight) |
|
|
|
if weight is not None: |
|
torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw) |
|
|
|
loss = loss.sum() |
|
return loss, dx, dw |
|
|
|
|
|
def fused_kl_div_backward( |
|
do: torch.Tensor, |
|
dx: torch.Tensor, |
|
dw: torch.Tensor |
|
): |
|
|
|
if torch.ne(do, torch.tensor(1.0, device=do.device)): |
|
|
|
|
|
N, H = dx.shape |
|
B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) |
|
|
|
elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( |
|
x=dx, |
|
g=do, |
|
N=N*H, |
|
B=B, |
|
num_warps=32, |
|
) |
|
|
|
|
|
if dw is not None: |
|
V, H = dw.shape |
|
elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( |
|
x=dw, |
|
g=do, |
|
N=V*H, |
|
B=B, |
|
num_warps=32, |
|
) |
|
|
|
return dx, dw |
|
|
|
|
|
class FusedKLDivLossFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
def forward( |
|
ctx, |
|
x: torch.Tensor, |
|
target_x: torch.Tensor, |
|
weight: torch.Tensor, |
|
target_weight: torch.Tensor, |
|
reduction: str |
|
): |
|
loss, dx, dw = fused_kl_div_forward( |
|
x=x, |
|
target_x=target_x, |
|
weight=weight, |
|
target_weight=target_weight, |
|
reduction=reduction |
|
) |
|
ctx.save_for_backward(dx, dw) |
|
return loss |
|
|
|
@staticmethod |
|
@contiguous |
|
def backward(ctx, do): |
|
dx, dw = ctx.saved_tensors |
|
dx, dw = fused_kl_div_backward(do, dx, dw) |
|
return dx, None, dw, None, None |
|
|
|
|
|
def fused_kl_div_loss( |
|
x: torch.Tensor, |
|
target_x: torch.Tensor, |
|
weight: torch.Tensor, |
|
target_weight: torch.Tensor, |
|
reduction: str = 'batchmean' |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Args: |
|
x (torch.Tensor): [batch_size * seq_len, hidden_size] |
|
target_x (torch.Tensor): [batch_size * seq_len, hidden_size] |
|
weight (torch.Tensor): [vocab_size, hidden_size] |
|
where `vocab_size` is the number of classes. |
|
target_weight (torch.Tensor): [vocab_size, hidden_size] |
|
where `vocab_size` is the number of classes. |
|
reduction: |
|
Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'. |
|
Returns: |
|
loss |
|
""" |
|
return FusedKLDivLossFunction.apply( |
|
x, |
|
target_x, |
|
weight, |
|
target_weight, |
|
reduction |
|
) |
|
|
|
|
|
class FusedKLDivLoss(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
reduction: str = 'batchmean' |
|
): |
|
""" |
|
Args: |
|
reduction: |
|
Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'. |
|
""" |
|
super().__init__() |
|
|
|
assert reduction in ['batchmean'], f"reduction: {reduction} is not supported" |
|
|
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
target_x: torch.Tensor, |
|
weight: torch.Tensor, |
|
target_weight: torch.Tensor |
|
): |
|
""" |
|
Args: |
|
x (torch.Tensor): [batch_size * seq_len, hidden_size] |
|
target_x (torch.Tensor): [batch_size * seq_len, hidden_size] |
|
weight (torch.Tensor): [vocab_size, hidden_size] |
|
where `vocab_size` is the number of classes. |
|
target_weight (torch.Tensor): [vocab_size, hidden_size] |
|
where `vocab_size` is the number of classes. |
|
Returns: |
|
loss |
|
""" |
|
loss = fused_kl_div_loss( |
|
x=x, |
|
target_x=target_x, |
|
weight=weight, |
|
target_weight=target_weight, |
|
reduction=self.reduction |
|
) |
|
return loss |
|
|