File size: 4,784 Bytes
bcd5ded 61d6f57 4b4e8c6 61d6f57 bcd5ded 4896967 1616958 551e06b 1616958 8b439aa 4896967 4a5b4c3 0fa9763 4896967 4a5b4c3 0fa9763 4896967 61d6f57 4896967 61d6f57 4896967 61d6f57 4896967 61d6f57 4896967 61d6f57 4896967 a457c54 61d6f57 4896967 61d6f57 4896967 a457c54 62a61d6 4896967 62a61d6 4896967 7186dc1 4896967 61d6f57 4896967 61d6f57 4896967 9767de9 62a61d6 4896967 62a61d6 4896967 62a61d6 4896967 62a61d6 61d6f57 4896967 4a5b4c3 61d6f57 300737c 4896967 62a61d6 4896967 62a61d6 61d6f57 4896967 e55f593 61d6f57 300737c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
from langchain.callbacks import get_openai_callback
from langchain.chains import LLMChain, RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import OpenAIWhisperParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores import MongoDBAtlasVectorSearch
from pymongo import MongoClient
from rag_base import BaseRAG
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_PROJECT"] = "openai-llm-rag"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
class LangChainRAG(BaseRAG):
MONGODB_DB_NAME = "langchain_db"
CHROMA_DIR = "/data/db"
YOUTUBE_DIR = "/data/yt"
LLM_CHAIN_PROMPT = PromptTemplate(
input_variables = ["question"],
template = os.environ["TEMPLATE"])
RAG_CHAIN_PROMPT = PromptTemplate(
input_variables = ["context", "question"],
template = os.environ["LANGCHAIN_TEMPLATE"])
def load_documents(self):
docs = []
# PDF
loader = PyPDFLoader(self.PDF_URL)
docs.extend(loader.load())
#print("docs = " + str(len(docs)))
# Web
loader = WebBaseLoader(self.WEB_URL)
docs.extend(loader.load())
#print("docs = " + str(len(docs)))
# YouTube
loader = GenericLoader(
YoutubeAudioLoader(
[self.YOUTUBE_URL_1, self.YOUTUBE_URL_2],
self.YOUTUBE_DIR),
OpenAIWhisperParser())
docs.extend(loader.load())
#print("docs = " + str(len(docs)))
return docs
def split_documents(self, config, docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_overlap = config["chunk_overlap"],
chunk_size = config["chunk_size"]
)
return text_splitter.split_documents(docs)
def store_documents_chroma(self, chunks):
Chroma.from_documents(
documents = chunks,
embedding = OpenAIEmbeddings(disallowed_special = ()), # embed
persist_directory = self.CHROMA_DIR
)
def store_documents_mongodb(self, chunks):
client = MongoClient(self.MONGODB_ATLAS_CLUSTER_URI)
collection = client[self.MONGODB_DB_NAME][self.MONGODB_COLLECTION_NAME]
MongoDBAtlasVectorSearch.from_documents(
documents = chunks,
embedding = OpenAIEmbeddings(disallowed_special = ()),
collection = collection,
index_name = self.MONGODB_INDEX_NAME
)
def ingestion(self, config):
docs = self.load_documents()
chunks = self.split_documents(config, docs)
#self.store_documents_chroma(chunks)
self.store_documents_mongodb(chunks)
def get_vector_store_chroma(self):
return Chroma(
embedding_function = OpenAIEmbeddings(disallowed_special = ()), # embed
persist_directory = self.CHROMA_DIR
)
def get_vector_store_mongodb(self):
return MongoDBAtlasVectorSearch.from_connection_string(
self.MONGODB_ATLAS_CLUSTER_URI,
self.MONGODB_DB_NAME + "." + self.MONGODB_COLLECTION_NAME,
OpenAIEmbeddings(disallowed_special = ()),
index_name = self.MONGODB_INDEX_NAME
)
def get_llm(self, config):
return ChatOpenAI(
model_name = config["model_name"],
temperature = config["temperature"]
)
def llm_chain(self, config, prompt):
llm_chain = LLMChain(
llm = self.get_llm(config),
prompt = self.LLM_CHAIN_PROMPT
)
with get_openai_callback() as callback:
completion = llm_chain.generate([{"question": prompt}])
return completion, callback
def rag_chain(self, config, prompt):
#vector_store = self.get_vector_store_chroma()
vector_store = self.get_vector_store_mongodb()
rag_chain = RetrievalQA.from_chain_type(
self.get_llm(config),
chain_type_kwargs = {"prompt": self.RAG_CHAIN_PROMPT},
retriever = vector_store.as_retriever(search_kwargs = {"k": config["k"]}),
return_source_documents = True
)
with get_openai_callback() as callback:
completion = rag_chain({"query": prompt})
return completion, callback |