Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -126,6 +126,7 @@ def evaluate(
|
|
126 |
table,
|
127 |
question,
|
128 |
llm="alpaca-lora",
|
|
|
129 |
input=None,
|
130 |
temperature=0.1,
|
131 |
top_p=0.75,
|
@@ -134,9 +135,13 @@ def evaluate(
|
|
134 |
max_new_tokens=128,
|
135 |
**kwargs,
|
136 |
):
|
|
|
137 |
prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
138 |
if llm == "alpaca-lora":
|
139 |
-
|
|
|
|
|
|
|
140 |
input_ids = inputs["input_ids"].to(device)
|
141 |
generation_config = GenerationConfig(
|
142 |
temperature=temperature,
|
@@ -155,15 +160,15 @@ def evaluate(
|
|
155 |
)
|
156 |
s = generation_output.sequences[0]
|
157 |
output = tokenizer.decode(s)
|
158 |
-
# output = query({
|
159 |
-
# "inputs": prompt
|
160 |
-
# })
|
161 |
elif llm == "flan-ul2":
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
167 |
else:
|
168 |
RuntimeError(f"No such LLM: {llm}")
|
169 |
|
@@ -200,6 +205,11 @@ demo = gr.Interface(
|
|
200 |
),
|
201 |
"image",
|
202 |
"text"],
|
|
|
|
|
|
|
|
|
|
|
203 |
outputs=[
|
204 |
gr.inputs.Textbox(
|
205 |
lines=8,
|
@@ -214,11 +224,11 @@ demo = gr.Interface(
|
|
214 |
description=description,
|
215 |
article=article,
|
216 |
enable_queue=True,
|
217 |
-
examples=[["alpaca-lora", "deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step."],
|
218 |
-
["alpaca-lora", "deplot_case_study_m1.png", "Summarise the chart for me please."],
|
219 |
-
["alpaca-lora", "deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step."],
|
220 |
-
["alpaca-lora", "deplot_case_study_4.png", "How many papers are submitted in 2020?"],
|
221 |
-
["alpaca-lora", "deplot_case_study_x2.png", "Summarise the chart for me please."]],
|
222 |
cache_examples=True)
|
223 |
|
224 |
demo.launch(debug=True)
|
|
|
126 |
table,
|
127 |
question,
|
128 |
llm="alpaca-lora",
|
129 |
+
shot="1-shot",
|
130 |
input=None,
|
131 |
temperature=0.1,
|
132 |
top_p=0.75,
|
|
|
135 |
max_new_tokens=128,
|
136 |
**kwargs,
|
137 |
):
|
138 |
+
prompt_0shot = _INSTRUCTION + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
139 |
prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
140 |
if llm == "alpaca-lora":
|
141 |
+
if shot == "1-shot":
|
142 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
143 |
+
else:
|
144 |
+
inputs = tokenizer(prompt_0shot, return_tensors="pt")
|
145 |
input_ids = inputs["input_ids"].to(device)
|
146 |
generation_config = GenerationConfig(
|
147 |
temperature=temperature,
|
|
|
160 |
)
|
161 |
s = generation_output.sequences[0]
|
162 |
output = tokenizer.decode(s)
|
|
|
|
|
|
|
163 |
elif llm == "flan-ul2":
|
164 |
+
if shot == "1-shot":
|
165 |
+
output = query({
|
166 |
+
"inputs": prompt
|
167 |
+
})[0]["generated_text"]
|
168 |
+
else:
|
169 |
+
output = query({
|
170 |
+
"inputs": prompt_0shot
|
171 |
+
})[0]["generated_text"]
|
172 |
else:
|
173 |
RuntimeError(f"No such LLM: {llm}")
|
174 |
|
|
|
205 |
),
|
206 |
"image",
|
207 |
"text"],
|
208 |
+
gr.Dropdown(
|
209 |
+
["0-shot", "1-shot"], label="#shots", info="How many example tables in the prompt?"
|
210 |
+
),
|
211 |
+
"image",
|
212 |
+
"text"],
|
213 |
outputs=[
|
214 |
gr.inputs.Textbox(
|
215 |
lines=8,
|
|
|
224 |
description=description,
|
225 |
article=article,
|
226 |
enable_queue=True,
|
227 |
+
examples=[["alpaca-lora", "1-shot", "deplot_case_study_m1.png", "What is the sum of numbers of Indonesia and Ireland? Remember to think step by step."],
|
228 |
+
["alpaca-lora", "1-shot", "deplot_case_study_m1.png", "Summarise the chart for me please."],
|
229 |
+
["alpaca-lora", "1-shot", "deplot_case_study_3.png", "By how much did China's growth rate drop? Think step by step."],
|
230 |
+
["alpaca-lora", "1-shot", "deplot_case_study_4.png", "How many papers are submitted in 2020?"],
|
231 |
+
["alpaca-lora", "1-shot", "deplot_case_study_x2.png", "Summarise the chart for me please."]],
|
232 |
cache_examples=True)
|
233 |
|
234 |
demo.launch(debug=True)
|