File size: 1,928 Bytes
01523b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import List, Union

from pydantic import Field

from agentverse.message import Message
from agentverse.llms import BaseLLM
from agentverse.llms.openai import get_embedding, OpenAIChat


from . import memory_registry
from .base import BaseMemory



@memory_registry.register("vectorstore")
class VectorStoreMemory(BaseMemory):

    """

    The main difference of this class with chat_history is that this class treat memory as a dict

    treat message.content as memory

    Attributes:
        messages (List[Message]) : used to store messages, message.content is the key of embeddings.
        embedding2memory (dict) : `key` is the embedding and `value` is the message
        memory2embedding (dict) : `key` is the message and `value` is the embedding
        llm (BaseLLM) : llm used to get embeddings


    Methods:
        add_message : Additionally, add the embedding to embeddings

    """

    messages: List[Message] = Field(default=[])
    embedding2memory: dict = {}
    memory2embedding: dict = {}
    llm: BaseLLM = OpenAIChat(model="gpt-4")

    def add_message(self, messages: List[Message]) -> None:
        for message in messages:
            self.messages.append(message)
            memory_embedding = get_embedding(message.content)
            self.embedding2memory[memory_embedding] = message.content
            self.memory2embedding[message.content] = memory_embedding

    def to_string(self, add_sender_prefix: bool = False) -> str:
        if add_sender_prefix:
            return "\n".join(
                [
                    f"[{message.sender}]: {message.content}"
                    if message.sender != ""
                    else message.content
                    for message in self.messages
                ]
            )
        else:
            return "\n".join([message.content for message in self.messages])

    def reset(self) -> None:
        self.messages = []