alexkueck commited on
Commit
063f2a2
·
1 Parent(s): af937f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -4
app.py CHANGED
@@ -1,10 +1,177 @@
1
  import os, sys, json
2
- from openai import OpenAI
3
  import gradio as gr
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  # Schnittstellen hinzubinden und OpenAI Key holen aus den Secrets
8
- client = OpenAI(
9
- api_key=os.getenv("OPENAI_API_KEY"), # this is also the default, it can be omitted
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, sys, json
 
2
  import gradio as gr
3
+ import openai
4
 
5
+ from langchain.chains import LLMChain, RetrievalQA
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import PyPDFLoader, WebBaseLoader
8
+ from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader
9
+ from langchain.document_loaders.generic import GenericLoader
10
+ from langchain.document_loaders.parsers import OpenAIWhisperParser
11
+
12
+ from langchain.embeddings.openai import OpenAIEmbeddings
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.vectorstores import Chroma
16
+ from langchain.vectorstores import MongoDBAtlasVectorSearch
17
+
18
+ from pymongo import MongoClient
19
+
20
+ from dotenv import load_dotenv, find_dotenv
21
+ _ = load_dotenv(find_dotenv())
22
 
23
 
24
  # Schnittstellen hinzubinden und OpenAI Key holen aus den Secrets
25
+ #client = OpenAI(
26
+ #api_key=os.getenv("OPENAI_API_KEY"), # this is also the default, it can be omitted
27
+ #)
28
+
29
+
30
+
31
+
32
+
33
+ openai.api_key = os.getenv["OPENAI_API_KEY"]
34
+
35
+ MONGODB_URI = os.environ["MONGODB_ATLAS_CLUSTER_URI"]
36
+ client = MongoClient(MONGODB_URI)
37
+ MONGODB_DB_NAME = "langchain_db"
38
+ MONGODB_COLLECTION_NAME = "gpt-4"
39
+ MONGODB_COLLECTION = client[MONGODB_DB_NAME][MONGODB_COLLECTION_NAME]
40
+ MONGODB_INDEX_NAME = "default"
41
+
42
+ template = """If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Always say
43
+ "🧠 Thanks for using the app - Bernd" at the end of the answer. """
44
+
45
+ llm_template = "Answer the question at the end. " + template + "Question: {question} Helpful Answer: "
46
+ rag_template = "Use the following pieces of context to answer the question at the end. " + template + "{context} Question: {question} Helpful Answer: "
47
+
48
+ LLM_CHAIN_PROMPT = PromptTemplate(input_variables = ["question"],
49
+ template = llm_template)
50
+ RAG_CHAIN_PROMPT = PromptTemplate(input_variables = ["context", "question"],
51
+ template = rag_template)
52
+
53
+ CHROMA_DIR = "/data/chroma"
54
+ YOUTUBE_DIR = "/data/youtube"
55
+
56
+ PDF_URL = "https://arxiv.org/pdf/2303.08774.pdf"
57
+ WEB_URL = "https://openai.com/research/gpt-4"
58
+ YOUTUBE_URL_1 = "https://www.youtube.com/watch?v=--khbXchTeE"
59
+ YOUTUBE_URL_2 = "https://www.youtube.com/watch?v=hdhZwyf24mE"
60
+ YOUTUBE_URL_3 = "https://www.youtube.com/watch?v=vw-KWfKwvTQ"
61
+
62
+ MODEL_NAME = "gpt-4"
63
+
64
+ def document_loading_splitting():
65
+ # Document loading
66
+ docs = []
67
+ # Load PDF
68
+ loader = PyPDFLoader(PDF_URL)
69
+ docs.extend(loader.load())
70
+ # Load Web
71
+ loader = WebBaseLoader(WEB_URL)
72
+ docs.extend(loader.load())
73
+ # Load YouTube
74
+ loader = GenericLoader(YoutubeAudioLoader([YOUTUBE_URL_1,
75
+ YOUTUBE_URL_2,
76
+ YOUTUBE_URL_3], YOUTUBE_DIR),
77
+ OpenAIWhisperParser())
78
+ docs.extend(loader.load())
79
+ # Document splitting
80
+ text_splitter = RecursiveCharacterTextSplitter(chunk_overlap = 150,
81
+ chunk_size = 1500)
82
+ splits = text_splitter.split_documents(docs)
83
+ return splits
84
+
85
+ def document_storage_chroma(splits):
86
+ Chroma.from_documents(documents = splits,
87
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
88
+ persist_directory = CHROMA_DIR)
89
+
90
+ def document_storage_mongodb(splits):
91
+ MongoDBAtlasVectorSearch.from_documents(documents = splits,
92
+ embedding = OpenAIEmbeddings(disallowed_special = ()),
93
+ collection = MONGODB_COLLECTION,
94
+ index_name = MONGODB_INDEX_NAME)
95
+
96
+ def document_retrieval_chroma(llm, prompt):
97
+ db = Chroma(embedding_function = OpenAIEmbeddings(),
98
+ persist_directory = CHROMA_DIR)
99
+ return db
100
+
101
+ def document_retrieval_mongodb(llm, prompt):
102
+ db = MongoDBAtlasVectorSearch.from_connection_string(MONGODB_URI,
103
+ MONGODB_DB_NAME + "." + MONGODB_COLLECTION_NAME,
104
+ OpenAIEmbeddings(disallowed_special = ()),
105
+ index_name = MONGODB_INDEX_NAME)
106
+ return db
107
+
108
+ def llm_chain(llm, prompt):
109
+ llm_chain = LLMChain(llm = llm, prompt = LLM_CHAIN_PROMPT)
110
+ result = llm_chain.run({"question": prompt})
111
+ return result
112
+
113
+ def rag_chain(llm, prompt, db):
114
+ rag_chain = RetrievalQA.from_chain_type(llm,
115
+ chain_type_kwargs = {"prompt": RAG_CHAIN_PROMPT},
116
+ retriever = db.as_retriever(search_kwargs = {"k": 3}),
117
+ return_source_documents = True)
118
+ result = rag_chain({"query": prompt})
119
+ return result["result"]
120
+
121
+ def invoke(openai_api_key, rag_option, prompt):
122
+ if (openai_api_key == ""):
123
+ raise gr.Error("OpenAI API Key is required.")
124
+ if (rag_option is None):
125
+ raise gr.Error("Retrieval Augmented Generation is required.")
126
+ if (prompt == ""):
127
+ raise gr.Error("Prompt is required.")
128
+ try:
129
+ llm = ChatOpenAI(model_name = MODEL_NAME,
130
+ openai_api_key = openai_api_key,
131
+ temperature = 0)
132
+ if (rag_option == "Chroma"):
133
+ #splits = document_loading_splitting()
134
+ #document_storage_chroma(splits)
135
+ db = document_retrieval_chroma(llm, prompt)
136
+ result = rag_chain(llm, prompt, db)
137
+ elif (rag_option == "MongoDB"):
138
+ #splits = document_loading_splitting()
139
+ #document_storage_mongodb(splits)
140
+ db = document_retrieval_mongodb(llm, prompt)
141
+ result = rag_chain(llm, prompt, db)
142
+ else:
143
+ result = llm_chain(llm, prompt)
144
+ except Exception as e:
145
+ raise gr.Error(e)
146
+ return result
147
+
148
+ description = """<strong>Overview:</strong> Reasoning application that demonstrates a <strong>Large Language Model (LLM)</strong> with
149
+ <strong>Retrieval Augmented Generation (RAG)</strong> on <strong>external data</strong>.\n\n
150
+ <strong>Instructions:</strong> Enter an OpenAI API key and perform LLM use cases (semantic search, summarization, translation, etc.) on
151
+ <a href='""" + YOUTUBE_URL_1 + """'>YouTube</a>, <a href='""" + PDF_URL + """'>PDF</a>, and <a href='""" + WEB_URL + """'>Web</a>
152
+ data on GPT-4, published after LLM knowledge cutoff.
153
+ <ul style="list-style-type:square;">
154
+ <li>Set "Retrieval Augmented Generation" to "<strong>Off</strong>" and submit prompt "What is GPT-4?" The <strong>LLM without RAG</strong> does not know the answer.</li>
155
+ <li>Set "Retrieval Augmented Generation" to "<strong>Chroma</strong>" or "<strong>MongoDB</strong>" and submit prompt "What is GPT-4?" The <strong>LLM with RAG</strong> knows the answer.</li>
156
+ <li>Experiment with prompts, e.g. "What are GPT-4's media capabilities in 5 emojis and 1 sentence?", "List GPT-4's exam scores and benchmark results.", or "Compare GPT-4 to GPT-3.5 in markdown table format."</li>
157
+ <li>Experiment some more, for example "What is the GPT-4 API's cost and rate limit? Answer in English, Arabic, Chinese, Hindi, and Russian in JSON format." or "Write a Python program that calls the GPT-4 API."</li>
158
+ </ul>\n\n
159
+ <strong>Technology:</strong> <a href='https://www.gradio.app/'>Gradio</a> UI using the <a href='https://openai.com/'>OpenAI</a> API and
160
+ AI-native <a href='https://www.trychroma.com/'>Chroma</a> embedding database /
161
+ <a href='https://www.mongodb.com/blog/post/introducing-atlas-vector-search-build-intelligent-applications-semantic-search-ai'>MongoDB</a> vector search.
162
+ <strong>Speech-to-text</strong> (STT) via <a href='https://openai.com/research/whisper'>whisper-1</a> model, <strong>text embedding</strong> via
163
+ <a href='https://openai.com/blog/new-and-improved-embedding-model'>text-embedding-ada-002</a> model, and <strong>text generation</strong> via
164
+ <a href='""" + WEB_URL + """'>gpt-4</a> model. Implementation via AI-first <a href='https://www.langchain.com/'>LangChain</a> toolkit.\n\n
165
+ In addition to the OpenAI API version, see also the <a href='https://aws.amazon.com/bedrock/'>Amazon Bedrock</a> API and
166
+ <a href='https://cloud.google.com/vertex-ai'>Google Vertex AI</a> API versions on
167
+ <a href='https://github.com/bstraehle/ai-ml-dl/tree/main/hugging-face'>GitHub</a>."""
168
+
169
+ gr.close_all()
170
+ demo = gr.Interface(fn=invoke,
171
+ inputs = [gr.Textbox(label = "OpenAI API Key", value = "sk-", lines = 1),
172
+ gr.Radio(["Off", "Chroma", "MongoDB"], label="Retrieval Augmented Generation", value = "Off"),
173
+ gr.Textbox(label = "Prompt", value = "What is GPT-4?", lines = 1)],
174
+ outputs = [gr.Textbox(label = "Completion", lines = 1)],
175
+ title = "Generative AI - LLM & RAG",
176
+ description = description)
177
+ demo.launch()