File size: 3,529 Bytes
fb72340
5f04412
 
50dc063
3e7b971
5f04412
a0b5dc6
32dedd9
5f04412
 
 
 
c947c47
1342013
5f04412
ebaeae5
 
dce66ba
e946a29
0ddb69a
5f04412
0ddb69a
 
 
 
5f04412
0ddb69a
 
5f04412
0ddb69a
5f04412
0ddb69a
3940450
0ddb69a
 
5f04412
0ddb69a
 
5f04412
0ddb69a
 
 
3940450
0ddb69a
5f04412
0ddb69a
 
3940450
 
0ddb69a
5f04412
0ddb69a
5f04412
c273c9f
 
 
 
0ddb69a
5f04412
e946a29
0ddb69a
3940450
008f2f7
3940450
 
0ddb69a
c273c9f
c4c7ba9
c273c9f
f5c95d9
 
1b00820
c273c9f
 
 
c4c7ba9
c273c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
5f04412
e946a29
ff3ca13
5f04412
ff3ca13
c273c9f
e946a29
0ddb69a
db1e09d
 
5f04412
0ddb69a
bb08f6a
c273c9f
0ddb69a
 
dce66ba
0ddb69a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os, requests

from llama_hub.youtube_transcript import YoutubeTranscriptReader
from llama_index import download_loader, PromptTemplate, ServiceContext
from llama_index.embeddings import OpenAIEmbedding
from llama_index.indices.vector_store.base import VectorStoreIndex
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch

from pathlib import Path
from pymongo import MongoClient
from rag_base import BaseRAG

class LlamaIndexRAG(BaseRAG):
    MONGODB_DB_NAME = "llamaindex_db"

    def load_documents(self):
        docs = []
    
        # PDF
        PDFReader = download_loader("PDFReader")
        loader = PDFReader()
        out_dir = Path("data")
    
        if not out_dir.exists():
            os.makedirs(out_dir)
    
        out_path = out_dir / "gpt-4.pdf"
    
        if not out_path.exists():
            r = requests.get(self.PDF_URL)
            with open(out_path, "wb") as f:
                f.write(r.content)

        docs.extend(loader.load_data(file = Path(out_path)))
        #print("docs = " + str(len(docs)))
    
        # Web
        SimpleWebPageReader = download_loader("SimpleWebPageReader")
        loader = SimpleWebPageReader()
        docs.extend(loader.load_data(urls = [self.WEB_URL]))
        #print("docs = " + str(len(docs)))

        # YouTube
        loader = YoutubeTranscriptReader()
        docs.extend(loader.load_data(ytlinks = [self.YOUTUBE_URL_1,
                                                self.YOUTUBE_URL_2]))
        #print("docs = " + str(len(docs)))
    
        return docs

    def get_llm(self, config):
        return OpenAI(
            model = config["model_name"], 
            temperature = config["temperature"]
        )

    def get_vector_store(self):
        return MongoDBAtlasVectorSearch(
            MongoClient(self.MONGODB_ATLAS_CLUSTER_URI),
            db_name = self.MONGODB_DB_NAME,
            collection_name = self.MONGODB_COLLECTION_NAME,
            index_name = self.MONGODB_INDEX_NAME
        )
        
    def get_service_context(self, config):
        return ServiceContext.from_defaults(
            chunk_overlap = config["chunk_overlap"],
            chunk_size = config["chunk_size"],
            embed_model = OpenAIEmbedding(), # embed
            llm = self.get_llm(config)
        )

    def get_storage_context(self):
        return StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
        
    def store_documents(self, config, docs):
        storage_context = StorageContext.from_defaults(
            vector_store = self.get_vector_store()
        )
    
        VectorStoreIndex.from_documents(
            docs,
            service_context = self.get_service_context(config),
            storage_context = self.get_storage_context()
        )

    def ingestion(self, config):
        docs = self.load_documents()
    
        self.store_documents(config, docs)
       
    def retrieval(self, config, prompt):
        index = VectorStoreIndex.from_vector_store(
            vector_store = self.get_vector_store()
        )

        query_engine = index.as_query_engine(
            text_qa_template = PromptTemplate(os.environ["RAG_TEMPLATE_2"]),
            service_context = self.get_service_context(config),
            similarity_top_k = config["k"]
        )

        return query_engine.query(prompt)