lfoppiano commited on
Commit
ad304e2
·
1 Parent(s): d67901d

re-implement the conversational memory access

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -4,10 +4,12 @@ from pathlib import Path
4
  from typing import Union, Any
5
 
6
  from grobid_client.grobid_client import GrobidClient
7
- from langchain.chains import create_extraction_chain
8
- from langchain.chains.question_answering import load_qa_chain
 
9
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
10
  from langchain.retrievers import MultiQueryRetriever
 
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain.vectorstores import Chroma
13
  from tqdm import tqdm
@@ -23,15 +25,28 @@ class DocumentQAEngine:
23
  embeddings_map_from_md5 = {}
24
  embeddings_map_to_md5 = {}
25
 
 
 
 
 
 
 
 
26
  def __init__(self,
27
  llm,
28
  embedding_function,
29
  qa_chain_type="stuff",
30
  embeddings_root_path=None,
31
  grobid_url=None,
 
32
  ):
33
  self.embedding_function = embedding_function
34
  self.llm = llm
 
 
 
 
 
35
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
36
 
37
  if embeddings_root_path is not None:
@@ -87,14 +102,14 @@ class DocumentQAEngine:
87
  return self.embeddings_map_from_md5[md5]
88
 
89
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
90
- verbose=False, memory=None) -> (
91
  Any, str):
92
  # self.load_embeddings(self.embeddings_root_path)
93
 
94
  if verbose:
95
  print(query)
96
 
97
- response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
98
  response = response['output_text'] if 'output_text' in response else response
99
 
100
  if verbose:
@@ -144,21 +159,21 @@ class DocumentQAEngine:
144
 
145
  return parsed_output
146
 
147
- def _run_query(self, doc_id, query, context_size=4, memory=None):
148
  relevant_documents = self._get_context(doc_id, query, context_size)
149
- if memory:
150
- return self.chain.run(input_documents=relevant_documents,
151
- question=query)
152
- else:
153
- return self.chain.run(input_documents=relevant_documents,
154
- question=query,
155
- memory=memory)
156
- # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
157
 
158
  def _get_context(self, doc_id, query, context_size=4):
159
  db = self.embeddings_dict[doc_id]
160
  retriever = db.as_retriever(search_kwargs={"k": context_size})
161
  relevant_documents = retriever.get_relevant_documents(query)
 
 
162
  return relevant_documents
163
 
164
  def get_all_context_by_document(self, doc_id):
@@ -222,11 +237,15 @@ class DocumentQAEngine:
222
  hash = metadata[0]['hash']
223
 
224
  if hash not in self.embeddings_dict.keys():
225
- self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
 
 
226
  collection_name=hash)
227
  else:
228
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
229
- self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
 
 
230
  collection_name=hash)
231
 
232
  self.embeddings_root_path = None
 
4
  from typing import Union, Any
5
 
6
  from grobid_client.grobid_client import GrobidClient
7
+ from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
8
+ from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
9
+ map_rerank_prompt
10
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
11
  from langchain.retrievers import MultiQueryRetriever
12
+ from langchain.schema import Document
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain.vectorstores import Chroma
15
  from tqdm import tqdm
 
25
  embeddings_map_from_md5 = {}
26
  embeddings_map_to_md5 = {}
27
 
28
+ default_prompts = {
29
+ 'stuff': stuff_prompt,
30
+ 'refine': refine_prompts,
31
+ "map_reduce": map_reduce_prompt,
32
+ "map_rerank": map_rerank_prompt
33
+ }
34
+
35
  def __init__(self,
36
  llm,
37
  embedding_function,
38
  qa_chain_type="stuff",
39
  embeddings_root_path=None,
40
  grobid_url=None,
41
+ memory=None
42
  ):
43
  self.embedding_function = embedding_function
44
  self.llm = llm
45
+ # if memory:
46
+ # prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm)
47
+ # self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory)
48
+ # else:
49
+ self.memory = memory
50
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
51
 
52
  if embeddings_root_path is not None:
 
102
  return self.embeddings_map_from_md5[md5]
103
 
104
  def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
105
+ verbose=False) -> (
106
  Any, str):
