Spaces:
Runtime error
Runtime error
File size: 27,282 Bytes
a983ebc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 |
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_layers.ipynb.
# %% ../nbs/01_layers.ipynb 2
from __future__ import annotations
from .imports import *
from .torch_imports import *
from .torch_core import *
from torch.nn.utils import weight_norm, spectral_norm
# %% auto 0
__all__ = ['NormType', 'inplace_relu', 'module', 'Identity', 'Lambda', 'PartialLambda', 'Flatten', 'ToTensorBase', 'View',
'ResizeBatch', 'Debugger', 'sigmoid_range', 'SigmoidRange', 'AdaptiveConcatPool1d', 'AdaptiveConcatPool2d',
'PoolType', 'adaptive_pool', 'PoolFlatten', 'BatchNorm', 'InstanceNorm', 'BatchNorm1dFlat', 'LinBnDrop',
'sigmoid', 'sigmoid_', 'vleaky_relu', 'init_default', 'init_linear', 'ConvLayer', 'AdaptiveAvgPool',
'MaxPool', 'AvgPool', 'trunc_normal_', 'Embedding', 'SelfAttention', 'PooledSelfAttention2d',
'SimpleSelfAttention', 'icnr_init', 'PixelShuffle_ICNR', 'sequential', 'SequentialEx', 'MergeLayer', 'Cat',
'SimpleCNN', 'ProdLayer', 'SEModule', 'ResBlock', 'SEBlock', 'SEResNeXtBlock', 'SeparableBlock',
'TimeDistributed', 'swish', 'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule',
'children_and_parameters', 'has_children', 'flatten_model', 'NoneReduce', 'in_channels']
# %% ../nbs/01_layers.ipynb 6
def module(*flds, **defaults):
"Decorator to create an `nn.Module` using `f` as `forward` method"
pa = [inspect.Parameter(o, inspect.Parameter.POSITIONAL_OR_KEYWORD) for o in flds]
pb = [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=v)
for k,v in defaults.items()]
params = pa+pb
all_flds = [*flds,*defaults.keys()]
def _f(f):
class c(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
for i,o in enumerate(args): kwargs[all_flds[i]] = o
kwargs = merge(defaults,kwargs)
for k,v in kwargs.items(): setattr(self,k,v)
__repr__ = basic_repr(all_flds)
forward = f
c.__signature__ = inspect.Signature(params)
c.__name__ = c.__qualname__ = f.__name__
c.__doc__ = f.__doc__
return c
return _f
# %% ../nbs/01_layers.ipynb 7
@module()
def Identity(self, x):
"Do nothing at all"
return x
# %% ../nbs/01_layers.ipynb 9
@module('func')
def Lambda(self, x):
"An easy way to create a pytorch layer for a simple `func`"
return self.func(x)
# %% ../nbs/01_layers.ipynb 11
class PartialLambda(Lambda):
"Layer that applies `partial(func, **kwargs)`"
def __init__(self, func, **kwargs):
super().__init__(partial(func, **kwargs))
self.repr = f'{func.__name__}, {kwargs}'
def forward(self, x): return self.func(x)
def __repr__(self): return f'{self.__class__.__name__}({self.repr})'
# %% ../nbs/01_layers.ipynb 13
@module(full=False)
def Flatten(self, x):
"Flatten `x` to a single dimension, e.g. at end of a model. `full` for rank-1 tensor"
return x.view(-1) if self.full else x.view(x.size(0), -1) # Removed cast to Tensorbase
# %% ../nbs/01_layers.ipynb 15
@module(tensor_cls=TensorBase)
def ToTensorBase(self, x):
"Convert x to TensorBase class"
return self.tensor_cls(x)
# %% ../nbs/01_layers.ipynb 17
class View(Module):
"Reshape `x` to `size`"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view(self.size)
# %% ../nbs/01_layers.ipynb 19
class ResizeBatch(Module):
"Reshape `x` to `size`, keeping batch dim the same size"
def __init__(self, *size): self.size = size
def forward(self, x): return x.view((x.size(0),) + self.size)
# %% ../nbs/01_layers.ipynb 21
@module()
def Debugger(self,x):
"A module to debug inside a model."
set_trace()
return x
# %% ../nbs/01_layers.ipynb 22
def sigmoid_range(x, low, high):
"Sigmoid function with range `(low, high)`"
return torch.sigmoid(x) * (high - low) + low
# %% ../nbs/01_layers.ipynb 24
@module('low','high')
def SigmoidRange(self, x):
"Sigmoid module with range `(low, high)`"
return sigmoid_range(x, self.low, self.high)
# %% ../nbs/01_layers.ipynb 27
class AdaptiveConcatPool1d(Module):
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`"
def __init__(self, size=None):
self.size = size or 1
self.ap = nn.AdaptiveAvgPool1d(self.size)
self.mp = nn.AdaptiveMaxPool1d(self.size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
# %% ../nbs/01_layers.ipynb 28
class AdaptiveConcatPool2d(Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`"
def __init__(self, size=None):
self.size = size or 1
self.ap = nn.AdaptiveAvgPool2d(self.size)
self.mp = nn.AdaptiveMaxPool2d(self.size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
# %% ../nbs/01_layers.ipynb 31
class PoolType: Avg,Max,Cat = 'Avg','Max','Cat'
# %% ../nbs/01_layers.ipynb 32
def adaptive_pool(pool_type):
return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d
# %% ../nbs/01_layers.ipynb 33
class PoolFlatten(nn.Sequential):
"Combine `nn.AdaptiveAvgPool2d` and `Flatten`."
def __init__(self, pool_type=PoolType.Avg): super().__init__(adaptive_pool(pool_type)(1), Flatten())
# %% ../nbs/01_layers.ipynb 36
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')
# %% ../nbs/01_layers.ipynb 37
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):
"Norm layer with `nf` features and `ndim` initialized depending on `norm_type`."
assert 1 <= ndim <= 3
bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs)
if bn.affine:
bn.bias.data.fill_(1e-3)
bn.weight.data.fill_(0. if zero else 1.)
return bn
# %% ../nbs/01_layers.ipynb 38
@delegates(nn.BatchNorm2d)
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):
"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)
# %% ../nbs/01_layers.ipynb 39
@delegates(nn.InstanceNorm2d)
def InstanceNorm(nf, ndim=2, norm_type=NormType.Instance, affine=True, **kwargs):
"InstanceNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
return _get_norm('InstanceNorm', nf, ndim, zero=norm_type==NormType.InstanceZero, affine=affine, **kwargs)
# %% ../nbs/01_layers.ipynb 45
class BatchNorm1dFlat(nn.BatchNorm1d):
"`nn.BatchNorm1d`, but first flattens leading dimensions"
def forward(self, x):
if x.dim()==2: return super().forward(x)
*f,l = x.shape
x = x.contiguous().view(-1,l)
return super().forward(x).view(*f,l)
# %% ../nbs/01_layers.ipynb 47
class LinBnDrop(nn.Sequential):
"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
lin = [nn.Linear(n_in, n_out, bias=not bn)]
if act is not None: lin.append(act)
layers = lin+layers if lin_first else layers+lin
super().__init__(*layers)
# %% ../nbs/01_layers.ipynb 51
def sigmoid(input, eps=1e-7):
"Same as `torch.sigmoid`, plus clamping to `(eps,1-eps)"
return input.sigmoid().clamp(eps,1-eps)
# %% ../nbs/01_layers.ipynb 52
def sigmoid_(input, eps=1e-7):
"Same as `torch.sigmoid_`, plus clamping to `(eps,1-eps)"
return input.sigmoid_().clamp_(eps,1-eps)
# %% ../nbs/01_layers.ipynb 53
from torch.nn.init import kaiming_uniform_,uniform_,xavier_uniform_,normal_
# %% ../nbs/01_layers.ipynb 54
def vleaky_relu(input, inplace=True):
"`F.leaky_relu` with 0.3 slope"
return F.leaky_relu(input, negative_slope=0.3, inplace=inplace)
# %% ../nbs/01_layers.ipynb 55
for o in F.relu,nn.ReLU,F.relu6,nn.ReLU6,F.leaky_relu,nn.LeakyReLU:
o.__default_init__ = kaiming_uniform_
# %% ../nbs/01_layers.ipynb 56
for o in F.sigmoid,nn.Sigmoid,F.tanh,nn.Tanh,sigmoid,sigmoid_:
o.__default_init__ = xavier_uniform_
# %% ../nbs/01_layers.ipynb 57
def init_default(m, func=nn.init.kaiming_normal_):
"Initialize `m` weights with `func` and set `bias` to 0."
if func and hasattr(m, 'weight'): func(m.weight)
with torch.no_grad(): nested_callable(m, 'bias.fill_')(0.)
return m
# %% ../nbs/01_layers.ipynb 58
def init_linear(m, act_func=None, init='auto', bias_std=0.01):
if getattr(m,'bias',None) is not None and bias_std is not None:
if bias_std != 0: normal_(m.bias, 0, bias_std)
else: m.bias.data.zero_()
if init=='auto':
if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_
else: init = nested_callable(act_func, '__class__.__default_init__')
if init == noop: init = getcallable(act_func, '__default_init__')
if callable(init): init(m.weight)
# %% ../nbs/01_layers.ipynb 60
def _conv_func(ndim=2, transpose=False):
"Return the proper conv `ndim` function, potentially `transposed`."
assert 1 <= ndim <=3
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')
# %% ../nbs/01_layers.ipynb 62
defaults.activation=nn.ReLU
# %% ../nbs/01_layers.ipynb 63
class ConvLayer(nn.Sequential):
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
@delegates(nn.Conv2d)
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,
act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs):
if padding is None: padding = ((ks-1)//2 if not transpose else 0)
bn = norm_type in (NormType.Batch, NormType.BatchZero)
inn = norm_type in (NormType.Instance, NormType.InstanceZero)
if bias is None: bias = not (bn or inn)
conv_func = _conv_func(ndim, transpose=transpose)
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs)
act = None if act_cls is None else act_cls()
init_linear(conv, act, init=init, bias_std=bias_std)
if norm_type==NormType.Weight: conv = weight_norm(conv)
elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
layers = [conv]
act_bn = []
if act is not None: act_bn.append(act)
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))
if bn_1st: act_bn.reverse()
layers += act_bn
if xtra: layers.append(xtra)
super().__init__(*layers)
# %% ../nbs/01_layers.ipynb 77
def AdaptiveAvgPool(sz=1, ndim=2):
"nn.AdaptiveAvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz)
# %% ../nbs/01_layers.ipynb 78
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
"nn.MaxPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding)
# %% ../nbs/01_layers.ipynb 79
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):
"nn.AvgPool layer for `ndim`"
assert 1 <= ndim <= 3
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)
# %% ../nbs/01_layers.ipynb 81
def trunc_normal_(x, mean=0., std=1.):
"Truncated normal initialization (approximation)"
# From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
return x.normal_().fmod_(2).mul_(std).add_(mean)
# %% ../nbs/01_layers.ipynb 82
class Embedding(nn.Embedding):
"Embedding layer with truncated normal initialization"
def __init__(self, ni, nf, std=0.01):
super().__init__(ni, nf)
trunc_normal_(self.weight.data, std=std)
# %% ../nbs/01_layers.ipynb 86
class SelfAttention(Module):
"Self attention layer for `n_channels`."
def __init__(self, n_channels):
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
#Notation from the paper.
size = x.size()
x = x.view(*size[:2],-1)
f,g,h = self.query(x),self.key(x),self.value(x)
beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
o = self.gamma * torch.bmm(h, beta) + x
return o.view(*size).contiguous()
# %% ../nbs/01_layers.ipynb 95
class PooledSelfAttention2d(Module):
"Pooled self attention layer for 2d."
def __init__(self, n_channels):
self.n_channels = n_channels
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)]
self.out = self._conv(n_channels//2, n_channels)
self.gamma = nn.Parameter(tensor([0.]))
def _conv(self,n_in,n_out):
return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False)
def forward(self, x):
n_ftrs = x.shape[2]*x.shape[3]
f = self.query(x).view(-1, self.n_channels//8, n_ftrs)
g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4)
h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4)
beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1)
o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3]))
return self.gamma * o + x
# %% ../nbs/01_layers.ipynb 97
def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return spectral_norm(conv)
# %% ../nbs/01_layers.ipynb 98
class SimpleSelfAttention(Module):
def __init__(self, n_in:int, ks=1, sym=False):
self.sym,self.n_in = sym,n_in
self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(tensor([0.]))
def forward(self,x):
if self.sym:
c = self.conv.weight.view(self.n_in,self.n_in)
c = (c + c.t())/2
self.conv.weight = c.view(self.n_in,self.n_in,1)
size = x.size()
x = x.view(*size[:2],-1)
convx = self.conv(x)
xxT = torch.bmm(x,x.permute(0,2,1).contiguous())
o = torch.bmm(xxT, convx)
o = self.gamma * o + x
return o.view(*size).contiguous()
# %% ../nbs/01_layers.ipynb 101
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):
"ICNR init of `x`, with `scale` and `init` function"
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
# %% ../nbs/01_layers.ipynb 104
class PixelShuffle_ICNR(nn.Sequential):
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`."
def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation):
super().__init__()
nf = ifnone(nf, ni)
layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0),
nn.PixelShuffle(scale)]
if norm_type == NormType.Weight:
layers[0][0].weight_v.data.copy_(icnr_init(layers[0][0].weight_v.data))
layers[0][0].weight_g.data.copy_(((layers[0][0].weight_v.data**2).sum(dim=[1,2,3])**0.5)[:,None,None,None])
else:
layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data))
if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]
super().__init__(*layers)
# %% ../nbs/01_layers.ipynb 110
def sequential(*args):
"Create an `nn.Sequential`, wrapping items with `Lambda` if needed"
if len(args) != 1 or not isinstance(args[0], OrderedDict):
args = list(args)
for i,o in enumerate(args):
if not isinstance(o,nn.Module): args[i] = Lambda(o)
return nn.Sequential(*args)
# %% ../nbs/01_layers.ipynb 111
class SequentialEx(Module):
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
def __init__(self, *layers): self.layers = nn.ModuleList(layers)
def forward(self, x):
res = x
for l in self.layers:
res.orig = x
nres = l(res)
# We have to remove res.orig to avoid hanging refs and therefore memory leaks
res.orig, nres.orig = None, None
res = nres
return res
def __getitem__(self,i): return self.layers[i]
def append(self,l): return self.layers.append(l)
def extend(self,l): return self.layers.extend(l)
def insert(self,i,l): return self.layers.insert(i,l)
# %% ../nbs/01_layers.ipynb 113
class MergeLayer(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, dense:bool=False): self.dense=dense
def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)
# %% ../nbs/01_layers.ipynb 118
class Cat(nn.ModuleList):
"Concatenate layers outputs over a given dim"
def __init__(self, layers, dim=1):
self.dim=dim
super().__init__(layers)
def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim)
# %% ../nbs/01_layers.ipynb 121
class SimpleCNN(nn.Sequential):
"Create a simple CNN with `filters`."
def __init__(self, filters, kernel_szs=None, strides=None, bn=True):
nl = len(filters)-1
kernel_szs = ifnone(kernel_szs, [3]*nl)
strides = ifnone(strides , [2]*nl)
layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],
norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range(nl)]
layers.append(PoolFlatten())
super().__init__(*layers)
# %% ../nbs/01_layers.ipynb 128
class ProdLayer(Module):
"Merge a shortcut with the result of the module by multiplying them."
def forward(self, x): return x * x.orig
# %% ../nbs/01_layers.ipynb 129
inplace_relu = partial(nn.ReLU, inplace=True)
# %% ../nbs/01_layers.ipynb 130
def SEModule(ch, reduction, act_cls=defaults.activation):
nf = math.ceil(ch//reduction/8)*8
return SequentialEx(nn.AdaptiveAvgPool2d(1),
ConvLayer(ch, nf, ks=1, norm_type=None, act_cls=act_cls),
ConvLayer(nf, ch, ks=1, norm_type=None, act_cls=nn.Sigmoid),
ProdLayer())
# %% ../nbs/01_layers.ipynb 131
class ResBlock(Module):
"Resnet block from `ni` to `nh` with `stride`"
@delegates(ConvLayer.__init__)
def __init__(self, expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,
sa=False, sym=False, norm_type=NormType.Batch, act_cls=defaults.activation, ndim=2, ks=3,
pool=AvgPool, pool_first=True, **kwargs):
norm2 = (NormType.BatchZero if norm_type==NormType.Batch else
NormType.InstanceZero if norm_type==NormType.Instance else norm_type)
if nh2 is None: nh2 = nf
if nh1 is None: nh1 = nh2
nf,ni = nf*expansion,ni*expansion
k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
convpath = [ConvLayer(ni, nh2, ks, stride=stride, groups=ni if dw else groups, **k0),
ConvLayer(nh2, nf, ks, groups=g2, **k1)
] if expansion == 1 else [
ConvLayer(ni, nh1, 1, **k0),
ConvLayer(nh1, nh2, ks, stride=stride, groups=nh1 if dw else groups, **k0),
ConvLayer(nh2, nf, 1, groups=g2, **k1)]
if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))
if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))
self.convpath = nn.Sequential(*convpath)
idpath = []
if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
if stride!=1: idpath.insert((1,0)[pool_first], pool(stride, ndim=ndim, ceil_mode=True))
self.idpath = nn.Sequential(*idpath)
self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()
def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))
# %% ../nbs/01_layers.ipynb 133
def SEBlock(expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs):
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh1=nf*2, nh2=nf*expansion, **kwargs)
# %% ../nbs/01_layers.ipynb 134
def SEResNeXtBlock(expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs):
w = math.floor(nf * (base_width / 64)) * groups
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh2=w, **kwargs)
# %% ../nbs/01_layers.ipynb 135
def SeparableBlock(expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs):
return ResBlock(expansion, ni, nf, stride=stride, reduction=reduction, nh2=nf*2, dw=True, **kwargs)
# %% ../nbs/01_layers.ipynb 138
def _stack_tups(tuples, stack_dim=1):
"Stack tuple of tensors along `stack_dim`"
return tuple(torch.stack([t[i] for t in tuples], dim=stack_dim) for i in range_of(tuples[0]))
# %% ../nbs/01_layers.ipynb 139
class TimeDistributed(Module):
"Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time."
def __init__(self, module, low_mem=False, tdim=1):
store_attr()
def forward(self, *tensors, **kwargs):
"input x with shape:(bs,seq_len,channels,width,height)"
if self.low_mem or self.tdim!=1:
return self.low_mem_forward(*tensors, **kwargs)
else:
#only support tdim=1
inp_shape = tensors[0].shape
bs, seq_len = inp_shape[0], inp_shape[1]
out = self.module(*[x.view(bs*seq_len, *x.shape[2:]) for x in tensors], **kwargs)
return self.format_output(out, bs, seq_len)
def low_mem_forward(self, *tensors, **kwargs):
"input x with shape:(bs,seq_len,channels,width,height)"
seq_len = tensors[0].shape[self.tdim]
args_split = [torch.unbind(x, dim=self.tdim) for x in tensors]
out = []
for i in range(seq_len):
out.append(self.module(*[args[i] for args in args_split]), **kwargs)
if isinstance(out[0], tuple):
return _stack_tups(out, stack_dim=self.tdim)
return torch.stack(out, dim=self.tdim)
def format_output(self, out, bs, seq_len):
"unstack from batchsize outputs"
if isinstance(out, tuple):
return tuple(out_i.view(bs, seq_len, *out_i.shape[1:]) for out_i in out)
return out.view(bs, seq_len,*out.shape[1:])
def __repr__(self):
return f'TimeDistributed({self.module})'
# %% ../nbs/01_layers.ipynb 158
from torch.jit import script
# %% ../nbs/01_layers.ipynb 159
@script
def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x))
@script
def _swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class _SwishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _swish_jit_bwd(x, grad_output)
# %% ../nbs/01_layers.ipynb 160
def swish(x, inplace=False): return _SwishJitAutoFn.apply(x)
# %% ../nbs/01_layers.ipynb 161
class Swish(Module):
def forward(self, x): return _SwishJitAutoFn.apply(x)
# %% ../nbs/01_layers.ipynb 162
@script
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))
@script
def _mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return _mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
return _mish_jit_bwd(x, grad_output)
# %% ../nbs/01_layers.ipynb 163
def mish(x): return F.mish(x) if torch.__version__ >= '1.9' else MishJitAutoFn.apply(x)
# %% ../nbs/01_layers.ipynb 164
class Mish(Module):
def forward(self, x): return MishJitAutoFn.apply(x)
# %% ../nbs/01_layers.ipynb 165
if ismin_torch('1.9'): Mish = nn.Mish
# %% ../nbs/01_layers.ipynb 166
for o in swish,Swish,mish,Mish: o.__default_init__ = kaiming_uniform_
# %% ../nbs/01_layers.ipynb 169
class ParameterModule(Module):
"Register a lone parameter `p` in a module."
def __init__(self, p): self.val = p
def forward(self, x): return x
# %% ../nbs/01_layers.ipynb 170
def children_and_parameters(m):
"Return the children of `m` and its direct parameters not registered in modules."
children = list(m.children())
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
for p in m.parameters():
if id(p) not in children_p: children.append(ParameterModule(p))
return children
# %% ../nbs/01_layers.ipynb 172
def has_children(m):
try: next(m.children())
except StopIteration: return False
return True
# %% ../nbs/01_layers.ipynb 174
def flatten_model(m):
"Return the list of all submodules and parameters of `m`"
return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m]
# %% ../nbs/01_layers.ipynb 176
class NoneReduce():
"A context manager to evaluate `loss_func` with none reduce."
def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None
def __enter__(self):
if hasattr(self.loss_func, 'reduction'):
self.old_red = self.loss_func.reduction
self.loss_func.reduction = 'none'
return self.loss_func
else: return partial(self.loss_func, reduction='none')
def __exit__(self, type, value, traceback):
if self.old_red is not None: self.loss_func.reduction = self.old_red
# %% ../nbs/01_layers.ipynb 178
def in_channels(m):
"Return the shape of the first weight layer in `m`."
try: return next(l.weight.shape[1] for l in flatten_model(m) if nested_attr(l,'weight.ndim',-1)==4)
except StopIteration as e: e.args = ["No weight layer"]; raise
|