ewfian commited on
Commit
e5b8f4d
·
1 Parent(s): 7a1acae

Add application file

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+
8
+ print("test1")
9
+ processor = DonutProcessor.from_pretrained("ewfian/donut_cn_invoice")
10
+ print("test2")
11
+ model = VisionEncoderDecoderModel.from_pretrained("ewfian/donut_cn_invoice")
12
+ print("test3")
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
+
17
+ task_prompt = "<s_totalAmountInWords>"
18
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
19
+
20
+ print("test")
21
+ print(decoder_input_ids.shape)
22
+
23
+ def process_document(image):
24
+ print("test2")
25
+
26
+ pixel_values = processor(image, return_tensors="pt").pixel_values
27
+
28
+ print(pixel_values.shape)
29
+ print(pixel_values)
30
+
31
+ outputs = model.generate(
32
+ pixel_values.to(device),
33
+ decoder_input_ids=decoder_input_ids.to(device),
34
+ max_length=model.decoder.config.max_position_embeddings,
35
+ pad_token_id=processor.tokenizer.pad_token_id,
36
+ eos_token_id=processor.tokenizer.eos_token_id,
37
+ use_cache=True,
38
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
39
+ return_dict_in_generate=True,
40
+ )
41
+
42
+ sequence = processor.batch_decode(outputs.sequences)[0]
43
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
+ return processor.token2json(sequence)
46
+
47
+ # t = process_document(test_sample)
48
+ # print(t)
49
+
50
+ demo = gr.Interface(
51
+ fn=process_document,
52
+ inputs="image",
53
+ outputs="json",
54
+ title="Demo: Donut 🍩 for Invioce Parsing",
55
+ cache_examples=False)
56
+
57
+ demo.launch(server_name="0.0.0.0")