weizechen commited on
Commit
1c62b4b
·
1 Parent(s): e06526c

fix openai api base bug

Browse files
Files changed (2) hide show
  1. agentverse/llms/openai.py +24 -39
  2. app.py +0 -1
agentverse/llms/openai.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import numpy as np
6
  from aiohttp import ClientSession
7
  from typing import Dict, List, Optional, Union
8
- from tenacity import retry, stop_after_attempt, wait_exponential
9
 
10
  from pydantic import BaseModel, Field
11
 
@@ -37,15 +37,20 @@ else:
37
  openai.api_version = "2023-05-15"
38
  is_openai_available = True
39
  else:
40
- logging.warning(
41
- "OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY"
42
- )
43
  is_openai_available = False
44
 
45
 
 
 
 
 
 
 
 
46
  class OpenAIChatArgs(BaseModelArgs):
47
  model: str = Field(default="gpt-3.5-turbo")
48
- deployment_id: str = Field(default="")
49
  max_tokens: int = Field(default=2048)
50
  temperature: float = Field(default=1.0)
51
  top_p: int = Field(default=1)
@@ -117,6 +122,7 @@ class OpenAIChat(BaseChatModel):
117
  stop=stop_after_attempt(20),
118
  wait=wait_exponential(multiplier=1, min=4, max=10),
119
  reraise=True,
 
120
  )
121
  def generate_response(
122
  self,
@@ -139,13 +145,9 @@ class OpenAIChat(BaseChatModel):
139
  self.collect_metrics(response)
140
  return LLMResult(
141
  content=response["choices"][0]["message"].get("content", ""),
142
- function_name=response["choices"][0]["message"][
143
- "function_call"
144
- ]["name"],
145
  function_arguments=ast.literal_eval(
146
- response["choices"][0]["message"]["function_call"][
147
- "arguments"
148
- ]
149
  ),
150
  send_tokens=response["usage"]["prompt_tokens"],
151
  recv_tokens=response["usage"]["completion_tokens"],
@@ -179,6 +181,7 @@ class OpenAIChat(BaseChatModel):
179
  stop=stop_after_attempt(20),
180
  wait=wait_exponential(multiplier=1, min=4, max=10),
181
  reraise=True,
 
182
  )
183
  async def agenerate_response(
184
  self,
@@ -200,9 +203,7 @@ class OpenAIChat(BaseChatModel):
200
  **self.args.dict(),
201
  )
202
  if response["choices"][0]["message"].get("function_call") is not None:
203
- function_name = response["choices"][0]["message"]["function_call"][
204
- "name"
205
- ]
206
  valid_function = False
207
  if function_name.startswith("function."):
208
  function_name = function_name.replace("function.", "")
@@ -220,27 +221,15 @@ class OpenAIChat(BaseChatModel):
220
  f"The returned function name {function_name} is not in the list of valid functions."
221
  )
222
  try:
223
- arguments = ast.literal_eval(
224
- response["choices"][0]["message"]["function_call"][
225
- "arguments"
226
- ]
227
- )
228
  except:
229
  try:
230
  arguments = ast.literal_eval(
231
- JsonRepair(
232
- response["choices"][0]["message"]["function_call"][
233
- "arguments"
234
- ]
235
- ).repair()
236
  )
237
  except:
238
- logger.warn(
239
- "The returned argument in function call is not valid json. Retrying..."
240
- )
241
- raise ValueError(
242
- "The returned argument in function call is not valid json."
243
- )
244
  self.collect_metrics(response)
