Daniel4343 commited on
Commit
3d220e9
·
1 Parent(s): 12a063e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -19,14 +19,12 @@ def process_answer(instruction, qa_chain):
19
  generated_text = qa_chain.run(instruction)
20
  return generated_text
21
 
22
-
23
  def get_file_size(file):
24
  file.seek(0, os.SEEK_END)
25
  file_size = file.tell()
26
  file.seek(0)
27
  return file_size
28
 
29
-
30
  @st.cache_resource
31
  def data_ingestion():
32
  for root, dirs, files in os.walk("docs"):
@@ -39,15 +37,14 @@ def data_ingestion():
39
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
40
  splits = text_splitter.split_documents(documents)
41
 
42
- # create embeddings here
43
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
44
  vectordb = FAISS.from_documents(splits, embeddings)
45
  vectordb.save_local("faiss_index")
46
 
47
-
48
  @st.cache_resource
49
  def initialize_qa_chain(selected_model):
50
- # Constants
51
  CHECKPOINT = selected_model
52
  TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
53
  BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
@@ -67,7 +64,7 @@ def initialize_qa_chain(selected_model):
67
 
68
  vectordb = FAISS.load_local("faiss_index", embeddings)
69
 
70
- # Build a QA chain
71
  qa_chain = RetrievalQA.from_chain_type(
72
  llm=llm,
73
  chain_type="stuff",
@@ -75,45 +72,42 @@ def initialize_qa_chain(selected_model):
75
  )
76
  return qa_chain
77
 
78
-
79
  @st.cache_data
80
- # function to display the PDF of a given file
81
  def display_pdf(file):
82
  try:
83
- # Opening file from file path
84
  with open(file, "rb") as f:
85
  base64_pdf = base64.b64encode(f.read()).decode('utf-8')
86
 
87
- # Embedding PDF in HTML
88
  pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
89
 
90
- # Displaying File
91
  st.markdown(pdf_display, unsafe_allow_html=True)
92
  except Exception as e:
93
- st.error(f"An error occurred while displaying the PDF: {e}")
94
 
95
-
96
- # Display conversation history using Streamlit messages
97
  def display_conversation(history):
98
  for i in range(len(history["generated"])):
99
  message(history["past"][i], is_user=True, key=f"{i}_user")
100
  message(history["generated"][i], key=str(i))
101
 
102
-
103
  def main():
104
- # Add a sidebar for model selection
105
  model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
106
- selected_model = st.sidebar.selectbox("Select Model", model_options)
107
 
108
- st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot 🦜📄 </h1>", unsafe_allow_html=True)
109
- st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions 👇</h2>", unsafe_allow_html=True)
110
 
111
  uploaded_file = st.file_uploader("", type=["pdf"])
112
 
113
  if uploaded_file is not None:
114
  file_details = {
115
- "Filename": uploaded_file.name,
116
- "File size": get_file_size(uploaded_file)
117
  }
118
  os.makedirs("docs", exist_ok=True)
119
  filepath = os.path.join("docs", uploaded_file.name)
@@ -123,40 +117,39 @@ def main():
123
 
124
  col1, col2 = st.columns([1, 2])
125
  with col1:
126
- st.markdown("<h4 style color:black;'>File details</h4>", unsafe_allow_html=True)
127
  st.json(file_details)
128
- st.markdown("<h4 style color:black;'>File preview</h4>", unsafe_allow_html=True)
129
  pdf_view = display_pdf(filepath)
130
 
131
  with col2:
132
- st.success(f'model selected successfully: {selected_model}')
133
- with st.spinner('Embeddings are in process...'):
134
  ingested_data = data_ingestion()
135
- st.success('Embeddings are created successfully!')
136
- st.markdown("<h4 style color:black;'>Chat Here</h4>", unsafe_allow_html=True)
137
 
138
  user_input = st.text_input("", key="input")
139
 
140
- # Initialize session state for generated responses and past messages
141
  if "generated" not in st.session_state:
142
- st.session_state["generated"] = ["I am ready to help you"]
143
  if "past" not in st.session_state:
144
- st.session_state["past"] = ["Hey there!"]
145
 
146
- # Search the database for a response based on user input and update session state
147
  if user_input:
148
  answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
149
  st.session_state["past"].append(user_input)
150
  response = answer
151
  st.session_state["generated"].append(response)
152
 
153
- # Display conversation history using Streamlit messages
154
  if st.session_state["generated"]:
155
  display_conversation(st.session_state)
156
 
157
  except Exception as e:
158
- st.error(f"An error occurred: {e}")
159
-
160
 
161
  if __name__ == "__main__":
162
  main()
 
19
  generated_text = qa_chain.run(instruction)
20
  return generated_text
