|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
|
|
|
|
|
|
if "all_gather_into_tensor" not in dir(torch.distributed): |
|
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base |
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
|
} |
|
) |
|
@triton.jit |
|
def cross_entropy_fwd_kernel( |
|
loss_ptr, |
|
lse_ptr, |
|
z_loss_ptr, |
|
logits_ptr, |
|
labels_ptr, |
|
smoothing, |
|
logit_scale, |
|
lse_square_scale, |
|
ignore_index, |
|
total_classes, |
|
class_start_idx, |
|
n_cols, |
|
logits_row_stride, |
|
BLOCK_SIZE: tl.constexpr, |
|
HAS_SMOOTHING: tl.constexpr, |
|
|
|
SPLIT: tl.constexpr, |
|
PRECOMPUTED_LSE: tl.constexpr, |
|
): |
|
row_idx = tl.program_id(0) |
|
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
|
sum_logits = 0.0 |
|
if not PRECOMPUTED_LSE: |
|
|
|
m_i = -float("inf") |
|
l_i = 0.0 |
|
for col_offset in range(0, n_cols, BLOCK_SIZE): |
|
cols = col_offset + tl.arange(0, BLOCK_SIZE) |
|
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( |
|
tl.float32 |
|
) * logit_scale |
|
if HAS_SMOOTHING: |
|
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) |
|
m_i_new = tl.maximum(m_i, tl.max(logits)) |
|
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) |
|
m_i = m_i_new |
|
lse = tl.log(l_i) + m_i |
|
tl.store(lse_ptr + row_idx, lse) |
|
else: |
|
lse = tl.load(lse_ptr + row_idx) |
|
label_idx = tl.load(labels_ptr + row_idx) |
|
if label_idx == ignore_index: |
|
loss = 0.0 |
|
z_loss = 0.0 |
|
else: |
|
label_idx -= class_start_idx |
|
if label_idx >= 0 and label_idx < n_cols: |
|
logits_label = tl.load(logits_ptr + label_idx) * logit_scale |
|
if HAS_SMOOTHING: |
|
loss = ( |
|
(lse if not SPLIT else 0.0) |
|
- smoothing * sum_logits / total_classes |
|
- (1 - smoothing) * logits_label |
|
) |
|
else: |
|
loss = (lse if not SPLIT else 0.0) - logits_label |
|
else: |
|
|
|
if HAS_SMOOTHING: |
|
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) |
|
else: |
|
loss = 0.0 |
|
if not SPLIT: |
|
z_loss = lse_square_scale * lse * lse |
|
loss += z_loss |
|
else: |
|
z_loss = 0.0 |
|
tl.store(loss_ptr + row_idx, loss) |
|
if not SPLIT: |
|
tl.store(z_loss_ptr + row_idx, z_loss) |
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, |
|
} |
|
) |
|
@triton.jit |
|
def cross_entropy_bwd_kernel( |
|
dlogits_ptr, |
|
dloss_ptr, |
|
logits_ptr, |
|
lse_ptr, |
|
labels_ptr, |
|
smoothing, |
|
logit_scale, |
|
lse_square_scale, |
|
ignore_index, |
|
total_classes, |
|
class_start_idx, |
|
n_cols, |
|
logits_row_stride, |
|
dlogits_row_stride, |
|
dloss_row_stride, |
|
BLOCK_SIZE: tl.constexpr, |
|
HAS_SMOOTHING: tl.constexpr, |
|
): |
|
row_idx = tl.program_id(0) |
|
col_block_idx = tl.program_id(1) |
|
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) |
|
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) |
|
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
|
label_idx = tl.load(labels_ptr + row_idx) |
|
if label_idx != ignore_index: |
|
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) |
|
else: |
|
dloss = 0.0 |
|
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( |
|
tl.float32 |
|
) * logit_scale |
|
lse = tl.load(lse_ptr + row_idx) |
|
probs = tl.exp(logits - lse) |
|
probs += 2.0 * lse_square_scale * lse * probs |
|
label_idx -= class_start_idx |
|
if HAS_SMOOTHING: |
|
smooth_positive = 1.0 - smoothing |
|
smooth_negative = smoothing / total_classes |
|
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative |
|
else: |
|
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) |
|
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) |
|
|
|
@torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda") |
|
def cross_entropy_triton_fwd( |
|
logits: torch.Tensor, |
|
labels: torch.Tensor, |
|
precomputed_lse: torch.Tensor, |
|
use_precomputed_lse: bool, |
|
split: bool, |
|
smoothing: float, |
|
logit_scale: float, |
|
lse_square_scale: float, |
|
ignore_index: int, |
|
total_classes: int, |
|
class_start_idx: int, |
|
n_cols: int, |
|
n_rows: int, |
|
BLOCK_SIZE: int, |
|
num_warps: int |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
if logits.stride(-1) != 1: |
|
logits = logits.contiguous() |
|
|
|
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
|
if use_precomputed_lse: |
|
assert precomputed_lse.shape == (n_rows,) |
|
lse = precomputed_lse.contiguous() |
|
else: |
|
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
|
|
|
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) |
|
|
|
|
|
with torch.cuda.device(logits.device.index): |
|
cross_entropy_fwd_kernel[(n_rows,)]( |
|
losses, |
|
lse, |
|
z_losses, |
|
logits, |
|
labels, |
|
smoothing, |
|
logit_scale, |
|
lse_square_scale, |
|
ignore_index, |
|
total_classes, |
|
class_start_idx, |
|
n_cols, |
|
logits.stride(0), |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
SPLIT=split, |
|
PRECOMPUTED_LSE=use_precomputed_lse, |
|
num_warps=num_warps, |
|
) |
|
|
|
return losses, z_losses, lse |
|
|
|
|
|
@torch.library.register_fake("flasht5::cross_entropy_triton_fwd") |
|
def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device) |
|
|
|
return losses, z_losses, logsumexp |
|
|
|
@torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda") |
|
def cross_entropy_triton_bwd( |
|
dlosses: torch.Tensor, |
|
logits: torch.Tensor, |
|
lse: torch.Tensor, |
|
labels: torch.Tensor, |
|
inplace_backward: bool, |
|
smoothing: float, |
|
logit_scale: float, |
|
lse_square_scale: float, |
|
ignore_index: int, |
|
total_classes: int, |
|
class_start_idx: int, |
|
n_cols: int, |
|
n_rows: int, |
|
BLOCK_SIZE: int, |
|
num_warps: int |
|
) -> torch.Tensor: |
|
|
|
dlogits = logits if inplace_backward else torch.empty_like(logits) |
|
|
|
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) |
|
|
|
|
|
|
|
with torch.cuda.device(logits.device.index): |
|
cross_entropy_bwd_kernel[grid]( |
|
dlogits, |
|
dlosses, |
|
logits, |
|
lse, |
|
labels, |
|
smoothing, |
|
logit_scale, |
|
lse_square_scale, |
|
ignore_index, |
|
total_classes, |
|
class_start_idx, |
|
n_cols, |
|
logits.stride(0), |
|
dlogits.stride(0), |
|
dlosses.stride(0), |
|
BLOCK_SIZE=BLOCK_SIZE, |
|
num_warps=num_warps, |
|
) |
|
|
|
return dlogits if not inplace_backward else None |
|
|
|
@torch.library.register_fake("flasht5::cross_entropy_triton_bwd") |
|
def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps): |
|
return torch.empty_like(logits) |
|
|
|
class CrossEntropyLoss(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward( |
|
ctx, |
|
logits, |
|
labels, |
|
precomputed_lse=None, |
|
smoothing=0.0, |
|
logit_scale=1.0, |
|
lse_square_scale=0.0, |
|
ignore_index=-100, |
|
inplace_backward=False, |
|
process_group=None, |
|
): |
|
|
|
|
|
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: |
|
labels = F.pad(labels, (0, 1))[..., :-1] |
|
assert labels.data_ptr() % 16 == 0 |
|
|
|
n_rows, n_cols = logits.shape |
|
assert labels.shape == (n_rows,) |
|
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) |
|
total_classes = world_size * n_cols |
|
rank = 0 if process_group is None else torch.distributed.get_rank(process_group) |
|
class_start_idx = rank * n_cols |
|
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 |
|
|
|
MAX_BLOCK_SIZE = 16 * 1024 |
|
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) |
|
num_warps = ( |
|
4 |
|
if BLOCK_SIZE < 2048 |
|
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) |
|
) |
|
|
|
losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd( |
|
logits, labels, precomputed_lse, use_precomputed_lse, \ |
|
world_size > 1, smoothing, logit_scale, lse_square_scale, \ |
|
ignore_index, total_classes, class_start_idx, \ |
|
n_cols, n_rows, BLOCK_SIZE, num_warps |
|
) |
|
|
|
if world_size > 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if world_size > 1: |
|
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) |
|
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) |
|
handle_losses = torch.distributed.all_reduce( |
|
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True |
|
) |
|
lse = torch.logsumexp(lse_allgather, dim=0) |
|
handle_losses.wait() |
|
|
|
|
|
|
|
|
|
|
|
losses += lse |
|
if lse_square_scale != 0.0: |
|
z_losses = lse_square_scale * lse.square() |
|
z_losses.masked_fill_(labels == ignore_index, 0.0) |
|
losses += z_losses |
|
else: |
|
z_losses = torch.zeros_like(losses) |
|
losses.masked_fill_(labels == ignore_index, 0.0) |
|
|
|
ctx.save_for_backward(logits, lse, labels) |
|
ctx.mark_non_differentiable(z_losses) |
|
ctx.smoothing = smoothing |
|
ctx.logit_scale = logit_scale |
|
ctx.lse_square_scale = lse_square_scale |
|
ctx.ignore_index = ignore_index |
|
ctx.total_classes = total_classes |
|
ctx.class_start_idx = class_start_idx |
|
ctx.inplace_backward = inplace_backward |
|
|
|
return losses, z_losses |
|
|
|
@staticmethod |
|
def backward(ctx, grad_losses, grad_z_losses): |
|
del grad_z_losses |
|
|
|
logits, lse, labels = ctx.saved_tensors |
|
|
|
n_rows, n_cols = logits.shape |
|
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) |
|
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) |
|
|
|
dlogits = torch.ops.flasht5.cross_entropy_triton_bwd( |
|
grad_losses, logits, lse, labels, \ |
|
ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \ |
|
ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \ |
|
ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps |
|
) |
|
|
|
if ctx.inplace_backward: |
|
dlogits = logits |
|
|
|
return dlogits, None, None, None, None, None, None, None, None, None |
|
|
|
|
|
def cross_entropy_loss( |
|
logits: torch.Tensor, |
|
labels: torch.Tensor, |
|
precomputed_lse: Optional[torch.Tensor] = None, |
|
label_smoothing: float = 0.0, |
|
logit_scale: float = 1.0, |
|
lse_square_scale: float = 0.0, |
|
ignore_index=-100, |
|
inplace_backward: bool = False, |
|
process_group=None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Arguments: |
|
logits: (batch, vocab_size) |
|
labels: (batch,) |
|
label_smoothing: float |
|
logit_scale: float. Multiply logits by this scale before calculating the loss. |
|
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. |
|
This is also referred to as "z-loss". |
|
ignore_index: int. If labels == ignore_index, the loss is set to 0.0. |
|
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. |
|
This saves memory. |
|
process_group: if not None, we're doing Tensor Parallel: each process is responsible for |
|
one part of the vocab. The loss will be aggregated across processes. |
|
Returns: |
|
losses: (batch,), float |
|
z_losses: (batch,), float |
|
""" |
|
return CrossEntropyLoss.apply( |
|
logits.view(-1, logits.shape[-1]), |
|
labels.view(-1), |
|
precomputed_lse, |
|
label_smoothing, |
|
logit_scale, |
|
lse_square_scale, |
|
ignore_index, |
|
inplace_backward, |
|
process_group, |
|
) |
|
|