import os from typing import List from tenacity import retry, stop_after_attempt, wait_random_exponential from ..message import Message from .base import IntelligenceBackend # Try to import the cohere package and check whether the API key is set try: import cohere except ImportError: is_cohere_available = False else: if os.environ.get("COHEREAI_API_KEY") is None: is_cohere_available = False else: is_cohere_available = True # Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat) DEFAULT_TEMPERATURE = 0.8 DEFAULT_MAX_TOKENS = 200 DEFAULT_MODEL = "command-xlarge" class CohereAIChat(IntelligenceBackend): """Interface to the Cohere API.""" stateful = True type_name = "cohere-chat" def __init__( self, temperature: float = DEFAULT_TEMPERATURE, max_tokens: int = DEFAULT_MAX_TOKENS, model: str = DEFAULT_MODEL, **kwargs, ): super().__init__( temperature=temperature, max_tokens=max_tokens, model=model, **kwargs ) self.temperature = temperature self.max_tokens = max_tokens self.model = model assert ( is_cohere_available ), "Cohere package is not installed or the API key is not set" self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY")) # Stateful variables self.session_id = None # The session id for the last conversation self.last_msg_hash = ( None # The hash of the last message of the last conversation ) def reset(self): self.session_id = None self.last_msg_hash = None @retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60)) def _get_response(self, new_message: str, persona_prompt: str): response = self.client.chat( new_message, persona_prompt=persona_prompt, temperature=self.temperature, max_tokens=self.max_tokens, session_id=self.session_id, ) self.session_id = response.session_id # Update the session id return response.reply def query( self, agent_name: str, role_desc: str, history_messages: List[Message], global_prompt: str = None, request_msg: Message = None, *args, **kwargs, ) -> str: """ Format the input and call the Cohere API. args: agent_name: the name of the agent role_desc: the description of the role of the agent env_desc: the description of the environment history_messages: the history of the conversation, or the observation for the agent request_msg: the request for the CohereAI """ # Find the index of the last message of the last conversation new_message_start_idx = 0 if self.last_msg_hash is not None: for i, message in enumerate(history_messages): if message.msg_hash == self.last_msg_hash: new_message_start_idx = i + 1 break new_messages = history_messages[new_message_start_idx:] assert len(new_messages) > 0, "No new messages found (this should not happen)" new_conversations = [] for message in new_messages: if message.agent_name != agent_name: # Since there are more than one player, we need to distinguish between the players new_conversations.append(f"[{message.agent_name}]: {message.content}") if request_msg: new_conversations.append( f"[{request_msg.agent_name}]: {request_msg.content}" ) # Concatenate all new messages into one message because the Cohere API only accepts one message new_message = "\n".join(new_conversations) persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}" response = self._get_response(new_message, persona_prompt) # Only update the last message hash if the API call is successful self.last_msg_hash = new_messages[-1].msg_hash return response