taesiri commited on
Commit
7fd29ef
·
1 Parent(s): 678410c
Files changed (2) hide show
  1. app.py +237 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
6
+ from huggingface_hub import login
7
+ import spaces
8
+ import json
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import base64
12
+
13
+
14
+ def check_environment():
15
+ required_vars = ["HF_TOKEN"]
16
+ missing_vars = [var for var in required_vars if var not in os.environ]
17
+
18
+ if missing_vars:
19
+ raise ValueError(
20
+ f"Missing required environment variables: {', '.join(missing_vars)}\n"
21
+ "Please set the HF_TOKEN environment variable with your Hugging Face token"
22
+ )
23
+
24
+
25
+ # Login to Hugging Face
26
+ check_environment()
27
+ login(token=os.environ["HF_TOKEN"], add_to_git_credential=True)
28
+
29
+ # Load model and processor (do this outside the inference function to avoid reloading)
30
+ base_model_path = (
31
+ "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-XLarge-FullModel"
32
+ )
33
+
34
+ processor = AutoProcessor.from_pretrained(base_model_path)
35
+ model = MllamaForConditionalGeneration.from_pretrained(
36
+ base_model_path,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="cuda",
39
+ )
40
+ # model = PeftModel.from_pretrained(model, lora_weights_path)
41
+ model.tie_weights()
42
+
43
+
44
+ def describe_image_in_JSON(json_string):
45
+ try:
46
+ # First JSON decode
47
+ first_decode = json.loads(json_string)
48
+
49
+ # Second JSON decode - parse the actual data
50
+ final_data = json.loads(first_decode)
51
+
52
+ return final_data
53
+
54
+ except json.JSONDecodeError as e:
55
+ return f"Error parsing JSON: {str(e)}"
56
+
57
+
58
+ def create_color_palette_image(colors):
59
+ if not colors or not isinstance(colors, list):
60
+ return None
61
+
62
+ try:
63
+ # Validate color format
64
+ for color in colors:
65
+ if not isinstance(color, str) or not color.startswith("#"):
66
+ return None
67
+
68
+ # Create figure and axis
69
+ fig, ax = plt.subplots(figsize=(10, 2))
70
+
71
+ # Create rectangles for each color
72
+ for i, color in enumerate(colors):
73
+ ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))
74
+
75
+ # Set the view limits and aspect ratio
76
+ ax.set_xlim(0, len(colors))
77
+ ax.set_ylim(0, 1)
78
+ ax.set_xticks([])
79
+ ax.set_yticks([])
80
+
81
+ return fig # Return the matplotlib figure directly
82
+ except Exception as e:
83
+ print(f"Error creating color palette: {e}")
84
+ return None
85
+
86
+
87
+ @spaces.GPU
88
+ def inference(image):
89
+ if image is None:
90
+ return ["Please provide an image"] * 8
91
+
92
+ if not isinstance(image, Image.Image):
93
+ try:
94
+ image = Image.fromarray(image)
95
+ except Exception as e:
96
+ print(f"Image conversion error: {e}")
97
+ return ["Invalid image format"] * 8
98
+
99
+ # Prepare input
100
+ messages = [
101
+ {
102
+ "role": "user",
103
+ "content": [
104
+ {"type": "image"},
105
+ {"type": "text", "text": "Describe the image in JSON"},
106
+ ],
107
+ }
108
+ ]
109
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
110
+ try:
111
+ # Move inputs to the correct device
112
+ inputs = processor(
113
+ image, input_text, add_special_tokens=False, return_tensors="pt"
114
+ ).to(model.device)
115
+
116
+ # Clear CUDA cache after inference
117
+ with torch.no_grad():
118
+ output = model.generate(**inputs, max_new_tokens=2048)
119
+ if torch.cuda.is_available():
120
+ torch.cuda.empty_cache()
121
+
122
+ except Exception as e:
123
+ print(f"Inference error: {e}")
124
+ return ["Error during inference"] * 8
125
+
126
+ # Decode output
127
+ result = processor.decode(output[0], skip_special_tokens=True)
128
+ print("DEBUG: Full decoded output:", result)
129
+
130
+ try:
131
+ json_str = result.strip().split("assistant\n")[1].strip()
132
+ print("DEBUG: Extracted JSON string after split:", json_str)
133
+ except Exception as e:
134
+ print("DEBUG: Error splitting response:", e)
135
+ return ["Error extracting JSON from response"] * 8 + [
136
+ "Failed to extract JSON",
137
+ "Error",
138
+ ]
139
+
140
+ parsed_json = describe_image_in_JSON(json_str)
141
+ if parsed_json:
142
+ # Create color palette visualization
143
+ colors = parsed_json.get("color_palette", [])
144
+ color_image = create_color_palette_image(colors)
145
+
146
+ # Convert lists to proper format for Gradio JSON components
147
+ character_list = json.dumps(parsed_json.get("character_list", []))
148
+ object_list = json.dumps(parsed_json.get("object_list", []))
149
+ texture_details = json.dumps(parsed_json.get("texture_details", []))
150
+
151
+ return (
152
+ parsed_json.get("description", "Not available"),
153
+ parsed_json.get("scene_description", "Not available"),
154
+ character_list,
155
+ object_list,
156
+ texture_details,
157
+ parsed_json.get("lighting_details", "Not available"),
158
+ color_image,
159
+ json_str,
160
+ "", # Error box
161
+ "Analysis complete", # Status
162
+ )
163
+ return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"]
164
+
165
+
166
+ # Update Gradio interface
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-XLarge Demo")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ image_input = gr.Image(
173
+ type="pil",
174
+ label="Upload Image",
175
+ elem_id="large-image",
176
+ )
177
+ submit_btn = gr.Button("Analyze Image", variant="primary")
178
+
179
+ with gr.Tabs():
180
+ with gr.Tab("Structured Results"):
181
+ with gr.Column(scale=1):
182
+ description_output = gr.Textbox(
183
+ label="Description",
184
+ lines=4,
185
+ )
186
+ scene_output = gr.Textbox(
187
+ label="Scene Description",
188
+ lines=2,
189
+ )
190
+ characters_output = gr.JSON(
191
+ label="Characters",
192
+ )
193
+ objects_output = gr.JSON(
194
+ label="Objects",
195
+ )
196
+ textures_output = gr.JSON(
197
+ label="Texture Details",
198
+ )
199
+ lighting_output = gr.Textbox(
200
+ label="Lighting Details",
201
+ lines=2,
202
+ )
203
+ color_palette_output = gr.Plot(
204
+ label="Color Palette",
205
+ )
206
+
207
+ with gr.Tab("Raw Output"):
208
+ raw_output = gr.Textbox(
209
+ label="Raw JSON Response",
210
+ lines=25,
211
+ max_lines=30,
212
+ )
213
+
214
+ error_box = gr.Textbox(label="Error Messages", visible=False)
215
+
216
+ with gr.Row():
217
+ status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
218
+
219
+ submit_btn.click(
220
+ fn=inference,
221
+ inputs=[image_input],
222
+ outputs=[
223
+ description_output,
224
+ scene_output,
225
+ characters_output,
226
+ objects_output,
227
+ textures_output,
228
+ lighting_output,
229
+ color_palette_output,
230
+ raw_output,
231
+ error_box,
232
+ status_text,
233
+ ],
234
+ api_name="analyze",
235
+ )
236
+
237
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ datasets
4
+ git+https://github.com/huggingface/transformers.git
5
+ accelerate
6
+ pillow
7
+ gradio
8
+ matplotlib