Anirudh1993 commited on
Commit
5a7c08c
Β·
verified Β·
1 Parent(s): 3570ead

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.check_pydantic_version import use_pydantic_v1
2
+ use_pydantic_v1() #This function has to be run before importing haystack. as haystack requires pydantic v1 to run
3
+
4
+
5
+ from operator import index
6
+ import streamlit as st
7
+ import logging
8
+ import os
9
+
10
+ from annotated_text import annotation
11
+ from json import JSONDecodeError
12
+ from markdown import markdown
13
+ from utils.config import parser
14
+ from utils.haystack import start_document_store, query, initialize_pipeline, start_preprocessor_node, start_retriever, start_reader
15
+ from utils.ui import reset_results, set_initial_state
16
+ import pandas as pd
17
+ import haystack
18
+
19
+ from datetime import datetime
20
+ import streamlit.components.v1 as components
21
+ import streamlit_authenticator as stauth
22
+ import pickle
23
+
24
+ from streamlit_modal import Modal
25
+ import numpy as np
26
+
27
+
28
+
29
+ names = ['mlreply']
30
+ usernames = ['docwhiz']
31
+ with open('hashed_password.pkl','rb') as f:
32
+ hashed_passwords = pickle.load(f)
33
+
34
+
35
+
36
+ # Whether the file upload should be enabled or not
37
+ DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
38
+
39
+
40
+ def show_documents_list(retrieved_documents):
41
+ data = []
42
+ for i, document in enumerate(retrieved_documents):
43
+ data.append([document.meta['name']])
44
+ df = pd.DataFrame(data, columns=['Uploaded Document Name'])
45
+ df.drop_duplicates(subset=['Uploaded Document Name'], inplace=True)
46
+ df.index = np.arange(1, len(df) + 1)
47
+ return df
48
+
49
+ # Define a function to handle file uploads
50
+ def upload_files():
51
+ uploaded_files = upload_container.file_uploader(
52
+ "upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden", key=1
53
+ )
54
+ return uploaded_files
55
+
56
+
57
+ # Define a function to process a single file
58
+ def process_file(data_file, preprocesor, document_store):
59
+ # read file and add content
60
+ file_contents = data_file.read().decode("utf-8")
61
+ docs = [{
62
+ 'content': str(file_contents),
63
+ 'meta': {'name': str(data_file.name)}
64
+ }]
65
+ try:
66
+ names = [item.meta.get('name') for item in document_store.get_all_documents()]
67
+ #if args.store == 'inmemory':
68
+ # doc = converter.convert(file_path=files, meta=None)
69
+ if data_file.name in names:
70
+ print(f"{data_file.name} already processed")
71
+ else:
72
+ print(f'preprocessing uploaded doc {data_file.name}.......')
73
+ #print(data_file.read().decode("utf-8"))
74
+ preprocessed_docs = preprocesor.process(docs)
75
+ print('writing to document store.......')
76
+ document_store.write_documents(preprocessed_docs)
77
+ print('updating emebdding.......')
78
+ document_store.update_embeddings(retriever)
79
+ except Exception as e:
80
+ print(e)
81
+
82
+
83
+ # Define a function to upload the documents to haystack document store
84
+ def upload_document():
85
+ if data_files is not None:
86
+ for data_file in data_files:
87
+ # Upload file
88
+ if data_file:
89
+ try:
90
+ #raw_json = upload_doc(data_file)
91
+ # Call the process_file function for each uploaded file
92
+ if args.store == 'inmemory':
93
+ processed_data = process_file(data_file, preprocesor, document_store)
94
+ #upload_container.write(str(data_file.name) + "    βœ… ")
95
+ except Exception as e:
96
+ upload_container.write(str(data_file.name) + "    ❌ ")
97
+ upload_container.write("_This file could not be parsed, see the logs for more information._")
98
+
99
+ # Define a function to reset the documents in haystack document store
100
+ def reset_documents():
101
+ print('\nReseting documents list at ' + str(datetime.now()) + '\n')
102
+ st.session_state.data_files = None
103
+ document_store.delete_documents()
104
+
105
+ try:
106
+ args = parser.parse_args()
107
+ preprocesor = start_preprocessor_node()
108
+ document_store = start_document_store(type=args.store)
109
+ document_store.get_all_documents()
110
+ retriever = start_retriever(document_store)
111
+ reader = start_reader()
112
+ st.set_page_config(
113
+ page_title="MLReplySearch",
114
+ layout="centered",
115
+ page_icon=":shark:",
116
+ menu_items={
117
+ 'Get Help': 'https://www.extremelycoolapp.com/help',
118
+ 'Report a bug': "https://www.extremelycoolapp.com/bug",
119
+ 'About': "# This is a header. This is an *extremely* cool app!"
120
+ }
121
+ )
122
+ st.sidebar.image("ml_logo.png", use_column_width=True)
123
+
124
+ authenticator = stauth.Authenticate(names, usernames, hashed_passwords, "document_search", "random_text", cookie_expiry_days=1)
125
+
126
+ name, authentication_status, username = authenticator.login("Login", "main")
127
+
128
+ if authentication_status == False:
129
+ st.error("Username/Password is incorrect")
130
+
131
+ if authentication_status == None:
132
+ st.warning("Please enter your username and password")
133
+
134
+ if authentication_status:
135
+
136
+ # Sidebar for Task Selection
137
+ st.sidebar.header('Options:')
138
+
139
+ # OpenAI Key Input
140
+ openai_key = st.sidebar.text_input("Enter LLM-authorization Key:", type="password")
141
+
142
+ if openai_key:
143
+ task_options = ['Extractive', 'Generative']
144
+ else:
145
+ task_options = ['Extractive']
146
+
147
+ task_selection = st.sidebar.radio('Select the task:', task_options)
148
+
149
+ # Check the task and initialize pipeline accordingly
150
+ if task_selection == 'Extractive':
151
+ pipeline_extractive = initialize_pipeline("extractive", document_store, retriever, reader)
152
+ elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it
153
+ pipeline_rag = initialize_pipeline("rag", document_store, retriever, reader, openai_key=openai_key)
154
+
155
+
156
+ set_initial_state()
157
+
158
+ modal = Modal("Manage Files", key="demo-modal")
159
+ open_modal = st.sidebar.button("Manage Files", use_container_width=True)
160
+ if open_modal:
161
+ modal.open()
162
+
163
+ st.write('# ' + args.name)
164
+ if modal.is_open():
165
+ with modal.container():
166
+ if not DISABLE_FILE_UPLOAD:
167
+ upload_container = st.container()
168
+ data_files = upload_files()
169
+ upload_document()
170
+ st.session_state.sidebar_state = 'collapsed'
171
+ st.table(show_documents_list(document_store.get_all_documents()))
172
+
173
+ # File upload block
174
+ # if not DISABLE_FILE_UPLOAD:
175
+ # upload_container = st.sidebar.container()
176
+ # upload_container.write("## File Upload:")
177
+ # data_files = upload_files()
178
+ # Button to update files in the documentStore
179
+ # upload_container.button('Upload Files', on_click=upload_document, args=())
180
+
181
+ # Button to reset the documents in DocumentStore
182
+ st.sidebar.button("Reset documents", on_click=reset_documents, args=(), use_container_width=True)
183
+
184
+ if "question" not in st.session_state:
185
+ st.session_state.question = ""
186
+ # Search bar
187
+ question = st.text_input("Question", value=st.session_state.question, max_chars=100, on_change=reset_results, label_visibility="hidden")
188
+
189
+ run_pressed = st.button("Run")
190
+
191
+ run_query = (
192
+ run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
193
+ )
194
+
195
+ # Get results for query
196
+ if run_query and question:
197
+ if task_selection == 'Extractive':
198
+ reset_results()
199
+ st.session_state.question = question
200
+ with st.spinner("πŸ”Ž    Running your pipeline"):
201
+ try:
202
+ st.session_state.results_extractive = query(pipeline_extractive, question)
203
+ st.session_state.task = task_selection
204
+ except JSONDecodeError as je:
205
+ st.error(
206
+ "πŸ‘“    An error occurred reading the results. Is the document store working?"
207
+ )
208
+ except Exception as e:
209
+ logging.exception(e)
210
+ st.error("🐞    An error occurred during the request.")
211
+
212
+ elif task_selection == 'Generative':
213
+ reset_results()
214
+ st.session_state.question = question
215
+ with st.spinner("πŸ”Ž    Running your pipeline"):
216
+ try:
217
+ st.session_state.results_generative = query(pipeline_rag, question)
218
+ st.session_state.task = task_selection
219
+ except JSONDecodeError as je:
220
+ st.error(
221
+ "πŸ‘“    An error occurred reading the results. Is the document store working?"
222
+ )
223
+ except Exception as e:
224
+ if "API key is invalid" in str(e):
225
+ logging.exception(e)
226
+ st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
227
+ else:
228
+ logging.exception(e)
229
+ st.error("🐞    An error occurred during the request.")
230
+ # Display results
231
+ if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
232
+
233
+ # Handle Extractive Answers
234
+ if task_selection == 'Extractive':
235
+ results = st.session_state.results_extractive
236
+
237
+ st.subheader("Extracted Answers:")
238
+
239
+ if 'answers' in results:
240
+ answers = results['answers']
241
+ treshold = 0.2
242
+ higher_then_treshold = any(ans.score > treshold for ans in answers)
243
+ if not higher_then_treshold:
244
+ st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
245
+ for count, answer in enumerate(answers):
246
+ if answer.answer:
247
+ text, context = answer.answer, answer.context
248
+ start_idx = context.find(text)
249
+ end_idx = start_idx + len(text)
250
+ score = round(answer.score, 3)
251
+ st.markdown(f"**Answer {count + 1}:**")
252
+ st.markdown(
253
+ context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
254
+ unsafe_allow_html=True,
255
+ )
256
+ else:
257
+ st.info(
258
+ "πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
259
+ )
260
+
261
+ # Handle Generative Answers
262
+ elif task_selection == 'Generative':
263
+ results = st.session_state.results_generative
264
+ st.subheader("Generated Answer:")
265
+ if 'results' in results:
266
+ st.markdown("**Answer:**")
267
+ st.write(results['results'][0])
268
+
269
+ # Handle Retrieved Documents
270
+ if 'documents' in results:
271
+ retrieved_documents = results['documents']
272
+ st.subheader("Retriever Results:")
273
+
274
+ data = []
275
+ for i, document in enumerate(retrieved_documents):
276
+ # Truncate the content
277
+ truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
278
+ data.append([i + 1, document.meta['name'], truncated_content])
279
+
280
+ # Convert data to DataFrame and display using Streamlit
281
+ df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
282
+ st.table(df)
283
+ except SystemExit as e:
284
+ os._exit(e.code)