Spaces:
Sleeping
Sleeping
File size: 2,132 Bytes
c2fa877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import typing as t
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import ChatGeneration
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.messages import BaseMessage
from src.config import logger
class LCMessageLoggerAsync(AsyncCallbackHandler):
"""Custom callback to make Langchain logs easy to read"""
@staticmethod
def langchain_msg_2_role_content(msg: BaseMessage):
return {"role": msg.type, "content": msg.content}
def __init__(self, log_raw_llm_response=True):
super().__init__()
self._log_raw_llm_response = log_raw_llm_response
def on_chat_model_start(
self,
serialized: dict[str, t.Any],
messages: list[list[BaseMessage]],
**kwargs: t.Any,
) -> t.Any:
"""Run when Chat Model starts running."""
if len(messages) != 1:
raise ValueError(f'expected "messages" to have len 1, got: {len(messages)}')
kwargs = serialized["kwargs"]
model_name = kwargs.get("model_name")
if not model_name:
model_name = kwargs.get("deployment_name")
if not model_name:
model_name = "<failed to determine LLM>"
msgs_list = list(map(self.langchain_msg_2_role_content, messages[0]))
msgs_str = "\n".join(map(str, msgs_list))
logger.info(f"call to {model_name} with {len(msgs_list)} messages:\n{msgs_str}")
def on_llm_end(self, response: LLMResult, **kwargs: t.Any) -> t.Any:
"""Run when LLM ends running."""
generations = response.generations
if len(generations) != 1:
raise ValueError(
f'expected "generations" to have len 1, got: {len(generations)}'
)
if len(generations[0]) != 1:
raise ValueError(
f'expected "generations[0]" to have len 1, got: {len(generations[0])}'
)
if self._log_raw_llm_response is True:
gen: ChatGeneration = generations[0][0]
ai_msg = gen.message
logger.info(f'raw LLM response: "{ai_msg.content}"')
|