Spaces:
Running
Running
import datetime | |
import os | |
from typing import List | |
from toolformers.base import Conversation, Toolformer, Tool | |
from toolformers.sambanova.function_calling import FunctionCallingLlm | |
COSTS = { | |
'llama3-405b': { | |
'prompt_tokens': 5e-6, | |
'completion_tokens': 10e-6 | |
} | |
} | |
class SambanovaConversation(Conversation): | |
def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None): | |
self.model_name = model_name | |
self.function_calling_llm = function_calling_llm | |
self.category = category | |
def chat(self, message, role='user', print_output=True): | |
if role != 'user': | |
raise ValueError('Role must be "user"') | |
agent_id = os.environ.get('AGENT_ID', None) | |
start_time = datetime.datetime.now() | |
response, usage_data = self.function_calling_llm.function_call_llm(message) | |
end_time = datetime.datetime.now() | |
print('Usage data:', usage_data) | |
if print_output: | |
print(response) | |
cost = 0 | |
for cost_name in ['prompt_tokens', 'completion_tokens']: | |
cost += COSTS[self.model_name][cost_name] * usage_data[cost_name] | |
return response | |
class SambanovaToolformer(Toolformer): | |
def __init__(self, model_name: str): | |
self.model_name = model_name | |
def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation: | |
function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name) | |
return SambanovaConversation(self.model_name, function_calling_llm, category) | |
#def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]): | |
# if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']: | |
# raise ValueError(f"Unknown model name: {model_name}") | |
# | |
# return SambanovaToolformer(model_name, system_prompt, tools) |