yonikremer
commited on
Commit
·
6e4f775
1
Parent(s):
474b6f1
limited the choice of models to just 2
Browse files- app.py +5 -8
- available_models.py +4 -0
- on_server_start.py +5 -8
app.py
CHANGED
@@ -7,6 +7,7 @@ In the demo, the user can write a prompt
|
|
7 |
import streamlit as st
|
8 |
from torch.cuda import CudaError
|
9 |
|
|
|
10 |
from hanlde_form_submit import on_form_submit
|
11 |
from on_server_start import main as on_server_start_main
|
12 |
|
@@ -15,12 +16,10 @@ on_server_start_main()
|
|
15 |
st.title("Grouped Sampling Demo")
|
16 |
|
17 |
with st.form("request_form"):
|
18 |
-
selected_model_name: str = st.
|
19 |
-
label="
|
20 |
-
|
21 |
-
help="
|
22 |
-
"Supported models are all the models in:"
|
23 |
-
" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch",
|
24 |
)
|
25 |
|
26 |
output_length: int = st.number_input(
|
@@ -70,5 +69,3 @@ with st.form("request_form"):
|
|
70 |
with open("user_instructions_hebrew.md", "r") as fh:
|
71 |
long_description = fh.read()
|
72 |
st.markdown(long_description)
|
73 |
-
|
74 |
-
await on_server_start_main()
|
|
|
7 |
import streamlit as st
|
8 |
from torch.cuda import CudaError
|
9 |
|
10 |
+
from available_models import AVAILABLE_MODELS
|
11 |
from hanlde_form_submit import on_form_submit
|
12 |
from on_server_start import main as on_server_start_main
|
13 |
|
|
|
16 |
st.title("Grouped Sampling Demo")
|
17 |
|
18 |
with st.form("request_form"):
|
19 |
+
selected_model_name: str = st.selectbox(
|
20 |
+
label="Select a model",
|
21 |
+
options=AVAILABLE_MODELS,
|
22 |
+
help="opt-iml-max-30b generates better texts but is slower",
|
|
|
|
|
23 |
)
|
24 |
|
25 |
output_length: int = st.number_input(
|
|
|
69 |
with open("user_instructions_hebrew.md", "r") as fh:
|
70 |
long_description = fh.read()
|
71 |
st.markdown(long_description)
|
|
|
|
available_models.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AVAILABLE_MODELS = (
|
2 |
+
"facebook/opt-iml-max-1.3b",
|
3 |
+
"facebook/opt-iml-max-30b",
|
4 |
+
)
|
on_server_start.py
CHANGED
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
6 |
from transformers import logging as transformers_logging
|
7 |
from huggingface_hub import logging as huggingface_hub_logging
|
8 |
|
|
|
9 |
from download_repo import download_pytorch_model
|
10 |
|
11 |
|
@@ -22,17 +23,13 @@ def download_useful_models():
|
|
22 |
Downloads the models that are useful for this project.
|
23 |
So that the user doesn't have to wait for the models to download when they first use the app.
|
24 |
"""
|
25 |
-
print("Downloading useful models...")
|
26 |
-
useful_models = (
|
27 |
-
"facebook/opt-125m",
|
28 |
-
"facebook/opt-iml-max-30b",
|
29 |
-
)
|
30 |
with ThreadPoolExecutor() as executor:
|
31 |
-
executor.map(download_pytorch_model,
|
32 |
|
33 |
|
34 |
-
|
35 |
-
disable_progress_bar()
|
36 |
download_useful_models()
|
37 |
|
38 |
|
|
|
6 |
from transformers import logging as transformers_logging
|
7 |
from huggingface_hub import logging as huggingface_hub_logging
|
8 |
|
9 |
+
from available_models import AVAILABLE_MODELS
|
10 |
from download_repo import download_pytorch_model
|
11 |
|
12 |
|
|
|
23 |
Downloads the models that are useful for this project.
|
24 |
So that the user doesn't have to wait for the models to download when they first use the app.
|
25 |
"""
|
26 |
+
print("Downloading useful models. It might take a while...")
|
|
|
|
|
|
|
|
|
27 |
with ThreadPoolExecutor() as executor:
|
28 |
+
executor.map(download_pytorch_model, AVAILABLE_MODELS)
|
29 |
|
30 |
|
31 |
+
def main():
|
32 |
+
# disable_progress_bar()
|
33 |
download_useful_models()
|
34 |
|
35 |
|