21
 
 
22
  def get_file_size(file):
23
  file.seek(0, os.SEEK_END)
24
  file_size = file.tell()
25
  file.seek(0)
26
  return file_size
27
 
 
28
  @st.cache_resource
29
  def data_ingestion():
30
  for root, dirs, files in os.walk("docs"):
 
37
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
38
  splits = text_splitter.split_documents(documents)
39
 
40
+ # Hier Embeddings erstellen
41
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
42
  vectordb = FAISS.from_documents(splits, embeddings)
43
  vectordb.save_local("faiss_index")
44
 
 
45
  @st.cache_resource
46
  def initialize_qa_chain(selected_model):
47
+ # Konstanten
48
  CHECKPOINT = selected_model
49
  TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
50
  BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
 
64
 
65
  vectordb = FAISS.load_local("faiss_index", embeddings)
66
 
67
+ # QA-Kette erstellen
68
  qa_chain = RetrievalQA.from_chain_type(
69
  llm=llm,
70
  chain_type="stuff",
 
72
  )
73
  return qa_chain
74
 
 
75
  @st.cache_data
76
+ # Funktion zum Anzeigen der PDF einer bestimmten Datei
77
  def display_pdf(file):
78
  try:
79
+ # Datei von Dateipfad öffnen
80
  with open(file, "rb") as f:
81
  base64_pdf = base64.b64encode(f.read()).decode('utf-8')
82
 
83
+ # PDF in HTML einbetten
84
  pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
85
 
86
+ # Datei anzeigen
87
  st.markdown(pdf_display, unsafe_allow_html=True)
88
  except Exception as e:
89
+ st.error(f"Ein Fehler ist beim Anzeigen der PDF aufgetreten: {e}")
90
 
91
+ # Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen
 
92
  def display_conversation(history):
93
  for i in range(len(history["generated"])):
94
  message(history["past"][i], is_user=True, key=f"{i}_user")
95
  message(history["generated"][i], key=str(i))
96
 
 
97
  def main():
98
+ # Sidebar für die Modellauswahl hinzufügen
99
  model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
100
+ selected_model = st.sidebar.selectbox("Modell auswählen", model_options)
101
 
102
+ st.markdown("<h1 style='text-align: center; color: blue;'>Benutzerdefinierter PDF-Chatbot 🦜📄 </h1>", unsafe_allow_html=True)
103
+ st.markdown("<h2 style='text-align: center; color:red;'>Laden Sie Ihr PDF hoch und stellen Sie Fragen 👇</h2>", unsafe_allow_html=True)
104
 
105
  uploaded_file = st.file_uploader("", type=["pdf"])
106
 
107
  if uploaded_file is not None:
108
  file_details = {
109
+ "Dateiname": uploaded_file.name,
110
+ "Dateigröße": get_file_size(uploaded_file)
111
  }
112
  os.makedirs("docs", exist_ok=True)
113
  filepath = os.path.join("docs", uploaded_file.name)
 
117
 
118
  col1, col2 = st.columns([1, 2])
119
  with col1:
120
+ st.markdown("<h4 style color:black;'>Dateidetails</h4>", unsafe_allow_html=True)
121
  st.json(file_details)
122
+ st.markdown("<h4 style color:black;'>Dateivorschau</h4>", unsafe_allow_html=True)
123
  pdf_view = display_pdf(filepath)
124
 
125
  with col2:
126
+ st.success(f'Modell erfolgreich ausgewählt: {selected_model}')
127
+ with st.spinner('Embeddings werden erstellt...'):
128
  ingested_data = data_ingestion()
129
+ st.success('Embeddings wurden erfolgreich erstellt!')
130
+ st.markdown("<h4 style color:black;'>Hier chatten</h4>", unsafe_allow_html=True)
131
 
132
  user_input = st.text_input("", key="input")
133
 
134
+ # Sitzungszustand für generierte Antworten und vergangene Nachrichten initialisieren
135
  if "generated" not in st.session_state:
136
+ st.session_state["generated"] = ["Ich bin bereit, Ihnen zu helfen"]
137
  if "past" not in st.session_state:
138
+ st.session_state["past"] = ["Hallo!"]
139
 
140
+ # In der Datenbank nach einer Antwort basierend auf der Benutzereingabe suchen und den Sitzungszustand aktualisieren
141
  if user_input:
142
  answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
143
  st.session_state["past"].append(user_input)
144
  response = answer
145
  st.session_state["generated"].append(response)
146
 
147
+ # Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen
148
  if st.session_state["generated"]:
149
  display_conversation(st.session_state)
150
 
151
  except Exception as e:
152
+ st.error(f"Ein Fehler ist aufgetreten: {e}")
 
153
 
154
  if __name__ == "__main__":
155
  main()