|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
|
_unflatten_dense_tensors)
|
|
|
|
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
|
|
from .registry import MODULE_WRAPPERS
|
|
from .scatter_gather import scatter_kwargs
|
|
|
|
|
|
@MODULE_WRAPPERS.register_module()
|
|
class MMDistributedDataParallel(nn.Module):
|
|
|
|
def __init__(self,
|
|
module,
|
|
dim=0,
|
|
broadcast_buffers=True,
|
|
bucket_cap_mb=25):
|
|
super(MMDistributedDataParallel, self).__init__()
|
|
self.module = module
|
|
self.dim = dim
|
|
self.broadcast_buffers = broadcast_buffers
|
|
|
|
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
|
|
self._sync_params()
|
|
|
|
def _dist_broadcast_coalesced(self, tensors, buffer_size):
|
|
for tensors in _take_tensors(tensors, buffer_size):
|
|
flat_tensors = _flatten_dense_tensors(tensors)
|
|
dist.broadcast(flat_tensors, 0)
|
|
for tensor, synced in zip(
|
|
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
|
|
tensor.copy_(synced)
|
|
|
|
def _sync_params(self):
|
|
module_states = list(self.module.state_dict().values())
|
|
if len(module_states) > 0:
|
|
self._dist_broadcast_coalesced(module_states,
|
|
self.broadcast_bucket_size)
|
|
if self.broadcast_buffers:
|
|
if (TORCH_VERSION != 'parrots'
|
|
and digit_version(TORCH_VERSION) < digit_version('1.0')):
|
|
buffers = [b.data for b in self.module._all_buffers()]
|
|
else:
|
|
buffers = [b.data for b in self.module.buffers()]
|
|
if len(buffers) > 0:
|
|
self._dist_broadcast_coalesced(buffers,
|
|
self.broadcast_bucket_size)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
inputs, kwargs = self.scatter(inputs, kwargs,
|
|
[torch.cuda.current_device()])
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
|
|
def train_step(self, *inputs, **kwargs):
|
|
inputs, kwargs = self.scatter(inputs, kwargs,
|
|
[torch.cuda.current_device()])
|
|
output = self.module.train_step(*inputs[0], **kwargs[0])
|
|
return output
|
|
|
|
def val_step(self, *inputs, **kwargs):
|
|
inputs, kwargs = self.scatter(inputs, kwargs,
|
|
[torch.cuda.current_device()])
|
|
output = self.module.val_step(*inputs[0], **kwargs[0])
|
|
return output
|
|
|