File size: 7,132 Bytes
8e786b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from operator import itemgetter
import pprint
from typing import Dict, List

from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent,
                              tool)
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage
from langchain.callbacks import get_openai_callback, FileCallbackHandler
from langchain.schema.agent import AgentActionMessageLog, AgentFinish
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.tools.render import format_tool_to_openai_function
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.schema.runnable import RunnableConfig



import logging, os, json
from collections import defaultdict

from pydantic import BaseModel, Field

from cyanite import free_text_search

from langfuse.callback import CallbackHandler

if os.getenv("USE_LANGFUSE") == True:
    handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" )
else:
    handler = []



system_message = \
"""You are an agent which recommends songs based on music styles provided by the user.
- A music style could be a combination of instruments, genres or sounds.
- Use get_music_style_description to generate a description of the user's music style.
- The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions.
- Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them).
- Do not mention any songs or artists, nor give a list of songs.
Write short responses with a respectful and friendly tone.
"""



describe_music_style_message = \
"""You receive a music style and your goal is to describe it further with genres, instruments and sounds.
If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them.
You should return the new music style as a set of words separated by commas.
You always give short answers, with at most 20 words.
"""


MEMORY_KEY = "history"

prompt = ChatPromptTemplate.from_messages([
    ("system", system_message),
    MessagesPlaceholder(variable_name="agent_scratchpad"),
    MessagesPlaceholder(variable_name=MEMORY_KEY),
    ("human", "{input}"),
])

conversation_memories = defaultdict(
    lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4)
)

#global dicts to store the tracks and the conversation costs
music_styles_to_tracks = {}
conversation_costs = defaultdict(float)


@tool
def get_music_style_description(music_style: str) -> str:
    "A tool which describes a music style and returns a description of it"
    description = describe_music_style(music_style)
    tracks = free_text_search(description, 5)

    logging.warning(f"""
                  music_style = {music_style}
                  music_style_description = {description}
                  tracks = {pprint.pformat(tracks)}""")

    # we store the tracks in a global variable so that we can access them later
    music_styles_to_tracks[description] = tracks 
    # we return only the description to the user
    return description

def describe_music_style(music_style: str) -> str:
    "A tool used to describe music styles"
    llm_describe = ChatOpenAI(temperature=0.0)
    prompt_describe = ChatPromptTemplate.from_messages([
        ("system", describe_music_style_message),
        ("human", "{music_style}"),
    ])
    runnable = prompt_describe | llm_describe | StrOutputParser()
    return runnable.invoke({"music_style" : music_style}, 
                           #RunnableConfig(verbose = True, recursion_limit=1)
                           )


# We instantiate the Chat Model and bind the tool to it.
llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1)
llm_with_tools = llm.bind(
    functions=[
        format_tool_to_openai_function(get_music_style_description)
    ]
)

def get_agent_executor_from_user_id(user_id) -> AgentExecutor:
    "Returns an agent executor for a given user_id"
    memory = conversation_memories[user_id]

    logging.warning(memory)

    agent = (
        {
            "input": lambda x: x["input"],
            "agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps'])
        }
        | RunnablePassthrough.assign(
            history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY)
        )
        | prompt
        | llm_with_tools
        | OpenAIFunctionsAgentOutputParser()
    )
    
    logging.error(memory)
    return AgentExecutor(
        agent=agent,
        tools=[get_music_style_description],
        memory=memory,
        callbacks=[handler] if handler else [],
        return_intermediate_steps=True,
        max_execution_time= 30,
        handle_parsing_errors=True,
        verbose=True
        )



def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List:
    "Given a list of intermediate steps, returns the tracks from the last get_music_style_description action"
    if len(intermediate_steps) == 0:
        return []
    else:
        print("INTERMEDIATE STEPS")
        pprint.pprint(intermediate_steps)
        print("===================")
        for action_message, prompt in intermediate_steps[::-1]:
            if action_message.tool == 'get_music_style_description':
                tracks = music_styles_to_tracks[prompt]
                return tracks

        # if none of the actions is get_music_style_description, return empty list
        return []


def llm_inference(message, history, user_id) -> Dict:
    """This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far"""
  
    # it first creates an agent executor with the previous conversation memory of a given user_id
    agent_executor = get_agent_executor_from_user_id(user_id)
    
    with get_openai_callback() as cb:

        # We get the Agent response
        answer = agent_executor({"input": message})

        # We keep track of the costs
        conversation_costs[user_id] += cb.total_cost
        total_conversation_costs = conversation_costs[user_id]

        # We get the tracks from the intermediate steps if any
        tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps'])

        logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}")
        logging.warning(music_styles_to_tracks)
        
    return {
        "output" : answer['output'], 
        "tracks" : tracks,
        "cost" : total_conversation_costs
              }