File size: 4,784 Bytes
bcd5ded
61d6f57
 
 
4b4e8c6
61d6f57
 
 
 
 
 
 
 
 
 
 
bcd5ded
4896967
1616958
551e06b
1616958
 
8b439aa
4896967
 
 
 
 
 
4a5b4c3
0fa9763
4896967
4a5b4c3
0fa9763
4896967
 
 
61d6f57
4896967
 
 
 
61d6f57
4896967
 
 
 
61d6f57
4896967
 
 
 
 
 
 
 
61d6f57
4896967
61d6f57
4896967
a457c54
 
 
 
61d6f57
4896967
61d6f57
4896967
 
 
a457c54
62a61d6
 
4896967
 
 
 
 
 
 
 
 
62a61d6
 
4896967
7186dc1
4896967
61d6f57
4896967
61d6f57
4896967
 
 
 
 
9767de9
62a61d6
 
4896967
 
 
 
 
 
62a61d6
 
4896967
 
 
 
62a61d6
 
4896967
 
 
 
62a61d6
 
61d6f57
4896967
4a5b4c3
61d6f57
300737c
4896967
 
 
 
 
 
 
62a61d6
4896967
62a61d6
 
61d6f57
4896967
e55f593
61d6f57
300737c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os

from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch

from pymongo import MongoClient
from rag_base import BaseRAG

os.environ["LANGCHAIN_ENDPOINT"]   = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"]    = "openai-llm-rag"
os.environ["LANGCHAIN_TRACING_V2"] = "true"

class LangChainRAG(BaseRAG):
    MONGODB_DB_NAME = "langchain_db"
    
    CHROMA_DIR  = "/data/db"
    YOUTUBE_DIR = "/data/yt"

    LLM_CHAIN_PROMPT = PromptTemplate(
        input_variables = ["question"], 
        template = os.environ["TEMPLATE"])
    RAG_CHAIN_PROMPT = PromptTemplate(
        input_variables = ["context", "question"], 
        template = os.environ["LANGCHAIN_TEMPLATE"])

    def load_documents(self):
        docs = []
    
        # PDF
        loader = PyPDFLoader(self.PDF_URL)
        docs.extend(loader.load())
        #print("docs = " + str(len(docs)))
    
        # Web
        loader = WebBaseLoader(self.WEB_URL)
        docs.extend(loader.load())
        #print("docs = " + str(len(docs)))
    
        # YouTube
        loader = GenericLoader(
            YoutubeAudioLoader(
                [self.YOUTUBE_URL_1, self.YOUTUBE_URL_2], 
                self.YOUTUBE_DIR), 
            OpenAIWhisperParser())
        docs.extend(loader.load())
        #print("docs = " + str(len(docs)))
    
        return docs

    def split_documents(self, config, docs):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_overlap  = config["chunk_overlap"],
            chunk_size = config["chunk_size"]
        )
    
        return text_splitter.split_documents(docs)
    
    def store_documents_chroma(self, chunks):
        Chroma.from_documents(
            documents = chunks, 
            embedding = OpenAIEmbeddings(disallowed_special = ()), # embed
            persist_directory = self.CHROMA_DIR
        )

    def store_documents_mongodb(self, chunks):
        client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
        collection = client[self.MONGODB_DB_NAME][self.MONGODB_COLLECTION_NAME]

        MongoDBAtlasVectorSearch.from_documents(
            documents = chunks,
            embedding = OpenAIEmbeddings(disallowed_special = ()),
            collection = collection,
            index_name = self.MONGODB_INDEX_NAME
        )

    def ingestion(self, config):
        docs = self.load_documents()
    
        chunks = self.split_documents(config, docs)
    
        #self.store_documents_chroma(chunks)
        self.store_documents_mongodb(chunks)

    def get_vector_store_chroma(self):
        return Chroma(
            embedding_function = OpenAIEmbeddings(disallowed_special = ()), # embed
            persist_directory = self.CHROMA_DIR
        )

    def get_vector_store_mongodb(self):
        return MongoDBAtlasVectorSearch.from_connection_string(
            self.MONGODB_ATLAS_CLUSTER_URI,
            self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
            OpenAIEmbeddings(disallowed_special = ()),
            index_name = self.MONGODB_INDEX_NAME
        )

    def get_llm(self, config):
        return ChatOpenAI(
            model_name = config["model_name"], 
            temperature = config["temperature"]
        )

    def llm_chain(self, config, prompt):
        llm_chain = LLMChain(
            llm = self.get_llm(config), 
            prompt = self.LLM_CHAIN_PROMPT
        )
    
        with get_openai_callback() as callback:
            completion = llm_chain.generate([{"question": prompt}])
    
        return completion, callback

    def rag_chain(self, config, prompt):
        #vector_store = self.get_vector_store_chroma()
        vector_store = self.get_vector_store_mongodb()

        rag_chain = RetrievalQA.from_chain_type(
            self.get_llm(config), 
            chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT}, 
            retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}), 
            return_source_documents = True
        )
    
        with get_openai_callback() as callback:
            completion = rag_chain({"query": prompt})

        return completion, callback