|
|
|
from itertools import chain
|
|
|
|
from torch.nn.parallel import DataParallel
|
|
|
|
from .scatter_gather import scatter_kwargs
|
|
|
|
|
|
class MMDataParallel(DataParallel):
|
|
"""The DataParallel module that supports DataContainer.
|
|
|
|
MMDataParallel has two main differences with PyTorch DataParallel:
|
|
|
|
- It supports a custom type :class:`DataContainer` which allows more
|
|
flexible control of input data during both GPU and CPU inference.
|
|
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
|
|
|
Args:
|
|
module (:class:`nn.Module`): Module to be encapsulated.
|
|
device_ids (list[int]): Device IDS of modules to be scattered to.
|
|
Defaults to None when GPU is not available.
|
|
output_device (str | int): Device ID for output. Defaults to None.
|
|
dim (int): Dimension used to scatter the data. Defaults to 0.
|
|
"""
|
|
|
|
def __init__(self, *args, dim=0, **kwargs):
|
|
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
|
self.dim = dim
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
"""Override the original forward function.
|
|
|
|
The main difference lies in the CPU inference where the data in
|
|
:class:`DataContainers` will still be gathered.
|
|
"""
|
|
if not self.device_ids:
|
|
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
else:
|
|
return super().forward(*inputs, **kwargs)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def train_step(self, *inputs, **kwargs):
|
|
if not self.device_ids:
|
|
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
|
return self.module.train_step(*inputs[0], **kwargs[0])
|
|
|
|
assert len(self.device_ids) == 1, \
|
|
('MMDataParallel only supports single GPU training, if you need to'
|
|
' train with multiple GPUs, please use MMDistributedDataParallel'
|
|
'instead.')
|
|
|
|
for t in chain(self.module.parameters(), self.module.buffers()):
|
|
if t.device != self.src_device_obj:
|
|
raise RuntimeError(
|
|
'module must have its parameters and buffers '
|
|
f'on device {self.src_device_obj} (device_ids[0]) but '
|
|
f'found one of them on device: {t.device}')
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
return self.module.train_step(*inputs[0], **kwargs[0])
|
|
|
|
def val_step(self, *inputs, **kwargs):
|
|
if not self.device_ids:
|
|
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
|
return self.module.val_step(*inputs[0], **kwargs[0])
|
|
|
|
assert len(self.device_ids) == 1, \
|
|
('MMDataParallel only supports single GPU training, if you need to'
|
|
' train with multiple GPUs, please use MMDistributedDataParallel'
|
|
' instead.')
|
|
|
|
for t in chain(self.module.parameters(), self.module.buffers()):
|
|
if t.device != self.src_device_obj:
|
|
raise RuntimeError(
|
|
'module must have its parameters and buffers '
|
|
f'on device {self.src_device_obj} (device_ids[0]) but '
|
|
f'found one of them on device: {t.device}')
|
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
return self.module.val_step(*inputs[0], **kwargs[0])
|
|
|