AgentVerse's picture
bump version to 0.1.8
01523b5
import base64
import itertools
import json
from typing import Dict, List, Tuple
import cv2
import gradio as gr
from agentverse import TaskSolving
from agentverse.simulation import Simulation
from agentverse.message import Message
def cover_img(background, img, place: Tuple[int, int]):
"""
Overlays the specified image to the specified position of the background image.
:param background: background image
:param img: the specified image
:param place: the top-left coordinate of the target location
"""
back_h, back_w, _ = background.shape
height, width, _ = img.shape
for i, j in itertools.product(range(height), range(width)):
if img[i, j, 3]:
background[place[0] + i, place[1] + j] = img[i, j, :3]
class GUI:
"""
the UI of frontend
"""
def __init__(self, task: str, tasks_dir: str):
"""
init a UI.
default number of students is 0
"""
self.messages = []
self.task = task
if task == "pipeline_brainstorming":
self.backend = TaskSolving.from_task(task, tasks_dir)
else:
self.backend = Simulation.from_task(task, tasks_dir)
self.turns_remain = 0
self.agent_id = {
self.backend.agents[idx].name: idx
for idx in range(len(self.backend.agents))
}
self.stu_num = len(self.agent_id) - 1
self.autoplay = False
self.image_now = None
self.text_now = None
self.tot_solutions = 5
self.solution_status = [False] * self.tot_solutions
def get_avatar(self, idx):
if idx == -1:
img = cv2.imread("./imgs/db_diag/-1.png")
elif self.task == "prisoner_dilemma":
img = cv2.imread(f"./imgs/prison/{idx}.png")
elif self.task == "db_diag":
img = cv2.imread(f"./imgs/db_diag/{idx}.png")
elif "sde" in self.task:
img = cv2.imread(f"./imgs/sde/{idx}.png")
else:
img = cv2.imread(f"./imgs/{idx}.png")
base64_str = cv2.imencode(".png", img)[1].tostring()
return "data:image/png;base64," + base64.b64encode(base64_str).decode("utf-8")
def stop_autoplay(self):
self.autoplay = False
return (
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
)
def start_autoplay(self):
self.autoplay = True
yield (
self.image_now,
self.text_now,
gr.Button.update(interactive=False),
gr.Button.update(interactive=True),
gr.Button.update(interactive=False),
*[gr.Button.update(visible=statu) for statu in self.solution_status],
gr.Box.update(visible=any(self.solution_status)),
)
while self.autoplay and self.turns_remain > 0:
outputs = self.gen_output()
self.image_now, self.text_now = outputs
yield (
*outputs,
gr.Button.update(interactive=not self.autoplay and self.turns_remain > 0),
gr.Button.update(interactive=self.autoplay and self.turns_remain > 0),
gr.Button.update(interactive=not self.autoplay and self.turns_remain > 0),
*[gr.Button.update(visible=statu) for statu in self.solution_status],
gr.Box.update(visible=any(self.solution_status))
)
def delay_gen_output(self):
yield (
self.image_now,
self.text_now,
gr.Button.update(interactive=False),
gr.Button.update(interactive=False),
*[gr.Button.update(visible=statu) for statu in self.solution_status],
gr.Box.update(visible=any(self.solution_status))
)
outputs = self.gen_output()
self.image_now, self.text_now = outputs
yield (
self.image_now,
self.text_now,
gr.Button.update(interactive=self.turns_remain > 0),
gr.Button.update(interactive=self.turns_remain > 0),
*[gr.Button.update(visible=statu) for statu in self.solution_status],
gr.Box.update(visible=any(self.solution_status))
)
def delay_reset(self):
self.autoplay = False
self.image_now, self.text_now = self.reset()
return (
self.image_now,
self.text_now,
gr.Button.update(interactive=True),
gr.Button.update(interactive=False),
gr.Button.update(interactive=True),
*[gr.Button.update(visible=statu) for statu in self.solution_status],
gr.Box.update(visible=any(self.solution_status))
)
def reset(self, stu_num=0):
"""
tell backend the new number of students and generate new empty image
:param stu_num:
:return: [empty image, empty message]
"""
if not 0 <= stu_num <= 30:
raise gr.Error("the number of students must be between 0 and 30.")
"""
# [To-Do] Need to add a function to assign agent numbers into the backend.
"""
# self.backend.reset(stu_num)
# self.stu_num = stu_num
"""
# [To-Do] Pass the parameters to reset
"""
self.backend.reset()
self.turns_remain = self.backend.environment.max_turns
if self.task == "prisoner_dilemma":
background = cv2.imread("./imgs/prison/case_1.png")
elif self.task == "db_diag":
background = cv2.imread("./imgs/db_diag/background.png")
elif "sde" in self.task:
background = cv2.imread("./imgs/sde/background.png")
else:
background = cv2.imread("./imgs/background.png")
back_h, back_w, _ = background.shape
stu_cnt = 0
for h_begin, w_begin in itertools.product(
range(800, back_h, 300), range(135, back_w - 200, 200)
):
stu_cnt += 1
img = cv2.imread(
f"./imgs/{(stu_cnt - 1) % 11 + 1 if stu_cnt <= self.stu_num else 'empty'}.png",
cv2.IMREAD_UNCHANGED,
)
cover_img(
background,
img,
(h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
)
self.messages = []
self.solution_status = [False] * self.tot_solutions
return [cv2.cvtColor(background, cv2.COLOR_BGR2RGB), ""]
def gen_img(self, data: List[Dict]):
"""
generate new image with sender rank
:param data:
:return: the new image
"""
# The following code need to be more general. This one is too task-specific.
# if len(data) != self.stu_num:
if len(data) != self.stu_num + 1:
raise gr.Error("data length is not equal to the total number of students.")
if self.task == "prisoner_dilemma":
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
if (
len(self.messages) < 2
or self.messages[-1][0] == 1
or self.messages[-2][0] == 2
):
background = cv2.imread("./imgs/prison/case_1.png")
if data[0]["message"] != "":
cover_img(background, img, (400, 480))
else:
background = cv2.imread("./imgs/prison/case_2.png")
if data[0]["message"] != "":
cover_img(background, img, (400, 880))
if data[1]["message"] != "":
cover_img(background, img, (550, 480))
if data[2]["message"] != "":
cover_img(background, img, (550, 880))
elif self.task == "db_diag":
background = cv2.imread("./imgs/db_diag/background.png")
img = cv2.imread("./imgs/db_diag/speaking.png", cv2.IMREAD_UNCHANGED)
if data[0]["message"] != "":
cover_img(background, img, (750, 80))
if data[1]["message"] != "":
cover_img(background, img, (310, 220))
if data[2]["message"] != "":
cover_img(background, img, (522, 11))
elif "sde" in self.task:
background = cv2.imread("./imgs/sde/background.png")
img = cv2.imread("./imgs/sde/speaking.png", cv2.IMREAD_UNCHANGED)
if data[0]["message"] != "":
cover_img(background, img, (692, 330))
if data[1]["message"] != "":
cover_img(background, img, (692, 660))
if data[2]["message"] != "":
cover_img(background, img, (692, 990))
else:
background = cv2.imread("./imgs/background.png")
back_h, back_w, _ = background.shape
stu_cnt = 0
if data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
cover_img(background, img, (370, 1250))
for h_begin, w_begin in itertools.product(
range(800, back_h, 300), range(135, back_w - 200, 200)
):
stu_cnt += 1
if stu_cnt <= self.stu_num:
img = cv2.imread(
f"./imgs/{(stu_cnt - 1) % 11 + 1}.png", cv2.IMREAD_UNCHANGED
)
cover_img(
background,
img,
(h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
)
if "[RaiseHand]" in data[stu_cnt]["message"]:
# elif data[stu_cnt]["message"] == "[RaiseHand]":
img = cv2.imread("./imgs/hand.png", cv2.IMREAD_UNCHANGED)
cover_img(background, img, (h_begin - 90, w_begin + 10))
elif data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
cover_img(background, img, (h_begin - 90, w_begin + 10))
else:
img = cv2.imread("./imgs/empty.png", cv2.IMREAD_UNCHANGED)
cover_img(background, img, (h_begin, w_begin))
return cv2.cvtColor(background, cv2.COLOR_BGR2RGB)
def return_format(self, messages: List[Message]):
_format = [{"message": "", "sender": idx} for idx in range(len(self.agent_id))]
for message in messages:
if self.task == "db_diag":
content_json: dict = message.content
content_json["diagnose"] = f"[{message.sender}]: {content_json['diagnose']}"
_format[self.agent_id[message.sender]]["message"] = json.dumps(content_json)
elif "sde" in self.task:
if message.sender == "code_tester":
pre_message, message_ = message.content.split("\n")
message_ = "{}\n{}".format(pre_message, json.loads(message_)["feedback"])
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
message.sender, message_
)
else:
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
message.sender, message.content
)
else:
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
message.sender, message.content
)
return _format
def gen_output(self):
"""
generate new image and message of next step
:return: [new image, new message]
"""
# data = self.backend.next_data()
return_message = self.backend.next()
data = self.return_format(return_message)
# data.sort(key=lambda item: item["sender"])
"""
# [To-Do]; Check the message from the backend: only 1 person can speak
"""
for item in data:
if item["message"] not in ["", "[RaiseHand]"]:
self.messages.append((item["sender"], item["message"]))
message = self.gen_message()
self.turns_remain -= 1
return [self.gen_img(data), message]
def gen_message(self):
# If the backend cannot handle this error, use the following code.
message = ""
"""
for item in data:
if item["message"] not in ["", "[RaiseHand]"]:
message = item["message"]
break
"""
for sender, msg in self.messages:
if sender == 0:
avatar = self.get_avatar(0)
elif sender == -1:
avatar = self.get_avatar(-1)
else:
avatar = self.get_avatar((sender - 1) % 11 + 1)
if self.task == "db_diag":
msg_json = json.loads(msg)
self.solution_status = [False] * self.tot_solutions
msg = msg_json["diagnose"]
if msg_json["solution"] != "":
solution: List[str] = msg_json["solution"]
for solu in solution:
if "query" in solu or "queries" in solu:
self.solution_status[0] = True
solu = solu.replace("query", '<span style="color:yellow;">query</span>')
solu = solu.replace("queries", '<span style="color:yellow;">queries</span>')
if "join" in solu:
self.solution_status[1] = True
solu = solu.replace("join", '<span style="color:yellow;">join</span>')
if "index" in solu:
self.solution_status[2] = True
solu = solu.replace("index", '<span style="color:yellow;">index</span>')
if "system configuration" in solu:
self.solution_status[3] = True
solu = solu.replace("system configuration",
'<span style="color:yellow;">system configuration</span>')
if "monitor" in solu or "Monitor" in solu or "Investigate" in solu:
self.solution_status[4] = True
solu = solu.replace("monitor", '<span style="color:yellow;">monitor</span>')
solu = solu.replace("Monitor", '<span style="color:yellow;">Monitor</span>')
solu = solu.replace("Investigate", '<span style="color:yellow;">Investigate</span>')
msg = f"{msg}<br>{solu}"
if msg_json["knowledge"] != "":
msg = f'{msg}<hr style="margin: 5px 0"><span style="font-style: italic">{msg_json["knowledge"]}<span>'
else:
msg = msg.replace("<", "&lt;")
msg = msg.replace(">", "&gt;")
message = (
f'<div style="display: flex; align-items: center; margin-bottom: 10px;overflow:auto;">'
f'<img src="{avatar}" style="width: 5%; height: 5%; border-radius: 25px; margin-right: 10px;">'
f'<div style="background-color: gray; color: white; padding: 10px; border-radius: 10px;'
f'max-width: 70%; white-space: pre-wrap">'
f"{msg}"
f"</div></div>" + message
)
message = '<div id="divDetail" style="height:600px;overflow:auto;">' + message + "</div>"
return message
def submit(self, message: str):
"""
submit message to backend
:param message: message
:return: [new image, new message]
"""
self.backend.submit(message)
self.messages.append((-1, f"[User]: {message}"))
return self.gen_img([{"message": ""}] * len(self.agent_id)), self.gen_message()
def launch(self, single_agent=False, discussion_mode=False):
if self.task == "pipeline_brainstorming":
with gr.Blocks() as demo:
chatbot = gr.Chatbot(height=800, show_label=False)
msg = gr.Textbox(label="Input")
def respond(message, chat_history):
chat_history.append((message, None))
yield "", chat_history
for response in self.backend.iter_run(single_agent=single_agent, discussion_mode=discussion_mode):
print(response)
chat_history.append((None, response))
yield "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
else:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_output = gr.Image()
with gr.Row():
reset_btn = gr.Button("Reset")
# next_btn = gr.Button("Next", variant="primary")
next_btn = gr.Button("Next", interactive=False)
stop_autoplay_btn = gr.Button(
"Stop Autoplay", interactive=False
)
start_autoplay_btn = gr.Button("Start Autoplay", interactive=False)
with gr.Box(visible=False) as solutions:
with gr.Column():
gr.HTML("Optimization Solutions:")
with gr.Row():
rewrite_slow_query_btn = gr.Button("Rewrite Slow Query", visible=False)
add_query_hints_btn = gr.Button("Add Query Hints", visible=False)
update_indexes_btn = gr.Button("Update Indexes", visible=False)
tune_parameters_btn = gr.Button("Tune Parameters", visible=False)
gather_more_info_btn = gr.Button("Gather More Info", visible=False)
# text_output = gr.Textbox()
text_output = gr.HTML(self.reset()[1])
# Given a botton to provide student numbers and their inf.
# stu_num = gr.Number(label="Student Number", precision=0)
# stu_num = self.stu_num
if self.task == "db_diag":
user_msg = gr.Textbox()
submit_btn = gr.Button("Submit", variant="primary")
submit_btn.click(fn=self.submit, inputs=user_msg, outputs=[image_output, text_output],
show_progress=False)
else:
pass
# next_btn.click(fn=self.gen_output, inputs=None, outputs=[image_output, text_output],
# show_progress=False)
next_btn.click(
fn=self.delay_gen_output,
inputs=None,
outputs=[
image_output,
text_output,
next_btn,
start_autoplay_btn,
rewrite_slow_query_btn,
add_query_hints_btn,
update_indexes_btn,
tune_parameters_btn,
gather_more_info_btn,
solutions
],
show_progress=False,
)
# [To-Do] Add botton: re-start (load different people and env)
# reset_btn.click(fn=self.reset, inputs=stu_num, outputs=[image_output, text_output],
# show_progress=False)
# reset_btn.click(fn=self.reset, inputs=None, outputs=[image_output, text_output], show_progress=False)
reset_btn.click(
fn=self.delay_reset,
inputs=None,
outputs=[
image_output,
text_output,
next_btn,
stop_autoplay_btn,
start_autoplay_btn,
rewrite_slow_query_btn,
add_query_hints_btn,
update_indexes_btn,
tune_parameters_btn,
gather_more_info_btn,
solutions
],
show_progress=False,
)
stop_autoplay_btn.click(
fn=self.stop_autoplay,
inputs=None,
outputs=[next_btn, stop_autoplay_btn, start_autoplay_btn],
show_progress=False,
)
start_autoplay_btn.click(
fn=self.start_autoplay,
inputs=None,
outputs=[
image_output,
text_output,
next_btn,
stop_autoplay_btn,
start_autoplay_btn,
rewrite_slow_query_btn,
add_query_hints_btn,
update_indexes_btn,
tune_parameters_btn,
gather_more_info_btn,
solutions
],
show_progress=False,
)
demo.queue(concurrency_count=5, max_size=20).launch()
# demo.launch()