Spaces:
Running
Running
import json | |
import gradio as gr | |
from collections import UserList | |
from flow import full_flow | |
schema = { | |
"input": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The name of the location for which the weather forecast is requested." | |
}, | |
"date": { | |
"type": "string", | |
"format": "date", | |
"description": "The date for which the weather forecast is requested, in YYYY-MM-DD format." | |
} | |
}, | |
"required": [ | |
"location", | |
"date" | |
] | |
}, | |
"output": { | |
"type": "object", | |
"properties": { | |
"temperature": { | |
"type": "number", | |
"description": "The forecasted temperature in degrees Celsius." | |
}, | |
"condition": { | |
"type": "string", | |
"description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)." | |
}, | |
"humidity": { | |
"type": "number", | |
"description": "The forecasted humidity percentage." | |
}, | |
"wind_speed": { | |
"type": "number", | |
"description": "The forecasted wind speed in kilometers per hour." | |
} | |
}, | |
"required": [ | |
"temperature", | |
"condition", | |
"humidity", | |
"wind_speed" | |
] | |
}, | |
"description": "Alice requests a weather forecast for a specific location and date from Bob's weather service.", | |
"examples": [ | |
{ | |
"location": "New York", | |
"date": "2023-10-15" | |
}, | |
{ | |
"location": "London", | |
"date": "2023-11-01" | |
} | |
], | |
"tools": [ | |
{ | |
"name": "WeatherForecastAPI", | |
"description": "An API that provides weather forecasts for a given location and date.", | |
"input": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The name of the location for which the weather forecast is requested." | |
}, | |
"date": { | |
"type": "string", | |
"format": "date", | |
"description": "The date for which the weather forecast is requested, in YYYY-MM-DD format." | |
} | |
}, | |
"required": [ | |
"location", | |
"date" | |
] | |
}, | |
"output": { | |
"type": "object", | |
"properties": { | |
"temperature": { | |
"type": "number", | |
"description": "The forecasted temperature in degrees Celsius." | |
}, | |
"condition": { | |
"type": "string", | |
"description": "A brief description of the weather condition (e.g., sunny, cloudy, rainy)." | |
}, | |
"humidity": { | |
"type": "number", | |
"description": "The forecasted humidity percentage." | |
}, | |
"wind_speed": { | |
"type": "number", | |
"description": "The forecasted wind speed in kilometers per hour." | |
} | |
}, | |
"required": [ | |
"temperature", | |
"condition", | |
"humidity", | |
"wind_speed" | |
] | |
}, | |
"dummy_outputs": [ | |
{ | |
"temperature": 18, | |
"condition": "Sunny", | |
"humidity": 55, | |
"wind_speed": 10 | |
}, | |
{ | |
"temperature": 12, | |
"condition": "Cloudy", | |
"humidity": 80, | |
"wind_speed": 15 | |
} | |
] | |
} | |
] | |
} | |
SCHEMAS = { | |
"weather_forecast": schema, | |
"other": { "input": "PIPPO"} | |
} | |
def parse_raw_messages(messages_raw): | |
messages_clean = [] | |
messages_agora = [] | |
for message in messages_raw: | |
role = message['role'] | |
message_without_role = dict(message) | |
del message_without_role['role'] | |
messages_agora.append({ | |
'role': role, | |
'content': '```\n' + json.dumps(message_without_role, indent=2) + '\n```' | |
}) | |
if message.get('status') == 'error': | |
messages_clean.append({ | |
'role': role, | |
'content': f"Error: {message['message']}" | |
}) | |
else: | |
messages_clean.append({ | |
'role': role, | |
'content': message['body'] | |
}) | |
return messages_clean, messages_agora | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("### Agora Demo") | |
gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.") | |
chosen_task = gr.Dropdown(choices=list(SCHEMAS.keys()), label="Schema", value="weather_forecast") | |
custom_task = gr.Checkbox(label="Custom Task") | |
STATE_TRACKER = {} | |
def render(chosen_task, custom_task): | |
if STATE_TRACKER.get('chosen_task') != chosen_task: | |
STATE_TRACKER['chosen_task'] = chosen_task | |
for k, v in SCHEMAS[chosen_task].items(): | |
if isinstance(v, str): | |
STATE_TRACKER[k] = v | |
else: | |
STATE_TRACKER[k] = json.dumps(v, indent=2) | |
if custom_task: | |
gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x})) | |
gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x})) | |
gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x})) | |
gr.TextArea(label="Tools", value=STATE_TRACKER["tools"], interactive=True).change(lambda x: STATE_TRACKER.update({'tools': x})) | |
gr.TextArea(label="Examples", value=STATE_TRACKER["examples"], interactive=True).change(lambda x: STATE_TRACKER.update({'examples': x})) | |
model_options = [ | |
('GPT 4o (Camel AI)', 'gpt-4o'), | |
('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'), | |
('Claude 3 Sonnet (LangChain)', 'claude-3-sonnet'), | |
('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'), | |
('Llama3 405B (Sambanova + LangChain)', 'llama3-405b') | |
] | |
fallback_image = '' | |
images = { | |
'gpt-4o': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png', | |
'gpt-4o-mini': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png', | |
'claude-3-5-sonnet-latest': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI', | |
'claude-3-5-haiku-latest': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI', | |
'gemini-1.5-pro': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/google-gemini-icon.png', | |
'llama3-405b': 'https://www.designstub.com/png-resources/wp-content/uploads/2023/03/meta-icon-social-media-flat-graphic-vector-3-novem.png' | |
} | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
alice_model_dd = gr.Dropdown(label="Alice Model", choices=model_options, value="gpt-4o") | |
with gr.Column(scale=1): | |
bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o") | |
button = gr.Button('Start', elem_id='start_button') | |
gr.Markdown('### Natural Language') | |
def render_with_images(alice_model, bob_model): | |
avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)] | |
chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images) | |
with gr.Accordion(label="Raw Messages", open=False): | |
chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images) | |
gr.Markdown('### Negotiation') | |
chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images) | |
gr.Markdown('### Protocol') | |
protocol_result = gr.TextArea(interactive=False, label="Protocol") | |
gr.Markdown('### Implementation') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation") | |
with gr.Column(scale=1): | |
bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation") | |
gr.Markdown('### Structured Communication') | |
structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images) | |
with gr.Accordion(label="Raw Messages", open=False): | |
structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images) | |
def respond(chosen_task, custom_task, alice_model, bob_model): | |
yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \ | |
None, None, None, None, None, None, None, None | |
if custom_task: | |
schema = dict(STATE_TRACKER) | |
for k, v in schema.items(): | |
if isinstance(v, str): | |
try: | |
schema[k] = json.loads(v) | |
except: | |
pass | |
else: | |
schema = SCHEMAS[chosen_task] | |
for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model): | |
nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw) | |
structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw) | |
yield gr.update(), gr.update(), gr.update(), nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, alice_implementation, bob_implementation | |
#yield from full_flow(schema, alice_model, bob_model) | |
yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() | |
button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation]) | |
demo.launch() | |
if __name__ == '__main__': | |
main() | |