|
from typing import Dict, List, Optional, Union |
|
|
|
from lagent.agents.aggregator.default_aggregator import DefaultAggregator |
|
from lagent.memory.base_memory import Memory |
|
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode |
|
|
|
|
|
class InternLMToolAggregator(DefaultAggregator): |
|
|
|
def __init__(self, |
|
environment_role='environment', |
|
environment_begin='', |
|
environment_end='', |
|
user_names: Optional[List[str]] = None, |
|
few_shot: Optional[List[List[dict]]] = None): |
|
self.environment_role = environment_role |
|
self.environment_begin = environment_begin |
|
self.environment_end = environment_end |
|
self.user_names = user_names or ['user'] |
|
self.few_shot = few_shot or [] |
|
|
|
def aggregate(self, |
|
messages: Memory, |
|
name: str, |
|
parser: Union[ToolParser, MixedToolParser], |
|
system_instruction: str = None) -> List[Dict[str, str]]: |
|
_message = [] |
|
messages = messages.get_memory() |
|
if system_instruction: |
|
_message.extend( |
|
self.aggregate_system_intruction(system_instruction)) |
|
tool_instruction = parser.format_instruction() |
|
if tool_instruction: |
|
if isinstance(tool_instruction, str): |
|
tool_instruction = dict( |
|
role='system', content=tool_instruction) |
|
if parser.tool_type: |
|
tool_instruction['name'] = parser.tool_type |
|
if isinstance(tool_instruction, dict): |
|
tool_instruction = [tool_instruction] |
|
_message.extend(tool_instruction) |
|
|
|
for shot in self.few_shot: |
|
i = 0 |
|
while i < len(shot): |
|
msg = shot[i] |
|
if msg['role'] in ['assistant', 'user', 'system']: |
|
_message.append(msg) |
|
elif msg['role'] == self.environment_role: |
|
if not msg['content'].startswith(self.environment_begin): |
|
msg['content'] = self.environment_begin + msg['content'] |
|
if not msg['content'].endswith(self.environment_end): |
|
msg['content'] += self.environment_end |
|
_message.append(msg) |
|
elif msg['role'] in ['thought', 'language']: |
|
if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool': |
|
_message.append( |
|
dict( |
|
role='assistant', |
|
content=parser.format_response( |
|
dict( |
|
tool_type=shot[i + 1]['name'], |
|
thought=msg['content'], |
|
action=shot[i + 1]['content'], |
|
status=None)))) |
|
i += 1 |
|
else: |
|
_message.append( |
|
dict( |
|
role='assistant', |
|
content=parser.format_response( |
|
dict( |
|
tool_type=None, |
|
thought=msg['content'], |
|
action=None, |
|
status=None)))) |
|
else: |
|
raise KeyError(f'Unkown role: {msg["role"]}') |
|
i += 1 |
|
|
|
tool_type = None |
|
for message in messages: |
|
if message.sender == name: |
|
if isinstance(message.formatted, dict): |
|
parsed = message.formatted |
|
if parsed['status'] == ToolStatusCode.PARSING_ERROR: |
|
continue |
|
_message.append( |
|
dict( |
|
role='assistant', |
|
content=parser.format_response(parsed))) |
|
tool_type = parsed['tool_type'] |
|
else: |
|
_message.append( |
|
dict(role='assistant', content=str(message.content))) |
|
elif message.sender in self.user_names: |
|
_message.append(dict(role='user', content=message.content)) |
|
else: |
|
msg = dict( |
|
role=self.environment_role, |
|
content=self.environment_begin + str(message.content) + |
|
self.environment_end) |
|
if tool_type: |
|
msg['name'] = tool_type |
|
_message.append(msg) |
|
return _message |
|
|