Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.cuda.amp import GradScaler | |
from torch.utils.data import Sampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.optim import Optimizer | |
from torch import distributed as dist | |
from torch.utils.tensorboard import SummaryWriter | |
import logging | |
import os | |
import re | |
import glob | |
import collections | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Union, Tuple | |
from datetime import datetime | |
from pathlib import Path | |
import argparse | |
# use duck typing for LRScheduler since we have different possibilities, see | |
# our class LRScheduler. | |
LRSchedulerType = object | |
Pathlike = Union[str, Path] | |
def average_state_dict( | |
state_dict_1: Dict[str, Tensor], | |
state_dict_2: Dict[str, Tensor], | |
weight_1: float, | |
weight_2: float, | |
scaling_factor: float = 1.0, | |
) -> Dict[str, Tensor]: | |
"""Average two state_dict with given weights: | |
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2) | |
* scaling_factor | |
It is an in-place operation on state_dict_1 itself. | |
""" | |
# Identify shared parameters. Two parameters are said to be shared | |
# if they have the same data_ptr | |
uniqued: Dict[int, str] = dict() | |
for k, v in state_dict_1.items(): | |
v_data_ptr = v.data_ptr() | |
if v_data_ptr in uniqued: | |
continue | |
uniqued[v_data_ptr] = k | |
uniqued_names = list(uniqued.values()) | |
for k in uniqued_names: | |
v = state_dict_1[k] | |
if torch.is_floating_point(v): | |
v *= weight_1 | |
v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 | |
v *= scaling_factor | |
def load_checkpoint( | |
filename: Path, | |
model: nn.Module, | |
model_avg: Optional[nn.Module] = None, | |
optimizer: Optional[Optimizer] = None, | |
scheduler: Optional[LRSchedulerType] = None, | |
scaler: Optional[GradScaler] = None, | |
sampler: Optional[Sampler] = None, | |
strict: bool = False, | |
) -> Dict[str, Any]: | |
""" | |
TODO: document it | |
""" | |
logging.info(f"Loading checkpoint from {filename}") | |
checkpoint = torch.load(filename, map_location="cpu") | |
if next(iter(checkpoint["model"])).startswith("module."): | |
logging.info("Loading checkpoint saved by DDP") | |
dst_state_dict = model.state_dict() | |
src_state_dict = checkpoint["model"] | |
for key in dst_state_dict.keys(): | |
src_key = "{}.{}".format("module", key) | |
dst_state_dict[key] = src_state_dict.pop(src_key) | |
assert len(src_state_dict) == 0 | |
model.load_state_dict(dst_state_dict, strict=strict) | |
else: | |
model.load_state_dict(checkpoint["model"], strict=strict) | |
checkpoint.pop("model") | |
if model_avg is not None and "model_avg" in checkpoint: | |
logging.info("Loading averaged model") | |
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) | |
checkpoint.pop("model_avg") | |
def load(name, obj): | |
s = checkpoint.get(name, None) | |
if obj and s: | |
obj.load_state_dict(s) | |
checkpoint.pop(name) | |
load("optimizer", optimizer) | |
load("scheduler", scheduler) | |
load("grad_scaler", scaler) | |
load("sampler", sampler) | |
return checkpoint | |
def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: | |
"""Find all available checkpoints in a directory. | |
The checkpoint filenames have the form: `checkpoint-xxx.pt` | |
where xxx is a numerical value. | |
Assume you have the following checkpoints in the folder `foo`: | |
- checkpoint-1.pt | |
- checkpoint-20.pt | |
- checkpoint-300.pt | |
- checkpoint-4000.pt | |
Case 1 (Return all checkpoints):: | |
find_checkpoints(out_dir='foo') | |
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., | |
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) | |
find_checkpoints(out_dir='foo', iteration=20) | |
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., | |
checkpoint-20.pt, checkpoint-1.pt):: | |
find_checkpoints(out_dir='foo', iteration=-20) | |
Args: | |
out_dir: | |
The directory where to search for checkpoints. | |
iteration: | |
If it is 0, return all available checkpoints. | |
If it is positive, return the checkpoints whose iteration number is | |
greater than or equal to `iteration`. | |
If it is negative, return the checkpoints whose iteration number is | |
less than or equal to `-iteration`. | |
Returns: | |
Return a list of checkpoint filenames, sorted in descending | |
order by the numerical value in the filename. | |
""" | |
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) | |
pattern = re.compile(r"checkpoint-([0-9]+).pt") | |
iter_checkpoints = [] | |
for c in checkpoints: | |
result = pattern.search(c) | |
if not result: | |
logging.warn(f"Invalid checkpoint filename {c}") | |
continue | |
iter_checkpoints.append((int(result.group(1)), c)) | |
# iter_checkpoints is a list of tuples. Each tuple contains | |
# two elements: (iteration_number, checkpoint-iteration_number.pt) | |
iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0]) | |
if iteration >= 0: | |
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] | |
else: | |
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] | |
return ans | |
def remove_checkpoints( | |
out_dir: Path, | |
topk: int, | |
): | |
"""Remove checkpoints from the given directory. | |
We assume that checkpoint filename has the form `checkpoint-xxx.pt` | |
where xxx is a number, representing the number of processed batches | |
when saving that checkpoint. We sort checkpoints by filename and keep | |
only the `topk` checkpoints with the highest `xxx`. | |
Args: | |
out_dir: | |
The directory containing checkpoints to be removed. | |
topk: | |
Number of checkpoints to keep. | |
rank: | |
If using DDP for training, it is the rank of the current node. | |
Use 0 if no DDP is used for training. | |
""" | |
assert topk >= 1, topk | |
checkpoints = find_checkpoints(out_dir) | |
if len(checkpoints) == 0: | |
logging.warn(f"No checkpoints found in {out_dir}") | |
return | |
if len(checkpoints) <= topk: | |
return | |
to_remove = checkpoints[topk:] | |
for c in to_remove: | |
os.remove(c) | |
def save_checkpoint_impl( | |
filename: Path, | |
model: Union[nn.Module, DDP], | |
model_avg: Optional[nn.Module] = None, | |
params: Optional[Dict[str, Any]] = None, | |
optimizer: Optional[Optimizer] = None, | |
scheduler: Optional[LRSchedulerType] = None, | |
scaler: Optional[GradScaler] = None, | |
sampler = None, | |
rank: int = 0, | |
) -> None: | |
"""Save training information to a file. | |
Args: | |
filename: | |
The checkpoint filename. | |
model: | |
The model to be saved. We only save its `state_dict()`. | |
model_avg: | |
The stored model averaged from the start of training. | |
params: | |
User defined parameters, e.g., epoch, loss. | |
optimizer: | |
The optimizer to be saved. We only save its `state_dict()`. | |
scheduler: | |
The scheduler to be saved. We only save its `state_dict()`. | |
scalar: | |
The GradScaler to be saved. We only save its `state_dict()`. | |
rank: | |
Used in DDP. We save checkpoint only for the node whose rank is 0. | |
Returns: | |
Return None. | |
""" | |
if rank != 0: | |
return | |
logging.info(f"Saving checkpoint to {filename}") | |
if isinstance(model, DDP): | |
model = model.module | |
checkpoint = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict() if optimizer is not None else None, | |
"scheduler": scheduler.state_dict() if scheduler is not None else None, | |
"grad_scaler": scaler.state_dict() if scaler is not None else None, | |
"sampler": sampler.state_dict() if sampler is not None else None, | |
} | |
if model_avg is not None: | |
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() | |
if params: | |
for k, v in params.items(): | |
assert k not in checkpoint | |
checkpoint[k] = v | |
torch.save(checkpoint, filename) | |
def save_checkpoint_with_global_batch_idx( | |
out_dir: Path, | |
global_batch_idx: int, | |
model: nn.Module, | |
model_avg: Optional[nn.Module] = None, | |
params: Optional[Dict[str, Any]] = None, | |
optimizer: Optional[Optimizer] = None, | |
scheduler: Optional[LRSchedulerType] = None, | |
scaler: Optional[GradScaler] = None, | |
sampler: Optional[Sampler] = None, | |
rank: int = 0, | |
): | |
"""Save training info after processing given number of batches. | |
Args: | |
out_dir: | |
The directory to save the checkpoint. | |
global_batch_idx: | |
The number of batches processed so far from the very start of the | |
training. The saved checkpoint will have the following filename: | |
f'out_dir / checkpoint-{global_batch_idx}.pt' | |
model: | |
The neural network model whose `state_dict` will be saved in the | |
checkpoint. | |
model_avg: | |
The stored model averaged from the start of training. | |
params: | |
A dict of training configurations to be saved. | |
optimizer: | |
The optimizer used in the training. Its `state_dict` will be saved. | |
scheduler: | |
The learning rate scheduler used in the training. Its `state_dict` will | |
be saved. | |
scaler: | |
The scaler used for mix precision training. Its `state_dict` will | |
be saved. | |
sampler: | |
The sampler used in the training dataset. | |
rank: | |
The rank ID used in DDP training of the current node. Set it to 0 | |
if DDP is not used. | |
""" | |
out_dir = Path(out_dir) | |
out_dir.mkdir(parents=True, exist_ok=True) | |
print(f"Saving model checkpoint with global batch IDX is {global_batch_idx}") | |
filename = out_dir / "checkpoint-global-batch.pt" | |
save_checkpoint_impl( | |
filename=filename, | |
model=model, | |
model_avg=model_avg, | |
params=params, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
scaler=scaler, | |
sampler=sampler, | |
rank=rank, | |
) | |
def update_averaged_model( | |
params: Dict[str, Tensor], | |
model_cur: nn.Module, | |
model_avg: nn.Module, | |
) -> None: | |
"""Update the averaged model: | |
model_avg = model_cur * (average_period / batch_idx_train) | |
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train) | |
Args: | |
params: | |
User defined parameters, e.g., epoch, loss. | |
model_cur: | |
The current model. | |
model_avg: | |
The averaged model to be updated. | |
""" | |
weight_cur = params.average_period / params.batch_idx_train | |
weight_avg = 1 - weight_cur | |
cur = model_cur.state_dict() | |
avg = model_avg.state_dict() | |
average_state_dict( | |
state_dict_1=avg, | |
state_dict_2=cur, | |
weight_1=weight_avg, | |
weight_2=weight_cur, | |
) | |
def cleanup_dist(): | |
dist.destroy_process_group() | |
def setup_dist( | |
rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None | |
): | |
""" | |
rank and world_size are used only if use_ddp_launch is False. | |
""" | |
if "MASTER_ADDR" not in os.environ: | |
os.environ["MASTER_ADDR"] = ( | |
"localhost" if master_addr is None else str(master_addr) | |
) | |
if "MASTER_PORT" not in os.environ: | |
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) | |
if use_ddp_launch is False: | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
torch.cuda.set_device(rank) | |
else: | |
dist.init_process_group("nccl") | |
def register_inf_check_hooks(model: nn.Module) -> None: | |
"""Registering forward hook on each module, to check | |
whether its output tensors is not finite. | |
Args: | |
model: | |
the model to be analyzed. | |
""" | |
for name, module in model.named_modules(): | |
if name == "": | |
name = "<top-level>" | |
# default param _name is a way to capture the current value of the variable "name". | |
def forward_hook(_module, _input, _output, _name=name): | |
if isinstance(_output, Tensor): | |
if not torch.isfinite(_output.to(torch.float32).sum()): | |
raise ValueError( | |
f"The sum of {_name}.output is not finite: {_output}" | |
) | |
elif isinstance(_output, tuple): | |
for i, o in enumerate(_output): | |
if isinstance(o, tuple): | |
o = o[0] | |
if not isinstance(o, Tensor): | |
continue | |
if not torch.isfinite(o.to(torch.float32).sum()): | |
raise ValueError( | |
f"The sum of {_name}.output[{i}] is not finite: {_output}" | |
) | |
# default param _name is a way to capture the current value of the variable "name". | |
def backward_hook(_module, _input, _output, _name=name): | |
if isinstance(_output, Tensor): | |
if not torch.isfinite(_output.to(torch.float32).sum()): | |
logging.warning( | |
f"The sum of {_name}.grad is not finite" # ": {_output}" | |
) | |
elif isinstance(_output, tuple): | |
for i, o in enumerate(_output): | |
if isinstance(o, tuple): | |
o = o[0] | |
if not isinstance(o, Tensor): | |
continue | |
if not torch.isfinite(o.to(torch.float32).sum()): | |
logging.warning(f"The sum of {_name}.grad[{i}] is not finite") | |
module.register_forward_hook(forward_hook) | |
module.register_backward_hook(backward_hook) | |
for name, parameter in model.named_parameters(): | |
def param_backward_hook(grad, _name=name): | |
if not torch.isfinite(grad.to(torch.float32).sum()): | |
logging.warning(f"The sum of {_name}.param_grad is not finite") | |
parameter.register_hook(param_backward_hook) | |
class AttributeDict(dict): | |
def __getattr__(self, key): | |
if key in self: | |
return self[key] | |
raise AttributeError(f"No such attribute '{key}'") | |
def __setattr__(self, key, value): | |
self[key] = value | |
def __delattr__(self, key): | |
if key in self: | |
del self[key] | |
return | |
raise AttributeError(f"No such attribute '{key}'") | |
class MetricsTracker(collections.defaultdict): | |
def __init__(self): | |
# Passing the type 'int' to the base-class constructor | |
# makes undefined items default to int() which is zero. | |
# This class will play a role as metrics tracker. | |
# It can record many metrics, including but not limited to loss. | |
super(MetricsTracker, self).__init__(int) | |
def __add__(self, other: "MetricsTracker") -> "MetricsTracker": | |
ans = MetricsTracker() | |
for k, v in self.items(): | |
ans[k] = v | |
for k, v in other.items(): | |
ans[k] = ans[k] + v | |
return ans | |
def __mul__(self, alpha: float) -> "MetricsTracker": | |
ans = MetricsTracker() | |
for k, v in self.items(): | |
ans[k] = v * alpha | |
return ans | |
def __str__(self) -> str: | |
ans_frames = "" | |
ans_utterances = "" | |
for k, v in self.norm_items(): | |
norm_value = "%.4g" % v | |
if "utt_" not in k: | |
ans_frames += str(k) + "=" + str(norm_value) + ", " | |
else: | |
ans_utterances += str(k) + "=" + str(norm_value) | |
if k == "utt_duration": | |
ans_utterances += " frames, " | |
elif k == "utt_pad_proportion": | |
ans_utterances += ", " | |
else: | |
raise ValueError(f"Unexpected key: {k}") | |
frames = "%.2f" % self["frames"] | |
ans_frames += "over " + str(frames) + " frames. " | |
if ans_utterances != "": | |
utterances = "%.2f" % self["utterances"] | |
ans_utterances += "over " + str(utterances) + " utterances." | |
return ans_frames + ans_utterances | |
def norm_items(self) -> List[Tuple[str, float]]: | |
""" | |
Returns a list of pairs, like: | |
[('ctc_loss', 0.1), ('att_loss', 0.07)] | |
""" | |
num_frames = self["frames"] if "frames" in self else 1 | |
num_utterances = self["utterances"] if "utterances" in self else 1 | |
ans = [] | |
for k, v in self.items(): | |
if k == "frames" or k == "utterances": | |
continue | |
norm_value = ( | |
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances | |
) | |
ans.append((k, norm_value)) | |
return ans | |
def write_summary( | |
self, | |
tb_writer: SummaryWriter, | |
prefix: str, | |
batch_idx: int, | |
) -> None: | |
"""Add logging information to a TensorBoard writer. | |
Args: | |
tb_writer: a TensorBoard writer | |
prefix: a prefix for the name of the loss, e.g. "train/valid_", | |
or "train/current_" | |
batch_idx: The current batch index, used as the x-axis of the plot. | |
""" | |
for k, v in self.norm_items(): | |
tb_writer.add_scalar(prefix + k, v, batch_idx) | |
def setup_logger( | |
log_filename: Pathlike, | |
log_level: str = "info", | |
use_console: bool = True, | |
) -> None: | |
"""Setup log level. | |
Args: | |
log_filename: | |
The filename to save the log. | |
log_level: | |
The log level to use, e.g., "debug", "info", "warning", "error", | |
"critical" | |
use_console: | |
True to also print logs to console. | |
""" | |
now = datetime.now() | |
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") | |
if dist.is_available() and dist.is_initialized(): | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa | |
log_filename = f"{log_filename}-{date_time}-{rank}" | |
else: | |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | |
log_filename = f"{log_filename}-{date_time}" | |
os.makedirs(os.path.dirname(log_filename), exist_ok=True) | |
level = logging.ERROR | |
if log_level == "debug": | |
level = logging.DEBUG | |
elif log_level == "info": | |
level = logging.INFO | |
elif log_level == "warning": | |
level = logging.WARNING | |
elif log_level == "critical": | |
level = logging.CRITICAL | |
logging.basicConfig( | |
filename=log_filename, | |
format=formatter, | |
level=level, | |
filemode="w", | |
) | |
if use_console: | |
console = logging.StreamHandler() | |
console.setLevel(level) | |
console.setFormatter(logging.Formatter(formatter)) | |
logging.getLogger("").addHandler(console) | |
def str2bool(v): | |
"""Used in argparse.ArgumentParser.add_argument to indicate | |
that a type is a bool type and user can enter | |
- yes, true, t, y, 1, to represent True | |
- no, false, f, n, 0, to represent False | |
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa | |
""" | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected.") |