contextual_retrieval / ragchatbot.py
viboognesh's picture
Upload folder using huggingface_hub
94d6273 verified
from llm_constants import LLM_MODEL_NAME, MAX_TOKENS, RERANKER_MODEL_NAME, EMBEDDINGS_MODEL_NAME, EMBEDDINGS_TOKENS_COST, INPUT_TOKENS_COST, OUTPUT_TOKENS_COST, COHERE_RERANKER_COST
from prompts import CHAT_PROMPT, TOOLS
import os
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from typing import List, Dict, Sequence
from pydantic_models import RequestModel, ResponseModel, ChatHistoryItem, VectorStoreDocumentItem
import tiktoken
from dotenv import load_dotenv
load_dotenv()
from langchain_community.vectorstores import FAISS
import anthropic
import cohere
class RAGChatBot:
__cohere_api_key = os.getenv("COHERE_API_KEY")
__anthroic_api_key = os.getenv("ANTHROPIC_API_KEY")
__openai_api_key = os.getenv("OPENAI_API_KEY")
__embedding_function = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME)
__base_retriever = None
__bm25_retriever = None
anthropic_client = None
cohere_client = None
top_n: int = 3
chat_history_length: int = 10
def __init__(self, vectorstore_path:str, top_n:int = 3):
if self.__cohere_api_key is None:
raise ValueError("COHERE_API_KEY must be set in the environment")
if self.__anthroic_api_key is None:
raise ValueError("ANTHROPIC_API_KEY must be set in the environment")
if self.__openai_api_key is None:
raise ValueError("OPENAI_API_KEY must be set in the environment")
if not isinstance(top_n, int):
raise ValueError("top_n must be an integer")
self.top_n = top_n
self.set_base_retriever(vectorstore_path)
self.set_anthropic_client()
self.set_cohere_client()
def set_base_retriever(self, vectorstore_path:str):
db = FAISS.load_local(vectorstore_path, self.__embedding_function, allow_dangerous_deserialization=True)
retriever = db.as_retriever(search_kwargs={"k": 25})
self.__base_retriever = retriever
self.__bm25_retriever = BM25Retriever.from_documents(list(db.docstore.__dict__.get('_dict').values()), k=25)
def set_anthropic_client(self):
self.anthropic_client = anthropic.Anthropic(api_key=self.__anthroic_api_key)
def set_cohere_client(self):
self.cohere_client = cohere.Client(self.__cohere_api_key)
def make_llm_api_call(self, messages:list):
return self.anthropic_client.messages.create(
model=LLM_MODEL_NAME,
max_tokens=MAX_TOKENS,
temperature=0,
messages=messages,
tools=TOOLS
)
def make_rerank_api_call(self, search_phrase:str, documents: Sequence[str]):
return self.cohere_client.rerank(query=search_phrase, documents=documents, model=RERANKER_MODEL_NAME, top_n=self.top_n)
def retrieve_documents(self, search_phrase:str):
similarity_documents = self.__base_retriever.invoke(search_phrase)
bm25_documents = self.__bm25_retriever.invoke(search_phrase)
unique_docs = []
for doc in bm25_documents:
if doc not in unique_docs:
unique_docs.append(doc)
for doc in similarity_documents:
if doc not in unique_docs:
unique_docs.append(doc)
return unique_docs
def retrieve_and_rerank(self, search_phrase:str):
documents = self.retrieve_documents(search_phrase)
if len(documents) == 0: # to avoid empty api call
return []
docs = [doc.page_content for doc in documents if isinstance(doc, Document) ]
api_result = self.make_rerank_api_call(search_phrase, docs)
reranked_docs = []
max_score = max([res.relevance_score for res in api_result.results])
threshold_score = max_score * 0.8
for res in api_result.results:
# if res.relevance_score < threshold_score:
# continue
doc = documents[res.index]
documentItem = VectorStoreDocumentItem(page_content=doc.page_content, filename=doc.metadata['filename'], heading=doc.metadata['heading'], relevance_score=res.relevance_score)
reranked_docs.append(documentItem)
return reranked_docs
def get_context_and_docs(self, search_phrase:str):
docs = self.retrieve_and_rerank(search_phrase)
context = "\n\n\n".join([f"Filename:{doc.heading}\n\n{doc.page_content}" for doc in docs])
return context, docs
def get_tool_use_assistant_message(self, tool_use_block):
return {'role': 'assistant',
'content':tool_use_block
}
def get_tool_use_user_message(self, tool_use_id, context):
return {'role': 'user',
'content': [{'type': 'tool_result',
'tool_use_id': tool_use_id,
'content': context}]}
def process_tool_call(self, tool_name, tool_input):
if tool_name == "Documents_Retriever":
context, sources_list = self.get_context_and_docs(tool_input["search_phrase"])
search_phrase = tool_input["search_phrase"]
return sources_list, search_phrase, context
def calculate_cost(self, input_tokens, output_tokens, search_phrase):
MILLION = 1000000
if search_phrase:
enc = tiktoken.get_encoding("cl100k_base")
query_encode = enc.encode(search_phrase)
embeddings_cost = len(query_encode) * (EMBEDDINGS_TOKENS_COST/MILLION)
total_cost = embeddings_cost + COHERE_RERANKER_COST + (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION))
else:
total_cost = (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION))
return total_cost
def chat_with_claude(self, user_message_history:list):
input_tokens = 0
output_tokens = 0
message = self.make_llm_api_call(user_message_history)
input_tokens += message.usage.input_tokens
output_tokens += message.usage.output_tokens
documents_list = []
search_phrase = ""
while message.stop_reason == "tool_use":
tool_use = next(block for block in message.content if block.type == "tool_use")
tool_name = tool_use.name
tool_input = tool_use.input
tool_use_id = tool_use.id
documents_list, search_phrase, tool_result = self.process_tool_call(tool_name, tool_input)
user_message_history.append( self.get_tool_use_assistant_message(message.content))
user_message_history.append( self.get_tool_use_user_message(tool_use_id, tool_result))
message = self.make_llm_api_call(user_message_history)
input_tokens += message.usage.input_tokens
output_tokens += message.usage.output_tokens
answer = next(
(block.text for block in message.content if hasattr(block,"text")),
None,
)
if "<answer>" in answer:
answer = answer.split("<answer>")[1].split("</answer>")[0].strip()
total_cost = self.calculate_cost(input_tokens, output_tokens, search_phrase)
return (documents_list, search_phrase, answer, total_cost)
def get_chat_history_text(self, chat_history: List[ChatHistoryItem]):
chat_history_text = ""
for chat_message in chat_history:
chat_history_text += f"User: {chat_message.user_message}\nAssistant: {chat_message.assistant_message}\n"
return chat_history_text.strip()
def get_response(self, input:RequestModel) -> ResponseModel:
chat_history = self.get_chat_history_text(input.chat_history)
user_question = input.user_question
user_prompt = CHAT_PROMPT.format(CHAT_HISTORY=chat_history, USER_QUESTION=user_question)
if input.use_tool:
user_prompt = f"{user_prompt}\nUse Documents_Retriever tool in your response."
sources_list, search_phrase, answer, _ = self.chat_with_claude([{"role":"user","content":[{"type":"text","text":user_prompt}]}])
updated_chat_history = input.chat_history.copy()
updated_chat_history.append(ChatHistoryItem(user_message=user_question, assistant_message=answer))
return ResponseModel(answer = answer, sources_documents = sources_list, chat_history=updated_chat_history, search_phrase=search_phrase)