Carlosito16 commited on
Commit
aa7bffb
·
1 Parent(s): 3a64521

putt @st .cache_resource on conversational_retriever

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -102,6 +102,23 @@ def load_llm_model():
102
  return llm
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def load_retriever(llm, db):
106
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
107
  retriever=db.as_retriever(),
@@ -213,24 +230,13 @@ embedding_model = load_embedding_model()
213
  vector_database = load_faiss_index()
214
  llm_model = load_llm_model()
215
  qa_retriever = load_retriever(llm= llm_model, db= vector_database)
216
-
217
-
218
  print("all load done")
219
 
220
  #Addional things for Conversation flows
221
- question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
222
- doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
223
- memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
224
 
225
 
226
 
227
- conversational_qa_memory_retriever = ConversationalRetrievalChain(
228
- retriever=vector_database.as_retriever(),
229
- question_generator=question_generator,
230
- combine_docs_chain=doc_chain,
231
- return_source_documents=True,
232
- memory = memory,
233
- get_chat_history=lambda h :h)
234
 
235
 
236
 
 
102
  return llm
103
 
104
 
105
+ @st.cache_resource
106
+ def load_conversational_qa_memory_retriever(llm_model):
107
+ question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
108
+ doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
109
+ memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
110
+
111
+
112
+
113
+ conversational_qa_memory_retriever = ConversationalRetrievalChain(
114
+ retriever=vector_database.as_retriever(),
115
+ question_generator=question_generator,
116
+ combine_docs_chain=doc_chain,
117
+ return_source_documents=True,
118
+ memory = memory,
119
+ get_chat_history=lambda h :h)
120
+ return conversational_qa_memory_retriever
121
+
122
  def load_retriever(llm, db):
123
  qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
124
  retriever=db.as_retriever(),
 
230
  vector_database = load_faiss_index()
231
  llm_model = load_llm_model()
232
  qa_retriever = load_retriever(llm= llm_model, db= vector_database)
233
+ conversational_qa_memory_retriever = load_conversational_qa_memory_retriever(llm_model)
 
234
  print("all load done")
235
 
236
  #Addional things for Conversation flows
 
 
 
237
 
238
 
239
 
 
 
 
 
 
 
 
240
 
241
 
242