Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_layers.ipynb. | |
# %% ../nbs/01_layers.ipynb 2 | |
from __future__ import annotations | |
from .imports import * | |
from .torch_imports import * | |
from .torch_core import * | |
from torch.nn.utils import weight_norm, spectral_norm | |
# %% auto 0 | |
__all__ = ['NormType', 'inplace_relu', 'module', 'Identity', 'Lambda', 'PartialLambda', 'Flatten', 'ToTensorBase', 'View', | |
'ResizeBatch', 'Debugger', 'sigmoid_range', 'SigmoidRange', 'AdaptiveConcatPool1d', 'AdaptiveConcatPool2d', | |
'PoolType', 'adaptive_pool', 'PoolFlatten', 'BatchNorm', 'InstanceNorm', 'BatchNorm1dFlat', 'LinBnDrop', | |
'sigmoid', 'sigmoid_', 'vleaky_relu', 'init_default', 'init_linear', 'ConvLayer', 'AdaptiveAvgPool', | |
'MaxPool', 'AvgPool', 'trunc_normal_', 'Embedding', 'SelfAttention', 'PooledSelfAttention2d', | |
'SimpleSelfAttention', 'icnr_init', 'PixelShuffle_ICNR', 'sequential', 'SequentialEx', 'MergeLayer', 'Cat', | |
'SimpleCNN', 'ProdLayer', 'SEModule', 'ResBlock', 'SEBlock', 'SEResNeXtBlock', 'SeparableBlock', | |
'TimeDistributed', 'swish', 'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule', | |
'children_and_parameters', 'has_children', 'flatten_model', 'NoneReduce', 'in_channels'] | |
# %% ../nbs/01_layers.ipynb 6 | |
def module(*flds, **defaults): | |
"Decorator to create an `nn.Module` using `f` as `forward` method" | |
pa = [inspect.Parameter(o, inspect.Parameter.POSITIONAL_OR_KEYWORD) for o in flds] | |
pb = [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=v) | |
for k,v in defaults.items()] | |
params = pa+pb | |
all_flds = [*flds,*defaults.keys()] | |
def _f(f): | |
class c(nn.Module): | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
for i,o in enumerate(args): kwargs[all_flds[i]] = o | |
kwargs = merge(defaults,kwargs) | |
for k,v in kwargs.items(): setattr(self,k,v) | |
__repr__ = basic_repr(all_flds) | |
forward = f | |
c.__signature__ = inspect.Signature(params) | |
c.__name__ = c.__qualname__ = f.__name__ | |
c.__doc__ = f.__doc__ | |
return c | |
return _f | |
# %% ../nbs/01_layers.ipynb 7 | |
def Identity(self, x): | |
"Do nothing at all" | |
return x | |
# %% ../nbs/01_layers.ipynb 9 | |
def Lambda(self, x): | |
"An easy way to create a pytorch layer for a simple `func`" | |
return self.func(x) | |
# %% ../nbs/01_layers.ipynb 11 | |
class PartialLambda(Lambda): | |
"Layer that applies `partial(func, **kwargs)`" | |
def __init__(self, func, **kwargs): | |
super().__init__(partial(func, **kwargs)) | |
self.repr = f'{func.__name__}, {kwargs}' | |
def forward(self, x): return self.func(x) | |
def __repr__(self): return f'{self.__class__.__name__}({self.repr})' | |
# %% ../nbs/01_layers.ipynb 13 | |
def Flatten(self, x): | |
"Flatten `x` to a single dimension, e.g. at end of a model. `full` for rank-1 tensor" | |
return x.view(-1) if self.full else x.view(x.size(0), -1) # Removed cast to Tensorbase | |
# %% ../nbs/01_layers.ipynb 15 | |
def ToTensorBase(self, x): | |
"Convert x to TensorBase class" | |
return self.tensor_cls(x) | |
# %% ../nbs/01_layers.ipynb 17 | |
class View(Module): | |
"Reshape `x` to `size`" | |
def __init__(self, *size): self.size = size | |
def forward(self, x): return x.view(self.size) | |
# %% ../nbs/01_layers.ipynb 19 | |
class ResizeBatch(Module): | |
"Reshape `x` to `size`, keeping batch dim the same size" | |
def __init__(self, *size): self.size = size | |
def forward(self, x): return x.view((x.size(0),) + self.size) | |
# %% ../nbs/01_layers.ipynb 21 | |
def Debugger(self,x): | |
"A module to debug inside a model." | |
set_trace() | |
return x | |
# %% ../nbs/01_layers.ipynb 22 | |
def sigmoid_range(x, low, high): | |
"Sigmoid function with range `(low, high)`" | |
return torch.sigmoid(x) * (high - low) + low | |
# %% ../nbs/01_layers.ipynb 24 | |
def SigmoidRange(self, x): | |
"Sigmoid module with range `(low, high)`" | |
return sigmoid_range(x, self.low, self.high) | |
# %% ../nbs/01_layers.ipynb 27 | |
class AdaptiveConcatPool1d(Module): | |
"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`" | |
def __init__(self, size=None): | |
self.size = size or 1 | |
self.ap = nn.AdaptiveAvgPool1d(self.size) | |
self.mp = nn.AdaptiveMaxPool1d(self.size) | |
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) | |
# %% ../nbs/01_layers.ipynb 28 | |
class AdaptiveConcatPool2d(Module): | |
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`" | |
def __init__(self, size=None): | |
self.size = size or 1 | |
self.ap = nn.AdaptiveAvgPool2d(self.size) | |
self.mp = nn.AdaptiveMaxPool2d(self.size) | |
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) | |
# %% ../nbs/01_layers.ipynb 31 | |
class PoolType: Avg,Max,Cat = 'Avg','Max','Cat' | |
# %% ../nbs/01_layers.ipynb 32 | |
def adaptive_pool(pool_type): | |
return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d | |
# %% ../nbs/01_layers.ipynb 33 | |
class PoolFlatten(nn.Sequential): | |
"Combine `nn.AdaptiveAvgPool2d` and `Flatten`." | |
def __init__(self, pool_type=PoolType.Avg): super().__init__(adaptive_pool(pool_type)(1), Flatten()) | |
# %% ../nbs/01_layers.ipynb 36 | |
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero') | |
# %% ../nbs/01_layers.ipynb 37 | |
def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs): | |
"Norm layer with `nf` features and `ndim` initialized depending on `norm_type`." | |
assert 1 <= ndim <= 3 | |
bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs) | |
if bn.affine: | |
bn.bias.data.fill_(1e-3) | |
bn.weight.data.fill_(0. if zero else 1.) | |
return bn | |
# %% ../nbs/01_layers.ipynb 38 | |
def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs): | |
"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." | |
return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs) | |
# %% ../nbs/01_layers.ipynb 39 | |
def InstanceNorm(nf, ndim=2, norm_type=NormType.Instance, affine=True, **kwargs): | |
"InstanceNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." | |
return _get_norm('InstanceNorm', nf, ndim, zero=norm_type==NormType.InstanceZero, affine=affine, **kwargs) | |
# %% ../nbs/01_layers.ipynb 45 | |
class BatchNorm1dFlat(nn.BatchNorm1d): | |
"`nn.BatchNorm1d`, but first flattens leading dimensions" | |
def forward(self, x): | |
if x.dim()==2: return super().forward(x) | |
*f,l = x.shape | |
x = x.contiguous().view(-1,l) | |
return super().forward(x).view(*f,l) | |
# %% ../nbs/01_layers.ipynb 47 | |
class LinBnDrop(nn.Sequential): | |
"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers" | |
def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False): | |
layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else [] | |
if p != 0: layers.append(nn.Dropout(p)) | |
lin = [nn.Linear(n_in, n_out, bias=not bn)] | |
if act is not None: lin.append(act) | |
layers = lin+layers if lin_first else layers+lin | |
super().__init__(*layers) | |
# %% ../nbs/01_layers.ipynb 51 | |
def sigmoid(input, eps=1e-7): | |
"Same as `torch.sigmoid`, plus clamping to `(eps,1-eps)" | |
return input.sigmoid().clamp(eps,1-eps) | |
# %% ../nbs/01_layers.ipynb 52 | |
def sigmoid_(input, eps=1e-7): | |
"Same as `torch.sigmoid_`, plus clamping to `(eps,1-eps)" | |
return input.sigmoid_().clamp_(eps,1-eps) | |
# %% ../nbs/01_layers.ipynb 53 | |
from torch.nn.init import kaiming_uniform_,uniform_,xavier_uniform_,normal_ | |
# %% ../nbs/01_layers.ipynb 54 | |
def vleaky_relu(input, inplace=True): | |
"`F.leaky_relu` with 0.3 slope" | |
return F.leaky_relu(input, negative_slope=0.3, inplace=inplace) | |
# %% ../nbs/01_layers.ipynb 55 | |
for o in F.relu,nn.ReLU,F.relu6,nn.ReLU6,F.leaky_relu,nn.LeakyReLU: | |
o.__default_init__ = kaiming_uniform_ | |
# %% ../nbs/01_layers.ipynb 56 | |
for o in F.sigmoid,nn.Sigmoid,F.tanh,nn.Tanh,sigmoid,sigmoid_: | |
o.__default_init__ = xavier_uniform_ | |
# %% ../nbs/01_layers.ipynb 57 | |
def init_default(m, func=nn.init.kaiming_normal_): | |
"Initialize `m` weights with `func` and set `bias` to 0." | |
if func and hasattr(m, 'weight'): func(m.weight) | |
with torch.no_grad(): nested_callable(m, 'bias.fill_')(0.) | |
return m | |
# %% ../nbs/01_layers.ipynb 58 | |
def init_linear(m, act_func=None, init='auto', bias_std=0.01): | |
if getattr(m,'bias',None) is not None and bias_std is not None: | |
if bias_std != 0: normal_(m.bias, 0, bias_std) | |
else: m.bias.data.zero_() | |
if init=='auto': | |
if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_ | |
else: init = nested_callable(act_func, '__class__.__default_init__') | |
if init == noop: init = getcallable(act_func, '__default_init__') | |
if callable(init): init(m.weight) | |
# %% ../nbs/01_layers.ipynb 60 | |
def _conv_func(ndim=2, transpose=False): | |
"Return the proper conv `ndim` function, potentially `transposed`." | |
assert 1 <= ndim <=3 | |
return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d') | |
# %% ../nbs/01_layers.ipynb 62 | |
defaults.activation=nn.ReLU | |
# %% ../nbs/01_layers.ipynb 63 | |
class ConvLayer(nn.Sequential): | |
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers." | |
def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True, | |
act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs): | |
if padding is None: padding = ((ks-1)//2 if not transpose else 0) | |
bn = norm_type in (NormType.Batch, NormType.BatchZero) | |
inn = norm_type in (NormType.Instance, NormType.InstanceZero) | |
if bias is None: bias = not (bn or inn) | |
conv_func = _conv_func(ndim, transpose=transpose) | |
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs) | |
act = None if act_cls is None else act_cls() | |
init_linear(conv, act, init=init, bias_std=bias_std) | |
if norm_type==NormType.Weight: conv = weight_norm(conv) | |
elif norm_type==NormType.Spectral: conv = spectral_norm(conv) | |
layers = [conv] | |
act_bn = [] | |
if act is not None: act_bn.append(act) | |
if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim)) | |
if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim)) | |
if bn_1st: act_bn.reverse() | |
layers += act_bn | |
if xtra: layers.append(xtra) | |
super().__init__(*layers) | |
# %% ../nbs/01_layers.ipynb 77 | |
def AdaptiveAvgPool(sz=1, ndim=2): | |
"nn.AdaptiveAvgPool layer for `ndim`" | |
assert 1 <= ndim <= 3 | |
return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz) | |
# %% ../nbs/01_layers.ipynb 78 | |
def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): | |
"nn.MaxPool layer for `ndim`" | |
assert 1 <= ndim <= 3 | |
return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding) | |
# %% ../nbs/01_layers.ipynb 79 | |
def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): | |
"nn.AvgPool layer for `ndim`" | |
assert 1 <= ndim <= 3 | |
return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode) | |
# %% ../nbs/01_layers.ipynb 81 | |
def trunc_normal_(x, mean=0., std=1.): | |
"Truncated normal initialization (approximation)" | |
# From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12 | |
return x.normal_().fmod_(2).mul_(std).add_(mean) | |
# %% ../nbs/01_layers.ipynb 82 | |
class Embedding(nn.Embedding): | |
"Embedding layer with truncated normal initialization" | |
def __init__(self, ni, nf, std=0.01): | |
super().__init__(ni, nf) | |
trunc_normal_(self.weight.data, std=std) | |
# %% ../nbs/01_layers.ipynb 86 | |
class SelfAttention(Module): | |
"Self attention layer for `n_channels`." | |
def __init__(self, n_channels): | |
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)] | |
self.gamma = nn.Parameter(tensor([0.])) | |
def _conv(self,n_in,n_out): | |
return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False) | |
def forward(self, x): | |
#Notation from the paper. | |
size = x.size() | |
x = x.view(*size[:2],-1) | |
f,g,h = self.query(x),self.key(x),self.value(x) | |
beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1) | |
o = self.gamma * torch.bmm(h, beta) + x | |
return o.view(*size).contiguous() | |
# %% ../nbs/01_layers.ipynb 95 | |
class PooledSelfAttention2d(Module): | |
"Pooled self attention layer for 2d." | |
def __init__(self, n_channels): | |
self.n_channels = n_channels | |
self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)] | |
self.out = self._conv(n_channels//2, n_channels) | |
self.gamma = nn.Parameter(tensor([0.])) | |
def _conv(self,n_in,n_out): | |
return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False) | |
def forward(self, x): | |
n_ftrs = x.shape[2]*x.shape[3] | |
f = self.query(x).view(-1, self.n_channels//8, n_ftrs) | |
g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4) | |
h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4) | |
beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1) | |
o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3])) | |
return self.gamma * o + x | |
# %% ../nbs/01_layers.ipynb 97 | |
def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False): | |
"Create and initialize a `nn.Conv1d` layer with spectral normalization." | |
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) | |
nn.init.kaiming_normal_(conv.weight) | |
if bias: conv.bias.data.zero_() | |
return spectral_norm(conv) | |
# %% ../nbs/01_layers.ipynb 98 | |
class SimpleSelfAttention(Module): | |
def __init__(self, n_in:int, ks=1, sym=False): | |
self.sym,self.n_in = sym,n_in | |
self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False) | |
self.gamma = nn.Parameter(tensor([0.])) | |
def forward(self,x): | |
if self.sym: | |
c = self.conv.weight.view(self.n_in,self.n_in) | |
c = (c + c.t())/2 | |
self.conv.weight = c.view(self.n_in,self.n_in,1) | |
size = x.size() | |
x = x.view(*size[:2],-1) | |
convx = self.conv(x) | |
xxT = torch.bmm(x,x.permute(0,2,1).contiguous()) | |
o = torch.bmm(xxT, convx) | |
o = self.gamma * o + x | |
return o.view(*size).contiguous() | |
# %% ../nbs/01_layers.ipynb 101 | |
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_): | |
"ICNR init of `x`, with `scale` and `init` function" | |
ni,nf,h,w = x.shape | |
ni2 = int(ni/(scale**2)) | |
k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1) | |
k = k.contiguous().view(ni2, nf, -1) | |
k = k.repeat(1, 1, scale**2) | |
return k.contiguous().view([nf,ni,h,w]).transpose(0, 1) | |
# %% ../nbs/01_layers.ipynb 104 | |
class PixelShuffle_ICNR(nn.Sequential): | |
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`." | |
def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation): | |
super().__init__() | |
nf = ifnone(nf, ni) | |
layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0), | |
nn.PixelShuffle(scale)] | |
if norm_type == NormType.Weight: | |
layers[0][0].weight_v.data.copy_(icnr_init(layers[0][0].weight_v.data)) | |
layers[0][0].weight_g.data.copy_(((layers[0][0].weight_v.data**2).sum(dim=[1,2,3])**0.5)[:,None,None,None]) | |
else: | |
layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data)) | |
if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)] | |
super().__init__(*layers) | |
# %% ../nbs/01_layers.ipynb 110 | |
def sequential(*args): | |
"Create an `nn.Sequential`, wrapping items with `Lambda` if needed" | |
if len(args) != 1 or not isinstance(args[0], OrderedDict): | |
args = list(args) | |
for i,o in enumerate(args): | |
if not isinstance(o,nn.Module): args[i] = Lambda(o) | |
return nn.Sequential(*args) | |
# %% ../nbs/01_layers.ipynb 111 | |
class SequentialEx(Module): | |
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input" | |
def __init__(self, *layers): self.layers = nn.ModuleList(layers) | |
def forward(self, x): | |
res = x | |
for l in self.layers: | |
res.orig = x | |
nres = l(res) | |
# We have to remove res.orig to avoid hanging refs and therefore memory leaks | |
res.orig, nres.orig = None, None | |
res = nres | |
return res | |
def __getitem__(self,i): return self.layers[i] | |
def append(self,l): return self.layers.append(l) | |
def extend(self,l): return self.layers.extend(l) | |
def insert(self,i,l): return self.layers.insert(i,l) | |
# %% ../nbs/01_layers.ipynb 113 | |
class MergeLayer(Module): | |
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`." | |
def __init__(self, dense:bool=False): self.dense=dense | |
def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig) | |
# %% ../nbs/01_layers.ipynb 118 | |
class Cat(nn.ModuleList): | |
"Concatenate layers outputs over a given dim" | |
def __init__(self, layers, dim=1): | |
self.dim=dim | |
super().__init__(layers) | |
def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim) | |
# %% ../nbs/01_layers.ipynb 121 | |
class SimpleCNN(nn.Sequential): | |
"Create a simple CNN with `filters`." | |
def __init__(self, filters, kernel_szs=None, strides=None, bn=True): | |
nl = len(filters)-1 | |
kernel_szs = ifnone(kernel_szs, [3]*nl) | |
strides = ifnone(strides , [2]*nl) | |
layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i], | |
norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range(nl)] | |
layers.append(PoolFlatten()) | |
super().__init__(*layers) | |
# %% ../nbs/01_layers.ipynb 128 | |
class ProdLayer(Module): | |
"Merge a shortcut with the result of the module by multiplying them." | |
def forward(self, x): return x * x.orig | |
# %% ../nbs/01_layers.ipynb 129 | |
inplace_relu = partial(nn.ReLU, inplace=True) | |
# %% ../nbs/01_layers.ipynb 130 | |
def SEModule(ch, reduction, act_cls=defaults.activation): | |
nf = math.ceil(ch//reduction/8)*8 | |
return SequentialEx(nn.AdaptiveAvgPool2d(1), | |
ConvLayer(ch, nf, ks=1, norm_type=None, act_cls=act_cls), | |
ConvLayer(nf, ch, ks=1, norm_type=None, act_cls=nn.Sigmoid), | |
ProdLayer()) | |
# %% ../nbs/01_layers.ipynb 131 | |
class ResBlock(Module): | |
"Resnet block from `ni` to `nh` with `stride`" | |
def __init__(self, expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, | |
sa=False, sym=False, norm_type=NormType.Batch, act_cls=defaults.activation, ndim=2, ks=3, | |
pool=AvgPool, pool_first=True, **kwargs): | |
norm2 = (NormType.BatchZero if norm_type==NormType.Batch else | |
NormType.InstanceZero if norm_type==NormType.Instance else norm_type) | |
if nh2 is None: nh2 = nf | |
if nh1 is None: nh1 = nh2 | |
nf,ni = nf*expansion,ni*expansion | |
k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs) | |
k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs) | |
convpath = [ConvLayer(ni, nh2, ks, stride=stride, groups=ni if dw else groups, **k0), | |
ConvLayer(nh2, nf, ks, groups=g2, **k1) | |
] if expansion == 1 else [ | |
ConvLayer(ni, nh1, 1, **k0), | |
ConvLayer(nh1, nh2, ks, stride=stride, groups=nh1 if dw else groups, **k0), | |
ConvLayer(nh2, nf, 1, groups=g2, **k1)] | |
if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls)) | |
if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym)) | |
self.convpath = nn.Sequential(*convpath) | |
idpath = [] | |
if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs)) | |
if stride!=1: idpath.insert((1,0)[pool_first], pool(stride, ndim=ndim, ceil_mode=True)) | |
self.idpath = nn.Sequential(*idpath) | |
self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls() | |
def forward(self, x): return self.act(self.convpath(x) + self.idpath(x)) | |
# %% ../nbs/01_layers.ipynb 133 | |
def SEBlock(expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs): | |
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh1=nf*2, nh2=nf*expansion, **kwargs) | |
# %% ../nbs/01_layers.ipynb 134 | |
def SEResNeXtBlock(expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs): | |
w = math.floor(nf * (base_width / 64)) * groups | |
return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh2=w, **kwargs) | |
# %% ../nbs/01_layers.ipynb 135 | |
def SeparableBlock(expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs): | |
return ResBlock(expansion, ni, nf, stride=stride, reduction=reduction, nh2=nf*2, dw=True, **kwargs) | |
# %% ../nbs/01_layers.ipynb 138 | |
def _stack_tups(tuples, stack_dim=1): | |
"Stack tuple of tensors along `stack_dim`" | |
return tuple(torch.stack([t[i] for t in tuples], dim=stack_dim) for i in range_of(tuples[0])) | |
# %% ../nbs/01_layers.ipynb 139 | |
class TimeDistributed(Module): | |
"Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time." | |
def __init__(self, module, low_mem=False, tdim=1): | |
store_attr() | |
def forward(self, *tensors, **kwargs): | |
"input x with shape:(bs,seq_len,channels,width,height)" | |
if self.low_mem or self.tdim!=1: | |
return self.low_mem_forward(*tensors, **kwargs) | |
else: | |
#only support tdim=1 | |
inp_shape = tensors[0].shape | |
bs, seq_len = inp_shape[0], inp_shape[1] | |
out = self.module(*[x.view(bs*seq_len, *x.shape[2:]) for x in tensors], **kwargs) | |
return self.format_output(out, bs, seq_len) | |
def low_mem_forward(self, *tensors, **kwargs): | |
"input x with shape:(bs,seq_len,channels,width,height)" | |
seq_len = tensors[0].shape[self.tdim] | |
args_split = [torch.unbind(x, dim=self.tdim) for x in tensors] | |
out = [] | |
for i in range(seq_len): | |
out.append(self.module(*[args[i] for args in args_split]), **kwargs) | |
if isinstance(out[0], tuple): | |
return _stack_tups(out, stack_dim=self.tdim) | |
return torch.stack(out, dim=self.tdim) | |
def format_output(self, out, bs, seq_len): | |
"unstack from batchsize outputs" | |
if isinstance(out, tuple): | |
return tuple(out_i.view(bs, seq_len, *out_i.shape[1:]) for out_i in out) | |
return out.view(bs, seq_len,*out.shape[1:]) | |
def __repr__(self): | |
return f'TimeDistributed({self.module})' | |
# %% ../nbs/01_layers.ipynb 158 | |
from torch.jit import script | |
# %% ../nbs/01_layers.ipynb 159 | |
def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x)) | |
def _swish_jit_bwd(x, grad_output): | |
x_sigmoid = torch.sigmoid(x) | |
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) | |
class _SwishJitAutoFn(torch.autograd.Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return _swish_jit_fwd(x) | |
def backward(ctx, grad_output): | |
x = ctx.saved_variables[0] | |
return _swish_jit_bwd(x, grad_output) | |
# %% ../nbs/01_layers.ipynb 160 | |
def swish(x, inplace=False): return _SwishJitAutoFn.apply(x) | |
# %% ../nbs/01_layers.ipynb 161 | |
class Swish(Module): | |
def forward(self, x): return _SwishJitAutoFn.apply(x) | |
# %% ../nbs/01_layers.ipynb 162 | |
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x))) | |
def _mish_jit_bwd(x, grad_output): | |
x_sigmoid = torch.sigmoid(x) | |
x_tanh_sp = F.softplus(x).tanh() | |
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | |
class MishJitAutoFn(torch.autograd.Function): | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
return _mish_jit_fwd(x) | |
def backward(ctx, grad_output): | |
x = ctx.saved_variables[0] | |
return _mish_jit_bwd(x, grad_output) | |
# %% ../nbs/01_layers.ipynb 163 | |
def mish(x): return F.mish(x) if torch.__version__ >= '1.9' else MishJitAutoFn.apply(x) | |
# %% ../nbs/01_layers.ipynb 164 | |
class Mish(Module): | |
def forward(self, x): return MishJitAutoFn.apply(x) | |
# %% ../nbs/01_layers.ipynb 165 | |
if ismin_torch('1.9'): Mish = nn.Mish | |
# %% ../nbs/01_layers.ipynb 166 | |
for o in swish,Swish,mish,Mish: o.__default_init__ = kaiming_uniform_ | |
# %% ../nbs/01_layers.ipynb 169 | |
class ParameterModule(Module): | |
"Register a lone parameter `p` in a module." | |
def __init__(self, p): self.val = p | |
def forward(self, x): return x | |
# %% ../nbs/01_layers.ipynb 170 | |
def children_and_parameters(m): | |
"Return the children of `m` and its direct parameters not registered in modules." | |
children = list(m.children()) | |
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[]) | |
for p in m.parameters(): | |
if id(p) not in children_p: children.append(ParameterModule(p)) | |
return children | |
# %% ../nbs/01_layers.ipynb 172 | |
def has_children(m): | |
try: next(m.children()) | |
except StopIteration: return False | |
return True | |
# %% ../nbs/01_layers.ipynb 174 | |
def flatten_model(m): | |
"Return the list of all submodules and parameters of `m`" | |
return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m] | |
# %% ../nbs/01_layers.ipynb 176 | |
class NoneReduce(): | |
"A context manager to evaluate `loss_func` with none reduce." | |
def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None | |
def __enter__(self): | |
if hasattr(self.loss_func, 'reduction'): | |
self.old_red = self.loss_func.reduction | |
self.loss_func.reduction = 'none' | |
return self.loss_func | |
else: return partial(self.loss_func, reduction='none') | |
def __exit__(self, type, value, traceback): | |
if self.old_red is not None: self.loss_func.reduction = self.old_red | |
# %% ../nbs/01_layers.ipynb 178 | |
def in_channels(m): | |
"Return the shape of the first weight layer in `m`." | |
try: return next(l.weight.shape[1] for l in flatten_model(m) if nested_attr(l,'weight.ndim',-1)==4) | |
except StopIteration as e: e.args = ["No weight layer"]; raise | |