"""OpenAI chat wrapper.""" from __future__ import annotations from typing import ( Any, AsyncIterator, Iterator, List, Optional, Union, ) from langchain_community.chat_models import ChatOpenAI, AzureChatOpenAI from langchain_community.chat_models.openai import acompletion_with_retry, _convert_delta_to_message_chunk from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import ( agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( AIMessageChunk, BaseMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel from langchain_community.adapters.openai import ( convert_dict_to_message, ) class H2OBaseChatOpenAI: def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk for chunk in self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ): if not isinstance(chunk, dict): chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) finish_reason = choice.get("finish_reason") generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) cg_chunk = self.mod_cg_chunk(cg_chunk) if run_manager: run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) yield cg_chunk def mod_cg_chunk(self, cg_chunk: ChatGenerationChunk) -> ChatGenerationChunk: if 'tools' in self.model_kwargs and self.model_kwargs['tools']: if 'tool_calls' in cg_chunk.message.additional_kwargs: cg_chunk.message.content = cg_chunk.text = cg_chunk.message.additional_kwargs['tool_calls'][0]['function']['arguments'] else: cg_chunk.text = '' return cg_chunk def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, **({"stream": stream} if stream is not None else {}), **kwargs, } response = self.completion_with_retry( messages=message_dicts, run_manager=run_manager, **params ) return self._create_chat_result(response) def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] if not isinstance(response, dict): response = response.dict() for res in response["choices"]: message = convert_dict_to_message(res["message"]) if 'tools' in self.model_kwargs and self.model_kwargs['tools']: if 'tool_calls' in message.additional_kwargs: message.content = ''.join([x['function']['arguments'] for x in message.additional_kwargs['tool_calls']]) generation_info = dict(finish_reason=res.get("finish_reason")) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration( message=message, generation_info=generation_info, ) generations.append(gen) token_usage = response.get("usage", {}) llm_output = { "token_usage": token_usage, "model_name": self.model_name, "system_fingerprint": response.get("system_fingerprint", ""), } return ChatResult(generations=generations, llm_output=llm_output) async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk async for chunk in await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ): if not isinstance(chunk, dict): chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) finish_reason = choice.get("finish_reason") generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ cg_chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) cg_chunk = self.mod_cg_chunk(cg_chunk) if run_manager: await run_manager.on_llm_new_token(token=cg_chunk.text, chunk=cg_chunk) yield cg_chunk async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, **({"stream": stream} if stream is not None else {}), **kwargs, } response = await acompletion_with_retry( self, messages=message_dicts, run_manager=run_manager, **params ) return self._create_chat_result(response) class H2OBaseAzureChatOpenAI(H2OBaseChatOpenAI, AzureChatOpenAI): pass