File size: 12,637 Bytes
01523b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from __future__ import annotations
from typing import List, Union, Optional, Any, TYPE_CHECKING
from collections import defaultdict

from pydantic import Field
import numpy as np
from datetime import datetime as dt

import re

from agentverse.llms.openai import get_embedding
from sklearn.metrics.pairwise import cosine_similarity

from agentverse.message import Message
from agentverse.memory import BaseMemory

from agentverse.logging import logger

from . import memory_manipulator_registry
from .base import BaseMemoryManipulator

if TYPE_CHECKING:
    from agentverse.memory import VectorStoreMemory
    from agentverse.agents.base import BaseAgent


IMPORTANCE_PROMPT = """On the scale of 1 to 10, where 1 is purely mundane \
(e.g., brushing teeth, making bed) and 10 is \
extremely poignant (e.g., a break up, college \
acceptance), rate the likely poignancy of the \
following piece of memory. \
If you think it's too hard to rate it, you can give an inaccurate assessment. \
The content or people mentioned is not real. You can hypothesis any reasonable context. \
Please strictly only output one number. \
Memory: {} \
Rating: """
IMMEDIACY_PROMPT = """On the scale of 1 to 10, where 1 is requiring no short time attention\
(e.g., a bed is in the room) and 10 is \
needing quick attention or immediate response(e.g., being required a reply by others), rate the likely immediacy of the \
following statement. \
If you think it's too hard to rate it, you can give an inaccurate assessment. \
The content or people mentioned is not real. You can hypothesis any reasonable context. \
Please strictly only output one number. \
Memory: {} \
Rating: """
QUESTION_PROMPT = """Given only the information above, what are 3 most salient \
high-level questions we can answer about the subjects in the statements?"""

INSIGHT_PROMPT = """What at most 5 high-level insights can you infer from \
the above statements? Only output insights with high confidence. 
example format: insight (because of 1, 5, 3)"""


