arslan-ahmed commited on
Commit
bde0da6
·
1 Parent(s): 12314a0

updated BAM models

Browse files
Files changed (3) hide show
  1. app.py +11 -9
  2. ttyd_consts.py +3 -4
  3. ttyd_functions.py +3 -0
app.py CHANGED
@@ -20,6 +20,7 @@ from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenP
20
  from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
21
  from ibm_watson_machine_learning.foundation_models import Model
22
  from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
 
23
 
24
  import genai
25
 
@@ -77,12 +78,13 @@ def setOaiApiKey(creds):
77
  def setBamApiKey(creds):
78
  creds = getBamCreds(creds)
79
  try:
80
- genai.Model.models(credentials=creds['bam_creds'])
 
81
  api_key_st = creds
82
- return 'BAM credentials accepted.', *[x.update(interactive=False) for x in credComps_btn_tb], api_key_st
83
  except Exception as e:
84
  gr.Warning(str(e))
85
- return [x.update() for x in credComps_op]
86
 
87
  def setWxApiKey(key, p_id):
88
  creds = getWxCreds(key, p_id)
@@ -97,7 +99,7 @@ def setWxApiKey(key, p_id):
97
 
98
  # convert user uploaded data to vectorstore
99
  def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.Progress()):
100
- opComponents = [data_ingest_btn, upload_fb, urls_tb]
101
  # parse user data
102
  file_paths = []
103
  documents = []
@@ -129,7 +131,7 @@ def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.P
129
  src_str = str(src_str[1]) + ' source document(s) successfully loaded in vector store.'+'\n\n' + src_str[0]
130
 
131
  progress(1, 'Data loaded')
132
- return vsDict_st, src_str, *[x.update(interactive=False) for x in [data_ingest_btn, upload_fb]], urls_tb.update(interactive=False, placeholder='')
133
 
134
  # initialize chatbot function sets the QA Chain, and also sets/updates any other components to start chatting. updateQaChain function only updates QA chain and will be called whenever Adv Settings are updated.
135
  def initializeChatbot(temp, k, modelNameDD, stdlQs, api_key_st, vsDict_st, progress=gr.Progress()):
@@ -247,7 +249,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
247
  , placeholder=url_tb_ph)
248
  data_ingest_btn = gr.Button("Load Data")
249
  status_tb = gr.TextArea(label='Status Info')
250
- initChatbot_btn = gr.Button("Initialize Chatbot", variant="primary")
251
 
252
  credComps_btn_tb = [oaiKey_tb, oaiKey_btn, bamKey_tb, bamKey_btn, wxKey_tb, wxPid_tb, wxKey_btn]
253
  credComps_op = [status_tb] + credComps_btn_tb + [api_key_state]
@@ -266,7 +268,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
266
  temp_sld = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", info='Sampling temperature to use when calling LLM. Defaults to 0.7')
267
  k_sld = gr.Slider(minimum=1, maximum=10, step=1, value=mode.k, label="K", info='Number of relavant documents to return from Vector Store. Defaults to 4')
268
  model_dd = gr.Dropdown(label='Model Name'\
269
- , choices=model_dd_choices, allow_custom_value=True\
270
  , info=model_dd_info)
