yonikremer
commited on
Commit
·
6dd4824
1
Parent(s):
df273ff
REREVRSE (#6)
Browse files- restoring the last working version (7068818ba6f1ea28a71928976c982c3296645b38)
- fixed yaml (f08a1cf56354e62f271279dca48706ed68051cf3)
- .dockerignore +0 -8
- .gitignore +1 -0
- .streamlit/config.toml +1 -7
- Dockerfile +0 -20
- README.md +5 -5
- app.py +7 -45
- available_models.py +2 -2
- download_repo.py +45 -0
- hanlde_form_submit.py +58 -4
- tests.py +7 -6
.dockerignore
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
tests.py
|
2 |
-
.gitattributes
|
3 |
-
.gitignore
|
4 |
-
start_server.py
|
5 |
-
.git/
|
6 |
-
.idea/
|
7 |
-
.pytest_cache/
|
8 |
-
__pycache__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
@@ -5505,3 +5505,4 @@ Mercury.modules
|
|
5505 |
!/.streamlit/config.toml
|
5506 |
!/Dockerfile
|
5507 |
!/.dockerignore
|
|
|
|
5505 |
!/.streamlit/config.toml
|
5506 |
!/Dockerfile
|
5507 |
!/.dockerignore
|
5508 |
+
!/download_repo.py
|
.streamlit/config.toml
CHANGED
@@ -1,8 +1,2 @@
|
|
1 |
[browser]
|
2 |
-
gatherUsageStats = false
|
3 |
-
[server]
|
4 |
-
port = 7860
|
5 |
-
[logger]
|
6 |
-
level = "error"
|
7 |
-
[theme]
|
8 |
-
base = "dark"
|
|
|
1 |
[browser]
|
2 |
+
gatherUsageStats = false
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
FROM bitnami/pytorch
|
2 |
-
|
3 |
-
RUN mkdir --mode 777 /app/my_streamlit_app
|
4 |
-
WORKDIR /app/my_stramlit_app
|
5 |
-
|
6 |
-
COPY ./requirements.txt /app/my_streamlit_app/requirements.txt
|
7 |
-
|
8 |
-
RUN pip install --no-cache-dir -r /app/my_streamlit_app/requirements.txt
|
9 |
-
|
10 |
-
RUN mkdir --mode 777 "/app/my_streamlit_app/.cache/"
|
11 |
-
RUN mkdir --mode 777 "/app/my_streamlit_app/.cache/huggingface/"
|
12 |
-
ENV HUGGINGFACE_HUB_CACHE="/app/my_streamlit_app/.cache/huggingface"
|
13 |
-
RUN mkdir --mode 777 "/app/my_streamlit_app/.cache/transformers/"
|
14 |
-
ENV TRANSFORMERS_CACHE="/app/my_streamlit_app/.cache/transformers"
|
15 |
-
|
16 |
-
ENV TOKENIZERS_PARALLELISM=false
|
17 |
-
|
18 |
-
COPY . /app/my_streamlit_app/
|
19 |
-
|
20 |
-
CMD ["streamlit", "run", "--server.port", "7860", "--server.enableCORS", "false", "--server.enableXsrfProtection", "false", "--browser.gatherUsageStats", "false", "--theme.base", "dark", "--server.maxUploadSize", "1000", "/app/my_streamlit_app/app.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -3,12 +3,12 @@ title: Grouped Sampling Demo
|
|
3 |
emoji: 🐠
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
6 |
-
sdk:
|
7 |
-
|
8 |
-
|
|
|
9 |
fullWidth: true
|
10 |
-
|
11 |
-
tags: [text-generation, pytorch, transformers, streamlit, docker]
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
3 |
emoji: 🐠
|
4 |
colorFrom: pink
|
5 |
colorTo: purple
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.17.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
fullWidth: true
|
11 |
+
tags: [text-generation, pytorch, transformers, streamlit]
|
|
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -3,61 +3,27 @@ The Streamlit app for the project demo.
|
|
3 |
In the demo, the user can write a prompt
|
4 |
and the model will generate a response using the grouped sampling algorithm.
|
5 |
"""
|
6 |
-
import os
|
7 |
-
from time import time
|
8 |
|
9 |
import streamlit as st
|
10 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
11 |
from torch.cuda import CudaError
|
12 |
-
from huggingface_hub import logging as hf_hub_logging
|
13 |
|
14 |
from available_models import AVAILABLE_MODELS
|
15 |
from hanlde_form_submit import on_form_submit
|
16 |
|
17 |
|
18 |
-
|
19 |
-
"""
|
20 |
-
Creates a pipeline with the given model name and group size.
|
21 |
-
:param model_name: The name of the model to use.
|
22 |
-
:param group_size: The size of the groups to use.
|
23 |
-
:return: A pipeline with the given model name and group size.
|
24 |
-
"""
|
25 |
-
st.write(f"Starts creating pipeline with model: {model_name}")
|
26 |
-
pipeline_start_time = time()
|
27 |
-
pipeline = GroupedSamplingPipeLine(
|
28 |
-
model_name=model_name,
|
29 |
-
group_size=group_size,
|
30 |
-
end_of_sentence_stop=False,
|
31 |
-
top_k=50,
|
32 |
-
)
|
33 |
-
pipeline_end_time = time()
|
34 |
-
pipeline_time = pipeline_end_time - pipeline_start_time
|
35 |
-
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
|
36 |
-
return pipeline
|
37 |
-
|
38 |
-
|
39 |
-
hf_hub_logging.set_verbosity_error()
|
40 |
-
|
41 |
-
st.set_page_config(
|
42 |
-
page_title="דגימה בקבוצות - שימוש יעיל במודלי שפה סיבתיים",
|
43 |
-
layout="wide",
|
44 |
-
)
|
45 |
-
|
46 |
-
pipelines = {
|
47 |
-
model_name: create_pipeline(model_name, 1024) for model_name in AVAILABLE_MODELS[1:]
|
48 |
-
}
|
49 |
|
50 |
with st.form("request_form"):
|
51 |
selected_model_name: str = st.selectbox(
|
52 |
label="בחרו מודל",
|
53 |
options=AVAILABLE_MODELS,
|
54 |
-
help="
|
55 |
)
|
56 |
|
57 |
output_length: int = st.number_input(
|
58 |
-
label="כמות המילים המקסימלית בפלט - בין 1 ל-
|
59 |
min_value=1,
|
60 |
-
max_value=
|
61 |
value=5,
|
62 |
)
|
63 |
|
@@ -65,7 +31,7 @@ with st.form("request_form"):
|
|
65 |
label="הקלט לאלוגריתם (באנגלית בלבד)",
|
66 |
value="Instruction: Answer in yes or no.\n"
|
67 |
"Question: Is the sky blue?\n"
|
68 |
-
"Answer:",
|
69 |
max_chars=2048,
|
70 |
)
|
71 |
|
@@ -77,7 +43,7 @@ with st.form("request_form"):
|
|
77 |
if submitted:
|
78 |
try:
|
79 |
output = on_form_submit(
|
80 |
-
|
81 |
output_length,
|
82 |
submitted_prompt,
|
83 |
)
|
@@ -89,10 +55,6 @@ with st.form("request_form"):
|
|
89 |
st.write(f"Generated text: {output}")
|
90 |
|
91 |
|
92 |
-
|
93 |
-
os.path.dirname(__file__),
|
94 |
-
"user_instructions_hebrew.md",
|
95 |
-
)
|
96 |
-
with open(user_instructions_file, "r") as fh:
|
97 |
long_description = fh.read()
|
98 |
st.markdown(long_description)
|
|
|
3 |
In the demo, the user can write a prompt
|
4 |
and the model will generate a response using the grouped sampling algorithm.
|
5 |
"""
|
|
|
|
|
6 |
|
7 |
import streamlit as st
|
|
|
8 |
from torch.cuda import CudaError
|
|
|
9 |
|
10 |
from available_models import AVAILABLE_MODELS
|
11 |
from hanlde_form_submit import on_form_submit
|
12 |
|
13 |
|
14 |
+
st.title("דגימה בקבוצות - שימוש יעיל במודלי שפה סיבתיים")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
with st.form("request_form"):
|
17 |
selected_model_name: str = st.selectbox(
|
18 |
label="בחרו מודל",
|
19 |
options=AVAILABLE_MODELS,
|
20 |
+
help="opt-iml-max-30b generates better texts but is slower",
|
21 |
)
|
22 |
|
23 |
output_length: int = st.number_input(
|
24 |
+
label="כמות המילים המקסימלית בפלט - בין 1 ל-4096",
|
25 |
min_value=1,
|
26 |
+
max_value=4096,
|
27 |
value=5,
|
28 |
)
|
29 |
|
|
|
31 |
label="הקלט לאלוגריתם (באנגלית בלבד)",
|
32 |
value="Instruction: Answer in yes or no.\n"
|
33 |
"Question: Is the sky blue?\n"
|
34 |
+
"Answer: ",
|
35 |
max_chars=2048,
|
36 |
)
|
37 |
|
|
|
43 |
if submitted:
|
44 |
try:
|
45 |
output = on_form_submit(
|
46 |
+
selected_model_name,
|
47 |
output_length,
|
48 |
submitted_prompt,
|
49 |
)
|
|
|
55 |
st.write(f"Generated text: {output}")
|
56 |
|
57 |
|
58 |
+
with open("user_instructions_hebrew.md", "r") as fh:
|
|
|
|
|
|
|
|
|
59 |
long_description = fh.read()
|
60 |
st.markdown(long_description)
|
available_models.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
AVAILABLE_MODELS = (
|
2 |
-
"
|
3 |
-
"
|
4 |
)
|
|
|
1 |
AVAILABLE_MODELS = (
|
2 |
+
"facebook/opt-iml-max-1.3b",
|
3 |
+
"facebook/opt-iml-max-30b",
|
4 |
)
|
download_repo.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import urllib3
|
2 |
+
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
|
5 |
+
from available_models import AVAILABLE_MODELS
|
6 |
+
|
7 |
+
|
8 |
+
def change_default_timeout(new_timeout: int) -> None:
|
9 |
+
"""
|
10 |
+
Changes the default timeout for downloading repositories from the Hugging Face Hub.
|
11 |
+
Prevents the following errors:
|
12 |
+
urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
|
13 |
+
Read timed out. (read timeout=10)
|
14 |
+
"""
|
15 |
+
urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout
|
16 |
+
|
17 |
+
|
18 |
+
def download_pytorch_model(name: str) -> None:
|
19 |
+
"""
|
20 |
+
Downloads a pytorch model and all the small files from the model's repository.
|
21 |
+
Other model formats (tensorflow, tflite, safetensors, msgpack, ot...) are not downloaded.
|
22 |
+
"""
|
23 |
+
number_of_seconds_in_a_year: int = 60 * 60 * 24 * 365
|
24 |
+
change_default_timeout(number_of_seconds_in_a_year)
|
25 |
+
snapshot_download(
|
26 |
+
repo_id=name,
|
27 |
+
etag_timeout=number_of_seconds_in_a_year,
|
28 |
+
resume_download=True,
|
29 |
+
repo_type="model",
|
30 |
+
library_name="pt",
|
31 |
+
# h5, tflite, safetensors, msgpack and ot models files are not needed
|
32 |
+
ignore_patterns=[
|
33 |
+
"*.h5",
|
34 |
+
"*.tflite",
|
35 |
+
"*.safetensors",
|
36 |
+
"*.msgpack",
|
37 |
+
"*.ot",
|
38 |
+
"*.md"
|
39 |
+
],
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
for model_name in AVAILABLE_MODELS:
|
45 |
+
download_pytorch_model(model_name)
|
hanlde_form_submit.py
CHANGED
@@ -1,8 +1,51 @@
|
|
|
|
1 |
from time import time
|
2 |
|
3 |
import streamlit as st
|
4 |
from grouped_sampling import GroupedSamplingPipeLine
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def generate_text(
|
8 |
pipeline: GroupedSamplingPipeLine,
|
@@ -25,13 +68,13 @@ def generate_text(
|
|
25 |
|
26 |
|
27 |
def on_form_submit(
|
28 |
-
|
29 |
output_length: int,
|
30 |
prompt: str,
|
31 |
) -> str:
|
32 |
"""
|
33 |
Called when the user submits the form.
|
34 |
-
:param
|
35 |
:param output_length: The size of the groups to use.
|
36 |
:param prompt: The prompt to use.
|
37 |
:return: The output of the model.
|
@@ -43,8 +86,16 @@ def on_form_submit(
|
|
43 |
"""
|
44 |
if len(prompt) == 0:
|
45 |
raise ValueError("The prompt must not be empty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
st.write("Generating text...")
|
47 |
-
print("Generating text...")
|
48 |
generation_start_time = time()
|
49 |
generated_text = generate_text(
|
50 |
pipeline=pipeline,
|
@@ -54,5 +105,8 @@ def on_form_submit(
|
|
54 |
generation_end_time = time()
|
55 |
generation_time = generation_end_time - generation_start_time
|
56 |
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
|
57 |
-
|
|
|
|
|
|
|
58 |
return generated_text
|
|
|
1 |
+
import os
|
2 |
from time import time
|
3 |
|
4 |
import streamlit as st
|
5 |
from grouped_sampling import GroupedSamplingPipeLine
|
6 |
|
7 |
+
from download_repo import download_pytorch_model
|
8 |
+
|
9 |
+
|
10 |
+
def is_downloaded(model_name: str) -> bool:
|
11 |
+
"""
|
12 |
+
Checks if the model is downloaded.
|
13 |
+
:param model_name: The name of the model to check.
|
14 |
+
:return: True if the model is downloaded, False otherwise.
|
15 |
+
"""
|
16 |
+
models_dir = "/root/.cache/huggingface/hub"
|
17 |
+
model_dir = os.path.join(models_dir, f"models--{model_name.replace('/', '--')}")
|
18 |
+
return os.path.isdir(model_dir)
|
19 |
+
|
20 |
+
|
21 |
+
def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
|
22 |
+
"""
|
23 |
+
Creates a pipeline with the given model name and group size.
|
24 |
+
:param model_name: The name of the model to use.
|
25 |
+
:param group_size: The size of the groups to use.
|
26 |
+
:return: A pipeline with the given model name and group size.
|
27 |
+
"""
|
28 |
+
if not is_downloaded(model_name):
|
29 |
+
download_repository_start_time = time()
|
30 |
+
st.write(f"Starts downloading model: {model_name} from the internet.")
|
31 |
+
download_pytorch_model(model_name)
|
32 |
+
download_repository_end_time = time()
|
33 |
+
download_time = download_repository_end_time - download_repository_start_time
|
34 |
+
st.write(f"Finished downloading model: {model_name} from the internet in {download_time:,.2f} seconds.")
|
35 |
+
st.write(f"Starts creating pipeline with model: {model_name}")
|
36 |
+
pipeline_start_time = time()
|
37 |
+
pipeline = GroupedSamplingPipeLine(
|
38 |
+
model_name=model_name,
|
39 |
+
group_size=group_size,
|
40 |
+
end_of_sentence_stop=False,
|
41 |
+
top_k=50,
|
42 |
+
load_in_8bit=False,
|
43 |
+
)
|
44 |
+
pipeline_end_time = time()
|
45 |
+
pipeline_time = pipeline_end_time - pipeline_start_time
|
46 |
+
st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
|
47 |
+
return pipeline
|
48 |
+
|
49 |
|
50 |
def generate_text(
|
51 |
pipeline: GroupedSamplingPipeLine,
|
|
|
68 |
|
69 |
|
70 |
def on_form_submit(
|
71 |
+
model_name: str,
|
72 |
output_length: int,
|
73 |
prompt: str,
|
74 |
) -> str:
|
75 |
"""
|
76 |
Called when the user submits the form.
|
77 |
+
:param model_name: The name of the model to use.
|
78 |
:param output_length: The size of the groups to use.
|
79 |
:param prompt: The prompt to use.
|
80 |
:return: The output of the model.
|
|
|
86 |
"""
|
87 |
if len(prompt) == 0:
|
88 |
raise ValueError("The prompt must not be empty.")
|
89 |
+
st.write(f"Loading model: {model_name}...")
|
90 |
+
loading_start_time = time()
|
91 |
+
pipeline = create_pipeline(
|
92 |
+
model_name=model_name,
|
93 |
+
group_size=output_length,
|
94 |
+
)
|
95 |
+
loading_end_time = time()
|
96 |
+
loading_time = loading_end_time - loading_start_time
|
97 |
+
st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
|
98 |
st.write("Generating text...")
|
|
|
99 |
generation_start_time = time()
|
100 |
generated_text = generate_text(
|
101 |
pipeline=pipeline,
|
|
|
105 |
generation_end_time = time()
|
106 |
generation_time = generation_end_time - generation_start_time
|
107 |
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
|
108 |
+
if not isinstance(generated_text, str):
|
109 |
+
raise RuntimeError(f"The model {model_name} did not generate any text.")
|
110 |
+
if len(generated_text) == 0:
|
111 |
+
raise RuntimeError(f"The model {model_name} did not generate any text.")
|
112 |
return generated_text
|
tests.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import pytest as pytest
|
2 |
-
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
|
|
4 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
5 |
|
6 |
|
@@ -14,13 +15,13 @@ def test_on_form_submit():
|
|
14 |
empty_prompt = ""
|
15 |
with pytest.raises(ValueError):
|
16 |
on_form_submit(model_name, output_length, empty_prompt)
|
17 |
-
unsupported_model_name = "unsupported_model_name"
|
18 |
-
with pytest.raises(UnsupportedModelNameException):
|
19 |
-
on_form_submit(unsupported_model_name, output_length, prompt)
|
20 |
|
21 |
|
22 |
-
|
23 |
-
model_name
|
|
|
|
|
|
|
24 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|
25 |
assert pipeline is not None
|
26 |
assert pipeline.model_name == model_name
|
|
|
1 |
import pytest as pytest
|
2 |
+
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
4 |
+
from available_models import AVAILABLE_MODELS
|
5 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
6 |
|
7 |
|
|
|
15 |
empty_prompt = ""
|
16 |
with pytest.raises(ValueError):
|
17 |
on_form_submit(model_name, output_length, empty_prompt)
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
+
@pytest.mark.parametrize(
|
21 |
+
"model_name",
|
22 |
+
AVAILABLE_MODELS,
|
23 |
+
)
|
24 |
+
def test_create_pipeline(model_name: str):
|
25 |
pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
|
26 |
assert pipeline is not None
|
27 |
assert pipeline.model_name == model_name
|