Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
from dotenv import load_dotenv | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.vectorstores import Chroma | |
from langchain.llms import GPT4All, LlamaCpp | |
import chromadb | |
import os | |
import argparse | |
import time | |
import streamlit as st | |
from htmlTemplates import css, bot_template, user_template | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
import langchain | |
from pydantic.v1 import BaseSettings | |
langchain.verbose = False | |
if not load_dotenv(): | |
print("Could not load .env file or it is empty. Please check if it exists and is readable.") | |
exit(1) | |
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") | |
persist_directory = os.environ.get('PERSIST_DIRECTORY') | |
model_type = os.environ.get('MODEL_TYPE') | |
model_path = os.environ.get('MODEL_PATH') | |
model_n_ctx = os.environ.get('MODEL_N_CTX') | |
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8)) | |
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) | |
from constants import CHROMA_SETTINGS | |
def handle_userinput(user_question): | |
response = st.session_state.conversation({'question': user_question}) | |
st.session_state.chat_history = response['chat_history'] | |
for i, message in enumerate(st.session_state.chat_history): | |
if i % 2 == 0: | |
st.write(user_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
else: | |
st.write(bot_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
def get_conversation_chain(llm, retriever): | |
#llm = ChatOpenAI() | |
#llm= GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) | |
memory = ConversationBufferMemory( | |
memory_key='chat_history', return_messages=True) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory | |
) | |
return conversation_chain | |
def main(): | |
# Parse the command line arguments | |
args = parse_arguments() | |
st.set_page_config(page_title="Chat with multiple PDFs", | |
page_icon=":books:") | |
st.write(css, unsafe_allow_html=True) | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = None | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = None | |
st.header("Tsetlin LLM Powered Chatbot") | |
user_question = st.text_input("Ask a question about Tsetlin Machine:") | |
if user_question: | |
handle_userinput(user_question) | |
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) | |
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory) | |
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client) | |
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) | |
# activate/deactivate the streaming StdOut callback for LLMs | |
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] | |
# Prepare the LLM | |
#what is match equivalent in python 3.9? | |
llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source) | |
# Interactive questions and answers | |
st.session_state.conversation = get_conversation_chain(llm, retriever) | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' | |
'using the power of LLMs.') | |
parser.add_argument("--hide-source", "-S", action='store_true', | |
help='Use this flag to disable printing of source documents used for answers.') | |
parser.add_argument("--mute-stream", "-M", | |
action='store_true', | |
help='Use this flag to disable the streaming StdOut callback for LLMs.') | |
return parser.parse_args() | |
if __name__ == "__main__": | |
main() | |