|
|
|
import torch
|
|
from torch.nn.parallel._functions import Scatter as OrigScatter
|
|
|
|
from ._functions import Scatter
|
|
from .data_container import DataContainer
|
|
|
|
|
|
def scatter(inputs, target_gpus, dim=0):
|
|
"""Scatter inputs to target gpus.
|
|
|
|
The only difference from original :func:`scatter` is to add support for
|
|
:type:`~mmcv.parallel.DataContainer`.
|
|
"""
|
|
|
|
def scatter_map(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
if target_gpus != [-1]:
|
|
return OrigScatter.apply(target_gpus, None, dim, obj)
|
|
else:
|
|
|
|
return Scatter.forward(target_gpus, obj)
|
|
if isinstance(obj, DataContainer):
|
|
if obj.cpu_only:
|
|
return obj.data
|
|
else:
|
|
return Scatter.forward(target_gpus, obj.data)
|
|
if isinstance(obj, tuple) and len(obj) > 0:
|
|
return list(zip(*map(scatter_map, obj)))
|
|
if isinstance(obj, list) and len(obj) > 0:
|
|
out = list(map(list, zip(*map(scatter_map, obj))))
|
|
return out
|
|
if isinstance(obj, dict) and len(obj) > 0:
|
|
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
|
return out
|
|
return [obj for targets in target_gpus]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
return scatter_map(inputs)
|
|
finally:
|
|
scatter_map = None
|
|
|
|
|
|
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
|
|
"""Scatter with support for kwargs dictionary."""
|
|
inputs = scatter(inputs, target_gpus, dim) if inputs else []
|
|
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
|
if len(inputs) < len(kwargs):
|
|
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
|
elif len(kwargs) < len(inputs):
|
|
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
|
inputs = tuple(inputs)
|
|
kwargs = tuple(kwargs)
|
|
return inputs, kwargs
|
|
|