|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torchvision.models.utils import load_state_dict_from_url |
|
from typing import Type, Any, Callable, Union, List, Optional |
|
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet |
|
|
|
|
|
__all__ = [ |
|
"ResNet", |
|
"resnet18", |
|
"resnet34", |
|
"resnet50", |
|
"resnet101", |
|
"resnet152", |
|
"resnext50_32x4d", |
|
"resnext101_32x8d", |
|
"wide_resnet50_2", |
|
"wide_resnet101_2", |
|
] |
|
|
|
|
|
model_urls = { |
|
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", |
|
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", |
|
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", |
|
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", |
|
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", |
|
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", |
|
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", |
|
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", |
|
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", |
|
} |
|
|
|
|
|
class ResNet_mine(ResNet): |
|
def __init__(self, block, layers, classifier_run=True, **kwargs): |
|
super().__init__(block, layers, **kwargs) |
|
self.classifier_run = classifier_run |
|
|
|
def _forward_impl(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): |
|
|
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x_ = self.layer4(x) |
|
|
|
x = self.avgpool(x_) |
|
x = torch.flatten(x, 1) |
|
if self.classifier_run: |
|
x = self.fc(x) |
|
|
|
return x, x_ |
|
|
|
def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): |
|
return self._forward_impl(x) |
|
|
|
|
|
def pnorm(weights, p): |
|
normB = torch.norm(weights, 2, 1) |
|
ws = weights.clone() |
|
for i in range(weights.size(0)): |
|
ws[i] = ws[i] / torch.pow(normB[i], p) |
|
return ws |
|
|
|
|
|
def _resnet( |
|
arch: str, |
|
block: Type[Union[BasicBlock, Bottleneck]], |
|
layers: List[int], |
|
pretrained: bool, |
|
progress: bool, |
|
**kwargs: Any |
|
) -> ResNet: |
|
model = ResNet_mine(block, layers, **kwargs) |
|
if pretrained: |
|
print("Inside resnet function, using ImageNet pretrained from model url!") |
|
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
def resnext50_32x4d( |
|
pretrained: bool = False, progress: bool = True, **kwargs: Any |
|
) -> ResNet: |
|
r"""ResNeXt-50 32x4d model from |
|
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. |
|
Args: |
|
pretrained (bool): If True, returns a model pre-trained on ImageNet |
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
""" |
|
kwargs["groups"] = 32 |
|
kwargs["width_per_group"] = 4 |
|
return _resnet( |
|
"resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs |
|
) |
|
|
|
|
|
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: |
|
r"""ResNet-50 model from |
|
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. |
|
|
|
Args: |
|
pretrained (bool): If True, returns a model pre-trained on ImageNet |
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
""" |
|
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) |
|
|