Spaces:
Sleeping
Sleeping
File size: 8,536 Bytes
94d6273 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
from llm_constants import LLM_MODEL_NAME, MAX_TOKENS, RERANKER_MODEL_NAME, EMBEDDINGS_MODEL_NAME, EMBEDDINGS_TOKENS_COST, INPUT_TOKENS_COST, OUTPUT_TOKENS_COST, COHERE_RERANKER_COST
from prompts import CHAT_PROMPT, TOOLS
import os
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from typing import List, Dict, Sequence
from pydantic_models import RequestModel, ResponseModel, ChatHistoryItem, VectorStoreDocumentItem
import tiktoken
from dotenv import load_dotenv
load_dotenv()
from langchain_community.vectorstores import FAISS
import anthropic
import cohere
class RAGChatBot:
__cohere_api_key = os.getenv("COHERE_API_KEY")
__anthroic_api_key = os.getenv("ANTHROPIC_API_KEY")
__openai_api_key = os.getenv("OPENAI_API_KEY")
__embedding_function = OpenAIEmbeddings(model=EMBEDDINGS_MODEL_NAME)
__base_retriever = None
__bm25_retriever = None
anthropic_client = None
cohere_client = None
top_n: int = 3
chat_history_length: int = 10
def __init__(self, vectorstore_path:str, top_n:int = 3):
if self.__cohere_api_key is None:
raise ValueError("COHERE_API_KEY must be set in the environment")
if self.__anthroic_api_key is None:
raise ValueError("ANTHROPIC_API_KEY must be set in the environment")
if self.__openai_api_key is None:
raise ValueError("OPENAI_API_KEY must be set in the environment")
if not isinstance(top_n, int):
raise ValueError("top_n must be an integer")
self.top_n = top_n
self.set_base_retriever(vectorstore_path)
self.set_anthropic_client()
self.set_cohere_client()
def set_base_retriever(self, vectorstore_path:str):
db = FAISS.load_local(vectorstore_path, self.__embedding_function, allow_dangerous_deserialization=True)
retriever = db.as_retriever(search_kwargs={"k": 25})
self.__base_retriever = retriever
self.__bm25_retriever = BM25Retriever.from_documents(list(db.docstore.__dict__.get('_dict').values()), k=25)
def set_anthropic_client(self):
self.anthropic_client = anthropic.Anthropic(api_key=self.__anthroic_api_key)
def set_cohere_client(self):
self.cohere_client = cohere.Client(self.__cohere_api_key)
def make_llm_api_call(self, messages:list):
return self.anthropic_client.messages.create(
model=LLM_MODEL_NAME,
max_tokens=MAX_TOKENS,
temperature=0,
messages=messages,
tools=TOOLS
)
def make_rerank_api_call(self, search_phrase:str, documents: Sequence[str]):
return self.cohere_client.rerank(query=search_phrase, documents=documents, model=RERANKER_MODEL_NAME, top_n=self.top_n)
def retrieve_documents(self, search_phrase:str):
similarity_documents = self.__base_retriever.invoke(search_phrase)
bm25_documents = self.__bm25_retriever.invoke(search_phrase)
unique_docs = []
for doc in bm25_documents:
if doc not in unique_docs:
unique_docs.append(doc)
for doc in similarity_documents:
if doc not in unique_docs:
unique_docs.append(doc)
return unique_docs
def retrieve_and_rerank(self, search_phrase:str):
documents = self.retrieve_documents(search_phrase)
if len(documents) == 0: # to avoid empty api call
return []
docs = [doc.page_content for doc in documents if isinstance(doc, Document) ]
api_result = self.make_rerank_api_call(search_phrase, docs)
reranked_docs = []
max_score = max([res.relevance_score for res in api_result.results])
threshold_score = max_score * 0.8
for res in api_result.results:
# if res.relevance_score < threshold_score:
# continue
doc = documents[res.index]
documentItem = VectorStoreDocumentItem(page_content=doc.page_content, filename=doc.metadata['filename'], heading=doc.metadata['heading'], relevance_score=res.relevance_score)
reranked_docs.append(documentItem)
return reranked_docs
def get_context_and_docs(self, search_phrase:str):
docs = self.retrieve_and_rerank(search_phrase)
context = "\n\n\n".join([f"Filename:{doc.heading}\n\n{doc.page_content}" for doc in docs])
return context, docs
def get_tool_use_assistant_message(self, tool_use_block):
return {'role': 'assistant',
'content':tool_use_block
}
def get_tool_use_user_message(self, tool_use_id, context):
return {'role': 'user',
'content': [{'type': 'tool_result',
'tool_use_id': tool_use_id,
'content': context}]}
def process_tool_call(self, tool_name, tool_input):
if tool_name == "Documents_Retriever":
context, sources_list = self.get_context_and_docs(tool_input["search_phrase"])
search_phrase = tool_input["search_phrase"]
return sources_list, search_phrase, context
def calculate_cost(self, input_tokens, output_tokens, search_phrase):
MILLION = 1000000
if search_phrase:
enc = tiktoken.get_encoding("cl100k_base")
query_encode = enc.encode(search_phrase)
embeddings_cost = len(query_encode) * (EMBEDDINGS_TOKENS_COST/MILLION)
total_cost = embeddings_cost + COHERE_RERANKER_COST + (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION))
else:
total_cost = (input_tokens*(INPUT_TOKENS_COST/MILLION)) + (output_tokens*(OUTPUT_TOKENS_COST/MILLION))
return total_cost
def chat_with_claude(self, user_message_history:list):
input_tokens = 0
output_tokens = 0
message = self.make_llm_api_call(user_message_history)
input_tokens += message.usage.input_tokens
output_tokens += message.usage.output_tokens
documents_list = []
search_phrase = ""
while message.stop_reason == "tool_use":
tool_use = next(block for block in message.content if block.type == "tool_use")
tool_name = tool_use.name
tool_input = tool_use.input
tool_use_id = tool_use.id
documents_list, search_phrase, tool_result = self.process_tool_call(tool_name, tool_input)
user_message_history.append( self.get_tool_use_assistant_message(message.content))
user_message_history.append( self.get_tool_use_user_message(tool_use_id, tool_result))
message = self.make_llm_api_call(user_message_history)
input_tokens += message.usage.input_tokens
output_tokens += message.usage.output_tokens
answer = next(
(block.text for block in message.content if hasattr(block,"text")),
None,
)
if "<answer>" in answer:
answer = answer.split("<answer>")[1].split("</answer>")[0].strip()
total_cost = self.calculate_cost(input_tokens, output_tokens, search_phrase)
return (documents_list, search_phrase, answer, total_cost)
def get_chat_history_text(self, chat_history: List[ChatHistoryItem]):
chat_history_text = ""
for chat_message in chat_history:
chat_history_text += f"User: {chat_message.user_message}\nAssistant: {chat_message.assistant_message}\n"
return chat_history_text.strip()
def get_response(self, input:RequestModel) -> ResponseModel:
chat_history = self.get_chat_history_text(input.chat_history)
user_question = input.user_question
user_prompt = CHAT_PROMPT.format(CHAT_HISTORY=chat_history, USER_QUESTION=user_question)
if input.use_tool:
user_prompt = f"{user_prompt}\nUse Documents_Retriever tool in your response."
sources_list, search_phrase, answer, _ = self.chat_with_claude([{"role":"user","content":[{"type":"text","text":user_prompt}]}])
updated_chat_history = input.chat_history.copy()
updated_chat_history.append(ChatHistoryItem(user_message=user_question, assistant_message=answer))
return ResponseModel(answer = answer, sources_documents = sources_list, chat_history=updated_chat_history, search_phrase=search_phrase)
|