pierreguillou commited on
Commit
d0e0e62
·
1 Parent(s): 9361e03

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -0
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import random
4
+ import pandas as pd
5
+ import numpy as np
6
+ from datasets import concatenate_datasets
7
+ from operator import itemgetter
8
+ import collections
9
+
10
+ # download datasets
11
+ from datasets import load_dataset
12
+
13
+ dataset_small = load_dataset("pierreguillou/DocLayNet-small")
14
+ dataset_base = load_dataset("pierreguillou/DocLayNet-base")
15
+
16
+ id2label = {idx:label for idx,label in enumerate(dataset_small["train"].features["categories"].feature.names)}
17
+ labels = [label for idx, label in id2label.items()]
18
+
19
+ # need to change the coordinates format
20
+ def convert_box(box):
21
+ x, y, w, h = tuple(box) # the row comes in (left, top, width, height) format
22
+ actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
23
+ return actual_box
24
+
25
+ # get back original size
26
+ def original_box(box, original_width, original_height, coco_width, coco_height):
27
+ return [
28
+ int(original_width * (box[0] / coco_width)),
29
+ int(original_height * (box[1] / coco_height)),
30
+ int(original_width * (box[2] / coco_width)),
31
+ int(original_height* (box[3] / coco_height)),
32
+ ]
33
+
34
+ # function to sort bounding boxes
35
+ def get_sorted_boxes(bboxes):
36
+
37
+ # sort by y from page top to bottom
38
+ bboxes = sorted(bboxes, key=itemgetter(1), reverse=False)
39
+ y_list = [bbox[1] for bbox in bboxes]
40
+
41
+ # sort by x from page left to right when boxes with same y
42
+ if len(list(set(y_list))) != len(y_list):
43
+ y_list_duplicates_indexes = dict()
44
+ y_list_duplicates = [item for item, count in collections.Counter(y_list).items() if count > 1]
45
+ for item in y_list_duplicates:
46
+ y_list_duplicates_indexes[item] = [i for i, e in enumerate(y_list) if e == item]
47
+ bbox_list_y_duplicates = sorted(np.array(bboxes)[y_list_duplicates_indexes[item]].tolist(), key=itemgetter(0), reverse=False)
48
+ np_array_bboxes = np.array(bboxes)
49
+ np_array_bboxes[y_list_duplicates_indexes[item]] = np.array(bbox_list_y_duplicates)
50
+ bboxes = np_array_bboxes.tolist()
51
+
52
+ return bboxes
53
+
54
+ # categories colors
55
+ label2color = {
56
+ 'Caption': 'brown',
57
+ 'Footnote': 'orange',
58
+ 'Formula': 'gray',
59
+ 'List-item': 'yellow',
60
+ 'Page-footer': 'red',
61
+ 'Page-header': 'red',
62
+ 'Picture': 'violet',
63
+ 'Section-header': 'orange',
64
+ 'Table': 'green',
65
+ 'Text': 'blue',
66
+ 'Title': 'pink'
67
+ }
68
+
69
+ # image witout content
70
+ examples_dir = 'samples/'
71
+ images_wo_content = examples_dir + "wo_content.png"
72
+
73
+ df_paragraphs_wo_content, df_lines_wo_content = pd.DataFrame(), pd.DataFrame()
74
+
75
+ df_paragraphs_wo_content["paragraphs"] = [0]
76
+ df_paragraphs_wo_content["categories"] = ["no content"]
77
+ df_paragraphs_wo_content["texts"] = ["no content"]
78
+ df_paragraphs_wo_content["bounding boxes"] = ["no content"]
79
+
80
+ df_lines_wo_content["lines"] = [0]
81
+ df_lines_wo_content["categories"] = ["no content"]
82
+ df_lines_wo_content["texts"] = ["no content"]
83
+ df_lines_wo_content["bounding boxes"] = ["no content"]
84
+
85
+ # lists
86
+ font = ImageFont.load_default()
87
+
88
+ dataset_names = ["small", "base"]
89
+ splits = ["all", "train", "validation", "test"]
90
+ domains = ["all", "Financial Reports", "Manuals", "Scientific Articles", "Laws & Regulations", "Patents", "Government Tenders"]
91
+ domains_names = [domain_name.lower().replace(" ", "_") for domain_name in domains]
92
+ categories = labels + ["all"]
93
+
94
+ # function to get a rendom image and all data from DocLayNet
95
+ def generate_annotated_image(dataset_name, split, domain, category):
96
+
97
+ def get_dataset(dataset_name, split, domain, category):
98
+
99
+ # error message
100
+ msg_error = ""
101
+
102
+ # get dataset
103
+ if dataset_name == "small": example = dataset_small
104
+ else: example = dataset_base
105
+
106
+ # get split
107
+ if split == "all":
108
+ example = concatenate_datasets([example["train"], example["validation"], example["test"]])
109
+ else:
110
+ example = example[split]
111
+
112
+ # get domain
113
+ domain_name = domains_names[domains.index(domain)]
114
+ if domain_name != "all":
115
+ example = example.filter(lambda example: example["doc_category"] == domain_name)
116
+ if len(example) == 0:
117
+ msg_error = f'There is no image with at least one annotated bounding box that matches to your parameters ("{domain}" domain / "DocLayNet {dataset_name}" dataset splitted into "{split}").'
118
+ example = dict()
119
+ return example, msg_error
120
+
121
+ # get category
122
+ idx_list = list()
123
+ if category != "all":
124
+ for idx, categories_list in zip(example["id"], example["categories"]):
125
+ if category in categories_list:
126
+ idx_list.append(idx)
127
+ example = example.select(idx_list)
128
+ if len(example) == 0:
129
+ msg_error = f'There is no image with at least one annotated bounding box that matches to your parameters (category: "{category}" / domain: "{domain}" / dataset: "DocLayNet {dataset_name}" / split: "{split}").'
130
+ example = dict()
131
+ return example, msg_error
132
+
133
+ return example, msg_error
134
+
135
+ # get results
136
+ example, msg_error = get_dataset(dataset_name, split, domain, category)
137
+
138
+ if len(msg_error) > 0:
139
+ return msg_error, images_wo_content, images_wo_content, df_paragraphs_wo_content, df_lines_wo_content
140
+
141
+ else:
142
+ # get random image & PDF data
143
+ image_files = example["image"]
144
+ index = random.randint(0, len(image_files))
145
+ image = image_files[index] # original image
146
+ coco_width, coco_height = example[index]["coco_width"], example[index]["coco_height"]
147
+ original_width, original_height = example[index]["original_width"], example[index]["original_height"]
148
+ original_filename = example[index]["original_filename"]
149
+ page_no = example[index]["page_no"]
150
+ num_pages = example[index]["num_pages"]
151
+
152
+ # resize image to original
153
+ image = image.resize((original_width, original_height))
154
+
155
+ # get corresponding annotations
156
+ texts = example[index]["texts"]
157
+ bboxes_block = example[index]["bboxes_block"]
158
+ bboxes_line = example[index]["bboxes_line"]
159
+ categories = example[index]["categories"]
160
+ domain = example[index]["doc_category"]
161
+
162
+ # get list of categories
163
+ categories_unique = sorted(list(set([categories_list for categories_list in categories])))
164
+ categories_unique = [id2label[idx] for idx in categories_unique]
165
+
166
+ # convert boxes to original
167
+ original_bboxes_block = [original_box(convert_box(box), original_width, original_height, coco_width, coco_height) for box in bboxes_block]
168
+ original_bboxes_line = [original_box(convert_box(box), original_width, original_height, coco_width, coco_height) for box in bboxes_line]
169
+ original_bboxes = [original_bboxes_block, original_bboxes_line]
170
+
171
+ ##### block boxes #####
172
+
173
+ # get list of unique block boxes
174
+ original_blocks = dict()
175
+ original_bboxes_block_list = list()
176
+ original_bbox_block_prec = list()
177
+ for count_block, original_bbox_block in enumerate(original_bboxes_block):
178
+ if original_bbox_block != original_bbox_block_prec:
179
+ original_bbox_block_indexes = [i for i, original_bbox in enumerate(original_bboxes_block) if original_bbox == original_bbox_block]
180
+ original_blocks[count_block] = original_bbox_block_indexes
181
+ original_bboxes_block_list.append(original_bbox_block)
182
+ original_bbox_block_prec = original_bbox_block
183
+
184
+ # get list of categories and texts by unique block boxes
185
+ category_block_list, text_block_list = list(), list()
186
+ for original_bbox_block in original_bboxes_block_list:
187
+ count_block = original_bboxes_block.index(original_bbox_block)
188
+ original_bbox_block_indexes = original_blocks[count_block]
189
+ category_block = categories[original_bbox_block_indexes[0]]
190
+ category_block_list.append(category_block)
191
+ if id2label[category_block] == "Text" or id2label[category_block] == "Caption" or id2label[category_block] == "Footnote":
192
+ text_block = ' '.join(np.array(texts)[original_bbox_block_indexes].tolist())
193
+ elif id2label[category_block] == "Section-header" or id2label[category_block] == "Title" or id2label[category_block] == "Picture" or id2label[category_block] == "Formula" or id2label[category_block] == "List-item" or id2label[category_block] == "Table" or id2label[category_block] == "Page-header" or id2label[category_block] == "Page-footer":
194
+ text_block = '\n'.join(np.array(texts)[original_bbox_block_indexes].tolist())
195
+ text_block_list.append(text_block)
196
+
197
+ # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
198
+ sorted_original_bboxes_block_list = get_sorted_boxes(original_bboxes_block_list)
199
+ sorted_original_bboxes_block_list_indexes = [original_bboxes_block_list.index(item) for item in sorted_original_bboxes_block_list]
200
+ sorted_category_block_list = np.array(category_block_list)[sorted_original_bboxes_block_list_indexes].tolist()
201
+ sorted_text_block_list = np.array(text_block_list)[sorted_original_bboxes_block_list_indexes].tolist()
202
+
203
+ ##### line boxes ####
204
+
205
+ # sort data from y = 0 to end of page (and after, x=0 to end of page when necessary)
206
+ original_bboxes_line_list = original_bboxes_line
207
+ category_line_list = categories
208
+ text_line_list = texts
209
+ sorted_original_bboxes_line_list = get_sorted_boxes(original_bboxes_line_list)
210
+ sorted_original_bboxes_line_list_indexes = [original_bboxes_line_list.index(item) for item in sorted_original_bboxes_line_list]
211
+ sorted_category_line_list = np.array(category_line_list)[sorted_original_bboxes_line_list_indexes].tolist()
212
+ sorted_text_line_list = np.array(text_line_list)[sorted_original_bboxes_line_list_indexes].tolist()
213
+
214
+ # setup images & PDf data
215
+ columns = 2
216
+ images = [image.copy(), image.copy()]
217
+ num_imgs = len(images)
218
+
219
+ imgs, df_paragraphs, df_lines = dict(), pd.DataFrame(), pd.DataFrame()
220
+ for i, img in enumerate(images):
221
+
222
+ draw = ImageDraw.Draw(img)
223
+
224
+ for box, label_idx, text in zip(original_bboxes[i], categories, texts):
225
+ label = id2label[label_idx]
226
+ color = label2color[label]
227
+ draw.rectangle(box, outline=color)
228
+ text = text.encode('latin-1', 'replace').decode('latin-1') # https://stackoverflow.com/questions/56761449/unicodeencodeerror-latin-1-codec-cant-encode-character-u2013-writing-to
229
+ draw.text((box[0] + 10, box[1] - 10), text=label, fill=color, font=font)
230
+
231
+ if i == 0:
232
+ imgs["paragraphs"] = img
233
+
234
+ df_paragraphs["paragraphs"] = list(range(len(sorted_original_bboxes_block_list)))
235
+ df_paragraphs["categories"] = [id2label[label_idx] for label_idx in sorted_category_block_list]
236
+ df_paragraphs["texts"] = sorted_text_block_list
237
+ df_paragraphs["bounding boxes"] = [str(bbox) for bbox in sorted_original_bboxes_block_list]
238
+
239
+ else:
240
+ imgs["lines"] = img
241
+
242
+ df_lines["lines"] = list(range(len(sorted_original_bboxes_line_list)))
243
+ df_lines["categories"] = [id2label[label_idx] for label_idx in sorted_category_line_list]
244
+ df_lines["texts"] = sorted_text_line_list
245
+ df_lines["bounding boxes"] = [str(bbox) for bbox in sorted_original_bboxes_line_list]
246
+
247
+ msg = f'The page {page_no} of PDF "{original_filename}" (domain "{domain}") matches your parameters.'
248
+
249
+ return msg, imgs["paragraphs"], imgs["lines"], df_paragraphs, df_lines
250
+
251
+ # gradio APP
252
+ with gr.Blocks(title="DocLayNet image viewer", css=".gradio-container") as demo:
253
+ gr.HTML("""
254
+ <div style="font-family:'Times New Roman', 'Serif'; font-size:26pt; font-weight:bold; text-align:center;"><h1>DocLayNet image viewer</h1></div>
255
+ <div style="margin-top: 20px"><p>(01/29/2023) This APP is an image viewer of the DocLayNet dataset.</p></div>
256
+ <div><p>It uses the datasets <a href="https://huggingface.co/datasets/pierreguillou/DocLayNet-small" target="_blank">DocLayNet small</a> and <a href="https://huggingface.co/datasets/pierreguillou/DocLayNet-base" target="_blank">DocLayNet base</a>.</p></div>
257
+ <div><p>Make your parameters selections and the output will show 2 images of a randomly selected PDF with annotated bounding boxes, one of paragraphs and the other of lines, and a table of texts with their labels.</p></div>
258
+ """)
259
+ with gr.Row():
260
+ with gr.Column():
261
+ dataset_name_gr = gr.Radio(dataset_names, value="small", label="DocLayNet dataset")
262
+ with gr.Column():
263
+ split_gr = gr.Dropdown(splits, value="all", label="Split")
264
+ with gr.Column():
265
+ domain_gr = gr.Dropdown(domains, value="all", label="Domain")
266
+ with gr.Column():
267
+ category_gr = gr.Dropdown(categories, value="all", label="Category")
268
+ btn = gr.Button("Display PDF image")
269
+ with gr.Row():
270
+ output_msg = gr.Textbox(label="Results")
271
+ with gr.Row():
272
+ # with gr.Column():
273
+ # json = gr.JSON(label="JSON")
274
+ with gr.Column():
275
+ img_paragraphs = gr.Image(type="pil", label="Bounding boxes of paragraphs")
276
+ with gr.Column():
277
+ img_lines = gr.Image(type="pil", label="Bounding boxes of lines")
278
+
279
+ with gr.Row():
280
+ with gr.Column():
281
+ df_paragraphs = gr.Dataframe(
282
+ headers=["paragraphs", "categories", "texts", "bounding boxes"],
283
+ datatype=["number", "str", "str", "str"],
284
+ # row_count='dynamic',
285
+ col_count=(4, "fixed"),
286
+ interactive=False,
287
+ label="Paragraphs data",
288
+ type="pandas",
289
+ wrap=True
290
+ )
291
+ with gr.Column():
292
+ df_lines = gr.Dataframe(
293
+ headers=["lines", "categories", "texts", "bounding boxes"],
294
+ datatype=["number", "str", "str", "str"],
295
+ # row_count='dynamic',
296
+ col_count=(4, "fixed"),
297
+ interactive=False,
298
+ label="Lines data",
299
+ type="pandas",
300
+ wrap=True
301
+ )
302
+ btn.click(generate_annotated_image, inputs=[dataset_name_gr, split_gr, domain_gr, category_gr], outputs=[output_msg, img_paragraphs, img_lines, df_paragraphs, df_lines])
303
+
304
+ gr.Markdown("## Example")
305
+ gr.Examples(
306
+ [["small", "all", "all", "all"]],
307
+ [dataset_name_gr, split_gr, domain_gr, category_gr],
308
+ [output_msg, img_paragraphs, img_lines, df_paragraphs, df_lines],
309
+ fn=generate_annotated_image,
310
+ cache_examples=True,
311
+ )
312
+
313
+ demo.launch()