Spaces:
Running
Running
File size: 3,805 Bytes
3cad23b c07f594 3cad23b c07f594 3cad23b c07f594 3cad23b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import datetime
import os
from random import random
import time
import traceback
from typing import List
from toolformers.base import Conversation, Tool, Toolformer
import google.generativeai as genai
from google.generativeai.generative_models import ChatSession
from utils import register_cost
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
COSTS = {
'gemini-1.5-pro': {
'prompt_tokens': 1.25e-6,
'completion_tokens': 5e-6
}
}
class GeminiConversation(Conversation):
def __init__(self, model_name, chat_agent : ChatSession, category=None):
self.model_name = model_name
self.chat_agent = chat_agent
self.category = category
def chat(self, message, role='user', print_output=True):
agent_id = os.environ.get('AGENT_ID', None)
time_start = datetime.datetime.now()
exponential_backoff_lower = 30
exponential_backoff_higher = 60
for i in range(5):
try:
response = self.chat_agent.send_message({
'role': role,
'parts': [
message
]
})
break
except Exception as e:
print(e)
if '429' in str(e):
print('Rate limit exceeded. Waiting with random exponential backoff.')
if i < 4:
time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
exponential_backoff_lower *= 2
exponential_backoff_higher *= 2
elif 'candidates[0]' in traceback.format_exc():
# When Gemini has nothing to say, it raises an error with this message
print('No response')
return 'No response'
elif '500' in str(e):
# Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry.
print('500 error')
time.sleep(5)
traceback.print_exc()
else:
raise e
time_end = datetime.datetime.now()
usage_info = {
'prompt_tokens': response.usage_metadata.prompt_token_count,
'completion_tokens': response.usage_metadata.candidates_token_count
}
total_cost = 0
for cost_name in ['prompt_tokens', 'completion_tokens']:
total_cost += COSTS[self.model_name][cost_name] * usage_info[cost_name]
register_cost(self.category, total_cost)
#send_usage_to_db(
# usage_info,
# time_start,
# time_end,
# agent_id,
# self.category,
# self.model_name
#)
reply = response.text
if print_output:
print(reply)
return reply
class GeminiToolformer(Toolformer):
def __init__(self, model_name):
self.model_name = model_name
def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation:
print('Tools:')
print('\n'.join([str(tool.as_openai_info()) for tool in tools]))
model = genai.GenerativeModel(
model_name=self.model_name,
system_instruction=system_prompt,
tools=[tool.as_gemini_tool() for tool in tools]
)
chat = model.start_chat(enable_automatic_function_calling=True)
return GeminiConversation(self.model_name, chat, category)
def make_gemini_toolformer(model_name):
if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']:
raise ValueError(f"Unknown model name: {model_name}")
return GeminiToolformer(model_name) |