File size: 5,474 Bytes
74b1bac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from blueprints.rag_utils import format_docs
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
from langchain_mistralai import ChatMistralAI
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
# from langchain_groq import ChatGroq
import time
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
load_dotenv()
HF_EMBEDDING = OpenAIEmbeddings(model='text-embedding-3-small')
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of lines."""
def parse(self, text: str) -> List[str]:
lines = text.strip().split("\n")
return list(filter(None, lines)) # Remove empty lines
# def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
# meta_data_docs = []
# for doc in docs:
# meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
# meta_data_docs.append(meta_data_doc)
# return meta_data_docs
# def search_with_filter(query, vector_store, k, headers):
# conditions = [
# models.FieldCondition(
# key="metadata.Header_1",
# match=models.MatchValue(
# value=headers[0]
# ),
# ),
# models.FieldCondition(
# key="metadata.Header_2",
# match=models.MatchValue(
# value=headers[1]
# ),
# ),
# ]
# if len(headers) == 3:
# conditions.append(
# models.FieldCondition(
# key="metadata.Header_3",
# match=models.MatchValue(
# value=headers[2]
# ),
# )
# )
# single_result = vector_store.similarity_search(
# query=query,
# k=k,
# filter=models.Filter(
# must=conditions
# ),
# )
# return single_result
if __name__ == "__main__":
client = QdrantClient(url="http://localhost:6333")
stsv = Qdrant(client, collection_name="eval_collection2", embeddings= HF_EMBEDDING)
retriever = stsv.as_retriever(search_kwargs={'k': 3})
import pickle
with open('/home/justtuananh/AI4TUAN/DOAN2024/offical/pipelines/documents_new.pkl', 'rb') as f:
sotaysinhvien = pickle.load(f)
retriever_bm25 = BM25SRetriever.from_documents(sotaysinhvien, k= 5, activate_numba = True)
# reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
llm = ChatMistralAI(
model="mistral-large-2407",
temperature=0,
max_retries=2,
)
# llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_4'))
# llm2 = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5'))
output_parser = LineListOutputParser()
llm_chain = QUERY_PROMPT | llm | output_parser
messages = [
{"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
]
def duy_phen():
user_message = input("Nhập câu hỏi của bạn: ")
start_time = time.time()
if retriever is not None:
retriever_multi = MultiQueryRetriever(
retriever=retriever, llm_chain=llm_chain, parser_key="lines"
)
ensemble_retriever = EnsembleRetriever(
retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5])
compressor = LLMChainExtractor.from_llm(llm)
# _filter = LLMChainFilter.from_llm(llm)
_filter2 = LLMListwiseRerank.from_llm(llm, top_n=5)
embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)
compression = ContextualCompressionRetriever(
base_compressor=_filter2, base_retriever=ensemble_retriever
)
# rag_chain = (
# {"context": compression | format_docs, "question": RunnablePassthrough()}
# | basic_template | llm2 | StrOutputParser()
# )
print(compression.invoke(f"{user_message}"))
end_time = time.time()
print(f'TIME USING : {end_time-start_time}')
else:
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')
duy_phen()
|