File size: 5,474 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from blueprints.rag_utils import format_docs
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
from langchain_mistralai import ChatMistralAI
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
# from langchain_groq import ChatGroq
import time 
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
load_dotenv()
HF_EMBEDDING =  OpenAIEmbeddings(model='text-embedding-3-small')

class LineListOutputParser(BaseOutputParser[List[str]]):
    """Output parser for a list of lines."""

    def parse(self, text: str) -> List[str]:
        lines = text.strip().split("\n")
        return list(filter(None, lines))  # Remove empty lines

# def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
#     meta_data_docs = []
#     for doc in docs:
#         meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
#         meta_data_docs.append(meta_data_doc)
#     return meta_data_docs

# def search_with_filter(query, vector_store, k, headers):
#     conditions = [
#         models.FieldCondition(
#             key="metadata.Header_1",
#             match=models.MatchValue(
#                 value=headers[0]
#             ),
#         ),
#         models.FieldCondition(
#             key="metadata.Header_2",
#             match=models.MatchValue(
#                 value=headers[1]
#             ),
#         ),
#     ]


#     if len(headers) == 3:
#         conditions.append(
#             models.FieldCondition(
#                 key="metadata.Header_3",
#                 match=models.MatchValue(
#                     value=headers[2]
#                 ),
#             )
#         )


#     single_result = vector_store.similarity_search(
#         query=query,
#         k=k,
#         filter=models.Filter(
#             must=conditions
#         ),
#     )
    
#     return single_result


if __name__ == "__main__": 

    client = QdrantClient(url="http://localhost:6333")

    stsv = Qdrant(client, collection_name="eval_collection2", embeddings= HF_EMBEDDING)
    retriever = stsv.as_retriever(search_kwargs={'k': 3})

    import pickle

    with open('/home/justtuananh/AI4TUAN/DOAN2024/offical/pipelines/documents_new.pkl', 'rb') as f:
        sotaysinhvien = pickle.load(f)

    retriever_bm25 = BM25SRetriever.from_documents(sotaysinhvien, k= 5, activate_numba = True)

    # reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
    llm = ChatMistralAI(
    model="mistral-large-2407",
    temperature=0,
    max_retries=2,
)
    # llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_4'))
    # llm2 = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5'))
    output_parser = LineListOutputParser()
    llm_chain = QUERY_PROMPT | llm | output_parser

    messages = [
        {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
    ]

    def duy_phen():

        user_message = input("Nhập câu hỏi của bạn: ")
        start_time = time.time()

        if retriever is not None:
            retriever_multi = MultiQueryRetriever(
                            retriever=retriever, llm_chain=llm_chain, parser_key="lines"
                        ) 

            ensemble_retriever = EnsembleRetriever(
                    retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5])
            
            compressor = LLMChainExtractor.from_llm(llm)
            # _filter = LLMChainFilter.from_llm(llm)
            _filter2 = LLMListwiseRerank.from_llm(llm, top_n=5)
            embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)
            compression = ContextualCompressionRetriever(
                        base_compressor=_filter2, base_retriever=ensemble_retriever
                        )

            # rag_chain = (
            #                 {"context": compression | format_docs, "question": RunnablePassthrough()}
            #                 | basic_template | llm2 | StrOutputParser()
            #             )
            
            print(compression.invoke(f"{user_message}"))
            end_time = time.time()
            print(f'TIME USING : {end_time-start_time}')

        else:
            print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')

    duy_phen()