saylee-m commited on
Commit
a106da8
·
1 Parent(s): 656f678
app.py CHANGED
@@ -1,7 +1,217 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from PIL import Image
3
  import gradio as gr
4
+ import re
5
+ import torch
6
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
7
+ from transformers import AutoProcessor, PaliGemmaProcessor, PaliGemmaForConditionalGeneration
8
+ from transformers import AutoModelForVision2Seq
9
+ from huggingface_hub import InferenceClient
10
+ import base64
11
 
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
 
14
+ model_choices = [
15
+ "idefics2",
16
+ "paligemma",
17
+ "donut"
18
+ ]
19
+
20
+
21
+
22
+ def load_donut_model():
23
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
24
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
25
+ model.to(device)
26
+ return model, processor
27
+
28
+ def load_paligemma_docvqa():
29
+ model_id = "google/paligemma-3b-ft-docvqa-896"
30
+ # model_id = "google/paligemma-3b-mix-448"
31
+ processor = AutoProcessor.from_pretrained(model_id)
32
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
33
+ model.to(device)
34
+ return model, processor
35
+
36
+ def load_idefics_docvqa():
37
+ model_id = "HuggingFaceM4/idefics2-8b"
38
+ processor = AutoProcessor.from_pretrained(model_id)
39
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
40
+ model.to(device)
41
+ return model, processor
42
+
43
+ def load_models():
44
+ # load donut
45
+ donut_model, donut_processor = load_donut_model()
46
+ print("donut downloaded")
47
+ #load paligemma
48
+ pg_model, pg_processor = load_paligemma_docvqa()
49
+ print("paligemma downloaded")
50
+
51
+ return {"donut":[donut_model, donut_processor],
52
+ # "idefics": [idf_model, idf_processor],
53
+ "paligemma": [pg_model, pg_processor]}
54
+
55
+ # loaded_models = load_models()
56
+
57
+ def base64_encoded_image(image_array):
58
+ im = Image.fromarray(image_array)
59
+ buffered = BytesIO()
60
+ im.save(buffered, format="PNG")
61
+ image_bytes = buffered.getvalue()
62
+ image_base64 = base64.b64encode(image_bytes).decode('ascii')
63
+ return image_base64
64
+
65
+
66
+ def inference_calling_idefics(image_array, question):
67
+ model_id = "HuggingFaceM4/idefics2-8b"
68
+ client = InferenceClient(model=model_id)
69
+ image_base64 = base64_encoded_image(image_array)
70
+ image_info = f"data:image/png;base64,{image_base64}"
71
+ prompt = f"![]({image_info}){question}\n\n"
72
+ response = client.text_generation(prompt)
73
+ return response
74
+
75
+
76
+ def process_document_donut(image_array, question):
77
+ model, processor = loaded_models.get("donut")
78
+
79
+ # prepare encoder inputs
80
+ pixel_values = processor(image_array, return_tensors="pt").pixel_values
81
+
82
+ # prepare decoder inputs
83
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
84
+ prompt = task_prompt.replace("{user_input}", question)
85
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
86
+
87
+ # generate answer
88
+ outputs = model.generate(
89
+ pixel_values.to(device),
90
+ decoder_input_ids=decoder_input_ids.to(device),
91
+ max_length=model.decoder.config.max_position_embeddings,
92
+ early_stopping=True,
93
+ pad_token_id=processor.tokenizer.pad_token_id,
94
+ eos_token_id=processor.tokenizer.eos_token_id,
95
+ use_cache=True,
96
+ num_beams=1,
97
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
98
+ return_dict_in_generate=True,
99
+ )
100
+
101
+ # postprocess
102
+ sequence = processor.batch_decode(outputs.sequences)[0]
103
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
104
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
105
+ op = processor.token2json(sequence)
106
+ op = op.get("answer", str(op))
107
+
108
+ return op
109
+
110
+ def process_document_pg(image_array, question):
111
+ model, processor = loaded_models.get("paligemma")
112
+
113
+ inputs = processor(images=image_array, text=question, return_tensors="pt").to(device)
114
+ predictions = model.generate(**inputs, max_new_tokens=100)
115
+ return processor.batch_decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
116
+
117
+ def process_document_idf(image_array, question):
118
+ model, processor = loaded_models.get("idefics")
119
+
120
+ inputs = processor(images=image_array, text=question, return_tensors="pt") #.to(device)
121
+ predictions = model.generate(**inputs, max_new_tokens=100)
122
+ return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
123
+
124
+
125
+ def generate_answer_donut(image_array, question):
126
+ try:
127
+ answer = process_document_donut(image_array, question)
128
+ print(answer)
129
+ return answer
130
+ except Exception as e:
131
+ print(e)
132
+ gr.Warning("There is some issue, please try again later.")
133
+ return "sorry :("
134
+
135
+ def generate_answer_idefics(image_array, question):
136
+ try:
137
+ # answer = process_document_idf(image_array, question)
138
+ answer = inference_calling_idefics(image_array, question)
139
+ print(answer)
140
+ return answer
141
+ except Exception as e:
142
+ print(e)
143
+ gr.Warning("There is some issue, please try again later.")
144
+ return "sorry :("
145
+
146
+ def generate_answer_paligemma(image_array, question):
147
+ try:
148
+ answer = process_document_pg(image_array, question)
149
+ print(answer)
150
+ return answer
151
+ except Exception as e:
152
+ print(e)
153
+ gr.Warning("There is some issue, please try again later.")
154
+ return "sorry :("
155
+
156
+ def generate_answers(image_path, question, selected_model=model_choices[0]):
157
+ try:
158
+ if selected_model == "donut":
159
+ answer = generate_answer_donut(image_path, question)
160
+ elif selected_model == "paligemma":
161
+ answer = generate_answer_paligemma(image_path, question)
162
+ else:
163
+ answer = generate_answer_idefics(image_path, question)
164
+
165
+ return [answer] #[donut_answer, pg_answer, idf_answer]
166
+ except Exception as e:
167
+ print(e)
168
+ gr.Warning("There is some issue, please try again later.")
169
+ return ["sorry :("]
170
+
171
+
172
+ def greet(name, shame, game):
173
+ return "Hello " + shame + "!!"
174
+
175
+ INTRO_TEXT = """## VQA demo\n\n
176
+ VQA task models comparison
177
+ This space is to compare multiple models on visual document question answering. \n\n
178
+ **Note: As the app is running on CPU currently, you might get error if you run multiple models back to back. Please reload the app to get the output.
179
+ """
180
+
181
+ with gr.Blocks(css="style.css") as demo:
182
+ gr.Markdown(INTRO_TEXT)
183
+ # with gr.Tab("Text Generation"):
184
+ with gr.Column():
185
+ image = gr.Image(label="Input Image")
186
+ question = gr.Text(label="Question")
187
+ selected_model = gr.Radio(model_choices, label="Model", info="Select the model you want to run")
188
+
189
+ outputs_answer = gr.Text(label="Answer generated by the selected model")
190
+ run_button = gr.Button()
191
+
192
+ inputs = [
193
+ image,
194
+ question,
195
+ selected_model
196
+ ]
197
+ outputs = [
198
+ outputs_answer
199
+ ]
200
+ run_button.click(
201
+ fn=generate_answers,
202
+ inputs=inputs,
203
+ outputs=outputs,
204
+ )
205
+
206
+ examples = [["images/sample_vendor_contract.png", "Who is agreement between?"],
207
+ ["images/apple-10k-form.png", "What are EMEA revenues in 2017?"],
208
+ ["images/bel-infographic.png", "What is total turnover?"],
209
+ ]
210
+ gr.Examples(
211
+ examples=examples,
212
+ inputs=inputs,
213
+ )
214
+
215
+
216
+ if __name__ == "__main__":
217
+ demo.queue(max_size=10).launch(debug=True)
images/apple-10k-form.png ADDED
images/sample_vendor_contract.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ sentencepiece