@memory_manipulator_registry.register("reflection")
class Reflection(BaseMemoryManipulator):
    memory: VectorStoreMemory = None
    agent: BaseAgent = None

    reflection: str = ""

    importance_threshold: int = 10
    accumulated_importance: int = 0

    memory2importance: dict = {}
    memory2immediacy: dict = {}
    memory2time: defaultdict = Field(default=defaultdict(dict))

    # TODO newly added func from generative agents

    def manipulate_memory(self) -> None:
        # reflect here
        if self.should_reflect():
            logger.debug(
                f"Agent {self.agent.name} is now doing reflection since accumulated_importance={self.accumulated_importance} < reflection_threshold={self.importance_threshold}"
            )
            self.reflection = self.reflect()
            return self.reflection
        else:
            logger.debug(
                f"Agent {self.agent.name} doesn't reflect since accumulated_importance={self.accumulated_importance} < reflection_threshold={self.importance_threshold}"
            )

            return ""

    def get_accumulated_importance(self):
        accumulated_importance = 0

        for memory in self.memory.messages:
            if (
                memory.content not in self.memory2importance
                or memory.content not in self.memory2immediacy
            ):
                self.memory2importance[memory.content] = self.get_importance(
                    memory.content
                )
                self.memory2immediacy[memory.content] = self.get_immediacy(
                    memory.content
                )

        for score in self.memory2importance.values():
            accumulated_importance += score

        self.accumulated_importance = accumulated_importance

        return accumulated_importance

    def should_reflect(self):
        if self.get_accumulated_importance() >= self.importance_threshold:
            # double the importance_threshold
            self.importance_threshold *= 2
            return True
        else:
            return False

    def get_questions(self, texts):
        prompt = "\n".join(texts) + "\n" + QUESTION_PROMPT
        result = self.agent.llm.generate_response(prompt)
        result = result.content
        questions = [q for q in result.split("\n") if len(q.strip()) > 0]
        questions = questions[:3]
        return questions

    def get_insights(self, statements):
        prompt = ""
        for i, st in enumerate(statements):
            prompt += str(i + 1) + ". " + st + "\n"
        prompt += INSIGHT_PROMPT
        result = self.agent.llm.generate_response(prompt)
        result = result.content
        insights = [isg for isg in result.split("\n") if len(isg.strip()) > 0][:5]
        insights = [".".join(i.split(".")[1:]) for i in insights]
        # remove insight pointers for now
        insights = [i.split("(")[0].strip() for i in insights]
        return insights

    def get_importance(self, content: str):
        """
        Exploit GPT to evaluate the importance of this memory
        """
        prompt = IMPORTANCE_PROMPT.format(content)
        result = self.memory.llm.generate_response(prompt)

        try:
            score = int(re.findall(r"\s*(\d+)\s*", result.content)[0])
        except Exception as e:
            logger.warn(
                f"Found error {e} Abnormal result of importance rating '{result}'. Setting default value"
            )
            score = 0
        return score

    def get_immediacy(self, content: str):
        """
        Exploit GPT to evaluate the immediacy of this memory
        """
        prompt = IMMEDIACY_PROMPT.format(content)
        result = self.memory.llm.generate_response(prompt)
        try:
            score = int(re.findall(r"\s*(\d+)\s*", result.content)[0])
        except Exception as e:
            logger.warn(
                f"Found error {e} Abnormal result of importance rating '{result}'. Setting default value"
            )
            score = 0
        return score

    def query_similarity(
        self,
        text: Union[str, List[str]],
        k: int,
        memory_bank: List,
        current_time=dt.now(),
        nms_threshold=0.99,
    ) -> List[str]:
        """
        get top-k entry based on recency, relevance, importance, immediacy
        The query result can be Short-term or Long-term queried result.
        formula is
        `score= sim(q,v) *max(LTM_score, STM_score)`
        `STM_score=time_score(createTime)*immediacy`
        `LTM_score=time_score(accessTime)*importance`
        time score is exponential decay weight. stm decays faster.

        The query supports querying based on multiple texts and only gives non-overlapping results
        If nms_threshold is not 1, nms mechanism if activated. By default,
        use soft nms with modified iou base(score starts to decay iff cos sim is higher than this value,
         and decay weight at this value if 0. rather than 1-threshold).

        Args:
            text: str
            k: int
            memory_bank: List
            current_time: dt.now
            nms_threshold: float = 0.99


        Returns: List[str]
        """
        assert len(text) > 0
        texts = [text] if isinstance(text, str) else text
        maximum_score = None
        for text in texts:
            embedding = get_embedding(text)
            score = []
            for memory in memory_bank:
                if memory.content not in self.memory2time:
                    self.memory2time[memory.content]["last_access_time"] = dt.now()
                    self.memory2time[memory.content]["create_time"] = dt.now()

                last_access_time_diff = (
                    current_time - self.memory2time[memory.content]["last_access_time"]
                ).total_seconds() // 3600
                recency = np.power(
                    0.99, last_access_time_diff
                )  # TODO: review the metaparameter 0.99

                create_time_diff = (
                    current_time - self.memory2time[memory.content]["create_time"]
                ).total_seconds() // 60
                instancy = np.power(
                    0.90, create_time_diff
                )  # TODO: review the metaparameter 0.90

                relevance = cosine_similarity(
                    np.array(embedding).reshape(1, -1),
                    np.array(self.memory.memory2embedding[memory.content]).reshape(
                        1, -1
                    ),
                )[0][0]

                if (
                    memory.content not in self.memory2importance
                    or memory.content not in self.memory2immediacy
                ):
                    self.memory2importance[memory.content] = self.get_importance(
                        memory.content
                    )
                    self.memory2immediacy[memory.content] = self.get_immediacy(
                        memory.content
                    )

                importance = self.memory2importance[memory.content] / 10
                immediacy = self.memory2immediacy[memory.content] / 10

                ltm_w = recency * importance
                stm_w = instancy * immediacy

                score.append(relevance * np.maximum(ltm_w, stm_w))

            score = np.array(score)

            if maximum_score is not None:
                maximum_score = np.maximum(score, maximum_score)
            else:
                maximum_score = score

        if nms_threshold == 1.0:
            # no nms is triggered
            top_k_indices = np.argsort(maximum_score)[-k:][::-1]
        else:
            # TODO: soft-nms
            assert 0 <= nms_threshold < 1
            top_k_indices = []
            while len(top_k_indices) < min(k, len(memory_bank)):
                top_index = np.argmax(maximum_score)
                top_k_indices.append(top_index)
                maximum_score[top_index] = -1  # anything to prevent being chosen again
                top_embedding = self.memory.memory2embedding[
                    memory_bank[top_index].content
                ]
                cos_sim = cosine_similarity(
                    np.array(top_embedding).reshape(1, -1),
                    np.array(
                        [
                            self.memory.memory2embedding[memory.content]
                            for memory in memory_bank
                        ]
                    ),
                )[0]
                score_weight = np.ones_like(maximum_score)
                score_weight[cos_sim >= nms_threshold] -= (
                    cos_sim[cos_sim >= nms_threshold] - nms_threshold
                ) / (1 - nms_threshold)
                maximum_score = maximum_score * score_weight

        # access them and refresh the access time
        for i in top_k_indices:
            self.memory2time[memory_bank[i].content]["last_access_time"] = current_time
        # sort them in time periods. if the data tag is 'observation', ad time info output.
        top_k_indices = sorted(
            top_k_indices,
            key=lambda x: self.memory2time[memory_bank[x].content]["create_time"],
        )
        query_results = []
        for i in top_k_indices:
            query_result = memory_bank[i].content
            query_results.append(query_result)

        return query_results

    def get_memories_of_interest_oneself(self):
        memories_of_interest = []
        for memory in self.memory.messages[-100:]:
            if memory.sender == self.agent.name:
                memories_of_interest.append(memory)
        return memories_of_interest

    def reflect(self):
        """
        initiate a reflection that inserts high level knowledge to memory
        """
        memories_of_interest = self.get_memories_of_interest_oneself()
        questions = self.get_questions([m.content for m in memories_of_interest])
        statements = self.query_similarity(
            questions, len(questions) * 10, memories_of_interest
        )
        insights = self.get_insights(statements)
        logger.info(self.agent.name + f" Insights: {insights}")
        for insight in insights:
            # convert insight to messages
            # TODO currently only oneself can see its own reflection
            insight_message = Message(
                content=insight, sender=self.agent.name, receiver={self.agent.name}
            )
            self.memory.add_message([insight_message])
        reflection = "\n".join(insights)
        return reflection

    def reset(self) -> None:
        self.reflection = ""