Spaces:
Paused
Paused
Commit
·
aa7bffb
1
Parent(s):
3a64521
putt @st .cache_resource on conversational_retriever
Browse files
app.py
CHANGED
@@ -102,6 +102,23 @@ def load_llm_model():
|
|
102 |
return llm
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def load_retriever(llm, db):
|
106 |
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
|
107 |
retriever=db.as_retriever(),
|
@@ -213,24 +230,13 @@ embedding_model = load_embedding_model()
|
|
213 |
vector_database = load_faiss_index()
|
214 |
llm_model = load_llm_model()
|
215 |
qa_retriever = load_retriever(llm= llm_model, db= vector_database)
|
216 |
-
|
217 |
-
|
218 |
print("all load done")
|
219 |
|
220 |
#Addional things for Conversation flows
|
221 |
-
question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
|
222 |
-
doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
|
223 |
-
memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
|
224 |
|
225 |
|
226 |
|
227 |
-
conversational_qa_memory_retriever = ConversationalRetrievalChain(
|
228 |
-
retriever=vector_database.as_retriever(),
|
229 |
-
question_generator=question_generator,
|
230 |
-
combine_docs_chain=doc_chain,
|
231 |
-
return_source_documents=True,
|
232 |
-
memory = memory,
|
233 |
-
get_chat_history=lambda h :h)
|
234 |
|
235 |
|
236 |
|
|
|
102 |
return llm
|
103 |
|
104 |
|
105 |
+
@st.cache_resource
|
106 |
+
def load_conversational_qa_memory_retriever(llm_model):
|
107 |
+
question_generator = LLMChain(llm=llm_model, prompt=CONDENSE_QUESTION_PROMPT)
|
108 |
+
doc_chain = load_qa_chain(llm_model, chain_type="stuff", prompt = PROMPT)
|
109 |
+
memory = ConversationBufferWindowMemory(k = 3, memory_key="chat_history", return_messages=True, output_key='answer')
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
conversational_qa_memory_retriever = ConversationalRetrievalChain(
|
114 |
+
retriever=vector_database.as_retriever(),
|
115 |
+
question_generator=question_generator,
|
116 |
+
combine_docs_chain=doc_chain,
|
117 |
+
return_source_documents=True,
|
118 |
+
memory = memory,
|
119 |
+
get_chat_history=lambda h :h)
|
120 |
+
return conversational_qa_memory_retriever
|
121 |
+
|
122 |
def load_retriever(llm, db):
|
123 |
qa_retriever = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",
|
124 |
retriever=db.as_retriever(),
|
|
|
230 |
vector_database = load_faiss_index()
|
231 |
llm_model = load_llm_model()
|
232 |
qa_retriever = load_retriever(llm= llm_model, db= vector_database)
|
233 |
+
conversational_qa_memory_retriever = load_conversational_qa_memory_retriever(llm_model)
|
|
|
234 |
print("all load done")
|
235 |
|
236 |
#Addional things for Conversation flows
|
|
|
|
|
|
|
237 |
|
238 |
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
|
242 |
|