samuelemarro's picture
Added cost tracking.
c07f594
raw
history blame
1.94 kB
import datetime
import os
from typing import List
from toolformers.base import Conversation, Toolformer, Tool
from toolformers.sambanova.function_calling import FunctionCallingLlm
COSTS = {
'llama3-405b': {
'prompt_tokens': 5e-6,
'completion_tokens': 10e-6
}
}
class SambanovaConversation(Conversation):
def __init__(self, model_name, function_calling_llm : FunctionCallingLlm, category=None):
self.model_name = model_name
self.function_calling_llm = function_calling_llm
self.category = category
def chat(self, message, role='user', print_output=True):
if role != 'user':
raise ValueError('Role must be "user"')
agent_id = os.environ.get('AGENT_ID', None)
start_time = datetime.datetime.now()
response, usage_data = self.function_calling_llm.function_call_llm(message)
end_time = datetime.datetime.now()
print('Usage data:', usage_data)
if print_output:
print(response)
cost = 0
for cost_name in ['prompt_tokens', 'completion_tokens']:
cost += COSTS[self.model_name][cost_name] * usage_data[cost_name]
return response
class SambanovaToolformer(Toolformer):
def __init__(self, model_name: str):
self.model_name = model_name
def new_conversation(self, system_prompt: str, tools: List[Tool], category=None) -> SambanovaConversation:
function_calling_llm = FunctionCallingLlm(system_prompt=system_prompt, tools=tools, select_expert=self.model_name)
return SambanovaConversation(self.model_name, function_calling_llm, category)
#def make_llama_toolformer(model_name, system_prompt: str, tools: List[Tool]):
# if model_name not in ['llama3-8b', 'llama3-70b', 'llama3-405b']:
# raise ValueError(f"Unknown model name: {model_name}")
#
# return SambanovaToolformer(model_name, system_prompt, tools)