import json import os import re from random import random from pprint import pprint import time from typing import List, Optional, Union from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage from langchain_core.messages.tool import ToolMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from toolformers.base import Tool, StringParameter from toolformers.sambanova.api_gateway import APIGateway from toolformers.sambanova.utils import get_total_usage, usage_tracker FUNCTION_CALLING_SYSTEM_PROMPT = """You have access to the following tools: {tools} You can call one or more tools by adding a section to your message. For example: ```json [{{ "tool": , "tool_input": }}] ``` Note that you can select multiple tools at once by adding more objects to the list. Do not add \ multiple sections to the same message. You will see the invocation of the tools in the response. Think step by step Do not call a tool if the input depends on another tool output that you do not have yet. Do not try to answer until you get all the tools output, if you do not have an answer yet, you can continue calling tools until you do. Your answer should be in the same language as the initial query. """ # noqa E501 conversational_response = Tool( name='ConversationalResponse', description='Respond conversationally only if no other tools should be called for a given query, or if you have a final answer. Response must be in the same language as the user query.', parameters=[StringParameter(name='response', description='Conversational response to the user. Must be in the same language as the user query.', required=True)], function=None ) class FunctionCallingLlm: """ function calling llm class """ def __init__( self, tools: Optional[Union[Tool, List[Tool]]] = None, default_tool: Optional[Tool] = None, system_prompt: Optional[str] = None, prod_mode: bool = False, api: str = 'sncloud', coe: bool = False, do_sample: bool = False, max_tokens_to_generate: Optional[int] = None, temperature: float = 0.2, select_expert: Optional[str] = None, ) -> None: """ Args: tools (Optional[Union[Tool, List[Tool]]]): The tools to use. default_tool (Optional[Tool]): The default tool to use. defaults to ConversationalResponse system_prompt (Optional[str]): The system prompt to use. defaults to FUNCTION_CALLING_SYSTEM_PROMPT prod_mode (bool): Whether to use production mode. Defaults to False. api (str): The api to use. Defaults to 'sncloud'. coe (bool): Whether to use coe. Defaults to False. do_sample (bool): Whether to do sample. Defaults to False. max_tokens_to_generate (Optional[int]): The max tokens to generate. If None, the model will attempt to use the maximum available tokens. temperature (float): The model temperature. Defaults to 0.2. select_expert (Optional[str]): The expert to use. Defaults to None. """ self.prod_mode = prod_mode sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') self.api = api self.llm = APIGateway.load_llm( type=api, streaming=True, coe=coe, do_sample=do_sample, max_tokens_to_generate=max_tokens_to_generate, temperature=temperature, select_expert=select_expert, process_prompt=False, sambanova_api_key=sambanova_api_key, ) if isinstance(tools, Tool): tools = [tools] self.tools = tools if system_prompt is None: system_prompt = '' system_prompt = system_prompt.replace('{','{{').replace('}', '}}') if len(self.tools) > 0: system_prompt += '\n\n' system_prompt += FUNCTION_CALLING_SYSTEM_PROMPT self.system_prompt = system_prompt if default_tool is None: default_tool = conversational_response def execute(self, invoked_tools: List[dict]) -> tuple[bool, List[str]]: """ Given a list of tool executions the llm return as required execute them given the name with the mane in tools_map and the input arguments if there is only one tool call and it is default conversational one, the response is marked as final response Args: invoked_tools (List[dict]): The list of tool executions generated by the LLM. """ if self.tools is not None: tools_map = {tool.name.lower(): tool for tool in self.tools} else: tools_map = {} tool_msg = "Tool '{name}' response: {response}" tools_msgs = [] if len(invoked_tools) == 1 and invoked_tools[0]['tool'].lower() == 'conversationalresponse': final_answer = True return final_answer, [invoked_tools[0]['tool_input']['response']] final_answer = False for tool in invoked_tools: if tool['tool'].lower() == 'invocationerror': tools_msgs.append(f'Tool invocation error: {tool["tool_input"]}') elif tool['tool'].lower() != 'conversationalresponse': print(f"\n\n---\nTool {tool['tool'].lower()} invoked with input {tool['tool_input']}\n") if tool['tool'].lower() not in tools_map: tools_msgs.append(f'Tool {tool["tool"]} not found') else: response = tools_map[tool['tool'].lower()].call_tool_for_toolformer(**tool['tool_input']) # print(f'Tool response: {str(response)}\n---\n\n') tools_msgs.append(tool_msg.format(name=tool['tool'], response=str(response))) return final_answer, tools_msgs def json_finder(self, input_string: str) -> Optional[str]: """ find json structures in an LLM string response, if bad formatted using LLM to correct it Args: input_string (str): The string to find the json structure in. """ # 1. Ideal pattern: correctly surrounded by tags json_pattern_1 = re.compile(r'(.*)', re.DOTALL + re.IGNORECASE) # 2. Sometimes the closing tag is missing json_pattern_2 = re.compile(r'(.*)', re.DOTALL + re.IGNORECASE) # 3. Sometimes it accidentally uses instead of json_pattern_3 = re.compile(r'(.*)', re.DOTALL + re.IGNORECASE) # 4. Sometimes it accidentally uses instead of and the closing tag is missing json_pattern_4 = re.compile(r'(.*)', re.DOTALL + re.IGNORECASE) # Find the first JSON structure in the string json_match = json_pattern_1.search(input_string) or json_pattern_2.search(input_string) or json_pattern_3.search(input_string) or json_pattern_4.search(input_string) if json_match: json_str = json_match.group(1) # 1. Outermost list of JSON object call_pattern_1 = re.compile(r'\[.*\]', re.DOTALL) # 2. Outermost JSON object call_pattern_2 = re.compile(r'\{.*\}', re.DOTALL) call_match_1 = call_pattern_1.search(json_str) call_match_2 = call_pattern_2.search(json_str) if call_match_1: json_str = call_match_1.group(0) try: return json.loads(json_str) except Exception as e: return [{'tool': 'InvocationError', 'tool_input' : str(e)}] elif call_match_2: json_str = call_match_2.group(0) try: return [json.loads(json_str)] except Exception as e: return [{'tool': 'InvocationError', 'tool_input' : str(e)}] else: return [{'tool': 'InvocationError', 'tool_input' : 'Could not find JSON object in the section'}] else: dummy_json_response = [{'tool': 'ConversationalResponse', 'tool_input': {'response': input_string}}] json_str = dummy_json_response return json_str def msgs_to_llama3_str(self, msgs: list) -> str: """ convert a list of langchain messages with roles to expected LLmana 3 input Args: msgs (list): The list of langchain messages. """ formatted_msgs = [] for msg in msgs: if msg.type == 'system': sys_placeholder = ( '<|begin_of_text|><|start_header_id|>system<|end_header_id|>system<|end_header_id|> {msg}' ) formatted_msgs.append(sys_placeholder.format(msg=msg.content)) elif msg.type == 'human': human_placeholder = '<|eot_id|><|start_header_id|>user<|end_header_id|>\nUser: {msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501 formatted_msgs.append(human_placeholder.format(msg=msg.content)) elif msg.type == 'ai': assistant_placeholder = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant: {msg}' formatted_msgs.append(assistant_placeholder.format(msg=msg.content)) elif msg.type == 'tool': tool_placeholder = '<|eot_id|><|start_header_id|>tools<|end_header_id|>\n{msg} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\nAssistant:' # noqa E501 formatted_msgs.append(tool_placeholder.format(msg=msg.content)) else: raise ValueError(f'Invalid message type: {msg.type}') return '\n'.join(formatted_msgs) def msgs_to_sncloud(self, msgs: list) -> list: """ convert a list of langchain messages with roles to expected FastCoE input Args: msgs (list): The list of langchain messages. """ formatted_msgs = [] for msg in msgs: if msg.type == 'system': formatted_msgs.append({'role': 'system', 'content': msg.content}) elif msg.type == 'human': formatted_msgs.append({'role': 'user', 'content': msg.content}) elif msg.type == 'ai': formatted_msgs.append({'role': 'assistant', 'content': msg.content}) elif msg.type == 'tool': formatted_msgs.append({'role': 'tools', 'content': msg.content}) else: raise ValueError(f'Invalid message type: {msg.type}') return json.dumps(formatted_msgs) def function_call_llm(self, query: str, max_it: int = 10, debug: bool = False) -> str: """ invocation method for function calling workflow Args: query (str): The query to execute. max_it (int, optional): The maximum number of iterations. Defaults to 5. debug (bool, optional): Whether to print debug information. Defaults to False. """ function_calling_chat_template = ChatPromptTemplate.from_messages([('system', self.system_prompt)]) tools_schemas = [tool.as_llama_schema() for tool in self.tools] history = function_calling_chat_template.format_prompt(tools=tools_schemas).to_messages() history.append(HumanMessage(query)) tool_call_id = 0 # identification for each tool calling required to create ToolMessages with usage_tracker(): for i in range(max_it): json_parsing_chain = RunnableLambda(self.json_finder) if self.api == 'sncloud': prompt = self.msgs_to_sncloud(history) else: prompt = self.msgs_to_llama3_str(history) # print(f'\n\n---\nCalling function calling LLM with prompt: \n{prompt}\n') exponential_backoff_lower = 30 exponential_backoff_higher = 60 llm_response = None for _ in range(5): try: llm_response = self.llm.invoke(prompt, stream_options={'include_usage': True}) break except Exception as e: if '429' in str(e): print('Rate limit exceeded. Waiting with random exponential backoff.') time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower) exponential_backoff_lower *= 2 exponential_backoff_higher *= 2 else: raise e print('LLM response:', llm_response) # print(f'\nFunction calling LLM response: \n{llm_response}\n---\n') parsed_tools_llm_response = json_parsing_chain.invoke(llm_response) history.append(AIMessage(llm_response)) final_answer, tools_msgs = self.execute(parsed_tools_llm_response) if final_answer: # if response was marked as final response in execution final_response = tools_msgs[0] if debug: print('\n\n---\nFinal function calling LLM history: \n') pprint(f'{history}') return final_response, get_total_usage() else: history.append(ToolMessage('\n'.join(tools_msgs), tool_call_id=tool_call_id)) tool_call_id += 1 raise Exception('Not a final response yet', history)