maitykritadhi's picture
Upload app.py
0f39449 verified
import os
import shutil
import streamlit as st
import chromadb
import config as cf
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
# from langchain_community.embeddings import SentenceTransformerEmbeddings
from sentence_transformers import SentenceTransformer
from langchain_groq import ChatGroq
from langchain.schema import Document
from source.utils.data_processing import ProcessDocs
from source.utils.store_data import get_vector_store, check_pdfs_chromadb, save_uploaded_files
from source.utils.process_data import get_pdf_text, get_text_chunks
llm = None
def get_conversational_chain(model):
global llm
# prompt_template = """
# Answer the question as detailed as possible from the provided context, make sure to provide all the details, if the answer is not in
# provided context just say, "answer is not available in the context", don't provide the wrong answer\n\n
# Context:\n {context}?\n
# Question: \n{question}\n
# Answer:
# """
# model = ChatGoogleGenerativeAI(model="gemini-pro",temperature=0.3)
if model == 'gemma-7b-it':
llm = ChatGroq(temperature=0, model_name="gemma-7b-it")
if model == 'mixtral-8x7b-32768':
llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768")
if model == 'llama3-70b-8192':
llm = ChatGroq(temperature=0, model_name="llama3-70b-8192")
if model == 'llama3-8b-8192':
llm = ChatGroq(temperature=0, model_name="llama3-8b-8192")
# prompt = PromptTemplate(template = prompt_template, input_variables = ["context", "question"])
chain = load_qa_chain(llm, chain_type="stuff",
# prompt=prompt
)
return chain
def user_input(user_question,model):
embedding_model = SentenceTransformer("all-mpnet-base-v2")
chain = get_conversational_chain(model)
docs = []
input_embeddings = embedding_model.encode(user_question).tolist()
client = chromadb.PersistentClient("chromadb")
collection = client.get_collection("Chromadb_pdf")
results = collection.query(
query_embeddings = [input_embeddings],
n_results = 5,
include=['distances', 'metadatas', 'documents']
)
if results['documents']:
pg_num = []
for i in range(len(results['documents'][0])):
document = results['documents'][0][i]
metadata = results['metadatas'][0][i]
pdf_name = metadata['pdf_name']
page_number = metadata['page_number']
docs.append(Document(
page_content=document,
metadata={
'source': pdf_name,
'page': page_number
}
))
pg_num.append(str(page_number))
response = chain(
{"input_documents": docs,
"question": user_question},
# return_only_outputs=True
return_only_outputs= False
)
# st.write("Reply: ", document)
# st.write("Reply:", response)
st.write("Reply:", response["output_text"])
st.write("Metadata: ", f"PDF Name: {pdf_name}, Page Numbers: {','.join(pg_num)}")
else:
st.write("No results found.")
def main():
st.set_page_config("Chat PDF")
model = st.selectbox("Select Model", ["llama3-8b-8192", "llama3-70b-8192","mixtral-8x7b-32768","gemma-7b-it"])
st.header("Chat with PDF after Uploading")
user_question = st.text_input("Ask a Question from the PDF Files")
if user_question:
db_obj = ProcessDocs(cf.db_collection_name)
response = db_obj.retrieval_qa(user_question, model)
st.write("Response:", response)
# st.write("Metadata: ", f"PDF Name: {pdf_name}, Page Numbers: {','.join(pg_num)}")
# user_input(user_question, model)
with st.sidebar:
st.title("Menu:")
pdf_docs = st.file_uploader("Upload your PDF Files and Click on the Submit & Process Button", accept_multiple_files=True)
db_obj = ProcessDocs(cf.db_collection_name)
# print(pdf_docs)
if st.button("Submit & Process"):
# global list_of_pdfs
# list_of_pdfs = check_pdfs_chromadb()
# check_pdfs_chromadb(list_of_pdfs)
new_files = [doc.name for doc in pdf_docs]
# new_files = [pdf_name for pdf_name in uploaded_docs_list]
# docs_directory = 'docs'
print(new_files)
if new_files:
if os.path.exists(cf.pdf_download_path):
shutil.rmtree(cf.pdf_download_path)
os.makedirs(cf.pdf_download_path)
pdf_docs = [pdf for pdf in pdf_docs if pdf.name in new_files]
print(pdf_docs)
save_uploaded_files(pdf_docs, cf.pdf_download_path)
with st.spinner("Processing..."):
new_unique_files = db_obj.identify_new_uploaded_files()
pdf_docs = db_obj.create_pdf_docx_loader(new_unique_files, model)
splits = db_obj.split_documents(pdf_docs)
db_obj.vector_store(splits)
# raw_text = get_pdf_text(cf.pdf_download_path)
# text_chunks = get_text_chunks(raw_text)
# get_vector_store(text_chunks)
st.success("Done")
# st.success("Done")
else:
st.success("No new files to process")
if __name__ == "__main__":
main()