LLMBB-Agent / agent /llm /qwen_oai.py
ff_li
目录调整
f67d239
raw
history blame
2.12 kB
import os
from typing import Dict, Iterator, List, Optional
import openai
from agent.llm.base import BaseChatModel
from typing import Dict, List, Literal, Optional, Union
class QwenChatAsOAI(BaseChatModel):
def __init__(self, model: str):
super().__init__()
openai.api_base = os.getenv('OPENAI_API_BASE')
openai.api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
self.model = os.getenv('OPENAI_MODEL_NAME', model)
def _chat_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> Iterator[str]:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=True)
# TODO: error handling
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content'):
yield chunk.choices[0].delta.content
def _chat_no_stream(
self,
messages: List[Dict],
stop: Optional[List[str]] = None,
) -> str:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
stop=stop,
stream=False)
# TODO: error handling
return response.choices[0].message.content
def chat_with_functions(self,
messages: List[Dict],
functions: Optional[List[Dict]] = None) -> Dict:
if functions:
response = openai.ChatCompletion.create(model=self.model,
messages=messages,
functions=functions)
else:
response = openai.ChatCompletion.create(model=self.model,
messages=messages)
# TODO: error handling
return response.choices[0].message