File size: 3,805 Bytes
3cad23b
 
 
 
 
 
 
 
 
 
 
 
c07f594
 
3cad23b
 
c07f594
 
 
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c07f594
 
 
 
 
 
3cad23b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import datetime
import os
from random import random
import time
import traceback
from typing import List

from toolformers.base import Conversation, Tool, Toolformer

import google.generativeai as genai
from google.generativeai.generative_models import ChatSession

from utils import register_cost

genai.configure(api_key=os.environ['GOOGLE_API_KEY'])

COSTS = {
    'gemini-1.5-pro': {
        'prompt_tokens': 1.25e-6,
        'completion_tokens': 5e-6
    }
}

class GeminiConversation(Conversation):
    def __init__(self, model_name, chat_agent : ChatSession, category=None):
        self.model_name = model_name
        self.chat_agent = chat_agent
        self.category = category

    def chat(self, message, role='user', print_output=True):
        agent_id = os.environ.get('AGENT_ID', None)
        time_start = datetime.datetime.now()

        exponential_backoff_lower = 30
        exponential_backoff_higher = 60
        for i in range(5):
            try:
                response = self.chat_agent.send_message({
                    'role': role,
                    'parts': [
                        message
                    ]
                })
                break
            except Exception as e:
                print(e)
                if '429' in str(e):
                    print('Rate limit exceeded. Waiting with random exponential backoff.')
                    if i < 4:
                        time.sleep(random() * (exponential_backoff_higher - exponential_backoff_lower) + exponential_backoff_lower)
                        exponential_backoff_lower *= 2
                        exponential_backoff_higher *= 2
                elif 'candidates[0]' in traceback.format_exc():
                    # When Gemini has nothing to say, it raises an error with this message
                    print('No response')
                    return 'No response'
                elif '500' in str(e):
                    # Sometimes Gemini just decides to return a 500 error for absolutely no reason. Retry.
                    print('500 error')
                    time.sleep(5)
                    traceback.print_exc()
                else:
                    raise e

        time_end = datetime.datetime.now()

        usage_info = {
            'prompt_tokens': response.usage_metadata.prompt_token_count,
            'completion_tokens': response.usage_metadata.candidates_token_count
        }

        total_cost = 0
        for cost_name in ['prompt_tokens', 'completion_tokens']:
            total_cost += COSTS[self.model_name][cost_name] * usage_info[cost_name]

        register_cost(self.category, total_cost)

        #send_usage_to_db(
        #    usage_info,
        #    time_start,
        #    time_end,
        #    agent_id,
        #    self.category,
        #    self.model_name
        #)

        reply = response.text

        if print_output:
            print(reply)
        
        return reply

class GeminiToolformer(Toolformer):
    def __init__(self, model_name):
        self.model_name = model_name

    def new_conversation(self, system_prompt, tools : List[Tool], category=None) -> Conversation:
        print('Tools:')
        print('\n'.join([str(tool.as_openai_info()) for tool in tools]))
        model = genai.GenerativeModel(
            model_name=self.model_name,
            system_instruction=system_prompt,
            tools=[tool.as_gemini_tool() for tool in tools]
        )

        chat = model.start_chat(enable_automatic_function_calling=True)

        return GeminiConversation(self.model_name, chat, category)

def make_gemini_toolformer(model_name):
    if model_name not in ['gemini-1.5-flash', 'gemini-1.5-pro']:
        raise ValueError(f"Unknown model name: {model_name}")

    return GeminiToolformer(model_name)