yonikremer
commited on
Commit
路
df273ff
1
Parent(s):
25de4e0
downloading models at the start of the app and not at usage time
Browse files- app.py +28 -1
- hanlde_form_submit.py +3 -41
app.py
CHANGED
@@ -4,8 +4,10 @@ In the demo, the user can write a prompt
|
|
4 |
and the model will generate a response using the grouped sampling algorithm.
|
5 |
"""
|
6 |
import os
|
|
|
7 |
|
8 |
import streamlit as st
|
|
|
9 |
from torch.cuda import CudaError
|
10 |
from huggingface_hub import logging as hf_hub_logging
|
11 |
|
@@ -13,6 +15,27 @@ from available_models import AVAILABLE_MODELS
|
|
13 |
from hanlde_form_submit import on_form_submit
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
hf_hub_logging.set_verbosity_error()
|
17 |
|
18 |
st.set_page_config(
|
@@ -20,6 +43,10 @@ st.set_page_config(
|
|
20 |
layout="wide",
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
23 |
with st.form("request_form"):
|
24 |
selected_model_name: str = st.selectbox(
|
25 |
label="讘讞专讜 诪讜讚诇",
|
@@ -50,7 +77,7 @@ with st.form("request_form"):
|
|
50 |
if submitted:
|
51 |
try:
|
52 |
output = on_form_submit(
|
53 |
-
selected_model_name,
|
54 |
output_length,
|
55 |
submitted_prompt,
|
56 |
)
|
|
|
4 |
and the model will generate a response using the grouped sampling algorithm.
|
5 |
"""
|
6 |
import os
|
7 |
+
from time import time
|
8 |
|
9 |
import streamlit as st
|
10 |
+
from grouped_sampling import GroupedSamplingPipeLine
|
11 |
from torch.cuda import CudaError
|
12 |
from huggingface_hub import logging as hf_hub_logging
|
13 |
|
|
|
15 |
from hanlde_form_submit import on_form_submit
|
16 |
|
17 |
|
18 |
+
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
|
19 |
+
"""
|
20 |
+
Creates a pipeline with the given model name and group size.
|
21 |
+
:param model_name: The name of the model to use.
|
22 |
+
:param group_size: The size of the groups to use.
|
23 |
+
:return: A pipeline with the given model name and group size.
|
24 |
+
"""
|
25 |
+
st.write(f"Starts creating pipeline with model: {model_name}")
|
26 |
+
pipeline_start_time = time()
|
27 |
+
pipeline = GroupedSamplingPipeLine(
|
28 |
+
model_name=model_name,
|
29 |
+
group_size=group_size,
|
30 |
+
end_of_sentence_stop=False,
|
31 |
+
top_k=50,
|
32 |
+
)
|
33 |
+
pipeline_end_time = time()
|
34 |
+
pipeline_time = pipeline_end_time - pipeline_start_time
|
35 |
+
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
|
36 |
+
return pipeline
|
37 |
+
|
38 |
+
|
39 |
hf_hub_logging.set_verbosity_error()
|
40 |
|
41 |
st.set_page_config(
|
|
|
43 |
layout="wide",
|
44 |
)
|
45 |
|
46 |
+
pipelines = {
|
47 |
+
model_name: create_pipeline(model_name, 1024) for model_name in AVAILABLE_MODELS[1:]
|
48 |
+
}
|
49 |
+
|
50 |
with st.form("request_form"):
|
51 |
selected_model_name: str = st.selectbox(
|
52 |
label="讘讞专讜 诪讜讚诇",
|
|
|
77 |
if submitted:
|
78 |
try:
|
79 |
output = on_form_submit(
|
80 |
+
pipelines[selected_model_name],
|
81 |
output_length,
|
82 |
submitted_prompt,
|
83 |
)
|
hanlde_form_submit.py
CHANGED
@@ -1,28 +1,7 @@
|
|
1 |
from time import time
|
2 |
|
3 |
import streamlit as st
|
4 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
5 |
-
|
6 |
-
|
7 |
-
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
|
8 |
-
"""
|
9 |
-
Creates a pipeline with the given model name and group size.
|
10 |
-
:param model_name: The name of the model to use.
|
11 |
-
:param group_size: The size of the groups to use.
|
12 |
-
:return: A pipeline with the given model name and group size.
|
13 |
-
"""
|
14 |
-
st.write(f"Starts creating pipeline with model: {model_name}")
|
15 |
-
pipeline_start_time = time()
|
16 |
-
pipeline = GroupedSamplingPipeLine(
|
17 |
-
model_name=model_name,
|
18 |
-
group_size=group_size,
|
19 |
-
end_of_sentence_stop=False,
|
20 |
-
top_k=50,
|
21 |
-
)
|
22 |
-
pipeline_end_time = time()
|
23 |
-
pipeline_time = pipeline_end_time - pipeline_start_time
|
24 |
-
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
|
25 |
-
return pipeline
|
26 |
|
27 |
|
28 |
def generate_text(
|
@@ -46,13 +25,13 @@ def generate_text(
|
|
46 |
|
47 |
|
48 |
def on_form_submit(
|
49 |
-
|
50 |
output_length: int,
|
51 |
prompt: str,
|
52 |
) -> str:
|
53 |
"""
|
54 |
Called when the user submits the form.
|
55 |
-
:param
|
56 |
:param output_length: The size of the groups to use.
|
57 |
:param prompt: The prompt to use.
|
58 |
:return: The output of the model.
|
@@ -64,19 +43,6 @@ def on_form_submit(
|
|
64 |
"""
|
65 |
if len(prompt) == 0:
|
66 |
raise ValueError("The prompt must not be empty.")
|
67 |
-
if not is_supported(model_name):
|
68 |
-
raise UnsupportedModelNameException(model_name)
|
69 |
-
st.write(f"Loading model: {model_name}...")
|
70 |
-
print(f"Loading model: {model_name}...")
|
71 |
-
loading_start_time = time()
|
72 |
-
pipeline = create_pipeline(
|
73 |
-
model_name=model_name,
|
74 |
-
group_size=output_length,
|
75 |
-
)
|
76 |
-
loading_end_time = time()
|
77 |
-
loading_time = loading_end_time - loading_start_time
|
78 |
-
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
|
79 |
-
print(f"Finished loading model: {model_name} in {loading_time:,} seconds.")
|
80 |
st.write("Generating text...")
|
81 |
print("Generating text...")
|
82 |
generation_start_time = time()
|
@@ -89,8 +55,4 @@ def on_form_submit(
|
|
89 |
generation_time = generation_end_time - generation_start_time
|
90 |
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
|
91 |
print(f"Finished generating text in {generation_time:,.2f} seconds.")
|
92 |
-
if not isinstance(generated_text, str):
|
93 |
-
raise RuntimeError(f"The model {model_name} did not generate any text.")
|
94 |
-
if len(generated_text) == 0:
|
95 |
-
raise RuntimeError(f"The model {model_name} did not generate any text.")
|
96 |
return generated_text
|
|
|
1 |
from time import time
|
2 |
|
3 |
import streamlit as st
|
4 |
+
from grouped_sampling import GroupedSamplingPipeLine
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def generate_text(
|
|
|
25 |
|
26 |
|
27 |
def on_form_submit(
|
28 |
+
pipeline: GroupedSamplingPipeLine,
|
29 |
output_length: int,
|
30 |
prompt: str,
|
31 |
) -> str:
|
32 |
"""
|
33 |
Called when the user submits the form.
|
34 |
+
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
|
35 |
:param output_length: The size of the groups to use.
|
36 |
:param prompt: The prompt to use.
|
37 |
:return: The output of the model.
|
|
|
43 |
"""
|
44 |
if len(prompt) == 0:
|
45 |
raise ValueError("The prompt must not be empty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
st.write("Generating text...")
|
47 |
print("Generating text...")
|
48 |
generation_start_time = time()
|
|
|
55 |
generation_time = generation_end_time - generation_start_time
|
56 |
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
|
57 |
print(f"Finished generating text in {generation_time:,.2f} seconds.")
|
|
|
|
|
|
|
|
|
58 |
return generated_text
|