Text2Text Generation
Transformers
PyTorch
French
flash_t5
custom_code
FAT5-small / adamw_scaled.py
bourdoiscatie's picture
Add FAT5-small
0743270 verified
raw
history blame
12.3 kB
import torch
import math
from torch.optim import Optimizer
from torch.optim.optimizer import _default_to_fused_or_foreach
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import Iterable, Tuple
from torch import nn, Tensor
class AdamWScale(Optimizer):
"""
This AdamW implementation is copied from Huggingface.
We modified it with Adagrad scaling by rms of a weight tensor
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 1e-3):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-6):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
kahan_sum (`bool`, *optional*, defaults to False):
Whether to use Kahan summation for updating parameters.
foreach (`bool`, *optional*, defaults to False):
Whether to use the foreach implementation.
correct_bias (`bool`, *optional*, defaults to True):
Whether to correct bias in Adam.
use_state_dtype (`torch.dtype`, *optional*, defaults to None):
The dtype to use for optimizer state. If None, use the default dtype.
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
kahan_sum: bool = False,
foreach: bool = False,
correct_bias: bool = True,
use_state_dtype: torch.dtype = None
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
assert not (foreach and use_state_dtype is not None), "foreach is not supported with use_state_dtype"
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, \
kahan_sum=kahan_sum, correct_bias=correct_bias, use_state_dtype=use_state_dtype)
super().__init__(params, defaults)
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
params, grads, exp_avgs, exp_avg_sqs, steps, kahan_comps = [], [], [], [], [], []
# Initialization
for p in group['params']:
if p.grad is None:
continue
params.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamWScale does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if "kahan_comp" not in state:
state['step'] = torch.tensor(0, dtype=torch.int32, device=p.device)
if group["use_state_dtype"] in [torch.float16, torch.bfloat16]:
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
state['exp_avg_sq'] = torch.zeros_like(p, device=p.device, dtype=group["use_state_dtype"])
else:
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["kahan_sum"] and p.dtype in [torch.float16, torch.bfloat16]:
state["kahan_comp"] = torch.zeros_like(p, memory_format=torch.preserve_format)
else:
state["kahan_comp"] = None
group["kahan_sum"] = False
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
kahan_comps.append(state["kahan_comp"])
steps.append(state["step"])
torch._foreach_add_(steps, 1)
# AdamW step
if group["foreach"] and _default_to_fused_or_foreach(params, False, False):
self._foreach_adamwscaled(params,
grads,
exp_avgs,
exp_avg_sqs,
steps,
kahan_comps,
group["lr"],
group["betas"][0],
group["betas"][1],
group["weight_decay"],
group["eps"],
group["kahan_sum"],
group["correct_bias"])
else:
self._adamwscaled(params,
grads,
exp_avgs,
exp_avg_sqs,
steps,
kahan_comps,
group["lr"],
group["betas"][0],
group["betas"][1],
group["weight_decay"],
group["eps"],
group["kahan_sum"],
group["correct_bias"])
return loss
def _adamwscaled(self,
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
steps: list[Tensor],
kahan_comps: list[Tensor],
lr: float,
beta1: float,
beta2: float,
weight_decay: float,
eps: float,
do_kahan_sum: bool,
correct_bias: bool):
for i, p in enumerate(params):
exp_avg, exp_avg_sq, grad, step, kahan_comp = exp_avgs[i], exp_avg_sqs[i], grads[i], steps[i], kahan_comps[i]
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2))
denom = exp_avg_sq.sqrt().add_(eps)
step_size = lr
if correct_bias: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** step
bias_correction2 = 1.0 - beta2 ** step
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# Adapt Step from Adafactor
step_size = step_size * max(1e-3, self._rms(p.data))
if do_kahan_sum:
# Adam step
kahan_comp.addcdiv_(exp_avg, denom, value=-step_size)
# update weights with kahan compensation using dev_grads as temp buffer
grad.copy_(p)
p.add_(kahan_comp)
# save error back to kahan compensation for next iteration
grad.sub_(p, alpha=1)
kahan_comp.add_(grad, alpha=1)
else:
p.addcdiv_(exp_avg, denom, value=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if weight_decay > 0.0:
p.add_(p, alpha=(-lr * weight_decay))
def _foreach_adamwscaled(self,
params: list[Tensor],
grads: list[Tensor],
exp_avgs: list[Tensor],
exp_avg_sqs: list[Tensor],
steps: list[Tensor],
kahan_comps: list[Tensor],
lr: float,
beta1: float,
beta2: float,
weight_decay: float,
eps: float,
do_kahan_sum: bool,
correct_bias: bool):
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, kahan_comps])
for (_, dtype), ((dev_params, dev_grads, dev_exp_avgs, dev_exp_avg_sqs, dev_kahan_comps), _) in grouped_tensors.items():
# Foreach implementation
torch._foreach_mul_(dev_exp_avgs, beta1)
torch._foreach_add_(dev_exp_avgs, dev_grads, alpha=1 - beta1)
torch._foreach_mul_(dev_exp_avg_sqs, beta2)
torch._foreach_addcmul_(dev_exp_avg_sqs, dev_grads, dev_grads, 1 - beta2)
# Compute denominator
torch._foreach_copy_(dev_grads, dev_exp_avg_sqs)
torch._foreach_sqrt_(dev_grads)
torch._foreach_add_(dev_grads, eps)
step_size = [torch.tensor(lr, dtype=torch.float32, device=p.device) for p in dev_params]
if correct_bias:
torch._foreach_mul_(step_size,
[torch.tensor((math.sqrt(1 - beta2 ** steps[i].item()) / (1 - beta1 ** steps[i].item()) ), dtype=torch.float32, device=p.device)
for i, p in enumerate(dev_params)])
# Adapt step size using RMS of parameters
rms_p = torch._foreach_norm(dev_params)
numel = [torch.tensor(math.sqrt(p.numel())) for p in dev_params]
torch._foreach_div_(rms_p, numel)
torch._foreach_maximum_(rms_p, 1e-3)
torch._foreach_mul_(step_size, rms_p)
torch._foreach_div_(dev_grads, step_size)
# explicitly delete tensors when not used
del rms_p
del numel
del step_size
# Update parameters
if do_kahan_sum:
# Adam step
torch._foreach_addcdiv_(dev_kahan_comps, dev_exp_avgs, dev_grads, value=-1)
# update weights with kahan compensation using dev_grads as temp buffer
torch._foreach_copy_(dev_grads, dev_params)
torch._foreach_add_(dev_params, dev_kahan_comps, alpha=1)
# save error back to kahan compensation for next iteration
torch._foreach_sub_(dev_grads, dev_params, alpha=1)
torch._foreach_add_(dev_kahan_comps, dev_grads, alpha=1)
else:
torch._foreach_addcdiv_(dev_params, dev_exp_avgs, dev_grads, value=-1)
# Weight decay
if weight_decay > 0.0:
torch._foreach_add_(dev_params, dev_params, alpha=-weight_decay * lr)