FredZhang7 commited on
Commit
e81feeb
·
1 Parent(s): 55a56ca

Delete modeling_customefficientnetv2.py

Browse files
Files changed (1) hide show
  1. modeling_customefficientnetv2.py +0 -27
modeling_customefficientnetv2.py DELETED
@@ -1,27 +0,0 @@
1
- from transformers import PretrainedModel
2
- from configuration_customefficientnetv2 import CustomEfficientNetV2Config
3
- import torch
4
-
5
- class CustomEfficientNetV2(PretrainedModel):
6
- config_class = CustomEfficientNetV2Config
7
-
8
- def __init__(self, config):
9
- super().__init__(config)
10
-
11
- self.url = config.url
12
- file_name = self.url.split('/')[-1]
13
- self.model = torch.load(file_name)
14
-
15
- self.input_size = config.input_size
16
- shape = [2] + self.input_size
17
- example_inputs = torch.randn(shape)
18
- example_inputs = (example_inputs - example_inputs.min()) / (example_inputs.max() - example_inputs.min())
19
-
20
- self.num_classes = config.num_classes
21
- if self.num_classes != 1000:
22
- self.model.classifier = torch.nn.Linear(in_features=1984, out_features=self.num_classes, bias=True)
23
-
24
- traced_model = torch.jit.trace(self.model, example_inputs)
25
- traced_model.save(file_name)
26
-
27
- self.model = torch.jit.load(file_name)