ShawnAI commited on
Commit
cca37b5
·
1 Parent(s): 1929137

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -50
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import random
3
  import time
4
 
 
 
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.vectorstores import Pinecone
@@ -15,7 +17,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
 
16
  #OPENAI_API_KEY = ""
17
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
18
- OPENAI_TEMP = 0
19
 
20
  PINECONE_KEY = os.environ.get("PINECONE_KEY", "")
21
  PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp")
@@ -28,7 +30,7 @@ TOP_K_DEFAULT = 10
28
  TOP_K_MAX = 25
29
 
30
 
31
- BUTTON_MIN_WIDTH = 180
32
 
33
  STATUS_NOK = "404-MODEL UNREADY-critical"
34
  STATUS_OK = "200-MODEL LOADED-9cf"
@@ -57,10 +59,19 @@ MODEL_WARNING = f"Please paste your OpenAI API Key from \
57
  [openai.com](https://platform.openai.com/account/api-keys) and then **{KEY_INIT}**"
58
 
59
 
60
- TAB_1 = "3GPP Chatbot"
61
 
62
  FAVICON = './icon.svg'
63
 
 
 
 
 
 
 
 
 
 
64
  webui_title = """
65
  # OpenAI Chatbot Based on Vector Database
66
  ## Example of 3GPP
@@ -91,20 +102,41 @@ def init_model(api_key, emb_name, db_api_key, db_env, db_index):
91
 
92
  #llm = OpenAI(temperature=OPENAI_TEMP, model_name="gpt-3.5-turbo-0301")
93
 
94
- llm = ChatOpenAI(temperature = OPENAI_TEMP,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  openai_api_key = api_key)
 
 
 
96
 
97
- chain = load_qa_chain(llm, chain_type="stuff")
98
-
 
 
 
 
99
  db = Pinecone.from_existing_index(index_name = db_index,
100
  embedding = embeddings)
101
 
102
- return api_key, MODEL_DONE, chain, db, None
103
  else:
104
- return None,MODEL_NULL,None,None,None
105
  except Exception as e:
106
  print(e)
107
- return None,MODEL_NULL,None,None,None
108
 
109
 
110
  def get_chat_history(inputs) -> str:
@@ -131,14 +163,16 @@ def doc_similarity(query, db, top_k):
131
  def user(user_message, history):
132
  return "", history+[[user_message, None]]
133
 
134
- def bot(box_message, ref_message, chain, db, top_k):
 
 
135
 
136
  # bot_message = random.choice(["Yes", "No"])
137
  # 0 is user question, 1 is bot response
138
  question = box_message[-1][0]
139
  history = box_message[:-1]
140
 
141
- if (not chain) or (not db):
142
  box_message[-1][1] = MODEL_WARNING
143
  return box_message, "", ""
144
 
@@ -149,17 +183,30 @@ def bot(box_message, ref_message, chain, db, top_k):
149
  details = f"Q: {question}\nR: {ref_message}"
150
 
151
 
152
- docs = doc_similarity(ref_message, db, top_k)
153
 
154
- delta_top_k = top_k - len(docs)
 
155
 
156
- if delta_top_k > 0:
157
- docs = doc_similarity(ref_message, db, top_k+delta_top_k)
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  all_output = chain({"input_documents": docs,
160
  "question": question,
161
  "chat_history": get_chat_history(history)})
162
-
163
  bot_message = all_output['output_text']
164
 
165
 
@@ -171,7 +218,7 @@ def bot(box_message, ref_message, chain, db, top_k):
171
  #print(source)
172
 
173
  box_message[-1][1] = bot_message
174
- return box_message, "", [[details, source]]
175
 
176
  #----------------------------------------------------------------------------------------------------------
177
  #----------------------------------------------------------------------------------------------------------
@@ -180,10 +227,11 @@ with gr.Blocks(
180
  title = TAB_1,
181
  theme = "Base",
182
  css = """.bigbox {
183
- min-height:200px;
184
  }
185
  """) as demo:
186
- llm_chain = gr.State()
 
187
  vector_db = gr.State()
188
  gr.Markdown(webui_title)
189
  gr.HTML(dup_link)
@@ -208,13 +256,26 @@ with gr.Blocks(
208
  with gr.Row():
209
  with gr.Column(scale=10):
210
  chatbot = gr.Chatbot(elem_classes="bigbox")
211
-
 
 
 
 
 
 
 
 
 
 
 
212
  with gr.Row():
213
  with gr.Column(scale=10):
214
  query = gr.Textbox(label="Question:",
215
  lines=2)
216
  ref = gr.Textbox(label="Reference(optional):")
 
217
  with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
 
218
  clear = gr.Button(KEY_CLEAR)
219
  submit = gr.Button(KEY_SUBMIT,variant="primary")
220
 
@@ -238,35 +299,38 @@ with gr.Blocks(
238
  lines=1,
239
  interactive=True,
240
  type='email')
241
- with gr.Row():
242
- db_api_textbox = gr.Textbox(
243
- label = "Pinecone API Key",
244
- # show_label = False,
245
- value = PINECONE_KEY,
246
- placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER",
247
- lines=1,
248
- interactive=True,
249
- type='password')
250
- with gr.Row():
251
- db_env_textbox = gr.Textbox(
252
- label = "Pinecone Environment",
253
- # show_label = False,
254
- value = PINECONE_ENV,
255
- placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER",
256
- lines=1,
257
- interactive=True,
258
- type='email')
259
- db_index_textbox = gr.Textbox(
260
- label = "Pinecone Index",
261
- # show_label = False,
262
- value = PINECONE_INDEX,
263
- placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER",
264
- lines=1,
265
- interactive=True,
266
- type='email')
267
-
268
- init_input = [llm_api_textbox, emb_textbox, db_api_textbox, db_env_textbox, db_index_textbox]
269
- init_output = [llm_api_textbox, model_statusbox, llm_chain, vector_db, chatbot]
 
 
 
270
 
271
  llm_api_textbox.submit(init_model, init_input, init_output)
272
  init.click(init_model, init_input, init_output)
@@ -276,7 +340,9 @@ with gr.Blocks(
276
  [query, chatbot],
277
  queue=False).then(
278
  bot,
279
- [chatbot, ref, llm_chain, vector_db, top_k],
 
 
280
  [chatbot, ref, detail_panel]
281
  )
282
 
 
2
  import random
3
  import time
4
 
5
+ from langchain import PromptTemplate
6
+ from langchain.llms import OpenAI
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import Pinecone
 
17
 
18
  #OPENAI_API_KEY = ""
19
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
20
+ OPENAI_TEMP = 1
21
 
22
  PINECONE_KEY = os.environ.get("PINECONE_KEY", "")
23
  PINECONE_ENV = os.environ.get("PINECONE_ENV", "asia-northeast1-gcp")
 
30
  TOP_K_MAX = 25
31
 
32
 
33
+ BUTTON_MIN_WIDTH = 205
34
 
35
  STATUS_NOK = "404-MODEL UNREADY-critical"
36
  STATUS_OK = "200-MODEL LOADED-9cf"
 
59
  [openai.com](https://platform.openai.com/account/api-keys) and then **{KEY_INIT}**"
60
 
61
 
62
+ TAB_1 = "Chatbot"
63
 
64
  FAVICON = './icon.svg'
65
 
66
+ LLM_LIST = ["gpt-3.5-turbo", "text-davinci-003"]
67
+
68
+
69
+ DOC_1 = '3GPP'
70
+ DOC_2 = 'HTTP2'
71
+
72
+ DOC_SUPPORTED = [DOC_1, DOC_2]
73
+ DOC_DEFAULT = [DOC_1]
74
+
75
  webui_title = """
76
  # OpenAI Chatbot Based on Vector Database
77
  ## Example of 3GPP
 
102
 
103
  #llm = OpenAI(temperature=OPENAI_TEMP, model_name="gpt-3.5-turbo-0301")
104
 
105
+
106
+ llm_dict = {}
107
+ for llm_name in LLM_LIST:
108
+ if llm_name == "gpt-3.5-turbo":
109
+ llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
110
+ temperature = OPENAI_TEMP,
111
+ openai_api_key = api_key)
112
+ else:
113
+ llm_dict[llm_name] = OpenAI(model_name=llm_name,
114
+ temperature = OPENAI_TEMP,
115
+ openai_api_key = api_key)
116
+
117
+ '''
118
+ ChatOpenAI(model_name="gpt-3.5-turbo",
119
+ temperature = OPENAI_TEMP,
120
  openai_api_key = api_key)
121
+ chain_1 = load_qa_chain(llm, chain_type="stuff")
122
+
123
+ #LLMChain(llm=llm, prompt=condense_question_prompt)
124
 
125
+ chain_2 = LLMChain(llm = llm,
126
+ prompt = PromptTemplate(template='{question}',
127
+ input_variables=['question']),
128
+ output_key = 'output_text')
129
+ '''
130
+
131
  db = Pinecone.from_existing_index(index_name = db_index,
132
  embedding = embeddings)
