|
import openai |
|
import os |
|
import gradio as gr |
|
from enum import Enum |
|
from dataclasses import dataclass, asdict, field |
|
from typing import List, Optional, Union, Dict, Any |
|
import json |
|
|
|
api_key = os.getenv("OPEN_AI_KEY") |
|
|
|
|
|
class COCOClass(Enum): |
|
person = 0 |
|
bicycle = 1 |
|
car = 2 |
|
motorcycle = 3 |
|
airplane = 4 |
|
bus = 5 |
|
train = 6 |
|
truck = 7 |
|
boat = 8 |
|
traffic_light = 9 |
|
fire_hydrant = 10 |
|
stop_sign = 11 |
|
parking_meter = 12 |
|
bench = 13 |
|
bird = 14 |
|
cat = 15 |
|
dog = 16 |
|
horse = 17 |
|
sheep = 18 |
|
cow = 19 |
|
elephant = 20 |
|
bear = 21 |
|
zebra = 22 |
|
giraffe = 23 |
|
backpack = 24 |
|
umbrella = 25 |
|
handbag = 26 |
|
tie = 27 |
|
suitcase = 28 |
|
frisbee = 29 |
|
skis = 30 |
|
snowboard = 31 |
|
sports_ball = 32 |
|
kite = 33 |
|
baseball_bat = 34 |
|
baseball_glove = 35 |
|
skateboard = 36 |
|
surfboard = 37 |
|
tennis_racket = 38 |
|
bottle = 39 |
|
wine_glass = 40 |
|
cup = 41 |
|
fork = 42 |
|
knife = 43 |
|
spoon = 44 |
|
bowl = 45 |
|
banana = 46 |
|
apple = 47 |
|
sandwich = 48 |
|
orange = 49 |
|
broccoli = 50 |
|
carrot = 51 |
|
hot_dog = 52 |
|
pizza = 53 |
|
donut = 54 |
|
cake = 55 |
|
chair = 56 |
|
couch = 57 |
|
potted_plant = 58 |
|
bed = 59 |
|
dining_table = 60 |
|
toilet = 61 |
|
tv = 62 |
|
laptop = 63 |
|
mouse = 64 |
|
remote = 65 |
|
keyboard = 66 |
|
cell_phone = 67 |
|
microwave = 68 |
|
oven = 69 |
|
toaster = 70 |
|
sink = 71 |
|
refrigerator = 72 |
|
book = 73 |
|
clock = 74 |
|
vase = 75 |
|
scissors = 76 |
|
teddy_bear = 77 |
|
hair_drier = 78 |
|
toothbrush = 79 |
|
|
|
|
|
@dataclass |
|
class VehicleProps: |
|
brand: Optional[str] = None |
|
type: Optional[COCOClass] = None |
|
plate: Optional[str] = None |
|
|
|
@dataclass |
|
class PersonProps: |
|
face_images: Optional[List[str]] = field(default_factory=list) |
|
age: Optional[int] = None |
|
race: Optional[str] = None |
|
gender: Optional[str] = None |
|
top_color: Optional[str] = None |
|
bottom_color: Optional[str] = None |
|
|
|
@dataclass |
|
class Activity: |
|
prompt: Optional[str] = None |
|
type: Optional[str] = None |
|
|
|
@dataclass |
|
class Investigation: |
|
target: COCOClass |
|
images: List[str] |
|
activity: Optional[Activity] = None |
|
complex_appearance: Optional[str] = None |
|
props: Optional[Union[VehicleProps, PersonProps]] = None |
|
primary_color: Optional[str] = None |
|
secondary_color: Optional[str] = None |
|
|
|
|
|
DEFAULT_SYSTEM_MESSAGE = """ |
|
You are a helpful assistant that extracts structured information from text descriptions. |
|
|
|
Your task is to parse the following text prompt and extract information to populate an Investigation JSON object as per the definitions provided. |
|
|
|
Definitions: |
|
|
|
Investigation: |
|
{{ |
|
"target": A COCO class name (from the COCOClass enum), |
|
"images": List of image URLs, |
|
"activity": {{ |
|
"prompt": A description of an activity, e.g., "crossing the street", "crossing red light", "holding a gun", |
|
"type": Either "full_screen" or "square" |
|
- "full_screen": When the activity requires the full scene for context (e.g., "seeing a movie"). |
|
- "square": When the activity context can be understood from a close-up image (e.g., "holding a cat"). |
|
}}, |
|
"complex_appearance": Description of appearance details that do not fit into other fields, e.g., "has a hat with Nike logo" or "Tattoo on left arm", |
|
"props": Either VehicleProps or PersonProps (only if the target is vehicle or person), |
|
"primary_color": Primary color mentioned in the prompt, |
|
"secondary_color": Secondary color mentioned in the prompt |
|
}} |
|
|
|
VehicleProps: |
|
{{ |
|
"brand": Vehicle brand, e.g., "Mercedes", |
|
"type": COCO class name of vehicles (e.g., "truck"), |
|
"plate": License plate number, e.g., "123AB" |
|
}} |
|
|
|
PersonProps: |
|
{{ |
|
"face_images": List of face image URLs, |
|
"age": Age as a number, |
|
"race": Race or ethnicity (one of: asian, white, middle eastern, indian, latino, black), |
|
"gender": Gender (Male or Female), |
|
"top_color": Color of the top garment (e.g., shirt, blouse), # Changed from shirt_color |
|
"bottom_color": Color of the bottom garment (pants, skirt, etc.) |
|
}} |
|
|
|
COCOClass Enum: |
|
{{ |
|
{', '.join([f'"{member.name}"' for member in COCOClass])} |
|
}} |
|
|
|
Important Notes: |
|
|
|
1. The output JSON should be as minimal as possible. Do not include fields like 'primary_color' or 'secondary_color' if they are not mentioned in the prompt. |
|
|
|
2. Be especially careful with 'activity' and 'complex_appearance' fields. Use them only if the prompt has data that does not fit elsewhere in the JSON. For example: |
|
- "a guy with red shirt" -> Map 'red shirt' to 'top_color' in PersonProps. |
|
- "a guy with a black hat" -> Since there isn't any field for 'hat', include "black hat" in 'complex_appearance'. |
|
|
|
3. Avoid using 'complex_appearance' and 'activity' fields unless absolutely necessary. |
|
|
|
4. Do not include undefined fields or fields not mentioned in the prompt. |
|
|
|
5. Use the COCOClass enum for the target class name. |
|
|
|
Now, process the following prompt: |
|
|
|
'''prompt_text''' |
|
|
|
Provide the Investigation JSON object, including only the relevant fields based on the prompt. Do not include any explanations. |
|
""" |
|
|
|
|
|
def process_prompt(prompt_text: str, images: List[str], face_images: List[str], |
|
system_message: Optional[str] = None, user_message: Optional[str] = None, |
|
temperature: float = 0.0) -> Optional[Dict[str, Any]]: |
|
client = openai.OpenAI(api_key=api_key) |
|
|
|
|
|
if not user_message: |
|
user_message = "" |
|
|
|
|
|
messages = [] |
|
if system_message.strip(): |
|
messages.append({"role": "system", "content": system_message.replace("prompt_text", prompt_text)}) |
|
if user_message.strip(): |
|
messages.append({"role": "user", "content": user_message.replace("prompt_text", prompt_text)}) |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=messages, |
|
response_format={ "type": "json_object" }, |
|
temperature=temperature, |
|
max_tokens=1000, |
|
) |
|
|
|
|
|
content = response.choices[0].message.content |
|
|
|
|
|
try: |
|
investigation_data = json.loads(content) |
|
except json.JSONDecodeError as e: |
|
print("Error parsing JSON:", e) |
|
print("OpenAI response:", content) |
|
return None |
|
|
|
|
|
investigation = parse_investigation(investigation_data, images, face_images) |
|
|
|
|
|
if investigation: |
|
investigation_dict = asdict(investigation) |
|
|
|
investigation_dict['target'] = investigation.target.name |
|
if investigation.props: |
|
if isinstance(investigation.props, VehicleProps) and investigation.props.type: |
|
investigation_dict['props']['type'] = investigation.props.type.name |
|
elif isinstance(investigation.props, PersonProps): |
|
pass |
|
return investigation_dict |
|
else: |
|
return None |
|
|
|
|
|
def parse_investigation(data: Dict[str, Any], images: List[str], face_images: List[str]) -> Optional[Investigation]: |
|
|
|
target_name = data.get('target') |
|
try: |
|
target_enum = COCOClass[target_name] |
|
except KeyError: |
|
print(f"Invalid COCO class name: {target_name}") |
|
return None |
|
|
|
|
|
activity_data = data.get('activity') |
|
if activity_data: |
|
activity = Activity( |
|
prompt=activity_data.get('prompt'), |
|
type=activity_data.get('type') |
|
) |
|
else: |
|
activity = None |
|
|
|
|
|
props_data = data.get('props') |
|
props = None |
|
if props_data: |
|
if 'face_images' in props_data: |
|
|
|
props = PersonProps( |
|
face_images=face_images, |
|
age=props_data.get('age'), |
|
race=props_data.get('race'), |
|
gender=props_data.get('gender'), |
|
top_color=props_data.get('top_color'), |
|
bottom_color=props_data.get('bottom_color') |
|
) |
|
elif 'brand' in props_data: |
|
|
|
vehicle_type_name = props_data.get('type') |
|
if vehicle_type_name: |
|
try: |
|
vehicle_type_enum = COCOClass[vehicle_type_name] |
|
except KeyError: |
|
print(f"Invalid vehicle type: {vehicle_type_name}") |
|
vehicle_type_enum = None |
|
else: |
|
vehicle_type_enum = None |
|
|
|
props = VehicleProps( |
|
brand=props_data.get('brand'), |
|
type=vehicle_type_enum, |
|
plate=props_data.get('plate') |
|
) |
|
|
|
|
|
investigation = Investigation( |
|
target=target_enum, |
|
images=images, |
|
activity=activity, |
|
complex_appearance=data.get('complex_appearance'), |
|
props=props, |
|
primary_color=data.get('primary_color'), |
|
secondary_color=data.get('secondary_color') |
|
) |
|
|
|
return investigation |
|
|
|
|
|
def gradio_app(prompts_text, system_message, user_message, temperature): |
|
|
|
prompts = [p.strip() for p in prompts_text.split('\n') if p.strip()] |
|
images = ["http://example.com/image1.jpg", "http://example.com/image2.jpg"] |
|
face_images = ["http://example.com/face1.jpg"] |
|
|
|
results = [] |
|
for p in prompts: |
|
investigation_dict = process_prompt( |
|
prompt_text=p, |
|
images=images, |
|
face_images=face_images, |
|
system_message=system_message if system_message else None, |
|
user_message=user_message if user_message else None, |
|
temperature=temperature if temperature else 0.0 |
|
) |
|
results.append(f'{p}\n') |
|
if investigation_dict: |
|
results.append(json.dumps(investigation_dict, indent=4)) |
|
else: |
|
results.append("Failed to process prompt.") |
|
|
|
return "\n\n".join(results) |
|
|
|
if __name__ == "__main__": |
|
|
|
default_prompts = "\n".join([ |
|
"A red sports car with a license plate reading 'FAST123'.", |
|
"An elderly woman wearing a green dress and a pearl necklace.", |
|
"A cyclist in a yellow jersey riding a blue bicycle.", |
|
"A group of people playing frisbee in the park.", |
|
"A man with a large tattoo of a dragon on his right arm.", |
|
"A black and white cat sitting on a red couch.", |
|
"A delivery truck with the 'FedEx' logo on the side.", |
|
"A child holding a red balloon shaped like a dog.", |
|
"A person wearing a hoodie with the text 'OpenAI' on it.", |
|
"A woman in a blue swimsuit swimming in the ocean." |
|
]) |
|
|
|
default_system_message = DEFAULT_SYSTEM_MESSAGE.replace("{{prompt_text}}", "{prompt_text}") |
|
default_user_message = "" |
|
default_temperature = 0.0 |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_app, |
|
inputs=[ |
|
gr.Textbox(lines=5, label="List of Prompts (comma-separated)", value=default_prompts), |
|
gr.Textbox(lines=20, label="System Message (optional)", value=default_system_message), |
|
gr.Textbox(lines=5, label="User Message (optional)", value=default_user_message), |
|
gr.Slider(minimum=0, maximum=1, step=0.1, label="Temperature", value=default_temperature) |
|
], |
|
outputs="text", |
|
title="OpenAI Prompt Engineering Tester", |
|
description="Test different prompts and messages with the OpenAI API." |
|
) |
|
|
|
|
|
iface.launch() |
|
|