File size: 5,378 Bytes
5e9bd47
 
 
 
 
 
 
 
 
 
 
 
 
 
6b368e8
 
 
 
 
5e9bd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import base64
import json
from openai import AzureOpenAI
import os
import sys
sys.path.append('./rxn/')
import torch
import json
from getReaction import get_reaction



class RXNIM:

    def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'):
        # 从环境变量读取 API Key
        self.API_KEY = os.environ.get('key')
        if not self.API_KEY:
            raise ValueError("Environment variable 'KEY' not set.")

        # Set up client
        self.client = AzureOpenAI(
            api_key=self.API_KEY,
            api_version=api_version,
            azure_endpoint=azure_endpoint,
        )

        # Define tools
        self.tools = [
            {
                'type': 'function',
                'function': {
                    'name': 'get_reaction',
                    'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
                    'parameters': {
                        'type': 'object',
                        'properties': {
                            'image_path': {
                                'type': 'string',
                                'description': 'The path to the reaction image.',
                            },
                        },
                        'required': ['image_path'],
                        'additionalProperties': False,
                    },
                },
            },
        ]

        # Define tool mapping
        self.TOOL_MAP = {
            'get_reaction': get_reaction,
        }

    def encode_image(self, image_path: str):
        '''Returns a base64 string of the input image.'''
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')

    def process(self, image_path: str, prompt_path: str):
        # Encode image
        base64_image = self.encode_image(image_path)

        # Read prompt
        with open(prompt_path, 'r') as prompt_file:
            prompt = prompt_file.read()

        # Build initial messages
        messages = [
            {'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'},
            {
                'role': 'user',
                'content': [
                    {
                        'type': 'text',
                        'text': prompt
                    },
                    {
                        'type': 'image_url',
                        'image_url': {
                            'url': f'data:image/png;base64,{base64_image}'
                        }
                    }
                ]
            },
        ]

        MAX_ITERATIONS = 5
        iterations = 0

        while iterations < MAX_ITERATIONS:
            iterations += 1
            print(f'Iteration {iterations}')

            # Call the model
            response = self.client.chat.completions.create(
                model='gpt-4o',
                temperature=0,
                response_format={'type': 'json_object'},
                messages=messages,
                tools=self.tools,
            )

            # Get assistant's message
            assistant_message = response.choices[0].message

            # Add assistant's message to messages
            messages.append(assistant_message)

            # Check for tool calls
            if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
                tool_calls = assistant_message.tool_calls
                results = []

                for tool_call in tool_calls:
                    tool_name = tool_call.function.name
                    tool_arguments = tool_call.function.arguments
                    tool_call_id = tool_call.id

                    tool_args = json.loads(tool_arguments)

                    if tool_name in self.TOOL_MAP:
                        try:
                            # Call the tool function
                            tool_result = self.TOOL_MAP[tool_name](image_path)
                            print(f'{tool_name} result: {tool_result}')
                        except Exception as e:
                            tool_result = {'error': str(e)}
                    else:
                        tool_result = {'error': f"Unknown tool called: {tool_name}"}

                    # Append tool result to messages
                    results.append({
                        'role': 'tool',
                        'content': json.dumps({
                            'image_path': image_path,
                            f'{tool_name}': tool_result,
                        }),
                        'tool_call_id': tool_call_id,
                    })
                    print(results)

                # Add tool results to messages
                messages.extend(results)
            else:
                # No more tool calls, assume task is completed
                break

        else:
            # Exceeded maximum iterations
            return "The assistant could not complete the task within the maximum number of iterations."

        # Return the final assistant message
        final_content = assistant_message.content
        return final_content