import datetime import os from typing import List from toolformers.base import Conversation, Toolformer, Tool from toolformers.sambanova.function_calling import FunctionCallingLlm COSTS = { 'llama3-405b': { 'prompt_tokens': 5e-6, 'completion_tokens': 10e-6 } } class SambanovaConversation(Conversation): def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None): self.model_name = model_name self.function_calling_llm = function_calling_llm self.category = category def chat(self, message, role='user', print_output=True): if role != 'user': raise ValueError('Role must be "user"') agent_id = os.environ.get('AGENT_ID', None) start_time = datetime.datetime.now() response, usage_data = self.function_calling_llm.function_call_llm(message) end_time = datetime.datetime.now() print('Usage data:', usage_data) if print_output: print(response) cost = 0 for cost_name in ['prompt_tokens', 'completion_tokens']: cost += COSTS[self.model_name][cost_name] * usage_data[cost_name] return response class SambanovaToolformer(Toolformer): def __init__(self, model_name: str): self.model_name = model_name def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation: function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name) return SambanovaConversation(self.model_name, function_calling_llm, category) #def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]): # if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']: # raise ValueError(f"Unknown model name: {model_name}") # # return SambanovaToolformer(model_name, system_prompt, tools)