Spaces:
Runtime error
Runtime error
File size: 4,588 Bytes
b27b0a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
from models.vsa_model import VisionSearchAssistant
from models.vsa_prompt import COCO_CLASSES
SAMPLES = {
"images/iclr.jpg": ("What prize did this paper win in 2024?", ", ".join(COCO_CLASSES)),
"images/tesla.jpg": ("What's the income of this company?", "car"),
"images/xiaomi.jpg": ("Provide information about the new products of this brand.", ", ".join(COCO_CLASSES)),
"images/leshi.jpg": ("Provide information about new products of this brand of potato chips in 2024.", ", ".join(COCO_CLASSES)),
}
SAMPLE_IMAGES = list(SAMPLES.keys())
SAMPLE_TEXTS = [e[0] for e in SAMPLES.values()]
SAMPLE_CLASSES = [e[1] for e in SAMPLES.values()]
def process_inputs(image, text, ground_classes):
if len(ground_classes) == 0:
ground_classes = None
else:
ground_classes = ground_classes.split(', ')
ground_output, query_output, search_output, answer_output = None, None, None, None
for output, output_type in vsa.app_run(image, text, ground_classes = ground_classes):
if output_type == 'ground':
ground_output = output
yield ground_output, query_output, search_output, answer_output
elif output_type == 'query':
query_output = ''
for qid, query in enumerate(output):
query_output += '[Area {}] '.format(qid) + query + '\n'
yield ground_output, query_output, search_output, answer_output
elif output_type == 'search':
search_output = ''
for cid, context in enumerate(output):
search_output += '[Context {}] '.format(cid) + context + '\n'
yield ground_output, query_output, search_output, answer_output
elif output_type == 'answer':
answer_output = output
yield ground_output, query_output, search_output, answer_output
def select_sample_inputs(sample):
if sample == 'none':
return None, None, None
image = sample
text, classes = SAMPLES[sample]
return image, text, classes
def confirm_sample_inputs(image, text, classes):
return image, text, classes
# Create a Blocks interface
with gr.Blocks() as app:
with gr.Tab("Run"):
with gr.Row():
with gr.Column():
with gr.Row():
image_input = gr.Image(label="Input Image", height=300, width=300)
ground_output = gr.Image(label="Grounding Output", height=300, width=300, interactive=False)
prompt_input = gr.Textbox(label="Input Text Prompt", lines=1, max_lines=1)
ground_class_input = gr.Textbox(
label="Ground Classes",
placeholder="Defaultly, the model will use COCO classes.",
lines=1, max_lines=1
)
submit_button = gr.Button("Submit")
answer_output = gr.Textbox(label="Answer Output", lines=4, max_lines=4, interactive=False)
with gr.Column():
query_output = gr.Textbox(label='Query Output', lines=14, max_lines=14, interactive=False)
search_output = gr.Textbox(label="Search Output", lines=14, max_lines=14, interactive=False)
with gr.Tab("Samples"):
sample_input = gr.Dropdown(label="Select One Sample", choices=SAMPLE_IMAGES)
with gr.Row():
sample_image = gr.Image(label="Sample Input Image", height=300, interactive=False, value=SAMPLE_IMAGES[0])
with gr.Column():
sample_text = gr.Textbox(label="Sample Input Text", lines=4, max_lines=4, interactive=False, value=SAMPLE_TEXTS[0])
sample_classes = gr.Textbox(label="Sample Input Classes", lines=4, max_lines=4, interactive=False, value=SAMPLE_CLASSES[0])
sample_button = gr.Button("Select This Sample")
# Processing action
submit_button.click(
fn=process_inputs,
inputs=[image_input, prompt_input, ground_class_input],
outputs=[ground_output, query_output, search_output, answer_output],
show_progress=True,
)
sample_input.change(
fn=select_sample_inputs,
inputs=[sample_input],
outputs=[sample_image, sample_text, sample_classes]
)
sample_button.click(
fn=confirm_sample_inputs,
inputs=[sample_image, sample_text, sample_classes],
outputs=[image_input, prompt_input, ground_class_input],
)
# vsa = VisionSearchAssistant(
# ground_device = "cuda:0",
# vlm_device="cuda:0",
# vlm_load_4bit=True,
# )
# Launch the app
app.launch() |