Text2Text Generation
Transformers
PyTorch
French
flash_t5
custom_code
FAT5-small / cross_entropy_loss.py
bourdoiscatie's picture
Add FAT5-small
0743270 verified
# Copyright (c) 2023, Tri Dao.
# Copyright 2024 CATIE. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modification to the original version from Tri Dao:
# - support for torch.compile
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
z_loss_ptr,
logits_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
):
row_idx = tl.program_id(0)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
sum_logits = 0.0 # For smoothing
if not PRECOMPUTED_LSE:
# Statistics for online softmax
m_i = -float("inf")
l_i = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
cols = col_offset + tl.arange(0, BLOCK_SIZE)
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
if HAS_SMOOTHING:
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
m_i_new = tl.maximum(m_i, tl.max(logits))
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
m_i = m_i_new
lse = tl.log(l_i) + m_i
tl.store(lse_ptr + row_idx, lse)
else:
lse = tl.load(lse_ptr + row_idx)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx == ignore_index:
loss = 0.0
z_loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= 0 and label_idx < n_cols:
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- smoothing * sum_logits / total_classes
- (1 - smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if HAS_SMOOTHING:
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
z_loss = lse_square_scale * lse * lse
loss += z_loss
else:
z_loss = 0.0
tl.store(loss_ptr + row_idx, loss)
if not SPLIT:
tl.store(z_loss_ptr + row_idx, z_loss)
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignore_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_positive = 1.0 - smoothing
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
@torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda")
def cross_entropy_triton_fwd(
logits: torch.Tensor,
labels: torch.Tensor,
precomputed_lse: torch.Tensor,
use_precomputed_lse: bool,
split: bool,
smoothing: float,
logit_scale: float,
lse_square_scale: float,
ignore_index: int,
total_classes: int,
class_start_idx: int,
n_cols: int,
n_rows: int,
BLOCK_SIZE: int,
num_warps: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if logits.stride(-1) != 1:
logits = logits.contiguous()
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
if use_precomputed_lse:
assert precomputed_lse.shape == (n_rows,)
lse = precomputed_lse.contiguous()
else:
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows,)](
losses, # data ptrs
lse,
z_losses,
logits,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
SPLIT=split,
PRECOMPUTED_LSE=use_precomputed_lse,
num_warps=num_warps,
)
return losses, z_losses, lse
@torch.library.register_fake("flasht5::cross_entropy_triton_fwd")
def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
return losses, z_losses, logsumexp
@torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda")
def cross_entropy_triton_bwd(
dlosses: torch.Tensor,
logits: torch.Tensor,
lse: torch.Tensor,
labels: torch.Tensor,
inplace_backward: bool,
smoothing: float,
logit_scale: float,
lse_square_scale: float,
ignore_index: int,
total_classes: int,
class_start_idx: int,
n_cols: int,
n_rows: int,
BLOCK_SIZE: int,
num_warps: int
) -> torch.Tensor:
dlogits = logits if inplace_backward else torch.empty_like(logits)
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
dlosses,
logits,
lse,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
dlosses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits if not inplace_backward else None
@torch.library.register_fake("flasht5::cross_entropy_triton_bwd")
def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
return torch.empty_like(logits)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
precomputed_lse=None,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignore_index=-100,
inplace_backward=False,
process_group=None,
):
# For some reason Triton generates wrong code when labels has dtype long and its address
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
labels = F.pad(labels, (0, 1))[..., :-1]
assert labels.data_ptr() % 16 == 0
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
MAX_BLOCK_SIZE = 16 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd(
logits, labels, precomputed_lse, use_precomputed_lse, \
world_size > 1, smoothing, logit_scale, lse_square_scale, \
ignore_index, total_classes, class_start_idx, \
n_cols, n_rows, BLOCK_SIZE, num_warps
)
if world_size > 1:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(labels == ignore_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(labels == ignore_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.mark_non_differentiable(z_losses)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignore_index = ignore_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses, z_losses
@staticmethod
def backward(ctx, grad_losses, grad_z_losses):
del grad_z_losses # z_losses are only for logging.
logits, lse, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
dlogits = torch.ops.flasht5.cross_entropy_triton_bwd(
grad_losses, logits, lse, labels, \
ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \
ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \
ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps
)
if ctx.inplace_backward:
dlogits = logits
return dlogits, None, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
precomputed_lse: Optional[torch.Tensor] = None,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignore_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return CrossEntropyLoss.apply(
logits.view(-1, logits.shape[-1]),
labels.view(-1),
precomputed_lse,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
inplace_backward,
process_group,
)