Spaces:
Paused
Paused
import logging | |
import os | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.document_loaders import DirectoryLoader | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
from langchain.retrievers.multi_query import MultiQueryRetriever | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from dotenv import load_dotenv | |
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.llms import OpenAI | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder | |
## Setting up Log configuration | |
logging.basicConfig( | |
filename='Logs/chatbot.log', # Name of the log file | |
level=logging.INFO, # Logging level (you can use logging.DEBUG for more detailed logs) | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
class Jine: | |
def __init__(self, OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG,USE_HYDE=False): | |
self.OPENAI_API_KEY = OPENAI_API_KEY | |
self.DATA_DIRECTORY = DATA_DIRECTORY | |
self.VECTOR_STORE_DIRECTORY = VECTOR_STORE_DIRECTORY | |
self.VECTOR_STORE_CHECK = VECTOR_STORE_CHECK | |
# self.DEBUG = DEBUG | |
self.vectorstore = None | |
self.bot = None | |
def create_vectorstore(self): | |
if self.VECTOR_STORE_CHECK: | |
print("Loading Vectorstore") | |
self.load_vectorstore() | |
else: | |
print("Creating Vectorstore") | |
docs = DirectoryLoader(self.DATA_DIRECTORY).load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10) | |
all_splits = text_splitter.split_documents(docs) | |
self.vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings(), | |
persist_directory=self.VECTOR_STORE_DIRECTORY) | |
def load_vectorstore(self): | |
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=OpenAIEmbeddings()) | |
def log(self, user_question, chatbot_reply): | |
# Log the user's question | |
logging.info(f"User: {user_question}") | |
# Log the chatbot's reply | |
logging.info(f"JIN-e: {chatbot_reply}") | |
def load_model(self): | |
self.create_vectorstore() | |
self.create_ensemble_retriever() | |
def chat(self, user_question): | |
result = self.bot({"query": user_question}) | |
response = result["result"] | |
self.log(user_question, response) | |
return response | |
### Adding Ensemble retriver | |
def create_ensemble_retriever(self): | |
template = """ | |
You are an Expert Policy Advisor.These Below are the Documents that are extracted from the different Policies.Your Job | |
is to Provide the Answer to below question based on the text below. | |
Here are few instructions for you to follow when answering a question. | |
- When you didnt find the relevant answers from below text Just Say "I dont know this,Please contact your HRBP for more details." | |
- These are policy Documents, When answering a question Do Not return in response that "This information is At Annex A/B".Provide a Complete response to request. | |
- Try to answer the questions in bullet format if possible. | |
- Use three sentences maximum to Answer the question in very concise manner | |
{context} | |
Question: {question} | |
Helpful Answer: | |
""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
print("====================="*10) | |
print("Loading Documents for Ensemble Retriver") | |
print("====================="*10) | |
docs = DirectoryLoader(self.DATA_DIRECTORY).load() | |
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10) | |
# all_splits = text_splitter.split_documents(docs) | |
bm25_retriever = BM25Retriever.from_documents(docs) | |
# GEttting only two relevant documents | |
bm25_retriever.k = 2 | |
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, | |
self.vectorstore.as_retriever(search_kwargs={"k": 2})], | |
weights=[0.5, 0.5]) | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
self.bot = RetrievalQA.from_chain_type( | |
llm, | |
retriever=ensemble_retriever, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}) | |
if __name__ == "__main__": | |
# Set your configuration here | |
load_dotenv() | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
DATA_DIRECTORY = os.getenv("DATA_DIRECTORY") | |
VECTOR_STORE_DIRECTORY = os.getenv("VECTOR_STORE_DIRCTORY") | |
VECTOR_STORE_CHECK = os.getenv("VECTOR_STORE_CHECK") | |
DEBUG = os.getenv("DEBUG") | |
USE_HYDE = os.getenv("USE_HYDE") | |
# Initialize Jine and start chatting | |
jine = Jine(OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG) | |
# print(jine.VECTOR_STORE_CHECK) | |
jine.load_model() | |
while True: | |
user_question = input("You: ") | |
if user_question.lower() in ["exit", "quit"]: | |
break | |
response = jine.chat(user_question) | |
print("JIN-e:", response) | |