271
  stdlQs_rb = gr.Radio(label='Standalone Question', info=stdlQs_rb_info\
272
  , type='index', value=stdlQs_rb_choices[1]\
@@ -280,7 +282,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
280
  oaiKey_tb.submit(**oaiKey_btn_args)
281
 
282
  # BAM API button
283
- bamKey_btn_args = {'fn':setBamApiKey, 'inputs':[bamKey_tb], 'outputs':credComps_op}
284
  bamKey_btn.click(**bamKey_btn_args)
285
  bamKey_tb.submit(**bamKey_btn_args)
286
 
@@ -289,7 +291,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue='orange', secondary_hue='gray
289
  wxKey_btn.click(**wxKey_btn_args)
290
 
291
  # Data Ingest Button
292
- data_ingest_event = data_ingest_btn.click(uiData_vecStore, [upload_fb, urls_tb, api_key_state, chromaVS_state], [chromaVS_state, status_tb, data_ingest_btn, upload_fb, urls_tb])
293
 
294
  # Adv Settings
295
  advSet_args = {'fn':updateQaChain, 'inputs':[temp_sld, k_sld, model_dd, stdlQs_rb, api_key_state, chromaVS_state], 'outputs':[qa_state, model_dd]}
 
20
  from ibm_watson_machine_learning.foundation_models.utils.enums import DecodingMethods
21
  from ibm_watson_machine_learning.foundation_models import Model
22
  from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
23
+ from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes
24
 
25
  import genai
26
 
 
78
  def setBamApiKey(creds):
79
  creds = getBamCreds(creds)
80
  try:
81
+ bam_models = genai.Model.models(credentials=creds['bam_creds'])
82
+ bam_models = sorted(x.id for x in bam_models)
83
  api_key_st = creds
84
+ return 'BAM credentials accepted.', *[x.update(interactive=False) for x in credComps_btn_tb], api_key_st, model_dd.update(choices=getModelChoices(openAi_models, ModelTypes, bam_models))
85
  except Exception as e:
86
  gr.Warning(str(e))
87
+ return *[x.update() for x in credComps_op], model_dd.update()
88
 
89
  def setWxApiKey(key, p_id):
90
  creds = getWxCreds(key, p_id)
 
99
 
100
  # convert user uploaded data to vectorstore
101
  def uiData_vecStore(userFiles, userUrls, api_key_st, vsDict_st={}, progress=gr.Progress()):
102
+ opComponents = [data_ingest_btn, upload_fb, urls_tb, initChatbot_btn]
103
  # parse user data
104
  file_paths = []
105
  documents = []
 
131
  src_str = str(src_str[1]) + ' source document(s) successfully loaded in vector store.'+'\n\n' + src_str[0]
132
 
133
  progress(1, 'Data loaded')
134
+ return vsDict_st, src_str, *[x.update(interactive=False) for x in [data_ingest_btn, upload_fb]], urls_tb.update(interactive=False, placeholder=''), initChatbot_btn.update(interactive=True)
135
 
136
  # initialize chatbot function sets the QA Chain, and also sets/updates any other components to start chatting. updateQaChain function only updates QA chain and will be called whenever Adv Settings are updated.
137
  def initializeChatbot(temp, k, modelNameDD, stdlQs, api_key_st, vsDict_st, progress=gr.Progress()):
 
249
  , placeholder=url_tb_ph)
250
  data_ingest_btn = gr.Button("Load Data")
251
  status_tb = gr.TextArea(label='Status Info')
252
+ initChatbot_btn = gr.Button("Initialize Chatbot", variant="primary", interactive=False)
253
 
254
  credComps_btn_tb = [oaiKey_tb, oaiKey_btn, bamKey_tb, bamKey_btn, wxKey_tb, wxPid_tb, wxKey_btn]
255
  credComps_op = [status_tb] + credComps_btn_tb + [api_key_state]
 
268
  temp_sld = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", info='Sampling temperature to use when calling LLM. Defaults to 0.7')
269
  k_sld = gr.Slider(minimum=1, maximum=10, step=1, value=mode.k, label="K", info='Number of relavant documents to return from Vector Store. Defaults to 4')
270
  model_dd = gr.Dropdown(label='Model Name'\
271
+ , choices=getModelChoices(openAi_models, ModelTypes, bam_models_old), allow_custom_value=True\
272
  , info=model_dd_info)
