yonikremer commited on
Commit
6e4f775
·
1 Parent(s): 474b6f1

limited the choice of models to just 2

Browse files
Files changed (3) hide show
  1. app.py +5 -8
  2. available_models.py +4 -0
  3. on_server_start.py +5 -8
app.py CHANGED
@@ -7,6 +7,7 @@ In the demo, the user can write a prompt
7
  import streamlit as st
8
  from torch.cuda import CudaError
9
 
 
10
  from hanlde_form_submit import on_form_submit
11
  from on_server_start import main as on_server_start_main
12
 
@@ -15,12 +16,10 @@ on_server_start_main()
15
  st.title("Grouped Sampling Demo")
16
 
17
  with st.form("request_form"):
18
- selected_model_name: str = st.text_input(
19
- label="Model name",
20
- value="gpt2",
21
- help="The name of the model to use."
22
- "Supported models are all the models in:"
23
- " https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch",
24
  )
25
 
26
  output_length: int = st.number_input(
@@ -70,5 +69,3 @@ with st.form("request_form"):
70
  with open("user_instructions_hebrew.md", "r") as fh:
71
  long_description = fh.read()
72
  st.markdown(long_description)
73
-
74
- await on_server_start_main()
 
7
  import streamlit as st
8
  from torch.cuda import CudaError
9
 
10
+ from available_models import AVAILABLE_MODELS
11
  from hanlde_form_submit import on_form_submit
12
  from on_server_start import main as on_server_start_main
13
 
 
16
  st.title("Grouped Sampling Demo")
17
 
18
  with st.form("request_form"):
19
+ selected_model_name: str = st.selectbox(
20
+ label="Select a model",
21
+ options=AVAILABLE_MODELS,
22
+ help="opt-iml-max-30b generates better texts but is slower",
 
 
23
  )
24
 
25
  output_length: int = st.number_input(
 
69
  with open("user_instructions_hebrew.md", "r") as fh:
70
  long_description = fh.read()
71
  st.markdown(long_description)
 
 
available_models.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ AVAILABLE_MODELS = (
2
+ "facebook/opt-iml-max-1.3b",
3
+ "facebook/opt-iml-max-30b",
4
+ )
on_server_start.py CHANGED
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
6
  from transformers import logging as transformers_logging
7
  from huggingface_hub import logging as huggingface_hub_logging
8
 
 
9
  from download_repo import download_pytorch_model
10
 
11
 
@@ -22,17 +23,13 @@ def download_useful_models():
22
  Downloads the models that are useful for this project.
23
  So that the user doesn't have to wait for the models to download when they first use the app.
24
  """
25
- print("Downloading useful models...")
26
- useful_models = (
27
- "facebook/opt-125m",
28
- "facebook/opt-iml-max-30b",
29
- )
30
  with ThreadPoolExecutor() as executor:
31
- executor.map(download_pytorch_model, useful_models)
32
 
33
 
34
- async def main():
35
- disable_progress_bar()
36
  download_useful_models()
37
 
38
 
 
6
  from transformers import logging as transformers_logging
7
  from huggingface_hub import logging as huggingface_hub_logging
8
 
9
+ from available_models import AVAILABLE_MODELS
10
  from download_repo import download_pytorch_model
11
 
12
 
 
23
  Downloads the models that are useful for this project.
24
  So that the user doesn't have to wait for the models to download when they first use the app.
25
  """
26
+ print("Downloading useful models. It might take a while...")
 
 
 
 
27
  with ThreadPoolExecutor() as executor:
28
+ executor.map(download_pytorch_model, AVAILABLE_MODELS)
29
 
30
 
31
+ def main():
32
+ # disable_progress_bar()
33
  download_useful_models()
34
 
35