Update app.py
Browse files
app.py
CHANGED
@@ -1,295 +1,368 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import cv2
|
3 |
-
import time
|
4 |
import openai
|
5 |
-
import base64
|
6 |
-
import pytz
|
7 |
-
import uuid
|
8 |
-
from threading import Thread
|
9 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
10 |
-
from datetime import datetime
|
11 |
-
import json
|
12 |
import os
|
13 |
-
|
14 |
-
from
|
15 |
-
import
|
16 |
-
import
|
|
|
17 |
|
18 |
api_key = os.getenv("OPEN_AI_KEY")
|
19 |
-
user_name = os.getenv("USER_NAME")
|
20 |
-
password = os.getenv("PASSWORD")
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
"""
|
28 |
-
AVATARS = (
|
29 |
-
"https://assets-global.website-files.com/63d6dca820934a77a340f31e/63dfb7a21b4c08282d524010_pyramid.png",
|
30 |
-
"https://media.roboflow.com/spaces/openai-white-logomark.png"
|
31 |
-
)
|
32 |
-
|
33 |
-
# Set your OpenAI API key
|
34 |
-
openai.api_key = api_key
|
35 |
-
MODEL="gpt-4o"
|
36 |
-
client = openai.OpenAI(api_key=api_key)
|
37 |
-
|
38 |
-
# Global variable to stop the video capture loop
|
39 |
-
stop_capture = False
|
40 |
-
alerts_mode = True
|
41 |
-
|
42 |
-
def clip_video_segment_2(input_video_path, start_time, duration):
|
43 |
-
os.makedirs('videos', exist_ok=True)
|
44 |
-
output_video_path = f"videos/{uuid.uuid4()}.mp4"
|
45 |
-
|
46 |
-
# Use ffmpeg-python to clip the video
|
47 |
-
try:
|
48 |
-
(
|
49 |
-
ffmpeg
|
50 |
-
.input(input_video_path, ss=start_time) # Seek to start_time
|
51 |
-
.output(output_video_path, t=duration, c='copy') # Set the duration
|
52 |
-
.run(overwrite_output=True)
|
53 |
-
)
|
54 |
-
print('input_video_path', input_video_path, output_video_path)
|
55 |
-
return output_video_path
|
56 |
-
except ffmpeg.Error as e:
|
57 |
-
print(f"Error clipping video: {e}")
|
58 |
-
return None
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
video_clip_path = f"videos/{uuid.uuid4()}.mp4"
|
75 |
-
|
76 |
-
# Get frame size
|
77 |
-
height, width, layers = frames[0].shape
|
78 |
-
size = (width, height)
|
79 |
-
|
80 |
-
# Define the codec and create VideoWriter object
|
81 |
-
fourcc = cv2.VideoWriter_fourcc(*'h264') # You can also try 'XVID', 'MJPG', etc.
|
82 |
-
out = cv2.VideoWriter(video_clip_path, fourcc, fps, size)
|
83 |
-
|
84 |
-
for frame in frames:
|
85 |
-
out.write(frame)
|
86 |
-
|
87 |
-
out.release()
|
88 |
-
|
89 |
-
return video_clip_path
|
90 |
-
|
91 |
-
|
92 |
-
def encode_to_video(frames, fps):
|
93 |
-
os.makedirs('videos', exist_ok=True)
|
94 |
-
video_clip_path = f"videos/{uuid.uuid4()}.mp4"
|
95 |
-
|
96 |
-
# Create a video clip from the frames using moviepy
|
97 |
-
clip = ImageSequenceClip([frame[:, :, ::-1] for frame in frames], fps=fps) # Convert from BGR to RGB
|
98 |
-
clip.write_videofile(video_clip_path, codec="libx264")
|
99 |
-
|
100 |
-
# Convert the video file to base64
|
101 |
-
with open(video_clip_path, "rb") as video_file:
|
102 |
-
video_data = base64.b64encode(video_file.read()).decode('utf-8')
|
103 |
-
|
104 |
-
return video_clip_path
|
105 |
-
|
106 |
-
# Function to process video frames using GPT-4 API
|
107 |
-
def process_frames(frames, frames_to_skip = 1):
|
108 |
-
os.makedirs('saved_frames', exist_ok=True)
|
109 |
-
curr_frame=0
|
110 |
-
base64Frames = []
|
111 |
-
while curr_frame < len(frames) - 1:
|
112 |
-
_, buffer = cv2.imencode(".jpg", frames[curr_frame])
|
113 |
-
base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
|
114 |
-
curr_frame += frames_to_skip
|
115 |
-
return base64Frames
|
116 |
-
|
117 |
-
# Function to check condition using GPT-4 API
|
118 |
-
def check_condition(prompt, base64Frames):
|
119 |
-
start_time = time.time()
|
120 |
-
print('checking condition for frames:', len(base64Frames))
|
121 |
-
|
122 |
-
# Save frames as images
|
123 |
-
|
124 |
-
|
125 |
-
messages = [
|
126 |
-
{"role": "system", "content": """You are analyzing video to check if the user's condition is met.
|
127 |
-
Please respond with a JSON object in the following format:
|
128 |
-
{"condition_met": true/false, "details": "optional details or summary. in the summary DON'T mention the words: image, images, frame, or frames. Instead, make it look like you were provided with video input and avoid referring to individual images or frames explicitly."}"""},
|
129 |
-
{"role": "user", "content": [prompt, *map(lambda x: {"type": "image_url", "image_url": {"url": f'data:image/jpg;base64,{x}', "detail": "low"}}, base64Frames)]}
|
130 |
-
]
|
131 |
|
132 |
response = client.chat.completions.create(
|
133 |
model="gpt-4o",
|
134 |
messages=messages,
|
135 |
-
|
136 |
-
|
|
|
137 |
)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
try:
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
#
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
yield result
|
264 |
-
return chatbot
|
265 |
-
|
266 |
-
|
267 |
-
# Function to stop video capture
|
268 |
-
def stop_capture_func():
|
269 |
-
global stop_capture
|
270 |
-
stop_capture = True
|
271 |
-
|
272 |
-
# Gradio interface
|
273 |
-
with gr.Blocks(title="Conntour", fill_height=True) as demo:
|
274 |
-
with gr.Tab("Analyze"):
|
275 |
-
with gr.Row():
|
276 |
-
video = gr.Video(label="Video Source")
|
277 |
-
with gr.Column():
|
278 |
-
chatbot = gr.Chatbot(label="Events", bubble_full_width=False, avatar_images=AVATARS)
|
279 |
-
prompt = gr.Textbox(label="Enter your prompt alert")
|
280 |
-
start_btn = gr.Button("Start")
|
281 |
-
stop_btn = gr.Button("Stop")
|
282 |
-
start_btn.click(analyze_video_file, inputs=[prompt, video, chatbot], outputs=[chatbot], queue=True)
|
283 |
-
stop_btn.click(stop_capture_func)
|
284 |
-
with gr.Tab("Alerts"):
|
285 |
-
with gr.Row():
|
286 |
-
stream = gr.Textbox(label="Video Source", value="https://streamapi2.eu.loclx.io/video_feed/101 OR rtsp://admin:[email protected]:5678/Streaming/Channels/101")
|
287 |
-
with gr.Column():
|
288 |
-
chatbot = gr.Chatbot(label="Events", bubble_full_width=False, avatar_images=AVATARS)
|
289 |
-
prompt = gr.Textbox(label="Enter your prompt alert")
|
290 |
-
start_btn = gr.Button("Start")
|
291 |
-
stop_btn = gr.Button("Stop")
|
292 |
-
start_btn.click(analyze_stream, inputs=[prompt, stream, chatbot], outputs=[chatbot], queue=True)
|
293 |
-
stop_btn.click(stop_capture_func)
|
294 |
-
|
295 |
-
demo.launch(favicon_path='favicon.ico', auth=(user_name, password))
|
|
|
|
|
|
|
|
|
1 |
import openai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
+
import gradio as gr
|
4 |
+
from enum import Enum
|
5 |
+
from dataclasses import dataclass, asdict, field
|
6 |
+
from typing import List, Optional, Union, Dict, Any
|
7 |
+
import json
|
8 |
|
9 |
api_key = os.getenv("OPEN_AI_KEY")
|
|
|
|
|
10 |
|
11 |
+
# Define the COCOClass enum
|
12 |
+
class COCOClass(Enum):
|
13 |
+
person = 0
|
14 |
+
bicycle = 1
|
15 |
+
car = 2
|
16 |
+
motorcycle = 3
|
17 |
+
airplane = 4
|
18 |
+
bus = 5
|
19 |
+
train = 6
|
20 |
+
truck = 7
|
21 |
+
boat = 8
|
22 |
+
traffic_light = 9
|
23 |
+
fire_hydrant = 10
|
24 |
+
stop_sign = 11
|
25 |
+
parking_meter = 12
|
26 |
+
bench = 13
|
27 |
+
bird = 14
|
28 |
+
cat = 15
|
29 |
+
dog = 16
|
30 |
+
horse = 17
|
31 |
+
sheep = 18
|
32 |
+
cow = 19
|
33 |
+
elephant = 20
|
34 |
+
bear = 21
|
35 |
+
zebra = 22
|
36 |
+
giraffe = 23
|
37 |
+
backpack = 24
|
38 |
+
umbrella = 25
|
39 |
+
handbag = 26
|
40 |
+
tie = 27
|
41 |
+
suitcase = 28
|
42 |
+
frisbee = 29
|
43 |
+
skis = 30
|
44 |
+
snowboard = 31
|
45 |
+
sports_ball = 32
|
46 |
+
kite = 33
|
47 |
+
baseball_bat = 34
|
48 |
+
baseball_glove = 35
|
49 |
+
skateboard = 36
|
50 |
+
surfboard = 37
|
51 |
+
tennis_racket = 38
|
52 |
+
bottle = 39
|
53 |
+
wine_glass = 40
|
54 |
+
cup = 41
|
55 |
+
fork = 42
|
56 |
+
knife = 43
|
57 |
+
spoon = 44
|
58 |
+
bowl = 45
|
59 |
+
banana = 46
|
60 |
+
apple = 47
|
61 |
+
sandwich = 48
|
62 |
+
orange = 49
|
63 |
+
broccoli = 50
|
64 |
+
carrot = 51
|
65 |
+
hot_dog = 52
|
66 |
+
pizza = 53
|
67 |
+
donut = 54
|
68 |
+
cake = 55
|
69 |
+
chair = 56
|
70 |
+
couch = 57
|
71 |
+
potted_plant = 58
|
72 |
+
bed = 59
|
73 |
+
dining_table = 60
|
74 |
+
toilet = 61
|
75 |
+
tv = 62
|
76 |
+
laptop = 63
|
77 |
+
mouse = 64
|
78 |
+
remote = 65
|
79 |
+
keyboard = 66
|
80 |
+
cell_phone = 67
|
81 |
+
microwave = 68
|
82 |
+
oven = 69
|
83 |
+
toaster = 70
|
84 |
+
sink = 71
|
85 |
+
refrigerator = 72
|
86 |
+
book = 73
|
87 |
+
clock = 74
|
88 |
+
vase = 75
|
89 |
+
scissors = 76
|
90 |
+
teddy_bear = 77
|
91 |
+
hair_drier = 78
|
92 |
+
toothbrush = 79
|
93 |
+
|
94 |
+
# Define data classes
|
95 |
+
@dataclass
|
96 |
+
class VehicleProps:
|
97 |
+
brand: Optional[str] = None
|
98 |
+
type: Optional[COCOClass] = None # Should be a vehicle class
|
99 |
+
plate: Optional[str] = None
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class PersonProps:
|
103 |
+
face_images: Optional[List[str]] = field(default_factory=list)
|
104 |
+
age: Optional[int] = None
|
105 |
+
race: Optional[str] = None # Should be one of the specified races
|
106 |
+
gender: Optional[str] = None # Male or Female
|
107 |
+
top_color: Optional[str] = None # Changed from shirt_color
|
108 |
+
bottom_color: Optional[str] = None
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class Activity:
|
112 |
+
prompt: Optional[str] = None
|
113 |
+
type: Optional[str] = None # "full_screen" or "square"
|
114 |
+
|
115 |
+
@dataclass
|
116 |
+
class Investigation:
|
117 |
+
target: COCOClass
|
118 |
+
images: List[str]
|
119 |
+
activity: Optional[Activity] = None
|
120 |
+
complex_appearance: Optional[str] = None
|
121 |
+
props: Optional[Union[VehicleProps, PersonProps]] = None
|
122 |
+
primary_color: Optional[str] = None
|
123 |
+
secondary_color: Optional[str] = None
|
124 |
+
|
125 |
+
# Default system message (moved to a global variable)
|
126 |
+
DEFAULT_SYSTEM_MESSAGE = """
|
127 |
+
You are a helpful assistant that extracts structured information from text descriptions.
|
128 |
+
|
129 |
+
Your task is to parse the following text prompt and extract information to populate an Investigation JSON object as per the definitions provided.
|
130 |
+
|
131 |
+
Definitions:
|
132 |
+
|
133 |
+
Investigation:
|
134 |
+
{{
|
135 |
+
"target": A COCO class name (from the COCOClass enum),
|
136 |
+
"images": List of image URLs,
|
137 |
+
"activity": {{
|
138 |
+
"prompt": A description of an activity, e.g., "crossing the street", "crossing red light", "holding a gun",
|
139 |
+
"type": Either "full_screen" or "square"
|
140 |
+
- "full_screen": When the activity requires the full scene for context (e.g., "seeing a movie").
|
141 |
+
- "square": When the activity context can be understood from a close-up image (e.g., "holding a cat").
|
142 |
+
}},
|
143 |
+
"complex_appearance": Description of appearance details that do not fit into other fields, e.g., "has a hat with Nike logo" or "Tattoo on left arm",
|
144 |
+
"props": Either VehicleProps or PersonProps (only if the target is vehicle or person),
|
145 |
+
"primary_color": Primary color mentioned in the prompt,
|
146 |
+
"secondary_color": Secondary color mentioned in the prompt
|
147 |
+
}}
|
148 |
+
|
149 |
+
VehicleProps:
|
150 |
+
{{
|
151 |
+
"brand": Vehicle brand, e.g., "Mercedes",
|
152 |
+
"type": COCO class name of vehicles (e.g., "truck"),
|
153 |
+
"plate": License plate number, e.g., "123AB"
|
154 |
+
}}
|
155 |
+
|
156 |
+
PersonProps:
|
157 |
+
{{
|
158 |
+
"face_images": List of face image URLs,
|
159 |
+
"age": Age as a number,
|
160 |
+
"race": Race or ethnicity (one of: asian, white, middle eastern, indian, latino, black),
|
161 |
+
"gender": Gender (Male or Female),
|
162 |
+
"top_color": Color of the top garment (e.g., shirt, blouse), # Changed from shirt_color
|
163 |
+
"bottom_color": Color of the bottom garment (pants, skirt, etc.)
|
164 |
+
}}
|
165 |
+
|
166 |
+
COCOClass Enum:
|
167 |
+
{{
|
168 |
+
{', '.join([f'"{member.name}"' for member in COCOClass])}
|
169 |
+
}}
|
170 |
+
|
171 |
+
Important Notes:
|
172 |
+
|
173 |
+
1. The output JSON should be as minimal as possible. Do not include fields like 'primary_color' or 'secondary_color' if they are not mentioned in the prompt.
|
174 |
+
|
175 |
+
2. Be especially careful with 'activity' and 'complex_appearance' fields. Use them only if the prompt has data that does not fit elsewhere in the JSON. For example:
|
176 |
+
- "a guy with red shirt" -> Map 'red shirt' to 'top_color' in PersonProps.
|
177 |
+
- "a guy with a black hat" -> Since there isn't any field for 'hat', include "black hat" in 'complex_appearance'.
|
178 |
+
|
179 |
+
3. Avoid using 'complex_appearance' and 'activity' fields unless absolutely necessary.
|
180 |
+
|
181 |
+
4. Do not include undefined fields or fields not mentioned in the prompt.
|
182 |
+
|
183 |
+
5. Use the COCOClass enum for the target class name.
|
184 |
+
|
185 |
+
Now, process the following prompt:
|
186 |
+
|
187 |
+
'''prompt_text'''
|
188 |
+
|
189 |
+
Provide the Investigation JSON object, including only the relevant fields based on the prompt. Do not include any explanations.
|
190 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
+
# Function to process the prompt
|
193 |
+
def process_prompt(prompt_text: str, images: List[str], face_images: List[str],
|
194 |
+
system_message: Optional[str] = None, user_message: Optional[str] = None,
|
195 |
+
temperature: float = 0.0) -> Optional[Dict[str, Any]]:
|
196 |
+
client = openai.OpenAI(api_key=api_key)
|
197 |
|
198 |
+
# Default user message
|
199 |
+
if not user_message:
|
200 |
+
user_message = ""
|
201 |
+
|
202 |
+
# Prepare messages for the API
|
203 |
+
messages = []
|
204 |
+
if system_message.strip():
|
205 |
+
messages.append({"role": "system", "content": system_message.replace("prompt_text", prompt_text)})
|
206 |
+
if user_message.strip():
|
207 |
+
messages.append({"role": "user", "content": user_message.replace("prompt_text", prompt_text)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
response = client.chat.completions.create(
|
210 |
model="gpt-4o",
|
211 |
messages=messages,
|
212 |
+
response_format={ "type": "json_object" },
|
213 |
+
temperature=temperature,
|
214 |
+
max_tokens=1000,
|
215 |
)
|
216 |
|
217 |
+
# Extract the content
|
218 |
+
content = response.choices[0].message.content
|
219 |
+
|
220 |
+
# Parse the JSON output
|
221 |
+
try:
|
222 |
+
investigation_data = json.loads(content)
|
223 |
+
except json.JSONDecodeError as e:
|
224 |
+
print("Error parsing JSON:", e)
|
225 |
+
print("OpenAI response:", content)
|
226 |
+
return None
|
227 |
+
|
228 |
+
# Construct the Investigation object
|
229 |
+
investigation = parse_investigation(investigation_data, images, face_images)
|
230 |
+
|
231 |
+
# Convert the Investigation object to dictionary
|
232 |
+
if investigation:
|
233 |
+
investigation_dict = asdict(investigation)
|
234 |
+
# Convert enums to their names
|
235 |
+
investigation_dict['target'] = investigation.target.name
|
236 |
+
if investigation.props:
|
237 |
+
if isinstance(investigation.props, VehicleProps) and investigation.props.type:
|
238 |
+
investigation_dict['props']['type'] = investigation.props.type.name
|
239 |
+
elif isinstance(investigation.props, PersonProps):
|
240 |
+
pass # No enums in PersonProps
|
241 |
+
return investigation_dict
|
242 |
+
else:
|
243 |
+
return None
|
244 |
+
|
245 |
+
# Function to parse the Investigation data
|
246 |
+
def parse_investigation(data: Dict[str, Any], images: List[str], face_images: List[str]) -> Optional[Investigation]:
|
247 |
+
# Parse target
|
248 |
+
target_name = data.get('target')
|
249 |
try:
|
250 |
+
target_enum = COCOClass[target_name]
|
251 |
+
except KeyError:
|
252 |
+
print(f"Invalid COCO class name: {target_name}")
|
253 |
+
return None
|
254 |
+
|
255 |
+
# Parse activity
|
256 |
+
activity_data = data.get('activity')
|
257 |
+
if activity_data:
|
258 |
+
activity = Activity(
|
259 |
+
prompt=activity_data.get('prompt'),
|
260 |
+
type=activity_data.get('type')
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
activity = None
|
264 |
+
|
265 |
+
# Parse props
|
266 |
+
props_data = data.get('props')
|
267 |
+
props = None
|
268 |
+
if props_data:
|
269 |
+
if 'face_images' in props_data:
|
270 |
+
# PersonProps
|
271 |
+
props = PersonProps(
|
272 |
+
face_images=face_images,
|
273 |
+
age=props_data.get('age'),
|
274 |
+
race=props_data.get('race'),
|
275 |
+
gender=props_data.get('gender'),
|
276 |
+
top_color=props_data.get('top_color'), # Changed from shirt_color
|
277 |
+
bottom_color=props_data.get('bottom_color')
|
278 |
+
)
|
279 |
+
elif 'brand' in props_data:
|
280 |
+
# VehicleProps
|
281 |
+
vehicle_type_name = props_data.get('type')
|
282 |
+
if vehicle_type_name:
|
283 |
+
try:
|
284 |
+
vehicle_type_enum = COCOClass[vehicle_type_name]
|
285 |
+
except KeyError:
|
286 |
+
print(f"Invalid vehicle type: {vehicle_type_name}")
|
287 |
+
vehicle_type_enum = None
|
288 |
+
else:
|
289 |
+
vehicle_type_enum = None
|
290 |
+
|
291 |
+
props = VehicleProps(
|
292 |
+
brand=props_data.get('brand'),
|
293 |
+
type=vehicle_type_enum,
|
294 |
+
plate=props_data.get('plate')
|
295 |
+
)
|
296 |
+
|
297 |
+
# Construct the Investigation object
|
298 |
+
investigation = Investigation(
|
299 |
+
target=target_enum,
|
300 |
+
images=images,
|
301 |
+
activity=activity,
|
302 |
+
complex_appearance=data.get('complex_appearance'),
|
303 |
+
props=props,
|
304 |
+
primary_color=data.get('primary_color'),
|
305 |
+
secondary_color=data.get('secondary_color')
|
306 |
+
)
|
307 |
+
|
308 |
+
return investigation
|
309 |
+
|
310 |
+
# Gradio app
|
311 |
+
def gradio_app(prompts_text, system_message, user_message, temperature):
|
312 |
+
# Split prompts by commas and strip whitespace
|
313 |
+
prompts = [p.strip() for p in prompts_text.split(',') if p.strip()]
|
314 |
+
images = ["http://example.com/image1.jpg", "http://example.com/image2.jpg"]
|
315 |
+
face_images = ["http://example.com/face1.jpg"]
|
316 |
+
|
317 |
+
results = []
|
318 |
+
for p in prompts:
|
319 |
+
investigation_dict = process_prompt(
|
320 |
+
prompt_text=p,
|
321 |
+
images=images,
|
322 |
+
face_images=face_images,
|
323 |
+
system_message=system_message if system_message else None,
|
324 |
+
user_message=user_message if user_message else None,
|
325 |
+
temperature=temperature if temperature else 0.0
|
326 |
+
)
|
327 |
+
if investigation_dict:
|
328 |
+
results.append(json.dumps(investigation_dict, indent=4))
|
329 |
+
else:
|
330 |
+
results.append("Failed to process prompt.")
|
331 |
+
|
332 |
+
return "\n\n".join(results)
|
333 |
+
|
334 |
+
if __name__ == "__main__":
|
335 |
+
# Default values
|
336 |
+
default_prompts = ", ".join([
|
337 |
+
"A red sports car with a license plate reading 'FAST123'.",
|
338 |
+
"An elderly woman wearing a green dress and a pearl necklace.",
|
339 |
+
"A cyclist in a yellow jersey riding a blue bicycle.",
|
340 |
+
"A group of people playing frisbee in the park.",
|
341 |
+
"A man with a large tattoo of a dragon on his right arm.",
|
342 |
+
"A black and white cat sitting on a red couch.",
|
343 |
+
"A delivery truck with the 'FedEx' logo on the side.",
|
344 |
+
"A child holding a red balloon shaped like a dog.",
|
345 |
+
"A person wearing a hoodie with the text 'OpenAI' on it.",
|
346 |
+
"A woman in a blue swimsuit swimming in the ocean."
|
347 |
+
])
|
348 |
+
|
349 |
+
default_system_message = DEFAULT_SYSTEM_MESSAGE.replace("{{prompt_text}}", "{prompt_text}") # Prepare for formatting
|
350 |
+
default_user_message = "" # Optional user message
|
351 |
+
default_temperature = 0.0 # Default temperature
|
352 |
+
|
353 |
+
# Create Gradio interface
|
354 |
+
iface = gr.Interface(
|
355 |
+
fn=gradio_app,
|
356 |
+
inputs=[
|
357 |
+
gr.Textbox(lines=5, label="List of Prompts (comma-separated)", value=default_prompts),
|
358 |
+
gr.Textbox(lines=20, label="System Message (optional)", value=default_system_message),
|
359 |
+
gr.Textbox(lines=5, label="User Message (optional)", value=default_user_message),
|
360 |
+
gr.Slider(minimum=0, maximum=1, step=0.1, label="Temperature", value=default_temperature)
|
361 |
+
],
|
362 |
+
outputs="text",
|
363 |
+
title="OpenAI Prompt Engineering Tester",
|
364 |
+
description="Test different prompts and messages with the OpenAI API."
|
365 |
+
)
|
366 |
+
|
367 |
+
# Launch the app
|
368 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|