245
  return LLMResult(
246
  function_name=function_name,
@@ -276,9 +265,7 @@ class OpenAIChat(BaseChatModel):
276
  except (OpenAIError, KeyboardInterrupt, json.decoder.JSONDecodeError) as error:
277
  raise
278
 
279
- def construct_messages(
280
- self, prepend_prompt: str, history: List[dict], append_prompt: str
281
- ):
282
  messages = []
283
  if prepend_prompt != "":
284
  messages.append({"role": "system", "content": prepend_prompt})
@@ -332,13 +319,11 @@ def get_embedding(text: str, attempts=3) -> np.array:
332
  try:
333
  text = text.replace("\n", " ")
334
  if openai.api_type == "azure":
335
- embedding = openai.Embedding.create(
336
- input=[text], deployment_id="text-embedding-ada-002"
337
- )["data"][0]["embedding"]
338
  else:
339
- embedding = openai.Embedding.create(
340
- input=[text], model="text-embedding-ada-002"
341
- )["data"][0]["embedding"]
342
  return tuple(embedding)
343
  except Exception as e:
344
  attempts += 1
 
5
  import numpy as np
6
  from aiohttp import ClientSession
7
  from typing import Dict, List, Optional, Union
8
+ from tenacity import retry, stop_after_attempt, wait_exponential, RetryCallState
9
 
10
  from pydantic import BaseModel, Field
11
 
 
37
  openai.api_version = "2023-05-15"
38
  is_openai_available = True
39
  else:
40
+ logging.warning("OpenAI API key is not set. Please set the environment variable OPENAI_API_KEY")
 
 
41
  is_openai_available = False
42
 
43
 
44
+ def log_retry(retry_state: RetryCallState):
45
+ exception = retry_state.outcome.exception()
46
+ logger.warn(
47
+ f"Retrying {retry_state.fn}\nAttempt: {retry_state.attempt_number}\nException: {exception.__class__.__name__} {exception}",
48
+ )
49
+
50
+
51
  class OpenAIChatArgs(BaseModelArgs):
52
  model: str = Field(default="gpt-3.5-turbo")
53
+ deployment_id: Optional[str] = Field(default=None)
54
  max_tokens: int = Field(default=2048)
55
  temperature: float = Field(default=1.0)
56
  top_p: int = Field(default=1)
 
122
  stop=stop_after_attempt(20),
123
  wait=wait_exponential(multiplier=1, min=4, max=10),
124
  reraise=True,
125
+ before_sleep=log_retry
126
  )
127
  def generate_response(
128
  self,
 
145
  self.collect_metrics(response)
146
  return LLMResult(
147
  content=response["choices"][0]["message"].get("content", ""),
148
+ function_name=response["choices"][0]["message"]["function_call"]["name"],
 
 
149
  function_arguments=ast.literal_eval(
150
+ response["choices"][0]["message"]["function_call"]["arguments"]
 
 
151
  ),
152
  send_tokens=response["usage"]["prompt_tokens"],
153
  recv_tokens=response["usage"]["completion_tokens"],
 
181
  stop=stop_after_attempt(20),
182
  wait=wait_exponential(multiplier=1, min=4, max=10),
183
  reraise=True,
184
+ before_sleep=log_retry,
185
  )
186
  async def agenerate_response(
187
  self,
 
203
  **self.args.dict(),
204
  )
205
  if response["choices"][0]["message"].get("function_call") is not None:
206
+ function_name = response["choices"][0]["message"]["function_call"]["name"]
 
 
207
  valid_function = False
208
  if function_name.startswith("function."):
209
  function_name = function_name.replace("function.", "")
 
221
  f"The returned function name {function_name} is not in the list of valid functions."
222
  )
223
  try:
224
+ arguments = ast.literal_eval(response["choices"][0]["message"]["function_call"]["arguments"])
 
 
 
 
225
  except:
226
  try:
227
  arguments = ast.literal_eval(
228
+ JsonRepair(response["choices"][0]["message"]["function_call"]["arguments"]).repair()
 
 
 
 
229
  )
230
  except:
231
+ logger.warn("The returned argument in function call is not valid json. Retrying...")
232
+ raise ValueError("The returned argument in function call is not valid json.")
 
 
 
 
233
  self.collect_metrics(response)
234
  return LLMResult(
235
  function_name=function_name,
 
265
  except (OpenAIError, KeyboardInterrupt, json.decoder.JSONDecodeError) as error:
266
  raise
267
 
268
+ def construct_messages(self, prepend_prompt: str, history: List[dict], append_prompt: str):
 
 
269
  messages = []
270
  if prepend_prompt != "":
271
  messages.append({"role": "system", "content": prepend_prompt})
 
319
  try:
320
  text = text.replace("\n", " ")
321
  if openai.api_type == "azure":
322
+ embedding = openai.Embedding.create(input=[text], deployment_id="text-embedding-ada-002")["data"][0][
323
+ "embedding"
324
+ ]
325
  else:
326
+ embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
 
 
327
  return tuple(embedding)
328
  except Exception as e:
329
  attempts += 1
app.py CHANGED
@@ -335,7 +335,6 @@ class GUI:
335
  """
336
 
337
  # data = self.backend.next_data()
338
-
339
  return_message = self.backend.next()
340
 
341
  data = self.return_format(return_message)
 
335
  """
336
 
337
  # data = self.backend.next_data()
 
338
  return_message = self.backend.next()
339
 
340
  data = self.return_format(return_message)