admin commited on
Commit
fabd5cb
·
1 Parent(s): 9c6afa6
Files changed (2) hide show
  1. app.py +3 -3
  2. model.py +1 -7
app.py CHANGED
@@ -87,14 +87,14 @@ if __name__ == "__main__":
87
  gr.Interface(
88
  fn=infer,
89
  inputs=[
90
- gr.Image(label="上传图片 Upload an image", type="filepath"),
91
  gr.Dropdown(
92
- label="选择权重 Select a model",
93
  choices=models,
94
  value=models[0],
95
  ),
96
  ],
97
- outputs=gr.Textbox(label="识别结果 Recognition result", show_copy_button=True),
98
  examples=samples,
99
  flagging_mode="never",
100
  cache_examples=False,
 
87
  gr.Interface(
88
  fn=infer,
89
  inputs=[
90
+ gr.Image(label="Upload an image", type="filepath"),
91
  gr.Dropdown(
92
+ label="Select a model",
93
  choices=models,
94
  value=models[0],
95
  ),
96
  ],
97
+ outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
98
  examples=samples,
99
  flagging_mode="never",
100
  cache_examples=False,
model.py CHANGED
@@ -7,7 +7,6 @@ import torch.nn as nn
7
 
8
  class Model(torch.jit.ScriptModule):
9
  CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
10
-
11
  __constants__ = [
12
  "_hidden1",
13
  "_hidden2",
@@ -31,7 +30,6 @@ class Model(torch.jit.ScriptModule):
31
 
32
  def __init__(self):
33
  super(Model, self).__init__()
34
-
35
  self._hidden1 = nn.Sequential(
36
  nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
37
  nn.BatchNorm2d(num_features=48),
@@ -90,7 +88,6 @@ class Model(torch.jit.ScriptModule):
90
  )
91
  self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
92
  self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
93
-
94
  self._digit_length = nn.Sequential(nn.Linear(3072, 7))
95
  self._digit1 = nn.Sequential(nn.Linear(3072, 11))
96
  self._digit2 = nn.Sequential(nn.Linear(3072, 11))
@@ -111,14 +108,12 @@ class Model(torch.jit.ScriptModule):
111
  x = x.view(x.size(0), 192 * 7 * 7)
112
  x = self._hidden9(x)
113
  x = self._hidden10(x)
114
-
115
  length_logits = self._digit_length(x)
116
  digit1_logits = self._digit1(x)
117
  digit2_logits = self._digit2(x)
118
  digit3_logits = self._digit3(x)
119
  digit4_logits = self._digit4(x)
120
  digit5_logits = self._digit5(x)
121
-
122
  return (
123
  length_logits,
124
  digit1_logits,
@@ -154,5 +149,4 @@ class Model(torch.jit.ScriptModule):
154
  self.load_state_dict(
155
  torch.load(path_to_checkpoint_file, map_location=torch.device("cpu"))
156
  )
157
- step = int(path_to_checkpoint_file.split("model-")[-1][:-4])
158
- return step
 
7
 
8
  class Model(torch.jit.ScriptModule):
9
  CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
 
10
  __constants__ = [
11
  "_hidden1",
12
  "_hidden2",
 
30
 
31
  def __init__(self):
32
  super(Model, self).__init__()
 
33
  self._hidden1 = nn.Sequential(
34
  nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
35
  nn.BatchNorm2d(num_features=48),
 
88
  )
89
  self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
90
  self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
 
91
  self._digit_length = nn.Sequential(nn.Linear(3072, 7))
92
  self._digit1 = nn.Sequential(nn.Linear(3072, 11))
93
  self._digit2 = nn.Sequential(nn.Linear(3072, 11))
 
108
  x = x.view(x.size(0), 192 * 7 * 7)
109
  x = self._hidden9(x)
110
  x = self._hidden10(x)
 
111
  length_logits = self._digit_length(x)
112
  digit1_logits = self._digit1(x)
113
  digit2_logits = self._digit2(x)
114
  digit3_logits = self._digit3(x)
115
  digit4_logits = self._digit4(x)
116
  digit5_logits = self._digit5(x)
 
117
  return (
118
  length_logits,
119
  digit1_logits,
 
149
  self.load_state_dict(
150
  torch.load(path_to_checkpoint_file, map_location=torch.device("cpu"))
151
  )
152
+ return int(path_to_checkpoint_file.split("model-")[-1][:-4])