# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_layers.ipynb. # %% ../nbs/01_layers.ipynb 2 from __future__ import annotations from .imports import * from .torch_imports import * from .torch_core import * from torch.nn.utils import weight_norm, spectral_norm # %% auto 0 __all__ = ['NormType', 'inplace_relu', 'module', 'Identity', 'Lambda', 'PartialLambda', 'Flatten', 'ToTensorBase', 'View', 'ResizeBatch', 'Debugger', 'sigmoid_range', 'SigmoidRange', 'AdaptiveConcatPool1d', 'AdaptiveConcatPool2d', 'PoolType', 'adaptive_pool', 'PoolFlatten', 'BatchNorm', 'InstanceNorm', 'BatchNorm1dFlat', 'LinBnDrop', 'sigmoid', 'sigmoid_', 'vleaky_relu', 'init_default', 'init_linear', 'ConvLayer', 'AdaptiveAvgPool', 'MaxPool', 'AvgPool', 'trunc_normal_', 'Embedding', 'SelfAttention', 'PooledSelfAttention2d', 'SimpleSelfAttention', 'icnr_init', 'PixelShuffle_ICNR', 'sequential', 'SequentialEx', 'MergeLayer', 'Cat', 'SimpleCNN', 'ProdLayer', 'SEModule', 'ResBlock', 'SEBlock', 'SEResNeXtBlock', 'SeparableBlock', 'TimeDistributed', 'swish', 'Swish', 'MishJitAutoFn', 'mish', 'Mish', 'ParameterModule', 'children_and_parameters', 'has_children', 'flatten_model', 'NoneReduce', 'in_channels'] # %% ../nbs/01_layers.ipynb 6 def module(*flds, **defaults): "Decorator to create an `nn.Module` using `f` as `forward` method" pa = [inspect.Parameter(o, inspect.Parameter.POSITIONAL_OR_KEYWORD) for o in flds] pb = [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=v) for k,v in defaults.items()] params = pa+pb all_flds = [*flds,*defaults.keys()] def _f(f): class c(nn.Module): def __init__(self, *args, **kwargs): super().__init__() for i,o in enumerate(args): kwargs[all_flds[i]] = o kwargs = merge(defaults,kwargs) for k,v in kwargs.items(): setattr(self,k,v) __repr__ = basic_repr(all_flds) forward = f c.__signature__ = inspect.Signature(params) c.__name__ = c.__qualname__ = f.__name__ c.__doc__ = f.__doc__ return c return _f # %% ../nbs/01_layers.ipynb 7 @module() def Identity(self, x): "Do nothing at all" return x # %% ../nbs/01_layers.ipynb 9 @module('func') def Lambda(self, x): "An easy way to create a pytorch layer for a simple `func`" return self.func(x) # %% ../nbs/01_layers.ipynb 11 class PartialLambda(Lambda): "Layer that applies `partial(func, **kwargs)`" def __init__(self, func, **kwargs): super().__init__(partial(func, **kwargs)) self.repr = f'{func.__name__}, {kwargs}' def forward(self, x): return self.func(x) def __repr__(self): return f'{self.__class__.__name__}({self.repr})' # %% ../nbs/01_layers.ipynb 13 @module(full=False) def Flatten(self, x): "Flatten `x` to a single dimension, e.g. at end of a model. `full` for rank-1 tensor" return x.view(-1) if self.full else x.view(x.size(0), -1) # Removed cast to Tensorbase # %% ../nbs/01_layers.ipynb 15 @module(tensor_cls=TensorBase) def ToTensorBase(self, x): "Convert x to TensorBase class" return self.tensor_cls(x) # %% ../nbs/01_layers.ipynb 17 class View(Module): "Reshape `x` to `size`" def __init__(self, *size): self.size = size def forward(self, x): return x.view(self.size) # %% ../nbs/01_layers.ipynb 19 class ResizeBatch(Module): "Reshape `x` to `size`, keeping batch dim the same size" def __init__(self, *size): self.size = size def forward(self, x): return x.view((x.size(0),) + self.size) # %% ../nbs/01_layers.ipynb 21 @module() def Debugger(self,x): "A module to debug inside a model." set_trace() return x # %% ../nbs/01_layers.ipynb 22 def sigmoid_range(x, low, high): "Sigmoid function with range `(low, high)`" return torch.sigmoid(x) * (high - low) + low # %% ../nbs/01_layers.ipynb 24 @module('low','high') def SigmoidRange(self, x): "Sigmoid module with range `(low, high)`" return sigmoid_range(x, self.low, self.high) # %% ../nbs/01_layers.ipynb 27 class AdaptiveConcatPool1d(Module): "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`" def __init__(self, size=None): self.size = size or 1 self.ap = nn.AdaptiveAvgPool1d(self.size) self.mp = nn.AdaptiveMaxPool1d(self.size) def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) # %% ../nbs/01_layers.ipynb 28 class AdaptiveConcatPool2d(Module): "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`" def __init__(self, size=None): self.size = size or 1 self.ap = nn.AdaptiveAvgPool2d(self.size) self.mp = nn.AdaptiveMaxPool2d(self.size) def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) # %% ../nbs/01_layers.ipynb 31 class PoolType: Avg,Max,Cat = 'Avg','Max','Cat' # %% ../nbs/01_layers.ipynb 32 def adaptive_pool(pool_type): return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d # %% ../nbs/01_layers.ipynb 33 class PoolFlatten(nn.Sequential): "Combine `nn.AdaptiveAvgPool2d` and `Flatten`." def __init__(self, pool_type=PoolType.Avg): super().__init__(adaptive_pool(pool_type)(1), Flatten()) # %% ../nbs/01_layers.ipynb 36 NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero') # %% ../nbs/01_layers.ipynb 37 def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs): "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`." assert 1 <= ndim <= 3 bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs) if bn.affine: bn.bias.data.fill_(1e-3) bn.weight.data.fill_(0. if zero else 1.) return bn # %% ../nbs/01_layers.ipynb 38 @delegates(nn.BatchNorm2d) def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs): "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs) # %% ../nbs/01_layers.ipynb 39 @delegates(nn.InstanceNorm2d) def InstanceNorm(nf, ndim=2, norm_type=NormType.Instance, affine=True, **kwargs): "InstanceNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." return _get_norm('InstanceNorm', nf, ndim, zero=norm_type==NormType.InstanceZero, affine=affine, **kwargs) # %% ../nbs/01_layers.ipynb 45 class BatchNorm1dFlat(nn.BatchNorm1d): "`nn.BatchNorm1d`, but first flattens leading dimensions" def forward(self, x): if x.dim()==2: return super().forward(x) *f,l = x.shape x = x.contiguous().view(-1,l) return super().forward(x).view(*f,l) # %% ../nbs/01_layers.ipynb 47 class LinBnDrop(nn.Sequential): "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers" def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False): layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else [] if p != 0: layers.append(nn.Dropout(p)) lin = [nn.Linear(n_in, n_out, bias=not bn)] if act is not None: lin.append(act) layers = lin+layers if lin_first else layers+lin super().__init__(*layers) # %% ../nbs/01_layers.ipynb 51 def sigmoid(input, eps=1e-7): "Same as `torch.sigmoid`, plus clamping to `(eps,1-eps)" return input.sigmoid().clamp(eps,1-eps) # %% ../nbs/01_layers.ipynb 52 def sigmoid_(input, eps=1e-7): "Same as `torch.sigmoid_`, plus clamping to `(eps,1-eps)" return input.sigmoid_().clamp_(eps,1-eps) # %% ../nbs/01_layers.ipynb 53 from torch.nn.init import kaiming_uniform_,uniform_,xavier_uniform_,normal_ # %% ../nbs/01_layers.ipynb 54 def vleaky_relu(input, inplace=True): "`F.leaky_relu` with 0.3 slope" return F.leaky_relu(input, negative_slope=0.3, inplace=inplace) # %% ../nbs/01_layers.ipynb 55 for o in F.relu,nn.ReLU,F.relu6,nn.ReLU6,F.leaky_relu,nn.LeakyReLU: o.__default_init__ = kaiming_uniform_ # %% ../nbs/01_layers.ipynb 56 for o in F.sigmoid,nn.Sigmoid,F.tanh,nn.Tanh,sigmoid,sigmoid_: o.__default_init__ = xavier_uniform_ # %% ../nbs/01_layers.ipynb 57 def init_default(m, func=nn.init.kaiming_normal_): "Initialize `m` weights with `func` and set `bias` to 0." if func and hasattr(m, 'weight'): func(m.weight) with torch.no_grad(): nested_callable(m, 'bias.fill_')(0.) return m # %% ../nbs/01_layers.ipynb 58 def init_linear(m, act_func=None, init='auto', bias_std=0.01): if getattr(m,'bias',None) is not None and bias_std is not None: if bias_std != 0: normal_(m.bias, 0, bias_std) else: m.bias.data.zero_() if init=='auto': if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_ else: init = nested_callable(act_func, '__class__.__default_init__') if init == noop: init = getcallable(act_func, '__default_init__') if callable(init): init(m.weight) # %% ../nbs/01_layers.ipynb 60 def _conv_func(ndim=2, transpose=False): "Return the proper conv `ndim` function, potentially `transposed`." assert 1 <= ndim <=3 return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d') # %% ../nbs/01_layers.ipynb 62 defaults.activation=nn.ReLU # %% ../nbs/01_layers.ipynb 63 class ConvLayer(nn.Sequential): "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers." @delegates(nn.Conv2d) def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True, act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs): if padding is None: padding = ((ks-1)//2 if not transpose else 0) bn = norm_type in (NormType.Batch, NormType.BatchZero) inn = norm_type in (NormType.Instance, NormType.InstanceZero) if bias is None: bias = not (bn or inn) conv_func = _conv_func(ndim, transpose=transpose) conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs) act = None if act_cls is None else act_cls() init_linear(conv, act, init=init, bias_std=bias_std) if norm_type==NormType.Weight: conv = weight_norm(conv) elif norm_type==NormType.Spectral: conv = spectral_norm(conv) layers = [conv] act_bn = [] if act is not None: act_bn.append(act) if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim)) if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim)) if bn_1st: act_bn.reverse() layers += act_bn if xtra: layers.append(xtra) super().__init__(*layers) # %% ../nbs/01_layers.ipynb 77 def AdaptiveAvgPool(sz=1, ndim=2): "nn.AdaptiveAvgPool layer for `ndim`" assert 1 <= ndim <= 3 return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz) # %% ../nbs/01_layers.ipynb 78 def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): "nn.MaxPool layer for `ndim`" assert 1 <= ndim <= 3 return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding) # %% ../nbs/01_layers.ipynb 79 def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): "nn.AvgPool layer for `ndim`" assert 1 <= ndim <= 3 return getattr(nn, f"AvgPool{ndim}d")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode) # %% ../nbs/01_layers.ipynb 81 def trunc_normal_(x, mean=0., std=1.): "Truncated normal initialization (approximation)" # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12 return x.normal_().fmod_(2).mul_(std).add_(mean) # %% ../nbs/01_layers.ipynb 82 class Embedding(nn.Embedding): "Embedding layer with truncated normal initialization" def __init__(self, ni, nf, std=0.01): super().__init__(ni, nf) trunc_normal_(self.weight.data, std=std) # %% ../nbs/01_layers.ipynb 86 class SelfAttention(Module): "Self attention layer for `n_channels`." def __init__(self, n_channels): self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)] self.gamma = nn.Parameter(tensor([0.])) def _conv(self,n_in,n_out): return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False) def forward(self, x): #Notation from the paper. size = x.size() x = x.view(*size[:2],-1) f,g,h = self.query(x),self.key(x),self.value(x) beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1) o = self.gamma * torch.bmm(h, beta) + x return o.view(*size).contiguous() # %% ../nbs/01_layers.ipynb 95 class PooledSelfAttention2d(Module): "Pooled self attention layer for 2d." def __init__(self, n_channels): self.n_channels = n_channels self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)] self.out = self._conv(n_channels//2, n_channels) self.gamma = nn.Parameter(tensor([0.])) def _conv(self,n_in,n_out): return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False) def forward(self, x): n_ftrs = x.shape[2]*x.shape[3] f = self.query(x).view(-1, self.n_channels//8, n_ftrs) g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4) h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4) beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1) o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3])) return self.gamma * o + x # %% ../nbs/01_layers.ipynb 97 def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False): "Create and initialize a `nn.Conv1d` layer with spectral normalization." conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) nn.init.kaiming_normal_(conv.weight) if bias: conv.bias.data.zero_() return spectral_norm(conv) # %% ../nbs/01_layers.ipynb 98 class SimpleSelfAttention(Module): def __init__(self, n_in:int, ks=1, sym=False): self.sym,self.n_in = sym,n_in self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False) self.gamma = nn.Parameter(tensor([0.])) def forward(self,x): if self.sym: c = self.conv.weight.view(self.n_in,self.n_in) c = (c + c.t())/2 self.conv.weight = c.view(self.n_in,self.n_in,1) size = x.size() x = x.view(*size[:2],-1) convx = self.conv(x) xxT = torch.bmm(x,x.permute(0,2,1).contiguous()) o = torch.bmm(xxT, convx) o = self.gamma * o + x return o.view(*size).contiguous() # %% ../nbs/01_layers.ipynb 101 def icnr_init(x, scale=2, init=nn.init.kaiming_normal_): "ICNR init of `x`, with `scale` and `init` function" ni,nf,h,w = x.shape ni2 = int(ni/(scale**2)) k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1) k = k.contiguous().view(ni2, nf, -1) k = k.repeat(1, 1, scale**2) return k.contiguous().view([nf,ni,h,w]).transpose(0, 1) # %% ../nbs/01_layers.ipynb 104 class PixelShuffle_ICNR(nn.Sequential): "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`." def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation): super().__init__() nf = ifnone(nf, ni) layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0), nn.PixelShuffle(scale)] if norm_type == NormType.Weight: layers[0][0].weight_v.data.copy_(icnr_init(layers[0][0].weight_v.data)) layers[0][0].weight_g.data.copy_(((layers[0][0].weight_v.data**2).sum(dim=[1,2,3])**0.5)[:,None,None,None]) else: layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data)) if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)] super().__init__(*layers) # %% ../nbs/01_layers.ipynb 110 def sequential(*args): "Create an `nn.Sequential`, wrapping items with `Lambda` if needed" if len(args) != 1 or not isinstance(args[0], OrderedDict): args = list(args) for i,o in enumerate(args): if not isinstance(o,nn.Module): args[i] = Lambda(o) return nn.Sequential(*args) # %% ../nbs/01_layers.ipynb 111 class SequentialEx(Module): "Like `nn.Sequential`, but with ModuleList semantics, and can access module input" def __init__(self, *layers): self.layers = nn.ModuleList(layers) def forward(self, x): res = x for l in self.layers: res.orig = x nres = l(res) # We have to remove res.orig to avoid hanging refs and therefore memory leaks res.orig, nres.orig = None, None res = nres return res def __getitem__(self,i): return self.layers[i] def append(self,l): return self.layers.append(l) def extend(self,l): return self.layers.extend(l) def insert(self,i,l): return self.layers.insert(i,l) # %% ../nbs/01_layers.ipynb 113 class MergeLayer(Module): "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`." def __init__(self, dense:bool=False): self.dense=dense def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig) # %% ../nbs/01_layers.ipynb 118 class Cat(nn.ModuleList): "Concatenate layers outputs over a given dim" def __init__(self, layers, dim=1): self.dim=dim super().__init__(layers) def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim) # %% ../nbs/01_layers.ipynb 121 class SimpleCNN(nn.Sequential): "Create a simple CNN with `filters`." def __init__(self, filters, kernel_szs=None, strides=None, bn=True): nl = len(filters)-1 kernel_szs = ifnone(kernel_szs, [3]*nl) strides = ifnone(strides , [2]*nl) layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i], norm_type=(NormType.Batch if bn and i= '1.9' else MishJitAutoFn.apply(x) # %% ../nbs/01_layers.ipynb 164 class Mish(Module): def forward(self, x): return MishJitAutoFn.apply(x) # %% ../nbs/01_layers.ipynb 165 if ismin_torch('1.9'): Mish = nn.Mish # %% ../nbs/01_layers.ipynb 166 for o in swish,Swish,mish,Mish: o.__default_init__ = kaiming_uniform_ # %% ../nbs/01_layers.ipynb 169 class ParameterModule(Module): "Register a lone parameter `p` in a module." def __init__(self, p): self.val = p def forward(self, x): return x # %% ../nbs/01_layers.ipynb 170 def children_and_parameters(m): "Return the children of `m` and its direct parameters not registered in modules." children = list(m.children()) children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[]) for p in m.parameters(): if id(p) not in children_p: children.append(ParameterModule(p)) return children # %% ../nbs/01_layers.ipynb 172 def has_children(m): try: next(m.children()) except StopIteration: return False return True # %% ../nbs/01_layers.ipynb 174 def flatten_model(m): "Return the list of all submodules and parameters of `m`" return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m] # %% ../nbs/01_layers.ipynb 176 class NoneReduce(): "A context manager to evaluate `loss_func` with none reduce." def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None def __enter__(self): if hasattr(self.loss_func, 'reduction'): self.old_red = self.loss_func.reduction self.loss_func.reduction = 'none' return self.loss_func else: return partial(self.loss_func, reduction='none') def __exit__(self, type, value, traceback): if self.old_red is not None: self.loss_func.reduction = self.old_red # %% ../nbs/01_layers.ipynb 178 def in_channels(m): "Return the shape of the first weight layer in `m`." try: return next(l.weight.shape[1] for l in flatten_model(m) if nested_attr(l,'weight.ndim',-1)==4) except StopIteration as e: e.args = ["No weight layer"]; raise