Spaces:
Sleeping
Sleeping
Eddie Pick
commited on
Commit
Β·
2f49709
1
Parent(s):
52f91c6
fixes
Browse files- requirements.txt +2 -1
- search_agent_ui.py +1 -1
- web_rag.py +22 -24
requirements.txt
CHANGED
@@ -12,10 +12,11 @@ langchain
|
|
12 |
langchain-aws
|
13 |
langchain-fireworks
|
14 |
langchain_core
|
15 |
-
|
16 |
langchain_community
|
17 |
langchain_experimental
|
18 |
langchain_openai
|
|
|
19 |
langchain_groq
|
20 |
langsmith
|
21 |
schema
|
|
|
12 |
langchain-aws
|
13 |
langchain-fireworks
|
14 |
langchain_core
|
15 |
+
langchain-cohere
|
16 |
langchain_community
|
17 |
langchain_experimental
|
18 |
langchain_openai
|
19 |
+
langchain-ollama
|
20 |
langchain_groq
|
21 |
langsmith
|
22 |
schema
|
search_agent_ui.py
CHANGED
@@ -75,7 +75,7 @@ with st.sidebar.expander("Options", expanded=False):
|
|
75 |
temperature = st.slider("Model temperature π‘οΈ", 0.0, 1.0, 0.1, help="The higher the more creative")
|
76 |
max_pages = st.slider("Max pages to retrieve π", 1, 20, 15, help="How many web pages to retrive from the internet")
|
77 |
top_k_documents = st.slider("Nbr of doc extracts to consider π", 1, 20, 5, help="How many of the top extracts to consider")
|
78 |
-
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode βοΈ", value=False, help="First generate a
|
79 |
|
80 |
with st.sidebar.expander("Links", expanded=False):
|
81 |
links_md = st.markdown("")
|
|
|
75 |
temperature = st.slider("Model temperature π‘οΈ", 0.0, 1.0, 0.1, help="The higher the more creative")
|
76 |
max_pages = st.slider("Max pages to retrieve π", 1, 20, 15, help="How many web pages to retrive from the internet")
|
77 |
top_k_documents = st.slider("Nbr of doc extracts to consider π", 1, 20, 5, help="How many of the top extracts to consider")
|
78 |
+
reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode βοΈ", value=False, help="First generate a draft, then comments and then rewrite")
|
79 |
|
80 |
with st.sidebar.expander("Links", expanded=False):
|
81 |
links_md = st.markdown("")
|
web_rag.py
CHANGED
@@ -28,16 +28,14 @@ from langchain.prompts.chat import (
|
|
28 |
from langchain.prompts.prompt import PromptTemplate
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
31 |
-
from langchain_aws import
|
32 |
-
from
|
33 |
-
from langchain_cohere
|
34 |
from langchain_fireworks.chat_models import ChatFireworks
|
35 |
-
#from langchain_groq import ChatGroq
|
36 |
from langchain_groq.chat_models import ChatGroq
|
37 |
from langchain_openai import ChatOpenAI
|
38 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
39 |
-
from
|
40 |
-
from langchain_community.chat_models.ollama import ChatOllama
|
41 |
|
42 |
def get_models(provider, model=None, temperature=0.0):
|
43 |
match provider:
|
@@ -45,10 +43,10 @@ def get_models(provider, model=None, temperature=0.0):
|
|
45 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
46 |
if model is None:
|
47 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
48 |
-
chat_llm =
|
49 |
credentials_profile_name=credentials_profile_name,
|
50 |
-
|
51 |
-
|
52 |
)
|
53 |
embedding_model = BedrockEmbeddings(
|
54 |
model_id='cohere.embed-multilingual-v3',
|
@@ -57,7 +55,7 @@ def get_models(provider, model=None, temperature=0.0):
|
|
57 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
58 |
case 'openai':
|
59 |
if model is None:
|
60 |
-
model = "gpt-
|
61 |
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
62 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
63 |
case 'groq':
|
@@ -67,7 +65,7 @@ def get_models(provider, model=None, temperature=0.0):
|
|
67 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
68 |
case 'ollama':
|
69 |
if model is None:
|
70 |
-
model = '
|
71 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
72 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
73 |
case 'cohere':
|
@@ -78,9 +76,8 @@ def get_models(provider, model=None, temperature=0.0):
|
|
78 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
79 |
case 'fireworks':
|
80 |
if model is None:
|
81 |
-
|
82 |
-
|
83 |
-
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=8192)
|
84 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
85 |
case _:
|
86 |
raise ValueError(f"Unknown LLM provider {provider}")
|
@@ -162,7 +159,7 @@ def optimize_search_query(chat_llm, query, callbacks=[]):
|
|
162 |
messages = get_optimized_search_messages(query)
|
163 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
164 |
optimized_search_query = response.content
|
165 |
-
return optimized_search_query.strip('"').split("**", 1)[0]
|
166 |
|
167 |
|
168 |
def get_rag_prompt_template():
|
@@ -242,23 +239,24 @@ def get_context_size(chat_llm):
|
|
242 |
else:
|
243 |
return 16385
|
244 |
if isinstance(chat_llm, ChatFireworks):
|
245 |
-
|
246 |
if isinstance(chat_llm, ChatGroq):
|
247 |
-
return
|
248 |
if isinstance(chat_llm, ChatOllama):
|
249 |
-
return
|
250 |
if isinstance(chat_llm, ChatCohere):
|
251 |
return 128000
|
252 |
-
if isinstance(chat_llm,
|
|
|
|
|
253 |
if chat_llm.model_id.startswith("anthropic.claude-3"):
|
254 |
return 200000
|
255 |
if chat_llm.model_id.startswith("anthropic.claude"):
|
256 |
return 100000
|
257 |
if chat_llm.model_id.startswith("mistral"):
|
258 |
-
if chat_llm.model_id.startswith("mistral.
|
259 |
-
return
|
260 |
-
|
261 |
-
return 8192
|
262 |
return 4096
|
263 |
|
264 |
|
@@ -280,4 +278,4 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
|
|
280 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
281 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
282 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
283 |
-
return response.content
|
|
|
28 |
from langchain.prompts.prompt import PromptTemplate
|
29 |
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
|
31 |
+
from langchain_aws import BedrockEmbeddings
|
32 |
+
from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
|
33 |
+
from langchain_cohere import ChatCohere
|
34 |
from langchain_fireworks.chat_models import ChatFireworks
|
|
|
35 |
from langchain_groq.chat_models import ChatGroq
|
36 |
from langchain_openai import ChatOpenAI
|
37 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
38 |
+
from langchain_ollama.chat_models import ChatOllama
|
|
|
39 |
|
40 |
def get_models(provider, model=None, temperature=0.0):
|
41 |
match provider:
|
|
|
43 |
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
|
44 |
if model is None:
|
45 |
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
46 |
+
chat_llm = ChatBedrockConverse(
|
47 |
credentials_profile_name=credentials_profile_name,
|
48 |
+
model=model,
|
49 |
+
temperature=temperature,
|
50 |
)
|
51 |
embedding_model = BedrockEmbeddings(
|
52 |
model_id='cohere.embed-multilingual-v3',
|
|
|
55 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
56 |
case 'openai':
|
57 |
if model is None:
|
58 |
+
model = "gpt-4o-mini"
|
59 |
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
60 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
61 |
case 'groq':
|
|
|
65 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
66 |
case 'ollama':
|
67 |
if model is None:
|
68 |
+
model = 'llama3.1'
|
69 |
chat_llm = ChatOllama(model=model, temperature=temperature)
|
70 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
71 |
case 'cohere':
|
|
|
76 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
77 |
case 'fireworks':
|
78 |
if model is None:
|
79 |
+
model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
|
80 |
+
chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
|
|
|
81 |
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
82 |
case _:
|
83 |
raise ValueError(f"Unknown LLM provider {provider}")
|
|
|
159 |
messages = get_optimized_search_messages(query)
|
160 |
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
161 |
optimized_search_query = response.content
|
162 |
+
return optimized_search_query.strip('"').split("**", 1)[0].strip()
|
163 |
|
164 |
|
165 |
def get_rag_prompt_template():
|
|
|
239 |
else:
|
240 |
return 16385
|
241 |
if isinstance(chat_llm, ChatFireworks):
|
242 |
+
32768
|
243 |
if isinstance(chat_llm, ChatGroq):
|
244 |
+
return 32768
|
245 |
if isinstance(chat_llm, ChatOllama):
|
246 |
+
return 120000
|
247 |
if isinstance(chat_llm, ChatCohere):
|
248 |
return 128000
|
249 |
+
if isinstance(chat_llm, ChatBedrockConverse):
|
250 |
+
if chat_llm.model_id.startswith("meta.llama3-1"):
|
251 |
+
return 128000
|
252 |
if chat_llm.model_id.startswith("anthropic.claude-3"):
|
253 |
return 200000
|
254 |
if chat_llm.model_id.startswith("anthropic.claude"):
|
255 |
return 100000
|
256 |
if chat_llm.model_id.startswith("mistral"):
|
257 |
+
if chat_llm.model_id.startswith("mistral.mistral.mistral-large-2407"):
|
258 |
+
return 128000
|
259 |
+
return 32000
|
|
|
260 |
return 4096
|
261 |
|
262 |
|
|
|
278 |
def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
|
279 |
prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
|
280 |
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
281 |
+
return response.content
|