FredZhang7
commited on
Commit
·
e81feeb
1
Parent(s):
55a56ca
Delete modeling_customefficientnetv2.py
Browse files
modeling_customefficientnetv2.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from transformers import PretrainedModel
|
2 |
-
from configuration_customefficientnetv2 import CustomEfficientNetV2Config
|
3 |
-
import torch
|
4 |
-
|
5 |
-
class CustomEfficientNetV2(PretrainedModel):
|
6 |
-
config_class = CustomEfficientNetV2Config
|
7 |
-
|
8 |
-
def __init__(self, config):
|
9 |
-
super().__init__(config)
|
10 |
-
|
11 |
-
self.url = config.url
|
12 |
-
file_name = self.url.split('/')[-1]
|
13 |
-
self.model = torch.load(file_name)
|
14 |
-
|
15 |
-
self.input_size = config.input_size
|
16 |
-
shape = [2] + self.input_size
|
17 |
-
example_inputs = torch.randn(shape)
|
18 |
-
example_inputs = (example_inputs - example_inputs.min()) / (example_inputs.max() - example_inputs.min())
|
19 |
-
|
20 |
-
self.num_classes = config.num_classes
|
21 |
-
if self.num_classes != 1000:
|
22 |
-
self.model.classifier = torch.nn.Linear(in_features=1984, out_features=self.num_classes, bias=True)
|
23 |
-
|
24 |
-
traced_model = torch.jit.trace(self.model, example_inputs)
|
25 |
-
traced_model.save(file_name)
|
26 |
-
|
27 |
-
self.model = torch.jit.load(file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|