Spaces:
Runtime error
Runtime error
Daniel4343
commited on
Commit
·
3d220e9
1
Parent(s):
12a063e
Update app.py
Browse files
app.py
CHANGED
@@ -19,14 +19,12 @@ def process_answer(instruction, qa_chain):
|
|
19 |
generated_text = qa_chain.run(instruction)
|
20 |
return generated_text
|
21 |
|
22 |
-
|
23 |
def get_file_size(file):
|
24 |
file.seek(0, os.SEEK_END)
|
25 |
file_size = file.tell()
|
26 |
file.seek(0)
|
27 |
return file_size
|
28 |
|
29 |
-
|
30 |
@st.cache_resource
|
31 |
def data_ingestion():
|
32 |
for root, dirs, files in os.walk("docs"):
|
@@ -39,15 +37,14 @@ def data_ingestion():
|
|
39 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
|
40 |
splits = text_splitter.split_documents(documents)
|
41 |
|
42 |
-
#
|
43 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
44 |
vectordb = FAISS.from_documents(splits, embeddings)
|
45 |
vectordb.save_local("faiss_index")
|
46 |
|
47 |
-
|
48 |
@st.cache_resource
|
49 |
def initialize_qa_chain(selected_model):
|
50 |
-
#
|
51 |
CHECKPOINT = selected_model
|
52 |
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
53 |
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
@@ -67,7 +64,7 @@ def initialize_qa_chain(selected_model):
|
|
67 |
|
68 |
vectordb = FAISS.load_local("faiss_index", embeddings)
|
69 |
|
70 |
-
#
|
71 |
qa_chain = RetrievalQA.from_chain_type(
|
72 |
llm=llm,
|
73 |
chain_type="stuff",
|
@@ -75,45 +72,42 @@ def initialize_qa_chain(selected_model):
|
|
75 |
)
|
76 |
return qa_chain
|
77 |
|
78 |
-
|
79 |
@st.cache_data
|
80 |
-
#
|
81 |
def display_pdf(file):
|
82 |
try:
|
83 |
-
#
|
84 |
with open(file, "rb") as f:
|
85 |
base64_pdf = base64.b64encode(f.read()).decode('utf-8')
|
86 |
|
87 |
-
#
|
88 |
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
|
89 |
|
90 |
-
#
|
91 |
st.markdown(pdf_display, unsafe_allow_html=True)
|
92 |
except Exception as e:
|
93 |
-
st.error(f"
|
94 |
|
95 |
-
|
96 |
-
# Display conversation history using Streamlit messages
|
97 |
def display_conversation(history):
|
98 |
for i in range(len(history["generated"])):
|
99 |
message(history["past"][i], is_user=True, key=f"{i}_user")
|
100 |
message(history["generated"][i], key=str(i))
|
101 |
|
102 |
-
|
103 |
def main():
|
104 |
-
|
105 |
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
|
106 |
-
selected_model = st.sidebar.selectbox("
|
107 |
|
108 |
-
st.markdown("<h1 style='text-align: center; color: blue;'>
|
109 |
-
st.markdown("<h2 style='text-align: center; color:red;'>
|
110 |
|
111 |
uploaded_file = st.file_uploader("", type=["pdf"])
|
112 |
|
113 |
if uploaded_file is not None:
|
114 |
file_details = {
|
115 |
-
"
|
116 |
-
"
|
117 |
}
|
118 |
os.makedirs("docs", exist_ok=True)
|
119 |
filepath = os.path.join("docs", uploaded_file.name)
|
@@ -123,40 +117,39 @@ def main():
|
|
123 |
|
124 |
col1, col2 = st.columns([1, 2])
|
125 |
with col1:
|
126 |
-
st.markdown("<h4 style color:black;'>
|
127 |
st.json(file_details)
|
128 |
-
st.markdown("<h4 style color:black;'>
|
129 |
pdf_view = display_pdf(filepath)
|
130 |
|
131 |
with col2:
|
132 |
-
st.success(f'
|
133 |
-
with st.spinner('Embeddings
|
134 |
ingested_data = data_ingestion()
|
135 |
-
st.success('Embeddings
|
136 |
-
st.markdown("<h4 style color:black;'>
|
137 |
|
138 |
user_input = st.text_input("", key="input")
|
139 |
|
140 |
-
#
|
141 |
if "generated" not in st.session_state:
|
142 |
-
st.session_state["generated"] = ["
|
143 |
if "past" not in st.session_state:
|
144 |
-
st.session_state["past"] = ["
|
145 |
|
146 |
-
#
|
147 |
if user_input:
|
148 |
answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
|
149 |
st.session_state["past"].append(user_input)
|
150 |
response = answer
|
151 |
st.session_state["generated"].append(response)
|
152 |
|
153 |
-
#
|
154 |
if st.session_state["generated"]:
|
155 |
display_conversation(st.session_state)
|
156 |
|
157 |
except Exception as e:
|
158 |
-
st.error(f"
|
159 |
-
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
main()
|
|
|
19 |
generated_text = qa_chain.run(instruction)
|
20 |
return generated_text
|
21 |
|
|
|
22 |
def get_file_size(file):
|
23 |
file.seek(0, os.SEEK_END)
|
24 |
file_size = file.tell()
|
25 |
file.seek(0)
|
26 |
return file_size
|
27 |
|
|
|
28 |
@st.cache_resource
|
29 |
def data_ingestion():
|
30 |
for root, dirs, files in os.walk("docs"):
|
|
|
37 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
|
38 |
splits = text_splitter.split_documents(documents)
|
39 |
|
40 |
+
# Hier Embeddings erstellen
|
41 |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
42 |
vectordb = FAISS.from_documents(splits, embeddings)
|
43 |
vectordb.save_local("faiss_index")
|
44 |
|
|
|
45 |
@st.cache_resource
|
46 |
def initialize_qa_chain(selected_model):
|
47 |
+
# Konstanten
|
48 |
CHECKPOINT = selected_model
|
49 |
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
50 |
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
|
|
64 |
|
65 |
vectordb = FAISS.load_local("faiss_index", embeddings)
|
66 |
|
67 |
+
# QA-Kette erstellen
|
68 |
qa_chain = RetrievalQA.from_chain_type(
|
69 |
llm=llm,
|
70 |
chain_type="stuff",
|
|
|
72 |
)
|
73 |
return qa_chain
|
74 |
|
|
|
75 |
@st.cache_data
|
76 |
+
# Funktion zum Anzeigen der PDF einer bestimmten Datei
|
77 |
def display_pdf(file):
|
78 |
try:
|
79 |
+
# Datei von Dateipfad öffnen
|
80 |
with open(file, "rb") as f:
|
81 |
base64_pdf = base64.b64encode(f.read()).decode('utf-8')
|
82 |
|
83 |
+
# PDF in HTML einbetten
|
84 |
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
|
85 |
|
86 |
+
# Datei anzeigen
|
87 |
st.markdown(pdf_display, unsafe_allow_html=True)
|
88 |
except Exception as e:
|
89 |
+
st.error(f"Ein Fehler ist beim Anzeigen der PDF aufgetreten: {e}")
|
90 |
|
91 |
+
# Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen
|
|
|
92 |
def display_conversation(history):
|
93 |
for i in range(len(history["generated"])):
|
94 |
message(history["past"][i], is_user=True, key=f"{i}_user")
|
95 |
message(history["generated"][i], key=str(i))
|
96 |
|
|
|
97 |
def main():
|
98 |
+
# Sidebar für die Modellauswahl hinzufügen
|
99 |
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
|
100 |
+
selected_model = st.sidebar.selectbox("Modell auswählen", model_options)
|
101 |
|
102 |
+
st.markdown("<h1 style='text-align: center; color: blue;'>Benutzerdefinierter PDF-Chatbot 🦜📄 </h1>", unsafe_allow_html=True)
|
103 |
+
st.markdown("<h2 style='text-align: center; color:red;'>Laden Sie Ihr PDF hoch und stellen Sie Fragen 👇</h2>", unsafe_allow_html=True)
|
104 |
|
105 |
uploaded_file = st.file_uploader("", type=["pdf"])
|
106 |
|
107 |
if uploaded_file is not None:
|
108 |
file_details = {
|
109 |
+
"Dateiname": uploaded_file.name,
|
110 |
+
"Dateigröße": get_file_size(uploaded_file)
|
111 |
}
|
112 |
os.makedirs("docs", exist_ok=True)
|
113 |
filepath = os.path.join("docs", uploaded_file.name)
|
|
|
117 |
|
118 |
col1, col2 = st.columns([1, 2])
|
119 |
with col1:
|
120 |
+
st.markdown("<h4 style color:black;'>Dateidetails</h4>", unsafe_allow_html=True)
|
121 |
st.json(file_details)
|
122 |
+
st.markdown("<h4 style color:black;'>Dateivorschau</h4>", unsafe_allow_html=True)
|
123 |
pdf_view = display_pdf(filepath)
|
124 |
|
125 |
with col2:
|
126 |
+
st.success(f'Modell erfolgreich ausgewählt: {selected_model}')
|
127 |
+
with st.spinner('Embeddings werden erstellt...'):
|
128 |
ingested_data = data_ingestion()
|
129 |
+
st.success('Embeddings wurden erfolgreich erstellt!')
|
130 |
+
st.markdown("<h4 style color:black;'>Hier chatten</h4>", unsafe_allow_html=True)
|
131 |
|
132 |
user_input = st.text_input("", key="input")
|
133 |
|
134 |
+
# Sitzungszustand für generierte Antworten und vergangene Nachrichten initialisieren
|
135 |
if "generated" not in st.session_state:
|
136 |
+
st.session_state["generated"] = ["Ich bin bereit, Ihnen zu helfen"]
|
137 |
if "past" not in st.session_state:
|
138 |
+
st.session_state["past"] = ["Hallo!"]
|
139 |
|
140 |
+
# In der Datenbank nach einer Antwort basierend auf der Benutzereingabe suchen und den Sitzungszustand aktualisieren
|
141 |
if user_input:
|
142 |
answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
|
143 |
st.session_state["past"].append(user_input)
|
144 |
response = answer
|
145 |
st.session_state["generated"].append(response)
|
146 |
|
147 |
+
# Unterhaltungsgeschichte mit Streamlit-Nachrichten anzeigen
|
148 |
if st.session_state["generated"]:
|
149 |
display_conversation(st.session_state)
|
150 |
|
151 |
except Exception as e:
|
152 |
+
st.error(f"Ein Fehler ist aufgetreten: {e}")
|
|
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
main()
|