to-be commited on
Commit
3ce6ec3
·
1 Parent(s): 08dbced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -35
app.py CHANGED
@@ -1,44 +1,56 @@
1
- """
2
- Donut
3
- Copyright (c) 2022-present NAVER Corp.
4
- MIT License
5
- https://github.com/clovaai/donut
6
- """
7
  import gradio as gr
 
8
  import torch
9
- from PIL import Image
10
  from transformers import DonutProcessor, VisionEncoderDecoderModel
11
 
12
- from donut import DonutModel
13
-
14
- def demo_process(input_img):
15
- global pretrained_model, task_prompt, task_name
16
- # input_img = Image.fromarray(input_img)
17
- output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
18
- return output
19
-
20
- task_prompt = f"<s_cord-v2>"
21
-
22
-
23
- #pretrained_model = DonutModel.from_pretrained("to-be/ID_cards_v1",revision="main")
24
- pretrained_model = VisionEncoderDecoderModel.from_pretrained("to-be/ID_cards_v1")
25
- pretrained_model.eval()
26
-
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  demo = gr.Interface(
30
- fn=demo_process,
31
- inputs= gr.inputs.Image(type="pil"),
32
  outputs="json",
33
- title=f"Donut 🍩 demonstration for `cord-v2` task",
34
- description="""This model is trained with 800 Indonesian receipt images of CORD dataset. <br>
35
- Demonstrations for other types of documents/tasks are available at https://github.com/clovaai/donut <br>
36
- More CORD receipt images are available at https://huggingface.co/datasets/naver-clova-ix/cord-v2
37
- More details are available at:
38
- - Paper: https://arxiv.org/abs/2111.15664
39
- - GitHub: https://github.com/clovaai/donut""",
40
- examples=[["3001.jpg"]],
41
- cache_examples=False,
42
- )
43
 
44
  demo.launch()
 
1
+ import re
 
 
 
 
 
2
  import gradio as gr
3
+
4
  import torch
 
5
  from transformers import DonutProcessor, VisionEncoderDecoderModel
6
 
7
+ processor = DonutProcessor.from_pretrained("to-be/ID_cards_v1")
8
+ model = VisionEncoderDecoderModel.from_pretrained("to-be/ID_cards_v1")
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+
13
+ def process_document(image):
14
+ # prepare encoder inputs
15
+ pixel_values = processor(image, return_tensors="pt").pixel_values
16
+
17
+ # prepare decoder inputs
18
+ task_prompt = "<s_cord-v2>"
19
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
20
+
21
+ # generate answer
22
+ outputs = model.generate(
23
+ pixel_values.to(device),
24
+ decoder_input_ids=decoder_input_ids.to(device),
25
+ max_length=model.decoder.config.max_position_embeddings,
26
+ early_stopping=True,
27
+ pad_token_id=processor.tokenizer.pad_token_id,
28
+ eos_token_id=processor.tokenizer.eos_token_id,
29
+ use_cache=True,
30
+ num_beams=1,
31
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
32
+ return_dict_in_generate=True,
33
+ )
34
+
35
+ # postprocess
36
+ sequence = processor.batch_decode(outputs.sequences)[0]
37
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
38
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
39
+
40
+ return processor.token2json(sequence)
41
+
42
+ description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on CORD (document parsing). To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
43
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
44
 
45
  demo = gr.Interface(
46
+ fn=process_document,
47
+ inputs="image",
48
  outputs="json",
49
+ title="Demo: Donut 🍩 for Document Parsing",
50
+ description=description,
51
+ article=article,
52
+ enable_queue=True,
53
+ examples=[["3001.png"]],
54
+ cache_examples=False)
 
 
 
 
55
 
56
  demo.launch()