273
  stdlQs_rb = gr.Radio(label='Standalone Question', info=stdlQs_rb_info\
274
  , type='index', value=stdlQs_rb_choices[1]\
 
282
  oaiKey_tb.submit(**oaiKey_btn_args)
283
 
284
  # BAM API button
285
+ bamKey_btn_args = {'fn':setBamApiKey, 'inputs':[bamKey_tb], 'outputs':credComps_op+[model_dd]}
286
  bamKey_btn.click(**bamKey_btn_args)
287
  bamKey_tb.submit(**bamKey_btn_args)
288
 
 
291
  wxKey_btn.click(**wxKey_btn_args)
292
 
293
  # Data Ingest Button
294
+ data_ingest_event = data_ingest_btn.click(uiData_vecStore, [upload_fb, urls_tb, api_key_state, chromaVS_state], [chromaVS_state, status_tb, data_ingest_btn, upload_fb, urls_tb, initChatbot_btn])
295
 
296
  # Adv Settings
297
  advSet_args = {'fn':updateQaChain, 'inputs':[temp_sld, k_sld, model_dd, stdlQs_rb, api_key_state, chromaVS_state], 'outputs':[qa_state, model_dd]}
ttyd_consts.py CHANGED
@@ -1,5 +1,4 @@
1
  from langchain import PromptTemplate
2
- from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes
3
  import os
4
  from dotenv import load_dotenv
5
  load_dotenv()
@@ -45,7 +44,7 @@ Question: {question} [/INST]
45
 
46
  promptLlama=PromptTemplate(input_variables=['context', 'question'], template=llamaPromptTemplate)
47
 
48
- bam_models = sorted(['bigscience/bloom',
49
  'salesforce/codegen2-16b',
50
  'codellama/codellama-34b-instruct',
51
  'tiiuae/falcon-40b',
@@ -71,9 +70,9 @@ bam_models = sorted(['bigscience/bloom',
71
  'bigcode/starcoder',
72
  'google/ul2'])
73
 
74
- model_dd_info = 'Make sure your credentials are submitted before changing the model. You can also input any OpenAI model name or Watsonx/BAM model ID.'
75
 
76
- model_dd_choices = ['gpt-3.5-turbo (openai)', 'gpt-3.5-turbo-16k (openai)', 'gpt-4 (openai)', 'text-davinci-003 (Legacy - openai)', 'text-curie-001 (Legacy - openai)', 'babbage-002 (openai)'] + [model.value+' (watsonx)' for model in ModelTypes] + [model + ' (bam)' for model in bam_models]
77
 
78
 
79
  OaiDefaultModel = 'gpt-3.5-turbo (openai)'
 
1
  from langchain import PromptTemplate
 
2
  import os
3
  from dotenv import load_dotenv
4
  load_dotenv()
 
44
 
45
  promptLlama=PromptTemplate(input_variables=['context', 'question'], template=llamaPromptTemplate)
46
 
47
+ bam_models_old = sorted(['bigscience/bloom',
48
  'salesforce/codegen2-16b',
49
  'codellama/codellama-34b-instruct',
50
  'tiiuae/falcon-40b',
 
70
  'bigcode/starcoder',
71
  'google/ul2'])
72
 
73
+ openAi_models = ['gpt-3.5-turbo (openai)', 'gpt-3.5-turbo-16k (openai)', 'gpt-4 (openai)', 'text-davinci-003 (Legacy - openai)', 'text-curie-001 (Legacy - openai)', 'babbage-002 (openai)']
74
 
75
+ model_dd_info = 'Make sure your credentials are submitted before changing the model. You can also input any OpenAI model name or Watsonx/BAM model ID.'
76
 
77
 
78
  OaiDefaultModel = 'gpt-3.5-turbo (openai)'
ttyd_functions.py CHANGED
@@ -372,3 +372,6 @@ def changeModel(oldModel, newModel):
372
  gr.Warning(warning)
373
  time.sleep(1)
374
  return newModel
 
 
 
 
372
  gr.Warning(warning)
373
  time.sleep(1)
374
  return newModel
375
+
376
+ def getModelChoices(openAi_models, wml_models, bam_models):
377
+ return [model for model in openAi_models] + [model.value+' (watsonx)' for model in wml_models] + [model + ' (bam)' for model in bam_models]