File size: 12,665 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import time
from typing import List
from qdrant_client import QdrantClient, models
from langchain_core.documents import Document
from semantic_cache.main import SemanticCache
from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
from Router.router import Evaluator
from langchain_openai import ChatOpenAI
# from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message
from blueprints.rag_utils import format_docs, translate
from blueprints.prompts import QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt, safe_prompt, cache_prompt
from SafetyChecker import SafetyChecker
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
# import cohere
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers import BaseOutputParser
# from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_groq import ChatGroq
from langchain_core.runnables import RunnablePassthrough
import time 
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
from langchain.retrievers.document_compressors import LLMChainFilter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers.document_compressors import LLMListwiseRerank
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_KEY')

os.environ["COHERE_API_KEY"]
# HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding")
HF_EMBEDDING =  OpenAIEmbeddings(model='text-embedding-3-small', api_key = os.getenv('OPENAI_KEY'))
class LineListOutputParser(BaseOutputParser[List[str]]):
    """Output parser for a list of lines."""

    def parse(self, text: str) -> List[str]:
        lines = text.strip().split("\n")
        return list(filter(None, lines))  # Remove empty lines



def add_or_update_system_message(content: str, messages: List[dict]):
            """
            Adds a new system message at the beginning of the messages list
            :param msg: The message to be added or appended.
            :param messages: The list of message dictionaries.
            :return: The updated list of message dictionaries.
            """

            if messages and messages[0].get("role") == "system":
                messages[0]["content"] += f"{content}\n"
            else:
                # Insert at the beginning
                messages.insert(0, {"role": "system", "content": content})
            return messages

def split_context( context):
        split_index = context.find("User question")
        system_prompt = context[:split_index].strip()
        user_question = context[split_index:].strip()
        user_split_index = user_question.find("<context>")
        f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:])
        return f_system_prompt

def extract_metadata(docs, headers=('Header_1', 'Header_2', 'Header_3')):
    meta_data_docs = []
    for doc in docs:
        meta_data_doc = [doc.metadata[header] for header in headers if doc.metadata.get(header)]
        meta_data_docs.append(meta_data_doc)
    return meta_data_docs


def search_with_filter(query, vector_store, k, headers):
    conditions = []
    
    # Xử lý điều kiện theo số lượng headers
    if len(headers) == 1:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
    elif len(headers) == 2:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_2",
                match=models.MatchValue(
                    value=headers[1]
                ),
            )
        )
    elif len(headers) == 3:
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_1",
                match=models.MatchValue(
                    value=headers[0]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_2",
                match=models.MatchValue(
                    value=headers[1]
                ),
            )
        )
        conditions.append(
            models.FieldCondition(
                key="metadata.Header_3",
                match=models.MatchValue(
                    value=headers[2]
                ),
            )
        )

    # Thực hiện truy vấn với các điều kiện
    single_result = vector_store.similarity_search(
        query=query,
        k=k,
        filter=models.Filter(
            must=conditions
        ),
    )
    
    return single_result

def get_relevant_documents(documents: List[Document], limit: int) -> List[Document]:
    result = []
    seen = set()
    for doc in documents:
        if doc.page_content not in seen:
            result.append(doc)
            seen.add(doc.page_content)
        if len(result) == limit:
            break
    return result




