|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from torch.nn.modules.module import Module
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from annotator.uniformer.mmcv.cnn import NORM_LAYERS
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext('_ext', [
|
|
'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
|
|
'sync_bn_backward_param', 'sync_bn_backward_data'
|
|
])
|
|
|
|
|
|
class SyncBatchNormFunction(Function):
|
|
|
|
@staticmethod
|
|
def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
|
|
eps, group, group_size, stats_mode):
|
|
return g.op(
|
|
'mmcv::MMCVSyncBatchNorm',
|
|
input,
|
|
running_mean,
|
|
running_var,
|
|
weight,
|
|
bias,
|
|
momentum_f=momentum,
|
|
eps_f=eps,
|
|
group_i=group,
|
|
group_size_i=group_size,
|
|
stats_mode=stats_mode)
|
|
|
|
@staticmethod
|
|
def forward(self, input, running_mean, running_var, weight, bias, momentum,
|
|
eps, group, group_size, stats_mode):
|
|
self.momentum = momentum
|
|
self.eps = eps
|
|
self.group = group
|
|
self.group_size = group_size
|
|
self.stats_mode = stats_mode
|
|
|
|
assert isinstance(
|
|
input, (torch.HalfTensor, torch.FloatTensor,
|
|
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
|
|
f'only support Half or Float Tensor, but {input.type()}'
|
|
output = torch.zeros_like(input)
|
|
input3d = input.flatten(start_dim=2)
|
|
output3d = output.view_as(input3d)
|
|
num_channels = input3d.size(1)
|
|
|
|
|
|
|
|
mean = torch.zeros(
|
|
num_channels, dtype=torch.float, device=input3d.device)
|
|
var = torch.zeros(
|
|
num_channels, dtype=torch.float, device=input3d.device)
|
|
norm = torch.zeros_like(
|
|
input3d, dtype=torch.float, device=input3d.device)
|
|
std = torch.zeros(
|
|
num_channels, dtype=torch.float, device=input3d.device)
|
|
|
|
batch_size = input3d.size(0)
|
|
if batch_size > 0:
|
|
ext_module.sync_bn_forward_mean(input3d, mean)
|
|
batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
|
|
else:
|
|
|
|
batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
|
|
|
|
|
|
vec = torch.cat([mean, batch_flag])
|
|
if self.stats_mode == 'N':
|
|
vec *= batch_size
|
|
if self.group_size > 1:
|
|
dist.all_reduce(vec, group=self.group)
|
|
total_batch = vec[-1].detach()
|
|
mean = vec[:num_channels]
|
|
|
|
if self.stats_mode == 'default':
|
|
mean = mean / self.group_size
|
|
elif self.stats_mode == 'N':
|
|
mean = mean / total_batch.clamp(min=1)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
if batch_size > 0:
|
|
ext_module.sync_bn_forward_var(input3d, mean, var)
|
|
|
|
if self.stats_mode == 'N':
|
|
var *= batch_size
|
|
if self.group_size > 1:
|
|
dist.all_reduce(var, group=self.group)
|
|
|
|
if self.stats_mode == 'default':
|
|
var /= self.group_size
|
|
elif self.stats_mode == 'N':
|
|
var /= total_batch.clamp(min=1)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
update_flag = total_batch.clamp(max=1)
|
|
momentum = update_flag * self.momentum
|
|
ext_module.sync_bn_forward_output(
|
|
input3d,
|
|
mean,
|
|
var,
|
|
weight,
|
|
bias,
|
|
running_mean,
|
|
running_var,
|
|
norm,
|
|
std,
|
|
output3d,
|
|
eps=self.eps,
|
|
momentum=momentum,
|
|
group_size=self.group_size)
|
|
self.save_for_backward(norm, std, weight)
|
|
return output
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(self, grad_output):
|
|
norm, std, weight = self.saved_tensors
|
|
grad_weight = torch.zeros_like(weight)
|
|
grad_bias = torch.zeros_like(weight)
|
|
grad_input = torch.zeros_like(grad_output)
|
|
grad_output3d = grad_output.flatten(start_dim=2)
|
|
grad_input3d = grad_input.view_as(grad_output3d)
|
|
|
|
batch_size = grad_input3d.size(0)
|
|
if batch_size > 0:
|
|
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
|
|
grad_bias)
|
|
|
|
|
|
if self.group_size > 1:
|
|
dist.all_reduce(grad_weight, group=self.group)
|
|
dist.all_reduce(grad_bias, group=self.group)
|
|
grad_weight /= self.group_size
|
|
grad_bias /= self.group_size
|
|
|
|
if batch_size > 0:
|
|
ext_module.sync_bn_backward_data(grad_output3d, weight,
|
|
grad_weight, grad_bias, norm, std,
|
|
grad_input3d)
|
|
|
|
return grad_input, None, None, grad_weight, grad_bias, \
|
|
None, None, None, None, None
|
|
|
|
|
|
@NORM_LAYERS.register_module(name='MMSyncBN')
|
|
class SyncBatchNorm(Module):
|
|
"""Synchronized Batch Normalization.
|
|
|
|
Args:
|
|
num_features (int): number of features/chennels in input tensor
|
|
eps (float, optional): a value added to the denominator for numerical
|
|
stability. Defaults to 1e-5.
|
|
momentum (float, optional): the value used for the running_mean and
|
|
running_var computation. Defaults to 0.1.
|
|
affine (bool, optional): whether to use learnable affine parameters.
|
|
Defaults to True.
|
|
track_running_stats (bool, optional): whether to track the running
|
|
mean and variance during training. When set to False, this
|
|
module does not track such statistics, and initializes statistics
|
|
buffers ``running_mean`` and ``running_var`` as ``None``. When
|
|
these buffers are ``None``, this module always uses batch
|
|
statistics in both training and eval modes. Defaults to True.
|
|
group (int, optional): synchronization of stats happen within
|
|
each process group individually. By default it is synchronization
|
|
across the whole world. Defaults to None.
|
|
stats_mode (str, optional): The statistical mode. Available options
|
|
includes ``'default'`` and ``'N'``. Defaults to 'default'.
|
|
When ``stats_mode=='default'``, it computes the overall statistics
|
|
using those from each worker with equal weight, i.e., the
|
|
statistics are synchronized and simply divied by ``group``. This
|
|
mode will produce inaccurate statistics when empty tensors occur.
|
|
When ``stats_mode=='N'``, it compute the overall statistics using
|
|
the total number of batches in each worker ignoring the number of
|
|
group, i.e., the statistics are synchronized and then divied by
|
|
the total batch ``N``. This mode is beneficial when empty tensors
|
|
occur during training, as it average the total mean by the real
|
|
number of batch.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_features,
|
|
eps=1e-5,
|
|
momentum=0.1,
|
|
affine=True,
|
|
track_running_stats=True,
|
|
group=None,
|
|
stats_mode='default'):
|
|
super(SyncBatchNorm, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.affine = affine
|
|
self.track_running_stats = track_running_stats
|
|
group = dist.group.WORLD if group is None else group
|
|
self.group = group
|
|
self.group_size = dist.get_world_size(group)
|
|
assert stats_mode in ['default', 'N'], \
|
|
f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
|
|
self.stats_mode = stats_mode
|
|
if self.affine:
|
|
self.weight = Parameter(torch.Tensor(num_features))
|
|
self.bias = Parameter(torch.Tensor(num_features))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
if self.track_running_stats:
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
self.register_buffer('num_batches_tracked',
|
|
torch.tensor(0, dtype=torch.long))
|
|
else:
|
|
self.register_buffer('running_mean', None)
|
|
self.register_buffer('running_var', None)
|
|
self.register_buffer('num_batches_tracked', None)
|
|
self.reset_parameters()
|
|
|
|
def reset_running_stats(self):
|
|
if self.track_running_stats:
|
|
self.running_mean.zero_()
|
|
self.running_var.fill_(1)
|
|
self.num_batches_tracked.zero_()
|
|
|
|
def reset_parameters(self):
|
|
self.reset_running_stats()
|
|
if self.affine:
|
|
self.weight.data.uniform_()
|
|
self.bias.data.zero_()
|
|
|
|
def forward(self, input):
|
|
if input.dim() < 2:
|
|
raise ValueError(
|
|
f'expected at least 2D input, got {input.dim()}D input')
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
if self.num_batches_tracked is not None:
|
|
self.num_batches_tracked += 1
|
|
if self.momentum is None:
|
|
exponential_average_factor = 1.0 / float(
|
|
self.num_batches_tracked)
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training or not self.track_running_stats:
|
|
return SyncBatchNormFunction.apply(
|
|
input, self.running_mean, self.running_var, self.weight,
|
|
self.bias, exponential_average_factor, self.eps, self.group,
|
|
self.group_size, self.stats_mode)
|
|
else:
|
|
return F.batch_norm(input, self.running_mean, self.running_var,
|
|
self.weight, self.bias, False,
|
|
exponential_average_factor, self.eps)
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__
|
|
s += f'({self.num_features}, '
|
|
s += f'eps={self.eps}, '
|
|
s += f'momentum={self.momentum}, '
|
|
s += f'affine={self.affine}, '
|
|
s += f'track_running_stats={self.track_running_stats}, '
|
|
s += f'group_size={self.group_size},'
|
|
s += f'stats_mode={self.stats_mode})'
|
|
return s
|
|
|