|
import os |
|
import time |
|
from typing import List |
|
from qdrant_client import QdrantClient, models |
|
from langchain_core.documents import Document |
|
from semantic_cache.main import SemanticCache |
|
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings |
|
from Router.router import Evaluator |
|
from langchain_openai import ChatOpenAI |
|
|
|
from blueprints.rag_utils import format_docs, translate |
|
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt |
|
from SafetyChecker import SafetyChecker |
|
from langchain.retrievers import EnsembleRetriever |
|
from BM25 import BM25SRetriever |
|
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
|
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.output_parsers import BaseOutputParser |
|
|
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_groq import ChatGroq |
|
from langchain_core.runnables import RunnablePassthrough |
|
import time |
|
from qdrant_client import QdrantClient |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain.retrievers.document_compressors import LLMChainFilter |
|
from langchain.retrievers.document_compressors import EmbeddingsFilter |
|
from langchain.retrievers.document_compressors import LLMListwiseRerank |
|
from dotenv import load_dotenv |
|
from langchain_openai import OpenAIEmbeddings |
|
load_dotenv() |
|
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY') |
|
|
|
os.environ["COHERE_API_KEY"] |
|
|
|
HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small', api_key = os.getenv('OPENAI_KEY')) |
|
class LineListOutputParser(BaseOutputParser[List[str]]): |
|
"""Output parser for a list of lines.""" |
|
|
|
def parse(self, text: str) -> List[str]: |
|
lines = text.strip().split("\n") |
|
return list(filter(None, lines)) |
|
|
|
|
|
|
|
def add_or_update_system_message(content: str, messages: List[dict]): |
|
""" |
|
Adds a new system message at the beginning of the messages list |
|
:param msg: The message to be added or appended. |
|
:param messages: The list of message dictionaries. |
|
:return: The updated list of message dictionaries. |
|
""" |
|
|
|
if messages and messages[0].get("role") == "system": |
|
messages[0]["content"] += f"{content}\n" |
|
else: |
|
|
|
messages.insert(0, {"role": "system", "content": content}) |
|
return messages |
|
|
|
def split_context( context): |
|
split_index = context.find("User question") |
|
system_prompt = context[:split_index].strip() |
|
user_question = context[split_index:].strip() |
|
user_split_index = user_question.find("<context>") |
|
f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:]) |
|
return f_system_prompt |
|
|
|
def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')): |
|
meta_data_docs = [] |
|
for doc in docs: |
|
meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)] |
|
meta_data_docs.append(meta_data_doc) |
|
return meta_data_docs |
|
|
|
|
|
def search_with_filter(query, vector_store, k, headers): |
|
conditions = [] |
|
|
|
|
|
if len(headers) == 1: |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_1", |
|
match=models.MatchValue( |
|
value=headers[0] |
|
), |
|
) |
|
) |
|
elif len(headers) == 2: |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_1", |
|
match=models.MatchValue( |
|
value=headers[0] |
|
), |
|
) |
|
) |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_2", |
|
match=models.MatchValue( |
|
value=headers[1] |
|
), |
|
) |
|
) |
|
elif len(headers) == 3: |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_1", |
|
match=models.MatchValue( |
|
value=headers[0] |
|
), |
|
) |
|
) |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_2", |
|
match=models.MatchValue( |
|
value=headers[1] |
|
), |
|
) |
|
) |
|
conditions.append( |
|
models.FieldCondition( |
|
key="metadata.Header_3", |
|
match=models.MatchValue( |
|
value=headers[2] |
|
), |
|
) |
|
) |
|
|
|
|
|
single_result = vector_store.similarity_search( |
|
query=query, |
|
k=k, |
|
filter=models.Filter( |
|
must=conditions |
|
), |
|
) |
|
|
|
return single_result |
|
|
|
def get_relevant_documents(documents: List[Document], limit: int) -> List[Document]: |
|
result = [] |
|
seen = set() |
|
for doc in documents: |
|
if doc.page_content not in seen: |
|
result.append(doc) |
|
seen.add(doc.page_content) |
|
if len(result) == limit: |
|
break |
|
return result |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
client = QdrantClient( |
|
url="http://localhost:6333" |
|
) |
|
stsv = Qdrant(client, collection_name="sotaysinhvien_filter", embeddings= HF_EMBEDDING) |
|
stsv_db = stsv.as_retriever(search_kwargs={'k': 10}) |
|
|
|
gthv = Qdrant(client, collection_name="gioithieuhocvien_filter", embeddings= HF_EMBEDDING) |
|
gthv_db = gthv.as_retriever(search_kwargs={'k': 10}) |
|
|
|
ttts = Qdrant(client, collection_name="thongtintuyensinh_filter", embeddings= HF_EMBEDDING) |
|
ttts_db = ttts.as_retriever(search_kwargs={'k': 10}) |
|
|
|
import pickle |
|
with open('data/sotaysinhvien_filter.pkl', 'rb') as f: |
|
sotaysinhvien = pickle.load(f) |
|
with open('data/thongtintuyensinh_filter.pkl', 'rb') as f: |
|
thongtintuyensinh = pickle.load(f) |
|
with open('data/gioithieuhocvien_filter.pkl', 'rb') as f: |
|
gioithieuhocvien = pickle.load(f) |
|
|
|
|
|
retriever_bm25_tuyensinh = BM25SRetriever.from_documents(thongtintuyensinh, k= 10, save_directory = "data/bm25s/ttts") |
|
retriever_bm25_sotay = BM25SRetriever.from_documents(sotaysinhvien, k= 10, save_directory = "data/bm25s/stsv") |
|
retriever_bm25_hocvien = BM25SRetriever.from_documents(gioithieuhocvien, k= 10, save_directory = "data/bm25s/gthv" ) |
|
|
|
|
|
|
|
llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_3')) |
|
llm2 = ChatGroq(model_name="llama-3.1-70b-versatile", temperature=1,api_key= os.getenv('llm_api_8')) |
|
|
|
output_parser = LineListOutputParser() |
|
llm_chain = QUERY_PROMPT | llm | output_parser |
|
|
|
|
|
|
|
|
|
|
|
cache = SemanticCache() |
|
another_chain = ( chitchat_prompt | llm2 | StrOutputParser()) |
|
safe_chain = ( safe_prompt | llm2 | StrOutputParser()) |
|
cache_chain = ( cache_prompt | llm2 | StrOutputParser()) |
|
|
|
|
|
|
|
while 1: |
|
body = {} |
|
|
|
user_message = input("Nhập câu hỏi nào!: ") |
|
|
|
|
|
checker = SafetyChecker() |
|
safety_result = checker.check_safety(translate(user_message)) |
|
print("Safety check :" ,safety_result) |
|
if safety_result != 'safe' : |
|
print("UNSAFE") |
|
response = safe_chain.invoke({'meaning': f'{safety_result}'}) |
|
print(response) |
|
exit() |
|
evaluator = Evaluator(llm="llama3-70b", prompt=evaluator_intent) |
|
output = evaluator.classify_text(user_message) |
|
print(output.result) |
|
retriever = None |
|
db = None |
|
|
|
source = None |
|
cache_result =cache.checker(user_message) |
|
if cache_result is not None: |
|
print("###Cache hit!###") |
|
response = cache_chain.invoke({"question": f'{user_message}', "content": f"{cache_result}"}) |
|
print(response) |
|
|
|
if output and output.result == 'OUT_OF_SCOPE' : |
|
print('OUT OF SCOPE') |
|
|
|
response = another_chain.invoke({"question": f"{user_message}"}) |
|
print(response) |
|
|
|
elif output and output.result == 'ASK_QUYDINH' : |
|
print('SO TAY SINH VIEN DB') |
|
retriever = stsv_db |
|
retriever_bm25 = retriever_bm25_sotay |
|
source = stsv |
|
|
|
|
|
elif output and output.result == 'ASK_HOCVIEN' : |
|
print('GIOI THIEU HOC VIEN DB') |
|
retriever = gthv_db |
|
retriever_bm25 = retriever_bm25_hocvien |
|
source = gthv |
|
|
|
|
|
elif output and output.result == 'ASK_TUYENSINH' : |
|
print('THONG TIN TUYEN SINH DB') |
|
retriever = ttts_db |
|
retriever_bm25 = retriever_bm25_tuyensinh |
|
source = ttts |
|
|
|
|
|
|
|
|
|
if retriever is not None: |
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[retriever_bm25, retriever], weights=[0.5, 0.5]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reranker = LLMListwiseRerank.from_llm( |
|
llm=llm, top_n=5 |
|
) |
|
tailieu = ensemble_retriever.invoke(f"{user_message}") |
|
docs = reranker.compress_documents(tailieu, user_message) |
|
end_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
meta_data_docs = extract_metadata(docs) |
|
|
|
full_result = [] |
|
for meta_data_doc in meta_data_docs: |
|
|
|
result = search_with_filter(user_message, source, 10, meta_data_doc) |
|
|
|
for i in result: |
|
full_result.append(i) |
|
print("Context liên quan" + '\n') |
|
print(full_result) |
|
|
|
|
|
|
|
|
|
|
|
result_final = get_relevant_documents(full_result, 10) |
|
|
|
context = format_docs(result_final) |
|
|
|
best_chain = ( basic_template | llm2 | StrOutputParser()) |
|
|
|
best_result = best_chain.invoke({"question": f'{user_message}', "context": f"{context}"}) |
|
|
|
print(f'Câu trả lời tối ưu nhất: {best_result}') |
|
|
|
|
|
|
|
print(f'TIME USING : {end_time -start_time}') |
|
else: |
|
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') |
|
|
|
|
|
|