107
  # self.load_embeddings(self.embeddings_root_path)
108
 
109
  if verbose:
110
  print(query)
111
 
112
+ response = self._run_query(doc_id, query, context_size=context_size)
113
  response = response['output_text'] if 'output_text' in response else response
114
 
115
  if verbose:
 
159
 
160
  return parsed_output
161
 
162
+ def _run_query(self, doc_id, query, context_size=4):
163
  relevant_documents = self._get_context(doc_id, query, context_size)
164
+ response = self.chain.run(input_documents=relevant_documents,
165
+ question=query)
166
+
167
+ if self.memory:
168
+ self.memory.save_context({"input": query}, {"output": response})
169
+ return response
 
 
170
 
171
  def _get_context(self, doc_id, query, context_size=4):
172
  db = self.embeddings_dict[doc_id]
173
  retriever = db.as_retriever(search_kwargs={"k": context_size})
174
  relevant_documents = retriever.get_relevant_documents(query)
175
+ if self.memory and len(self.memory.buffer_as_messages) > 0:
176
+ relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str)))
177
  return relevant_documents
178
 
179
  def get_all_context_by_document(self, doc_id):
 
237
  hash = metadata[0]['hash']
238
 
239
  if hash not in self.embeddings_dict.keys():
240
+ self.embeddings_dict[hash] = Chroma.from_texts(texts,
241
+ embedding=self.embedding_function,
242
+ metadatas=metadata,
243
  collection_name=hash)
244
  else:
245
  self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
246
+ self.embeddings_dict[hash] = Chroma.from_texts(texts,
247
+ embedding=self.embedding_function,
248
+ metadatas=metadata,
249
  collection_name=hash)
250
 
251
  self.embeddings_root_path = None
streamlit_app.py CHANGED
@@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
 
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
10
 
@@ -80,6 +81,7 @@ def clear_memory():
80
 
81
  # @st.cache_resource
82
  def init_qa(model, api_key=None):
 
83
  if model == 'chatgpt-3.5-turbo':
84
  if api_key:
85
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
@@ -108,7 +110,7 @@ def init_qa(model, api_key=None):
108
  st.stop()
109
  return
110
 
111
- return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
112
 
113
 
114
  @st.cache_resource
@@ -315,8 +317,7 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
315
  elif mode == "LLM":
316
  with st.spinner("Generating response..."):
317
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
318
- context_size=context_size,
319
- memory=st.session_state.memory)
320
 
321
  if not text_response:
322
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
@@ -335,11 +336,11 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
335
  st.write(text_response)
336
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
337
 
338
- for id in range(0, len(st.session_state.messages), 2):
339
- question = st.session_state.messages[id]['content']
340
- if len(st.session_state.messages) > id + 1:
341
- answer = st.session_state.messages[id + 1]['content']
342
- st.session_state.memory.save_context({"input": question}, {"output": answer})
343
 
344
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
345
  play_old_messages()
 
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
+ from langchain.callbacks import PromptLayerCallbackHandler
9
  from langchain.llms.huggingface_hub import HuggingFaceHub
10
  from langchain.memory import ConversationBufferWindowMemory
11
 
 
81
 
82
  # @st.cache_resource
83
  def init_qa(model, api_key=None):
84
+ ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
85
  if model == 'chatgpt-3.5-turbo':
86
  if api_key:
87
  chat = ChatOpenAI(model_name="gpt-3.5-turbo",
 
110
  st.stop()
111
  return
112
 
113
+ return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
114
 
115
 
116
  @st.cache_resource
 
317
  elif mode == "LLM":
318
  with st.spinner("Generating response..."):
319
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
320
+ context_size=context_size)
 
321
 
322
  if not text_response:
323
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
 
336
  st.write(text_response)
337
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
338
 
339
+ # if len(st.session_state.messages) > 1:
340
+ # last_answer = st.session_state.messages[len(st.session_state.messages)-1]
341
+ # if last_answer['role'] == "assistant":
342
+ # last_question = st.session_state.messages[len(st.session_state.messages)-2]
343
+ # st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']})
344
 
345
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
346
  play_old_messages()