Spaces:
Sleeping
Sleeping
import os | |
from typing import List | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
from ..message import Message | |
from .base import IntelligenceBackend | |
# Try to import the cohere package and check whether the API key is set | |
try: | |
import cohere | |
except ImportError: | |
is_cohere_available = False | |
else: | |
if os.environ.get("COHEREAI_API_KEY") is None: | |
is_cohere_available = False | |
else: | |
is_cohere_available = True | |
# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat) | |
DEFAULT_TEMPERATURE = 0.8 | |
DEFAULT_MAX_TOKENS = 200 | |
DEFAULT_MODEL = "command-xlarge" | |
class CohereAIChat(IntelligenceBackend): | |
"""Interface to the Cohere API.""" | |
stateful = True | |
type_name = "cohere-chat" | |
def __init__( | |
self, | |
temperature: float = DEFAULT_TEMPERATURE, | |
max_tokens: int = DEFAULT_MAX_TOKENS, | |
model: str = DEFAULT_MODEL, | |
**kwargs, | |
): | |
super().__init__( | |
temperature=temperature, max_tokens=max_tokens, model=model, **kwargs | |
) | |
self.temperature = temperature | |
self.max_tokens = max_tokens | |
self.model = model | |
assert ( | |
is_cohere_available | |
), "Cohere package is not installed or the API key is not set" | |
self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY")) | |
# Stateful variables | |
self.session_id = None # The session id for the last conversation | |
self.last_msg_hash = ( | |
None # The hash of the last message of the last conversation | |
) | |
def reset(self): | |
self.session_id = None | |
self.last_msg_hash = None | |
def _get_response(self, new_message: str, persona_prompt: str): | |
response = self.client.chat( | |
new_message, | |
persona_prompt=persona_prompt, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
session_id=self.session_id, | |
) | |
self.session_id = response.session_id # Update the session id | |
return response.reply | |
def query( | |
self, | |
agent_name: str, | |
role_desc: str, | |
history_messages: List[Message], | |
global_prompt: str = None, | |
request_msg: Message = None, | |
*args, | |
**kwargs, | |
) -> str: | |
""" | |
Format the input and call the Cohere API. | |
args: | |
agent_name: the name of the agent | |
role_desc: the description of the role of the agent | |
env_desc: the description of the environment | |
history_messages: the history of the conversation, or the observation for the agent | |
request_msg: the request for the CohereAI | |
""" | |
# Find the index of the last message of the last conversation | |
new_message_start_idx = 0 | |
if self.last_msg_hash is not None: | |
for i, message in enumerate(history_messages): | |
if message.msg_hash == self.last_msg_hash: | |
new_message_start_idx = i + 1 | |
break | |
new_messages = history_messages[new_message_start_idx:] | |
assert len(new_messages) > 0, "No new messages found (this should not happen)" | |
new_conversations = [] | |
for message in new_messages: | |
if message.agent_name != agent_name: | |
# Since there are more than one player, we need to distinguish between the players | |
new_conversations.append(f"[{message.agent_name}]: {message.content}") | |
if request_msg: | |
new_conversations.append( | |
f"[{request_msg.agent_name}]: {request_msg.content}" | |
) | |
# Concatenate all new messages into one message because the Cohere API only accepts one message | |
new_message = "\n".join(new_conversations) | |
persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}" | |
response = self._get_response(new_message, persona_prompt) | |
# Only update the last message hash if the API call is successful | |
self.last_msg_hash = new_messages[-1].msg_hash | |
return response | |