from collections import OrderedDict import torch @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): name = name.replace("module.", "") ema_params[name].mul_(decay).add_(param.detach().cpu().data, alpha=1 - decay) @torch.no_grad() def update_cpu_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): name = name.replace("module.", "") ema_params[name].mul_(decay).add_(param.detach().cpu().data, alpha=1 - decay) def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag