maxiw commited on
Commit
7670816
·
verified ·
1 Parent(s): a84cec7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoProcessor
4
+ from PIL import ImageDraw
5
+
6
+
7
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
8
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
9
+
10
+ models = {
11
+ "AskUI/PTA-1": AutoModelForCausalLM.from_pretrained("AskUI/PTA-1", trust_remote_code=True),
12
+ }
13
+
14
+ processors = {
15
+ "AskUI/PTA-1": AutoProcessor.from_pretrained("AskUI/PTA-1", trust_remote_code=True)
16
+ }
17
+
18
+
19
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=3):
20
+ draw = ImageDraw.Draw(image)
21
+ for box in bounding_boxes:
22
+ xmin, ymin, xmax, ymax = box
23
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
24
+ return image
25
+
26
+
27
+ def florence_output_to_box(output):
28
+ try:
29
+ if "polygons" in output and len(output["polygons"]) > 0:
30
+ polygons = output["polygons"]
31
+ target_polygon = polygons[0][0]
32
+ target_polygon = [int(el) for el in target_polygon]
33
+ return [
34
+ target_polygon[0],
35
+ target_polygon[1],
36
+ target_polygon[4],
37
+ target_polygon[5],
38
+ ]
39
+ if "bboxes" in output and len(output["bboxes"]) > 0:
40
+ bboxes = output["bboxes"]
41
+ target_bbox = bboxes[0]
42
+ target_bbox = [int(el) for el in target_bbox]
43
+ return target_bbox
44
+ except Exception as e:
45
+ print(f"Error: {e}")
46
+ return None
47
+
48
+
49
+ def run_example(image, text_input, model_id="AskUI/PTA-1"):
50
+ model = models[model_id].to(device, torch_dtype)
51
+ processor = processors[model_id]
52
+ task_prompt = "<OPEN_VOCABULARY_DETECTION>"
53
+ prompt = task_prompt + text_input
54
+
55
+ image = image.convert("RGB")
56
+
57
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
58
+
59
+ generated_ids = model.generate(
60
+ input_ids=inputs["input_ids"],
61
+ pixel_values=inputs["pixel_values"],
62
+ max_new_tokens=1024,
63
+ do_sample=False,
64
+ num_beams=3,
65
+ )
66
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
67
+ parsed_answer = processor.post_process_generation(generated_text, task="<OPEN_VOCABULARY_DETECTION>", image_size=(image.width, image.height))
68
+ target_box = florence_output_to_box(parsed_answer["<OPEN_VOCABULARY_DETECTION>"])
69
+ return target_box, draw_bounding_boxes(image, [target_box])
70
+
71
+
72
+ css = """
73
+ #output {
74
+ height: 500px;
75
+ overflow: auto;
76
+ border: 1px solid #ccc;
77
+ }
78
+ """
79
+ with gr.Blocks(css=css) as demo:
80
+ gr.Markdown(
81
+ """
82
+ <div style="display: flex; justify-content: space-between; align-items: center; background-color: #baff49; padding: 10px;">
83
+ <h1 style="margin: 0; color: #101828";>PTA-1: Controlling Computers with Small Models</h1>
84
+ <img src="https://cdn.prod.website-files.com/6627a15f6d261b8bf852c0a1/670529b583d3638f72db5614_askui-logo-primary-filled.svg" alt="Logo" style="height: 50px;">
85
+ </div>
86
+ """)
87
+ gr.Markdown("Check out the model [AskUI/PTA-1](https://huggingface.co/AskUI/PTA-1).")
88
+ with gr.Row():
89
+ with gr.Column():
90
+ input_img = gr.Image(label="Input Image", type="pil")
91
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="AskUI/PTA-1")
92
+ text_input = gr.Textbox(label="User Prompt")
93
+ submit_btn = gr.Button(value="Submit")
94
+ with gr.Column():
95
+ model_output_text = gr.Textbox(label="Model Output Text")
96
+ annotated_image = gr.Image(label="Annotated Image")
97
+
98
+ gr.Examples(
99
+ examples=[
100
+ ["assets/sample.png", "search box"],
101
+ ["assets/sample.png", "Query Service"],
102
+ ["assets/ipad.png", "App Store icon"],
103
+ ["assets/ipad.png", 'colorful icon with letter "S"'],
104
+ ["assets/phone.jpg", "password field"],
105
+ ["assets/phone.jpg", "back arrow icon"],
106
+ ["assets/windows.jpg", "icon with letter S"],
107
+ ["assets/windows.jpg", "Settings"],
108
+ ],
109
+ inputs=[input_img, text_input],
110
+ outputs=[model_output_text, annotated_image],
111
+ fn=run_example,
112
+ cache_examples=False,
113
+ label="Try examples"
114
+ )
115
+
116
+ submit_btn.click(run_example, [input_img, text_input, model_selector], [model_output_text, annotated_image])
117
+
118
+ demo.launch(debug=False)