Eddie Pick commited on
Commit
2f49709
Β·
1 Parent(s): 52f91c6
Files changed (3) hide show
  1. requirements.txt +2 -1
  2. search_agent_ui.py +1 -1
  3. web_rag.py +22 -24
requirements.txt CHANGED
@@ -12,10 +12,11 @@ langchain
12
  langchain-aws
13
  langchain-fireworks
14
  langchain_core
15
- langchain_cohere
16
  langchain_community
17
  langchain_experimental
18
  langchain_openai
 
19
  langchain_groq
20
  langsmith
21
  schema
 
12
  langchain-aws
13
  langchain-fireworks
14
  langchain_core
15
+ langchain-cohere
16
  langchain_community
17
  langchain_experimental
18
  langchain_openai
19
+ langchain-ollama
20
  langchain_groq
21
  langsmith
22
  schema
search_agent_ui.py CHANGED
@@ -75,7 +75,7 @@ with st.sidebar.expander("Options", expanded=False):
75
  temperature = st.slider("Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
  max_pages = st.slider("Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
77
  top_k_documents = st.slider("Nbr of doc extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
78
- reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a write, then comments and then rewrite")
79
 
80
  with st.sidebar.expander("Links", expanded=False):
81
  links_md = st.markdown("")
 
75
  temperature = st.slider("Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
  max_pages = st.slider("Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
77
  top_k_documents = st.slider("Nbr of doc extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
78
+ reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a draft, then comments and then rewrite")
79
 
80
  with st.sidebar.expander("Links", expanded=False):
81
  links_md = st.markdown("")
web_rag.py CHANGED
@@ -28,16 +28,14 @@ from langchain.prompts.chat import (
28
  from langchain.prompts.prompt import PromptTemplate
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
31
- from langchain_aws import ChatBedrock
32
- from langchain_cohere.chat_models import ChatCohere
33
- from langchain_cohere.embeddings import CohereEmbeddings
34
  from langchain_fireworks.chat_models import ChatFireworks
35
- #from langchain_groq import ChatGroq
36
  from langchain_groq.chat_models import ChatGroq
37
  from langchain_openai import ChatOpenAI
38
  from langchain_openai.embeddings import OpenAIEmbeddings
39
- from langchain_community.embeddings.bedrock import BedrockEmbeddings
40
- from langchain_community.chat_models.ollama import ChatOllama
41
 
42
  def get_models(provider, model=None, temperature=0.0):
43
  match provider:
@@ -45,10 +43,10 @@ def get_models(provider, model=None, temperature=0.0):
45
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
46
  if model is None:
47
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
48
- chat_llm = ChatBedrock(
49
  credentials_profile_name=credentials_profile_name,
50
- model_id=model,
51
- model_kwargs={"temperature": temperature, "max_tokens":4096 },
52
  )
53
  embedding_model = BedrockEmbeddings(
54
  model_id='cohere.embed-multilingual-v3',
@@ -57,7 +55,7 @@ def get_models(provider, model=None, temperature=0.0):
57
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
58
  case 'openai':
59
  if model is None:
60
- model = "gpt-3.5-turbo"
61
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
62
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
63
  case 'groq':
@@ -67,7 +65,7 @@ def get_models(provider, model=None, temperature=0.0):
67
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
68
  case 'ollama':
69
  if model is None:
70
- model = 'llama2'
71
  chat_llm = ChatOllama(model=model, temperature=temperature)
72
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
73
  case 'cohere':
@@ -78,9 +76,8 @@ def get_models(provider, model=None, temperature=0.0):
78
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
79
  case 'fireworks':
80
  if model is None:
81
- #model = 'accounts/fireworks/models/dbrx-instruct'
82
- model = 'accounts/fireworks/models/llama-v3-70b-instruct'
83
- chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=8192)
84
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
85
  case _:
86
  raise ValueError(f"Unknown LLM provider {provider}")
@@ -162,7 +159,7 @@ def optimize_search_query(chat_llm, query, callbacks=[]):
162
  messages = get_optimized_search_messages(query)
163
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
164
  optimized_search_query = response.content
165
- return optimized_search_query.strip('"').split("**", 1)[0]
166
 
167
 
168
  def get_rag_prompt_template():
@@ -242,23 +239,24 @@ def get_context_size(chat_llm):
242
  else:
243
  return 16385
244
  if isinstance(chat_llm, ChatFireworks):
245
- return 8192
246
  if isinstance(chat_llm, ChatGroq):
247
- return 37862
248
  if isinstance(chat_llm, ChatOllama):
249
- return 8192
250
  if isinstance(chat_llm, ChatCohere):
251
  return 128000
252
- if isinstance(chat_llm, ChatBedrock):
 
 
253
  if chat_llm.model_id.startswith("anthropic.claude-3"):
254
  return 200000
255
  if chat_llm.model_id.startswith("anthropic.claude"):
256
  return 100000
257
  if chat_llm.model_id.startswith("mistral"):
258
- if chat_llm.model_id.startswith("mistral.mixtral-8x7b"):
259
- return 4096
260
- else:
261
- return 8192
262
  return 4096
263
 
264
 
@@ -280,4 +278,4 @@ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10,
280
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
281
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
282
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
283
- return response.content
 
28
  from langchain.prompts.prompt import PromptTemplate
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
31
+ from langchain_aws import BedrockEmbeddings
32
+ from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse
33
+ from langchain_cohere import ChatCohere
34
  from langchain_fireworks.chat_models import ChatFireworks
 
35
  from langchain_groq.chat_models import ChatGroq
36
  from langchain_openai import ChatOpenAI
37
  from langchain_openai.embeddings import OpenAIEmbeddings
38
+ from langchain_ollama.chat_models import ChatOllama
 
39
 
40
  def get_models(provider, model=None, temperature=0.0):
41
  match provider:
 
43
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
44
  if model is None:
45
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
46
+ chat_llm = ChatBedrockConverse(
47
  credentials_profile_name=credentials_profile_name,
48
+ model=model,
49
+ temperature=temperature,
50
  )
51
  embedding_model = BedrockEmbeddings(
52
  model_id='cohere.embed-multilingual-v3',
 
55
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
56
  case 'openai':
57
  if model is None:
58
+ model = "gpt-4o-mini"
59
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
60
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
61
  case 'groq':
 
65
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
66
  case 'ollama':
67
  if model is None:
68
+ model = 'llama3.1'
69
  chat_llm = ChatOllama(model=model, temperature=temperature)
70
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
71
  case 'cohere':
 
76
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
77
  case 'fireworks':
78
  if model is None:
79
+ model = 'accounts/fireworks/models/llama-v3p1-8b-instruct'
80
+ chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=120000)
 
81
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
82
  case _:
83
  raise ValueError(f"Unknown LLM provider {provider}")
 
159
  messages = get_optimized_search_messages(query)
160
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
161
  optimized_search_query = response.content
162
+ return optimized_search_query.strip('"').split("**", 1)[0].strip()
163
 
164
 
165
  def get_rag_prompt_template():
 
239
  else:
240
  return 16385
241
  if isinstance(chat_llm, ChatFireworks):
242
+ 32768
243
  if isinstance(chat_llm, ChatGroq):
244
+ return 32768
245
  if isinstance(chat_llm, ChatOllama):
246
+ return 120000
247
  if isinstance(chat_llm, ChatCohere):
248
  return 128000
249
+ if isinstance(chat_llm, ChatBedrockConverse):
250
+ if chat_llm.model_id.startswith("meta.llama3-1"):
251
+ return 128000
252
  if chat_llm.model_id.startswith("anthropic.claude-3"):
253
  return 200000
254
  if chat_llm.model_id.startswith("anthropic.claude"):
255
  return 100000
256
  if chat_llm.model_id.startswith("mistral"):
257
+ if chat_llm.model_id.startswith("mistral.mistral.mistral-large-2407"):
258
+ return 128000
259
+ return 32000
 
260
  return 4096
261
 
262
 
 
278
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
279
  prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
280
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
281
+ return response.content