File size: 5,378 Bytes
5e9bd47 6b368e8 5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import base64
import json
from openai import AzureOpenAI
import os
import sys
sys.path.append('./rxn/')
import torch
import json
from getReaction import get_reaction
class RXNIM:
def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'):
# 从环境变量读取 API Key
self.API_KEY = os.environ.get('key')
if not self.API_KEY:
raise ValueError("Environment variable 'KEY' not set.")
# Set up client
self.client = AzureOpenAI(
api_key=self.API_KEY,
api_version=api_version,
azure_endpoint=azure_endpoint,
)
# Define tools
self.tools = [
{
'type': 'function',
'function': {
'name': 'get_reaction',
'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.',
'parameters': {
'type': 'object',
'properties': {
'image_path': {
'type': 'string',
'description': 'The path to the reaction image.',
},
},
'required': ['image_path'],
'additionalProperties': False,
},
},
},
]
# Define tool mapping
self.TOOL_MAP = {
'get_reaction': get_reaction,
}
def encode_image(self, image_path: str):
'''Returns a base64 string of the input image.'''
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def process(self, image_path: str, prompt_path: str):
# Encode image
base64_image = self.encode_image(image_path)
# Read prompt
with open(prompt_path, 'r') as prompt_file:
prompt = prompt_file.read()
# Build initial messages
messages = [
{'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': prompt
},
{
'type': 'image_url',
'image_url': {
'url': f'data:image/png;base64,{base64_image}'
}
}
]
},
]
MAX_ITERATIONS = 5
iterations = 0
while iterations < MAX_ITERATIONS:
iterations += 1
print(f'Iteration {iterations}')
# Call the model
response = self.client.chat.completions.create(
model='gpt-4o',
temperature=0,
response_format={'type': 'json_object'},
messages=messages,
tools=self.tools,
)
# Get assistant's message
assistant_message = response.choices[0].message
# Add assistant's message to messages
messages.append(assistant_message)
# Check for tool calls
if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
tool_calls = assistant_message.tool_calls
results = []
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
tool_call_id = tool_call.id
tool_args = json.loads(tool_arguments)
if tool_name in self.TOOL_MAP:
try:
# Call the tool function
tool_result = self.TOOL_MAP[tool_name](image_path)
print(f'{tool_name} result: {tool_result}')
except Exception as e:
tool_result = {'error': str(e)}
else:
tool_result = {'error': f"Unknown tool called: {tool_name}"}
# Append tool result to messages
results.append({
'role': 'tool',
'content': json.dumps({
'image_path': image_path,
f'{tool_name}': tool_result,
}),
'tool_call_id': tool_call_id,
})
print(results)
# Add tool results to messages
messages.extend(results)
else:
# No more tool calls, assume task is completed
break
else:
# Exceeded maximum iterations
return "The assistant could not complete the task within the maximum number of iterations."
# Return the final assistant message
final_content = assistant_message.content
return final_content
|