|
import numpy as np |
|
from copy import deepcopy |
|
import torch |
|
from torch.backends import cudnn |
|
from torch.cuda.amp import GradScaler, autocast |
|
from torch.nn import Identity |
|
|
|
from nnunet.network_architecture.generic_UNet import Upsample |
|
from nnunet.network_architecture.generic_modular_UNet import PlainConvUNetDecoder, get_default_network_config |
|
from nnunet.network_architecture.neural_network import SegmentationNetwork |
|
from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss |
|
from torch import nn |
|
from torch.optim import SGD |
|
|
|
|
|
class BasicPreActResidualBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): |
|
""" |
|
This is norm nonlin conv norm nonlin conv |
|
:param in_planes: |
|
:param out_planes: |
|
:param props: |
|
:param override_stride: |
|
""" |
|
super().__init__() |
|
|
|
self.kernel_size = kernel_size |
|
props['conv_op_kwargs']['stride'] = 1 |
|
|
|
self.stride = stride |
|
self.props = props |
|
self.out_planes = out_planes |
|
self.in_planes = in_planes |
|
|
|
if stride is not None: |
|
kwargs_conv1 = deepcopy(props['conv_op_kwargs']) |
|
kwargs_conv1['stride'] = stride |
|
else: |
|
kwargs_conv1 = props['conv_op_kwargs'] |
|
|
|
self.norm1 = props['norm_op'](in_planes, **props['norm_op_kwargs']) |
|
self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) |
|
self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
|
**kwargs_conv1) |
|
|
|
if props['dropout_op_kwargs']['p'] != 0: |
|
self.dropout = props['dropout_op'](**props['dropout_op_kwargs']) |
|
else: |
|
self.dropout = Identity() |
|
|
|
self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
|
self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) |
|
self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
|
**props['conv_op_kwargs']) |
|
|
|
if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): |
|
stride_here = stride if stride is not None else 1 |
|
self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False)) |
|
else: |
|
self.downsample_skip = None |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.nonlin1(self.norm1(x)) |
|
|
|
if self.downsample_skip is not None: |
|
residual = self.downsample_skip(out) |
|
|
|
|
|
out = self.conv1(out) |
|
|
|
out = self.dropout(out) |
|
|
|
|
|
out = self.conv2(self.nonlin2(self.norm2(out))) |
|
|
|
out += residual |
|
|
|
return out |
|
|
|
|
|
class PreActResidualLayer(nn.Module): |
|
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None): |
|
super().__init__() |
|
|
|
network_props = deepcopy(network_props) |
|
|
|
self.convs = nn.Sequential( |
|
BasicPreActResidualBlock(input_channels, output_channels, kernel_size, network_props, first_stride), |
|
*[BasicPreActResidualBlock(output_channels, output_channels, kernel_size, network_props) for _ in |
|
range(num_blocks - 1)] |
|
) |
|
|
|
def forward(self, x): |
|
return self.convs(x) |
|
|
|
|
|
class PreActResidualUNetEncoder(nn.Module): |
|
def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale, |
|
pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True, |
|
max_num_features=480, pool_type: str = 'conv'): |
|
""" |
|
Following UNet building blocks can be added by utilizing the properties this class exposes (TODO) |
|
|
|
this one includes the bottleneck layer! |
|
|
|
:param input_channels: |
|
:param base_num_features: |
|
:param num_blocks_per_stage: |
|
:param feat_map_mul_on_downscale: |
|
:param pool_op_kernel_sizes: |
|
:param conv_kernel_sizes: |
|
:param props: |
|
""" |
|
super(PreActResidualUNetEncoder, self).__init__() |
|
|
|
self.default_return_skips = default_return_skips |
|
self.props = props |
|
|
|
pool_op = self._handle_pool(pool_type) |
|
|
|
self.stages = [] |
|
self.stage_output_features = [] |
|
self.stage_pool_kernel_size = [] |
|
self.stage_conv_op_kernel_size = [] |
|
|
|
assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes) |
|
|
|
num_stages = len(conv_kernel_sizes) |
|
|
|
if not isinstance(num_blocks_per_stage, (list, tuple)): |
|
num_blocks_per_stage = [num_blocks_per_stage] * num_stages |
|
else: |
|
assert len(num_blocks_per_stage) == num_stages |
|
|
|
self.num_blocks_per_stage = num_blocks_per_stage |
|
|
|
self.initial_conv = props['conv_op'](input_channels, base_num_features, 3, padding=1, **props['conv_op_kwargs']) |
|
|
|
current_input_features = base_num_features |
|
for stage in range(num_stages): |
|
current_output_features = min(base_num_features * feat_map_mul_on_downscale ** stage, max_num_features) |
|
current_kernel_size = conv_kernel_sizes[stage] |
|
|
|
current_pool_kernel_size = pool_op_kernel_sizes[stage] |
|
if pool_op is not None: |
|
pool_kernel_size_for_conv = [1 for i in current_pool_kernel_size] |
|
else: |
|
pool_kernel_size_for_conv = current_pool_kernel_size |
|
|
|
current_stage = PreActResidualLayer(current_input_features, current_output_features, current_kernel_size, props, |
|
self.num_blocks_per_stage[stage], pool_kernel_size_for_conv) |
|
if pool_op is not None: |
|
current_stage = nn.Sequential(pool_op(current_pool_kernel_size), current_stage) |
|
|
|
self.stages.append(current_stage) |
|
self.stage_output_features.append(current_output_features) |
|
self.stage_conv_op_kernel_size.append(current_kernel_size) |
|
self.stage_pool_kernel_size.append(current_pool_kernel_size) |
|
|
|
|
|
current_input_features = current_output_features |
|
|
|
self.stages = nn.ModuleList(self.stages) |
|
self.output_features = current_input_features |
|
|
|
def _handle_pool(self, pool_type): |
|
assert pool_type in ['conv', 'avg', 'max'] |
|
if pool_type == 'avg': |
|
if self.props['conv_op'] == nn.Conv2d: |
|
pool_op = nn.AvgPool2d |
|
elif self.props['conv_op'] == nn.Conv3d: |
|
pool_op = nn.AvgPool3d |
|
else: |
|
raise NotImplementedError |
|
elif pool_type == 'max': |
|
if self.props['conv_op'] == nn.Conv2d: |
|
pool_op = nn.MaxPool2d |
|
elif self.props['conv_op'] == nn.Conv3d: |
|
pool_op = nn.MaxPool3d |
|
else: |
|
raise NotImplementedError |
|
elif pool_type == 'conv': |
|
pool_op = None |
|
else: |
|
raise ValueError |
|
return pool_op |
|
|
|
def forward(self, x, return_skips=None): |
|
""" |
|
|
|
:param x: |
|
:param return_skips: if none then self.default_return_skips is used |
|
:return: |
|
""" |
|
skips = [] |
|
|
|
x = self.initial_conv(x) |
|
|
|
for s in self.stages: |
|
x = s(x) |
|
if self.default_return_skips: |
|
skips.append(x) |
|
|
|
if return_skips is None: |
|
return_skips = self.default_return_skips |
|
|
|
if return_skips: |
|
return skips |
|
else: |
|
return x |
|
|
|
@staticmethod |
|
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_modalities, pool_op_kernel_sizes, num_conv_per_stage_encoder, |
|
feat_map_mul_on_downscale, batch_size): |
|
npool = len(pool_op_kernel_sizes) - 1 |
|
|
|
current_shape = np.array(patch_size) |
|
|
|
tmp = (num_conv_per_stage_encoder[0] * 2 + 1) * np.prod(current_shape) * base_num_features \ |
|
+ num_modalities * np.prod(current_shape) |
|
|
|
num_feat = base_num_features |
|
|
|
for p in range(1, npool + 1): |
|
current_shape = current_shape / np.array(pool_op_kernel_sizes[p]) |
|
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features) |
|
num_convs = num_conv_per_stage_encoder[p] * 2 + 1 |
|
print(p, num_feat, num_convs, current_shape) |
|
tmp += num_convs * np.prod(current_shape) * num_feat |
|
return tmp * batch_size |
|
|
|
|
|
class PreActResidualUNetDecoder(nn.Module): |
|
def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False, |
|
upscale_logits=False): |
|
super(PreActResidualUNetDecoder, self).__init__() |
|
self.num_classes = num_classes |
|
self.deep_supervision = deep_supervision |
|
""" |
|
We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here |
|
""" |
|
previous_stages = previous.stages |
|
previous_stage_output_features = previous.stage_output_features |
|
previous_stage_pool_kernel_size = previous.stage_pool_kernel_size |
|
previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size |
|
|
|
if network_props is None: |
|
self.props = previous.props |
|
else: |
|
self.props = network_props |
|
|
|
if self.props['conv_op'] == nn.Conv2d: |
|
transpconv = nn.ConvTranspose2d |
|
upsample_mode = "bilinear" |
|
elif self.props['conv_op'] == nn.Conv3d: |
|
transpconv = nn.ConvTranspose3d |
|
upsample_mode = "trilinear" |
|
else: |
|
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op'])) |
|
|
|
if num_blocks_per_stage is None: |
|
num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1] |
|
|
|
assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1 |
|
|
|
self.stage_pool_kernel_size = previous_stage_pool_kernel_size |
|
self.stage_output_features = previous_stage_output_features |
|
self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size |
|
|
|
num_stages = len(previous_stages) - 1 |
|
|
|
|
|
self.tus = [] |
|
self.stages = [] |
|
self.deep_supervision_outputs = [] |
|
|
|
|
|
cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int) |
|
|
|
for i, s in enumerate(np.arange(num_stages)[::-1]): |
|
features_below = previous_stage_output_features[s + 1] |
|
features_skip = previous_stage_output_features[s] |
|
|
|
self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1], |
|
previous_stage_pool_kernel_size[s + 1], bias=False)) |
|
|
|
self.stages.append(PreActResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s], |
|
self.props, num_blocks_per_stage[i], None)) |
|
|
|
if deep_supervision and s != 0: |
|
norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs']) |
|
nonlin = self.props['nonlin'](**self.props['nonlin_kwargs']) |
|
seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True) |
|
if upscale_logits: |
|
upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode) |
|
self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer, upsample)) |
|
else: |
|
self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer)) |
|
|
|
self.segmentation_conv_norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs']) |
|
self.segmentation_conv_nonlin = self.props['nonlin'](**self.props['nonlin_kwargs']) |
|
self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True) |
|
self.segmentation_output = nn.Sequential(self.segmentation_conv_norm, self.segmentation_conv_nonlin, |
|
self.segmentation_output) |
|
|
|
self.tus = nn.ModuleList(self.tus) |
|
self.stages = nn.ModuleList(self.stages) |
|
self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs) |
|
|
|
def forward(self, skips): |
|
|
|
|
|
|
|
skips = skips[::-1] |
|
seg_outputs = [] |
|
|
|
x = skips[0] |
|
|
|
for i in range(len(self.tus)): |
|
x = self.tus[i](x) |
|
x = torch.cat((x, skips[i + 1]), dim=1) |
|
x = self.stages[i](x) |
|
if self.deep_supervision and (i != len(self.tus) - 1): |
|
seg_outputs.append(self.deep_supervision_outputs[i](x)) |
|
|
|
segmentation = self.segmentation_output(x) |
|
|
|
if self.deep_supervision: |
|
seg_outputs.append(segmentation) |
|
return seg_outputs[::-1] |
|
|
|
else: |
|
return segmentation |
|
|
|
@staticmethod |
|
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder, |
|
feat_map_mul_on_downscale, batch_size): |
|
""" |
|
This only applies for num_conv_per_stage and convolutional_upsampling=True |
|
not real vram consumption. just a constant term to which the vram consumption will be approx proportional |
|
(+ offset for parameter storage) |
|
:param patch_size: |
|
:param num_pool_per_axis: |
|
:param base_num_features: |
|
:param max_num_features: |
|
:return: |
|
""" |
|
npool = len(pool_op_kernel_sizes) - 1 |
|
|
|
current_shape = np.array(patch_size) |
|
tmp = (num_blocks_per_stage_decoder[-1] * 2 + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape) |
|
|
|
num_feat = base_num_features |
|
|
|
for p in range(1, npool): |
|
current_shape = current_shape / np.array(pool_op_kernel_sizes[p]) |
|
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features) |
|
num_convs = num_blocks_per_stage_decoder[-(p + 1)] * 2 + 1 + 1 |
|
print(p, num_feat, num_convs, current_shape) |
|
tmp += num_convs * np.prod(current_shape) * num_feat |
|
|
|
return tmp * batch_size |
|
|
|
|
|
class PreActResidualUNet(SegmentationNetwork): |
|
use_this_for_batch_size_computation_2D = 858931200.0 |
|
use_this_for_batch_size_computation_3D = 727842816.0 |
|
|
|
def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale, |
|
pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder, |
|
deep_supervision=False, upscale_logits=False, max_features=512, initializer=None): |
|
super(PreActResidualUNet, self).__init__() |
|
self.conv_op = props['conv_op'] |
|
self.num_classes = num_classes |
|
|
|
self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder, |
|
feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes, |
|
props, default_return_skips=True, max_num_features=max_features) |
|
self.decoder = PreActResidualUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props, |
|
deep_supervision, upscale_logits) |
|
if initializer is not None: |
|
self.apply(initializer) |
|
|
|
def forward(self, x): |
|
skips = self.encoder(x) |
|
return self.decoder(skips) |
|
|
|
@staticmethod |
|
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_modalities, num_classes, pool_op_kernel_sizes, num_conv_per_stage_encoder, |
|
num_conv_per_stage_decoder, feat_map_mul_on_downscale, batch_size): |
|
enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_modalities, pool_op_kernel_sizes, |
|
num_conv_per_stage_encoder, |
|
feat_map_mul_on_downscale, batch_size) |
|
dec = PreActResidualUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_classes, pool_op_kernel_sizes, |
|
num_conv_per_stage_decoder, |
|
feat_map_mul_on_downscale, batch_size) |
|
|
|
return enc + dec |
|
|
|
@staticmethod |
|
def compute_reference_for_vram_consumption_3d(): |
|
patch_size = (128, 128, 128) |
|
pool_op_kernel_sizes = ((1, 1, 1), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2)) |
|
blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1) |
|
blocks_per_stage_decoder = (1, 1, 1, 1, 1) |
|
|
|
return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes, |
|
blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 2) |
|
|
|
@staticmethod |
|
def compute_reference_for_vram_consumption_2d(): |
|
patch_size = (256, 256) |
|
pool_op_kernel_sizes = ( |
|
(1, 1), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2) |
|
) |
|
blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1, 1) |
|
blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1) |
|
|
|
return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes, |
|
blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 50) |
|
|
|
|
|
class FabiansPreActUNet(SegmentationNetwork): |
|
use_this_for_2D_configuration = 1792460800 |
|
use_this_for_3D_configuration = 1318592512 |
|
default_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) |
|
default_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) |
|
default_min_batch_size = 2 |
|
|
|
def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale, |
|
pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder, |
|
deep_supervision=False, upscale_logits=False, max_features=512, initializer=None): |
|
super().__init__() |
|
self.conv_op = props['conv_op'] |
|
self.num_classes = num_classes |
|
|
|
self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder, |
|
feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes, |
|
props, default_return_skips=True, max_num_features=max_features) |
|
props['dropout_op_kwargs']['p'] = 0 |
|
self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props, |
|
deep_supervision, upscale_logits) |
|
|
|
expected_num_skips = len(conv_kernel_sizes) - 1 |
|
num_features_skips = [min(max_features, base_num_features * 2**i) for i in range(expected_num_skips)] |
|
norm_nonlins = [] |
|
for i in range(expected_num_skips): |
|
norm_nonlins.append(nn.Sequential(props['norm_op'](num_features_skips[i], **props['norm_op_kwargs']), props['nonlin'](**props['nonlin_kwargs']))) |
|
self.norm_nonlins = nn.ModuleList(norm_nonlins) |
|
|
|
if initializer is not None: |
|
self.apply(initializer) |
|
|
|
def forward(self, x, gt=None, loss=None): |
|
skips = self.encoder(x) |
|
for i, op in enumerate(self.norm_nonlins): |
|
skips[i] = self.norm_nonlins[i](skips[i]) |
|
return self.decoder(skips, gt, loss) |
|
|
|
@staticmethod |
|
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder, |
|
num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size): |
|
enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_modalities, pool_op_kernel_sizes, |
|
num_blocks_per_stage_encoder, |
|
feat_map_mul_on_downscale, batch_size) |
|
dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
|
num_classes, pool_op_kernel_sizes, |
|
num_blocks_per_stage_decoder, |
|
feat_map_mul_on_downscale, batch_size) |
|
|
|
return enc + dec |
|
|
|
|
|
def find_3d_configuration(): |
|
cudnn.benchmark = True |
|
cudnn.deterministic = False |
|
|
|
conv_op_kernel_sizes = ((3, 3, 3), |
|
(3, 3, 3), |
|
(3, 3, 3), |
|
(3, 3, 3), |
|
(3, 3, 3), |
|
(3, 3, 3)) |
|
pool_op_kernel_sizes = ((1, 1, 1), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2), |
|
(2, 2, 2)) |
|
|
|
patch_size = (128, 128, 128) |
|
base_num_features = 32 |
|
input_modalities = 4 |
|
blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6) |
|
blocks_per_stage_decoder = (2, 2, 2, 2, 2) |
|
feat_map_mult_on_downscale = 2 |
|
num_classes = 5 |
|
max_features = 320 |
|
batch_size = 2 |
|
|
|
unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale, |
|
pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(3, dropout_p=None), num_classes, |
|
blocks_per_stage_decoder, True, False, max_features=max_features).cuda() |
|
|
|
scaler = GradScaler() |
|
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95) |
|
|
|
print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities, |
|
num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder, |
|
blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size)) |
|
|
|
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {}) |
|
|
|
dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda() |
|
dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long() |
|
|
|
for i in range(10): |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
skips = unet.encoder(dummy_input) |
|
print([i.shape for i in skips]) |
|
output = unet.decoder(skips)[0] |
|
|
|
l = loss(output, dummy_gt) |
|
print(l.item()) |
|
scaler.scale(l).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
with autocast(): |
|
import hiddenlayer as hl |
|
g = hl.build_graph(unet, dummy_input, transforms=None) |
|
g.save("/home/fabian/test_arch.pdf") |
|
|
|
|
|
def find_2d_configuration(): |
|
cudnn.benchmark = True |
|
cudnn.deterministic = False |
|
|
|
conv_op_kernel_sizes = ((3, 3), |
|
(3, 3), |
|
(3, 3), |
|
(3, 3), |
|
(3, 3), |
|
(3, 3), |
|
(3, 3)) |
|
pool_op_kernel_sizes = ((1, 1), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2), |
|
(2, 2)) |
|
|
|
patch_size = (256, 256) |
|
base_num_features = 32 |
|
input_modalities = 4 |
|
blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6) |
|
blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2) |
|
feat_map_mult_on_downscale = 2 |
|
num_classes = 5 |
|
max_features = 512 |
|
batch_size = 50 |
|
|
|
unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale, |
|
pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(2, dropout_p=None), num_classes, |
|
blocks_per_stage_decoder, True, False, max_features=max_features).cuda() |
|
|
|
scaler = GradScaler() |
|
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95) |
|
|
|
print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities, |
|
num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder, |
|
blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size)) |
|
|
|
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {}) |
|
|
|
dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda() |
|
dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long() |
|
|
|
for i in range(10): |
|
optimizer.zero_grad() |
|
|
|
with autocast(): |
|
skips = unet.encoder(dummy_input) |
|
print([i.shape for i in skips]) |
|
output = unet.decoder(skips)[0] |
|
|
|
l = loss(output, dummy_gt) |
|
print(l.item()) |
|
scaler.scale(l).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
with autocast(): |
|
import hiddenlayer as hl |
|
g = hl.build_graph(unet, dummy_input, transforms=None) |
|
g.save("/home/fabian/test_arch.pdf") |
|
|