File size: 4,588 Bytes
b27b0a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from models.vsa_model import VisionSearchAssistant
from models.vsa_prompt import COCO_CLASSES


SAMPLES = {
    "images/iclr.jpg": ("What prize did this paper win in 2024?", ", ".join(COCO_CLASSES)),
    "images/tesla.jpg": ("What's the income of this company?", "car"),
    "images/xiaomi.jpg": ("Provide information about the new products of this brand.", ", ".join(COCO_CLASSES)),
    "images/leshi.jpg": ("Provide information about new products of this brand of potato chips in 2024.", ", ".join(COCO_CLASSES)),
}
SAMPLE_IMAGES = list(SAMPLES.keys())
SAMPLE_TEXTS = [e[0] for e in SAMPLES.values()]
SAMPLE_CLASSES = [e[1] for e in SAMPLES.values()]


def process_inputs(image, text, ground_classes):
    if len(ground_classes) == 0:
        ground_classes = None
    else:
        ground_classes = ground_classes.split(', ')

    ground_output, query_output, search_output, answer_output = None, None, None, None
    for output, output_type in vsa.app_run(image, text, ground_classes = ground_classes):
        if output_type == 'ground':
            ground_output = output
            yield ground_output, query_output, search_output, answer_output
        elif output_type == 'query':
            query_output = ''
            for qid, query in enumerate(output):
                query_output += '[Area {}] '.format(qid) + query + '\n'
            yield ground_output, query_output, search_output, answer_output
        elif output_type == 'search':
            search_output = ''
            for cid, context in enumerate(output):
                search_output += '[Context {}] '.format(cid) + context + '\n'
            yield ground_output, query_output, search_output, answer_output
        elif output_type == 'answer':
            answer_output = output
            yield ground_output, query_output, search_output, answer_output


def select_sample_inputs(sample):
    if sample == 'none':
        return None, None, None
    image = sample
    text, classes = SAMPLES[sample]
    return image, text, classes

def confirm_sample_inputs(image, text, classes):
    return image, text, classes


# Create a Blocks interface
with gr.Blocks() as app:
    with gr.Tab("Run"):
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    image_input = gr.Image(label="Input Image", height=300, width=300)
                    ground_output = gr.Image(label="Grounding Output", height=300, width=300, interactive=False)
                prompt_input = gr.Textbox(label="Input Text Prompt", lines=1, max_lines=1)
                ground_class_input = gr.Textbox(
                    label="Ground Classes", 
                    placeholder="Defaultly, the model will use COCO classes.",
                    lines=1, max_lines=1
                )
                submit_button = gr.Button("Submit")
                answer_output = gr.Textbox(label="Answer Output", lines=4, max_lines=4, interactive=False)
            with gr.Column():
                query_output = gr.Textbox(label='Query Output', lines=14, max_lines=14, interactive=False)
                search_output = gr.Textbox(label="Search Output", lines=14, max_lines=14, interactive=False)
    with gr.Tab("Samples"):
        sample_input = gr.Dropdown(label="Select One Sample", choices=SAMPLE_IMAGES)
        with gr.Row():
            sample_image = gr.Image(label="Sample Input Image", height=300, interactive=False, value=SAMPLE_IMAGES[0])
            with gr.Column():
                sample_text = gr.Textbox(label="Sample Input Text", lines=4, max_lines=4, interactive=False, value=SAMPLE_TEXTS[0])
                sample_classes = gr.Textbox(label="Sample Input Classes", lines=4, max_lines=4, interactive=False, value=SAMPLE_CLASSES[0])
        sample_button = gr.Button("Select This Sample")
        
    
    # Processing action
    submit_button.click(
        fn=process_inputs,
        inputs=[image_input, prompt_input, ground_class_input],
        outputs=[ground_output, query_output, search_output, answer_output],
        show_progress=True,
    )
    sample_input.change(
        fn=select_sample_inputs,
        inputs=[sample_input],
        outputs=[sample_image, sample_text, sample_classes]
    )
    sample_button.click(
        fn=confirm_sample_inputs,
        inputs=[sample_image, sample_text, sample_classes],
        outputs=[image_input, prompt_input, ground_class_input],
    )


# vsa = VisionSearchAssistant(
#     ground_device = "cuda:0",
#     vlm_device="cuda:0",
#     vlm_load_4bit=True,
# )
# Launch the app
app.launch()