File size: 4,497 Bytes
a3a16bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import base64
import gradio as gr
import torch
from PIL import Image, ImageDraw
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import ast
import os
from datetime import datetime
import numpy as np

# Function to draw a point on the image
def draw_point(image_input, point=None, radius=5):
    if isinstance(image_input, str):
        image = Image.open(image_input)
    else:
        image = Image.fromarray(np.uint8(image_input))

    if point:
        x, y = point[0] * image.width, point[1] * image.height
        ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
    return image

# Function to save the uploaded image and return its path
def array_to_image_path(image_array):
    if image_array is None:
        raise ValueError("No image provided. Please upload an image before submitting.")
    img = Image.fromarray(np.uint8(image_array))
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"image_{timestamp}.png"
    img.save(filename)
    return os.path.abspath(filename)

# Load the model
model = Qwen2VLForConditionalGeneration.from_pretrained(
    # "./showui-2b",
    "/users/difei/siyuan/showui_demo/showui-2b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # verbose=True,
)

# Define minimum and maximum pixel thresholds
min_pixels = 256 * 28 * 28
max_pixels = 1344 * 28 * 28

# Load the processor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

# Hugging Face Space description
DESCRIPTION = "[ShowUI-2B Demo](https://huggingface.co/showlab/ShowUI-2B)"

# Define the system instruction
_SYSTEM = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1."

# Define the main function for inference
def run_showui(image, query):
    image_path = array_to_image_path(image)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": _SYSTEM},
                {"type": "image", "image": image_path, "min_pixels": min_pixels, "max_pixels": max_pixels},
                {"type": "text", "text": query}
            ],
        }
    ]

    # Prepare inputs for the model
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    )
    inputs = inputs.to("cuda")

    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    # Parse the output into coordinates
    click_xy = ast.literal_eval(output_text)

    # Draw the point on the image
    result_image = draw_point(image_path, click_xy, radius=10)
    return result_image, str(click_xy)

with open("./assets/showui.png", "rb") as image_file:
    base64_image = base64.b64encode(image_file.read()).decode("utf-8")

# Gradio UI
with gr.Blocks() as demo:
    gr.HTML(
        f"""
        <div style="text-align: center; margin-bottom: 20px;">
            <a href="https://github.com/showlab/ShowUI" target="_blank">
                <img src="data:image/png;base64,{base64_image}" alt="ShowUI Logo" style="width: 200px; height: auto;"/>
            </a>
        </div>
        """
    )
    
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="ShowUI-2B Input"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Screenshot")
                text_input = gr.Textbox(label="Query (e.g., 'Click Nahant')")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_img = gr.Image(label="Output Image")
                output_coords = gr.Textbox(label="Clickable Coordinates")

        submit_btn.click(run_showui, [input_img, text_input], [output_img, output_coords])

demo.queue(api_open=False)
demo.launch()