Spaces:
Sleeping
Sleeping
from operator import itemgetter | |
import pprint | |
from typing import Dict, List | |
from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent, | |
tool) | |
from langchain.chains.conversation.memory import ConversationBufferWindowMemory | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage | |
from langchain.callbacks import get_openai_callback, FileCallbackHandler | |
from langchain.schema.agent import AgentActionMessageLog, AgentFinish | |
from langchain.utils.openai_functions import convert_pydantic_to_openai_function | |
from langchain.agents.format_scratchpad import format_to_openai_functions | |
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda | |
from langchain.tools.render import format_tool_to_openai_function | |
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser | |
from langchain.schema.runnable import RunnableConfig | |
import logging, os, json | |
from collections import defaultdict | |
from pydantic import BaseModel, Field | |
from cyanite import free_text_search | |
from langfuse.callback import CallbackHandler | |
if os.getenv("USE_LANGFUSE") == True: | |
handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" ) | |
else: | |
handler = [] | |
system_message = \ | |
"""You are an agent which recommends songs based on music styles provided by the user. | |
- A music style could be a combination of instruments, genres or sounds. | |
- Use get_music_style_description to generate a description of the user's music style. | |
- The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions. | |
- Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them). | |
- Do not mention any songs or artists, nor give a list of songs. | |
Write short responses with a respectful and friendly tone. | |
""" | |
describe_music_style_message = \ | |
"""You receive a music style and your goal is to describe it further with genres, instruments and sounds. | |
If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them. | |
You should return the new music style as a set of words separated by commas. | |
You always give short answers, with at most 20 words. | |
""" | |
MEMORY_KEY = "history" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_message), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
MessagesPlaceholder(variable_name=MEMORY_KEY), | |
("human", "{input}"), | |
]) | |
conversation_memories = defaultdict( | |
lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4) | |
) | |
#global dicts to store the tracks and the conversation costs | |
music_styles_to_tracks = {} | |
conversation_costs = defaultdict(float) | |
def get_music_style_description(music_style: str) -> str: | |
"A tool which describes a music style and returns a description of it" | |
description = describe_music_style(music_style) | |
tracks = free_text_search(description, 5) | |
logging.warning(f""" | |
music_style = {music_style} | |
music_style_description = {description} | |
tracks = {pprint.pformat(tracks)}""") | |
# we store the tracks in a global variable so that we can access them later | |
music_styles_to_tracks[description] = tracks | |
# we return only the description to the user | |
return description | |
def describe_music_style(music_style: str) -> str: | |
"A tool used to describe music styles" | |
llm_describe = ChatOpenAI(temperature=0.0) | |
prompt_describe = ChatPromptTemplate.from_messages([ | |
("system", describe_music_style_message), | |
("human", "{music_style}"), | |
]) | |
runnable = prompt_describe | llm_describe | StrOutputParser() | |
return runnable.invoke({"music_style" : music_style}, | |
#RunnableConfig(verbose = True, recursion_limit=1) | |
) | |
# We instantiate the Chat Model and bind the tool to it. | |
llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1) | |
llm_with_tools = llm.bind( | |
functions=[ | |
format_tool_to_openai_function(get_music_style_description) | |
] | |
) | |
def get_agent_executor_from_user_id(user_id) -> AgentExecutor: | |
"Returns an agent executor for a given user_id" | |
memory = conversation_memories[user_id] | |
logging.warning(memory) | |
agent = ( | |
{ | |
"input": lambda x: x["input"], | |
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']) | |
} | |
| RunnablePassthrough.assign( | |
history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY) | |
) | |
| prompt | |
| llm_with_tools | |
| OpenAIFunctionsAgentOutputParser() | |
) | |
logging.error(memory) | |
return AgentExecutor( | |
agent=agent, | |
tools=[get_music_style_description], | |
memory=memory, | |
callbacks=[handler] if handler else [], | |
return_intermediate_steps=True, | |
max_execution_time= 30, | |
handle_parsing_errors=True, | |
verbose=True | |
) | |
def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List: | |
"Given a list of intermediate steps, returns the tracks from the last get_music_style_description action" | |
if len(intermediate_steps) == 0: | |
return [] | |
else: | |
print("INTERMEDIATE STEPS") | |
pprint.pprint(intermediate_steps) | |
print("===================") | |
for action_message, prompt in intermediate_steps[::-1]: | |
if action_message.tool == 'get_music_style_description': | |
tracks = music_styles_to_tracks[prompt] | |
return tracks | |
# if none of the actions is get_music_style_description, return empty list | |
return [] | |
def llm_inference(message, history, user_id) -> Dict: | |
"""This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far""" | |
# it first creates an agent executor with the previous conversation memory of a given user_id | |
agent_executor = get_agent_executor_from_user_id(user_id) | |
with get_openai_callback() as cb: | |
# We get the Agent response | |
answer = agent_executor({"input": message}) | |
# We keep track of the costs | |
conversation_costs[user_id] += cb.total_cost | |
total_conversation_costs = conversation_costs[user_id] | |
# We get the tracks from the intermediate steps if any | |
tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps']) | |
logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}") | |
logging.warning(music_styles_to_tracks) | |
return { | |
"output" : answer['output'], | |
"tracks" : tracks, | |
"cost" : total_conversation_costs | |
} | |