bstraehle commited on
Commit
61d6f57
·
1 Parent(s): 1b1996b

Create rag_langchain.py

Browse files
Files changed (1) hide show
  1. rag_langchain.py +142 -0
rag_langchain.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging, os, sys
2
+
3
+ from langchain.callbacks import get_openai_callback
4
+ from langchain.chains import LLMChain, RetrievalQA
5
+ from langchain.chat_models import ChatOpenAI
6
+ from langchain.document_loaders import PyPDFLoader, WebBaseLoader
7
+ from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
8
+ from langchain.document_loaders.generic import GenericLoader
9
+ from langchain.document_loaders.parsers import OpenAIWhisperParser
10
+ from langchain.embeddings.openai import OpenAIEmbeddings
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.vectorstores import Chroma
14
+ from langchain.vectorstores import MongoDBAtlasVectorSearch
15
+
16
+ from pymongo import MongoClient
17
+
18
+ RAG_CHROMA = "Chroma"
19
+ RAG_MONGODB = "MongoDB"
20
+
21
+ PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
22
+ WEB_URL = "https://openai.com/research/gpt-4"
23
+ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
24
+ YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
25
+
26
+ YOUTUBE_DIR = "/data/yt"
27
+ CHROMA_DIR = "/data/db"
28
+
29
+ MONGODB_ATLAS_CLUSTER_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
30
+ MONGODB_DB_NAME = "langchain_db"
31
+ MONGODB_COLLECTION_NAME = "gpt-4"
32
+ MONGODB_INDEX_NAME = "default"
33
+
34
+ LLM_CHAIN_PROMPT = PromptTemplate(
35
+ input_variables = ["question"],
36
+ template = os.environ["LLM_TEMPLATE"])
37
+ RAG_CHAIN_PROMPT = PromptTemplate(
38
+ input_variables = ["context", "question"],
39
+ template = os.environ["RAG_TEMPLATE"])
40
+
41
+ logging.basicConfig(stream = sys.stdout, level = logging.INFO)
42
+ logging.getLogger().addHandler(logging.StreamHandler(stream = sys.stdout))
43
+
44
+ def load_documents():
45
+ docs = []
46
+
47
+ # PDF
48
+ loader = PyPDFLoader(PDF_URL)
49
+ docs.extend(loader.load())
50
+ #print("docs = " + str(len(docs)))
51
+
52
+ # Web
53
+ loader = WebBaseLoader(WEB_URL)
54
+ docs.extend(loader.load())
55
+ #print("docs = " + str(len(docs)))
56
+
57
+ # YouTube
58
+ loader = GenericLoader(
59
+ YoutubeAudioLoader(
60
+ [YOUTUBE_URL_1, YOUTUBE_URL_2],
61
+ YOUTUBE_DIR),
62
+ OpenAIWhisperParser())
63
+ docs.extend(loader.load())
64
+ #print("docs = " + str(len(docs)))
65
+
66
+ return docs
67
+
68
+ def split_documents(config, docs):
69
+ text_splitter = RecursiveCharacterTextSplitter()
70
+
71
+ return text_splitter.split_documents(docs)
72
+
73
+ def store_chroma(chunks):
74
+ Chroma.from_documents(
75
+ documents = chunks,
76
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
77
+ persist_directory = CHROMA_DIR)
78
+
79
+ def store_mongodb(chunks):
80
+ client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
81
+ collection = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
82
+
83
+ MongoDBAtlasVectorSearch.from_documents(
84
+ documents = chunks,
85
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
86
+ collection = collection,
87
+ index_name = MONGODB_INDEX_NAME)
88
+
89
+ def rag_ingestion(config):
90
+ docs = load_documents()
91
+
92
+ chunks = split_documents(config, docs)
93
+
94
+ store_chroma(chunks)
95
+ store_mongodb(chunks)
96
+
97
+ def retrieve_chroma():
98
+ return Chroma(
99
+ embedding_function = OpenAIEmbeddings(disallowed_special = ()),
100
+ persist_directory = CHROMA_DIR)
101
+
102
+ def retrieve_mongodb():
103
+ return MongoDBAtlasVectorSearch.from_connection_string(
104
+ MONGODB_ATLAS_CLUSTER_URI,
105
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
106
+ OpenAIEmbeddings(disallowed_special = ()),
107
+ index_name = MONGODB_INDEX_NAME)
108
+
109
+ def get_llm(config):
110
+ return ChatOpenAI(
111
+ model_name = config["model_name"],
112
+ temperature = config["temperature"])
113
+
114
+ def llm_chain(config, prompt):
115
+ llm_chain = LLMChain(
116
+ llm = get_llm(config),
117
+ prompt = LLM_CHAIN_PROMPT)
118
+
119
+ with get_openai_callback() as cb:
120
+ completion = llm_chain.generate([{"question": prompt}])
121
+
122
+ return completion, llm_chain, cb
123
+
124
+ def rag_chain(config, rag_option, prompt):
125
+ llm = get_llm(config)
126
+
127
+ if (rag_option == RAG_CHROMA):
128
+ db = retrieve_chroma()
129
+ elif (rag_option == RAG_MONGODB):
130
+ db = retrieve_mongodb()
131
+
132
+ rag_chain = RetrievalQA.from_chain_type(
133
+ llm,
134
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT,
135
+ "verbose": True},
136
+ retriever = db.as_retriever(search_kwargs = {"k": config["k"]}),
137
+ return_source_documents = True)
138
+
139
+ with get_openai_callback() as cb:
140
+ completion = rag_chain({"query": prompt})
141
+
142
+ return completion, rag_chain, cb