Spaces:
Running
Running
import json | |
import os | |
import re | |
from random import random | |
from pprint import pprint | |
import time | |
from typing import List, Optional, Union | |
from langchain_core.messages.ai import AIMessage | |
from langchain_core.messages.human import HumanMessage | |
from langchain_core.messages.tool import ToolMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableLambda | |
from toolformers.base import Tool, StringParameter | |
from toolformers.sambanova.api_gateway import APIGateway | |
from toolformers.sambanova.utils import get_total_usage, usage_tracker | |
FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools: | |
{tools} | |
You can call one or more tools by adding a <ToolCalls> section to your message. For example: | |
<ToolCalls> | |
```json | |
[{{ | |
"tool": <name of the selected tool>, | |
"tool_input": <parameters for the selected tool, matching the tool's JSON schema> | |
}}] | |
``` | |
</ToolCalls> | |
Note that you can select multiple tools at once by adding more objects to the list. Do not add \ | |
multiple <ToolCalls> sections to the same message. | |
You will see the invocation of the tools in the response. | |
Think step by step | |
Do not call a tool if the input depends on another tool output that you do not have yet. | |
Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do. | |
Your answer should be in the same language as the initial query. | |
""" # noqa E501 | |
conversational_response = Tool( | |
name='ConversationalResponse', | |
description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.', | |
parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)], | |
function=None | |
) | |
class FunctionCallingLlm: | |
""" | |
function calling llm class | |
""" | |
def __init__( | |
self, | |
tools: Optional[Union[Tool, List[Tool]]] = None, | |
default_tool: Optional[Tool] = None, | |
system_prompt: Optional[str] = None, | |
prod_mode: bool = False, | |
api: str = 'sncloud', | |
coe: bool = False, | |
do_sample: bool = False, | |
max_tokens_to_generate: Optional[int] = None, | |
temperature: float = 0.2, | |
select_expert: Optional[str] = None, | |
) -> None: | |
""" | |
Args: | |
tools (Optional[Union[Tool, List[Tool]]]): The tools to use. | |
default_tool (Optional[Tool]): The default tool to use. | |
defaults to ConversationalResponse | |
system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT | |
prod_mode (bool): Whether to use production mode. Defaults to False. | |
api (str): The api to use. Defaults to 'sncloud'. | |
coe (bool): Whether to use coe. Defaults to False. | |
do_sample (bool): Whether to do sample. Defaults to False. | |
max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens. | |
temperature (float): The model temperature. Defaults to 0.2. | |
select_expert (Optional[str]): The expert to use. Defaults to None. | |
""" | |
self.prod_mode = prod_mode | |
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
self.api = api | |
self.llm = APIGateway.load_llm( | |
type=api, | |
streaming=True, | |
coe=coe, | |
do_sample=do_sample, | |
max_tokens_to_generate=max_tokens_to_generate, | |
temperature=temperature, | |
select_expert=select_expert, | |
process_prompt=False, | |
sambanova_api_key=sambanova_api_key, | |
) | |
if isinstance(tools, Tool): | |
tools = [tools] | |
self.tools = tools | |
if system_prompt is None: | |
system_prompt = '' | |
system_prompt = system_prompt.replace('{','{{').replace('}', '}}') | |
if len(self.tools) > 0: | |
system_prompt += '\n\n' | |
system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT | |
self.system_prompt = system_prompt | |
if default_tool is None: | |
default_tool = conversational_response | |
def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]: | |
""" | |
Given a list of tool executions the llm return as required | |
execute them given the name with the mane in tools_map and the input arguments | |
if there is only one tool call and it is default conversational one, the response is marked as final response | |
Args: | |
invoked_tools (List[dict]): The list of tool executions generated by the LLM. | |
""" | |
if self.tools is not None: | |
tools_map = {tool.name.lower(): tool for tool in self.tools} | |
else: | |
tools_map = {} | |
tool_msg = "Tool '{name}' response: {response}" | |
tools_msgs = [] | |
if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse': | |
final_answer = True | |
return final_answer, [invoked_tools[0]['tool_input']['response']] | |
final_answer = False | |
for tool in invoked_tools: | |
if tool['tool'].lower() == 'invocationerror': | |
tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}') | |
elif tool['tool'].lower() != 'conversationalresponse': | |
print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n") | |
if tool['tool'].lower() not in tools_map: | |
tools_msgs.append(f'Tool {tool["tool"]} not found') | |
else: | |
response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input']) | |
# print(f'Tool response: {str(response)}\n---\n\n') | |
tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response))) | |
return final_answer, tools_msgs | |
def json_finder(self, input_string: str) -> Optional[str]: | |
""" | |
find json structures in an LLM string response, if bad formatted using LLM to correct it | |
Args: | |
input_string (str): The string to find the json structure in. | |
""" | |
# 1. Ideal pattern: correctly surrounded by <ToolCalls> tags | |
json_pattern_1 = re.compile(r'<ToolCalls\>(.*)</ToolCalls\>', re.DOTALL + re.IGNORECASE) | |
# 2. Sometimes the closing tag is missing | |
json_pattern_2 = re.compile(r'<ToolCalls\>(.*)', re.DOTALL + re.IGNORECASE) | |
# 3. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> | |
json_pattern_3 = re.compile(r'<ToolCall\>(.*)</ToolCall\>', re.DOTALL + re.IGNORECASE) | |
# 4. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> and the closing tag is missing | |
json_pattern_4 = re.compile(r'<ToolCall\>(.*)', re.DOTALL + re.IGNORECASE) | |
# Find the first JSON structure in the string | |
json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string) | |
if json_match: | |
json_str = json_match.group(1) | |
# 1. Outermost list of JSON object | |
call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL) | |
# 2. Outermost JSON object | |
call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL) | |
call_match_1 = call_pattern_1.search(json_str) | |
call_match_2 = call_pattern_2.search(json_str) | |
if call_match_1: | |
json_str = call_match_1.group(0) | |
try: | |
return json.loads(json_str) | |
except Exception as e: | |
return [{'tool': 'InvocationError', 'tool_input' : str(e)}] | |
elif call_match_2: | |
json_str = call_match_2.group(0) | |
try: | |
return [json.loads(json_str)] | |
except Exception as e: | |
return [{'tool': 'InvocationError', 'tool_input' : str(e)}] | |
else: | |
return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the <ToolCalls> section'}] | |
else: | |
dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}] | |
json_str = dummy_json_response | |
return json_str | |
def msgs_to_llama3_str(self, msgs: list) -> str: | |
""" | |
convert a list of langchain messages with roles to expected LLmana 3 input | |
Args: | |
msgs (list): The list of langchain messages. | |
""" | |
formatted_msgs = [] | |
for msg in msgs: | |
if msg.type == 'system': | |
sys_placeholder = ( | |
'<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}' | |
) | |
formatted_msgs.append(sys_placeholder.format(msg=msg.content)) | |
elif msg.type == 'human': | |
human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501 | |
formatted_msgs.append(human_placeholder.format(msg=msg.content)) | |
elif msg.type == 'ai': | |
assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}' | |
formatted_msgs.append(assistant_placeholder.format(msg=msg.content)) | |
elif msg.type == 'tool': | |
tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501 | |
formatted_msgs.append(tool_placeholder.format(msg=msg.content)) | |
else: | |
raise ValueError(f'Invalid message type: {msg.type}') | |
return '\n'.join(formatted_msgs) | |
def msgs_to_sncloud(self, msgs: list) -> list: | |
""" | |
convert a list of langchain messages with roles to expected FastCoE input | |
Args: | |
msgs (list): The list of langchain messages. | |
""" | |
formatted_msgs = [] | |
for msg in msgs: | |
if msg.type == 'system': | |
formatted_msgs.append({'role': 'system', 'content': msg.content}) | |
elif msg.type == 'human': | |
formatted_msgs.append({'role': 'user', 'content': msg.content}) | |
elif msg.type == 'ai': | |
formatted_msgs.append({'role': 'assistant', 'content': msg.content}) | |
elif msg.type == 'tool': | |
formatted_msgs.append({'role': 'tools', 'content': msg.content}) | |
else: | |
raise ValueError(f'Invalid message type: {msg.type}') | |
return json.dumps(formatted_msgs) | |
def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str: | |
""" | |
invocation method for function calling workflow | |
Args: | |
query (str): The query to execute. | |
max_it (int, optional): The maximum number of iterations. Defaults to 5. | |
debug (bool, optional): Whether to print debug information. Defaults to False. | |
""" | |
function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)]) | |
tools_schemas = [tool.as_llama_schema() for tool in self.tools] | |
history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages() | |
history.append(HumanMessage(query)) | |
tool_call_id = 0 # identification for each tool calling required to create ToolMessages | |
with usage_tracker(): | |
for i in range(max_it): | |
json_parsing_chain = RunnableLambda(self.json_finder) | |
if self.api == 'sncloud': | |
prompt = self.msgs_to_sncloud(history) | |
else: | |
prompt = self.msgs_to_llama3_str(history) | |
# print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n') | |
exponential_backoff_lower = 30 | |
exponential_backoff_higher = 60 | |
llm_response = None | |
for _ in range(5): | |
try: | |
llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True}) | |
break | |
except Exception as e: | |
if '429' in str(e): | |
print('Rate limit exceeded. Waiting with random exponential backoff.') | |
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower) | |
exponential_backoff_lower *= 2 | |
exponential_backoff_higher *= 2 | |
else: | |
raise e | |
print('LLM response:', llm_response) | |
# print(f'\nFunction calling LLM response: \n{llm_response}\n---\n') | |
parsed_tools_llm_response = json_parsing_chain.invoke(llm_response) | |
history.append(AIMessage(llm_response)) | |
final_answer, tools_msgs = self.execute(parsed_tools_llm_response) | |
if final_answer: # if response was marked as final response in execution | |
final_response = tools_msgs[0] | |
if debug: | |
print('\n\n---\nFinal function calling LLM history: \n') | |
pprint(f'{history}') | |
return final_response, get_total_usage() | |
else: | |
history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id)) | |
tool_call_id += 1 | |
raise Exception('Not a final response yet', history) |