File size: 27,608 Bytes
c642393 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 |
import numpy as np
from copy import deepcopy
import torch
from torch.backends import cudnn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import Identity
from nnunet.network_architecture.generic_UNet import Upsample
from nnunet.network_architecture.generic_modular_UNet import PlainConvUNetDecoder, get_default_network_config
from nnunet.network_architecture.neural_network import SegmentationNetwork
from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
from torch import nn
from torch.optim import SGD
class BasicPreActResidualBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, props, stride=None):
"""
This is norm nonlin conv norm nonlin conv
:param in_planes:
:param out_planes:
:param props:
:param override_stride:
"""
super().__init__()
self.kernel_size = kernel_size
props['conv_op_kwargs']['stride'] = 1
self.stride = stride
self.props = props
self.out_planes = out_planes
self.in_planes = in_planes
if stride is not None:
kwargs_conv1 = deepcopy(props['conv_op_kwargs'])
kwargs_conv1['stride'] = stride
else:
kwargs_conv1 = props['conv_op_kwargs']
self.norm1 = props['norm_op'](in_planes, **props['norm_op_kwargs'])
self.nonlin1 = props['nonlin'](**props['nonlin_kwargs'])
self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
**kwargs_conv1)
if props['dropout_op_kwargs']['p'] != 0:
self.dropout = props['dropout_op'](**props['dropout_op_kwargs'])
else:
self.dropout = Identity()
self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs'])
self.nonlin2 = props['nonlin'](**props['nonlin_kwargs'])
self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
**props['conv_op_kwargs'])
if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes):
stride_here = stride if stride is not None else 1
self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False))
else:
self.downsample_skip = None
def forward(self, x):
residual = x
out = self.nonlin1(self.norm1(x))
if self.downsample_skip is not None:
residual = self.downsample_skip(out)
# norm nonlin conv
out = self.conv1(out)
out = self.dropout(out) # this does nothing if props['dropout_op_kwargs'] == 0
# norm nonlin conv
out = self.conv2(self.nonlin2(self.norm2(out)))
out += residual
return out
class PreActResidualLayer(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None):
super().__init__()
network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
self.convs = nn.Sequential(
BasicPreActResidualBlock(input_channels, output_channels, kernel_size, network_props, first_stride),
*[BasicPreActResidualBlock(output_channels, output_channels, kernel_size, network_props) for _ in
range(num_blocks - 1)]
)
def forward(self, x):
return self.convs(x)
class PreActResidualUNetEncoder(nn.Module):
def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale,
pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True,
max_num_features=480, pool_type: str = 'conv'):
"""
Following UNet building blocks can be added by utilizing the properties this class exposes (TODO)
this one includes the bottleneck layer!
:param input_channels:
:param base_num_features:
:param num_blocks_per_stage:
:param feat_map_mul_on_downscale:
:param pool_op_kernel_sizes:
:param conv_kernel_sizes:
:param props:
"""
super(PreActResidualUNetEncoder, self).__init__()
self.default_return_skips = default_return_skips
self.props = props
pool_op = self._handle_pool(pool_type)
self.stages = []
self.stage_output_features = []
self.stage_pool_kernel_size = []
self.stage_conv_op_kernel_size = []
assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes)
num_stages = len(conv_kernel_sizes)
if not isinstance(num_blocks_per_stage, (list, tuple)):
num_blocks_per_stage = [num_blocks_per_stage] * num_stages
else:
assert len(num_blocks_per_stage) == num_stages
self.num_blocks_per_stage = num_blocks_per_stage # decoder may need this
self.initial_conv = props['conv_op'](input_channels, base_num_features, 3, padding=1, **props['conv_op_kwargs'])
current_input_features = base_num_features
for stage in range(num_stages):
current_output_features = min(base_num_features * feat_map_mul_on_downscale ** stage, max_num_features)
current_kernel_size = conv_kernel_sizes[stage]
current_pool_kernel_size = pool_op_kernel_sizes[stage]
if pool_op is not None:
pool_kernel_size_for_conv = [1 for i in current_pool_kernel_size]
else:
pool_kernel_size_for_conv = current_pool_kernel_size
current_stage = PreActResidualLayer(current_input_features, current_output_features, current_kernel_size, props,
self.num_blocks_per_stage[stage], pool_kernel_size_for_conv)
if pool_op is not None:
current_stage = nn.Sequential(pool_op(current_pool_kernel_size), current_stage)
self.stages.append(current_stage)
self.stage_output_features.append(current_output_features)
self.stage_conv_op_kernel_size.append(current_kernel_size)
self.stage_pool_kernel_size.append(current_pool_kernel_size)
# update current_input_features
current_input_features = current_output_features
self.stages = nn.ModuleList(self.stages)
self.output_features = current_input_features
def _handle_pool(self, pool_type):
assert pool_type in ['conv', 'avg', 'max']
if pool_type == 'avg':
if self.props['conv_op'] == nn.Conv2d:
pool_op = nn.AvgPool2d
elif self.props['conv_op'] == nn.Conv3d:
pool_op = nn.AvgPool3d
else:
raise NotImplementedError
elif pool_type == 'max':
if self.props['conv_op'] == nn.Conv2d:
pool_op = nn.MaxPool2d
elif self.props['conv_op'] == nn.Conv3d:
pool_op = nn.MaxPool3d
else:
raise NotImplementedError
elif pool_type == 'conv':
pool_op = None
else:
raise ValueError
return pool_op
def forward(self, x, return_skips=None):
"""
:param x:
:param return_skips: if none then self.default_return_skips is used
:return:
"""
skips = []
x = self.initial_conv(x)
for s in self.stages:
x = s(x)
if self.default_return_skips:
skips.append(x)
if return_skips is None:
return_skips = self.default_return_skips
if return_skips:
return skips
else:
return x
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, pool_op_kernel_sizes, num_conv_per_stage_encoder,
feat_map_mul_on_downscale, batch_size):
npool = len(pool_op_kernel_sizes) - 1
current_shape = np.array(patch_size)
tmp = (num_conv_per_stage_encoder[0] * 2 + 1) * np.prod(current_shape) * base_num_features \
+ num_modalities * np.prod(current_shape)
num_feat = base_num_features
for p in range(1, npool + 1):
current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
num_convs = num_conv_per_stage_encoder[p] * 2 + 1 # + 1 for conv in skip in first block
print(p, num_feat, num_convs, current_shape)
tmp += num_convs * np.prod(current_shape) * num_feat
return tmp * batch_size
class PreActResidualUNetDecoder(nn.Module):
def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
upscale_logits=False):
super(PreActResidualUNetDecoder, self).__init__()
self.num_classes = num_classes
self.deep_supervision = deep_supervision
"""
We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
"""
previous_stages = previous.stages
previous_stage_output_features = previous.stage_output_features
previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size
if network_props is None:
self.props = previous.props
else:
self.props = network_props
if self.props['conv_op'] == nn.Conv2d:
transpconv = nn.ConvTranspose2d
upsample_mode = "bilinear"
elif self.props['conv_op'] == nn.Conv3d:
transpconv = nn.ConvTranspose3d
upsample_mode = "trilinear"
else:
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))
if num_blocks_per_stage is None:
num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1
self.stage_pool_kernel_size = previous_stage_pool_kernel_size
self.stage_output_features = previous_stage_output_features
self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size
num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the
# bottleneck
self.tus = []
self.stages = []
self.deep_supervision_outputs = []
# only used for upsample_logits
cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)
for i, s in enumerate(np.arange(num_stages)[::-1]):
features_below = previous_stage_output_features[s + 1]
features_skip = previous_stage_output_features[s]
self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
previous_stage_pool_kernel_size[s + 1], bias=False))
# after we tu we concat features so now we have 2xfeatures_skip
self.stages.append(PreActResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s],
self.props, num_blocks_per_stage[i], None))
if deep_supervision and s != 0:
norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs'])
nonlin = self.props['nonlin'](**self.props['nonlin_kwargs'])
seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True)
if upscale_logits:
upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer, upsample))
else:
self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer))
self.segmentation_conv_norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs'])
self.segmentation_conv_nonlin = self.props['nonlin'](**self.props['nonlin_kwargs'])
self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True)
self.segmentation_output = nn.Sequential(self.segmentation_conv_norm, self.segmentation_conv_nonlin,
self.segmentation_output)
self.tus = nn.ModuleList(self.tus)
self.stages = nn.ModuleList(self.stages)
self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
def forward(self, skips):
# skips come from the encoder. They are sorted so that the bottleneck is last in the list
# what is maybe not perfect is that the TUs and stages here are sorted the other way around
# so let's just reverse the order of skips
skips = skips[::-1]
seg_outputs = []
x = skips[0] # this is the bottleneck
for i in range(len(self.tus)):
x = self.tus[i](x)
x = torch.cat((x, skips[i + 1]), dim=1)
x = self.stages[i](x)
if self.deep_supervision and (i != len(self.tus) - 1):
seg_outputs.append(self.deep_supervision_outputs[i](x))
segmentation = self.segmentation_output(x)
if self.deep_supervision:
seg_outputs.append(segmentation)
return seg_outputs[::-1] # seg_outputs are ordered so that the seg from the highest layer is first, the seg from
# the bottleneck of the UNet last
else:
return segmentation
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder,
feat_map_mul_on_downscale, batch_size):
"""
This only applies for num_conv_per_stage and convolutional_upsampling=True
not real vram consumption. just a constant term to which the vram consumption will be approx proportional
(+ offset for parameter storage)
:param patch_size:
:param num_pool_per_axis:
:param base_num_features:
:param max_num_features:
:return:
"""
npool = len(pool_op_kernel_sizes) - 1
current_shape = np.array(patch_size)
tmp = (num_blocks_per_stage_decoder[-1] * 2 + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape)
num_feat = base_num_features
for p in range(1, npool):
current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
num_convs = num_blocks_per_stage_decoder[-(p + 1)] * 2 + 1 + 1 # +1 for transpconv and +1 for conv in skip
print(p, num_feat, num_convs, current_shape)
tmp += num_convs * np.prod(current_shape) * num_feat
return tmp * batch_size
class PreActResidualUNet(SegmentationNetwork):
use_this_for_batch_size_computation_2D = 858931200.0 # 1167982592.0
use_this_for_batch_size_computation_3D = 727842816.0 # 1152286720.0
def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
super(PreActResidualUNet, self).__init__()
self.conv_op = props['conv_op']
self.num_classes = num_classes
self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
props, default_return_skips=True, max_num_features=max_features)
self.decoder = PreActResidualUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
deep_supervision, upscale_logits)
if initializer is not None:
self.apply(initializer)
def forward(self, x):
skips = self.encoder(x)
return self.decoder(skips)
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, num_classes, pool_op_kernel_sizes, num_conv_per_stage_encoder,
num_conv_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, pool_op_kernel_sizes,
num_conv_per_stage_encoder,
feat_map_mul_on_downscale, batch_size)
dec = PreActResidualUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_classes, pool_op_kernel_sizes,
num_conv_per_stage_decoder,
feat_map_mul_on_downscale, batch_size)
return enc + dec
@staticmethod
def compute_reference_for_vram_consumption_3d():
patch_size = (128, 128, 128)
pool_op_kernel_sizes = ((1, 1, 1),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2))
blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1)
blocks_per_stage_decoder = (1, 1, 1, 1, 1)
return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes,
blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 2)
@staticmethod
def compute_reference_for_vram_consumption_2d():
patch_size = (256, 256)
pool_op_kernel_sizes = (
(1, 1), # (256, 256)
(2, 2), # (128, 128)
(2, 2), # (64, 64)
(2, 2), # (32, 32)
(2, 2), # (16, 16)
(2, 2), # (8, 8)
(2, 2) # (4, 4)
)
blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1, 1)
blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1)
return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes,
blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 50)
class FabiansPreActUNet(SegmentationNetwork):
use_this_for_2D_configuration = 1792460800
use_this_for_3D_configuration = 1318592512
default_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
default_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
default_min_batch_size = 2 # this is what works with the numbers above
def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
super().__init__()
self.conv_op = props['conv_op']
self.num_classes = num_classes
self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
props, default_return_skips=True, max_num_features=max_features)
props['dropout_op_kwargs']['p'] = 0
self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
deep_supervision, upscale_logits)
expected_num_skips = len(conv_kernel_sizes) - 1
num_features_skips = [min(max_features, base_num_features * 2**i) for i in range(expected_num_skips)]
norm_nonlins = []
for i in range(expected_num_skips):
norm_nonlins.append(nn.Sequential(props['norm_op'](num_features_skips[i], **props['norm_op_kwargs']), props['nonlin'](**props['nonlin_kwargs'])))
self.norm_nonlins = nn.ModuleList(norm_nonlins)
if initializer is not None:
self.apply(initializer)
def forward(self, x, gt=None, loss=None):
skips = self.encoder(x)
for i, op in enumerate(self.norm_nonlins):
skips[i] = self.norm_nonlins[i](skips[i])
return self.decoder(skips, gt, loss)
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, pool_op_kernel_sizes,
num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, batch_size)
dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_classes, pool_op_kernel_sizes,
num_blocks_per_stage_decoder,
feat_map_mul_on_downscale, batch_size)
return enc + dec
def find_3d_configuration():
cudnn.benchmark = True
cudnn.deterministic = False
conv_op_kernel_sizes = ((3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3))
pool_op_kernel_sizes = ((1, 1, 1),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2))
patch_size = (128, 128, 128)
base_num_features = 32
input_modalities = 4
blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6)
blocks_per_stage_decoder = (2, 2, 2, 2, 2)
feat_map_mult_on_downscale = 2
num_classes = 5
max_features = 320
batch_size = 2
unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale,
pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(3, dropout_p=None), num_classes,
blocks_per_stage_decoder, True, False, max_features=max_features).cuda()
scaler = GradScaler()
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities,
num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder,
blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size))
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda()
dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long()
for i in range(10):
optimizer.zero_grad()
with autocast():
skips = unet.encoder(dummy_input)
print([i.shape for i in skips])
output = unet.decoder(skips)[0]
l = loss(output, dummy_gt)
print(l.item())
scaler.scale(l).backward()
scaler.step(optimizer)
scaler.update()
with autocast():
import hiddenlayer as hl
g = hl.build_graph(unet, dummy_input, transforms=None)
g.save("/home/fabian/test_arch.pdf")
def find_2d_configuration():
cudnn.benchmark = True
cudnn.deterministic = False
conv_op_kernel_sizes = ((3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3))
pool_op_kernel_sizes = ((1, 1),
(2, 2),
(2, 2),
(2, 2),
(2, 2),
(2, 2),
(2, 2))
patch_size = (256, 256)
base_num_features = 32
input_modalities = 4
blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6)
blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2)
feat_map_mult_on_downscale = 2
num_classes = 5
max_features = 512
batch_size = 50
unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale,
pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(2, dropout_p=None), num_classes,
blocks_per_stage_decoder, True, False, max_features=max_features).cuda()
scaler = GradScaler()
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities,
num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder,
blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size))
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda()
dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long()
for i in range(10):
optimizer.zero_grad()
with autocast():
skips = unet.encoder(dummy_input)
print([i.shape for i in skips])
output = unet.decoder(skips)[0]
l = loss(output, dummy_gt)
print(l.item())
scaler.scale(l).backward()
scaler.step(optimizer)
scaler.update()
with autocast():
import hiddenlayer as hl
g = hl.build_graph(unet, dummy_input, transforms=None)
g.save("/home/fabian/test_arch.pdf")
|