Graduation / pipelines /filter.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
12.7 kB
import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from semantic_cache.main import SemanticCache
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from Router.router import Evaluator
from langchain_openai import ChatOpenAI
# from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message
from blueprints.rag_utils import format_docs, translate
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from SafetyChecker import SafetyChecker
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
# import cohere
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_groq import ChatGroq
from langchain_core.runnables import RunnablePassthrough
import time
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY')
os.environ["COHERE_API_KEY"]
# HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding")
HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small', api_key = os.getenv('OPENAI_KEY'))
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of lines."""
def parse(self, text: str) -> List[str]:
lines = text.strip().split("\n")
return list(filter(None, lines)) # Remove empty lines
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def split_context( context):
split_index = context.find("User question")
system_prompt = context[:split_index].strip()
user_question = context[split_index:].strip()
user_split_index = user_question.find("<context>")
f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:])
return f_system_prompt
def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
meta_data_docs = []
for doc in docs:
meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
meta_data_docs.append(meta_data_doc)
return meta_data_docs
def search_with_filter(query, vector_store, k, headers):
conditions = []
# Xử lý điều kiện theo số lượng headers
if len(headers) == 1:
conditions.append(
models.FieldCondition(
key="metadata.Header_1",
match=models.MatchValue(
value=headers[0]
),
)
)
elif len(headers) == 2:
conditions.append(
models.FieldCondition(
key="metadata.Header_1",
match=models.MatchValue(
value=headers[0]
),
)
)
conditions.append(
models.FieldCondition(
key="metadata.Header_2",
match=models.MatchValue(
value=headers[1]
),
)
)
elif len(headers) == 3:
conditions.append(
models.FieldCondition(
key="metadata.Header_1",
match=models.MatchValue(
value=headers[0]
),
)
)
conditions.append(
models.FieldCondition(
key="metadata.Header_2",
match=models.MatchValue(
value=headers[1]
),
)
)
conditions.append(
models.FieldCondition(
key="metadata.Header_3",
match=models.MatchValue(
value=headers[2]
),
)
)
# Thực hiện truy vấn với các điều kiện
single_result = vector_store.similarity_search(
query=query,
k=k,
filter=models.Filter(
must=conditions
),
)
return single_result
def get_relevant_documents(documents: List[Document], limit: int) -> List[Document]:
result = []
seen = set()
for doc in documents:
if doc.page_content not in seen:
result.append(doc)
seen.add(doc.page_content)
if len(result) == limit:
break
return result
if __name__ == "__main__":
client = QdrantClient(
url="http://localhost:6333"
)
stsv = Qdrant(client, collection_name="sotaysinhvien_filter", embeddings= HF_EMBEDDING)
stsv_db = stsv.as_retriever(search_kwargs={'k': 10})
gthv = Qdrant(client, collection_name="gioithieuhocvien_filter", embeddings= HF_EMBEDDING)
gthv_db = gthv.as_retriever(search_kwargs={'k': 10})
ttts = Qdrant(client, collection_name="thongtintuyensinh_filter", embeddings= HF_EMBEDDING)
ttts_db = ttts.as_retriever(search_kwargs={'k': 10})
import pickle
with open('data/sotaysinhvien_filter.pkl', 'rb') as f:
sotaysinhvien = pickle.load(f)
with open('data/thongtintuyensinh_filter.pkl', 'rb') as f:
thongtintuyensinh = pickle.load(f)
with open('data/gioithieuhocvien_filter.pkl', 'rb') as f:
gioithieuhocvien = pickle.load(f)
retriever_bm25_tuyensinh = BM25SRetriever.from_documents(thongtintuyensinh, k= 10, save_directory = "data/bm25s/ttts")
retriever_bm25_sotay = BM25SRetriever.from_documents(sotaysinhvien, k= 10, save_directory = "data/bm25s/stsv")
retriever_bm25_hocvien = BM25SRetriever.from_documents(gioithieuhocvien, k= 10, save_directory = "data/bm25s/gthv" )
# reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_3'))
llm2 = ChatGroq(model_name="llama-3.1-70b-versatile", temperature=1,api_key= os.getenv('llm_api_8'))
output_parser = LineListOutputParser()
llm_chain = QUERY_PROMPT | llm | output_parser
# messages = [
# {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
# ]
# ###########################
cache = SemanticCache()
another_chain = ( chitchat_prompt | llm2 | StrOutputParser())
safe_chain = ( safe_prompt | llm2 | StrOutputParser())
cache_chain = ( cache_prompt | llm2 | StrOutputParser())
# def duy_phen():
while 1:
body = {}
user_message = input("Nhập câu hỏi nào!: ")
checker = SafetyChecker()
safety_result = checker.check_safety(translate(user_message))
print("Safety check :" ,safety_result)
if safety_result != 'safe' :
print("UNSAFE")
response = safe_chain.invoke({'meaning': f'{safety_result}'})
print(response)
exit()
evaluator = Evaluator(llm="llama3-70b", prompt=evaluator_intent)
output = evaluator.classify_text(user_message)
print(output.result)
retriever = None # or assign a specific default retriever if applicable
db = None # initialize db as well if it is used later in the code
# print(output.result)
source = None
cache_result =cache.checker(user_message)
if cache_result is not None:
print("###Cache hit!###")
response = cache_chain.invoke({"question": f'{user_message}', "content": f"{cache_result}"})
print(response)
if output and output.result == 'OUT_OF_SCOPE' :
print('OUT OF SCOPE')
# print(body)
response = another_chain.invoke({"question": f"{user_message}"})
print(response)
elif output and output.result == 'ASK_QUYDINH' :
print('SO TAY SINH VIEN DB')
retriever = stsv_db
retriever_bm25 = retriever_bm25_sotay
source = stsv
# db = sotaysinhvien
elif output and output.result == 'ASK_HOCVIEN' :
print('GIOI THIEU HOC VIEN DB')
retriever = gthv_db
retriever_bm25 = retriever_bm25_hocvien
source = gthv
# db = gioithieuhocvien
elif output and output.result == 'ASK_TUYENSINH' :
print('THONG TIN TUYEN SINH DB')
retriever = ttts_db
retriever_bm25 = retriever_bm25_tuyensinh
source = ttts
# db = thongtintuyensinh
if retriever is not None:
# retriever_multi = MultiQueryRetriever(
# retriever=retriever, llm_chain=llm_chain, parser_key="lines"
# )
start_time = time.time()
ensemble_retriever = EnsembleRetriever(
retrievers=[retriever_bm25, retriever], weights=[0.5, 0.5])
# compressor = LLMChainExtractor.from_llm(llm)
# _filter = LLMChainFilter.from_llm(llm)
# embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)
# compression = ContextualCompressionRetriever(
# base_compressor=_filter2, base_retriever=ensemble_retriever
# )
reranker = LLMListwiseRerank.from_llm(
llm=llm, top_n=5
)
tailieu = ensemble_retriever.invoke(f"{user_message}")
docs = reranker.compress_documents(tailieu, user_message)
end_time = time.time()
#################### Filter lại ở đây -> add more documents liên quan hơn #########################
# docs = compression.invoke(f"{user_message}")
# print(docs)
meta_data_docs = extract_metadata(docs)
full_result = []
for meta_data_doc in meta_data_docs:
result = search_with_filter(user_message, source, 10, meta_data_doc)
for i in result:
full_result.append(i)
print("Context liên quan" + '\n')
print(full_result)
# rag_chain = (
# {"context": compression | format_docs, "question": RunnablePassthrough()}
# | basic_template | llm2 | StrOutputParser()
# )
result_final = get_relevant_documents(full_result, 10)
context = format_docs(result_final)
best_chain = ( basic_template | llm2 | StrOutputParser())
best_result = best_chain.invoke({"question": f'{user_message}', "context": f"{context}"})
print(f'Câu trả lời tối ưu nhất: {best_result}')
print(f'TIME USING : {end_time -start_time}')
else:
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')
# duy_phen()