if __name__ == "__main__":

    client = QdrantClient(
        url="http://localhost:6333"
    )
    stsv = Qdrant(client, collection_name="sotaysinhvien_filter", embeddings= HF_EMBEDDING)
    stsv_db = stsv.as_retriever(search_kwargs={'k': 10})

    gthv = Qdrant(client, collection_name="gioithieuhocvien_filter", embeddings= HF_EMBEDDING)
    gthv_db = gthv.as_retriever(search_kwargs={'k': 10})

    ttts = Qdrant(client, collection_name="thongtintuyensinh_filter", embeddings= HF_EMBEDDING)
    ttts_db = ttts.as_retriever(search_kwargs={'k': 10})

    import pickle
    with open('data/sotaysinhvien_filter.pkl', 'rb') as f:
        sotaysinhvien = pickle.load(f)
    with open('data/thongtintuyensinh_filter.pkl', 'rb') as f:
        thongtintuyensinh = pickle.load(f)
    with open('data/gioithieuhocvien_filter.pkl', 'rb') as f:
        gioithieuhocvien = pickle.load(f)


    retriever_bm25_tuyensinh = BM25SRetriever.from_documents(thongtintuyensinh, k= 10, save_directory = "data/bm25s/ttts")
    retriever_bm25_sotay = BM25SRetriever.from_documents(sotaysinhvien, k= 10,  save_directory = "data/bm25s/stsv")
    retriever_bm25_hocvien = BM25SRetriever.from_documents(gioithieuhocvien, k= 10, save_directory = "data/bm25s/gthv" )


    # reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
    llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_3'))
    llm2 = ChatGroq(model_name="llama-3.1-70b-versatile", temperature=1,api_key= os.getenv('llm_api_8'))
    
    output_parser = LineListOutputParser()
    llm_chain = QUERY_PROMPT | llm | output_parser

    # messages = [
    #     {"role": "system", "content": "Dựa vào thông tin sau, trả lời câu hỏi bằng tiếng việt"}
    # ]
    # ###########################
    cache = SemanticCache()
    another_chain = ( chitchat_prompt | llm2 | StrOutputParser())
    safe_chain = ( safe_prompt | llm2 | StrOutputParser())
    cache_chain = ( cache_prompt | llm2 | StrOutputParser())


    # def duy_phen():
    while 1: 
        body = {}

        user_message = input("Nhập câu hỏi nào!: ")
        
        
        checker = SafetyChecker()
        safety_result = checker.check_safety(translate(user_message))
        print("Safety check :" ,safety_result)
        if safety_result != 'safe' :
                    print("UNSAFE")
                    response = safe_chain.invoke({'meaning': f'{safety_result}'})
                    print(response)
                    exit()
        evaluator = Evaluator(llm="llama3-70b", prompt=evaluator_intent)
        output = evaluator.classify_text(user_message)
        print(output.result)
        retriever = None  # or assign a specific default retriever if applicable
        db = None  # initialize db as well if it is used later in the code
                # print(output.result)
        source = None
        cache_result =cache.checker(user_message)
        if cache_result is not None:
            print("###Cache hit!###")
            response = cache_chain.invoke({"question": f'{user_message}', "content": f"{cache_result}"})
            print(response)

        if output and  output.result == 'OUT_OF_SCOPE' :
                    print('OUT OF SCOPE')
                    # print(body)
                    response = another_chain.invoke({"question": f"{user_message}"})
                    print(response)
                    
        elif output and  output.result == 'ASK_QUYDINH'  :
                        print('SO TAY SINH VIEN DB') 
                        retriever = stsv_db
                        retriever_bm25 = retriever_bm25_sotay
                        source = stsv
                
                        # db = sotaysinhvien
        elif output and  output.result == 'ASK_HOCVIEN' :
                        print('GIOI THIEU HOC VIEN DB')  
                        retriever = gthv_db
                        retriever_bm25 = retriever_bm25_hocvien
                        source = gthv
                      
                        # db = gioithieuhocvien
        elif output and  output.result == 'ASK_TUYENSINH'  :
                        print('THONG TIN TUYEN SINH DB') 
                        retriever = ttts_db
                        retriever_bm25 = retriever_bm25_tuyensinh
                        source = ttts
                  
                        # db = thongtintuyensinh

        
        if retriever is not None:
            
            # retriever_multi = MultiQueryRetriever(
            #                 retriever=retriever, llm_chain=llm_chain, parser_key="lines"
            #             ) 
            start_time = time.time()
            ensemble_retriever = EnsembleRetriever(
                    retrievers=[retriever_bm25, retriever], weights=[0.5, 0.5])
            
            # compressor = LLMChainExtractor.from_llm(llm)
            # _filter = LLMChainFilter.from_llm(llm)
            # embeddings_filter = EmbeddingsFilter(embeddings=HF_EMBEDDING, similarity_threshold=0.5)

            # compression = ContextualCompressionRetriever(
            #             base_compressor=_filter2, base_retriever=ensemble_retriever
            #             )

            reranker = LLMListwiseRerank.from_llm(
                llm=llm, top_n=5
            )           
            tailieu = ensemble_retriever.invoke(f"{user_message}")
            docs = reranker.compress_documents(tailieu, user_message)
            end_time = time.time()
#################### Filter lại ở đây -> add more documents liên quan hơn ######################### 
            
            # docs = compression.invoke(f"{user_message}")
            # print(docs)
            
            
            meta_data_docs = extract_metadata(docs)
         
            full_result = []
            for meta_data_doc in meta_data_docs:
              
                result = search_with_filter(user_message, source, 10, meta_data_doc)
        
                for i in result: 
                    full_result.append(i)
            print("Context liên quan" + '\n')
            print(full_result)
            
            # rag_chain = (
            #                 {"context": compression | format_docs, "question": RunnablePassthrough()}
            #                 | basic_template | llm2 | StrOutputParser()
            #             )
            result_final = get_relevant_documents(full_result, 10)
            
            context = format_docs(result_final)

            best_chain = ( basic_template | llm2 | StrOutputParser())
            
            best_result = best_chain.invoke({"question": f'{user_message}', "context": f"{context}"})
            
            print(f'Câu trả lời tối ưu nhất: {best_result}')


            
            print(f'TIME USING : {end_time -start_time}')
        else:
            print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')

    # duy_phen()