bstraehle commited on
Commit
4a5b4c3
·
1 Parent(s): 8485707

Update rag_langchain.py

Browse files
Files changed (1) hide show
  1. rag_langchain.py +5 -5
rag_langchain.py CHANGED
@@ -23,11 +23,11 @@ class LangChainRAG(BaseRAG):
23
  YOUTUBE_DIR = "/data/yt"
24
 
25
  LLM_CHAIN_PROMPT = PromptTemplate(
26
- input_variables = ["query_str"],
27
  template = os.environ["LLM_TEMPLATE"])
28
  RAG_CHAIN_PROMPT = PromptTemplate(
29
- input_variables = ["context_str", "query_str"],
30
- template = os.environ["RAG_TEMPLATE_2"])
31
 
32
  def load_documents(self):
33
  docs = []
@@ -114,7 +114,7 @@ class LangChainRAG(BaseRAG):
114
  )
115
 
116
  with get_openai_callback() as callback:
117
- completion = llm_chain.generate([{"query_str": prompt}])
118
 
119
  return completion, llm_chain, callback
120
 
@@ -130,6 +130,6 @@ class LangChainRAG(BaseRAG):
130
  )
131
 
132
  with get_openai_callback() as callback:
133
- completion = rag_chain({"query_str": prompt})
134
 
135
  return completion, rag_chain, callback
 
23
  YOUTUBE_DIR = "/data/yt"
24
 
25
  LLM_CHAIN_PROMPT = PromptTemplate(
26
+ input_variables = ["question"],
27
  template = os.environ["LLM_TEMPLATE"])
28
  RAG_CHAIN_PROMPT = PromptTemplate(
29
+ input_variables = ["context", "question"],
30
+ template = os.environ["RAG_TEMPLATE"])
31
 
32
  def load_documents(self):
33
  docs = []
 
114
  )
115
 
116
  with get_openai_callback() as callback:
117
+ completion = llm_chain.generate([{"question": prompt}])
118
 
119
  return completion, llm_chain, callback
120
 
 
130
  )
131
 
132
  with get_openai_callback() as callback:
133
+ completion = rag_chain({"query": prompt})
134
 
135
  return completion, rag_chain, callback