133
 
134
+ return api_key, MODEL_DONE, llm_dict, None, db, None
135
  else:
136
+ return None,MODEL_NULL,None,None,None,None
137
  except Exception as e:
138
  print(e)
139
+ return None,MODEL_NULL,None,None,None,None
140
 
141
 
142
  def get_chat_history(inputs) -> str:
 
163
  def user(user_message, history):
164
  return "", history+[[user_message, None]]
165
 
166
+ def bot(box_message, ref_message,
167
+ llm_dropdown, llm_dict, doc_list,
168
+ db, top_k):
169
 
170
  # bot_message = random.choice(["Yes", "No"])
171
  # 0 is user question, 1 is bot response
172
  question = box_message[-1][0]
173
  history = box_message[:-1]
174
 
175
+ if (not llm_dict) or (not doc_check) or (not db):
176
  box_message[-1][1] = MODEL_WARNING
177
  return box_message, "", ""
178
 
 
183
  details = f"Q: {question}\nR: {ref_message}"
184
 
185
 
186
+ llm = llm_dict[llm_dropdown]
187
 
188
+ print(llm)
189
+ print(doc_list)
190
 
191
+ if DOC_1 in doc_list:
192
+ chain = load_qa_chain(llm, chain_type="stuff")
193
+ docs = doc_similarity(ref_message, db, top_k)
194
+ delta_top_k = top_k - len(docs)
195
+
196
+ if delta_top_k > 0:
197
+ docs = doc_similarity(ref_message, db, top_k+delta_top_k)
198
+
199
+ else:
200
+ chain = LLMChain(llm = llm,
201
+ prompt = PromptTemplate(template='{question}',
202
+ input_variables=['question']),
203
+ output_key = 'output_text')
204
+ docs = []
205
 
