USTC975's picture
create the app
43c34cc
import os
from typing import List
from tenacity import retry, stop_after_attempt, wait_random_exponential
from ..message import Message
from .base import IntelligenceBackend
# Try to import the cohere package and check whether the API key is set
try:
import cohere
except ImportError:
is_cohere_available = False
else:
if os.environ.get("COHEREAI_API_KEY") is None:
is_cohere_available = False
else:
is_cohere_available = True
# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
DEFAULT_TEMPERATURE = 0.8
DEFAULT_MAX_TOKENS = 200
DEFAULT_MODEL = "command-xlarge"
class CohereAIChat(IntelligenceBackend):
"""Interface to the Cohere API."""
stateful = True
type_name = "cohere-chat"
def __init__(
self,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
model: str = DEFAULT_MODEL,
**kwargs,
):
super().__init__(
temperature=temperature, max_tokens=max_tokens, model=model, **kwargs
)
self.temperature = temperature
self.max_tokens = max_tokens
self.model = model
assert (
is_cohere_available
), "Cohere package is not installed or the API key is not set"
self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY"))
# Stateful variables
self.session_id = None # The session id for the last conversation
self.last_msg_hash = (
None # The hash of the last message of the last conversation
)
def reset(self):
self.session_id = None
self.last_msg_hash = None
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
def _get_response(self, new_message: str, persona_prompt: str):
response = self.client.chat(
new_message,
persona_prompt=persona_prompt,
temperature=self.temperature,
max_tokens=self.max_tokens,
session_id=self.session_id,
)
self.session_id = response.session_id # Update the session id
return response.reply
def query(
self,
agent_name: str,
role_desc: str,
history_messages: List[Message],
global_prompt: str = None,
request_msg: Message = None,
*args,
**kwargs,
) -> str:
"""
Format the input and call the Cohere API.
args:
agent_name: the name of the agent
role_desc: the description of the role of the agent
env_desc: the description of the environment
history_messages: the history of the conversation, or the observation for the agent
request_msg: the request for the CohereAI
"""
# Find the index of the last message of the last conversation
new_message_start_idx = 0
if self.last_msg_hash is not None:
for i, message in enumerate(history_messages):
if message.msg_hash == self.last_msg_hash:
new_message_start_idx = i + 1
break
new_messages = history_messages[new_message_start_idx:]
assert len(new_messages) > 0, "No new messages found (this should not happen)"
new_conversations = []
for message in new_messages:
if message.agent_name != agent_name:
# Since there are more than one player, we need to distinguish between the players
new_conversations.append(f"[{message.agent_name}]: {message.content}")
if request_msg:
new_conversations.append(
f"[{request_msg.agent_name}]: {request_msg.content}"
)
# Concatenate all new messages into one message because the Cohere API only accepts one message
new_message = "\n".join(new_conversations)
persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"
response = self._get_response(new_message, persona_prompt)
# Only update the last message hash if the API call is successful
self.last_msg_hash = new_messages[-1].msg_hash
return response