Spaces:
Build error
Build error
AgentVerse
commited on
Commit
·
01523b5
1
Parent(s):
7b262a7
bump version to 0.1.8
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- agentverse/__init__.py +24 -0
- agentverse/agents/__init__.py +20 -0
- agentverse/agents/base.py +101 -0
- agentverse/agents/simulation_agent/conversation.py +107 -0
- agentverse/agents/simulation_agent/prisoner_dilemma.py +167 -0
- agentverse/agents/simulation_agent/reflection.py +227 -0
- agentverse/agents/simulation_agent/tool.py +177 -0
- agentverse/agents/tasksolving_agent/__init__.py +6 -0
- agentverse/agents/tasksolving_agent/critic.py +127 -0
- agentverse/agents/tasksolving_agent/evaluator.py +86 -0
- agentverse/agents/tasksolving_agent/executor.py +130 -0
- agentverse/agents/tasksolving_agent/manager.py +116 -0
- agentverse/agents/tasksolving_agent/role_assigner.py +88 -0
- agentverse/agents/tasksolving_agent/solver.py +110 -0
- agentverse/agentverse.py +65 -0
- agentverse/demo.py +487 -0
- agentverse/environments/__init__.py +17 -0
- agentverse/environments/base.py +58 -0
- agentverse/environments/simulation_env/basic.py +101 -0
- agentverse/environments/simulation_env/pokemon.py +222 -0
- agentverse/environments/simulation_env/prisoner_dilemma.py +49 -0
- agentverse/environments/simulation_env/reflection.py +128 -0
- agentverse/environments/simulation_env/rules/__init__.py +1 -0
- agentverse/environments/simulation_env/rules/base.py +98 -0
- agentverse/environments/simulation_env/rules/describer/__init__.py +9 -0
- agentverse/environments/simulation_env/rules/describer/base.py +23 -0
- agentverse/environments/simulation_env/rules/describer/basic.py +16 -0
- agentverse/environments/simulation_env/rules/describer/classroom.py +40 -0
- agentverse/environments/simulation_env/rules/describer/pokemon.py +51 -0
- agentverse/environments/simulation_env/rules/describer/prisoner.py +49 -0
- agentverse/environments/simulation_env/rules/order/__init__.py +11 -0
- agentverse/environments/simulation_env/rules/order/base.py +18 -0
- agentverse/environments/simulation_env/rules/order/classroom.py +100 -0
- agentverse/environments/simulation_env/rules/order/concurrent.py +19 -0
- agentverse/environments/simulation_env/rules/order/prisoner.py +48 -0
- agentverse/environments/simulation_env/rules/order/random.py +21 -0
- agentverse/environments/simulation_env/rules/order/sde_team.py +30 -0
- agentverse/environments/simulation_env/rules/order/sde_team_given_tests.py +35 -0
- agentverse/environments/simulation_env/rules/order/sequential.py +28 -0
- agentverse/environments/simulation_env/rules/selector/__init__.py +10 -0
- agentverse/environments/simulation_env/rules/selector/base.py +30 -0
- agentverse/environments/simulation_env/rules/selector/basic.py +27 -0
- agentverse/environments/simulation_env/rules/selector/classroom.py +47 -0
- agentverse/environments/simulation_env/rules/selector/code_api.py +97 -0
- agentverse/environments/simulation_env/rules/selector/pokemon.py +98 -0
- agentverse/environments/simulation_env/rules/selector/sde_team.py +72 -0
- agentverse/environments/simulation_env/rules/selector/sde_team_given_tests.py +56 -0
- agentverse/environments/simulation_env/rules/updater/__init__.py +9 -0
- agentverse/environments/simulation_env/rules/updater/base.py +27 -0
- agentverse/environments/simulation_env/rules/updater/basic.py +75 -0
agentverse/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .output_parser import output_parser_registry
|
2 |
+
from .environments import env_registry
|
3 |
+
from .environments.simulation_env.rules.order import order_registry
|
4 |
+
from .environments.simulation_env.rules.describer import describer_registry
|
5 |
+
from .environments.simulation_env.rules.selector import selector_registry
|
6 |
+
from .environments.simulation_env.rules.updater import updater_registry
|
7 |
+
from .environments.simulation_env.rules.visibility import visibility_registry
|
8 |
+
|
9 |
+
|
10 |
+
from .environments.tasksolving_env.rules.decision_maker import decision_maker_registry
|
11 |
+
from .environments.tasksolving_env.rules.evaluator import evaluator_registry
|
12 |
+
from .environments.tasksolving_env.rules.executor import executor_registry
|
13 |
+
from .environments.tasksolving_env.rules.role_assigner import role_assigner_registry
|
14 |
+
|
15 |
+
|
16 |
+
from .simulation import Simulation
|
17 |
+
from .tasksolving import TaskSolving
|
18 |
+
from .initialization import (
|
19 |
+
prepare_task_config,
|
20 |
+
load_agent,
|
21 |
+
load_environment,
|
22 |
+
load_llm,
|
23 |
+
load_memory,
|
24 |
+
)
|
agentverse/agents/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .agent import Agent
|
2 |
+
from agentverse.registry import Registry
|
3 |
+
|
4 |
+
agent_registry = Registry(name="AgentRegistry")
|
5 |
+
|
6 |
+
|
7 |
+
from .base import BaseAgent
|
8 |
+
from agentverse.agents.simulation_agent.conversation import ConversationAgent
|
9 |
+
from agentverse.agents.simulation_agent.tool import ToolAgent
|
10 |
+
from agentverse.agents.simulation_agent.prisoner_dilemma import (
|
11 |
+
PoliceAgent,
|
12 |
+
PrisonerAgent,
|
13 |
+
)
|
14 |
+
|
15 |
+
from agentverse.agents.tasksolving_agent.role_assigner import RoleAssignerAgent
|
16 |
+
from agentverse.agents.tasksolving_agent.critic import CriticAgent
|
17 |
+
from agentverse.agents.tasksolving_agent.evaluator import EvaluatorAgent
|
18 |
+
from agentverse.agents.tasksolving_agent.solver import SolverAgent
|
19 |
+
from agentverse.agents.tasksolving_agent.manager import ManagerAgent
|
20 |
+
from agentverse.agents.tasksolving_agent.executor import ExecutorAgent
|
agentverse/agents/base.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from abc import abstractmethod
|
3 |
+
from typing import List, NamedTuple, Set, Union
|
4 |
+
from string import Template
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from agentverse.llms import BaseLLM
|
9 |
+
from agentverse.memory import BaseMemory, ChatHistoryMemory
|
10 |
+
from agentverse.message import Message
|
11 |
+
from agentverse.output_parser import OutputParser
|
12 |
+
from agentverse.memory_manipulator import BaseMemoryManipulator
|
13 |
+
|
14 |
+
|
15 |
+
class BaseAgent(BaseModel):
|
16 |
+
name: str
|
17 |
+
llm: BaseLLM
|
18 |
+
output_parser: OutputParser
|
19 |
+
prepend_prompt_template: str = Field(default="")
|
20 |
+
append_prompt_template: str = Field(default="")
|
21 |
+
prompt_template: str = Field(default="")
|
22 |
+
role_description: str = Field(default="")
|
23 |
+
memory: BaseMemory = Field(default_factory=ChatHistoryMemory)
|
24 |
+
memory_manipulator: BaseMemoryManipulator = Field(
|
25 |
+
default_factory=BaseMemoryManipulator
|
26 |
+
)
|
27 |
+
max_retry: int = Field(default=3)
|
28 |
+
receiver: Set[str] = Field(default=set({"all"}))
|
29 |
+
async_mode: bool = Field(default=True)
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def step(self, env_description: str = "") -> Message:
|
33 |
+
"""Get one step response"""
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def astep(self, env_description: str = "") -> Message:
|
38 |
+
"""Asynchronous version of step"""
|
39 |
+
pass
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
def reset(self) -> None:
|
43 |
+
"""Reset the agent"""
|
44 |
+
pass
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
48 |
+
"""Add a message to the memory"""
|
49 |
+
pass
|
50 |
+
|
51 |
+
def get_spend(self) -> float:
|
52 |
+
return self.llm.get_spend()
|
53 |
+
|
54 |
+
def get_spend_formatted(self) -> str:
|
55 |
+
two_trailing = f"${self.get_spend():.2f}"
|
56 |
+
if two_trailing == "$0.00":
|
57 |
+
return f"${self.get_spend():.6f}"
|
58 |
+
return two_trailing
|
59 |
+
|
60 |
+
def get_all_prompts(self, **kwargs):
|
61 |
+
prepend_prompt = Template(self.prepend_prompt_template).safe_substitute(
|
62 |
+
**kwargs
|
63 |
+
)
|
64 |
+
append_prompt = Template(self.append_prompt_template).safe_substitute(**kwargs)
|
65 |
+
return prepend_prompt, append_prompt
|
66 |
+
|
67 |
+
def get_receiver(self) -> Set[str]:
|
68 |
+
return self.receiver
|
69 |
+
|
70 |
+
def set_receiver(self, receiver: Union[Set[str], str]) -> None:
|
71 |
+
if isinstance(receiver, str):
|
72 |
+
self.receiver = set({receiver})
|
73 |
+
elif isinstance(receiver, set):
|
74 |
+
self.receiver = receiver
|
75 |
+
else:
|
76 |
+
raise ValueError(
|
77 |
+
"input argument `receiver` must be a string or a set of string"
|
78 |
+
)
|
79 |
+
|
80 |
+
def add_receiver(self, receiver: Union[Set[str], str]) -> None:
|
81 |
+
if isinstance(receiver, str):
|
82 |
+
self.receiver.add(receiver)
|
83 |
+
elif isinstance(receiver, set):
|
84 |
+
self.receiver = self.receiver.union(receiver)
|
85 |
+
else:
|
86 |
+
raise ValueError(
|
87 |
+
"input argument `receiver` must be a string or a set of string"
|
88 |
+
)
|
89 |
+
|
90 |
+
def remove_receiver(self, receiver: Union[Set[str], str]) -> None:
|
91 |
+
if isinstance(receiver, str):
|
92 |
+
try:
|
93 |
+
self.receiver.remove(receiver)
|
94 |
+
except KeyError as e:
|
95 |
+
logging.warning(f"Receiver {receiver} not found.")
|
96 |
+
elif isinstance(receiver, set):
|
97 |
+
self.receiver = self.receiver.difference(receiver)
|
98 |
+
else:
|
99 |
+
raise ValueError(
|
100 |
+
"input argument `receiver` must be a string or a set of string"
|
101 |
+
)
|
agentverse/agents/simulation_agent/conversation.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from colorama import Fore
|
3 |
+
|
4 |
+
# import logging
|
5 |
+
from agentverse.logging import get_logger
|
6 |
+
import bdb
|
7 |
+
from string import Template
|
8 |
+
from typing import TYPE_CHECKING, List
|
9 |
+
|
10 |
+
from agentverse.message import Message
|
11 |
+
|
12 |
+
#from . import agent_registry
|
13 |
+
#from .base import BaseAgent
|
14 |
+
from agentverse.agents import agent_registry
|
15 |
+
from agentverse.agents.base import BaseAgent
|
16 |
+
|
17 |
+
logger = get_logger()
|
18 |
+
|
19 |
+
|
20 |
+
@agent_registry.register("conversation")
|
21 |
+
class ConversationAgent(BaseAgent):
|
22 |
+
def step(self, env_description: str = "") -> Message:
|
23 |
+
prompt = self._fill_prompt_template(env_description)
|
24 |
+
|
25 |
+
parsed_response = None
|
26 |
+
for i in range(self.max_retry):
|
27 |
+
try:
|
28 |
+
response = self.llm.generate_response(prompt)
|
29 |
+
parsed_response = self.output_parser.parse(response)
|
30 |
+
break
|
31 |
+
except KeyboardInterrupt:
|
32 |
+
raise
|
33 |
+
except Exception as e:
|
34 |
+
logger.error(e)
|
35 |
+
logger.warn("Retrying...")
|
36 |
+
continue
|
37 |
+
|
38 |
+
if parsed_response is None:
|
39 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
40 |
+
|
41 |
+
message = Message(
|
42 |
+
content=""
|
43 |
+
if parsed_response is None
|
44 |
+
else parsed_response.return_values["output"],
|
45 |
+
sender=self.name,
|
46 |
+
receiver=self.get_receiver(),
|
47 |
+
)
|
48 |
+
return message
|
49 |
+
|
50 |
+
async def astep(self, env_description: str = "") -> Message:
|
51 |
+
"""Asynchronous version of step"""
|
52 |
+
prompt = self._fill_prompt_template(env_description)
|
53 |
+
|
54 |
+
parsed_response = None
|
55 |
+
for i in range(self.max_retry):
|
56 |
+
try:
|
57 |
+
# if self.name == "Code Reviewer":
|
58 |
+
logger.debug(prompt, "Prompt", Fore.CYAN)
|
59 |
+
response = await self.llm.agenerate_response(prompt)
|
60 |
+
|
61 |
+
# logging.info(f"{self.name}'s request result:"
|
62 |
+
# f" {response.content}")
|
63 |
+
parsed_response = self.output_parser.parse(response)
|
64 |
+
break
|
65 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
66 |
+
raise
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(e)
|
69 |
+
logger.warning("Retrying...")
|
70 |
+
continue
|
71 |
+
|
72 |
+
if parsed_response is None:
|
73 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
74 |
+
|
75 |
+
message = Message(
|
76 |
+
content=""
|
77 |
+
if parsed_response is None
|
78 |
+
else parsed_response.return_values["output"],
|
79 |
+
sender=self.name,
|
80 |
+
receiver=self.get_receiver(),
|
81 |
+
)
|
82 |
+
return message
|
83 |
+
|
84 |
+
def _fill_prompt_template(self, env_description: str = "") -> str:
|
85 |
+
"""Fill the placeholders in the prompt template
|
86 |
+
|
87 |
+
In the conversation agent, three placeholders are supported:
|
88 |
+
- ${agent_name}: the name of the agent
|
89 |
+
- ${env_description}: the description of the environment
|
90 |
+
- ${role_description}: the description of the role of the agent
|
91 |
+
- ${chat_history}: the chat history of the agent
|
92 |
+
"""
|
93 |
+
input_arguments = {
|
94 |
+
"agent_name": self.name,
|
95 |
+
"env_description": env_description,
|
96 |
+
"role_description": self.role_description,
|
97 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
98 |
+
}
|
99 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
100 |
+
|
101 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
102 |
+
self.memory.add_message(messages)
|
103 |
+
|
104 |
+
def reset(self) -> None:
|
105 |
+
"""Reset the agent"""
|
106 |
+
self.memory.reset()
|
107 |
+
# TODO: reset receiver
|
agentverse/agents/simulation_agent/prisoner_dilemma.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from string import Template
|
5 |
+
from typing import TYPE_CHECKING, List
|
6 |
+
|
7 |
+
from agentverse.message import Message
|
8 |
+
|
9 |
+
# from . import agent_registry
|
10 |
+
# from .base import BaseAgent
|
11 |
+
from agentverse.agents import agent_registry
|
12 |
+
from agentverse.agents.base import BaseAgent
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from agentverse.environments.base import BaseEnvironment
|
16 |
+
|
17 |
+
|
18 |
+
class PrisonerDilemaAgent(BaseAgent):
|
19 |
+
def step(
|
20 |
+
self,
|
21 |
+
environment: BaseEnvironment,
|
22 |
+
env_description: str = "",
|
23 |
+
) -> Message:
|
24 |
+
prompt = self._fill_prompt_template(env_description)
|
25 |
+
|
26 |
+
parsed_response = None
|
27 |
+
for i in range(self.max_retry):
|
28 |
+
try:
|
29 |
+
response = self.llm.generate_response(prompt)
|
30 |
+
parsed_response = self.output_parser.parse(self, environment, response)
|
31 |
+
break
|
32 |
+
except Exception as e:
|
33 |
+
logging.error(e)
|
34 |
+
logging.warning("Retrying...")
|
35 |
+
continue
|
36 |
+
|
37 |
+
if parsed_response is None:
|
38 |
+
logging.error(f"{self.name} failed to generate valid response.")
|
39 |
+
|
40 |
+
message = Message(
|
41 |
+
content=""
|
42 |
+
if parsed_response is None
|
43 |
+
else parsed_response.return_values["output"],
|
44 |
+
sender=self.name,
|
45 |
+
receiver=self.get_receiver(),
|
46 |
+
)
|
47 |
+
return message
|
48 |
+
|
49 |
+
async def astep(
|
50 |
+
self, environment: BaseEnvironment, env_description: str = ""
|
51 |
+
) -> Message:
|
52 |
+
"""Asynchronous version of step"""
|
53 |
+
prompt = self._fill_prompt_template(env_description)
|
54 |
+
|
55 |
+
parsed_response = None
|
56 |
+
for i in range(self.max_retry):
|
57 |
+
try:
|
58 |
+
response = await self.llm.agenerate_response(prompt)
|
59 |
+
parsed_response = self.output_parser.parse(self, environment, response)
|
60 |
+
break
|
61 |
+
except Exception as e:
|
62 |
+
logging.error(e)
|
63 |
+
logging.warning("Retrying...")
|
64 |
+
continue
|
65 |
+
|
66 |
+
if parsed_response is None:
|
67 |
+
logging.error(f"{self.name} failed to generate valid response.")
|
68 |
+
|
69 |
+
message = Message(
|
70 |
+
content=""
|
71 |
+
if parsed_response is None
|
72 |
+
else parsed_response.return_values["output"],
|
73 |
+
sender=self.name,
|
74 |
+
receiver=self.get_receiver(),
|
75 |
+
)
|
76 |
+
return message
|
77 |
+
|
78 |
+
def _fill_prompt_template(self, env_description: str = "") -> str:
|
79 |
+
"""Fill the placeholders in the prompt template
|
80 |
+
|
81 |
+
In the conversation agent, three placeholders are supported:
|
82 |
+
- ${agent_name}: the name of the agent
|
83 |
+
- ${env_description}: the description of the environment
|
84 |
+
- ${role_description}: the description of the role of the agent
|
85 |
+
- ${chat_history}: the chat history of the agent
|
86 |
+
"""
|
87 |
+
input_arguments = {
|
88 |
+
"agent_name": self.name,
|
89 |
+
"env_description": env_description,
|
90 |
+
"role_description": self.role_description,
|
91 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
92 |
+
}
|
93 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
94 |
+
|
95 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
96 |
+
self.memory.add_message(messages)
|
97 |
+
|
98 |
+
def reset(self) -> None:
|
99 |
+
"""Reset the agent"""
|
100 |
+
self.memory.reset()
|
101 |
+
# TODO: reset receiver
|
102 |
+
|
103 |
+
|
104 |
+
@agent_registry.register("police")
|
105 |
+
class PoliceAgent(PrisonerDilemaAgent):
|
106 |
+
interrogating_form: str
|
107 |
+
|
108 |
+
def _fill_prompt_template(self, env_description: str = "") -> str:
|
109 |
+
"""Fill the placeholders in the prompt template
|
110 |
+
|
111 |
+
In the conversation agent, three placeholders are supported:
|
112 |
+
- ${agent_name}: the name of the agent
|
113 |
+
- ${env_description}: the description of the environment
|
114 |
+
- ${role_description}: the description of the role of the agent
|
115 |
+
- ${chat_history}: the chat history of the agent
|
116 |
+
"""
|
117 |
+
input_arguments = {
|
118 |
+
"agent_name": self.name,
|
119 |
+
"env_description": env_description,
|
120 |
+
"role_description": self.role_description,
|
121 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
122 |
+
}
|
123 |
+
|
124 |
+
role_argument = {
|
125 |
+
"interrogating_form": self.interrogating_form,
|
126 |
+
}
|
127 |
+
|
128 |
+
role_description = Template(self.role_description).safe_substitute(
|
129 |
+
role_argument
|
130 |
+
)
|
131 |
+
input_arguments["role_description"] = role_description
|
132 |
+
|
133 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
134 |
+
|
135 |
+
|
136 |
+
@agent_registry.register("prisoner")
|
137 |
+
class PrisonerAgent(PrisonerDilemaAgent):
|
138 |
+
personality: str
|
139 |
+
relationship_with_another: str
|
140 |
+
|
141 |
+
def _fill_prompt_template(self, env_description: str = "") -> str:
|
142 |
+
"""Fill the placeholders in the prompt template
|
143 |
+
|
144 |
+
In the conversation agent, three placeholders are supported:
|
145 |
+
- ${agent_name}: the name of the agent
|
146 |
+
- ${env_description}: the description of the environment
|
147 |
+
- ${role_description}: the description of the role of the agent
|
148 |
+
- ${chat_history}: the chat history of the agent
|
149 |
+
"""
|
150 |
+
input_arguments = {
|
151 |
+
"agent_name": self.name,
|
152 |
+
"env_description": env_description,
|
153 |
+
"role_description": self.role_description,
|
154 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
155 |
+
}
|
156 |
+
|
157 |
+
role_argument = {
|
158 |
+
"personality": self.personality,
|
159 |
+
"relationship_with_another": self.relationship_with_another,
|
160 |
+
}
|
161 |
+
|
162 |
+
role_description = Template(self.role_description).safe_substitute(
|
163 |
+
role_argument
|
164 |
+
)
|
165 |
+
input_arguments["role_description"] = role_description
|
166 |
+
|
167 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
agentverse/agents/simulation_agent/reflection.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
"""
|
4 |
+
An agent based upon Observation-Planning-Reflection architecture.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from logging import getLogger
|
8 |
+
|
9 |
+
from abc import abstractmethod
|
10 |
+
from typing import List, Set, Union, NamedTuple, TYPE_CHECKING
|
11 |
+
|
12 |
+
from pydantic import BaseModel, Field, validator
|
13 |
+
|
14 |
+
from agentverse.llms import BaseLLM
|
15 |
+
from agentverse.memory import BaseMemory, ChatHistoryMemory
|
16 |
+
from agentverse.message import Message
|
17 |
+
from agentverse.output_parser import OutputParser
|
18 |
+
|
19 |
+
from agentverse.message import Message
|
20 |
+
from agentverse.agents.base import BaseAgent
|
21 |
+
|
22 |
+
from datetime import datetime as dt
|
23 |
+
import datetime
|
24 |
+
|
25 |
+
#from . import agent_registry
|
26 |
+
from string import Template
|
27 |
+
|
28 |
+
from agentverse.agents import agent_registry
|
29 |
+
from agentverse.agents.base import BaseAgent
|
30 |
+
|
31 |
+
logger = getLogger(__file__)
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from agentverse.environments.base import BaseEnvironment
|
35 |
+
|
36 |
+
|
37 |
+
@agent_registry.register("reflection")
|
38 |
+
class ReflectionAgent(BaseAgent):
|
39 |
+
async_mode: bool = (True,)
|
40 |
+
current_time: str = (None,)
|
41 |
+
environment: BaseEnvironment = None
|
42 |
+
step_cnt: int = 0
|
43 |
+
|
44 |
+
manipulated_memory: str = Field(
|
45 |
+
default="", description="one fragment used in prompt construction"
|
46 |
+
)
|
47 |
+
|
48 |
+
@validator("current_time")
|
49 |
+
def convert_str_to_dt(cls, current_time):
|
50 |
+
if not isinstance(current_time, str):
|
51 |
+
raise ValueError("current_time should be str")
|
52 |
+
return dt.strptime(current_time, "%Y-%m-%d %H:%M:%S")
|
53 |
+
|
54 |
+
def step(self, current_time: dt, env_description: str = "") -> Message:
|
55 |
+
"""
|
56 |
+
Call this method at each time frame
|
57 |
+
"""
|
58 |
+
self.current_time = current_time
|
59 |
+
|
60 |
+
self.manipulated_memory = self.memory_manipulator.manipulate_memory()
|
61 |
+
|
62 |
+
prompt = self._fill_prompt_template(env_description)
|
63 |
+
|
64 |
+
parsed_response, reaction, target = None, None, None
|
65 |
+
for i in range(self.max_retry):
|
66 |
+
try:
|
67 |
+
response = self.llm.agenerate_response(prompt)
|
68 |
+
parsed_response = self.output_parser.parse(response)
|
69 |
+
|
70 |
+
if "say(" in parsed_response.return_values["output"]:
|
71 |
+
reaction, target = eval(
|
72 |
+
"self._" + parsed_response.return_values["output"].strip()
|
73 |
+
)
|
74 |
+
elif "act(" in parsed_response.return_values["output"]:
|
75 |
+
reaction, target = eval(
|
76 |
+
"self._" + parsed_response.return_values["output"].strip()
|
77 |
+
)
|
78 |
+
elif "do_nothing(" in parsed_response.return_values["output"]:
|
79 |
+
reaction, target = None, None
|
80 |
+
else:
|
81 |
+
raise Exception(
|
82 |
+
f"no valid parsed_response detected, "
|
83 |
+
f"cur response {parsed_response.return_values['output']}"
|
84 |
+
)
|
85 |
+
break
|
86 |
+
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(e)
|
89 |
+
logger.warn("Retrying...")
|
90 |
+
continue
|
91 |
+
|
92 |
+
if parsed_response is None:
|
93 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
94 |
+
|
95 |
+
if reaction is None:
|
96 |
+
reaction = "Keep doing last action ..."
|
97 |
+
|
98 |
+
message = Message(
|
99 |
+
content="" if reaction is None else reaction,
|
100 |
+
sender=self.name,
|
101 |
+
receiver=self.get_receiver()
|
102 |
+
if target is None
|
103 |
+
else self.get_valid_receiver(target),
|
104 |
+
)
|
105 |
+
|
106 |
+
self.step_cnt += 1
|
107 |
+
|
108 |
+
return message
|
109 |
+
|
110 |
+
async def astep(self, current_time: dt, env_description: str = "") -> Message:
|
111 |
+
"""Asynchronous version of step"""
|
112 |
+
# use environment's time to update agent's time
|
113 |
+
self.current_time = current_time
|
114 |
+
# Before the agent step, we check current status,
|
115 |
+
# TODO add this func after
|
116 |
+
# self.check_status_passive()
|
117 |
+
|
118 |
+
self.manipulated_memory = self.memory_manipulator.manipulate_memory()
|
119 |
+
|
120 |
+
prompt = self._fill_prompt_template(env_description)
|
121 |
+
|
122 |
+
parsed_response, reaction, target = None, None, None
|
123 |
+
for i in range(self.max_retry):
|
124 |
+
try:
|
125 |
+
response = await self.llm.agenerate_response(prompt)
|
126 |
+
parsed_response = self.output_parser.parse(response)
|
127 |
+
|
128 |
+
if "say(" in parsed_response.return_values["output"]:
|
129 |
+
reaction, target = eval(
|
130 |
+
"self._" + parsed_response.return_values["output"].strip()
|
131 |
+
)
|
132 |
+
elif "act(" in parsed_response.return_values["output"]:
|
133 |
+
reaction, target = eval(
|
134 |
+
"self._" + parsed_response.return_values["output"].strip()
|
135 |
+
)
|
136 |
+
elif "do_nothing(" in parsed_response.return_values["output"]:
|
137 |
+
reaction, target = None, None
|
138 |
+
else:
|
139 |
+
raise Exception(
|
140 |
+
f"no valid parsed_response detected, "
|
141 |
+
f"cur response {parsed_response.return_values['output']}"
|
142 |
+
)
|
143 |
+
|
144 |
+
break
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
logger.error(e)
|
148 |
+
logger.warn("Retrying...")
|
149 |
+
continue
|
150 |
+
|
151 |
+
if parsed_response is None:
|
152 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
153 |
+
|
154 |
+
if reaction is None:
|
155 |
+
reaction = "Keep doing last action ..."
|
156 |
+
|
157 |
+
message = Message(
|
158 |
+
content="" if reaction is None else reaction,
|
159 |
+
sender=self.name,
|
160 |
+
receiver=self.get_receiver()
|
161 |
+
if target is None
|
162 |
+
else self.get_valid_receiver(target),
|
163 |
+
)
|
164 |
+
|
165 |
+
self.step_cnt += 1
|
166 |
+
|
167 |
+
return message
|
168 |
+
|
169 |
+
def _act(self, description=None, target=None):
|
170 |
+
if description is None:
|
171 |
+
return ""
|
172 |
+
if target is None:
|
173 |
+
reaction_content = f"{self.name} performs action: '{description}'."
|
174 |
+
else:
|
175 |
+
reaction_content = (
|
176 |
+
f"{self.name} performs action to {target}: '{description}'."
|
177 |
+
)
|
178 |
+
# self.environment.broadcast_observations(self, target, reaction_content)
|
179 |
+
return reaction_content, target
|
180 |
+
|
181 |
+
def _say(self, description, target=None):
|
182 |
+
if description is None:
|
183 |
+
return ""
|
184 |
+
if target is None:
|
185 |
+
reaction_content = f"{self.name} says: '{description}'."
|
186 |
+
else:
|
187 |
+
reaction_content = f"{self.name} says to {target}: '{description}'."
|
188 |
+
# self.environment.broadcast_observations(self, target, reaction_content)
|
189 |
+
return reaction_content, target
|
190 |
+
|
191 |
+
def get_valid_receiver(self, target: str) -> set():
|
192 |
+
all_agents_name = []
|
193 |
+
for agent in self.environment.agents:
|
194 |
+
all_agents_name.append(agent.name)
|
195 |
+
|
196 |
+
if not (target in all_agents_name):
|
197 |
+
return {"all"}
|
198 |
+
else:
|
199 |
+
return {target}
|
200 |
+
|
201 |
+
def _fill_prompt_template(self, env_description: str = "") -> str:
|
202 |
+
"""Fill the placeholders in the prompt template
|
203 |
+
|
204 |
+
In the conversation agent, three placeholders are supported:
|
205 |
+
- ${agent_name}: the name of the agent
|
206 |
+
- ${env_description}: the description of the environment
|
207 |
+
- ${role_description}: the description of the role of the agent
|
208 |
+
- ${chat_history}: the chat history of the agent
|
209 |
+
"""
|
210 |
+
input_arguments = {
|
211 |
+
"agent_name": self.name,
|
212 |
+
"role_description": self.role_description,
|
213 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
214 |
+
"current_time": self.current_time,
|
215 |
+
"env_description": env_description,
|
216 |
+
}
|
217 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
218 |
+
|
219 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
220 |
+
self.memory.add_message(messages)
|
221 |
+
|
222 |
+
def reset(self, environment: BaseEnvironment) -> None:
|
223 |
+
"""Reset the agent"""
|
224 |
+
self.environment = environment
|
225 |
+
self.memory.reset()
|
226 |
+
self.memory_manipulator.agent = self
|
227 |
+
self.memory_manipulator.memory = self.memory
|
agentverse/agents/simulation_agent/tool.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from string import Template
|
3 |
+
from typing import List, NamedTuple, Optional, Union
|
4 |
+
|
5 |
+
from langchain.tools import BaseTool
|
6 |
+
from pydantic import Field
|
7 |
+
|
8 |
+
|
9 |
+
from agentverse.memory import BaseMemory, ChatHistoryMemory
|
10 |
+
from agentverse.message import Message
|
11 |
+
from agentverse.utils import AgentAction, AgentFinish
|
12 |
+
|
13 |
+
#from . import agent_registry
|
14 |
+
#from .base import BaseAgent
|
15 |
+
|
16 |
+
from agentverse.agents import agent_registry
|
17 |
+
from agentverse.agents.base import BaseAgent
|
18 |
+
|
19 |
+
class ToolNotExistError(BaseException):
|
20 |
+
"""Exception raised when parsing output from a command fails."""
|
21 |
+
|
22 |
+
def __init__(self, tool_name=""):
|
23 |
+
self.tool_name = tool_name
|
24 |
+
|
25 |
+
def __str__(self):
|
26 |
+
return f"Tool {self.tool_name} does not exist."
|
27 |
+
|
28 |
+
|
29 |
+
@agent_registry.register("tool")
|
30 |
+
class ToolAgent(BaseAgent):
|
31 |
+
tools: List[BaseTool] = Field(default=[])
|
32 |
+
tool_memory: BaseMemory = Field(default_factory=ChatHistoryMemory)
|
33 |
+
verbose: bool = Field(default=False)
|
34 |
+
|
35 |
+
def step(self, env_description: str = "") -> Message:
|
36 |
+
parsed_response = None
|
37 |
+
tool_observation = [self.tool_memory.to_string()]
|
38 |
+
while True:
|
39 |
+
prompt = self._fill_prompt_template(env_description, tool_observation)
|
40 |
+
|
41 |
+
for i in range(self.max_retry):
|
42 |
+
try:
|
43 |
+
response = self.llm.generate_response(prompt)
|
44 |
+
parsed_response = self.output_parser.parse(response)
|
45 |
+
if isinstance(parsed_response, AgentAction):
|
46 |
+
observation = self._call_tool(parsed_response)
|
47 |
+
tool_observation.append(
|
48 |
+
parsed_response.log.strip()
|
49 |
+
+ f"\nObservation: {observation.strip()}"
|
50 |
+
)
|
51 |
+
break
|
52 |
+
except BaseException as e:
|
53 |
+
logging.error(e)
|
54 |
+
logging.warning("Retrying...")
|
55 |
+
continue
|
56 |
+
if parsed_response is None or isinstance(parsed_response, AgentFinish):
|
57 |
+
break
|
58 |
+
|
59 |
+
if parsed_response is None:
|
60 |
+
logging.error(f"{self.name} failed to generate valid response.")
|
61 |
+
|
62 |
+
self._update_tool_memory(tool_observation)
|
63 |
+
|
64 |
+
message = Message(
|
65 |
+
content=""
|
66 |
+
if parsed_response is None
|
67 |
+
else parsed_response.return_values["output"],
|
68 |
+
sender=self.name,
|
69 |
+
receiver=self.get_receiver(),
|
70 |
+
)
|
71 |
+
return message
|
72 |
+
|
73 |
+
async def astep(self, env_description: str = "") -> Message:
|
74 |
+
"""Asynchronous version of step"""
|
75 |
+
parsed_response = None
|
76 |
+
# Initialize the tool_observation with tool_memory
|
77 |
+
tool_observation = [self.tool_memory.to_string()]
|
78 |
+
while True:
|
79 |
+
prompt = self._fill_prompt_template(env_description, tool_observation)
|
80 |
+
|
81 |
+
for i in range(self.max_retry):
|
82 |
+
try:
|
83 |
+
response = await self.llm.agenerate_response(prompt)
|
84 |
+
parsed_response = self.output_parser.parse(response)
|
85 |
+
if isinstance(parsed_response, AgentAction):
|
86 |
+
# If the response is an action, call the tool
|
87 |
+
# and append the observation to tool_observation
|
88 |
+
observation = await self._acall_tool(parsed_response)
|
89 |
+
tool_observation.append(
|
90 |
+
parsed_response.log.strip()
|
91 |
+
+ f"\nObservation: {observation.strip()}"
|
92 |
+
)
|
93 |
+
break
|
94 |
+
except BaseException as e:
|
95 |
+
logging.error(e)
|
96 |
+
logging.warning("Retrying...")
|
97 |
+
continue
|
98 |
+
if parsed_response is None or isinstance(parsed_response, AgentFinish):
|
99 |
+
break
|
100 |
+
|
101 |
+
if parsed_response is None:
|
102 |
+
logging.error(f"{self.name} failed to generate valid response.")
|
103 |
+
|
104 |
+
self._update_tool_memory(tool_observation)
|
105 |
+
|
106 |
+
message = Message(
|
107 |
+
content=""
|
108 |
+
if parsed_response is None
|
109 |
+
else parsed_response.return_values["output"],
|
110 |
+
sender=self.name,
|
111 |
+
receiver=self.get_receiver(),
|
112 |
+
)
|
113 |
+
return message
|
114 |
+
|
115 |
+
def _call_tool(self, response: NamedTuple) -> str:
|
116 |
+
"""Call a tool and return the output"""
|
117 |
+
name_to_tool = {tool.name: tool for tool in self.tools}
|
118 |
+
if response.tool not in name_to_tool:
|
119 |
+
raise ToolNotExistError(response.tool)
|
120 |
+
tool = name_to_tool[response.tool]
|
121 |
+
observation = tool.run(response.tool_input, verbose=self.verbose)
|
122 |
+
return observation
|
123 |
+
|
124 |
+
async def _acall_tool(self, response: NamedTuple) -> str:
|
125 |
+
"""Call a tool and return the output"""
|
126 |
+
name_to_tool = {tool.name: tool for tool in self.tools}
|
127 |
+
if response.tool not in name_to_tool:
|
128 |
+
raise ToolNotExistError(response.tool)
|
129 |
+
tool = name_to_tool[response.tool]
|
130 |
+
observation = await tool.arun(response.tool_input, verbose=self.verbose)
|
131 |
+
return observation
|
132 |
+
|
133 |
+
def _update_tool_memory(self, tool_observation: List[str]):
|
134 |
+
"""Update the memory of the tool"""
|
135 |
+
if len(tool_observation) == 1:
|
136 |
+
# If no tool is called this turn, do nothing
|
137 |
+
return
|
138 |
+
messages = [
|
139 |
+
Message(content=observation) for observation in tool_observation[1:]
|
140 |
+
]
|
141 |
+
self.tool_memory.add_message(messages)
|
142 |
+
|
143 |
+
def _fill_prompt_template(
|
144 |
+
self, env_description: str = "", tool_observation: List[str] = []
|
145 |
+
) -> str:
|
146 |
+
"""Fill the placeholders in the prompt template
|
147 |
+
|
148 |
+
In the tool agent, these placeholders are supported:
|
149 |
+
- ${agent_name}: the name of the agent
|
150 |
+
- ${env_description}: the description of the environment
|
151 |
+
- ${role_description}: the description of the role of the agent
|
152 |
+
- ${chat_history}: the chat history of the agent
|
153 |
+
- ${tools}: the list of tools and their usage
|
154 |
+
- ${tool_names}: the list of tool names
|
155 |
+
- ${tool_observations}: the observation of the tool in this turn
|
156 |
+
"""
|
157 |
+
tools = "\n".join([f"> {tool.name}: {tool.description}" for tool in self.tools])
|
158 |
+
tools = tools.replace("{{", "{").replace("}}", "}")
|
159 |
+
tool_names = ", ".join([tool.name for tool in self.tools])
|
160 |
+
input_arguments = {
|
161 |
+
"agent_name": self.name,
|
162 |
+
"env_description": env_description,
|
163 |
+
"role_description": self.role_description,
|
164 |
+
"chat_history": self.memory.to_string(add_sender_prefix=True),
|
165 |
+
"tools": tools,
|
166 |
+
"tool_names": tool_names,
|
167 |
+
"tool_observation": "\n".join(tool_observation),
|
168 |
+
}
|
169 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
170 |
+
|
171 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
172 |
+
self.memory.add_message(messages)
|
173 |
+
|
174 |
+
def reset(self) -> None:
|
175 |
+
"""Reset the agent"""
|
176 |
+
self.memory.reset()
|
177 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .critic import CriticAgent
|
2 |
+
from .evaluator import EvaluatorAgent
|
3 |
+
from .executor import ExecutorAgent
|
4 |
+
from .manager import ManagerAgent
|
5 |
+
from .role_assigner import RoleAssignerAgent
|
6 |
+
from .solver import SolverAgent
|
agentverse/agents/tasksolving_agent/critic.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import json
|
4 |
+
from colorama import Fore
|
5 |
+
from agentverse.logging import get_logger
|
6 |
+
import bdb
|
7 |
+
from string import Template
|
8 |
+
from typing import TYPE_CHECKING, List, Union
|
9 |
+
|
10 |
+
from agentverse.message import Message
|
11 |
+
|
12 |
+
from agentverse.agents import agent_registry
|
13 |
+
from agentverse.agents.base import BaseAgent
|
14 |
+
from agentverse.utils import AgentCriticism
|
15 |
+
from agentverse.message import CriticMessage
|
16 |
+
|
17 |
+
logger = get_logger()
|
18 |
+
|
19 |
+
|
20 |
+
@agent_registry.register("critic")
|
21 |
+
class CriticAgent(BaseAgent):
|
22 |
+
max_history: int = 3
|
23 |
+
tools: List[dict] = []
|
24 |
+
tool_names: List[str] = []
|
25 |
+
tool_descriptions: str = ""
|
26 |
+
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
tool_config_file = kwargs.pop("tool_config", "")
|
29 |
+
tools = []
|
30 |
+
tool_names = []
|
31 |
+
tool_descriptions = ""
|
32 |
+
if tool_config_file != "":
|
33 |
+
try:
|
34 |
+
with open(tool_config_file, "r") as f:
|
35 |
+
tools_dict = json.load(f)
|
36 |
+
tools = tools_dict["tools_json"]
|
37 |
+
tool_names = [t["name"] for t in tools]
|
38 |
+
tool_descriptions = "\n".join(
|
39 |
+
[f"- {t['name']}: " + t["description"] for t in tools]
|
40 |
+
)
|
41 |
+
kwargs.update('tools', tools)
|
42 |
+
kwargs.update('tool_names', tool_names)
|
43 |
+
kwargs.update('tool_descriptions', tool_descriptions)
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(e)
|
46 |
+
logger.warn("Failed to load tool config file.")
|
47 |
+
super().__init__(
|
48 |
+
*args,
|
49 |
+
**kwargs,
|
50 |
+
)
|
51 |
+
|
52 |
+
def step(self, env_description: str = "") -> CriticMessage:
|
53 |
+
pass
|
54 |
+
|
55 |
+
async def astep(
|
56 |
+
self,
|
57 |
+
preliminary_solution: str,
|
58 |
+
advice: str = "No advice yet.",
|
59 |
+
task_description: str = "",
|
60 |
+
all_roles: str = "",
|
61 |
+
**kwargs,
|
62 |
+
) -> CriticMessage:
|
63 |
+
"""Asynchronous version of step"""
|
64 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
65 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
66 |
+
preliminary_solution=preliminary_solution,
|
67 |
+
advice=advice,
|
68 |
+
task_description=task_description,
|
69 |
+
role_description=self.role_description,
|
70 |
+
agent_name=self.name,
|
71 |
+
all_roles=all_roles,
|
72 |
+
# tool_names=self.tool_names,
|
73 |
+
tool_descriptions=self.tool_descriptions,
|
74 |
+
)
|
75 |
+
history = self.memory.to_messages(self.name, start_index=-self.max_history)
|
76 |
+
parsed_response: Union[AgentCriticism, None] = None
|
77 |
+
for i in range(self.max_retry):
|
78 |
+
try:
|
79 |
+
response = await self.llm.agenerate_response(
|
80 |
+
prepend_prompt, history, append_prompt
|
81 |
+
)
|
82 |
+
parsed_response = self.output_parser.parse(response)
|
83 |
+
break
|
84 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
85 |
+
raise
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(e)
|
88 |
+
logger.warn("Retrying...")
|
89 |
+
continue
|
90 |
+
|
91 |
+
if parsed_response is None:
|
92 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
93 |
+
|
94 |
+
message = CriticMessage(
|
95 |
+
content=parsed_response.criticism if parsed_response is not None else "",
|
96 |
+
sender=self.name,
|
97 |
+
sender_agent=self,
|
98 |
+
is_agree=parsed_response.is_agree if parsed_response is not None else False,
|
99 |
+
)
|
100 |
+
return message
|
101 |
+
|
102 |
+
def _fill_prompt_template(
|
103 |
+
self, preliminary_solution: str, advice: str, task_description: str
|
104 |
+
) -> str:
|
105 |
+
"""Fill the placeholders in the prompt template
|
106 |
+
|
107 |
+
In the conversation agent, three placeholders are supported:
|
108 |
+
- ${role_description}
|
109 |
+
- ${task_description}
|
110 |
+
- ${preliminary_solution}
|
111 |
+
- ${advice}
|
112 |
+
"""
|
113 |
+
input_arguments = {
|
114 |
+
"role_description": self.role_description,
|
115 |
+
"task_description": task_description,
|
116 |
+
"preliminary_solution": preliminary_solution,
|
117 |
+
"advice": advice,
|
118 |
+
}
|
119 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
120 |
+
|
121 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
122 |
+
self.memory.add_message(messages)
|
123 |
+
|
124 |
+
def reset(self) -> None:
|
125 |
+
"""Reset the agent"""
|
126 |
+
self.memory.reset()
|
127 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/evaluator.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
from colorama import Fore
|
5 |
+
|
6 |
+
from agentverse.logging import get_logger
|
7 |
+
import bdb
|
8 |
+
from string import Template
|
9 |
+
from typing import TYPE_CHECKING, List, Tuple
|
10 |
+
|
11 |
+
from agentverse.message import EvaluatorMessage, Message
|
12 |
+
|
13 |
+
from agentverse.agents import agent_registry
|
14 |
+
from agentverse.agents.base import BaseAgent
|
15 |
+
|
16 |
+
|
17 |
+
logger = get_logger()
|
18 |
+
|
19 |
+
|
20 |
+
@agent_registry.register("evaluator")
|
21 |
+
class EvaluatorAgent(BaseAgent):
|
22 |
+
def step(
|
23 |
+
self,
|
24 |
+
solution: str,
|
25 |
+
result: str,
|
26 |
+
task_description: str,
|
27 |
+
all_role_description: str,
|
28 |
+
) -> EvaluatorMessage:
|
29 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
30 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
31 |
+
solution=solution,
|
32 |
+
result=result,
|
33 |
+
task_description=task_description,
|
34 |
+
all_role_description=all_role_description,
|
35 |
+
)
|
36 |
+
history = self.memory.to_messages(self.name)
|
37 |
+
parsed_response = None
|
38 |
+
for i in range(self.max_retry):
|
39 |
+
try:
|
40 |
+
response = self.llm.generate_response(
|
41 |
+
prepend_prompt, history, append_prompt
|
42 |
+
)
|
43 |
+
parsed_response = self.output_parser.parse(response)
|
44 |
+
break
|
45 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
46 |
+
raise
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(e)
|
49 |
+
logger.warn("Retrying...")
|
50 |
+
continue
|
51 |
+
|
52 |
+
if parsed_response is None:
|
53 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
54 |
+
message = EvaluatorMessage(
|
55 |
+
sender=self.name,
|
56 |
+
sender_agent=self,
|
57 |
+
score=parsed_response[0] if parsed_response is not None else 0,
|
58 |
+
advice=parsed_response[1] if parsed_response is not None else "",
|
59 |
+
)
|
60 |
+
return message
|
61 |
+
# return parsed_response
|
62 |
+
|
63 |
+
async def astep(self, solution: str) -> EvaluatorMessage:
|
64 |
+
"""Asynchronous version of step"""
|
65 |
+
pass
|
66 |
+
|
67 |
+
def _fill_prompt_template(self, solution: str, task_description: str) -> str:
|
68 |
+
"""Fill the placeholders in the prompt template
|
69 |
+
|
70 |
+
In the role_assigner agent, three placeholders are supported:
|
71 |
+
- ${task_description}
|
72 |
+
- ${solution}
|
73 |
+
"""
|
74 |
+
input_arguments = {
|
75 |
+
"task_description": task_description,
|
76 |
+
"solution": solution,
|
77 |
+
}
|
78 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
79 |
+
|
80 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
81 |
+
self.memory.add_message(messages)
|
82 |
+
|
83 |
+
def reset(self) -> None:
|
84 |
+
"""Reset the agent"""
|
85 |
+
self.memory.reset()
|
86 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/executor.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from agentverse.logging import get_logger
|
4 |
+
from colorama import Fore
|
5 |
+
import bdb
|
6 |
+
from string import Template
|
7 |
+
from typing import TYPE_CHECKING, List, Any
|
8 |
+
|
9 |
+
from agentverse.message import ExecutorMessage, Message, SolverMessage
|
10 |
+
from agentverse.utils import AgentFinish, AgentAction
|
11 |
+
|
12 |
+
from agentverse.agents import agent_registry
|
13 |
+
from agentverse.agents.base import BaseAgent
|
14 |
+
import requests
|
15 |
+
|
16 |
+
logger = get_logger()
|
17 |
+
|
18 |
+
|
19 |
+
@agent_registry.register("executor")
|
20 |
+
class ExecutorAgent(BaseAgent):
|
21 |
+
max_history: int = 5
|
22 |
+
|
23 |
+
def step(
|
24 |
+
self, task_description: str, solution: str, tools: List[dict] = [], **kwargs
|
25 |
+
) -> ExecutorMessage:
|
26 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
27 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
28 |
+
task_description=task_description,
|
29 |
+
solution=solution,
|
30 |
+
agent_name=self.name,
|
31 |
+
**kwargs,
|
32 |
+
)
|
33 |
+
|
34 |
+
history = self.memory.to_messages(self.name, start_index=-self.max_history)
|
35 |
+
parsed_response = None
|
36 |
+
for i in range(self.max_retry):
|
37 |
+
try:
|
38 |
+
response = self.llm.generate_response(
|
39 |
+
prepend_prompt, history, append_prompt, tools
|
40 |
+
)
|
41 |
+
parsed_response = self.output_parser.parse(response)
|
42 |
+
break
|
43 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
44 |
+
raise
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(e)
|
47 |
+
logger.warn("Retrying...")
|
48 |
+
continue
|
49 |
+
|
50 |
+
if parsed_response is None:
|
51 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
52 |
+
if isinstance(parsed_response, AgentFinish):
|
53 |
+
message = ExecutorMessage(
|
54 |
+
content=parsed_response.return_values["output"],
|
55 |
+
sender=self.name,
|
56 |
+
sender_agent=self,
|
57 |
+
)
|
58 |
+
elif isinstance(parsed_response, AgentAction):
|
59 |
+
message = ExecutorMessage(
|
60 |
+
content=parsed_response.log,
|
61 |
+
sender=self.name,
|
62 |
+
sender_agent=self,
|
63 |
+
tool_name=parsed_response.tool,
|
64 |
+
tool_input=parsed_response.tool_input,
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
raise ValueError(
|
68 |
+
f"Error response type: {type(parsed_response)}. Only support \
|
69 |
+
AgentFinish and AgentAction. Modify your output parser."
|
70 |
+
)
|
71 |
+
return message
|
72 |
+
|
73 |
+
async def astep(
|
74 |
+
self, task_description: str, solution: str, tools: List[dict] = [], **kwargs
|
75 |
+
) -> ExecutorMessage:
|
76 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
77 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
78 |
+
task_description=task_description,
|
79 |
+
solution=solution,
|
80 |
+
agent_name=self.name,
|
81 |
+
**kwargs,
|
82 |
+
)
|
83 |
+
|
84 |
+
history = self.memory.to_messages(self.name, start_index=-self.max_history)
|
85 |
+
parsed_response = None
|
86 |
+
for i in range(self.max_retry):
|
87 |
+
try:
|
88 |
+
response = await self.llm.agenerate_response(
|
89 |
+
prepend_prompt, history, append_prompt, tools
|
90 |
+
)
|
91 |
+
parsed_response = self.output_parser.parse(response)
|
92 |
+
break
|
93 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
94 |
+
raise
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(e)
|
97 |
+
logger.warn("Retrying...")
|
98 |
+
continue
|
99 |
+
|
100 |
+
if parsed_response is None:
|
101 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
102 |
+
parsed_response = AgentAction(tool="", tool_input="", log="")
|
103 |
+
if isinstance(parsed_response, AgentFinish):
|
104 |
+
message = ExecutorMessage(
|
105 |
+
content=parsed_response.return_values["output"],
|
106 |
+
sender=self.name,
|
107 |
+
sender_agent=self,
|
108 |
+
)
|
109 |
+
elif isinstance(parsed_response, AgentAction):
|
110 |
+
message = ExecutorMessage(
|
111 |
+
content=parsed_response.log,
|
112 |
+
sender=self.name,
|
113 |
+
sender_agent=self,
|
114 |
+
tool_name=parsed_response.tool,
|
115 |
+
tool_input=parsed_response.tool_input,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
raise ValueError(
|
119 |
+
f"Error response type: {type(parsed_response)}. Only support \
|
120 |
+
AgentFinish and AgentAction. Modify your output parser."
|
121 |
+
)
|
122 |
+
return message
|
123 |
+
|
124 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
125 |
+
self.memory.add_message(messages)
|
126 |
+
|
127 |
+
def reset(self) -> None:
|
128 |
+
"""Reset the agent"""
|
129 |
+
self.memory.reset()
|
130 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/manager.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
from colorama import Fore
|
5 |
+
|
6 |
+
from agentverse.logging import get_logger
|
7 |
+
import bdb
|
8 |
+
from string import Template
|
9 |
+
from typing import TYPE_CHECKING, List, Tuple
|
10 |
+
|
11 |
+
from agentverse.message import Message
|
12 |
+
|
13 |
+
from agentverse.agents import agent_registry
|
14 |
+
from agentverse.agents.base import BaseAgent
|
15 |
+
from agentverse.utils import AgentCriticism
|
16 |
+
|
17 |
+
import random
|
18 |
+
from rapidfuzz import fuzz
|
19 |
+
|
20 |
+
|
21 |
+
logger = get_logger()
|
22 |
+
|
23 |
+
|
24 |
+
@agent_registry.register("manager")
|
25 |
+
class ManagerAgent(BaseAgent):
|
26 |
+
prompt_template: str
|
27 |
+
|
28 |
+
def step(
|
29 |
+
self,
|
30 |
+
former_solution: str,
|
31 |
+
candidate_critic_opinions: List[AgentCriticism],
|
32 |
+
advice: str,
|
33 |
+
task_description: str = "",
|
34 |
+
previous_sentence: str = "",
|
35 |
+
) -> Message:
|
36 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
37 |
+
|
38 |
+
prompt = self._fill_prompt_template(
|
39 |
+
former_solution,
|
40 |
+
candidate_critic_opinions,
|
41 |
+
advice,
|
42 |
+
task_description,
|
43 |
+
previous_sentence,
|
44 |
+
)
|
45 |
+
|
46 |
+
logger.debug(f"Prompt:\n{prompt}", "Manager", Fore.CYAN)
|
47 |
+
parsed_response = None
|
48 |
+
for i in range(self.max_retry):
|
49 |
+
try:
|
50 |
+
# LLM Manager
|
51 |
+
# response = self.llm.generate_response(prompt)
|
52 |
+
# parsed_response = self.output_parser.parse(response)
|
53 |
+
selected_role_description = self.llm.generate_response(prompt).content
|
54 |
+
candidate_score_list = [
|
55 |
+
fuzz.ratio(candidate.sender, selected_role_description)
|
56 |
+
for candidate in candidate_critic_opinions
|
57 |
+
]
|
58 |
+
selected_index = candidate_score_list.index(max(candidate_score_list))
|
59 |
+
candidate_critic_opinion = candidate_critic_opinions[selected_index]
|
60 |
+
|
61 |
+
# Random Manager
|
62 |
+
# parsed_response = random.choice(candidate_critic_opinions)
|
63 |
+
break
|
64 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
65 |
+
raise
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(e)
|
68 |
+
logger.warn("Retrying...")
|
69 |
+
continue
|
70 |
+
return candidate_critic_opinion
|
71 |
+
|
72 |
+
async def astep(self, env_description: str = "") -> Message:
|
73 |
+
"""Asynchronous version of step"""
|
74 |
+
pass
|
75 |
+
|
76 |
+
def _fill_prompt_template(
|
77 |
+
self,
|
78 |
+
former_solution: str,
|
79 |
+
candidate_critic_opinions: List[AgentCriticism],
|
80 |
+
advice: str,
|
81 |
+
task_description: str,
|
82 |
+
previous_sentence: str,
|
83 |
+
) -> str:
|
84 |
+
"""Fill the placeholders in the prompt template
|
85 |
+
|
86 |
+
In the role_assigner agent, three placeholders are supported:
|
87 |
+
- ${task_description}
|
88 |
+
- ${former_solution}
|
89 |
+
- ${critic_messages}
|
90 |
+
- ${advice}
|
91 |
+
- ${previous_sentence}
|
92 |
+
"""
|
93 |
+
input_arguments = {
|
94 |
+
"task_description": task_description,
|
95 |
+
"former_solution": former_solution,
|
96 |
+
"previous_sentence": previous_sentence,
|
97 |
+
"critic_opinions": "\n".join(
|
98 |
+
[
|
99 |
+
f"Role: {critic.sender}. {critic.sender_agent.role_description} said: {critic.content}"
|
100 |
+
for critic in candidate_critic_opinions
|
101 |
+
]
|
102 |
+
),
|
103 |
+
"advice": advice,
|
104 |
+
}
|
105 |
+
|
106 |
+
# manger select the proper sentence
|
107 |
+
template = Template(self.prompt_template)
|
108 |
+
return template.safe_substitute(input_arguments)
|
109 |
+
|
110 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
111 |
+
self.memory.add_message(messages)
|
112 |
+
|
113 |
+
def reset(self) -> None:
|
114 |
+
"""Reset the agent"""
|
115 |
+
self.memory.reset()
|
116 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/role_assigner.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
from colorama import Fore
|
5 |
+
|
6 |
+
from agentverse.logging import get_logger
|
7 |
+
import bdb
|
8 |
+
from string import Template
|
9 |
+
from typing import TYPE_CHECKING, List
|
10 |
+
|
11 |
+
from agentverse.message import RoleAssignerMessage, Message
|
12 |
+
|
13 |
+
from agentverse.agents import agent_registry
|
14 |
+
from agentverse.agents.base import BaseAgent
|
15 |
+
|
16 |
+
|
17 |
+
logger = get_logger()
|
18 |
+
|
19 |
+
|
20 |
+
@agent_registry.register("role_assigner")
|
21 |
+
class RoleAssignerAgent(BaseAgent):
|
22 |
+
def step(
|
23 |
+
self, advice: str, task_description: str, cnt_critic_agents: int
|
24 |
+
) -> RoleAssignerMessage:
|
25 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
26 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
27 |
+
advice=advice,
|
28 |
+
task_description=task_description,
|
29 |
+
cnt_critic_agents=cnt_critic_agents,
|
30 |
+
)
|
31 |
+
history = self.memory.to_messages(self.name)
|
32 |
+
parsed_response = None
|
33 |
+
for i in range(self.max_retry):
|
34 |
+
try:
|
35 |
+
response = self.llm.generate_response(
|
36 |
+
prepend_prompt, history, append_prompt
|
37 |
+
)
|
38 |
+
parsed_response = self.output_parser.parse(response)
|
39 |
+
if len(parsed_response) < cnt_critic_agents:
|
40 |
+
logger.warn(
|
41 |
+
f"Number of generate roles ({len(parsed_response)}) and number of group members ({cnt_critic_agents}) do not match."
|
42 |
+
)
|
43 |
+
logger.warn("Retrying...")
|
44 |
+
continue
|
45 |
+
break
|
46 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
47 |
+
raise
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(e)
|
50 |
+
logger.warn("Retrying...")
|
51 |
+
continue
|
52 |
+
|
53 |
+
if parsed_response is None:
|
54 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
55 |
+
|
56 |
+
message = RoleAssignerMessage(
|
57 |
+
content=parsed_response, sender=self.name, sender_agent=self
|
58 |
+
)
|
59 |
+
return message
|
60 |
+
|
61 |
+
async def astep(self, env_description: str = "") -> RoleAssignerMessage:
|
62 |
+
"""Asynchronous version of step"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
def _fill_prompt_template(
|
66 |
+
self, advice, task_description: str, cnt_critic_agents: int
|
67 |
+
) -> str:
|
68 |
+
"""Fill the placeholders in the prompt template
|
69 |
+
|
70 |
+
In the role_assigner agent, three placeholders are supported:
|
71 |
+
- ${task_description}
|
72 |
+
- ${cnt_critic_agnets}
|
73 |
+
- ${advice}
|
74 |
+
"""
|
75 |
+
input_arguments = {
|
76 |
+
"task_description": task_description,
|
77 |
+
"cnt_critic_agents": cnt_critic_agents,
|
78 |
+
"advice": advice,
|
79 |
+
}
|
80 |
+
return Template(self.prompt_template).safe_substitute(input_arguments)
|
81 |
+
|
82 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
83 |
+
self.memory.add_message(messages)
|
84 |
+
|
85 |
+
def reset(self) -> None:
|
86 |
+
"""Reset the agent"""
|
87 |
+
self.memory.reset()
|
88 |
+
# TODO: reset receiver
|
agentverse/agents/tasksolving_agent/solver.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
from colorama import Fore
|
5 |
+
|
6 |
+
from agentverse.logging import get_logger
|
7 |
+
import bdb
|
8 |
+
from string import Template
|
9 |
+
from typing import TYPE_CHECKING, List, Tuple
|
10 |
+
|
11 |
+
# from agentverse.environments import PipelineEnvironment
|
12 |
+
from agentverse.message import SolverMessage, Message, CriticMessage
|
13 |
+
|
14 |
+
from agentverse.agents import agent_registry
|
15 |
+
from agentverse.agents.base import BaseAgent
|
16 |
+
from agentverse.utils import AgentCriticism
|
17 |
+
|
18 |
+
|
19 |
+
logger = get_logger()
|
20 |
+
|
21 |
+
|
22 |
+
@agent_registry.register("solver")
|
23 |
+
class SolverAgent(BaseAgent):
|
24 |
+
max_history: int = 3
|
25 |
+
|
26 |
+
def step(
|
27 |
+
self, former_solution: str, advice: str, task_description: str = "", **kwargs
|
28 |
+
) -> SolverMessage:
|
29 |
+
logger.debug("", self.name, Fore.MAGENTA)
|
30 |
+
# prompt = self._fill_prompt_template(
|
31 |
+
# former_solution, critic_opinions, advice, task_description
|
32 |
+
# )
|
33 |
+
prepend_prompt, append_prompt = self.get_all_prompts(
|
34 |
+
former_solution=former_solution,
|
35 |
+
task_description=task_description,
|
36 |
+
advice=advice,
|
37 |
+
role_description=self.role_description,
|
38 |
+
**kwargs,
|
39 |
+
)
|
40 |
+
history = self.memory.to_messages(self.name, start_index=-self.max_history)
|
41 |
+
parsed_response = None
|
42 |
+
for i in range(self.max_retry):
|
43 |
+
try:
|
44 |
+
response = self.llm.generate_response(
|
45 |
+
prepend_prompt, history, append_prompt
|
46 |
+
)
|
47 |
+
parsed_response = self.output_parser.parse(response)
|
48 |
+
break
|
49 |
+
except (KeyboardInterrupt, bdb.BdbQuit):
|
50 |
+
raise
|
51 |
+
except Exception as e:
|
52 |
+
logger.error(e)
|
53 |
+
logger.warn("Retrying...")
|
54 |
+
continue
|
55 |
+
|
56 |
+
if parsed_response is None:
|
57 |
+
logger.error(f"{self.name} failed to generate valid response.")
|
58 |
+
|
59 |
+
message = SolverMessage(
|
60 |
+
content=""
|
61 |
+
if parsed_response is None
|
62 |
+
else parsed_response.return_values["output"],
|
63 |
+
sender=self.name,
|
64 |
+
receiver=self.get_receiver(),
|
65 |
+
)
|
66 |
+
return message
|
67 |
+
|
68 |
+
async def astep(self, env_description: str = "") -> SolverMessage:
|
69 |
+
"""Asynchronous version of step"""
|
70 |
+
pass
|
71 |
+
|
72 |
+
def _fill_prompt_template(
|
73 |
+
self,
|
74 |
+
former_solution: str,
|
75 |
+
critic_opinions: List[AgentCriticism],
|
76 |
+
advice: str,
|
77 |
+
task_description: str,
|
78 |
+
) -> str:
|
79 |
+
"""Fill the placeholders in the prompt template
|
80 |
+
|
81 |
+
In the role_assigner agent, three placeholders are supported:
|
82 |
+
- ${task_description}
|
83 |
+
- ${former_solution}
|
84 |
+
- ${critic_messages}
|
85 |
+
- ${advice}
|
86 |
+
"""
|
87 |
+
input_arguments = {
|
88 |
+
"task_description": task_description,
|
89 |
+
"former_solution": former_solution,
|
90 |
+
"critic_opinions": "\n".join(
|
91 |
+
[
|
92 |
+
f"{critic.sender_agent.role_description} said: {critic.criticism}"
|
93 |
+
for critic in critic_opinions
|
94 |
+
]
|
95 |
+
),
|
96 |
+
"advice": advice,
|
97 |
+
}
|
98 |
+
# if discussion_mode:
|
99 |
+
# template = Template(self.prompt_template[1])
|
100 |
+
# else:
|
101 |
+
template = Template(self.prompt_template)
|
102 |
+
return template.safe_substitute(input_arguments)
|
103 |
+
|
104 |
+
def add_message_to_memory(self, messages: List[Message]) -> None:
|
105 |
+
self.memory.add_message(messages)
|
106 |
+
|
107 |
+
def reset(self) -> None:
|
108 |
+
"""Reset the agent"""
|
109 |
+
self.memory.reset()
|
110 |
+
# TODO: reset receiver
|
agentverse/agentverse.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
# from agentverse.agents import Agent
|
6 |
+
from agentverse.agents.conversation_agent import BaseAgent
|
7 |
+
from agentverse.environments import BaseEnvironment
|
8 |
+
from agentverse.initialization import load_agent, load_environment, prepare_task_config
|
9 |
+
|
10 |
+
logging.basicConfig(
|
11 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
12 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
13 |
+
level=logging.INFO,
|
14 |
+
)
|
15 |
+
|
16 |
+
openai_logger = logging.getLogger("openai")
|
17 |
+
openai_logger.setLevel(logging.WARNING)
|
18 |
+
|
19 |
+
|
20 |
+
class AgentVerse:
|
21 |
+
def __init__(self, agents: List[BaseAgent], environment: BaseEnvironment):
|
22 |
+
self.agents = agents
|
23 |
+
self.environment = environment
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def from_task(cls, task: str, tasks_dir: str):
|
27 |
+
"""Build an AgentVerse from a task name.
|
28 |
+
The task name should correspond to a directory in `tasks` directory.
|
29 |
+
Then this method will load the configuration from the yaml file in that directory.
|
30 |
+
"""
|
31 |
+
# Prepare the config of the task
|
32 |
+
task_config = prepare_task_config(task, tasks_dir)
|
33 |
+
|
34 |
+
# Build the agents
|
35 |
+
agents = []
|
36 |
+
for agent_configs in task_config["agents"]:
|
37 |
+
agent = load_agent(agent_configs)
|
38 |
+
agents.append(agent)
|
39 |
+
|
40 |
+
# Build the environment
|
41 |
+
env_config = task_config["environment"]
|
42 |
+
env_config["agents"] = agents
|
43 |
+
environment = load_environment(env_config)
|
44 |
+
|
45 |
+
return cls(agents, environment)
|
46 |
+
|
47 |
+
def run(self):
|
48 |
+
"""Run the environment from scratch until it is done."""
|
49 |
+
self.environment.reset()
|
50 |
+
while not self.environment.is_done():
|
51 |
+
asyncio.run(self.environment.step())
|
52 |
+
|
53 |
+
def reset(self):
|
54 |
+
self.environment.reset()
|
55 |
+
for agent in self.agents:
|
56 |
+
agent.reset()
|
57 |
+
|
58 |
+
def next(self, *args, **kwargs):
|
59 |
+
"""Run the environment for one step and return the return message."""
|
60 |
+
return_message = asyncio.run(self.environment.step(*args, **kwargs))
|
61 |
+
return return_message
|
62 |
+
|
63 |
+
def update_state(self, *args, **kwargs):
|
64 |
+
"""Run the environment for one step and return the return message."""
|
65 |
+
self.environment.update_state(*args, **kwargs)
|
agentverse/demo.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
from typing import Dict, List, Tuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from agentverse.agentverse import AgentVerse
|
10 |
+
from agentverse.message import Message
|
11 |
+
|
12 |
+
|
13 |
+
def cover_img(background, img, place: Tuple[int, int]):
|
14 |
+
"""
|
15 |
+
Overlays the specified image to the specified position of the background image.
|
16 |
+
:param background: background image
|
17 |
+
:param img: the specified image
|
18 |
+
:param place: the top-left coordinate of the target location
|
19 |
+
"""
|
20 |
+
back_h, back_w, _ = background.shape
|
21 |
+
height, width, _ = img.shape
|
22 |
+
for i, j in itertools.product(range(height), range(width)):
|
23 |
+
if img[i, j, 3]:
|
24 |
+
background[place[0] + i, place[1] + j] = img[i, j, :3]
|
25 |
+
|
26 |
+
|
27 |
+
class UI:
|
28 |
+
"""
|
29 |
+
the UI of frontend
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, task: str):
|
33 |
+
"""
|
34 |
+
init a UI.
|
35 |
+
default number of students is 0
|
36 |
+
"""
|
37 |
+
self.messages = []
|
38 |
+
self.task = task
|
39 |
+
self.backend = AgentVerse.from_task(task)
|
40 |
+
self.turns_remain = 0
|
41 |
+
self.agent_id = {
|
42 |
+
self.backend.agents[idx].name: idx
|
43 |
+
for idx in range(len(self.backend.agents))
|
44 |
+
}
|
45 |
+
self.stu_num = len(self.agent_id) - 1
|
46 |
+
self.autoplay = False
|
47 |
+
self.image_now = None
|
48 |
+
self.text_now = None
|
49 |
+
self.tot_solutions = 5
|
50 |
+
self.solution_status = [False] * self.tot_solutions
|
51 |
+
|
52 |
+
def get_avatar(self, idx):
|
53 |
+
if idx == -1:
|
54 |
+
img = cv2.imread("./imgs/db_diag/-1.png")
|
55 |
+
elif self.task == "prisoner_dilemma":
|
56 |
+
img = cv2.imread(f"./imgs/prison/{idx}.png")
|
57 |
+
elif self.task == "db_diag":
|
58 |
+
img = cv2.imread(f"./imgs/db_diag/{idx}.png")
|
59 |
+
elif "sde" in self.task:
|
60 |
+
img = cv2.imread(f"./imgs/sde/{idx}.png")
|
61 |
+
else:
|
62 |
+
img = cv2.imread(f"./imgs/{idx}.png")
|
63 |
+
base64_str = cv2.imencode(".png", img)[1].tostring()
|
64 |
+
return "data:image/png;base64," + base64.b64encode(base64_str).decode("utf-8")
|
65 |
+
|
66 |
+
def stop_autoplay(self):
|
67 |
+
self.autoplay = False
|
68 |
+
return (
|
69 |
+
gr.Button.update(interactive=False),
|
70 |
+
gr.Button.update(interactive=False),
|
71 |
+
gr.Button.update(interactive=False),
|
72 |
+
)
|
73 |
+
|
74 |
+
def start_autoplay(self):
|
75 |
+
self.autoplay = True
|
76 |
+
yield (
|
77 |
+
self.image_now,
|
78 |
+
self.text_now,
|
79 |
+
gr.Button.update(interactive=False),
|
80 |
+
gr.Button.update(interactive=True),
|
81 |
+
gr.Button.update(interactive=False),
|
82 |
+
*[gr.Button.update(visible=statu) for statu in self.solution_status],
|
83 |
+
gr.Box.update(visible=any(self.solution_status)),
|
84 |
+
)
|
85 |
+
|
86 |
+
while self.autoplay and self.turns_remain > 0:
|
87 |
+
outputs = self.gen_output()
|
88 |
+
self.image_now, self.text_now = outputs
|
89 |
+
|
90 |
+
yield (
|
91 |
+
*outputs,
|
92 |
+
gr.Button.update(interactive=not self.autoplay and self.turns_remain > 0),
|
93 |
+
gr.Button.update(interactive=self.autoplay and self.turns_remain > 0),
|
94 |
+
gr.Button.update(interactive=not self.autoplay and self.turns_remain > 0),
|
95 |
+
*[gr.Button.update(visible=statu) for statu in self.solution_status],
|
96 |
+
gr.Box.update(visible=any(self.solution_status))
|
97 |
+
)
|
98 |
+
|
99 |
+
def delay_gen_output(self):
|
100 |
+
yield (
|
101 |
+
self.image_now,
|
102 |
+
self.text_now,
|
103 |
+
gr.Button.update(interactive=False),
|
104 |
+
gr.Button.update(interactive=False),
|
105 |
+
*[gr.Button.update(visible=statu) for statu in self.solution_status],
|
106 |
+
gr.Box.update(visible=any(self.solution_status))
|
107 |
+
)
|
108 |
+
|
109 |
+
outputs = self.gen_output()
|
110 |
+
self.image_now, self.text_now = outputs
|
111 |
+
|
112 |
+
yield (
|
113 |
+
self.image_now,
|
114 |
+
self.text_now,
|
115 |
+
gr.Button.update(interactive=self.turns_remain > 0),
|
116 |
+
gr.Button.update(interactive=self.turns_remain > 0),
|
117 |
+
*[gr.Button.update(visible=statu) for statu in self.solution_status],
|
118 |
+
gr.Box.update(visible=any(self.solution_status))
|
119 |
+
)
|
120 |
+
|
121 |
+
def delay_reset(self):
|
122 |
+
self.autoplay = False
|
123 |
+
self.image_now, self.text_now = self.reset()
|
124 |
+
return (
|
125 |
+
self.image_now,
|
126 |
+
self.text_now,
|
127 |
+
gr.Button.update(interactive=True),
|
128 |
+
gr.Button.update(interactive=False),
|
129 |
+
gr.Button.update(interactive=True),
|
130 |
+
*[gr.Button.update(visible=statu) for statu in self.solution_status],
|
131 |
+
gr.Box.update(visible=any(self.solution_status))
|
132 |
+
)
|
133 |
+
|
134 |
+
def reset(self, stu_num=0):
|
135 |
+
"""
|
136 |
+
tell backend the new number of students and generate new empty image
|
137 |
+
:param stu_num:
|
138 |
+
:return: [empty image, empty message]
|
139 |
+
"""
|
140 |
+
if not 0 <= stu_num <= 30:
|
141 |
+
raise gr.Error("the number of students must be between 0 and 30.")
|
142 |
+
|
143 |
+
"""
|
144 |
+
# [To-Do] Need to add a function to assign agent numbers into the backend.
|
145 |
+
"""
|
146 |
+
# self.backend.reset(stu_num)
|
147 |
+
# self.stu_num = stu_num
|
148 |
+
|
149 |
+
"""
|
150 |
+
# [To-Do] Pass the parameters to reset
|
151 |
+
"""
|
152 |
+
self.backend.reset()
|
153 |
+
self.turns_remain = self.backend.environment.max_turns
|
154 |
+
|
155 |
+
if self.task == "prisoner_dilemma":
|
156 |
+
background = cv2.imread("./imgs/prison/case_1.png")
|
157 |
+
elif self.task == "db_diag":
|
158 |
+
background = cv2.imread("./imgs/db_diag/background.png")
|
159 |
+
elif "sde" in self.task:
|
160 |
+
background = cv2.imread("./imgs/sde/background.png")
|
161 |
+
else:
|
162 |
+
background = cv2.imread("./imgs/background.png")
|
163 |
+
back_h, back_w, _ = background.shape
|
164 |
+
stu_cnt = 0
|
165 |
+
for h_begin, w_begin in itertools.product(
|
166 |
+
range(800, back_h, 300), range(135, back_w - 200, 200)
|
167 |
+
):
|
168 |
+
stu_cnt += 1
|
169 |
+
img = cv2.imread(
|
170 |
+
f"./imgs/{(stu_cnt - 1) % 11 + 1 if stu_cnt <= self.stu_num else 'empty'}.png",
|
171 |
+
cv2.IMREAD_UNCHANGED,
|
172 |
+
)
|
173 |
+
cover_img(
|
174 |
+
background,
|
175 |
+
img,
|
176 |
+
(h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
|
177 |
+
)
|
178 |
+
self.messages = []
|
179 |
+
self.solution_status = [False] * self.tot_solutions
|
180 |
+
return [cv2.cvtColor(background, cv2.COLOR_BGR2RGB), ""]
|
181 |
+
|
182 |
+
def gen_img(self, data: List[Dict]):
|
183 |
+
"""
|
184 |
+
generate new image with sender rank
|
185 |
+
:param data:
|
186 |
+
:return: the new image
|
187 |
+
"""
|
188 |
+
# The following code need to be more general. This one is too task-specific.
|
189 |
+
# if len(data) != self.stu_num:
|
190 |
+
if len(data) != self.stu_num + 1:
|
191 |
+
raise gr.Error("data length is not equal to the total number of students.")
|
192 |
+
if self.task == "prisoner_dilemma":
|
193 |
+
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
|
194 |
+
if (
|
195 |
+
len(self.messages) < 2
|
196 |
+
or self.messages[-1][0] == 1
|
197 |
+
or self.messages[-2][0] == 2
|
198 |
+
):
|
199 |
+
background = cv2.imread("./imgs/prison/case_1.png")
|
200 |
+
if data[0]["message"] != "":
|
201 |
+
cover_img(background, img, (400, 480))
|
202 |
+
else:
|
203 |
+
background = cv2.imread("./imgs/prison/case_2.png")
|
204 |
+
if data[0]["message"] != "":
|
205 |
+
cover_img(background, img, (400, 880))
|
206 |
+
if data[1]["message"] != "":
|
207 |
+
cover_img(background, img, (550, 480))
|
208 |
+
if data[2]["message"] != "":
|
209 |
+
cover_img(background, img, (550, 880))
|
210 |
+
elif self.task == "db_diag":
|
211 |
+
background = cv2.imread("./imgs/db_diag/background.png")
|
212 |
+
img = cv2.imread("./imgs/db_diag/speaking.png", cv2.IMREAD_UNCHANGED)
|
213 |
+
if data[0]["message"] != "":
|
214 |
+
cover_img(background, img, (750, 80))
|
215 |
+
if data[1]["message"] != "":
|
216 |
+
cover_img(background, img, (310, 220))
|
217 |
+
if data[2]["message"] != "":
|
218 |
+
cover_img(background, img, (522, 11))
|
219 |
+
elif "sde" in self.task:
|
220 |
+
background = cv2.imread("./imgs/sde/background.png")
|
221 |
+
img = cv2.imread("./imgs/sde/speaking.png", cv2.IMREAD_UNCHANGED)
|
222 |
+
if data[0]["message"] != "":
|
223 |
+
cover_img(background, img, (692, 330))
|
224 |
+
if data[1]["message"] != "":
|
225 |
+
cover_img(background, img, (692, 660))
|
226 |
+
if data[2]["message"] != "":
|
227 |
+
cover_img(background, img, (692, 990))
|
228 |
+
else:
|
229 |
+
background = cv2.imread("./imgs/background.png")
|
230 |
+
back_h, back_w, _ = background.shape
|
231 |
+
stu_cnt = 0
|
232 |
+
if data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
|
233 |
+
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
|
234 |
+
cover_img(background, img, (370, 1250))
|
235 |
+
for h_begin, w_begin in itertools.product(
|
236 |
+
range(800, back_h, 300), range(135, back_w - 200, 200)
|
237 |
+
):
|
238 |
+
stu_cnt += 1
|
239 |
+
if stu_cnt <= self.stu_num:
|
240 |
+
img = cv2.imread(
|
241 |
+
f"./imgs/{(stu_cnt - 1) % 11 + 1}.png", cv2.IMREAD_UNCHANGED
|
242 |
+
)
|
243 |
+
cover_img(
|
244 |
+
background,
|
245 |
+
img,
|
246 |
+
(h_begin - 30 if img.shape[0] > 190 else h_begin, w_begin),
|
247 |
+
)
|
248 |
+
if "[RaiseHand]" in data[stu_cnt]["message"]:
|
249 |
+
# elif data[stu_cnt]["message"] == "[RaiseHand]":
|
250 |
+
img = cv2.imread("./imgs/hand.png", cv2.IMREAD_UNCHANGED)
|
251 |
+
cover_img(background, img, (h_begin - 90, w_begin + 10))
|
252 |
+
elif data[stu_cnt]["message"] not in ["", "[RaiseHand]"]:
|
253 |
+
img = cv2.imread("./imgs/speaking.png", cv2.IMREAD_UNCHANGED)
|
254 |
+
cover_img(background, img, (h_begin - 90, w_begin + 10))
|
255 |
+
|
256 |
+
else:
|
257 |
+
img = cv2.imread("./imgs/empty.png", cv2.IMREAD_UNCHANGED)
|
258 |
+
cover_img(background, img, (h_begin, w_begin))
|
259 |
+
return cv2.cvtColor(background, cv2.COLOR_BGR2RGB)
|
260 |
+
|
261 |
+
def return_format(self, messages: List[Message]):
|
262 |
+
_format = [{"message": "", "sender": idx} for idx in range(len(self.agent_id))]
|
263 |
+
|
264 |
+
for message in messages:
|
265 |
+
if self.task == "db_diag":
|
266 |
+
content_json: dict = message.content
|
267 |
+
content_json["diagnose"] = f"[{message.sender}]: {content_json['diagnose']}"
|
268 |
+
_format[self.agent_id[message.sender]]["message"] = json.dumps(content_json)
|
269 |
+
elif "sde" in self.task:
|
270 |
+
if message.sender == "code_tester":
|
271 |
+
pre_message, message_ = message.content.split("\n")
|
272 |
+
message_ = "{}\n{}".format(pre_message, json.loads(message_)["feedback"])
|
273 |
+
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
|
274 |
+
message.sender, message_
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
|
278 |
+
message.sender, message.content
|
279 |
+
)
|
280 |
+
|
281 |
+
else:
|
282 |
+
_format[self.agent_id[message.sender]]["message"] = "[{}]: {}".format(
|
283 |
+
message.sender, message.content
|
284 |
+
)
|
285 |
+
|
286 |
+
return _format
|
287 |
+
|
288 |
+
def gen_output(self):
|
289 |
+
"""
|
290 |
+
generate new image and message of next step
|
291 |
+
:return: [new image, new message]
|
292 |
+
"""
|
293 |
+
|
294 |
+
# data = self.backend.next_data()
|
295 |
+
return_message = self.backend.next()
|
296 |
+
data = self.return_format(return_message)
|
297 |
+
|
298 |
+
# data.sort(key=lambda item: item["sender"])
|
299 |
+
"""
|
300 |
+
# [To-Do]; Check the message from the backend: only 1 person can speak
|
301 |
+
"""
|
302 |
+
|
303 |
+
for item in data:
|
304 |
+
if item["message"] not in ["", "[RaiseHand]"]:
|
305 |
+
self.messages.append((item["sender"], item["message"]))
|
306 |
+
|
307 |
+
message = self.gen_message()
|
308 |
+
self.turns_remain -= 1
|
309 |
+
return [self.gen_img(data), message]
|
310 |
+
|
311 |
+
def gen_message(self):
|
312 |
+
# If the backend cannot handle this error, use the following code.
|
313 |
+
message = ""
|
314 |
+
"""
|
315 |
+
for item in data:
|
316 |
+
if item["message"] not in ["", "[RaiseHand]"]:
|
317 |
+
message = item["message"]
|
318 |
+
break
|
319 |
+
"""
|
320 |
+
for sender, msg in self.messages:
|
321 |
+
if sender == 0:
|
322 |
+
avatar = self.get_avatar(0)
|
323 |
+
elif sender == -1:
|
324 |
+
avatar = self.get_avatar(-1)
|
325 |
+
else:
|
326 |
+
avatar = self.get_avatar((sender - 1) % 11 + 1)
|
327 |
+
if self.task == "db_diag":
|
328 |
+
msg_json = json.loads(msg)
|
329 |
+
self.solution_status = [False] * self.tot_solutions
|
330 |
+
msg = msg_json["diagnose"]
|
331 |
+
if msg_json["solution"] != "":
|
332 |
+
solution: List[str] = msg_json["solution"]
|
333 |
+
for solu in solution:
|
334 |
+
if "query" in solu or "queries" in solu:
|
335 |
+
self.solution_status[0] = True
|
336 |
+
solu = solu.replace("query", '<span style="color:yellow;">query</span>')
|
337 |
+
solu = solu.replace("queries", '<span style="color:yellow;">queries</span>')
|
338 |
+
if "join" in solu:
|
339 |
+
self.solution_status[1] = True
|
340 |
+
solu = solu.replace("join", '<span style="color:yellow;">join</span>')
|
341 |
+
if "index" in solu:
|
342 |
+
self.solution_status[2] = True
|
343 |
+
solu = solu.replace("index", '<span style="color:yellow;">index</span>')
|
344 |
+
if "system configuration" in solu:
|
345 |
+
self.solution_status[3] = True
|
346 |
+
solu = solu.replace("system configuration",
|
347 |
+
'<span style="color:yellow;">system configuration</span>')
|
348 |
+
if "monitor" in solu or "Monitor" in solu or "Investigate" in solu:
|
349 |
+
self.solution_status[4] = True
|
350 |
+
solu = solu.replace("monitor", '<span style="color:yellow;">monitor</span>')
|
351 |
+
solu = solu.replace("Monitor", '<span style="color:yellow;">Monitor</span>')
|
352 |
+
solu = solu.replace("Investigate", '<span style="color:yellow;">Investigate</span>')
|
353 |
+
msg = f"{msg}<br>{solu}"
|
354 |
+
if msg_json["knowledge"] != "":
|
355 |
+
msg = f'{msg}<hr style="margin: 5px 0"><span style="font-style: italic">{msg_json["knowledge"]}<span>'
|
356 |
+
else:
|
357 |
+
msg = msg.replace("<", "<")
|
358 |
+
msg = msg.replace(">", ">")
|
359 |
+
message = (
|
360 |
+
f'<div style="display: flex; align-items: center; margin-bottom: 10px;overflow:auto;">'
|
361 |
+
f'<img src="{avatar}" style="width: 5%; height: 5%; border-radius: 25px; margin-right: 10px;">'
|
362 |
+
f'<div style="background-color: gray; color: white; padding: 10px; border-radius: 10px;'
|
363 |
+
f'max-width: 70%; white-space: pre-wrap">'
|
364 |
+
f"{msg}"
|
365 |
+
f"</div></div>" + message
|
366 |
+
)
|
367 |
+
message = '<div id="divDetail" style="height:600px;overflow:auto;">' + message + "</div>"
|
368 |
+
return message
|
369 |
+
|
370 |
+
def submit(self, message: str):
|
371 |
+
"""
|
372 |
+
submit message to backend
|
373 |
+
:param message: message
|
374 |
+
:return: [new image, new message]
|
375 |
+
"""
|
376 |
+
self.backend.submit(message)
|
377 |
+
self.messages.append((-1, f"[User]: {message}"))
|
378 |
+
return self.gen_img([{"message": ""}] * len(self.agent_id)), self.gen_message()
|
379 |
+
|
380 |
+
def launch(self):
|
381 |
+
"""
|
382 |
+
start a frontend
|
383 |
+
"""
|
384 |
+
with gr.Blocks() as demo:
|
385 |
+
with gr.Row():
|
386 |
+
with gr.Column():
|
387 |
+
image_output = gr.Image()
|
388 |
+
with gr.Row():
|
389 |
+
reset_btn = gr.Button("Reset")
|
390 |
+
# next_btn = gr.Button("Next", variant="primary")
|
391 |
+
next_btn = gr.Button("Next", interactive=False)
|
392 |
+
stop_autoplay_btn = gr.Button(
|
393 |
+
"Stop Autoplay", interactive=False
|
394 |
+
)
|
395 |
+
start_autoplay_btn = gr.Button("Start Autoplay", interactive=False)
|
396 |
+
with gr.Box(visible=False) as solutions:
|
397 |
+
with gr.Column():
|
398 |
+
gr.HTML("Optimization Solutions:")
|
399 |
+
with gr.Row():
|
400 |
+
rewrite_slow_query_btn = gr.Button("Rewrite Slow Query", visible=False)
|
401 |
+
add_query_hints_btn = gr.Button("Add Query Hints", visible=False)
|
402 |
+
update_indexes_btn = gr.Button("Update Indexes", visible=False)
|
403 |
+
tune_parameters_btn = gr.Button("Tune Parameters", visible=False)
|
404 |
+
gather_more_info_btn = gr.Button("Gather More Info", visible=False)
|
405 |
+
# text_output = gr.Textbox()
|
406 |
+
text_output = gr.HTML(self.reset()[1])
|
407 |
+
|
408 |
+
# Given a botton to provide student numbers and their inf.
|
409 |
+
# stu_num = gr.Number(label="Student Number", precision=0)
|
410 |
+
# stu_num = self.stu_num
|
411 |
+
|
412 |
+
if self.task == "db_diag":
|
413 |
+
user_msg = gr.Textbox()
|
414 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
415 |
+
|
416 |
+
submit_btn.click(fn=self.submit, inputs=user_msg, outputs=[image_output, text_output], show_progress=False)
|
417 |
+
else:
|
418 |
+
pass
|
419 |
+
|
420 |
+
# next_btn.click(fn=self.gen_output, inputs=None, outputs=[image_output, text_output], show_progress=False)
|
421 |
+
next_btn.click(
|
422 |
+
fn=self.delay_gen_output,
|
423 |
+
inputs=None,
|
424 |
+
outputs=[
|
425 |
+
image_output,
|
426 |
+
text_output,
|
427 |
+
next_btn,
|
428 |
+
start_autoplay_btn,
|
429 |
+
rewrite_slow_query_btn,
|
430 |
+
add_query_hints_btn,
|
431 |
+
update_indexes_btn,
|
432 |
+
tune_parameters_btn,
|
433 |
+
gather_more_info_btn,
|
434 |
+
solutions
|
435 |
+
],
|
436 |
+
show_progress=False,
|
437 |
+
)
|
438 |
+
|
439 |
+
# [To-Do] Add botton: re-start (load different people and env)
|
440 |
+
# reset_btn.click(fn=self.reset, inputs=stu_num, outputs=[image_output, text_output], show_progress=False)
|
441 |
+
# reset_btn.click(fn=self.reset, inputs=None, outputs=[image_output, text_output], show_progress=False)
|
442 |
+
reset_btn.click(
|
443 |
+
fn=self.delay_reset,
|
444 |
+
inputs=None,
|
445 |
+
outputs=[
|
446 |
+
image_output,
|
447 |
+
text_output,
|
448 |
+
next_btn,
|
449 |
+
stop_autoplay_btn,
|
450 |
+
start_autoplay_btn,
|
451 |
+
rewrite_slow_query_btn,
|
452 |
+
add_query_hints_btn,
|
453 |
+
update_indexes_btn,
|
454 |
+
tune_parameters_btn,
|
455 |
+
gather_more_info_btn,
|
456 |
+
solutions
|
457 |
+
],
|
458 |
+
show_progress=False,
|
459 |
+
)
|
460 |
+
|
461 |
+
stop_autoplay_btn.click(
|
462 |
+
fn=self.stop_autoplay,
|
463 |
+
inputs=None,
|
464 |
+
outputs=[next_btn, stop_autoplay_btn, start_autoplay_btn],
|
465 |
+
show_progress=False,
|
466 |
+
)
|
467 |
+
start_autoplay_btn.click(
|
468 |
+
fn=self.start_autoplay,
|
469 |
+
inputs=None,
|
470 |
+
outputs=[
|
471 |
+
image_output,
|
472 |
+
text_output,
|
473 |
+
next_btn,
|
474 |
+
stop_autoplay_btn,
|
475 |
+
start_autoplay_btn,
|
476 |
+
rewrite_slow_query_btn,
|
477 |
+
add_query_hints_btn,
|
478 |
+
update_indexes_btn,
|
479 |
+
tune_parameters_btn,
|
480 |
+
gather_more_info_btn,
|
481 |
+
solutions
|
482 |
+
],
|
483 |
+
show_progress=False,
|
484 |
+
)
|
485 |
+
|
486 |
+
demo.queue(concurrency_count=5, max_size=20).launch()
|
487 |
+
# demo.launch()
|
agentverse/environments/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
from agentverse.registry import Registry
|
3 |
+
|
4 |
+
|
5 |
+
env_registry = Registry(name="EnvironmentRegistry")
|
6 |
+
|
7 |
+
|
8 |
+
from .base import BaseEnvironment, BaseRule
|
9 |
+
|
10 |
+
# from .basic import PipelineEnvironment
|
11 |
+
from .simulation_env.basic import BasicEnvironment
|
12 |
+
from .simulation_env.pokemon import PokemonEnvironment
|
13 |
+
from .simulation_env.prisoner_dilemma import PrisonerDilemmaEnvironment
|
14 |
+
from .simulation_env.sde_team import SdeTeamEnvironment
|
15 |
+
from .simulation_env.sde_team_given_tests import SdeTeamGivenTestsEnvironment
|
16 |
+
|
17 |
+
from .tasksolving_env.basic import BasicEnvironment
|
agentverse/environments/base.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from agentverse.logging import logger
|
3 |
+
|
4 |
+
from abc import abstractmethod
|
5 |
+
from typing import TYPE_CHECKING, Any, Dict, List
|
6 |
+
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
# from agentverse.agents.agent import Agent
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from agentverse.agents.base import BaseAgent
|
13 |
+
from agentverse.message import Message
|
14 |
+
|
15 |
+
|
16 |
+
class BaseRule(BaseModel):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
class BaseEnvironment(BaseModel):
|
21 |
+
"""
|
22 |
+
Base class for environment.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
agents: List of agents
|
26 |
+
rule: Rule for the environment
|
27 |
+
max_turns: Maximum number of turns
|
28 |
+
cnt_turn: Current turn number
|
29 |
+
last_messages: Messages from last turn
|
30 |
+
rule_params: Variables set by the rule
|
31 |
+
"""
|
32 |
+
|
33 |
+
agents: List[BaseAgent]
|
34 |
+
rule: BaseRule
|
35 |
+
max_turns: int = 10
|
36 |
+
cnt_turn: int = 0
|
37 |
+
last_messages: List[Message] = []
|
38 |
+
rule_params: Dict = {}
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
async def step(self) -> List[Message]:
|
42 |
+
"""Run one step of the environment"""
|
43 |
+
pass
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def reset(self) -> None:
|
47 |
+
"""Reset the environment"""
|
48 |
+
pass
|
49 |
+
|
50 |
+
def report_metrics(self) -> None:
|
51 |
+
"""Report useful metrics"""
|
52 |
+
total_spent = sum([agent.get_spend() for agent in self.agents])
|
53 |
+
logger.info(f"Total spent: ${total_spent}")
|
54 |
+
pass
|
55 |
+
|
56 |
+
def is_done(self) -> bool:
|
57 |
+
"""Check if the environment is done"""
|
58 |
+
return self.cnt_turn >= self.max_turns
|
agentverse/environments/simulation_env/basic.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
# import logging
|
4 |
+
from agentverse.logging import get_logger
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
# from agentverse.agents.agent import Agent
|
8 |
+
from agentverse.agents.simulation_agent.conversation import BaseAgent
|
9 |
+
|
10 |
+
# from agentverse.environments.simulation_env.rules.base import Rule
|
11 |
+
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
|
12 |
+
from agentverse.message import Message
|
13 |
+
|
14 |
+
logger = get_logger()
|
15 |
+
|
16 |
+
from .. import env_registry as EnvironmentRegistry
|
17 |
+
from ..base import BaseEnvironment
|
18 |
+
|
19 |
+
|
20 |
+
@EnvironmentRegistry.register("sim-basic")
|
21 |
+
class BasicEnvironment(BaseEnvironment):
|
22 |
+
"""
|
23 |
+
A basic environment implementing the logic of conversation.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
agents: List of agents
|
27 |
+
rule: Rule for the environment
|
28 |
+
max_turns: Maximum number of turns
|
29 |
+
cnt_turn: Current turn number
|
30 |
+
last_messages: Messages from last turn
|
31 |
+
rule_params: Variables set by the rule
|
32 |
+
"""
|
33 |
+
|
34 |
+
agents: List[BaseAgent]
|
35 |
+
rule: Rule
|
36 |
+
max_turns: int = 10
|
37 |
+
cnt_turn: int = 0
|
38 |
+
last_messages: List[Message] = []
|
39 |
+
rule_params: Dict = {}
|
40 |
+
|
41 |
+
def __init__(self, rule, **kwargs):
|
42 |
+
rule_config = rule
|
43 |
+
order_config = rule_config.get("order", {"type": "sequential"})
|
44 |
+
visibility_config = rule_config.get("visibility", {"type": "all"})
|
45 |
+
selector_config = rule_config.get("selector", {"type": "basic"})
|
46 |
+
updater_config = rule_config.get("updater", {"type": "basic"})
|
47 |
+
describer_config = rule_config.get("describer", {"type": "basic"})
|
48 |
+
rule = Rule(
|
49 |
+
order_config,
|
50 |
+
visibility_config,
|
51 |
+
selector_config,
|
52 |
+
updater_config,
|
53 |
+
describer_config,
|
54 |
+
)
|
55 |
+
super().__init__(rule=rule, **kwargs)
|
56 |
+
|
57 |
+
async def step(self) -> List[Message]:
|
58 |
+
"""Run one step of the environment"""
|
59 |
+
|
60 |
+
# Get the next agent index
|
61 |
+
agent_ids = self.rule.get_next_agent_idx(self)
|
62 |
+
|
63 |
+
# Generate current environment description
|
64 |
+
env_descriptions = self.rule.get_env_description(self)
|
65 |
+
|
66 |
+
# Generate the next message
|
67 |
+
messages = await asyncio.gather(
|
68 |
+
*[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
|
69 |
+
)
|
70 |
+
|
71 |
+
# Some rules will select certain messages from all the messages
|
72 |
+
selected_messages = self.rule.select_message(self, messages)
|
73 |
+
self.last_messages = selected_messages
|
74 |
+
self.print_messages(selected_messages)
|
75 |
+
|
76 |
+
# Update the memory of the agents
|
77 |
+
self.rule.update_memory(self)
|
78 |
+
|
79 |
+
# Update the set of visible agents for each agent
|
80 |
+
self.rule.update_visible_agents(self)
|
81 |
+
|
82 |
+
self.cnt_turn += 1
|
83 |
+
|
84 |
+
return selected_messages
|
85 |
+
|
86 |
+
def print_messages(self, messages: List[Message]) -> None:
|
87 |
+
for message in messages:
|
88 |
+
if message is not None:
|
89 |
+
# logging.info(f"{message.sender}: {message.content}")
|
90 |
+
logger.info(f"{message.sender}: {message.content}")
|
91 |
+
|
92 |
+
def reset(self) -> None:
|
93 |
+
"""Reset the environment"""
|
94 |
+
self.cnt_turn = 0
|
95 |
+
self.rule.reset()
|
96 |
+
for agent in self.agents:
|
97 |
+
agent.reset()
|
98 |
+
|
99 |
+
def is_done(self) -> bool:
|
100 |
+
"""Check if the environment is done"""
|
101 |
+
return self.cnt_turn >= self.max_turns
|
agentverse/environments/simulation_env/pokemon.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import datetime
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, List, Optional, Set
|
5 |
+
|
6 |
+
# from agentverse.agents.agent import Agent
|
7 |
+
from agentverse.agents.simulation_agent.conversation import BaseAgent
|
8 |
+
|
9 |
+
# from agentverse.environments.simulation_env.rules.base import Rule
|
10 |
+
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
|
11 |
+
from agentverse.message import Message
|
12 |
+
|
13 |
+
from .. import env_registry as EnvironmentRegistry
|
14 |
+
from ..base import BaseEnvironment
|
15 |
+
|
16 |
+
|
17 |
+
@EnvironmentRegistry.register("pokemon")
|
18 |
+
class PokemonEnvironment(BaseEnvironment):
|
19 |
+
"""
|
20 |
+
An environment for Pokémon demo.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
agents: List of agents
|
24 |
+
locations: A dict of locations to agents within them
|
25 |
+
rule: Rule for the environment
|
26 |
+
max_turns: Maximum number of turns
|
27 |
+
cnt_turn: Current turn number
|
28 |
+
last_messages: Messages from last turn
|
29 |
+
rule_params: Variables set by the rule
|
30 |
+
"""
|
31 |
+
|
32 |
+
agents: List[BaseAgent]
|
33 |
+
locations_to_agents: Dict[str, Set[str]]
|
34 |
+
# locations_descriptions: Dict[str, str]
|
35 |
+
time: datetime.datetime = datetime.datetime(2021, 1, 1, 8, 0, 0)
|
36 |
+
rule: Rule
|
37 |
+
max_turns: int = 10
|
38 |
+
cnt_turn: int = 0
|
39 |
+
last_messages: List[Message] = []
|
40 |
+
rule_params: Dict = {}
|
41 |
+
|
42 |
+
def __init__(self, rule, locations, **kwargs):
|
43 |
+
rule_config = rule
|
44 |
+
order_config = rule_config.get("order", {"type": "sequential"})
|
45 |
+
visibility_config = rule_config.get("visibility", {"type": "all"})
|
46 |
+
selector_config = rule_config.get("selector", {"type": "basic"})
|
47 |
+
updater_config = rule_config.get("updater", {"type": "basic"})
|
48 |
+
describer_config = rule_config.get("describer", {"type": "basic"})
|
49 |
+
rule = Rule(
|
50 |
+
order_config,
|
51 |
+
visibility_config,
|
52 |
+
selector_config,
|
53 |
+
updater_config,
|
54 |
+
describer_config,
|
55 |
+
)
|
56 |
+
locations_to_agents = {}
|
57 |
+
# locations_descriptions = {}
|
58 |
+
locations_config = locations
|
59 |
+
for loc in locations_config:
|
60 |
+
locations_to_agents[loc["name"]] = set(loc["init_agents"])
|
61 |
+
# locations_descriptions[loc["name"]] = loc["description"]
|
62 |
+
super().__init__(
|
63 |
+
rule=rule,
|
64 |
+
locations_to_agents=locations_to_agents,
|
65 |
+
# locations_descriptions=locations_descriptions,
|
66 |
+
**kwargs,
|
67 |
+
)
|
68 |
+
|
69 |
+
async def step(
|
70 |
+
self,
|
71 |
+
is_player: bool = False,
|
72 |
+
player_content: str = None,
|
73 |
+
receiver: str = None,
|
74 |
+
receiver_id: Optional[int] = None,
|
75 |
+
agent_ids: Optional[List[int]] = None,
|
76 |
+
) -> List[Message]:
|
77 |
+
"""Run one step of the environment"""
|
78 |
+
|
79 |
+
# Get the next agent index
|
80 |
+
# time.sleep(8)
|
81 |
+
# return [Message(content="Test", sender="May", receiver=["May"])]
|
82 |
+
if is_player:
|
83 |
+
return await self._respond_to_player(player_content, receiver, receiver_id)
|
84 |
+
else:
|
85 |
+
return await self._routine_step(agent_ids)
|
86 |
+
|
87 |
+
async def _routine_step(self, agent_ids) -> List[Message]:
|
88 |
+
self.rule.update_visible_agents(self)
|
89 |
+
|
90 |
+
# agent_ids = self.rule.get_next_agent_idx(self)
|
91 |
+
|
92 |
+
# Generate current environment description
|
93 |
+
env_descriptions = self.rule.get_env_description(self)
|
94 |
+
|
95 |
+
# Generate the next message
|
96 |
+
messages = await asyncio.gather(
|
97 |
+
*[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
|
98 |
+
)
|
99 |
+
# messages = self.get_test_messages()
|
100 |
+
|
101 |
+
# Some rules will select certain messages from all the messages
|
102 |
+
selected_messages = self.rule.select_message(self, messages)
|
103 |
+
|
104 |
+
# Update the memory of the agents
|
105 |
+
self.last_messages = selected_messages
|
106 |
+
self.rule.update_memory(self)
|
107 |
+
self.print_messages(selected_messages)
|
108 |
+
|
109 |
+
self.cnt_turn += 1
|
110 |
+
self.time += datetime.timedelta(minutes=5)
|
111 |
+
|
112 |
+
return selected_messages
|
113 |
+
|
114 |
+
async def _respond_to_player(
|
115 |
+
self,
|
116 |
+
player_content: str = None,
|
117 |
+
receiver: str = None,
|
118 |
+
receiver_id: Optional[int] = None,
|
119 |
+
) -> List[Message]:
|
120 |
+
if receiver_id is None:
|
121 |
+
for agent in self.agents:
|
122 |
+
if agent.name == receiver:
|
123 |
+
receiver_id = agent.agent_id
|
124 |
+
break
|
125 |
+
agent_ids = [receiver_id]
|
126 |
+
agent_name = receiver
|
127 |
+
player_message = Message(
|
128 |
+
sender="Brenden", content=player_content, receiver=[agent_name]
|
129 |
+
)
|
130 |
+
|
131 |
+
# Update the set of visible agents for each agent
|
132 |
+
self.rule.update_visible_agents(self)
|
133 |
+
|
134 |
+
# Generate current environment description
|
135 |
+
env_descriptions = self.rule.get_env_description(self, player_content)
|
136 |
+
|
137 |
+
# Generate the next message
|
138 |
+
messages = await asyncio.gather(
|
139 |
+
*[self.agents[i].astep(env_descriptions[i]) for i in agent_ids]
|
140 |
+
)
|
141 |
+
|
142 |
+
# Some rules will select certain messages from all the messages
|
143 |
+
# selected_messages = self.rule.select_message(self, messages)
|
144 |
+
|
145 |
+
# Update the memory of the agents
|
146 |
+
self.last_messages = [player_message, *messages]
|
147 |
+
self.rule.update_memory(self)
|
148 |
+
self.print_messages(messages)
|
149 |
+
|
150 |
+
self.cnt_turn += 1
|
151 |
+
|
152 |
+
return messages
|
153 |
+
|
154 |
+
def update_state(self, agent_location: Dict[str, str]):
|
155 |
+
for agent_name, location in agent_location.items():
|
156 |
+
# original_location = self.get_agent_to_location()[agent_name]
|
157 |
+
# self.locations_to_agents[original_location].remove(agent_name)
|
158 |
+
self.locations_to_agents[location].add(agent_name)
|
159 |
+
|
160 |
+
def get_agent_to_location(self) -> Dict[str, str]:
|
161 |
+
ret = {}
|
162 |
+
for location, agent_names in self.locations_to_agents.items():
|
163 |
+
for agent in agent_names:
|
164 |
+
ret[agent] = location
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def print_messages(self, messages: List[Message]) -> None:
|
168 |
+
for message in messages:
|
169 |
+
if message is not None:
|
170 |
+
logging.info(f"{message.sender}: {message.content}")
|
171 |
+
|
172 |
+
def reset(self) -> None:
|
173 |
+
"""Reset the environment"""
|
174 |
+
self.cnt_turn = 0
|
175 |
+
self.rule.reset()
|
176 |
+
for agent in self.agents:
|
177 |
+
agent.reset()
|
178 |
+
|
179 |
+
def is_done(self) -> bool:
|
180 |
+
"""Check if the environment is done"""
|
181 |
+
return self.cnt_turn >= self.max_turns
|
182 |
+
|
183 |
+
def get_test_messages(self) -> List[Message]:
|
184 |
+
messages = [
|
185 |
+
Message(
|
186 |
+
content='{"to": "Birch", "action": "Speak", "text": "Hi!!!"}',
|
187 |
+
sender="May",
|
188 |
+
receiver={"May", "Birch"},
|
189 |
+
tool_response=[],
|
190 |
+
),
|
191 |
+
Message(
|
192 |
+
content='{"to": "May", "text": "Good morning, May! How is your research going?", "action": "Speak"}',
|
193 |
+
sender="Birch",
|
194 |
+
receiver={"May", "Birch"},
|
195 |
+
tool_response=[],
|
196 |
+
),
|
197 |
+
Message(
|
198 |
+
content='{"to": "Pokémon Center", "action": "MoveTo"}',
|
199 |
+
sender="Steven",
|
200 |
+
receiver={"Steven"},
|
201 |
+
tool_response=[],
|
202 |
+
),
|
203 |
+
Message(
|
204 |
+
content='{"to": "Shop", "last_time": "10 minutes", "action": "MoveTo"}',
|
205 |
+
sender="Maxie",
|
206 |
+
receiver={"Maxie"},
|
207 |
+
tool_response=[],
|
208 |
+
),
|
209 |
+
Message(
|
210 |
+
content='{"to": "Pok\\u00e9mon Center", "action": "MoveTo"}',
|
211 |
+
sender="Archie",
|
212 |
+
receiver={"Archie"},
|
213 |
+
tool_response=[],
|
214 |
+
),
|
215 |
+
Message(
|
216 |
+
content='{"to": "Shop", "action": "MoveTo"}',
|
217 |
+
sender="Joseph",
|
218 |
+
receiver={"Joseph"},
|
219 |
+
tool_response=[],
|
220 |
+
),
|
221 |
+
]
|
222 |
+
return messages
|
agentverse/environments/simulation_env/prisoner_dilemma.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
# from agentverse.agents.agent import Agent
|
6 |
+
from agentverse.agents.simulation_agent.conversation import BaseAgent
|
7 |
+
|
8 |
+
# from agentverse.environments.simulation_env.rules.base import Rule
|
9 |
+
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
|
10 |
+
from agentverse.message import Message
|
11 |
+
|
12 |
+
from .. import env_registry as EnvironmentRegistry
|
13 |
+
from .basic import BasicEnvironment
|
14 |
+
|
15 |
+
|
16 |
+
@EnvironmentRegistry.register("prisoner_dilemma")
|
17 |
+
class PrisonerDilemmaEnvironment(BasicEnvironment):
|
18 |
+
"""
|
19 |
+
An environment for prisoner dilemma.
|
20 |
+
"""
|
21 |
+
|
22 |
+
async def step(self) -> List[Message]:
|
23 |
+
"""Run one step of the environment"""
|
24 |
+
|
25 |
+
# Get the next agent index
|
26 |
+
agent_ids = self.rule.get_next_agent_idx(self)
|
27 |
+
|
28 |
+
# Generate current environment description
|
29 |
+
env_descriptions = self.rule.get_env_description(self)
|
30 |
+
|
31 |
+
# Generate the next message
|
32 |
+
messages = await asyncio.gather(
|
33 |
+
*[self.agents[i].astep(self, env_descriptions[i]) for i in agent_ids]
|
34 |
+
)
|
35 |
+
|
36 |
+
# Some rules will select certain messages from all the messages
|
37 |
+
selected_messages = self.rule.select_message(self, messages)
|
38 |
+
self.last_messages = selected_messages
|
39 |
+
self.print_messages(selected_messages)
|
40 |
+
|
41 |
+
# Update the memory of the agents
|
42 |
+
self.rule.update_memory(self)
|
43 |
+
|
44 |
+
# Update the set of visible agents for each agent
|
45 |
+
self.rule.update_visible_agents(self)
|
46 |
+
|
47 |
+
self.cnt_turn += 1
|
48 |
+
|
49 |
+
return selected_messages
|
agentverse/environments/simulation_env/reflection.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
from datetime import datetime as dt
|
6 |
+
import datetime
|
7 |
+
|
8 |
+
from pydantic import Field
|
9 |
+
|
10 |
+
from agentverse.agents.simulation_agent.conversation import BaseAgent
|
11 |
+
|
12 |
+
# from agentverse.environments.simulation_env.rules.base import Rule
|
13 |
+
from agentverse.environments.simulation_env.rules.base import SimulationRule as Rule
|
14 |
+
from agentverse.message import Message
|
15 |
+
|
16 |
+
from . import env_registry as EnvironmentRegistry
|
17 |
+
from ..base import BaseEnvironment
|
18 |
+
|
19 |
+
from pydantic import validator
|
20 |
+
|
21 |
+
|
22 |
+
@EnvironmentRegistry.register("reflection")
|
23 |
+
class ReflectionEnvironment(BaseEnvironment):
|
24 |
+
"""
|
25 |
+
Environment used in Observation-Planning-Reflection agent architecture.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
agents: List of agents
|
29 |
+
rule: Rule for the environment
|
30 |
+
max_turns: Maximum number of turns
|
31 |
+
cnt_turn: Current turn number
|
32 |
+
last_messages: Messages from last turn
|
33 |
+
rule_params: Variables set by the rule
|
34 |
+
current_time
|
35 |
+
time_delta: time difference between steps
|
36 |
+
"""
|
37 |
+
|
38 |
+
agents: List[BaseAgent]
|
39 |
+
rule: Rule
|
40 |
+
max_turns: int = 10
|
41 |
+
cnt_turn: int = 0
|
42 |
+
last_messages: List[Message] = []
|
43 |
+
rule_params: Dict = {}
|
44 |
+
current_time: dt = dt.now()
|
45 |
+
time_delta: int = 120
|
46 |
+
#
|
47 |
+
|
48 |
+
# @validator("time_delta")
|
49 |
+
# def convert_str_to_timedelta(cls, string):
|
50 |
+
#
|
51 |
+
# return datetime.timedelta(seconds=int(string))
|
52 |
+
|
53 |
+
def __init__(self, rule, **kwargs):
|
54 |
+
rule_config = rule
|
55 |
+
order_config = rule_config.get("order", {"type": "sequential"})
|
56 |
+
visibility_config = rule_config.get("visibility", {"type": "all"})
|
57 |
+
selector_config = rule_config.get("selector", {"type": "basic"})
|
58 |
+
updater_config = rule_config.get("updater", {"type": "basic"})
|
59 |
+
describer_config = rule_config.get("describer", {"type": "basic"})
|
60 |
+
rule = Rule(
|
61 |
+
order_config,
|
62 |
+
visibility_config,
|
63 |
+
selector_config,
|
64 |
+
updater_config,
|
65 |
+
describer_config,
|
66 |
+
)
|
67 |
+
|
68 |
+
super().__init__(rule=rule, **kwargs)
|
69 |
+
|
70 |
+
async def step(self) -> List[Message]:
|
71 |
+
"""Run one step of the environment"""
|
72 |
+
|
73 |
+
logging.log(logging.INFO, f"Tick tock. Current time: {self.current_time}")
|
74 |
+
|
75 |
+
# Get the next agent index
|
76 |
+
agent_ids = self.rule.get_next_agent_idx(self)
|
77 |
+
|
78 |
+
# Generate current environment description
|
79 |
+
env_descriptions = self.rule.get_env_description(self)
|
80 |
+
|
81 |
+
# Generate the next message
|
82 |
+
messages = await asyncio.gather(
|
83 |
+
*[
|
84 |
+
self.agents[i].astep(self.current_time, env_descriptions[i])
|
85 |
+
for i in agent_ids
|
86 |
+
]
|
87 |
+
)
|
88 |
+
|
89 |
+
# Some rules will select certain messages from all the messages
|
90 |
+
selected_messages = self.rule.select_message(self, messages)
|
91 |
+
self.last_messages = selected_messages
|
92 |
+
self.print_messages(selected_messages)
|
93 |
+
|
94 |
+
# Update the memory of the agents
|
95 |
+
self.rule.update_memory(self)
|
96 |
+
|
97 |
+
# Update the set of visible agents for each agent
|
98 |
+
self.rule.update_visible_agents(self)
|
99 |
+
|
100 |
+
self.cnt_turn += 1
|
101 |
+
|
102 |
+
# update current_time
|
103 |
+
self.tick_tock()
|
104 |
+
|
105 |
+
return selected_messages
|
106 |
+
|
107 |
+
def print_messages(self, messages: List[Message]) -> None:
|
108 |
+
for message in messages:
|
109 |
+
if message is not None:
|
110 |
+
logging.info(f"{message.sender}: {message.content}")
|
111 |
+
|
112 |
+
def reset(self) -> None:
|
113 |
+
"""Reset the environment"""
|
114 |
+
self.cnt_turn = 0
|
115 |
+
self.rule.reset()
|
116 |
+
BaseAgent.update_forward_refs()
|
117 |
+
for agent in self.agents:
|
118 |
+
agent.reset(environment=self)
|
119 |
+
|
120 |
+
def is_done(self) -> bool:
|
121 |
+
"""Check if the environment is done"""
|
122 |
+
return self.cnt_turn >= self.max_turns
|
123 |
+
|
124 |
+
def tick_tock(self) -> None:
|
125 |
+
"""Increment the time"""
|
126 |
+
self.current_time = self.current_time + datetime.timedelta(
|
127 |
+
seconds=self.time_delta
|
128 |
+
)
|
agentverse/environments/simulation_env/rules/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base import SimulationRule
|
agentverse/environments/simulation_env/rules/base.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import abstractmethod
|
4 |
+
from typing import TYPE_CHECKING, Any, List, Optional
|
5 |
+
|
6 |
+
from agentverse.environments.simulation_env.rules.describer import (
|
7 |
+
BaseDescriber,
|
8 |
+
describer_registry,
|
9 |
+
)
|
10 |
+
from agentverse.environments.simulation_env.rules.order import BaseOrder, order_registry
|
11 |
+
from agentverse.environments.simulation_env.rules.selector import (
|
12 |
+
BaseSelector,
|
13 |
+
selector_registry,
|
14 |
+
)
|
15 |
+
from agentverse.environments.simulation_env.rules.updater import (
|
16 |
+
BaseUpdater,
|
17 |
+
updater_registry,
|
18 |
+
)
|
19 |
+
from agentverse.environments.simulation_env.rules.visibility import (
|
20 |
+
BaseVisibility,
|
21 |
+
visibility_registry,
|
22 |
+
)
|
23 |
+
from agentverse.environments import BaseRule
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from agentverse.environments.base import BaseEnvironment
|
27 |
+
|
28 |
+
from agentverse.message import Message
|
29 |
+
|
30 |
+
|
31 |
+
# class Rule(BaseModel):
|
32 |
+
class SimulationRule(BaseRule):
|
33 |
+
"""
|
34 |
+
Rule for the environment. It controls the speaking order of the agents
|
35 |
+
and maintain the set of visible agents for each agent.
|
36 |
+
"""
|
37 |
+
|
38 |
+
order: BaseOrder
|
39 |
+
visibility: BaseVisibility
|
40 |
+
selector: BaseSelector
|
41 |
+
updater: BaseUpdater
|
42 |
+
describer: BaseDescriber
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
order_config,
|
47 |
+
visibility_config,
|
48 |
+
selector_config,
|
49 |
+
updater_config,
|
50 |
+
describer_config,
|
51 |
+
):
|
52 |
+
order = order_registry.build(**order_config)
|
53 |
+
visibility = visibility_registry.build(**visibility_config)
|
54 |
+
selector = selector_registry.build(**selector_config)
|
55 |
+
updater = updater_registry.build(**updater_config)
|
56 |
+
describer = describer_registry.build(**describer_config)
|
57 |
+
super().__init__(
|
58 |
+
order=order,
|
59 |
+
visibility=visibility,
|
60 |
+
selector=selector,
|
61 |
+
updater=updater,
|
62 |
+
describer=describer,
|
63 |
+
)
|
64 |
+
|
65 |
+
def get_next_agent_idx(
|
66 |
+
self, environment: BaseEnvironment, *args, **kwargs
|
67 |
+
) -> List[int]:
|
68 |
+
"""Return the index of the next agent to speak"""
|
69 |
+
return self.order.get_next_agent_idx(environment, *args, **kwargs)
|
70 |
+
|
71 |
+
def update_visible_agents(
|
72 |
+
self, environment: BaseEnvironment, *args, **kwargs
|
73 |
+
) -> None:
|
74 |
+
"""Update the set of visible agents for the agent"""
|
75 |
+
self.visibility.update_visible_agents(environment, *args, **kwargs)
|
76 |
+
|
77 |
+
def select_message(
|
78 |
+
self, environment: BaseEnvironment, messages: List[Message], *args, **kwargs
|
79 |
+
) -> List[Message]:
|
80 |
+
"""Select a set of valid messages from all the generated messages"""
|
81 |
+
return self.selector.select_message(environment, messages, *args, **kwargs)
|
82 |
+
|
83 |
+
def update_memory(self, environment: BaseEnvironment, *args, **kwargs) -> None:
|
84 |
+
"""For each message, add it to the memory of the agent who is able to see that message"""
|
85 |
+
self.updater.update_memory(environment, *args, **kwargs)
|
86 |
+
|
87 |
+
def get_env_description(
|
88 |
+
self, environment: BaseEnvironment, *args, **kwargs
|
89 |
+
) -> List[str]:
|
90 |
+
"""Return the description of the environment for each agent"""
|
91 |
+
return self.describer.get_env_description(environment, *args, **kwargs)
|
92 |
+
|
93 |
+
def reset(self) -> None:
|
94 |
+
self.order.reset()
|
95 |
+
self.visibility.reset()
|
96 |
+
self.selector.reset()
|
97 |
+
self.updater.reset()
|
98 |
+
self.describer.reset()
|
agentverse/environments/simulation_env/rules/describer/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agentverse.registry import Registry
|
2 |
+
|
3 |
+
describer_registry = Registry(name="DescriberRegistry")
|
4 |
+
|
5 |
+
from .base import BaseDescriber
|
6 |
+
from .basic import BasicDescriber
|
7 |
+
from .classroom import ClassroomDescriber
|
8 |
+
from .pokemon import PokemonDescriber
|
9 |
+
from .prisoner import PrisonerDescriber
|
agentverse/environments/simulation_env/rules/describer/base.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any, List
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from . import describer_registry as DescriberRegistry
|
8 |
+
from abc import abstractmethod
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
|
13 |
+
|
14 |
+
class BaseDescriber(BaseModel):
|
15 |
+
@abstractmethod
|
16 |
+
def get_env_description(
|
17 |
+
self, environment: BaseEnvironment, *args, **kwargs
|
18 |
+
) -> List[str]:
|
19 |
+
"""Return the environment description for each agent"""
|
20 |
+
pass
|
21 |
+
|
22 |
+
def reset(self) -> None:
|
23 |
+
pass
|
agentverse/environments/simulation_env/rules/describer/basic.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any, List
|
4 |
+
|
5 |
+
from . import describer_registry as DescriberRegistry
|
6 |
+
from .base import BaseDescriber
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from agentverse.environments import BaseEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
@DescriberRegistry.register("basic")
|
13 |
+
class BasicDescriber(BaseDescriber):
|
14 |
+
def get_env_description(self, environment: BaseEnvironment) -> List[str]:
|
15 |
+
"""Return the environment description for each agent"""
|
16 |
+
return ["" for _ in range(len(environment.agents))]
|
agentverse/environments/simulation_env/rules/describer/classroom.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any, List
|
4 |
+
from string import Template
|
5 |
+
|
6 |
+
from . import describer_registry as DescriberRegistry
|
7 |
+
from .basic import BasicDescriber
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from agentverse.environments import BaseEnvironment
|
11 |
+
|
12 |
+
|
13 |
+
@DescriberRegistry.register("classroom")
|
14 |
+
class ClassroomDescriber(BasicDescriber):
|
15 |
+
start_prompt: str
|
16 |
+
end_prompt: str
|
17 |
+
|
18 |
+
def get_env_description(self, environment: BaseEnvironment) -> List[str]:
|
19 |
+
if not environment.rule_params.get("is_grouped", False):
|
20 |
+
if environment.rule_params.get("is_grouped_ended", False):
|
21 |
+
# If the group discussion is just ended
|
22 |
+
environment.rule_params["is_grouped_ended"] = False
|
23 |
+
return [self.end_prompt for _ in range(len(environment.agents))]
|
24 |
+
else:
|
25 |
+
return super().get_env_description(environment)
|
26 |
+
description = []
|
27 |
+
for i, agent in enumerate(environment.agents):
|
28 |
+
if i == 0:
|
29 |
+
# Professor will not participate in group discussion
|
30 |
+
description.append("")
|
31 |
+
else:
|
32 |
+
description.append(
|
33 |
+
Template(self.start_prompt).safe_substitute(
|
34 |
+
{"receiver_name": ", ".join(agent.receiver)}
|
35 |
+
)
|
36 |
+
)
|
37 |
+
return description
|
38 |
+
|
39 |
+
def reset(self) -> None:
|
40 |
+
pass
|
agentverse/environments/simulation_env/rules/describer/pokemon.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Dict
|
4 |
+
from copy import deepcopy
|
5 |
+
|
6 |
+
from . import describer_registry as DescriberRegistry
|
7 |
+
from .base import BaseDescriber
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from agentverse.environments.pokemon import PokemonEnvironment
|
11 |
+
|
12 |
+
|
13 |
+
@DescriberRegistry.register("pokemon")
|
14 |
+
class PokemonDescriber(BaseDescriber):
|
15 |
+
def get_env_description(
|
16 |
+
self,
|
17 |
+
environment: PokemonEnvironment,
|
18 |
+
player_content: str = "",
|
19 |
+
) -> List[str]:
|
20 |
+
time = environment.time
|
21 |
+
if player_content == "":
|
22 |
+
agent_to_location = environment.get_agent_to_location()
|
23 |
+
descriptions = []
|
24 |
+
for agent in environment.agents:
|
25 |
+
description = ""
|
26 |
+
if agent.name not in agent_to_location:
|
27 |
+
# Agent is on the way to a location
|
28 |
+
descriptions.append("")
|
29 |
+
continue
|
30 |
+
location = agent_to_location[agent.name]
|
31 |
+
agents_in_same_loc = deepcopy(environment.locations_to_agents[location])
|
32 |
+
agents_in_same_loc.remove(agent.name)
|
33 |
+
agents_in_same_loc = list(agents_in_same_loc)
|
34 |
+
description += f"It is now {time}. You are at {location}."
|
35 |
+
if len(agents_in_same_loc) == 0:
|
36 |
+
description += " There is no one else here."
|
37 |
+
elif len(agents_in_same_loc) == 1:
|
38 |
+
description += f" {agents_in_same_loc[0]} is also here."
|
39 |
+
else:
|
40 |
+
other_agents = ", ".join(agents_in_same_loc)
|
41 |
+
description += f" {other_agents} are also here."
|
42 |
+
# description += " The locations you can go to include: \n"
|
43 |
+
# for loc, dsec in environment.locations_descriptions.items():
|
44 |
+
# description += f"{loc}: {dsec}\n"
|
45 |
+
descriptions.append(description)
|
46 |
+
return descriptions
|
47 |
+
else:
|
48 |
+
description = ""
|
49 |
+
description += f"It is now {time}. Brendan is talking to you.\n"
|
50 |
+
description += f"[Brendan]: {player_content}\n"
|
51 |
+
return [description for _ in range(len(environment.agents))]
|
agentverse/environments/simulation_env/rules/describer/prisoner.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, Any, List
|
4 |
+
|
5 |
+
from . import describer_registry as DescriberRegistry
|
6 |
+
from .base import BaseDescriber
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from agentverse.environments import BaseEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
@DescriberRegistry.register("prisoner")
|
13 |
+
class PrisonerDescriber(BaseDescriber):
|
14 |
+
switch_func = {
|
15 |
+
"Both Suspects": "Suspect2",
|
16 |
+
"Suspect1": "Suspect2",
|
17 |
+
"Suspect2": "Suspect1",
|
18 |
+
}
|
19 |
+
receiver: str = "Both Suspects"
|
20 |
+
|
21 |
+
def get_env_description(self, environment: BaseEnvironment) -> List[str]:
|
22 |
+
if environment.cnt_turn == 0:
|
23 |
+
environment.agents[0].set_receiver({"all"})
|
24 |
+
environment.agents[1].set_receiver({"Police", "Suspect1"})
|
25 |
+
environment.agents[2].set_receiver({"Police", "Suspect2"})
|
26 |
+
|
27 |
+
# only police have to choose to talk to suspect1 or suspect
|
28 |
+
description = []
|
29 |
+
for i, agent in enumerate(environment.agents):
|
30 |
+
if i == 0:
|
31 |
+
# police -> suspect1 -> police -> suspect2
|
32 |
+
if environment.cnt_turn % 2 == 1:
|
33 |
+
description.append("")
|
34 |
+
continue
|
35 |
+
|
36 |
+
# Police will have to choose talk to which suspect
|
37 |
+
description.append(f"You are now talking to {self.receiver}")
|
38 |
+
|
39 |
+
receiver = "all" if self.receiver == "Both Suspects" else self.receiver
|
40 |
+
self.receiver = self.switch_func[self.receiver]
|
41 |
+
agent.set_receiver({receiver})
|
42 |
+
|
43 |
+
else:
|
44 |
+
description.append("")
|
45 |
+
|
46 |
+
return description
|
47 |
+
|
48 |
+
def reset(self) -> None:
|
49 |
+
pass
|
agentverse/environments/simulation_env/rules/order/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agentverse.registry import Registry
|
2 |
+
order_registry = Registry(name="OrderRegistry")
|
3 |
+
|
4 |
+
from .base import BaseOrder
|
5 |
+
from .sequential import SequentialOrder
|
6 |
+
from .random import RandomOrder
|
7 |
+
from .concurrent import ConcurrentOrder
|
8 |
+
from .classroom import ClassroomOrder
|
9 |
+
from .prisoner import PrisonerOrder
|
10 |
+
from .sde_team import SdeTeamOrder
|
11 |
+
from .sde_team_given_tests import SdeTeamGivenTestsOrder
|
agentverse/environments/simulation_env/rules/order/base.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from abc import abstractmethod
|
4 |
+
from typing import TYPE_CHECKING, Any, List
|
5 |
+
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from agentverse.environments import BaseEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
class BaseOrder(BaseModel):
|
13 |
+
@abstractmethod
|
14 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
15 |
+
"""Return the index of the next agent to speak"""
|
16 |
+
|
17 |
+
def reset(self) -> None:
|
18 |
+
pass
|
agentverse/environments/simulation_env/rules/order/classroom.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from typing import TYPE_CHECKING, Any, List, Optional
|
6 |
+
|
7 |
+
from . import order_registry as OrderRegistry
|
8 |
+
from .base import BaseOrder
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
|
13 |
+
|
14 |
+
@OrderRegistry.register("classroom")
|
15 |
+
class ClassroomOrder(BaseOrder):
|
16 |
+
"""The order for a classroom discussion
|
17 |
+
The agents speak in the following order:
|
18 |
+
1. The professor speaks first
|
19 |
+
2. Then the professor can continue to speak, and the students can raise hands
|
20 |
+
3. The professor can call on a student, then the student can speak or ask a question
|
21 |
+
4. In the group discussion, the students in the group can speak in turn
|
22 |
+
"""
|
23 |
+
|
24 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
25 |
+
# `is_grouped_ended`: whether the group discussion just ended
|
26 |
+
# `is_grouped`: whether it is currently in a group discussion
|
27 |
+
if environment.rule_params.get("is_grouped_ended", False):
|
28 |
+
return [0]
|
29 |
+
if environment.rule_params.get("is_grouped", False):
|
30 |
+
return self.get_next_agent_idx_grouped(environment)
|
31 |
+
else:
|
32 |
+
return self.get_next_agent_idx_ungrouped(environment)
|
33 |
+
|
34 |
+
def get_next_agent_idx_ungrouped(self, environment: BaseEnvironment) -> List[int]:
|
35 |
+
if len(environment.last_messages) == 0:
|
36 |
+
# If the class just begins or no one speaks in the last turn, we let only the professor speak
|
37 |
+
return [0]
|
38 |
+
elif len(environment.last_messages) == 1:
|
39 |
+
message = environment.last_messages[0]
|
40 |
+
sender = message.sender
|
41 |
+
content = message.content
|
42 |
+
if sender.startswith("Professor"):
|
43 |
+
if content.startswith("[CallOn]"):
|
44 |
+
# 1. professor calls on someone, then the student should speak
|
45 |
+
result = re.search(r"\[CallOn\] Yes, ([sS]tudent )?(\w+)", content)
|
46 |
+
if result is not None:
|
47 |
+
name_to_id = {
|
48 |
+
agent.name[len("Student ") :]: i
|
49 |
+
for i, agent in enumerate(environment.agents)
|
50 |
+
}
|
51 |
+
return [name_to_id[result.group(2)]]
|
52 |
+
else:
|
53 |
+
# 2. professor normally speaks, then anyone can act
|
54 |
+
return list(range(len(environment.agents)))
|
55 |
+
elif sender.startswith("Student"):
|
56 |
+
# 3. student ask question after being called on, or
|
57 |
+
# 4. only one student raises hand, and the professor happens to listen
|
58 |
+
# 5. the group discussion is just over, and there happens to be only a student speaking in the last turn
|
59 |
+
return [0]
|
60 |
+
else:
|
61 |
+
# If len(last_messages) > 1, then
|
62 |
+
# 1. there must be at least one student raises hand or speaks.
|
63 |
+
# 2. the group discussion is just over.
|
64 |
+
return [0]
|
65 |
+
assert (
|
66 |
+
False
|
67 |
+
), f"Should not reach here, last_messages: {environment.last_messages}"
|
68 |
+
|
69 |
+
def get_next_agent_idx_grouped(self, environment: BaseEnvironment) -> List[int]:
|
70 |
+
# Get the grouping information
|
71 |
+
# groups: A list of list of agent ids, the i-th list contains
|
72 |
+
# the agent ids in the i-th group
|
73 |
+
# group_speaker_mapping: A mapping from group id to the id of
|
74 |
+
# the speaker in the group
|
75 |
+
# `groups` should be set in the corresponding `visibility`,
|
76 |
+
# and `group_speaker_mapping` should be maintained here.
|
77 |
+
if "groups" not in environment.rule_params:
|
78 |
+
logging.warning(
|
79 |
+
"The environment is grouped, but the grouping information is not provided."
|
80 |
+
)
|
81 |
+
groups = environment.rule_params.get(
|
82 |
+
"groups", [list(range(len(environment.agents)))]
|
83 |
+
)
|
84 |
+
group_speaker_mapping = environment.rule_params.get(
|
85 |
+
"group_speaker_mapping", {i: 0 for i in range(len(groups))}
|
86 |
+
)
|
87 |
+
|
88 |
+
# For grouped environment, we let the students speak in turn within each group
|
89 |
+
next_agent_idx = []
|
90 |
+
for group_id in range(len(groups)):
|
91 |
+
speaker_index = group_speaker_mapping[group_id]
|
92 |
+
speaker = groups[group_id][speaker_index]
|
93 |
+
next_agent_idx.append(speaker)
|
94 |
+
|
95 |
+
# Maintain the `group_speaker_mapping`
|
96 |
+
for k, v in group_speaker_mapping.items():
|
97 |
+
group_speaker_mapping[k] = (v + 1) % len(groups[k])
|
98 |
+
environment.rule_params["group_speaker_mapping"] = group_speaker_mapping
|
99 |
+
|
100 |
+
return next_agent_idx
|
agentverse/environments/simulation_env/rules/order/concurrent.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from . import order_registry as OrderRegistry
|
6 |
+
from .base import BaseOrder
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from agentverse.environments import BaseEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
@OrderRegistry.register("concurrent")
|
13 |
+
class ConcurrentOrder(BaseOrder):
|
14 |
+
"""
|
15 |
+
The agents speak concurrently
|
16 |
+
"""
|
17 |
+
|
18 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
19 |
+
return list(range(len(environment.agents)))
|
agentverse/environments/simulation_env/rules/order/prisoner.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from typing import TYPE_CHECKING, Any, List, Optional
|
6 |
+
|
7 |
+
from . import order_registry as OrderRegistry
|
8 |
+
from .base import BaseOrder
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
|
13 |
+
|
14 |
+
@OrderRegistry.register("prisoner")
|
15 |
+
class PrisonerOrder(BaseOrder):
|
16 |
+
"""The order for a classroom discussion
|
17 |
+
The agents speak in the following order:
|
18 |
+
1. The professor speaks first
|
19 |
+
2. Then the professor can continue to speak, and the students can raise hands
|
20 |
+
3. The professor can call on a student, then the student can speak or ask a question
|
21 |
+
4. In the group discussion, the students in the group can speak in turn
|
22 |
+
"""
|
23 |
+
|
24 |
+
# try police, prisoner1 prisoner2 first
|
25 |
+
|
26 |
+
last_prisoner_index: int = 1
|
27 |
+
switch_func: dict = {1: 2, 2: 1}
|
28 |
+
|
29 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
30 |
+
if len(environment.last_messages) == 0:
|
31 |
+
# If the game just begins or , we let only the police speak
|
32 |
+
return [0]
|
33 |
+
elif len(environment.last_messages) == 1:
|
34 |
+
message = environment.last_messages[0]
|
35 |
+
sender = message.sender
|
36 |
+
content = message.content
|
37 |
+
if sender.startswith("Police"):
|
38 |
+
next_prisoner = self.last_prisoner_index
|
39 |
+
self.last_prisoner_index = self.switch_func[self.last_prisoner_index]
|
40 |
+
return [next_prisoner]
|
41 |
+
elif sender.startswith("Suspect"):
|
42 |
+
# 3. when one prisoner made his action, let the police tell another prisoner
|
43 |
+
return [0]
|
44 |
+
else:
|
45 |
+
# If len(last_messages) > 1, then
|
46 |
+
# 1. there must be at least one student raises hand or speaks.
|
47 |
+
# 2. the group discussion is just over.
|
48 |
+
return [0]
|
agentverse/environments/simulation_env/rules/order/random.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import random
|
4 |
+
from typing import TYPE_CHECKING, List
|
5 |
+
|
6 |
+
from . import order_registry as OrderRegistry
|
7 |
+
from .base import BaseOrder
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from agentverse.environments import BaseEnvironment
|
11 |
+
|
12 |
+
|
13 |
+
@OrderRegistry.register("random")
|
14 |
+
class RandomOrder(BaseOrder):
|
15 |
+
"""
|
16 |
+
Order for random conversation
|
17 |
+
The agents speak in a random order
|
18 |
+
"""
|
19 |
+
|
20 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
21 |
+
return [random.randint(0, len(environment.agents) - 1)]
|
agentverse/environments/simulation_env/rules/order/sde_team.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
from typing import TYPE_CHECKING, Any, List, Optional
|
7 |
+
|
8 |
+
from . import order_registry as OrderRegistry
|
9 |
+
from .base import BaseOrder
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from agentverse.environments import BaseEnvironment
|
13 |
+
|
14 |
+
|
15 |
+
@OrderRegistry.register("sde_team")
|
16 |
+
class SdeTeamOrder(BaseOrder):
|
17 |
+
"""The order for a code problem solving
|
18 |
+
"""
|
19 |
+
next_agent_idx: int = 2
|
20 |
+
|
21 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
22 |
+
if self.next_agent_idx == 2:
|
23 |
+
self.next_agent_idx = 0
|
24 |
+
return [2] * 5 # TODO set the number in yaml
|
25 |
+
elif self.next_agent_idx == 0:
|
26 |
+
self.next_agent_idx = 1
|
27 |
+
return [0]
|
28 |
+
elif self.next_agent_idx == 1:
|
29 |
+
self.next_agent_idx = 0
|
30 |
+
return [1]
|
agentverse/environments/simulation_env/rules/order/sde_team_given_tests.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
from typing import TYPE_CHECKING, Any, List, Optional
|
7 |
+
|
8 |
+
from . import order_registry as OrderRegistry
|
9 |
+
from .base import BaseOrder
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from agentverse.environments import BaseEnvironment
|
13 |
+
|
14 |
+
|
15 |
+
@OrderRegistry.register("sde_team_given_tests")
|
16 |
+
class SdeTeamGivenTestsOrder(BaseOrder):
|
17 |
+
"""The order for a code problem solving given unit tests
|
18 |
+
0 - code writer
|
19 |
+
1 - code tester
|
20 |
+
2 - code reviewer
|
21 |
+
"""
|
22 |
+
next_agent_idx: int = 0
|
23 |
+
|
24 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
25 |
+
if self.next_agent_idx == 0:
|
26 |
+
self.next_agent_idx = 1
|
27 |
+
return [0]
|
28 |
+
elif self.next_agent_idx == 1:
|
29 |
+
self.next_agent_idx = 2
|
30 |
+
return [1]
|
31 |
+
elif self.next_agent_idx == 2:
|
32 |
+
self.next_agent_idx = 0
|
33 |
+
return [2]
|
34 |
+
else:
|
35 |
+
raise ValueError("Invalid next_agent_idx: {}".format(self.next_agent_idx))
|
agentverse/environments/simulation_env/rules/order/sequential.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from . import order_registry as OrderRegistry
|
6 |
+
from .base import BaseOrder
|
7 |
+
|
8 |
+
if TYPE_CHECKING:
|
9 |
+
from agentverse.environments import BaseEnvironment
|
10 |
+
|
11 |
+
|
12 |
+
@OrderRegistry.register("sequential")
|
13 |
+
class SequentialOrder(BaseOrder):
|
14 |
+
"""
|
15 |
+
Order for sequential conversation
|
16 |
+
The agents speak in a round-robin fashion
|
17 |
+
"""
|
18 |
+
|
19 |
+
next_agent_idx: int = 0
|
20 |
+
|
21 |
+
def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
|
22 |
+
"""Return the index of the next agent to speak"""
|
23 |
+
ret = self.next_agent_idx
|
24 |
+
self.next_agent_idx = (self.next_agent_idx + 1) % len(environment.agents)
|
25 |
+
return [ret]
|
26 |
+
|
27 |
+
def reset(self) -> None:
|
28 |
+
self.next_agent_idx = 0
|
agentverse/environments/simulation_env/rules/selector/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agentverse.registry import Registry
|
2 |
+
|
3 |
+
selector_registry = Registry(name="SelectorRegistry")
|
4 |
+
|
5 |
+
from .base import BaseSelector
|
6 |
+
from .basic import BasicSelector
|
7 |
+
from .classroom import ClassroomSelector
|
8 |
+
from .sde_team import SdeTeamSelector
|
9 |
+
from .sde_team_given_tests import SdeTeamGivenTestsSelector
|
10 |
+
from .pokemon import PokemonSelector
|
agentverse/environments/simulation_env/rules/selector/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from agentverse.message import Message
|
8 |
+
|
9 |
+
from . import selector_registry as SelectorRegistry
|
10 |
+
from abc import abstractmethod
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from agentverse.environments import BaseEnvironment
|
14 |
+
|
15 |
+
|
16 |
+
@SelectorRegistry.register("base")
|
17 |
+
class BaseSelector(BaseModel):
|
18 |
+
"""
|
19 |
+
Base class for all selecters
|
20 |
+
"""
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def select_message(
|
24 |
+
self, environment: BaseEnvironment, messages: List[Message]
|
25 |
+
) -> List[Message]:
|
26 |
+
"""Selects a set of valid messages from all messages"""
|
27 |
+
pass
|
28 |
+
|
29 |
+
def reset(self) -> None:
|
30 |
+
pass
|
agentverse/environments/simulation_env/rules/selector/basic.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from agentverse.message import Message
|
6 |
+
|
7 |
+
from . import selector_registry as SelectorRegistry
|
8 |
+
from .base import BaseSelector
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
|
13 |
+
|
14 |
+
@SelectorRegistry.register("basic")
|
15 |
+
class BasicSelector(BaseSelector):
|
16 |
+
"""
|
17 |
+
Base class for all selecters
|
18 |
+
"""
|
19 |
+
|
20 |
+
def select_message(
|
21 |
+
self, environment: BaseEnvironment, messages: List[Message]
|
22 |
+
) -> List[Message]:
|
23 |
+
"""Selects a set of valid messages from all messages"""
|
24 |
+
return messages
|
25 |
+
|
26 |
+
def reset(self) -> None:
|
27 |
+
pass
|
agentverse/environments/simulation_env/rules/selector/classroom.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from agentverse.message import Message
|
6 |
+
|
7 |
+
from . import selector_registry as SelectorRegistry
|
8 |
+
from .base import BaseSelector
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
|
13 |
+
|
14 |
+
@SelectorRegistry.register("classroom")
|
15 |
+
class ClassroomSelector(BaseSelector):
|
16 |
+
def select_message(
|
17 |
+
self, environment: BaseEnvironment, messages: List[Message]
|
18 |
+
) -> List[Message]:
|
19 |
+
selected = []
|
20 |
+
for message in messages:
|
21 |
+
if message.sender.startswith("Student"):
|
22 |
+
if message.content.startswith("[RaiseHand]"):
|
23 |
+
message.content = "[RaiseHand]"
|
24 |
+
selected.append(message)
|
25 |
+
elif message.content != "" or len(message.tool_response) > 0:
|
26 |
+
selected.append(message)
|
27 |
+
elif message.sender.startswith("Professor"):
|
28 |
+
# If the professor launch a group discussion, then we
|
29 |
+
# brutely discard the student's message in this turn
|
30 |
+
if message.content.startswith("[GroupDiscuss]"):
|
31 |
+
return [message]
|
32 |
+
selected.append(message)
|
33 |
+
|
34 |
+
# If some student speak while the professor is speaking, then
|
35 |
+
# we brutely discard the student's message in this turn
|
36 |
+
if (
|
37 |
+
len(selected) > 1
|
38 |
+
and selected[0].sender.startswith("Professor")
|
39 |
+
and selected[0].content != ""
|
40 |
+
):
|
41 |
+
filtered_selected = []
|
42 |
+
filtered_selected.append(selected[0])
|
43 |
+
for message in selected[1:]:
|
44 |
+
if message.content.startswith("[RaiseHand]"):
|
45 |
+
filtered_selected.append(message)
|
46 |
+
selected = filtered_selected
|
47 |
+
return selected
|
agentverse/environments/simulation_env/rules/selector/code_api.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import sys
|
3 |
+
import ast
|
4 |
+
import json
|
5 |
+
import astunparse
|
6 |
+
import concurrent.futures
|
7 |
+
import traceback
|
8 |
+
|
9 |
+
|
10 |
+
def get_call_str(assert_statement: str) -> str:
|
11 |
+
call_str = ast.parse(assert_statement).body[0].test.left # type: ignore
|
12 |
+
return astunparse.unparse(call_str).strip()
|
13 |
+
|
14 |
+
def get_output(func: str, assert_statement: str) -> str:
|
15 |
+
try:
|
16 |
+
func_call = get_call_str(assert_statement)
|
17 |
+
try:
|
18 |
+
exec(func, globals())
|
19 |
+
output = eval(func_call)
|
20 |
+
return output
|
21 |
+
except Exception as e:
|
22 |
+
return str(e)
|
23 |
+
except:
|
24 |
+
return "get_call_str error"
|
25 |
+
|
26 |
+
def worker(code, globals=None, locals=None):
|
27 |
+
old_stdout = sys.stdout
|
28 |
+
redirected_output = sys.stdout = io.StringIO()
|
29 |
+
if locals is None:
|
30 |
+
locals = {}
|
31 |
+
try:
|
32 |
+
# TODO: exec(code, globals, locals) could be buggy
|
33 |
+
# In cases where both import statement and function exits in the code, if the locals are given,
|
34 |
+
# the code will not find the imported package.
|
35 |
+
# For example,
|
36 |
+
# code = "import math\ndef f(x):\n\treturn math.pow(x, 2)\nassert f(2) == 4"
|
37 |
+
# It will return "NameError: name 'math' is not defined"
|
38 |
+
exec(code, locals, locals)
|
39 |
+
stdout = redirected_output.getvalue()
|
40 |
+
return stdout, globals, locals
|
41 |
+
except Exception as e:
|
42 |
+
trace_str = traceback.format_exc()
|
43 |
+
return f"Error: {trace_str}", globals, locals
|
44 |
+
finally:
|
45 |
+
sys.stdout = old_stdout # restore the original stdout
|
46 |
+
|
47 |
+
def execute_code(code: str) -> str:
|
48 |
+
"""Execute a snippet of python code and return the output or the error message.
|
49 |
+
"""
|
50 |
+
timeout = 5
|
51 |
+
try:
|
52 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
53 |
+
future = executor.submit(worker, code)
|
54 |
+
result, _, _ = future.result(timeout)
|
55 |
+
return result
|
56 |
+
except concurrent.futures.TimeoutError:
|
57 |
+
return "Timeout"
|
58 |
+
|
59 |
+
def execute_unit_tests(func_impl: str, tests: str) -> str:
|
60 |
+
"""Run a python function on a bunch of unit tests tests and return detailed feedback.
|
61 |
+
"""
|
62 |
+
# tests = eval(tests)
|
63 |
+
# assert type(tests) == list
|
64 |
+
|
65 |
+
# Combine function code and assert statement
|
66 |
+
func_test_list = [f'{func_impl}\n{test}' for test in tests]
|
67 |
+
|
68 |
+
# Run the tests and collect the results
|
69 |
+
success_tests = []
|
70 |
+
failed_tests = []
|
71 |
+
is_passing = True
|
72 |
+
num_tests = len(func_test_list)
|
73 |
+
for i in range(num_tests):
|
74 |
+
output = execute_code(func_test_list[i])
|
75 |
+
if output == "Timeout":
|
76 |
+
failed_tests += [f"{tests[i]} # output: Timeout"]
|
77 |
+
is_passing = False
|
78 |
+
elif output.startswith("Error: "):
|
79 |
+
# print(output)
|
80 |
+
func_output = get_output(func_impl, tests[i])
|
81 |
+
if func_output == "get_call_str error":
|
82 |
+
func_output = output
|
83 |
+
failed_tests += [f"{tests[i]} # output: {func_output}"]
|
84 |
+
is_passing = False
|
85 |
+
else:
|
86 |
+
success_tests += [tests[i]]
|
87 |
+
|
88 |
+
feedback = "Tested passed:\n\n"
|
89 |
+
for test in success_tests:
|
90 |
+
feedback += f"{test}\n\n"
|
91 |
+
feedback += "Tests failed:\n\n"
|
92 |
+
for test in failed_tests:
|
93 |
+
feedback += f"{test}\n\n"
|
94 |
+
|
95 |
+
return json.dumps({"is_passing": is_passing,
|
96 |
+
"feedback": feedback})
|
97 |
+
|
agentverse/environments/simulation_env/rules/selector/pokemon.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
|
7 |
+
from agentverse.message import Message
|
8 |
+
|
9 |
+
from . import selector_registry as SelectorRegistry
|
10 |
+
from .base import BaseSelector
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from agentverse.environments import PokemonEnvironment
|
14 |
+
|
15 |
+
|
16 |
+
@SelectorRegistry.register("pokemon")
|
17 |
+
class PokemonSelector(BaseSelector):
|
18 |
+
"""
|
19 |
+
Selector for Pokemon environment
|
20 |
+
"""
|
21 |
+
|
22 |
+
def select_message(
|
23 |
+
self, environment: PokemonEnvironment, messages: List[Message]
|
24 |
+
) -> List[Message]:
|
25 |
+
valid = []
|
26 |
+
talk_matrix = np.zeros((len(environment.agents), len(environment.agents)))
|
27 |
+
agent_to_idx = {agent.name: i for i, agent in enumerate(environment.agents)}
|
28 |
+
for i, message in enumerate(messages):
|
29 |
+
try:
|
30 |
+
content = json.loads(message.content)
|
31 |
+
except json.decoder.JSONDecodeError:
|
32 |
+
valid.append(0)
|
33 |
+
continue
|
34 |
+
if content["action"] == "Speak":
|
35 |
+
try:
|
36 |
+
if "to" not in content:
|
37 |
+
# If the model does not generate receiver, then we discard the message
|
38 |
+
valid.append(0)
|
39 |
+
elif content["to"] in agent_to_idx:
|
40 |
+
# TODO: allow talk to a list of agents
|
41 |
+
valid.append(1)
|
42 |
+
# talk_matrix[i][j] = 1 ==> i talk to j
|
43 |
+
talk_matrix[agent_to_idx[message.sender]][
|
44 |
+
agent_to_idx[content["to"]]
|
45 |
+
] = 1
|
46 |
+
else:
|
47 |
+
# If the receiver is not in the environment, then we discard the message
|
48 |
+
valid.append(0)
|
49 |
+
except:
|
50 |
+
valid.append(0)
|
51 |
+
continue
|
52 |
+
elif content["action"] == "MoveTo":
|
53 |
+
# If the agent move to a location that does not exist, then we discard the message
|
54 |
+
valid.append(
|
55 |
+
"to" in content and content["to"] in environment.locations_to_agents
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
valid.append(1)
|
59 |
+
selected_messages = []
|
60 |
+
for i, message in enumerate(messages):
|
61 |
+
content = json.loads(message.content)
|
62 |
+
sender_idx = agent_to_idx[message.sender]
|
63 |
+
if valid[i] == 0:
|
64 |
+
selected_messages.append(Message())
|
65 |
+
continue
|
66 |
+
if content["action"] == "MoveTo":
|
67 |
+
if np.sum(talk_matrix[:, sender_idx]) > 0:
|
68 |
+
# If someone talk to this agent, then we discard the move action
|
69 |
+
selected_messages.append(Message())
|
70 |
+
else:
|
71 |
+
selected_messages.append(message)
|
72 |
+
elif content["action"] == "Speak":
|
73 |
+
receiver_idx = agent_to_idx[content["to"]]
|
74 |
+
if talk_matrix[sender_idx][receiver_idx] == 0:
|
75 |
+
# If this agent talk to someone who also talk to this agent, and we
|
76 |
+
# select the message from this agent, then we discard the message
|
77 |
+
selected_messages.append(Message())
|
78 |
+
continue
|
79 |
+
if np.sum(talk_matrix[receiver_idx, :]) > 0:
|
80 |
+
if talk_matrix[receiver_idx][sender_idx] == 1:
|
81 |
+
# If the receiver talk to this agent, then we randomly select one message
|
82 |
+
if sender_idx < receiver_idx:
|
83 |
+
if np.random.random() < 0.5:
|
84 |
+
selected_messages.append(message)
|
85 |
+
talk_matrix[receiver_idx][sender_idx] = 0
|
86 |
+
else:
|
87 |
+
selected_messages.append(Message())
|
88 |
+
talk_matrix[sender_idx][receiver_idx] = 0
|
89 |
+
else:
|
90 |
+
print("Shouldn't happen")
|
91 |
+
else:
|
92 |
+
# If the receiver talk to other agent, we still talk to the receiver (?)
|
93 |
+
selected_messages.append(message)
|
94 |
+
else:
|
95 |
+
selected_messages.append(message)
|
96 |
+
else:
|
97 |
+
selected_messages.append(message)
|
98 |
+
return selected_messages
|
agentverse/environments/simulation_env/rules/selector/sde_team.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from agentverse.message import Message
|
6 |
+
|
7 |
+
from . import selector_registry as SelectorRegistry
|
8 |
+
from .base import BaseSelector
|
9 |
+
|
10 |
+
import json
|
11 |
+
import re
|
12 |
+
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
from agentverse.environments import BaseEnvironment
|
15 |
+
|
16 |
+
def extract(content: str, keyword: str):
|
17 |
+
result = ""
|
18 |
+
flag = False
|
19 |
+
for line in content.split('\n'):
|
20 |
+
if line.strip().startswith(keyword):
|
21 |
+
flag = True
|
22 |
+
continue
|
23 |
+
if flag:
|
24 |
+
result += line
|
25 |
+
result += "\n"
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
@SelectorRegistry.register("sde_team")
|
30 |
+
class SdeTeamSelector(BaseSelector):
|
31 |
+
def select_message(self, environment: BaseEnvironment, messages: List[Message]) -> List[Message]:
|
32 |
+
last_sender = environment.last_messages[0].sender
|
33 |
+
selected = messages
|
34 |
+
|
35 |
+
if last_sender == "unit_test_generator":
|
36 |
+
unit_tests = set()
|
37 |
+
for message in selected:
|
38 |
+
unit_test = extract(message.content, "<unit test>:")
|
39 |
+
if unit_test not in unit_tests:
|
40 |
+
unit_tests.add(extract(message.content, "<unit test>:"))
|
41 |
+
unit_tests = list(unit_tests)
|
42 |
+
environment.rule_params["unit_tests"] = str(unit_tests)
|
43 |
+
new_message = Message(
|
44 |
+
content="",
|
45 |
+
sender="unit_test_generator",
|
46 |
+
receiver=[],
|
47 |
+
) # TODO: set the content of the message
|
48 |
+
selected = [new_message]
|
49 |
+
|
50 |
+
elif last_sender == "code_writer":
|
51 |
+
cur_code = extract(selected[0].content, "<code>:")
|
52 |
+
environment.rule_params["code"] = cur_code
|
53 |
+
|
54 |
+
from .code_api import execute_unit_tests
|
55 |
+
feedback = execute_unit_tests(environment.rule_params["code"], eval(environment.rule_params["unit_tests"]))
|
56 |
+
|
57 |
+
environment.rule_params["feedback"] = feedback
|
58 |
+
selected[0].content = f"<current code>:\n\n{cur_code}\n\n<unit test feedback>:\n{feedback}"
|
59 |
+
f_dict = json.loads(feedback)
|
60 |
+
if f_dict["is_passing"]:
|
61 |
+
environment.rule_params["end_flag"] = True
|
62 |
+
|
63 |
+
elif last_sender == "code_reviewer":
|
64 |
+
code_review = selected[0].content
|
65 |
+
cur_code = environment.rule_params["code"]
|
66 |
+
selected[0].content = f"<current code>:\n\n{cur_code}\n\n{code_review}"
|
67 |
+
feedback = environment.rule_params["feedback"]
|
68 |
+
f_dict = json.loads(feedback)
|
69 |
+
if f_dict["is_passing"]:
|
70 |
+
environment.rule_params["end_flag"] = True
|
71 |
+
|
72 |
+
return selected
|
agentverse/environments/simulation_env/rules/selector/sde_team_given_tests.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
+
from agentverse.message import Message
|
6 |
+
|
7 |
+
from . import selector_registry as SelectorRegistry
|
8 |
+
from .base import BaseSelector
|
9 |
+
|
10 |
+
import json
|
11 |
+
import re
|
12 |
+
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
from agentverse.environments import BaseEnvironment
|
15 |
+
|
16 |
+
def extract(content: str, keyword: str):
|
17 |
+
result = ""
|
18 |
+
flag = False
|
19 |
+
for line in content.split('\n'):
|
20 |
+
if line.strip().startswith(keyword):
|
21 |
+
flag = True
|
22 |
+
continue
|
23 |
+
if flag:
|
24 |
+
result += line
|
25 |
+
result += "\n"
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
@SelectorRegistry.register("sde_team_given_tests")
|
30 |
+
class SdeTeamGivenTestsSelector(BaseSelector):
|
31 |
+
def select_message(self, environment: BaseEnvironment, messages: List[Message]) -> List[Message]:
|
32 |
+
last_sender = environment.last_messages[0].sender
|
33 |
+
selected = messages
|
34 |
+
|
35 |
+
if last_sender == "code_writer":
|
36 |
+
cur_code = extract(selected[0].content, "<code>:")
|
37 |
+
environment.rule_params["code"] = cur_code
|
38 |
+
selected[0].content = f"<current code>:\n{cur_code}"
|
39 |
+
|
40 |
+
elif last_sender == "code_tester":
|
41 |
+
|
42 |
+
from .code_api import execute_unit_tests
|
43 |
+
feedback = execute_unit_tests(environment.rule_params["code"], eval(environment.unit_tests))
|
44 |
+
environment.rule_params["feedback"] = feedback
|
45 |
+
selected[0].content = f"<unit test feedback>:\n{feedback}"
|
46 |
+
|
47 |
+
f_dict = json.loads(feedback)
|
48 |
+
if f_dict["is_passing"]:
|
49 |
+
environment.rule_params["end_flag"] = True
|
50 |
+
|
51 |
+
elif last_sender == "code_reviewer":
|
52 |
+
code_review = selected[0].content
|
53 |
+
cur_code = environment.rule_params["code"]
|
54 |
+
selected[0].content = f"{code_review}"
|
55 |
+
|
56 |
+
return selected
|
agentverse/environments/simulation_env/rules/updater/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agentverse.registry import Registry
|
2 |
+
|
3 |
+
updater_registry = Registry(name="UpdaterRegistry")
|
4 |
+
|
5 |
+
from .base import BaseUpdater
|
6 |
+
from .basic import BasicUpdater
|
7 |
+
from .classroom import ClassroomUpdater
|
8 |
+
from .sde_team import SdeTeamUpdater
|
9 |
+
from .pokemon import PokemonUpdater
|
agentverse/environments/simulation_env/rules/updater/base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List, Tuple
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
# from agentverse.agents import Agent
|
8 |
+
from abc import abstractmethod
|
9 |
+
|
10 |
+
from . import updater_registry as UpdaterRegistry
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from agentverse.environments import BaseEnvironment
|
14 |
+
|
15 |
+
|
16 |
+
@UpdaterRegistry.register("base")
|
17 |
+
class BaseUpdater(BaseModel):
|
18 |
+
"""
|
19 |
+
The base class of updater class.
|
20 |
+
"""
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def update_memory(self, environment: BaseEnvironment):
|
24 |
+
pass
|
25 |
+
|
26 |
+
def reset(self):
|
27 |
+
pass
|
agentverse/environments/simulation_env/rules/updater/basic.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import TYPE_CHECKING, List, Tuple
|
4 |
+
|
5 |
+
from . import updater_registry as UpdaterRegistry
|
6 |
+
from .base import BaseUpdater
|
7 |
+
from agentverse.message import Message
|
8 |
+
from agentverse.logging import get_logger
|
9 |
+
|
10 |
+
if TYPE_CHECKING:
|
11 |
+
from agentverse.environments import BaseEnvironment
|
12 |
+
from agentverse.agents import BaseAgent
|
13 |
+
|
14 |
+
logger = get_logger()
|
15 |
+
|
16 |
+
|
17 |
+
@UpdaterRegistry.register("basic")
|
18 |
+
class BasicUpdater(BaseUpdater):
|
19 |
+
"""
|
20 |
+
The basic version of updater.
|
21 |
+
The messages will be seen by all the receiver specified in the message.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def update_memory(self, environment: BaseEnvironment):
|
25 |
+
added = False
|
26 |
+
for message in environment.last_messages:
|
27 |
+
if len(message.tool_response) > 0:
|
28 |
+
self.add_tool_response(
|
29 |
+
message.sender, environment.agents, message.tool_response
|
30 |
+
)
|
31 |
+
if message.content == "":
|
32 |
+
continue
|
33 |
+
added |= self.add_message_to_all_agents(environment.agents, message)
|
34 |
+
# If no one speaks in this turn. Add an empty message to all agents
|
35 |
+
if not added:
|
36 |
+
for agent in environment.agents:
|
37 |
+
agent.add_message_to_memory([Message(content="[Silence]")])
|
38 |
+
|
39 |
+
def add_tool_response(
|
40 |
+
self,
|
41 |
+
name: str,
|
42 |
+
agents: List[BaseAgent],
|
43 |
+
tool_response: List[str],
|
44 |
+
):
|
45 |
+
for agent in agents:
|
46 |
+
if agent.name != name:
|
47 |
+
continue
|
48 |
+
if agent.tool_memory is not None:
|
49 |
+
agent.tool_memory.add_message(tool_response)
|
50 |
+
break
|
51 |
+
|
52 |
+
def add_message_to_all_agents(
|
53 |
+
self, agents: List[BaseAgent], message: Message
|
54 |
+
) -> bool:
|
55 |
+
if "all" in message.receiver:
|
56 |
+
# If receiver is all, then add the message to all agents
|
57 |
+
for agent in agents:
|
58 |
+
agent.add_message_to_memory([message])
|
59 |
+
return True
|
60 |
+
else:
|
61 |
+
# If receiver is not all, then add the message to the specified agents
|
62 |
+
receiver_set = message.receiver
|
63 |
+
for agent in agents:
|
64 |
+
if agent.name in receiver_set:
|
65 |
+
agent.add_message_to_memory([message])
|
66 |
+
receiver_set.remove(agent.name)
|
67 |
+
if len(receiver_set) > 0:
|
68 |
+
missing_receiver = ", ".join(list(receiver_set))
|
69 |
+
# raise ValueError(
|
70 |
+
# "Receiver {} not found. Message discarded".format(missing_receiver)
|
71 |
+
# )
|
72 |
+
logger.warn(
|
73 |
+
"Receiver {} not found. Message discarded".format(missing_receiver)
|
74 |
+
)
|
75 |
+
return True
|