jin-e / jine.py
hamxahbhattii's picture
added Jine
6330947
raw
history blame
5.49 kB
import logging
import os
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from dotenv import load_dotenv
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import LLMChain, HypotheticalDocumentEmbedder
## Setting up Log configuration
logging.basicConfig(
filename='Logs/chatbot.log', # Name of the log file
level=logging.INFO, # Logging level (you can use logging.DEBUG for more detailed logs)
format='%(asctime)s - %(levelname)s - %(message)s'
)
class Jine:
def __init__(self, OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG,USE_HYDE=False):
self.OPENAI_API_KEY = OPENAI_API_KEY
self.DATA_DIRECTORY = DATA_DIRECTORY
self.VECTOR_STORE_DIRECTORY = VECTOR_STORE_DIRECTORY
self.VECTOR_STORE_CHECK = VECTOR_STORE_CHECK
# self.DEBUG = DEBUG
self.vectorstore = None
self.bot = None
def create_vectorstore(self):
if self.VECTOR_STORE_CHECK:
print("Loading Vectorstore")
self.load_vectorstore()
else:
print("Creating Vectorstore")
docs = DirectoryLoader(self.DATA_DIRECTORY).load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10)
all_splits = text_splitter.split_documents(docs)
self.vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings(),
persist_directory=self.VECTOR_STORE_DIRECTORY)
def load_vectorstore(self):
self.vectorstore = Chroma(persist_directory=self.VECTOR_STORE_DIRECTORY, embedding_function=OpenAIEmbeddings())
def log(self, user_question, chatbot_reply):
# Log the user's question
logging.info(f"User: {user_question}")
# Log the chatbot's reply
logging.info(f"JIN-e: {chatbot_reply}")
def load_model(self):
self.create_vectorstore()
self.create_ensemble_retriever()
def chat(self, user_question):
result = self.bot({"query": user_question})
response = result["result"]
self.log(user_question, response)
return response
### Adding Ensemble retriver
def create_ensemble_retriever(self):
template = """
You are an Expert Policy Advisor.These Below are the Documents that are extracted from the different Policies.Your Job
is to Provide the Answer to below question based on the text below.
Here are few instructions for you to follow when answering a question.
- When you didnt find the relevant answers from below text Just Say "I dont know this,Please contact your HRBP for more details."
- These are policy Documents, When answering a question Do Not return in response that "This information is At Annex A/B".Provide a Complete response to request.
- Try to answer the questions in bullet format if possible.
- Use three sentences maximum to Answer the question in very concise manner
{context}
Question: {question}
Helpful Answer:
"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
print("====================="*10)
print("Loading Documents for Ensemble Retriver")
print("====================="*10)
docs = DirectoryLoader(self.DATA_DIRECTORY).load()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=10)
# all_splits = text_splitter.split_documents(docs)
bm25_retriever = BM25Retriever.from_documents(docs)
# GEttting only two relevant documents
bm25_retriever.k = 2
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever,
self.vectorstore.as_retriever(search_kwargs={"k": 2})],
weights=[0.5, 0.5])
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
self.bot = RetrievalQA.from_chain_type(
llm,
retriever=ensemble_retriever,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
if __name__ == "__main__":
# Set your configuration here
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
DATA_DIRECTORY = os.getenv("DATA_DIRECTORY")
VECTOR_STORE_DIRECTORY = os.getenv("VECTOR_STORE_DIRCTORY")
VECTOR_STORE_CHECK = os.getenv("VECTOR_STORE_CHECK")
DEBUG = os.getenv("DEBUG")
USE_HYDE = os.getenv("USE_HYDE")
# Initialize Jine and start chatting
jine = Jine(OPENAI_API_KEY, VECTOR_STORE_DIRECTORY, VECTOR_STORE_CHECK, DATA_DIRECTORY, DEBUG)
# print(jine.VECTOR_STORE_CHECK)
jine.load_model()
while True:
user_question = input("You: ")
if user_question.lower() in ["exit", "quit"]:
break
response = jine.chat(user_question)
print("JIN-e:", response)