File size: 4,271 Bytes
43c34cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from typing import List

from tenacity import retry, stop_after_attempt, wait_random_exponential

from ..message import Message
from .base import IntelligenceBackend

# Try to import the cohere package and check whether the API key is set
try:
    import cohere
except ImportError:
    is_cohere_available = False
else:
    if os.environ.get("COHEREAI_API_KEY") is None:
        is_cohere_available = False
    else:
        is_cohere_available = True

# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
DEFAULT_TEMPERATURE = 0.8
DEFAULT_MAX_TOKENS = 200
DEFAULT_MODEL = "command-xlarge"


class CohereAIChat(IntelligenceBackend):
    """Interface to the Cohere API."""

    stateful = True
    type_name = "cohere-chat"

    def __init__(
        self,
        temperature: float = DEFAULT_TEMPERATURE,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        model: str = DEFAULT_MODEL,
        **kwargs,
    ):
        super().__init__(
            temperature=temperature, max_tokens=max_tokens, model=model, **kwargs
        )

        self.temperature = temperature
        self.max_tokens = max_tokens
        self.model = model

        assert (
            is_cohere_available
        ), "Cohere package is not installed or the API key is not set"
        self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY"))

        # Stateful variables
        self.session_id = None  # The session id for the last conversation
        self.last_msg_hash = (
            None  # The hash of the last message of the last conversation
        )

    def reset(self):
        self.session_id = None
        self.last_msg_hash = None

    @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
    def _get_response(self, new_message: str, persona_prompt: str):
        response = self.client.chat(
            new_message,
            persona_prompt=persona_prompt,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            session_id=self.session_id,
        )

        self.session_id = response.session_id  # Update the session id
        return response.reply

    def query(
        self,
        agent_name: str,
        role_desc: str,
        history_messages: List[Message],
        global_prompt: str = None,
        request_msg: Message = None,
        *args,
        **kwargs,
    ) -> str:
        """
        Format the input and call the Cohere API.

        args:
            agent_name: the name of the agent
            role_desc: the description of the role of the agent
            env_desc: the description of the environment
            history_messages: the history of the conversation, or the observation for the agent
            request_msg: the request for the CohereAI
        """
        # Find the index of the last message of the last conversation
        new_message_start_idx = 0
        if self.last_msg_hash is not None:
            for i, message in enumerate(history_messages):
                if message.msg_hash == self.last_msg_hash:
                    new_message_start_idx = i + 1
                    break

        new_messages = history_messages[new_message_start_idx:]
        assert len(new_messages) > 0, "No new messages found (this should not happen)"

        new_conversations = []
        for message in new_messages:
            if message.agent_name != agent_name:
                # Since there are more than one player, we need to distinguish between the players
                new_conversations.append(f"[{message.agent_name}]: {message.content}")

        if request_msg:
            new_conversations.append(
                f"[{request_msg.agent_name}]: {request_msg.content}"
            )

        # Concatenate all new messages into one message because the Cohere API only accepts one message
        new_message = "\n".join(new_conversations)
        persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"

        response = self._get_response(new_message, persona_prompt)

        # Only update the last message hash if the API call is successful
        self.last_msg_hash = new_messages[-1].msg_hash

        return response