|
|
|
import copy
|
|
import inspect
|
|
|
|
import torch
|
|
|
|
from ...utils import Registry, build_from_cfg
|
|
|
|
OPTIMIZERS = Registry('optimizer')
|
|
OPTIMIZER_BUILDERS = Registry('optimizer builder')
|
|
|
|
|
|
def register_torch_optimizers():
|
|
torch_optimizers = []
|
|
for module_name in dir(torch.optim):
|
|
if module_name.startswith('__'):
|
|
continue
|
|
_optim = getattr(torch.optim, module_name)
|
|
if inspect.isclass(_optim) and issubclass(_optim,
|
|
torch.optim.Optimizer):
|
|
OPTIMIZERS.register_module()(_optim)
|
|
torch_optimizers.append(module_name)
|
|
return torch_optimizers
|
|
|
|
|
|
TORCH_OPTIMIZERS = register_torch_optimizers()
|
|
|
|
|
|
def build_optimizer_constructor(cfg):
|
|
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
|
|
|
|
|
def build_optimizer(model, cfg):
|
|
optimizer_cfg = copy.deepcopy(cfg)
|
|
constructor_type = optimizer_cfg.pop('constructor',
|
|
'DefaultOptimizerConstructor')
|
|
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
|
optim_constructor = build_optimizer_constructor(
|
|
dict(
|
|
type=constructor_type,
|
|
optimizer_cfg=optimizer_cfg,
|
|
paramwise_cfg=paramwise_cfg))
|
|
optimizer = optim_constructor(model)
|
|
return optimizer
|
|
|