gyrojeff commited on
Commit
29efa50
·
1 Parent(s): e82af47

feat: add gradio demo

Browse files
Files changed (1) hide show
  1. demo.py +179 -0
demo.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from detector.model import *
7
+ from detector import config
8
+ from font_dataset.font import load_fonts, load_font_with_exclusion
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "-d",
13
+ "--device",
14
+ type=int,
15
+ default=0,
16
+ help="GPU devices to use (default: 0), -1 for CPU",
17
+ )
18
+ parser.add_argument(
19
+ "-c",
20
+ "--checkpoint",
21
+ type=str,
22
+ default=None,
23
+ help="Trainer checkpoint path (default: None)",
24
+ )
25
+ parser.add_argument(
26
+ "-m",
27
+ "--model",
28
+ type=str,
29
+ default="resnet18",
30
+ choices=["resnet18", "resnet34", "resnet50", "resnet101", "deepfont"],
31
+ help="Model to use (default: resnet18)",
32
+ )
33
+ parser.add_argument(
34
+ "-f",
35
+ "--font-classification-only",
36
+ action="store_true",
37
+ help="Font classification only (default: False)",
38
+ )
39
+ parser.add_argument(
40
+ "-z",
41
+ "--size",
42
+ type=int,
43
+ default=512,
44
+ help="Model feature image input size (default: 512)",
45
+ )
46
+ parser.add_argument(
47
+ "-s",
48
+ "--share",
49
+ action="store_true",
50
+ help="Get public link via Gradio (default: False)",
51
+ )
52
+
53
+ args = parser.parse_args()
54
+
55
+ config.INPUT_SIZE = args.size
56
+ device = torch.device("cpu") if args.device == -1 else torch.device("cuda", args.device)
57
+
58
+ regression_use_tanh = False
59
+
60
+ if args.model == "resnet18":
61
+ model = ResNet18Regressor(regression_use_tanh=regression_use_tanh)
62
+ elif args.model == "resnet34":
63
+ model = ResNet34Regressor(regression_use_tanh=regression_use_tanh)
64
+ elif args.model == "resnet50":
65
+ model = ResNet50Regressor(regression_use_tanh=regression_use_tanh)
66
+ elif args.model == "resnet101":
67
+ model = ResNet101Regressor(regression_use_tanh=regression_use_tanh)
68
+ elif args.model == "deepfont":
69
+ assert args.pretrained is False
70
+ assert args.size == 105
71
+ assert args.font_classification_only is True
72
+ model = DeepFontBaseline()
73
+ else:
74
+ raise NotImplementedError()
75
+
76
+ if torch.__version__ >= "2.0" and os.name == "posix":
77
+ model = torch.compile(model)
78
+
79
+ detector = FontDetector(
80
+ model=model,
81
+ lambda_font=1,
82
+ lambda_direction=1,
83
+ lambda_regression=1,
84
+ font_classification_only=args.font_classification_only,
85
+ lr=1,
86
+ betas=(1, 1),
87
+ num_warmup_iters=1,
88
+ num_iters=1e9,
89
+ num_epochs=1e9,
90
+ )
91
+ detector.load_from_checkpoint(
92
+ args.checkpoint,
93
+ map_location=device,
94
+ model=model,
95
+ lambda_font=1,
96
+ lambda_direction=1,
97
+ lambda_regression=1,
98
+ font_classification_only=args.font_classification_only,
99
+ lr=1,
100
+ betas=(1, 1),
101
+ num_warmup_iters=1,
102
+ num_iters=1e9,
103
+ num_epochs=1e9,
104
+ )
105
+ detector = detector.to(device)
106
+ detector.eval()
107
+
108
+
109
+ transform = transforms.Compose(
110
+ [
111
+ transforms.Resize((512, 512)),
112
+ transforms.ToTensor(),
113
+ ]
114
+ )
115
+
116
+ print("Preparing fonts ...")
117
+ font_list, exclusion_rule = load_fonts()
118
+
119
+ font_list = list(filter(lambda x: not exclusion_rule(x), font_list))
120
+ font_list.sort(key=lambda x: x.path)
121
+
122
+ for i in range(len(font_list)):
123
+ font_list[i].path = font_list[i].path[18:] # remove ./dataset/fonts/./ prefix
124
+
125
+ font_demo_images = []
126
+
127
+ for i in range(len(font_list)):
128
+ font_demo_images.append(Image.open(f"demo_fonts/{i}.jpg").convert("RGB"))
129
+
130
+
131
+ def recognize_font(image):
132
+ transformed_image = transform(image)
133
+ with torch.no_grad():
134
+ transformed_image = transformed_image.to(device)
135
+ output = detector(transformed_image.unsqueeze(0))
136
+ prob = output[0][: config.FONT_COUNT].softmax(dim=0)
137
+
138
+ indicies = torch.topk(prob, 9)[1]
139
+
140
+ return [
141
+ {font_list[i].path: float(prob[i]) for i in range(config.FONT_COUNT)},
142
+ *[gr.Image.update(value=font_demo_images[indicies[i]]) for i in range(9)],
143
+ *[
144
+ gr.Markdown.update(
145
+ value=f"**Font Name**: {font_list[indicies[i]].path}"
146
+ )
147
+ for i in range(9)
148
+ ],
149
+ ]
150
+
151
+
152
+ def generate_grid(num_columns, num_rows):
153
+ ret_images, ret_labels = [], []
154
+ with gr.Column():
155
+ for _ in range(num_rows):
156
+ with gr.Row():
157
+ for _ in range(num_columns):
158
+ with gr.Column():
159
+ ret_labels.append(gr.Markdown("**Font Name**"))
160
+ ret_images.append(gr.Image())
161
+ return ret_images, ret_labels
162
+
163
+
164
+ with gr.Blocks() as demo:
165
+ with gr.Column():
166
+ with gr.Row():
167
+ inp = gr.Image(type="pil", label="Input Image")
168
+ out = gr.Label(num_top_classes=9, label="Output Font")
169
+ font_demo_images_blocks, font_demo_labels_blocks = generate_grid(3, 3)
170
+
171
+ submit_button = gr.Button(label="Submit")
172
+ submit_button.click(
173
+ fn=recognize_font,
174
+ inputs=inp,
175
+ outputs=[out, *font_demo_images_blocks, *font_demo_labels_blocks],
176
+ )
177
+
178
+
179
+ demo.launch(share=args.share)