|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
|
|
|
class InitWeights_He(object): |
|
def __init__(self, neg_slope=1e-2): |
|
self.neg_slope = neg_slope |
|
|
|
def __call__(self, module): |
|
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): |
|
module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) |
|
if module.bias is not None: |
|
module.bias = nn.init.constant_(module.bias, 0) |
|
|
|
|
|
class InitWeights_XavierUniform(object): |
|
def __init__(self, gain=1): |
|
self.gain = gain |
|
|
|
def __call__(self, module): |
|
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): |
|
module.weight = nn.init.xavier_uniform_(module.weight, self.gain) |
|
if module.bias is not None: |
|
module.bias = nn.init.constant_(module.bias, 0) |
|
|