Spaces:
Sleeping
Sleeping
juanluisrto
commited on
Commit
·
8e786b4
1
Parent(s):
034eb7d
Upload 3 files
Browse files- app.py +24 -104
- cyanite.py +74 -0
- langhcain_agent.py +191 -0
app.py
CHANGED
@@ -1,121 +1,41 @@
|
|
1 |
-
import os, json, random, logging
|
2 |
-
from typing import List
|
3 |
-
from dotenv import load_dotenv
|
4 |
-
|
5 |
-
from langchain.agents import AgentType, initialize_agent
|
6 |
-
from langchain.chat_models import ChatOpenAI
|
7 |
-
from langchain.tools import Tool
|
8 |
-
|
9 |
-
from langchain.schema import SystemMessage
|
10 |
-
from langchain.agents import OpenAIFunctionsAgent
|
11 |
-
from langchain.prompts import MessagesPlaceholder
|
12 |
-
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
|
13 |
-
from langchain.chains.conversation.memory import ConversationBufferMemory
|
14 |
-
from langchain.chat_models import ChatOpenAI
|
15 |
-
from langchain.agents import tool, AgentExecutor, OpenAIFunctionsAgent, AgentType, Agent
|
16 |
-
from langchain.schema import SystemMessage
|
17 |
-
from langchain.prompts import MessagesPlaceholder
|
18 |
-
from langchain.chains.conversation.memory import ConversationBufferMemory
|
19 |
-
from langchain.chat_models import ChatOpenAI
|
20 |
-
|
21 |
-
from langchain.prompts import ChatPromptTemplate
|
22 |
-
from langchain.schema import StrOutputParser
|
23 |
-
|
24 |
import gradio as gr
|
|
|
|
|
|
|
25 |
|
26 |
-
load_dotenv()
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
llm = ChatOpenAI(temperature=0)
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
from typing import List, Dict
|
35 |
-
|
36 |
-
@tool
|
37 |
-
def describe_popculture_references(references: List) -> Dict:
|
38 |
-
"A tool used to describe pop-culture references as music styles"
|
39 |
-
prompt = ChatPromptTemplate.from_messages([
|
40 |
-
("system", """You receive a list of pop-culture references (like TV-Shows, films, artists, famous people, etc).
|
41 |
-
For each reference, write a few words separated by commas which captures the essence of it. Use music styles, sounds and instruments.
|
42 |
-
Return a dict with the references as keys and music styles as values.
|
43 |
-
"""),
|
44 |
-
("human", "{references_list}"),
|
45 |
-
])
|
46 |
-
runnable = prompt | llm | StrOutputParser()
|
47 |
-
return runnable.invoke({"references_list" : references})
|
48 |
-
|
49 |
-
|
50 |
-
@tool
|
51 |
-
def extract_popculture_references(input_style: str) -> List:
|
52 |
-
"A tool used to extract pop-culture references from a piece of text"
|
53 |
-
prompt = ChatPromptTemplate.from_messages([
|
54 |
-
("system", """You detect elements of the pop-culture (like TV-Shows, films, artists, famous people, etc) in the human's input message.
|
55 |
-
Return a list with these elements only. If there are none, return an empty list.
|
56 |
-
"""),
|
57 |
-
("human", "{input_style}"),
|
58 |
-
])
|
59 |
-
runnable = prompt | llm | StrOutputParser()
|
60 |
-
output = runnable.invoke({"input_style" : input_style})
|
61 |
-
return output
|
62 |
-
|
63 |
-
@tool
|
64 |
-
def call_music_recommendation_api(input : str) -> List[str]:
|
65 |
-
"""
|
66 |
-
Calls the music recommendation API
|
67 |
-
"""
|
68 |
-
print("Calling music recommendation API: ", input)
|
69 |
-
return {"songs" : [input]}
|
70 |
-
|
71 |
-
tools = [describe_popculture_references, extract_popculture_references, call_music_recommendation_api]
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
76 |
-
"""You are an agent which recommends songs based on the style a user gives.
|
77 |
-
You follow the following conversation protocol:
|
78 |
-
- You start the conversation by asking the user what style of music they like
|
79 |
-
- The user responds with a style of music
|
80 |
-
- If there are pop culture references like a movie, a TV show, an artist, a famous person, extract them AND then describe them as music styles.
|
81 |
-
- Ask the user if he is ok with the new generated style
|
82 |
-
- If the user agrees, call the music recommendation API with this style.
|
83 |
-
""")
|
84 |
|
85 |
|
86 |
-
MEMORY_KEY = "chat_history"
|
87 |
-
prompt = OpenAIFunctionsAgent.create_prompt(
|
88 |
-
system_message=system_message,
|
89 |
-
extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]
|
90 |
-
)
|
91 |
|
92 |
-
|
93 |
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
agent = OpenAIFunctionsAgent(
|
96 |
-
llm=llm,
|
97 |
-
tools=tools,
|
98 |
-
prompt=prompt,
|
99 |
-
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION
|
100 |
-
)
|
101 |
|
102 |
-
agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True)
|
103 |
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
yield chunk["output"]
|
109 |
-
|
110 |
|
111 |
-
gr.ChatInterface(
|
112 |
-
|
113 |
-
|
|
|
114 |
textbox=gr.Textbox(placeholder="Ask me for music recommendations!", container=False, scale=7),
|
115 |
description="This AI makes song recommendations based on your music style.",
|
|
|
116 |
title="Persona Music song recommender",
|
117 |
-
examples=["Recommend me something in Quentin Tarantino reggae style", "Give me songs with calm and relaxing vibes", "I want to listen to something like the movie Inception", "I want music that sounds like Lebron James eating soup"],
|
118 |
retry_btn="Retry",
|
119 |
clear_btn="Clear",
|
120 |
-
undo_btn = None
|
121 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
from dotenv import load_dotenv
|
5 |
|
6 |
+
load_dotenv(override=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
from langhcain_agent import llm_inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def predict_interface(message, history=None, user_id = None):
|
13 |
|
14 |
+
response = llm_inference(message, history, user_id)
|
15 |
+
logging.error(response)
|
16 |
+
logging.error(user_id)
|
17 |
+
return response['output']
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
20 |
|
21 |
|
22 |
+
session_id = gr.Textbox(value = str(uuid.uuid4()), type = "text", label = "session_id")
|
23 |
+
example_sentences=["Recommend me something in Quentin Tarantino reggae style", "Give me songs with calm and relaxing vibes", "I want to listen to something like the movie Inception", "I want music that sounds like Lebron James eating soup"]
|
24 |
+
examples = [[example, f"user_{i}"] for i, example in enumerate(example_sentences)]
|
|
|
|
|
25 |
|
26 |
+
chat = gr.ChatInterface(
|
27 |
+
predict_interface,
|
28 |
+
additional_inputs= [session_id],
|
29 |
+
chatbot=gr.Chatbot(height=600),
|
30 |
textbox=gr.Textbox(placeholder="Ask me for music recommendations!", container=False, scale=7),
|
31 |
description="This AI makes song recommendations based on your music style.",
|
32 |
+
examples=examples,
|
33 |
title="Persona Music song recommender",
|
|
|
34 |
retry_btn="Retry",
|
35 |
clear_btn="Clear",
|
36 |
+
undo_btn = None
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
chat.queue().launch()
|
cyanite.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
+
|
6 |
+
CYANITE_API_URL = "https://api.cyanite.ai/graphql"
|
7 |
+
CYANITE_ACCESS_TOKEN = os.getenv("CYANITE_ACCESS_TOKEN")
|
8 |
+
|
9 |
+
def free_text_search(search_text, num_tracks=5):
|
10 |
+
headers = {
|
11 |
+
"Authorization": f"Bearer {CYANITE_ACCESS_TOKEN}",
|
12 |
+
"Content-Type": "application/json"
|
13 |
+
}
|
14 |
+
|
15 |
+
query = '''
|
16 |
+
query FreeTextSearch($searchText: String!, $numTracks: Int!) {
|
17 |
+
freeTextSearch(
|
18 |
+
first: $numTracks
|
19 |
+
target: { library: {} }
|
20 |
+
searchText: $searchText
|
21 |
+
) {
|
22 |
+
... on FreeTextSearchError {
|
23 |
+
message
|
24 |
+
code
|
25 |
+
}
|
26 |
+
... on FreeTextSearchConnection {
|
27 |
+
edges {
|
28 |
+
cursor
|
29 |
+
node {
|
30 |
+
id
|
31 |
+
title
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
'''
|
38 |
+
|
39 |
+
variables = {
|
40 |
+
"searchText": search_text,
|
41 |
+
"numTracks": num_tracks
|
42 |
+
}
|
43 |
+
import time
|
44 |
+
|
45 |
+
start_time = time.time()
|
46 |
+
|
47 |
+
response = requests.post(
|
48 |
+
CYANITE_API_URL,
|
49 |
+
headers=headers,
|
50 |
+
json={'query': query, 'variables': variables}
|
51 |
+
)
|
52 |
+
|
53 |
+
end_time = time.time()
|
54 |
+
time_taken = end_time - start_time
|
55 |
+
logging.warning(f"Cyanite API: Time taken: {time_taken} seconds")
|
56 |
+
|
57 |
+
if response.status_code == 200:
|
58 |
+
songs = extract_songs_from_response(response.json())
|
59 |
+
if songs:
|
60 |
+
return songs
|
61 |
+
else:
|
62 |
+
raise Exception("No songs found")
|
63 |
+
else:
|
64 |
+
raise Exception(f"Query failed with status code {response.status_code}")
|
65 |
+
|
66 |
+
def extract_songs_from_response(response_json):
|
67 |
+
try:
|
68 |
+
edges = response_json['data']['freeTextSearch']['edges']
|
69 |
+
if not edges:
|
70 |
+
return None # No songs found
|
71 |
+
songs = [{"id": edge["node"]["id"], "title": edge["node"]["title"]} for edge in edges]
|
72 |
+
return songs
|
73 |
+
except KeyError:
|
74 |
+
raise Exception("Invalid response format")
|
langhcain_agent.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
import pprint
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent,
|
6 |
+
tool)
|
7 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
8 |
+
from langchain.chat_models import ChatOpenAI
|
9 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
10 |
+
from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage
|
11 |
+
from langchain.callbacks import get_openai_callback, FileCallbackHandler
|
12 |
+
from langchain.schema.agent import AgentActionMessageLog, AgentFinish
|
13 |
+
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
14 |
+
from langchain.agents.format_scratchpad import format_to_openai_functions
|
15 |
+
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
|
16 |
+
from langchain.tools.render import format_tool_to_openai_function
|
17 |
+
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
|
18 |
+
from langchain.schema.runnable import RunnableConfig
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
import logging, os, json
|
23 |
+
from collections import defaultdict
|
24 |
+
|
25 |
+
from pydantic import BaseModel, Field
|
26 |
+
|
27 |
+
from cyanite import free_text_search
|
28 |
+
|
29 |
+
from langfuse.callback import CallbackHandler
|
30 |
+
|
31 |
+
if os.getenv("USE_LANGFUSE") == True:
|
32 |
+
handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" )
|
33 |
+
else:
|
34 |
+
handler = []
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
system_message = \
|
39 |
+
"""You are an agent which recommends songs based on music styles provided by the user.
|
40 |
+
- A music style could be a combination of instruments, genres or sounds.
|
41 |
+
- Use get_music_style_description to generate a description of the user's music style.
|
42 |
+
- The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions.
|
43 |
+
- Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them).
|
44 |
+
- Do not mention any songs or artists, nor give a list of songs.
|
45 |
+
Write short responses with a respectful and friendly tone.
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
describe_music_style_message = \
|
51 |
+
"""You receive a music style and your goal is to describe it further with genres, instruments and sounds.
|
52 |
+
If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them.
|
53 |
+
You should return the new music style as a set of words separated by commas.
|
54 |
+
You always give short answers, with at most 20 words.
|
55 |
+
"""
|
56 |
+
|
57 |
+
|
58 |
+
MEMORY_KEY = "history"
|
59 |
+
|
60 |
+
prompt = ChatPromptTemplate.from_messages([
|
61 |
+
("system", system_message),
|
62 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
63 |
+
MessagesPlaceholder(variable_name=MEMORY_KEY),
|
64 |
+
("human", "{input}"),
|
65 |
+
])
|
66 |
+
|
67 |
+
conversation_memories = defaultdict(
|
68 |
+
lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4)
|
69 |
+
)
|
70 |
+
|
71 |
+
#global dicts to store the tracks and the conversation costs
|
72 |
+
music_styles_to_tracks = {}
|
73 |
+
conversation_costs = defaultdict(float)
|
74 |
+
|
75 |
+
|
76 |
+
@tool
|
77 |
+
def get_music_style_description(music_style: str) -> str:
|
78 |
+
"A tool which describes a music style and returns a description of it"
|
79 |
+
description = describe_music_style(music_style)
|
80 |
+
tracks = free_text_search(description, 5)
|
81 |
+
|
82 |
+
logging.warning(f"""
|
83 |
+
music_style = {music_style}
|
84 |
+
music_style_description = {description}
|
85 |
+
tracks = {pprint.pformat(tracks)}""")
|
86 |
+
|
87 |
+
# we store the tracks in a global variable so that we can access them later
|
88 |
+
music_styles_to_tracks[description] = tracks
|
89 |
+
# we return only the description to the user
|
90 |
+
return description
|
91 |
+
|
92 |
+
def describe_music_style(music_style: str) -> str:
|
93 |
+
"A tool used to describe music styles"
|
94 |
+
llm_describe = ChatOpenAI(temperature=0.0)
|
95 |
+
prompt_describe = ChatPromptTemplate.from_messages([
|
96 |
+
("system", describe_music_style_message),
|
97 |
+
("human", "{music_style}"),
|
98 |
+
])
|
99 |
+
runnable = prompt_describe | llm_describe | StrOutputParser()
|
100 |
+
return runnable.invoke({"music_style" : music_style},
|
101 |
+
#RunnableConfig(verbose = True, recursion_limit=1)
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
# We instantiate the Chat Model and bind the tool to it.
|
106 |
+
llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1)
|
107 |
+
llm_with_tools = llm.bind(
|
108 |
+
functions=[
|
109 |
+
format_tool_to_openai_function(get_music_style_description)
|
110 |
+
]
|
111 |
+
)
|
112 |
+
|
113 |
+
def get_agent_executor_from_user_id(user_id) -> AgentExecutor:
|
114 |
+
"Returns an agent executor for a given user_id"
|
115 |
+
memory = conversation_memories[user_id]
|
116 |
+
|
117 |
+
logging.warning(memory)
|
118 |
+
|
119 |
+
agent = (
|
120 |
+
{
|
121 |
+
"input": lambda x: x["input"],
|
122 |
+
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps'])
|
123 |
+
}
|
124 |
+
| RunnablePassthrough.assign(
|
125 |
+
history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY)
|
126 |
+
)
|
127 |
+
| prompt
|
128 |
+
| llm_with_tools
|
129 |
+
| OpenAIFunctionsAgentOutputParser()
|
130 |
+
)
|
131 |
+
|
132 |
+
logging.error(memory)
|
133 |
+
return AgentExecutor(
|
134 |
+
agent=agent,
|
135 |
+
tools=[get_music_style_description],
|
136 |
+
memory=memory,
|
137 |
+
callbacks=[handler] if handler else [],
|
138 |
+
return_intermediate_steps=True,
|
139 |
+
max_execution_time= 30,
|
140 |
+
handle_parsing_errors=True,
|
141 |
+
verbose=True
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List:
|
147 |
+
"Given a list of intermediate steps, returns the tracks from the last get_music_style_description action"
|
148 |
+
if len(intermediate_steps) == 0:
|
149 |
+
return []
|
150 |
+
else:
|
151 |
+
print("INTERMEDIATE STEPS")
|
152 |
+
pprint.pprint(intermediate_steps)
|
153 |
+
print("===================")
|
154 |
+
for action_message, prompt in intermediate_steps[::-1]:
|
155 |
+
if action_message.tool == 'get_music_style_description':
|
156 |
+
tracks = music_styles_to_tracks[prompt]
|
157 |
+
return tracks
|
158 |
+
|
159 |
+
# if none of the actions is get_music_style_description, return empty list
|
160 |
+
return []
|
161 |
+
|
162 |
+
|
163 |
+
def llm_inference(message, history, user_id) -> Dict:
|
164 |
+
"""This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far"""
|
165 |
+
|
166 |
+
# it first creates an agent executor with the previous conversation memory of a given user_id
|
167 |
+
agent_executor = get_agent_executor_from_user_id(user_id)
|
168 |
+
|
169 |
+
with get_openai_callback() as cb:
|
170 |
+
|
171 |
+
# We get the Agent response
|
172 |
+
answer = agent_executor({"input": message})
|
173 |
+
|
174 |
+
# We keep track of the costs
|
175 |
+
conversation_costs[user_id] += cb.total_cost
|
176 |
+
total_conversation_costs = conversation_costs[user_id]
|
177 |
+
|
178 |
+
# We get the tracks from the intermediate steps if any
|
179 |
+
tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps'])
|
180 |
+
|
181 |
+
logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}")
|
182 |
+
logging.warning(music_styles_to_tracks)
|
183 |
+
|
184 |
+
return {
|
185 |
+
"output" : answer['output'],
|
186 |
+
"tracks" : tracks,
|
187 |
+
"cost" : total_conversation_costs
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
|