|
|
|
import copy
|
|
import logging
|
|
import os.path as osp
|
|
import warnings
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
import annotator.uniformer.mmcv as mmcv
|
|
from ..parallel import is_module_wrapper
|
|
from .checkpoint import load_checkpoint
|
|
from .dist_utils import get_dist_info
|
|
from .hooks import HOOKS, Hook
|
|
from .log_buffer import LogBuffer
|
|
from .priority import Priority, get_priority
|
|
from .utils import get_time_str
|
|
|
|
|
|
class BaseRunner(metaclass=ABCMeta):
|
|
"""The base class of Runner, a training helper for PyTorch.
|
|
|
|
All subclasses should implement the following APIs:
|
|
|
|
- ``run()``
|
|
- ``train()``
|
|
- ``val()``
|
|
- ``save_checkpoint()``
|
|
|
|
Args:
|
|
model (:obj:`torch.nn.Module`): The model to be run.
|
|
batch_processor (callable): A callable method that process a data
|
|
batch. The interface of this method should be
|
|
`batch_processor(model, data, train_mode) -> dict`
|
|
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
|
|
optimizer (in most cases) or a dict of optimizers (in models that
|
|
requires more than one optimizer, e.g., GAN).
|
|
work_dir (str, optional): The working directory to save checkpoints
|
|
and logs. Defaults to None.
|
|
logger (:obj:`logging.Logger`): Logger used during training.
|
|
Defaults to None. (The default value is just for backward
|
|
compatibility)
|
|
meta (dict | None): A dict records some import information such as
|
|
environment info and seed, which will be logged in logger hook.
|
|
Defaults to None.
|
|
max_epochs (int, optional): Total training epochs.
|
|
max_iters (int, optional): Total training iterations.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model,
|
|
batch_processor=None,
|
|
optimizer=None,
|
|
work_dir=None,
|
|
logger=None,
|
|
meta=None,
|
|
max_iters=None,
|
|
max_epochs=None):
|
|
if batch_processor is not None:
|
|
if not callable(batch_processor):
|
|
raise TypeError('batch_processor must be callable, '
|
|
f'but got {type(batch_processor)}')
|
|
warnings.warn('batch_processor is deprecated, please implement '
|
|
'train_step() and val_step() in the model instead.')
|
|
|
|
|
|
if is_module_wrapper(model):
|
|
_model = model.module
|
|
else:
|
|
_model = model
|
|
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
|
|
raise RuntimeError(
|
|
'batch_processor and model.train_step()/model.val_step() '
|
|
'cannot be both available.')
|
|
else:
|
|
assert hasattr(model, 'train_step')
|
|
|
|
|
|
if isinstance(optimizer, dict):
|
|
for name, optim in optimizer.items():
|
|
if not isinstance(optim, Optimizer):
|
|
raise TypeError(
|
|
f'optimizer must be a dict of torch.optim.Optimizers, '
|
|
f'but optimizer["{name}"] is a {type(optim)}')
|
|
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
|
|
raise TypeError(
|
|
f'optimizer must be a torch.optim.Optimizer object '
|
|
f'or dict or None, but got {type(optimizer)}')
|
|
|
|
|
|
if not isinstance(logger, logging.Logger):
|
|
raise TypeError(f'logger must be a logging.Logger object, '
|
|
f'but got {type(logger)}')
|
|
|
|
|
|
if meta is not None and not isinstance(meta, dict):
|
|
raise TypeError(
|
|
f'meta must be a dict or None, but got {type(meta)}')
|
|
|
|
self.model = model
|
|
self.batch_processor = batch_processor
|
|
self.optimizer = optimizer
|
|
self.logger = logger
|
|
self.meta = meta
|
|
|
|
if mmcv.is_str(work_dir):
|
|
self.work_dir = osp.abspath(work_dir)
|
|
mmcv.mkdir_or_exist(self.work_dir)
|
|
elif work_dir is None:
|
|
self.work_dir = None
|
|
else:
|
|
raise TypeError('"work_dir" must be a str or None')
|
|
|
|
|
|
if hasattr(self.model, 'module'):
|
|
self._model_name = self.model.module.__class__.__name__
|
|
else:
|
|
self._model_name = self.model.__class__.__name__
|
|
|
|
self._rank, self._world_size = get_dist_info()
|
|
self.timestamp = get_time_str()
|
|
self.mode = None
|
|
self._hooks = []
|
|
self._epoch = 0
|
|
self._iter = 0
|
|
self._inner_iter = 0
|
|
|
|
if max_epochs is not None and max_iters is not None:
|
|
raise ValueError(
|
|
'Only one of `max_epochs` or `max_iters` can be set.')
|
|
|
|
self._max_epochs = max_epochs
|
|
self._max_iters = max_iters
|
|
|
|
self.log_buffer = LogBuffer()
|
|
|
|
@property
|
|
def model_name(self):
|
|
"""str: Name of the model, usually the module class name."""
|
|
return self._model_name
|
|
|
|
@property
|
|
def rank(self):
|
|
"""int: Rank of current process. (distributed training)"""
|
|
return self._rank
|
|
|
|
@property
|
|
def world_size(self):
|
|
"""int: Number of processes participating in the job.
|
|
(distributed training)"""
|
|
return self._world_size
|
|
|
|
@property
|
|
def hooks(self):
|
|
"""list[:obj:`Hook`]: A list of registered hooks."""
|
|
return self._hooks
|
|
|
|
@property
|
|
def epoch(self):
|
|
"""int: Current epoch."""
|
|
return self._epoch
|
|
|
|
@property
|
|
def iter(self):
|
|
"""int: Current iteration."""
|
|
return self._iter
|
|
|
|
@property
|
|
def inner_iter(self):
|
|
"""int: Iteration in an epoch."""
|
|
return self._inner_iter
|
|
|
|
@property
|
|
def max_epochs(self):
|
|
"""int: Maximum training epochs."""
|
|
return self._max_epochs
|
|
|
|
@property
|
|
def max_iters(self):
|
|
"""int: Maximum training iterations."""
|
|
return self._max_iters
|
|
|
|
@abstractmethod
|
|
def train(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def val(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def run(self, data_loaders, workflow, **kwargs):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_checkpoint(self,
|
|
out_dir,
|
|
filename_tmpl,
|
|
save_optimizer=True,
|
|
meta=None,
|
|
create_symlink=True):
|
|
pass
|
|
|
|
def current_lr(self):
|
|
"""Get current learning rates.
|
|
|
|
Returns:
|
|
list[float] | dict[str, list[float]]: Current learning rates of all
|
|
param groups. If the runner has a dict of optimizers, this
|
|
method will return a dict.
|
|
"""
|
|
if isinstance(self.optimizer, torch.optim.Optimizer):
|
|
lr = [group['lr'] for group in self.optimizer.param_groups]
|
|
elif isinstance(self.optimizer, dict):
|
|
lr = dict()
|
|
for name, optim in self.optimizer.items():
|
|
lr[name] = [group['lr'] for group in optim.param_groups]
|
|
else:
|
|
raise RuntimeError(
|
|
'lr is not applicable because optimizer does not exist.')
|
|
return lr
|
|
|
|
def current_momentum(self):
|
|
"""Get current momentums.
|
|
|
|
Returns:
|
|
list[float] | dict[str, list[float]]: Current momentums of all
|
|
param groups. If the runner has a dict of optimizers, this
|
|
method will return a dict.
|
|
"""
|
|
|
|
def _get_momentum(optimizer):
|
|
momentums = []
|
|
for group in optimizer.param_groups:
|
|
if 'momentum' in group.keys():
|
|
momentums.append(group['momentum'])
|
|
elif 'betas' in group.keys():
|
|
momentums.append(group['betas'][0])
|
|
else:
|
|
momentums.append(0)
|
|
return momentums
|
|
|
|
if self.optimizer is None:
|
|
raise RuntimeError(
|
|
'momentum is not applicable because optimizer does not exist.')
|
|
elif isinstance(self.optimizer, torch.optim.Optimizer):
|
|
momentums = _get_momentum(self.optimizer)
|
|
elif isinstance(self.optimizer, dict):
|
|
momentums = dict()
|
|
for name, optim in self.optimizer.items():
|
|
momentums[name] = _get_momentum(optim)
|
|
return momentums
|
|
|
|
def register_hook(self, hook, priority='NORMAL'):
|
|
"""Register a hook into the hook list.
|
|
|
|
The hook will be inserted into a priority queue, with the specified
|
|
priority (See :class:`Priority` for details of priorities).
|
|
For hooks with the same priority, they will be triggered in the same
|
|
order as they are registered.
|
|
|
|
Args:
|
|
hook (:obj:`Hook`): The hook to be registered.
|
|
priority (int or str or :obj:`Priority`): Hook priority.
|
|
Lower value means higher priority.
|
|
"""
|
|
assert isinstance(hook, Hook)
|
|
if hasattr(hook, 'priority'):
|
|
raise ValueError('"priority" is a reserved attribute for hooks')
|
|
priority = get_priority(priority)
|
|
hook.priority = priority
|
|
|
|
inserted = False
|
|
for i in range(len(self._hooks) - 1, -1, -1):
|
|
if priority >= self._hooks[i].priority:
|
|
self._hooks.insert(i + 1, hook)
|
|
inserted = True
|
|
break
|
|
if not inserted:
|
|
self._hooks.insert(0, hook)
|
|
|
|
def register_hook_from_cfg(self, hook_cfg):
|
|
"""Register a hook from its cfg.
|
|
|
|
Args:
|
|
hook_cfg (dict): Hook config. It should have at least keys 'type'
|
|
and 'priority' indicating its type and priority.
|
|
|
|
Notes:
|
|
The specific hook class to register should not use 'type' and
|
|
'priority' arguments during initialization.
|
|
"""
|
|
hook_cfg = hook_cfg.copy()
|
|
priority = hook_cfg.pop('priority', 'NORMAL')
|
|
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
|
|
self.register_hook(hook, priority=priority)
|
|
|
|
def call_hook(self, fn_name):
|
|
"""Call all hooks.
|
|
|
|
Args:
|
|
fn_name (str): The function name in each hook to be called, such as
|
|
"before_train_epoch".
|
|
"""
|
|
for hook in self._hooks:
|
|
getattr(hook, fn_name)(self)
|
|
|
|
def get_hook_info(self):
|
|
|
|
stage_hook_map = {stage: [] for stage in Hook.stages}
|
|
for hook in self.hooks:
|
|
try:
|
|
priority = Priority(hook.priority).name
|
|
except ValueError:
|
|
priority = hook.priority
|
|
classname = hook.__class__.__name__
|
|
hook_info = f'({priority:<12}) {classname:<35}'
|
|
for trigger_stage in hook.get_triggered_stages():
|
|
stage_hook_map[trigger_stage].append(hook_info)
|
|
|
|
stage_hook_infos = []
|
|
for stage in Hook.stages:
|
|
hook_infos = stage_hook_map[stage]
|
|
if len(hook_infos) > 0:
|
|
info = f'{stage}:\n'
|
|
info += '\n'.join(hook_infos)
|
|
info += '\n -------------------- '
|
|
stage_hook_infos.append(info)
|
|
return '\n'.join(stage_hook_infos)
|
|
|
|
def load_checkpoint(self,
|
|
filename,
|
|
map_location='cpu',
|
|
strict=False,
|
|
revise_keys=[(r'^module.', '')]):
|
|
return load_checkpoint(
|
|
self.model,
|
|
filename,
|
|
map_location,
|
|
strict,
|
|
self.logger,
|
|
revise_keys=revise_keys)
|
|
|
|
def resume(self,
|
|
checkpoint,
|
|
resume_optimizer=True,
|
|
map_location='default'):
|
|
if map_location == 'default':
|
|
if torch.cuda.is_available():
|
|
device_id = torch.cuda.current_device()
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint,
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
else:
|
|
checkpoint = self.load_checkpoint(checkpoint)
|
|
else:
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint, map_location=map_location)
|
|
|
|
self._epoch = checkpoint['meta']['epoch']
|
|
self._iter = checkpoint['meta']['iter']
|
|
if self.meta is None:
|
|
self.meta = {}
|
|
self.meta.setdefault('hook_msgs', {})
|
|
|
|
self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
|
|
|
|
|
|
|
|
if 'config' in checkpoint['meta']:
|
|
config = mmcv.Config.fromstring(
|
|
checkpoint['meta']['config'], file_format='.py')
|
|
previous_gpu_ids = config.get('gpu_ids', None)
|
|
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
|
|
previous_gpu_ids) != self.world_size:
|
|
self._iter = int(self._iter * len(previous_gpu_ids) /
|
|
self.world_size)
|
|
self.logger.info('the iteration number is changed due to '
|
|
'change of GPU number')
|
|
|
|
|
|
self.meta = checkpoint['meta']
|
|
|
|
if 'optimizer' in checkpoint and resume_optimizer:
|
|
if isinstance(self.optimizer, Optimizer):
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
elif isinstance(self.optimizer, dict):
|
|
for k in self.optimizer.keys():
|
|
self.optimizer[k].load_state_dict(
|
|
checkpoint['optimizer'][k])
|
|
else:
|
|
raise TypeError(
|
|
'Optimizer should be dict or torch.optim.Optimizer '
|
|
f'but got {type(self.optimizer)}')
|
|
|
|
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
|
|
|
def register_lr_hook(self, lr_config):
|
|
if lr_config is None:
|
|
return
|
|
elif isinstance(lr_config, dict):
|
|
assert 'policy' in lr_config
|
|
policy_type = lr_config.pop('policy')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if policy_type == policy_type.lower():
|
|
policy_type = policy_type.title()
|
|
hook_type = policy_type + 'LrUpdaterHook'
|
|
lr_config['type'] = hook_type
|
|
hook = mmcv.build_from_cfg(lr_config, HOOKS)
|
|
else:
|
|
hook = lr_config
|
|
self.register_hook(hook, priority='VERY_HIGH')
|
|
|
|
def register_momentum_hook(self, momentum_config):
|
|
if momentum_config is None:
|
|
return
|
|
if isinstance(momentum_config, dict):
|
|
assert 'policy' in momentum_config
|
|
policy_type = momentum_config.pop('policy')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if policy_type == policy_type.lower():
|
|
policy_type = policy_type.title()
|
|
hook_type = policy_type + 'MomentumUpdaterHook'
|
|
momentum_config['type'] = hook_type
|
|
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
|
|
else:
|
|
hook = momentum_config
|
|
self.register_hook(hook, priority='HIGH')
|
|
|
|
def register_optimizer_hook(self, optimizer_config):
|
|
if optimizer_config is None:
|
|
return
|
|
if isinstance(optimizer_config, dict):
|
|
optimizer_config.setdefault('type', 'OptimizerHook')
|
|
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
|
|
else:
|
|
hook = optimizer_config
|
|
self.register_hook(hook, priority='ABOVE_NORMAL')
|
|
|
|
def register_checkpoint_hook(self, checkpoint_config):
|
|
if checkpoint_config is None:
|
|
return
|
|
if isinstance(checkpoint_config, dict):
|
|
checkpoint_config.setdefault('type', 'CheckpointHook')
|
|
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
|
|
else:
|
|
hook = checkpoint_config
|
|
self.register_hook(hook, priority='NORMAL')
|
|
|
|
def register_logger_hooks(self, log_config):
|
|
if log_config is None:
|
|
return
|
|
log_interval = log_config['interval']
|
|
for info in log_config['hooks']:
|
|
logger_hook = mmcv.build_from_cfg(
|
|
info, HOOKS, default_args=dict(interval=log_interval))
|
|
self.register_hook(logger_hook, priority='VERY_LOW')
|
|
|
|
def register_timer_hook(self, timer_config):
|
|
if timer_config is None:
|
|
return
|
|
if isinstance(timer_config, dict):
|
|
timer_config_ = copy.deepcopy(timer_config)
|
|
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
|
|
else:
|
|
hook = timer_config
|
|
self.register_hook(hook, priority='LOW')
|
|
|
|
def register_custom_hooks(self, custom_config):
|
|
if custom_config is None:
|
|
return
|
|
|
|
if not isinstance(custom_config, list):
|
|
custom_config = [custom_config]
|
|
|
|
for item in custom_config:
|
|
if isinstance(item, dict):
|
|
self.register_hook_from_cfg(item)
|
|
else:
|
|
self.register_hook(item, priority='NORMAL')
|
|
|
|
def register_profiler_hook(self, profiler_config):
|
|
if profiler_config is None:
|
|
return
|
|
if isinstance(profiler_config, dict):
|
|
profiler_config.setdefault('type', 'ProfilerHook')
|
|
hook = mmcv.build_from_cfg(profiler_config, HOOKS)
|
|
else:
|
|
hook = profiler_config
|
|
self.register_hook(hook)
|
|
|
|
def register_training_hooks(self,
|
|
lr_config,
|
|
optimizer_config=None,
|
|
checkpoint_config=None,
|
|
log_config=None,
|
|
momentum_config=None,
|
|
timer_config=dict(type='IterTimerHook'),
|
|
custom_hooks_config=None):
|
|
"""Register default and custom hooks for training.
|
|
|
|
Default and custom hooks include:
|
|
|
|
+----------------------+-------------------------+
|
|
| Hooks | Priority |
|
|
+======================+=========================+
|
|
| LrUpdaterHook | VERY_HIGH (10) |
|
|
+----------------------+-------------------------+
|
|
| MomentumUpdaterHook | HIGH (30) |
|
|
+----------------------+-------------------------+
|
|
| OptimizerStepperHook | ABOVE_NORMAL (40) |
|
|
+----------------------+-------------------------+
|
|
| CheckpointSaverHook | NORMAL (50) |
|
|
+----------------------+-------------------------+
|
|
| IterTimerHook | LOW (70) |
|
|
+----------------------+-------------------------+
|
|
| LoggerHook(s) | VERY_LOW (90) |
|
|
+----------------------+-------------------------+
|
|
| CustomHook(s) | defaults to NORMAL (50) |
|
|
+----------------------+-------------------------+
|
|
|
|
If custom hooks have same priority with default hooks, custom hooks
|
|
will be triggered after default hooks.
|
|
"""
|
|
self.register_lr_hook(lr_config)
|
|
self.register_momentum_hook(momentum_config)
|
|
self.register_optimizer_hook(optimizer_config)
|
|
self.register_checkpoint_hook(checkpoint_config)
|
|
self.register_timer_hook(timer_config)
|
|
self.register_logger_hooks(log_config)
|
|
self.register_custom_hooks(custom_hooks_config)
|
|
|