agora-demo / toolformers /sambanova /function_calling.py
samuelemarro's picture
Initial upload to test HF Spaces.
3cad23b
raw
history blame
14 kB
import json
import os
import re
from random import random
from pprint import pprint
import time
from typing import List, Optional, Union
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from toolformers.base import Tool, StringParameter
from toolformers.sambanova.api_gateway import APIGateway
from toolformers.sambanova.utils import get_total_usage, usage_tracker
FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools:
{tools}
You can call one or more tools by adding a <ToolCalls> section to your message. For example:
<ToolCalls>
```json
[{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}]
```
</ToolCalls>
Note that you can select multiple tools at once by adding more objects to the list. Do not add \
multiple <ToolCalls> sections to the same message.
You will see the invocation of the tools in the response.
Think step by step
Do not call a tool if the input depends on another tool output that you do not have yet.
Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do.
Your answer should be in the same language as the initial query.
""" # noqa E501
conversational_response = Tool(
name='ConversationalResponse',
description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.',
parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)],
function=None
)
class FunctionCallingLlm:
"""
function calling llm class
"""
def __init__(
self,
tools: Optional[Union[Tool, List[Tool]]] = None,
default_tool: Optional[Tool] = None,
system_prompt: Optional[str] = None,
prod_mode: bool = False,
api: str = 'sncloud',
coe: bool = False,
do_sample: bool = False,
max_tokens_to_generate: Optional[int] = None,
temperature: float = 0.2,
select_expert: Optional[str] = None,
) -> None:
"""
Args:
tools (Optional[Union[Tool, List[Tool]]]): The tools to use.
default_tool (Optional[Tool]): The default tool to use.
defaults to ConversationalResponse
system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT
prod_mode (bool): Whether to use production mode. Defaults to False.
api (str): The api to use. Defaults to 'sncloud'.
coe (bool): Whether to use coe. Defaults to False.
do_sample (bool): Whether to do sample. Defaults to False.
max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens.
temperature (float): The model temperature. Defaults to 0.2.
select_expert (Optional[str]): The expert to use. Defaults to None.
"""
self.prod_mode = prod_mode
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
self.api = api
self.llm = APIGateway.load_llm(
type=api,
streaming=True,
coe=coe,
do_sample=do_sample,
max_tokens_to_generate=max_tokens_to_generate,
temperature=temperature,
select_expert=select_expert,
process_prompt=False,
sambanova_api_key=sambanova_api_key,
)
if isinstance(tools, Tool):
tools = [tools]
self.tools = tools
if system_prompt is None:
system_prompt = ''
system_prompt = system_prompt.replace('{','{{').replace('}', '}}')
if len(self.tools) > 0:
system_prompt += '\n\n'
system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT
self.system_prompt = system_prompt
if default_tool is None:
default_tool = conversational_response
def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]:
"""
Given a list of tool executions the llm return as required
execute them given the name with the mane in tools_map and the input arguments
if there is only one tool call and it is default conversational one, the response is marked as final response
Args:
invoked_tools (List[dict]): The list of tool executions generated by the LLM.
"""
if self.tools is not None:
tools_map = {tool.name.lower(): tool for tool in self.tools}
else:
tools_map = {}
tool_msg = "Tool '{name}' response: {response}"
tools_msgs = []
if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse':
final_answer = True
return final_answer, [invoked_tools[0]['tool_input']['response']]
final_answer = False
for tool in invoked_tools:
if tool['tool'].lower() == 'invocationerror':
tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}')
elif tool['tool'].lower() != 'conversationalresponse':
print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n")
if tool['tool'].lower() not in tools_map:
tools_msgs.append(f'Tool {tool["tool"]} not found')
else:
response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input'])
# print(f'Tool response: {str(response)}\n---\n\n')
tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response)))
return final_answer, tools_msgs
def json_finder(self, input_string: str) -> Optional[str]:
"""
find json structures in an LLM string response, if bad formatted using LLM to correct it
Args:
input_string (str): The string to find the json structure in.
"""
# 1. Ideal pattern: correctly surrounded by <ToolCalls> tags
json_pattern_1 = re.compile(r'<ToolCalls\>(.*)</ToolCalls\>', re.DOTALL + re.IGNORECASE)
# 2. Sometimes the closing tag is missing
json_pattern_2 = re.compile(r'<ToolCalls\>(.*)', re.DOTALL + re.IGNORECASE)
# 3. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls>
json_pattern_3 = re.compile(r'<ToolCall\>(.*)</ToolCall\>', re.DOTALL + re.IGNORECASE)
# 4. Sometimes it accidentally uses <ToolCall> instead of <ToolCalls> and the closing tag is missing
json_pattern_4 = re.compile(r'<ToolCall\>(.*)', re.DOTALL + re.IGNORECASE)
# Find the first JSON structure in the string
json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string)
if json_match:
json_str = json_match.group(1)
# 1. Outermost list of JSON object
call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL)
# 2. Outermost JSON object
call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL)
call_match_1 = call_pattern_1.search(json_str)
call_match_2 = call_pattern_2.search(json_str)
if call_match_1:
json_str = call_match_1.group(0)
try:
return json.loads(json_str)
except Exception as e:
return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
elif call_match_2:
json_str = call_match_2.group(0)
try:
return [json.loads(json_str)]
except Exception as e:
return [{'tool': 'InvocationError', 'tool_input' : str(e)}]
else:
return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the <ToolCalls> section'}]
else:
dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}]
json_str = dummy_json_response
return json_str
def msgs_to_llama3_str(self, msgs: list) -> str:
"""
convert a list of langchain messages with roles to expected LLmana 3 input
Args:
msgs (list): The list of langchain messages.
"""
formatted_msgs = []
for msg in msgs:
if msg.type == 'system':
sys_placeholder = (
'<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}'
)
formatted_msgs.append(sys_placeholder.format(msg=msg.content))
elif msg.type == 'human':
human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
formatted_msgs.append(human_placeholder.format(msg=msg.content))
elif msg.type == 'ai':
assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}'
formatted_msgs.append(assistant_placeholder.format(msg=msg.content))
elif msg.type == 'tool':
tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501
formatted_msgs.append(tool_placeholder.format(msg=msg.content))
else:
raise ValueError(f'Invalid message type: {msg.type}')
return '\n'.join(formatted_msgs)
def msgs_to_sncloud(self, msgs: list) -> list:
"""
convert a list of langchain messages with roles to expected FastCoE input
Args:
msgs (list): The list of langchain messages.
"""
formatted_msgs = []
for msg in msgs:
if msg.type == 'system':
formatted_msgs.append({'role': 'system', 'content': msg.content})
elif msg.type == 'human':
formatted_msgs.append({'role': 'user', 'content': msg.content})
elif msg.type == 'ai':
formatted_msgs.append({'role': 'assistant', 'content': msg.content})
elif msg.type == 'tool':
formatted_msgs.append({'role': 'tools', 'content': msg.content})
else:
raise ValueError(f'Invalid message type: {msg.type}')
return json.dumps(formatted_msgs)
def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str:
"""
invocation method for function calling workflow
Args:
query (str): The query to execute.
max_it (int, optional): The maximum number of iterations. Defaults to 5.
debug (bool, optional): Whether to print debug information. Defaults to False.
"""
function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)])
tools_schemas = [tool.as_llama_schema() for tool in self.tools]
history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages()
history.append(HumanMessage(query))
tool_call_id = 0 # identification for each tool calling required to create ToolMessages
with usage_tracker():
for i in range(max_it):
json_parsing_chain = RunnableLambda(self.json_finder)
if self.api == 'sncloud':
prompt = self.msgs_to_sncloud(history)
else:
prompt = self.msgs_to_llama3_str(history)
# print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n')
exponential_backoff_lower = 30
exponential_backoff_higher = 60
llm_response = None
for _ in range(5):
try:
llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True})
break
except Exception as e:
if '429' in str(e):
print('Rate limit exceeded. Waiting with random exponential backoff.')
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
exponential_backoff_lower *= 2
exponential_backoff_higher *= 2
else:
raise e
print('LLM response:', llm_response)
# print(f'\nFunction calling LLM response: \n{llm_response}\n---\n')
parsed_tools_llm_response = json_parsing_chain.invoke(llm_response)
history.append(AIMessage(llm_response))
final_answer, tools_msgs = self.execute(parsed_tools_llm_response)
if final_answer: # if response was marked as final response in execution
final_response = tools_msgs[0]
if debug:
print('\n\n---\nFinal function calling LLM history: \n')
pprint(f'{history}')
return final_response, get_total_usage()
else:
history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id))
tool_call_id += 1
raise Exception('Not a final response yet', history)