Spaces:
Sleeping
Sleeping
Added Firework.ai as provider. Better streamlit ui.
Browse files- README.md +1 -1
- dotenv.sample +1 -0
- requirements.txt +1 -0
- search_agent.py +4 -4
- search_agent_ui.py +32 -14
- web_crawler.py +2 -2
- web_rag.py +21 -3
README.md
CHANGED
@@ -15,7 +15,7 @@ license: apache-2.0
|
|
15 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
16 |
Does a bit what [Perplexity AI](https://www.perplexity.ai/) does.
|
17 |
|
18 |
-
The Streamlit GUI hosted on 🤗
|
19 |
|
20 |
This Python script and Streamli GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
|
21 |
|
|
|
15 |
This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
|
16 |
Does a bit what [Perplexity AI](https://www.perplexity.ai/) does.
|
17 |
|
18 |
+
The Streamlit GUI hosted on 🤗 Spaces is [available to test](https://huggingface.co/spaces/CyranoB/search_agent)
|
19 |
|
20 |
This Python script and Streamli GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
|
21 |
|
dotenv.sample
CHANGED
@@ -6,6 +6,7 @@ LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
|
|
6 |
|
7 |
OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
|
8 |
ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
|
|
|
9 |
GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
|
10 |
CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
|
11 |
COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
|
|
|
6 |
|
7 |
OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
|
8 |
ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
|
9 |
+
FIREWORKS_API_KEY=XXXXXXXXXXXXXXXXXXX
|
10 |
GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
|
11 |
CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
|
12 |
COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
|
requirements.txt
CHANGED
@@ -8,6 +8,7 @@ pdfplumber
|
|
8 |
python-dotenv
|
9 |
langchain
|
10 |
langchain-cohere
|
|
|
11 |
langchain_core
|
12 |
langchain_community
|
13 |
langchain_experimental
|
|
|
8 |
python-dotenv
|
9 |
langchain
|
10 |
langchain-cohere
|
11 |
+
langchain-fireworks
|
12 |
langchain_core
|
13 |
langchain_community
|
14 |
langchain_experimental
|
search_agent.py
CHANGED
@@ -16,7 +16,7 @@ Options:
|
|
16 |
--version Show version.
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
-
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
|
20 |
-m model --model=model Use a specific model
|
21 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
22 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
@@ -78,8 +78,8 @@ if __name__ == '__main__':
|
|
78 |
output=arguments["--output"]
|
79 |
query = arguments["SEARCH_QUERY"]
|
80 |
|
81 |
-
chat = wr.
|
82 |
-
console.log(f"Using {chat.
|
83 |
|
84 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
85 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
@@ -98,7 +98,7 @@ if __name__ == '__main__':
|
|
98 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
99 |
|
100 |
with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
|
101 |
-
vector_store = wc.vectorize(contents)
|
102 |
|
103 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
104 |
respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
|
|
|
16 |
--version Show version.
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
+
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
20 |
-m model --model=model Use a specific model
|
21 |
-n num --max_pages=num Max number of pages to retrieve [default: 10]
|
22 |
-o text --output=text Output format (choices: text, markdown) [default: markdown]
|
|
|
78 |
output=arguments["--output"]
|
79 |
query = arguments["SEARCH_QUERY"]
|
80 |
|
81 |
+
chat, embedding_model = wr.get_models(provider, model, temperature)
|
82 |
+
#console.log(f"Using {chat.model_name} on {provider}")
|
83 |
|
84 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
85 |
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
|
|
98 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
99 |
|
100 |
with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
|
101 |
+
vector_store = wc.vectorize(contents, embedding_model)
|
102 |
|
103 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
104 |
respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
|
search_agent_ui.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import datetime
|
|
|
2 |
|
3 |
import dotenv
|
4 |
import streamlit as st
|
@@ -13,11 +14,13 @@ import web_crawler as wc
|
|
13 |
dotenv.load_dotenv()
|
14 |
|
15 |
ls_tracer = LangChainTracer(
|
16 |
-
project_name="
|
17 |
client=Client()
|
18 |
)
|
19 |
|
|
|
20 |
class StreamHandler(BaseCallbackHandler):
|
|
|
21 |
def __init__(self, container, initial_text=""):
|
22 |
self.container = container
|
23 |
self.text = initial_text
|
@@ -26,16 +29,34 @@ class StreamHandler(BaseCallbackHandler):
|
|
26 |
self.text += token
|
27 |
self.container.markdown(self.text)
|
28 |
|
29 |
-
chat = wr.get_chat_llm(provider="cohere")
|
30 |
-
|
31 |
st.title("🔍 Simple Search Agent 💬")
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
if "messages" not in st.session_state:
|
34 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
35 |
|
36 |
-
if "input_disabled" not in st.session_state:
|
37 |
-
st.session_state["input_disabled"] = False
|
38 |
-
|
39 |
for message in st.session_state.messages:
|
40 |
st.chat_message(message["role"]).write(message["content"])
|
41 |
if message["role"] == "assistant" and 'message_id' in message:
|
@@ -46,26 +67,25 @@ for message in st.session_state.messages:
|
|
46 |
mime="text/plain"
|
47 |
)
|
48 |
|
49 |
-
if prompt := st.chat_input("Enter you instructions..."
|
50 |
-
|
51 |
-
st.session_state["input_disabled"] = True
|
52 |
-
|
53 |
st.chat_message("user").write(prompt)
|
54 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
55 |
|
|
|
|
|
56 |
with st.status("Thinking", expanded=True):
|
57 |
st.write("I first need to do some research")
|
58 |
|
59 |
optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
|
60 |
st.write(f"I should search the web for: {optimize_search_query}")
|
61 |
|
62 |
-
sources = wc.get_sources(optimize_search_query, max_pages=
|
63 |
|
64 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
65 |
contents = wc.get_links_contents(sources)
|
66 |
|
67 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
68 |
-
vector_store = wc.vectorize(contents)
|
69 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
70 |
|
71 |
rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
|
@@ -82,5 +102,3 @@ if prompt := st.chat_input("Enter you instructions...", disabled=st.session_stat
|
|
82 |
file_name=f"{message_id}.txt",
|
83 |
mime="text/plain"
|
84 |
)
|
85 |
-
st.session_state["input_disabled"] = False
|
86 |
-
|
|
|
1 |
import datetime
|
2 |
+
import os
|
3 |
|
4 |
import dotenv
|
5 |
import streamlit as st
|
|
|
14 |
dotenv.load_dotenv()
|
15 |
|
16 |
ls_tracer = LangChainTracer(
|
17 |
+
project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
|
18 |
client=Client()
|
19 |
)
|
20 |
|
21 |
+
|
22 |
class StreamHandler(BaseCallbackHandler):
|
23 |
+
"""Stream handler that appends tokens to container."""
|
24 |
def __init__(self, container, initial_text=""):
|
25 |
self.container = container
|
26 |
self.text = initial_text
|
|
|
29 |
self.text += token
|
30 |
self.container.markdown(self.text)
|
31 |
|
|
|
|
|
32 |
st.title("🔍 Simple Search Agent 💬")
|
33 |
|
34 |
+
if "providers" not in st.session_state:
|
35 |
+
providers = []
|
36 |
+
if os.getenv("COHERE_API_KEY"):
|
37 |
+
providers.append("cohere")
|
38 |
+
if os.getenv("OPENAI_API_KEY"):
|
39 |
+
providers.append("openai")
|
40 |
+
if os.getenv("GROQ_API_KEY"):
|
41 |
+
providers.append("groq")
|
42 |
+
if os.getenv("OLLAMA_API_KEY"):
|
43 |
+
providers.append("ollama")
|
44 |
+
if os.getenv("FIREWORKS_API_KEY"):
|
45 |
+
providers.append("fireworks")
|
46 |
+
if os.getenv("CREDENTIALS_PROFILE_NAME"):
|
47 |
+
providers.append("bedrock")
|
48 |
+
st.session_state["providers"] = providers
|
49 |
+
|
50 |
+
with st.sidebar:
|
51 |
+
st.write("Options")
|
52 |
+
model_provider = st.selectbox("🧠 Model provider 🧠", st.session_state["providers"])
|
53 |
+
temperature = st.slider("🌡️ Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
|
54 |
+
max_pages = st.slider("🔍 Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
|
55 |
+
top_k_documents = st.slider("📄 How many document extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
|
56 |
+
|
57 |
if "messages" not in st.session_state:
|
58 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
59 |
|
|
|
|
|
|
|
60 |
for message in st.session_state.messages:
|
61 |
st.chat_message(message["role"]).write(message["content"])
|
62 |
if message["role"] == "assistant" and 'message_id' in message:
|
|
|
67 |
mime="text/plain"
|
68 |
)
|
69 |
|
70 |
+
if prompt := st.chat_input("Enter you instructions..." ):
|
|
|
|
|
|
|
71 |
st.chat_message("user").write(prompt)
|
72 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
73 |
|
74 |
+
chat, embedding_model = wr.get_models(model_provider, temperature=temperature)
|
75 |
+
|
76 |
with st.status("Thinking", expanded=True):
|
77 |
st.write("I first need to do some research")
|
78 |
|
79 |
optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
|
80 |
st.write(f"I should search the web for: {optimize_search_query}")
|
81 |
|
82 |
+
sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
|
83 |
|
84 |
st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
|
85 |
contents = wc.get_links_contents(sources)
|
86 |
|
87 |
st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
|
88 |
+
vector_store = wc.vectorize(contents, embedding_model=embedding_model)
|
89 |
st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
|
90 |
|
91 |
rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
|
|
|
102 |
file_name=f"{message_id}.txt",
|
103 |
mime="text/plain"
|
104 |
)
|
|
|
|
web_crawler.py
CHANGED
@@ -124,7 +124,7 @@ def get_links_contents(sources, get_driver_func=None):
|
|
124 |
result['page_content'] = main_content
|
125 |
return results
|
126 |
|
127 |
-
def vectorize(contents):
|
128 |
documents = []
|
129 |
for content in contents:
|
130 |
try:
|
@@ -135,7 +135,7 @@ def vectorize(contents):
|
|
135 |
documents.append(doc)
|
136 |
except Exception as e:
|
137 |
print(f"[gray]Error processing content for {content['link']}: {e}")
|
138 |
-
semantic_chunker = SemanticChunker(
|
139 |
docs = semantic_chunker.split_documents(documents)
|
140 |
embeddings = OpenAIEmbeddings()
|
141 |
store = FAISS.from_documents(docs, embeddings)
|
|
|
124 |
result['page_content'] = main_content
|
125 |
return results
|
126 |
|
127 |
+
def vectorize(contents, embedding_model):
|
128 |
documents = []
|
129 |
for content in contents:
|
130 |
try:
|
|
|
135 |
documents.append(doc)
|
136 |
except Exception as e:
|
137 |
print(f"[gray]Error processing content for {content['link']}: {e}")
|
138 |
+
semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
|
139 |
docs = semantic_chunker.split_documents(documents)
|
140 |
embeddings = OpenAIEmbeddings()
|
141 |
store = FAISS.from_documents(docs, embeddings)
|
web_rag.py
CHANGED
@@ -29,40 +29,58 @@ from langchain.prompts.prompt import PromptTemplate
|
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
31 |
from langchain_cohere.chat_models import ChatCohere
|
|
|
|
|
32 |
from langchain_groq import ChatGroq
|
33 |
from langchain_openai import ChatOpenAI
|
|
|
34 |
from langchain_community.chat_models.bedrock import BedrockChat
|
|
|
35 |
from langchain_community.chat_models.ollama import ChatOllama
|
36 |
|
37 |
-
def
|
38 |
match provider:
|
39 |
case 'bedrock':
|
|
|
40 |
if model is None:
|
41 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
42 |
chat_llm = BedrockChat(
|
43 |
-
credentials_profile_name=
|
44 |
model_id=model,
|
45 |
model_kwargs={"temperature": temperature },
|
46 |
)
|
|
|
|
|
|
|
|
|
47 |
case 'openai':
|
48 |
if model is None:
|
49 |
model = "gpt-3.5-turbo"
|
50 |
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
|
|
51 |
case 'groq':
|
52 |
if model is None:
|
53 |
model = 'mixtral-8x7b-32768'
|
54 |
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
|
|
55 |
case 'ollama':
|
56 |
if model is None:
|
57 |
model = 'llama2'
|
58 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
|
|
59 |
case 'cohere':
|
60 |
if model is None:
|
61 |
model = 'command-r-plus'
|
62 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
case _:
|
64 |
raise ValueError(f"Unknown LLM provider {provider}")
|
65 |
-
return chat_llm
|
66 |
|
67 |
|
68 |
def get_optimized_search_messages(query):
|
|
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
31 |
from langchain_cohere.chat_models import ChatCohere
|
32 |
+
from langchain_cohere.embeddings import CohereEmbeddings
|
33 |
+
from langchain_fireworks.chat_models import ChatFireworks
|
34 |
from langchain_groq import ChatGroq
|
35 |
from langchain_openai import ChatOpenAI
|
36 |
+
from langchain_openai.embeddings import OpenAIEmbeddings
|
37 |
from langchain_community.chat_models.bedrock import BedrockChat
|
38 |
+
from langchain_community.embeddings.bedrock import BedrockEmbeddings
|
39 |
from langchain_community.chat_models.ollama import ChatOllama
|
40 |
|
41 |
+
def get_models(provider, model=None, temperature=0.0):
|
42 |
match provider:
|
43 |
case 'bedrock':
|
44 |
+
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
45 |
if model is None:
|
46 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
47 |
chat_llm = BedrockChat(
|
48 |
+
credentials_profile_name=credentials_profile_name,
|
49 |
model_id=model,
|
50 |
model_kwargs={"temperature": temperature },
|
51 |
)
|
52 |
+
embedding_model = BedrockEmbeddings(
|
53 |
+
model_id='cohere.embed-multilingual-v3',
|
54 |
+
credentials_profile_name=credentials_profile_name
|
55 |
+
)
|
56 |
case 'openai':
|
57 |
if model is None:
|
58 |
model = "gpt-3.5-turbo"
|
59 |
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
60 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
61 |
case 'groq':
|
62 |
if model is None:
|
63 |
model = 'mixtral-8x7b-32768'
|
64 |
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
65 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
66 |
case 'ollama':
|
67 |
if model is None:
|
68 |
model = 'llama2'
|
69 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
70 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
71 |
case 'cohere':
|
72 |
if model is None:
|
73 |
model = 'command-r-plus'
|
74 |
chat_llm = ChatCohere(model=model, temperature=temperature)
|
75 |
+
embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
|
76 |
+
case 'fireworks':
|
77 |
+
if model is None:
|
78 |
+
model = 'accounts/fireworks/models/mixtral-8x22b-instruct-preview'
|
79 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature)
|
80 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
81 |
case _:
|
82 |
raise ValueError(f"Unknown LLM provider {provider}")
|
83 |
+
return chat_llm, embedding_model
|
84 |
|
85 |
|
86 |
def get_optimized_search_messages(query):
|