File size: 4,753 Bytes
499a238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Dict, List, Optional, Union

from lagent.agents.aggregator.default_aggregator import DefaultAggregator
from lagent.memory.base_memory import Memory
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode


class InternLMToolAggregator(DefaultAggregator):

    def __init__(self,
                 environment_role='environment',
                 environment_begin='',
                 environment_end='',
                 user_names: Optional[List[str]] = None,
                 few_shot: Optional[List[List[dict]]] = None):
        self.environment_role = environment_role
        self.environment_begin = environment_begin
        self.environment_end = environment_end
        self.user_names = user_names or ['user']
        self.few_shot = few_shot or []

    def aggregate(self,
                  messages: Memory,
                  name: str,
                  parser: Union[ToolParser, MixedToolParser],
                  system_instruction: str = None) -> List[Dict[str, str]]:
        _message = []
        messages = messages.get_memory()
        if system_instruction:
            _message.extend(
                self.aggregate_system_intruction(system_instruction))
        tool_instruction = parser.format_instruction()
        if tool_instruction:
            if isinstance(tool_instruction, str):
                tool_instruction = dict(
                    role='system', content=tool_instruction)
                if parser.tool_type:
                    tool_instruction['name'] = parser.tool_type
            if isinstance(tool_instruction, dict):
                tool_instruction = [tool_instruction]
            _message.extend(tool_instruction)

        for shot in self.few_shot:
            i = 0
            while i < len(shot):
                msg = shot[i]
                if msg['role'] in ['assistant', 'user', 'system']:
                    _message.append(msg)
                elif msg['role'] == self.environment_role:
                    if not msg['content'].startswith(self.environment_begin):
                        msg['content'] = self.environment_begin + msg['content']
                    if not msg['content'].endswith(self.environment_end):
                        msg['content'] += self.environment_end
                    _message.append(msg)
                elif msg['role'] in ['thought', 'language']:
                    if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool':
                        _message.append(
                            dict(
                                role='assistant',
                                content=parser.format_response(
                                    dict(
                                        tool_type=shot[i + 1]['name'],
                                        thought=msg['content'],
                                        action=shot[i + 1]['content'],
                                        status=None))))
                        i += 1
                    else:
                        _message.append(
                            dict(
                                role='assistant',
                                content=parser.format_response(
                                    dict(
                                        tool_type=None,
                                        thought=msg['content'],
                                        action=None,
                                        status=None))))
                else:
                    raise KeyError(f'Unkown role: {msg["role"]}')
                i += 1

        tool_type = None
        for message in messages:
            if message.sender == name:
                if isinstance(message.formatted, dict):
                    parsed = message.formatted
                    if parsed['status'] == ToolStatusCode.PARSING_ERROR:
                        continue
                    _message.append(
                        dict(
                            role='assistant',
                            content=parser.format_response(parsed)))
                    tool_type = parsed['tool_type']
                else:
                    _message.append(
                        dict(role='assistant', content=str(message.content)))
            elif message.sender in self.user_names:
                _message.append(dict(role='user', content=message.content))
            else:
                msg = dict(
                    role=self.environment_role,
                    content=self.environment_begin + str(message.content) +
                    self.environment_end)
                if tool_type:
                    msg['name'] = tool_type
                _message.append(msg)
        return _message