Spaces:
Sleeping
Sleeping
Review mode
Browse files- README.md +2 -0
- copywriter.py +37 -5
- requirements.txt +5 -1
- search_agent.py +10 -8
- search_agent_ui.py +80 -18
- web_crawler.py +2 -1
- web_rag.py +76 -37
README.md
CHANGED
@@ -10,6 +10,8 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
|
|
|
|
13 |
# Simple Search Agent
|
14 |
|
15 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
⚠️ **This project is a demonstration / proof-of-concept and is not intended for use in production environments. It is provided as-is, without warranty or guarantee of any kind. The code and any accompanying materials are for educational, testing, or evaluation purposes only.**⚠️
|
14 |
+
|
15 |
# Simple Search Agent
|
16 |
|
17 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
copywriter.py
CHANGED
@@ -7,7 +7,6 @@ from langchain.prompts.chat import (
|
|
7 |
from langchain.prompts.prompt import PromptTemplate
|
8 |
|
9 |
|
10 |
-
|
11 |
def get_comments_prompt(query, draft):
|
12 |
system_message = SystemMessage(
|
13 |
content="""
|
@@ -35,14 +34,11 @@ def get_comments_prompt(query, draft):
|
|
35 |
)
|
36 |
return [system_message, human_message]
|
37 |
|
38 |
-
|
39 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
40 |
messages = get_comments_prompt(query, draft)
|
41 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
42 |
return response.content
|
43 |
|
44 |
-
|
45 |
-
|
46 |
def get_final_text_prompt(query, draft, comments):
|
47 |
system_message = SystemMessage(
|
48 |
content="""
|
@@ -74,4 +70,40 @@ def get_final_text_prompt(query, draft, comments):
|
|
74 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
75 |
messages = get_final_text_prompt(query, draft, comments)
|
76 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
77 |
-
return response.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from langchain.prompts.prompt import PromptTemplate
|
8 |
|
9 |
|
|
|
10 |
def get_comments_prompt(query, draft):
|
11 |
system_message = SystemMessage(
|
12 |
content="""
|
|
|
34 |
)
|
35 |
return [system_message, human_message]
|
36 |
|
|
|
37 |
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
38 |
messages = get_comments_prompt(query, draft)
|
39 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
40 |
return response.content
|
41 |
|
|
|
|
|
42 |
def get_final_text_prompt(query, draft, comments):
|
43 |
system_message = SystemMessage(
|
44 |
content="""
|
|
|
70 |
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
71 |
messages = get_final_text_prompt(query, draft, comments)
|
72 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
73 |
+
return response.content
|
74 |
+
|
75 |
+
|
76 |
+
def get_compare_texts_prompts(query, draft_text, final_text):
|
77 |
+
system_message = SystemMessage(
|
78 |
+
content="""
|
79 |
+
I want you to act as a writing quality evaluator.
|
80 |
+
I will provide you with the original user request and four texts.
|
81 |
+
Your task is to carefully analyze, compare the two texts across the following dimensions and grade each text 0 to 10:
|
82 |
+
1. Grammar and spelling - Which text has fewer grammatical errors and spelling mistakes?
|
83 |
+
2. Clarity and coherence - Which text is easier to understand and has a more logical flow of ideas? Evaluate how well each text conveys its main points.
|
84 |
+
3. Tone and style - Which text has a more appropriate and engaging tone and writing style for its intended purpose and audience?
|
85 |
+
4. Sticking to the request - Which text is more successful responding to the original user request. Consider the request, the style, the length, etc.
|
86 |
+
5. Overall effectiveness - Considering the above factors, which text is more successful overall at communicating its message and achieving its goals?
|
87 |
+
|
88 |
+
After comparing the texts on these criteria, clearly state which text you think is better and summarize the main reasons why.
|
89 |
+
Provide specific examples from each text to support your evaluation.
|
90 |
+
"""
|
91 |
+
)
|
92 |
+
human_message = HumanMessage(
|
93 |
+
content=f"""
|
94 |
+
Original query: {query}
|
95 |
+
------------------------
|
96 |
+
Text 1: {draft_text}
|
97 |
+
------------------------
|
98 |
+
Text 2: {final_text}
|
99 |
+
------------------------
|
100 |
+
Summary:
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
return [system_message, human_message]
|
104 |
+
|
105 |
+
|
106 |
+
def compare_text(chat_llm, query, draft, final, callbacks=[]):
|
107 |
+
messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
|
108 |
+
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
109 |
+
return response.content
|
requirements.txt
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
boto3
|
2 |
bs4
|
|
|
3 |
cohere
|
4 |
docopt
|
5 |
faiss-cpu
|
@@ -7,7 +9,7 @@ google-api-python-client
|
|
7 |
pdfplumber
|
8 |
python-dotenv
|
9 |
langchain
|
10 |
-
langchain-
|
11 |
langchain-fireworks
|
12 |
langchain_core
|
13 |
langchain_community
|
@@ -18,6 +20,8 @@ langsmith
|
|
18 |
schema
|
19 |
streamlit
|
20 |
selenium
|
|
|
|
|
21 |
rich
|
22 |
trafilatura
|
23 |
watchdog
|
|
|
1 |
+
anthropic
|
2 |
boto3
|
3 |
bs4
|
4 |
+
chromedriver-py
|
5 |
cohere
|
6 |
docopt
|
7 |
faiss-cpu
|
|
|
9 |
pdfplumber
|
10 |
python-dotenv
|
11 |
langchain
|
12 |
+
langchain-aws
|
13 |
langchain-fireworks
|
14 |
langchain_core
|
15 |
langchain_community
|
|
|
20 |
schema
|
21 |
streamlit
|
22 |
selenium
|
23 |
+
tiktoken
|
24 |
+
transformers
|
25 |
rich
|
26 |
trafilatura
|
27 |
watchdog
|
search_agent.py
CHANGED
@@ -8,6 +8,7 @@ Usage:
|
|
8 |
[--temperature=temp]
|
9 |
[--copywrite]
|
10 |
[--max_pages=num]
|
|
|
11 |
[--output=text]
|
12 |
SEARCH_QUERY
|
13 |
search_agent.py --version
|
@@ -21,6 +22,7 @@ Options:
|
|
21 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
22 |
-m model --model=model Use a specific model
|
23 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
|
|
24 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
25 |
|
26 |
"""
|
@@ -63,8 +65,6 @@ def get_selenium_driver():
|
|
63 |
driver = webdriver.Chrome(options=chrome_options)
|
64 |
return driver
|
65 |
|
66 |
-
|
67 |
-
|
68 |
callbacks = []
|
69 |
if os.getenv("LANGCHAIN_API_KEY"):
|
70 |
callbacks.append(
|
@@ -90,14 +90,16 @@ if __name__ == '__main__':
|
|
90 |
temperature = float(arguments["--temperature"])
|
91 |
domain=arguments["--domain"]
|
92 |
max_pages=arguments["--max_pages"]
|
|
|
93 |
output=arguments["--output"]
|
94 |
query = arguments["SEARCH_QUERY"]
|
95 |
|
96 |
chat, embedding_model = wr.get_models(provider, model, temperature)
|
97 |
-
#console.log(f"Using {chat.model_name} on {provider}")
|
98 |
|
99 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
100 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
|
|
|
|
101 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
102 |
|
103 |
with console.status(
|
@@ -112,11 +114,11 @@ if __name__ == '__main__':
|
|
112 |
contents = wc.get_links_contents(sources, get_selenium_driver)
|
113 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
114 |
|
115 |
-
with console.status(f"[bold green]
|
116 |
vector_store = wc.vectorize(contents, embedding_model)
|
117 |
|
118 |
-
with console.status("[bold green]
|
119 |
-
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k =
|
120 |
|
121 |
console.rule(f"[bold green]Response from {provider}")
|
122 |
if output == "text":
|
@@ -129,7 +131,7 @@ if __name__ == '__main__':
|
|
129 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
130 |
comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
|
131 |
|
132 |
-
console.rule(
|
133 |
if output == "text":
|
134 |
console.print(comments)
|
135 |
else:
|
@@ -139,7 +141,7 @@ if __name__ == '__main__':
|
|
139 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
140 |
final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
|
141 |
|
142 |
-
console.rule(
|
143 |
if output == "text":
|
144 |
console.print(final_text)
|
145 |
else:
|
|
|
8 |
[--temperature=temp]
|
9 |
[--copywrite]
|
10 |
[--max_pages=num]
|
11 |
+
[--max_extracts=num]
|
12 |
[--output=text]
|
13 |
SEARCH_QUERY
|
14 |
search_agent.py --version
|
|
|
22 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
23 |
-m model --model=model Use a specific model
|
24 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
25 |
+
-e num --max_extracts=num Max number of page extract to consider [default: 5]
|
26 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
27 |
|
28 |
"""
|
|
|
65 |
driver = webdriver.Chrome(options=chrome_options)
|
66 |
return driver
|
67 |
|
|
|
|
|
68 |
callbacks = []
|
69 |
if os.getenv("LANGCHAIN_API_KEY"):
|
70 |
callbacks.append(
|
|
|
90 |
temperature = float(arguments["--temperature"])
|
91 |
domain=arguments["--domain"]
|
92 |
max_pages=arguments["--max_pages"]
|
93 |
+
max_extract=int(arguments["--max_extracts"])
|
94 |
output=arguments["--output"]
|
95 |
query = arguments["SEARCH_QUERY"]
|
96 |
|
97 |
chat, embedding_model = wr.get_models(provider, model, temperature)
|
|
|
98 |
|
99 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
100 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
101 |
+
if len(optimize_search_query) < 3:
|
102 |
+
optimize_search_query = query
|
103 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
104 |
|
105 |
with console.status(
|
|
|
114 |
contents = wc.get_links_contents(sources, get_selenium_driver)
|
115 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
116 |
|
117 |
+
with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
|
118 |
vector_store = wc.vectorize(contents, embedding_model)
|
119 |
|
120 |
+
with console.status("[bold green]Writing content", spinner='dots8Bit'):
|
121 |
+
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract, callbacks=callbacks)
|
122 |
|
123 |
console.rule(f"[bold green]Response from {provider}")
|
124 |
if output == "text":
|
|
|
131 |
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
132 |
comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
|
133 |
|
134 |
+
console.rule("[bold green]Response from reviewer")
|
135 |
if output == "text":
|
136 |
console.print(comments)
|
137 |
else:
|
|
|
141 |
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
142 |
final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
|
143 |
|
144 |
+
console.rule("[bold green]Final text")
|
145 |
if output == "text":
|
146 |
console.print(final_text)
|
147 |
else:
|
search_agent_ui.py
CHANGED
@@ -10,6 +10,7 @@ from langsmith.client import Client
|
|
10 |
|
11 |
import web_rag as wr
|
12 |
import web_crawler as wc
|
|
|
13 |
|
14 |
dotenv.load_dotenv()
|
15 |
|
@@ -18,7 +19,6 @@ ls_tracer = LangChainTracer(
|
|
18 |
client=Client()
|
19 |
)
|
20 |
|
21 |
-
|
22 |
class StreamHandler(BaseCallbackHandler):
|
23 |
"""Stream handler that appends tokens to container."""
|
24 |
def __init__(self, container, initial_text=""):
|
@@ -28,11 +28,36 @@ class StreamHandler(BaseCallbackHandler):
|
|
28 |
def on_llm_new_token(self, token: str, **kwargs):
|
29 |
self.text += token
|
30 |
self.container.markdown(self.text)
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
st.title("🔍 Simple Search Agent 💬")
|
33 |
|
34 |
if "providers" not in st.session_state:
|
35 |
providers = []
|
|
|
|
|
36 |
if os.getenv("COHERE_API_KEY"):
|
37 |
providers.append("cohere")
|
38 |
if os.getenv("OPENAI_API_KEY"):
|
@@ -41,22 +66,34 @@ if "providers" not in st.session_state:
|
|
41 |
providers.append("groq")
|
42 |
if os.getenv("OLLAMA_API_KEY"):
|
43 |
providers.append("ollama")
|
44 |
-
if os.getenv("FIREWORKS_API_KEY"):
|
45 |
-
providers.append("fireworks")
|
46 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
47 |
providers.append("bedrock")
|
48 |
st.session_state["providers"] = providers
|
49 |
|
50 |
-
with st.sidebar:
|
51 |
-
st.
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if "messages" not in st.session_state:
|
58 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
59 |
-
|
60 |
for message in st.session_state.messages:
|
61 |
st.chat_message(message["role"]).write(message["content"])
|
62 |
if message["role"] == "assistant" and 'message_id' in message:
|
@@ -80,6 +117,7 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
80 |
st.write(f"I should search the web for: {optimize_search_query}")
|
81 |
|
82 |
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
|
|
|
83 |
|
84 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
85 |
contents = wc.get_links_contents(sources)
|
@@ -87,18 +125,42 @@ if prompt := st.chat_input("Enter you instructions..." ):
|
|
87 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
88 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
89 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
|
92 |
with st.chat_message("assistant"):
|
93 |
st_cb = StreamHandler(st.empty())
|
94 |
result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
|
95 |
response = result.content.strip()
|
96 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
97 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
import web_rag as wr
|
12 |
import web_crawler as wc
|
13 |
+
import copywriter as cw
|
14 |
|
15 |
dotenv.load_dotenv()
|
16 |
|
|
|
19 |
client=Client()
|
20 |
)
|
21 |
|
|
|
22 |
class StreamHandler(BaseCallbackHandler):
|
23 |
"""Stream handler that appends tokens to container."""
|
24 |
def __init__(self, container, initial_text=""):
|
|
|
28 |
def on_llm_new_token(self, token: str, **kwargs):
|
29 |
self.text += token
|
30 |
self.container.markdown(self.text)
|
31 |
+
|
32 |
|
33 |
+
def create_links_markdown(sources_list):
|
34 |
+
"""
|
35 |
+
Create a markdown string for each source in the provided JSON.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
sources_list (list): A list of dictionaries representing the sources.
|
39 |
+
Each dictionary should have 'title', 'link', and 'snippet' keys.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
str: A markdown string with a bullet point for each source,
|
43 |
+
including the title linked to the URL and the snippet.
|
44 |
+
"""
|
45 |
+
markdown_list = []
|
46 |
+
for source in sources_list:
|
47 |
+
title = source['title']
|
48 |
+
link = source['link']
|
49 |
+
snippet = source['snippet']
|
50 |
+
markdown = f"- [{title}]({link})\n {snippet}"
|
51 |
+
markdown_list.append(markdown)
|
52 |
+
return "\n".join(markdown_list)
|
53 |
+
|
54 |
+
st.set_page_config(layout="wide")
|
55 |
st.title("🔍 Simple Search Agent 💬")
|
56 |
|
57 |
if "providers" not in st.session_state:
|
58 |
providers = []
|
59 |
+
if os.getenv("FIREWORKS_API_KEY"):
|
60 |
+
providers.append("fireworks")
|
61 |
if os.getenv("COHERE_API_KEY"):
|
62 |
providers.append("cohere")
|
63 |
if os.getenv("OPENAI_API_KEY"):
|
|
|
66 |
providers.append("groq")
|
67 |
if os.getenv("OLLAMA_API_KEY"):
|
68 |
providers.append("ollama")
|
|
|
|
|
69 |
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
70 |
providers.append("bedrock")
|
71 |
st.session_state["providers"] = providers
|
72 |
|
73 |
+
with st.sidebar.expander("Options", expanded=False):
|
74 |
+
model_provider = st.selectbox("Model provider 🧠", st.session_state["providers"])
|
75 |
+
temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
76 |
+
max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
|
77 |
+
top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
|
78 |
+
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a write, then comments and then rewrite")
|
79 |
+
|
80 |
+
with st.sidebar.expander("Links", expanded=False):
|
81 |
+
links_md = st.markdown("")
|
82 |
+
|
83 |
+
if reviewer_mode:
|
84 |
+
with st.sidebar.expander("Answer review", expanded=False):
|
85 |
+
st.caption("Draft")
|
86 |
+
draft_md = st.markdown("")
|
87 |
+
st.divider()
|
88 |
+
st.caption("Comments")
|
89 |
+
comments_md = st.markdown("")
|
90 |
+
st.divider()
|
91 |
+
st.caption("Comparaison")
|
92 |
+
comparaison_md = st.markdown("")
|
93 |
|
94 |
if "messages" not in st.session_state:
|
95 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
96 |
+
|
97 |
for message in st.session_state.messages:
|
98 |
st.chat_message(message["role"]).write(message["content"])
|
99 |
if message["role"] == "assistant" and 'message_id' in message:
|
|
|
117 |
st.write(f"I should search the web for: {optimize_search_query}")
|
118 |
|
119 |
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
|
120 |
+
links_md.markdown(create_links_markdown(sources))
|
121 |
|
122 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
123 |
contents = wc.get_links_contents(sources)
|
|
|
125 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
126 |
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
127 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
128 |
+
|
129 |
+
|
130 |
+
if reviewer_mode:
|
131 |
+
st.write("Creating a draft")
|
132 |
+
draft_prompt = wr.build_rag_prompt(
|
133 |
+
chat, prompt, optimize_search_query,
|
134 |
+
vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
|
135 |
+
draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
136 |
+
draft_md.markdown(draft.content)
|
137 |
+
st.write("Sending draft for review")
|
138 |
+
comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
|
139 |
+
comments_md.markdown(comments)
|
140 |
+
st.write("Reviewing comments and generating final answer")
|
141 |
+
rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
|
142 |
+
else:
|
143 |
+
rag_prompt = wr.build_rag_prompt(
|
144 |
+
chat, prompt, optimize_search_query, vector_store,
|
145 |
+
top_k=top_k_documents, callbacks=[ls_tracer]
|
146 |
+
)
|
147 |
|
|
|
148 |
with st.chat_message("assistant"):
|
149 |
st_cb = StreamHandler(st.empty())
|
150 |
result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
|
151 |
response = result.content.strip()
|
152 |
message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
153 |
st.session_state.messages.append({"role": "assistant", "content": response})
|
154 |
+
|
155 |
+
if st.session_state.messages[-1]["role"] == "assistant":
|
156 |
+
st.download_button(
|
157 |
+
label="Download",
|
158 |
+
data=st.session_state.messages[-1]["content"],
|
159 |
+
file_name=f"{message_id}.txt",
|
160 |
+
mime="text/plain"
|
161 |
+
)
|
162 |
+
|
163 |
+
if reviewer_mode:
|
164 |
+
compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
|
165 |
+
result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
|
166 |
+
comparaison_md.markdown(result.content)
|
web_crawler.py
CHANGED
@@ -35,12 +35,13 @@ def get_sources(query, max_pages=10, domain=None):
|
|
35 |
json_response = response.json()
|
36 |
|
37 |
if 'web' not in json_response or 'results' not in json_response['web']:
|
|
|
38 |
raise Exception('Invalid API response format')
|
39 |
|
40 |
final_results = [{
|
41 |
'title': result['title'],
|
42 |
'link': result['url'],
|
43 |
-
'snippet': result['description'],
|
44 |
'favicon': result.get('profile', {}).get('img', '')
|
45 |
} for result in json_response['web']['results']]
|
46 |
|
|
|
35 |
json_response = response.json()
|
36 |
|
37 |
if 'web' not in json_response or 'results' not in json_response['web']:
|
38 |
+
print(response.text)
|
39 |
raise Exception('Invalid API response format')
|
40 |
|
41 |
final_results = [{
|
42 |
'title': result['title'],
|
43 |
'link': result['url'],
|
44 |
+
'snippet': extract(result['description'], output_format='txt', include_tables=False, include_images=False, include_formatting=True),
|
45 |
'favicon': result.get('profile', {}).get('img', '')
|
46 |
} for result in json_response['web']['results']]
|
47 |
|
web_rag.py
CHANGED
@@ -28,13 +28,14 @@ from langchain.prompts.chat import (
|
|
28 |
from langchain.prompts.prompt import PromptTemplate
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
|
|
31 |
from langchain_cohere.chat_models import ChatCohere
|
32 |
from langchain_cohere.embeddings import CohereEmbeddings
|
33 |
from langchain_fireworks.chat_models import ChatFireworks
|
34 |
-
from langchain_groq import ChatGroq
|
|
|
35 |
from langchain_openai import ChatOpenAI
|
36 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
37 |
-
from langchain_community.chat_models.bedrock import BedrockChat
|
38 |
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
39 |
from langchain_community.chat_models.ollama import ChatOllama
|
40 |
|
@@ -44,15 +45,15 @@ def get_models(provider, model=None, temperature=0.0):
|
|
44 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
45 |
if model is None:
|
46 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
47 |
-
chat_llm =
|
48 |
credentials_profile_name=credentials_profile_name,
|
49 |
model_id=model,
|
50 |
-
model_kwargs={"temperature": temperature,
|
|
|
|
|
|
|
|
|
51 |
)
|
52 |
-
#embedding_model = BedrockEmbeddings(
|
53 |
-
# model_id='cohere.embed-multilingual-v3',
|
54 |
-
# credentials_profile_name=credentials_profile_name
|
55 |
-
#)
|
56 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
57 |
case 'openai':
|
58 |
if model is None:
|
@@ -73,14 +74,17 @@ def get_models(provider, model=None, temperature=0.0):
|
|
73 |
if model is None:
|
74 |
model = 'command-r-plus'
|
75 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
76 |
-
embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
|
|
77 |
case 'fireworks':
|
78 |
if model is None:
|
79 |
-
model = 'accounts/fireworks/models/
|
80 |
-
|
|
|
81 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
82 |
case _:
|
83 |
raise ValueError(f"Unknown LLM provider {provider}")
|
|
|
84 |
return chat_llm, embedding_model
|
85 |
|
86 |
|
@@ -96,12 +100,13 @@ def get_optimized_search_messages(query):
|
|
96 |
"""
|
97 |
system_message = SystemMessage(
|
98 |
content="""
|
99 |
-
I want you to act as a prompt optimizer for web search.
|
|
|
100 |
To optimize the prompt:
|
101 |
-
Identify the key information being requested
|
102 |
-
Arrange the keywords into a concise search string
|
103 |
-
Keep it short, around 1 to 5 words total
|
104 |
-
Put the most important keywords first
|
105 |
|
106 |
Some tips and things to be sure to remove:
|
107 |
- Remove any conversational or instructional phrases
|
@@ -110,44 +115,44 @@ def get_optimized_search_messages(query):
|
|
110 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
111 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
112 |
|
113 |
-
|
114 |
|
115 |
Example:
|
116 |
Question: How do I bake chocolate chip cookies from scratch?
|
117 |
-
|
118 |
Example:
|
119 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
120 |
-
|
121 |
Example:
|
122 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
123 |
-
|
124 |
Example:
|
125 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
126 |
-
|
127 |
Example:
|
128 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
129 |
-
|
130 |
Example:
|
131 |
Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
|
132 |
-
|
133 |
Example:
|
134 |
Question: Biography of Napoleon. Include a table with the major events.
|
135 |
-
|
136 |
Example:
|
137 |
Question: Write a short article on the history of the United States. Include a table with the major events.
|
138 |
-
|
139 |
Example:
|
140 |
Question: Write a short article about the solar system in the style of donald trump
|
141 |
-
|
142 |
Exmaple:
|
143 |
Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
|
144 |
-
|
145 |
"""
|
146 |
)
|
147 |
human_message = HumanMessage(
|
148 |
content=f"""
|
149 |
Question: {query}
|
150 |
-
|
151 |
"""
|
152 |
)
|
153 |
return [system_message, human_message]
|
@@ -230,15 +235,49 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = [
|
|
230 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
231 |
return response.content
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
242 |
-
prompt = build_rag_prompt(question, search_query, vectorstore, top_k=
|
243 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
244 |
-
return response.content
|
|
|
28 |
from langchain.prompts.prompt import PromptTemplate
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
31 |
+
from langchain_aws import ChatBedrock
|
32 |
from langchain_cohere.chat_models import ChatCohere
|
33 |
from langchain_cohere.embeddings import CohereEmbeddings
|
34 |
from langchain_fireworks.chat_models import ChatFireworks
|
35 |
+
#from langchain_groq import ChatGroq
|
36 |
+
from langchain_groq.chat_models import ChatGroq
|
37 |
from langchain_openai import ChatOpenAI
|
38 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
|
|
39 |
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
40 |
from langchain_community.chat_models.ollama import ChatOllama
|
41 |
|
|
|
45 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
46 |
if model is None:
|
47 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
48 |
+
chat_llm = ChatBedrock(
|
49 |
credentials_profile_name=credentials_profile_name,
|
50 |
model_id=model,
|
51 |
+
model_kwargs={"temperature": temperature, "max_tokens":4096 },
|
52 |
+
)
|
53 |
+
embedding_model = BedrockEmbeddings(
|
54 |
+
model_id='cohere.embed-multilingual-v3',
|
55 |
+
credentials_profile_name=credentials_profile_name
|
56 |
)
|
|
|
|
|
|
|
|
|
57 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
58 |
case 'openai':
|
59 |
if model is None:
|
|
|
74 |
if model is None:
|
75 |
model = 'command-r-plus'
|
76 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
77 |
+
#embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
78 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
79 |
case 'fireworks':
|
80 |
if model is None:
|
81 |
+
#model = 'accounts/fireworks/models/dbrx-instruct'
|
82 |
+
model = 'accounts/fireworks/models/llama-v3-70b-instruct'
|
83 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=8192)
|
84 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
85 |
case _:
|
86 |
raise ValueError(f"Unknown LLM provider {provider}")
|
87 |
+
|
88 |
return chat_llm, embedding_model
|
89 |
|
90 |
|
|
|
100 |
"""
|
101 |
system_message = SystemMessage(
|
102 |
content="""
|
103 |
+
I want you to act as a prompt optimizer for web search.
|
104 |
+
I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
|
105 |
To optimize the prompt:
|
106 |
+
- Identify the key information being requested
|
107 |
+
- Arrange the keywords into a concise search string
|
108 |
+
- Keep it short, around 1 to 5 words total
|
109 |
+
- Put the most important keywords first
|
110 |
|
111 |
Some tips and things to be sure to remove:
|
112 |
- Remove any conversational or instructional phrases
|
|
|
115 |
- Remove style instructions (exmaple: "in the style of", engaging, short, long)
|
116 |
- Remove lenght instruction (example: essay, article, letter, etc)
|
117 |
|
118 |
+
You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
|
119 |
|
120 |
Example:
|
121 |
Question: How do I bake chocolate chip cookies from scratch?
|
122 |
+
chocolate chip cookies recipe from scratch**
|
123 |
Example:
|
124 |
Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
|
125 |
+
Marie Curie timeline**
|
126 |
Example:
|
127 |
Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
|
128 |
+
geopolitics nato russia**
|
129 |
Example:
|
130 |
Question: Write an engaging LinkedIn post about Andrew Ng
|
131 |
+
Andrew Ng**
|
132 |
Example:
|
133 |
Question: Write a short article about the solar system in the style of Carl Sagan
|
134 |
+
solar system**
|
135 |
Example:
|
136 |
Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
|
137 |
+
Kubernetes decision**
|
138 |
Example:
|
139 |
Question: Biography of Napoleon. Include a table with the major events.
|
140 |
+
napoleon biography events**
|
141 |
Example:
|
142 |
Question: Write a short article on the history of the United States. Include a table with the major events.
|
143 |
+
united states history events**
|
144 |
Example:
|
145 |
Question: Write a short article about the solar system in the style of donald trump
|
146 |
+
solar system**
|
147 |
Exmaple:
|
148 |
Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
|
149 |
+
freakeconomics book predictions failed**
|
150 |
"""
|
151 |
)
|
152 |
human_message = HumanMessage(
|
153 |
content=f"""
|
154 |
Question: {query}
|
155 |
+
|
156 |
"""
|
157 |
)
|
158 |
return [system_message, human_message]
|
|
|
235 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
236 |
return response.content
|
237 |
|
238 |
+
def get_context_size(chat_llm):
|
239 |
+
if isinstance(chat_llm, ChatOpenAI):
|
240 |
+
if chat_llm.model_name.startswith("gpt-4"):
|
241 |
+
return 128000
|
242 |
+
else:
|
243 |
+
return 16385
|
244 |
+
if isinstance(chat_llm, ChatFireworks):
|
245 |
+
return 8192
|
246 |
+
if isinstance(chat_llm, ChatGroq):
|
247 |
+
return 37862
|
248 |
+
if isinstance(chat_llm, ChatOllama):
|
249 |
+
return 8192
|
250 |
+
if isinstance(chat_llm, ChatCohere):
|
251 |
+
return 128000
|
252 |
+
if isinstance(chat_llm, ChatBedrock):
|
253 |
+
if chat_llm.model_id.startswith("anthropic.claude-3"):
|
254 |
+
return 200000
|
255 |
+
if chat_llm.model_id.startswith("anthropic.claude"):
|
256 |
+
return 100000
|
257 |
+
if chat_llm.model_id.startswith("mistral"):
|
258 |
+
if chat_llm.model_id.startswith("mistral.mixtral-8x7b"):
|
259 |
+
return 4096
|
260 |
+
else:
|
261 |
+
return 8192
|
262 |
+
return 4096
|
263 |
+
|
264 |
+
|
265 |
+
def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
266 |
+
done = False
|
267 |
+
while not done:
|
268 |
+
unique_docs = vectorstore.similarity_search(
|
269 |
+
search_query, k=top_k, callbacks=callbacks, verbose=True)
|
270 |
+
context = format_docs(unique_docs)
|
271 |
+
prompt = get_rag_prompt_template().format(query=question, context=context)
|
272 |
+
nbr_tokens = chat_llm.get_num_tokens(prompt)
|
273 |
+
if top_k <= 1 or nbr_tokens <= get_context_size(chat_llm) - 768:
|
274 |
+
done = True
|
275 |
+
else:
|
276 |
+
top_k = int(top_k * 0.75)
|
277 |
+
|
278 |
+
return prompt
|
279 |
|
280 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
281 |
+
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
282 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
283 |
+
return response.content
|