|
import copy |
|
import warnings |
|
from collections import OrderedDict, UserDict, UserList, abc |
|
from functools import wraps |
|
from itertools import chain, repeat |
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union |
|
|
|
from lagent.agents.aggregator import DefaultAggregator |
|
from lagent.hooks import Hook, RemovableHandle |
|
from lagent.llms import BaseLLM |
|
from lagent.memory import Memory, MemoryManager |
|
from lagent.prompts.parsers import StrParser |
|
from lagent.prompts.prompt_template import PromptTemplate |
|
from lagent.schema import AgentMessage |
|
from lagent.utils import create_object |
|
|
|
|
|
class Agent: |
|
"""Agent is the basic unit of the system. It is responsible for |
|
communicating with the LLM, managing the memory, and handling the |
|
message aggregation and parsing. It can also be extended with hooks |
|
|
|
Args: |
|
llm (Union[BaseLLM, Dict]): The language model used by the agent. |
|
template (Union[PromptTemplate, str]): The template used to format the |
|
messages. |
|
memory (Dict): The memory used by the agent. |
|
output_format (Dict): The output format used by the agent. |
|
aggregator (Dict): The aggregator used by the agent. |
|
name (Optional[str]): The name of the agent. |
|
description (Optional[str]): The description of the agent. |
|
hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent. |
|
|
|
Returns: |
|
AgentMessage: The response message. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
llm: Union[BaseLLM, Dict] = None, |
|
template: Union[PromptTemplate, str, dict, List[dict]] = None, |
|
memory: Dict = dict(type=Memory), |
|
output_format: Optional[Dict] = None, |
|
aggregator: Dict = dict(type=DefaultAggregator), |
|
name: Optional[str] = None, |
|
description: Optional[str] = None, |
|
hooks: Optional[Union[List[Dict], Dict]] = None, |
|
): |
|
self.name = name or self.__class__.__name__ |
|
self.llm: BaseLLM = create_object(llm) |
|
self.memory: MemoryManager = MemoryManager(memory) if memory else None |
|
self.output_format: StrParser = create_object(output_format) |
|
self.template = template |
|
self.description = description |
|
self.aggregator: DefaultAggregator = create_object(aggregator) |
|
self._hooks: Dict[int, Hook] = OrderedDict() |
|
if hooks: |
|
for hook in hooks: |
|
hook = create_object(hook) |
|
self.register_hook(hook) |
|
|
|
def update_memory(self, message, session_id=0): |
|
if self.memory: |
|
self.memory.add(message, session_id=session_id) |
|
|
|
def __call__( |
|
self, |
|
*message: Union[str, AgentMessage, List[AgentMessage]], |
|
session_id=0, |
|
**kwargs, |
|
) -> AgentMessage: |
|
|
|
message = [ |
|
AgentMessage(sender='user', content=m) |
|
if isinstance(m, str) else copy.deepcopy(m) for m in message |
|
] |
|
for hook in self._hooks.values(): |
|
result = hook.before_agent(self, message, session_id) |
|
if result: |
|
message = result |
|
self.update_memory(message, session_id=session_id) |
|
response_message = self.forward( |
|
*message, session_id=session_id, **kwargs) |
|
if not isinstance(response_message, AgentMessage): |
|
response_message = AgentMessage( |
|
sender=self.name, |
|
content=response_message, |
|
) |
|
self.update_memory(response_message, session_id=session_id) |
|
response_message = copy.deepcopy(response_message) |
|
for hook in self._hooks.values(): |
|
result = hook.after_agent(self, response_message, session_id) |
|
if result: |
|
response_message = result |
|
return response_message |
|
|
|
def forward(self, |
|
*message: AgentMessage, |
|
session_id=0, |
|
**kwargs) -> Union[AgentMessage, str]: |
|
formatted_messages = self.aggregator.aggregate( |
|
self.memory.get(session_id), |
|
self.name, |
|
self.output_format, |
|
self.template, |
|
) |
|
llm_response = self.llm.chat(formatted_messages, **kwargs) |
|
if self.output_format: |
|
formatted_messages = self.output_format.parse_response( |
|
llm_response) |
|
return AgentMessage( |
|
sender=self.name, |
|
content=llm_response, |
|
formatted=formatted_messages, |
|
) |
|
return llm_response |
|
|
|
def __setattr__(self, __name: str, __value: Any) -> None: |
|
if isinstance(__value, Agent): |
|
_agents = getattr(self, '_agents', OrderedDict()) |
|
_agents[__name] = __value |
|
super().__setattr__('_agents', _agents) |
|
super().__setattr__(__name, __value) |
|
|
|
def state_dict(self, session_id=0): |
|
state_dict, stack = {}, [('', self)] |
|
while stack: |
|
prefix, node = stack.pop() |
|
key = prefix + 'memory' |
|
if node.memory is not None: |
|
if session_id not in node.memory.memory_map: |
|
warnings.warn(f'No session id {session_id} in {key}') |
|
memory = node.memory.get(session_id) |
|
state_dict[key] = memory and memory.save() or [] |
|
if hasattr(node, '_agents'): |
|
for name, value in reversed(node._agents.items()): |
|
stack.append((prefix + name + '.', value)) |
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict: Dict, session_id=0): |
|
_state_dict = self.state_dict() |
|
missing_keys = set(_state_dict) - set(state_dict) |
|
if missing_keys: |
|
raise KeyError(f'Missing keys: {missing_keys}') |
|
extra_keys = set(state_dict) - set(_state_dict) |
|
if extra_keys: |
|
warnings.warn(f'Mismatch keys which are not used: {extra_keys}') |
|
for key in _state_dict: |
|
obj = self |
|
for attr in key.split('.')[:-1]: |
|
if isinstance(obj, AgentList): |
|
assert attr.isdigit() |
|
obj = obj[int(attr)] |
|
elif isinstance(obj, AgentDict): |
|
obj = obj[attr] |
|
else: |
|
obj = getattr(obj, attr) |
|
if obj.memory is not None: |
|
if session_id not in obj.memory.memory_map: |
|
obj.memory.create_instance(session_id) |
|
obj.memory.memory_map[session_id].load(state_dict[key] or []) |
|
|
|
def register_hook(self, hook: Callable): |
|
handle = RemovableHandle(self._hooks) |
|
self._hooks[handle.id] = hook |
|
return handle |
|
|
|
def reset(self, |
|
session_id=0, |
|
keypath: Optional[str] = None, |
|
recursive: bool = False): |
|
assert not (keypath and |
|
recursive), 'keypath and recursive can\'t be used together' |
|
if keypath: |
|
keys, agent = keypath.split('.'), self |
|
for key in keys: |
|
agents = getattr(agent, '_agents', {}) |
|
if key not in agents: |
|
raise KeyError(f'No sub-agent named {key} in {agent}') |
|
agent = agents[key] |
|
agent.reset(session_id, recursive=False) |
|
else: |
|
if self.memory: |
|
self.memory.reset(session_id=session_id) |
|
if recursive: |
|
for agent in getattr(self, '_agents', {}).values(): |
|
agent.reset(session_id, recursive=True) |
|
|
|
def __repr__(self): |
|
|
|
def _rcsv_repr(agent, n_indent=1): |
|
res = agent.__class__.__name__ + (f"(name='{agent.name}')" |
|
if agent.name else '') |
|
modules = [ |
|
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" |
|
for name, agent in getattr(agent, '_agents', {}).items() |
|
] |
|
if modules: |
|
res += '(\n' + '\n'.join( |
|
modules) + f'\n{(n_indent - 1) * " "})' |
|
elif not res.endswith(')'): |
|
res += '()' |
|
return res |
|
|
|
return _rcsv_repr(self) |
|
|
|
|
|
class AsyncAgent(Agent): |
|
|
|
async def __call__(self, |
|
*message: AgentMessage | List[AgentMessage], |
|
session_id=0, |
|
**kwargs) -> AgentMessage: |
|
message = [ |
|
AgentMessage(sender='user', content=m) |
|
if isinstance(m, str) else copy.deepcopy(m) for m in message |
|
] |
|
for hook in self._hooks.values(): |
|
result = hook.before_agent(self, message, session_id) |
|
if result: |
|
message = result |
|
self.update_memory(message, session_id=session_id) |
|
response_message = await self.forward( |
|
*message, session_id=session_id, **kwargs) |
|
if not isinstance(response_message, AgentMessage): |
|
response_message = AgentMessage( |
|
sender=self.name, |
|
content=response_message, |
|
) |
|
self.update_memory(response_message, session_id=session_id) |
|
response_message = copy.deepcopy(response_message) |
|
for hook in self._hooks.values(): |
|
result = hook.after_agent(self, response_message, session_id) |
|
if result: |
|
response_message = result |
|
return response_message |
|
|
|
async def forward(self, |
|
*message: AgentMessage, |
|
session_id=0, |
|
**kwargs) -> Union[AgentMessage, str]: |
|
formatted_messages = self.aggregator.aggregate( |
|
self.memory.get(session_id), |
|
self.name, |
|
self.output_format, |
|
self.template, |
|
) |
|
llm_response = await self.llm.chat(formatted_messages, session_id, |
|
**kwargs) |
|
if self.output_format: |
|
formatted_messages = self.output_format.parse_response( |
|
llm_response) |
|
return AgentMessage( |
|
sender=self.name, |
|
content=llm_response, |
|
formatted=formatted_messages, |
|
) |
|
return llm_response |
|
|
|
|
|
class Sequential(Agent): |
|
"""Sequential is an agent container that forwards messages to each agent |
|
in the order they are added.""" |
|
|
|
def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): |
|
super().__init__(**kwargs) |
|
self._agents = OrderedDict() |
|
if not agents: |
|
raise ValueError('At least one agent should be provided') |
|
if isinstance(agents[0], |
|
Iterable) and not isinstance(agents[0], Agent): |
|
if not agents[0]: |
|
raise ValueError('At least one agent should be provided') |
|
agents = agents[0] |
|
for key, agent in enumerate(agents): |
|
if isinstance(agents, Mapping): |
|
key, agent = agent, agents[agent] |
|
elif isinstance(agent, tuple): |
|
key, agent = agent |
|
self.add_agent(key, agent) |
|
|
|
def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): |
|
assert isinstance( |
|
agent, (Agent, AsyncAgent |
|
)), f'{type(agent)} is not an Agent or AsyncAgent subclass' |
|
self._agents[str(name)] = agent |
|
|
|
def forward(self, |
|
*message: AgentMessage, |
|
session_id=0, |
|
exit_at: Optional[int] = None, |
|
**kwargs) -> AgentMessage: |
|
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' |
|
if exit_at is None: |
|
exit_at = len(self) - 1 |
|
iterator = chain.from_iterable(repeat(self._agents.values())) |
|
for _ in range(exit_at + 1): |
|
agent = next(iterator) |
|
if isinstance(message, AgentMessage): |
|
message = (message, ) |
|
message = agent(*message, session_id=session_id, **kwargs) |
|
return message |
|
|
|
def __getitem__(self, key): |
|
if isinstance(key, int) and key < 0: |
|
assert key >= -len(self), 'index out of range' |
|
key = len(self) + key |
|
return self._agents[str(key)] |
|
|
|
def __len__(self): |
|
return len(self._agents) |
|
|
|
|
|
class AsyncSequential(Sequential, AsyncAgent): |
|
|
|
async def forward(self, |
|
*message: AgentMessage, |
|
session_id=0, |
|
exit_at: Optional[int] = None, |
|
**kwargs) -> AgentMessage: |
|
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' |
|
if exit_at is None: |
|
exit_at = len(self) - 1 |
|
iterator = chain.from_iterable(repeat(self._agents.values())) |
|
for _ in range(exit_at + 1): |
|
agent = next(iterator) |
|
if isinstance(message, AgentMessage): |
|
message = (message, ) |
|
message = await agent(*message, session_id=session_id, **kwargs) |
|
return message |
|
|
|
|
|
class AgentContainerMixin: |
|
|
|
def __init_subclass__(cls): |
|
super().__init_subclass__() |
|
|
|
def wrap_api(func): |
|
|
|
@wraps(func) |
|
def wrapped_func(self, *args, **kwargs): |
|
data = self.data.copy() if hasattr(self, 'data') else None |
|
|
|
def _backup(d): |
|
if d is None: |
|
self.data.clear() |
|
else: |
|
self.data = d |
|
|
|
ret = func(self, *args, **kwargs) |
|
agents = OrderedDict() |
|
for k, item in (self.data.items() if isinstance( |
|
self.data, abc.Mapping) else enumerate(self.data)): |
|
if isinstance(self.data, |
|
abc.Mapping) and not isinstance(k, str): |
|
_backup(data) |
|
raise KeyError( |
|
f'agent name should be a string, got {type(k)}') |
|
if isinstance(k, str) and '.' in k: |
|
_backup(data) |
|
raise KeyError( |
|
f'agent name can\'t contain ".", got {k}') |
|
if not isinstance(item, (Agent, AsyncAgent)): |
|
_backup(data) |
|
raise TypeError( |
|
f'{type(item)} is not an Agent or AsyncAgent subclass' |
|
) |
|
agents[str(k)] = item |
|
self._agents = agents |
|
return ret |
|
|
|
return wrapped_func |
|
|
|
for method in [ |
|
'append', 'sort', 'reverse', 'pop', 'clear', 'update', |
|
'insert', 'extend', 'remove', '__init__', '__setitem__', |
|
'__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', |
|
'__imul__', '__rmul__' |
|
]: |
|
if hasattr(cls, method): |
|
setattr(cls, method, wrap_api(getattr(cls, method))) |
|
|
|
|
|
class AgentList(Agent, UserList, AgentContainerMixin): |
|
|
|
def __init__(self, |
|
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): |
|
Agent.__init__(self, memory=None) |
|
UserList.__init__(self, agents) |
|
self.name = None |
|
|
|
|
|
class AgentDict(Agent, UserDict, AgentContainerMixin): |
|
|
|
def __init__(self, |
|
agents: Optional[Mapping[str, Union[Agent, |
|
AsyncAgent]]] = None): |
|
Agent.__init__(self, memory=None) |
|
UserDict.__init__(self, agents) |
|
self.name = None |
|
|