Spaces:
Sleeping
Sleeping
File size: 4,271 Bytes
43c34cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
from typing import List
from tenacity import retry, stop_after_attempt, wait_random_exponential
from ..message import Message
from .base import IntelligenceBackend
# Try to import the cohere package and check whether the API key is set
try:
import cohere
except ImportError:
is_cohere_available = False
else:
if os.environ.get("COHEREAI_API_KEY") is None:
is_cohere_available = False
else:
is_cohere_available = True
# Default config follows the [Cohere documentation](https://cohere-sdk.readthedocs.io/en/latest/cohere.html#cohere.client.Client.chat)
DEFAULT_TEMPERATURE = 0.8
DEFAULT_MAX_TOKENS = 200
DEFAULT_MODEL = "command-xlarge"
class CohereAIChat(IntelligenceBackend):
"""Interface to the Cohere API."""
stateful = True
type_name = "cohere-chat"
def __init__(
self,
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: int = DEFAULT_MAX_TOKENS,
model: str = DEFAULT_MODEL,
**kwargs,
):
super().__init__(
temperature=temperature, max_tokens=max_tokens, model=model, **kwargs
)
self.temperature = temperature
self.max_tokens = max_tokens
self.model = model
assert (
is_cohere_available
), "Cohere package is not installed or the API key is not set"
self.client = cohere.Client(os.environ.get("COHEREAI_API_KEY"))
# Stateful variables
self.session_id = None # The session id for the last conversation
self.last_msg_hash = (
None # The hash of the last message of the last conversation
)
def reset(self):
self.session_id = None
self.last_msg_hash = None
@retry(stop=stop_after_attempt(6), wait=wait_random_exponential(min=1, max=60))
def _get_response(self, new_message: str, persona_prompt: str):
response = self.client.chat(
new_message,
persona_prompt=persona_prompt,
temperature=self.temperature,
max_tokens=self.max_tokens,
session_id=self.session_id,
)
self.session_id = response.session_id # Update the session id
return response.reply
def query(
self,
agent_name: str,
role_desc: str,
history_messages: List[Message],
global_prompt: str = None,
request_msg: Message = None,
*args,
**kwargs,
) -> str:
"""
Format the input and call the Cohere API.
args:
agent_name: the name of the agent
role_desc: the description of the role of the agent
env_desc: the description of the environment
history_messages: the history of the conversation, or the observation for the agent
request_msg: the request for the CohereAI
"""
# Find the index of the last message of the last conversation
new_message_start_idx = 0
if self.last_msg_hash is not None:
for i, message in enumerate(history_messages):
if message.msg_hash == self.last_msg_hash:
new_message_start_idx = i + 1
break
new_messages = history_messages[new_message_start_idx:]
assert len(new_messages) > 0, "No new messages found (this should not happen)"
new_conversations = []
for message in new_messages:
if message.agent_name != agent_name:
# Since there are more than one player, we need to distinguish between the players
new_conversations.append(f"[{message.agent_name}]: {message.content}")
if request_msg:
new_conversations.append(
f"[{request_msg.agent_name}]: {request_msg.content}"
)
# Concatenate all new messages into one message because the Cohere API only accepts one message
new_message = "\n".join(new_conversations)
persona_prompt = f"Environment:\n{global_prompt}\n\nYour role:\n{role_desc}"
response = self._get_response(new_message, persona_prompt)
# Only update the last message hash if the API call is successful
self.last_msg_hash = new_messages[-1].msg_hash
return response
|