206
  all_output = chain({"input_documents": docs,
207
  "question": question,
208
  "chat_history": get_chat_history(history)})
209
+
210
  bot_message = all_output['output_text']
211
 
212
 
 
218
  #print(source)
219
 
220
  box_message[-1][1] = bot_message
221
+ return box_message, "", [[details, bot_message + source]]
222
 
223
  #----------------------------------------------------------------------------------------------------------
224
  #----------------------------------------------------------------------------------------------------------
 
227
  title = TAB_1,
228
  theme = "Base",
229
  css = """.bigbox {
230
+ min-height:250px;
231
  }
232
  """) as demo:
233
+ llm = gr.State()
234
+ chain_2 = gr.State() # not inuse
235
  vector_db = gr.State()
236
  gr.Markdown(webui_title)
237
  gr.HTML(dup_link)
 
256
  with gr.Row():
257
  with gr.Column(scale=10):
258
  chatbot = gr.Chatbot(elem_classes="bigbox")
259
+ #with gr.Column(scale=1):
260
+ with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
261
+ doc_check = gr.CheckboxGroup(choices = DOC_SUPPORTED,
262
+ value = DOC_DEFAULT,
263
+ label = "Reference Docs",
264
+ interactive=True)
265
+ llm_dropdown = gr.Dropdown(LLM_LIST,
266
+ value=LLM_LIST[0],
267
+ multiselect=False,
268
+ interactive=True,
269
+ label="LLM Selection",
270
+ )
271
  with gr.Row():
272
  with gr.Column(scale=10):
273
  query = gr.Textbox(label="Question:",
274
  lines=2)
275
  ref = gr.Textbox(label="Reference(optional):")
276
+
277
  with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
278
+
279
  clear = gr.Button(KEY_CLEAR)
280
  submit = gr.Button(KEY_SUBMIT,variant="primary")
281
 
 
299
  lines=1,
300
  interactive=True,
301
  type='email')
302
+ with gr.Accordion("Pinecone Database for "+DOC_1):
303
+ with gr.Row():
304
+ db_api_textbox = gr.Textbox(
305
+ label = "Pinecone API Key",
306
+ # show_label = False,
307
+ value = PINECONE_KEY,
308
+ placeholder = "Paste Your Pinecone API Key (xx-xx-xx-xx-xx) and Hit ENTER",
309
+ lines=1,
310
+ interactive=True,
311
+ type='password')
312
+ with gr.Row():
313
+ db_env_textbox = gr.Textbox(
314
+ label = "Pinecone Environment",
315
+ # show_label = False,
316
+ value = PINECONE_ENV,
317
+ placeholder = "Paste Your Pinecone Environment (xx-xx-xx) and Hit ENTER",
318
+ lines=1,
319
+ interactive=True,
320
+ type='email')
321
+ db_index_textbox = gr.Textbox(
322
+ label = "Pinecone Index",
323
+ # show_label = False,
324
+ value = PINECONE_INDEX,
325
+ placeholder = "Paste Your Pinecone Index (xxxx) and Hit ENTER",
326
+ lines=1,
327
+ interactive=True,
328
+ type='email')
329
+
330
+ init_input = [llm_api_textbox, emb_textbox, db_api_textbox, db_env_textbox, db_index_textbox]
331
+ init_output = [llm_api_textbox, model_statusbox,
332
+ llm, chain_2,
333
+ vector_db, chatbot]
334
 
335
  llm_api_textbox.submit(init_model, init_input, init_output)
336
  init.click(init_model, init_input, init_output)
 
340
  [query, chatbot],
341
  queue=False).then(
342
  bot,
343
+ [chatbot, ref,
344
+ llm_dropdown, llm, doc_check,
345
+ vector_db, top_k],
346
  [chatbot, ref, detail_panel]
347
  )
348