grouped-sampling-demo / hanlde_form_submit.py
yonikremer's picture
downloading models at the start of the app and not at usage time
df273ff
raw
history blame
1.97 kB
from time import time
import streamlit as st
from grouped_sampling import GroupedSamplingPipeLine
def generate_text(
pipeline: GroupedSamplingPipeLine,
prompt: str,
output_length: int,
) -> str:
"""
Generates text using the given pipeline.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param prompt: The prompt to use. str.
:param output_length: The size of the text to generate in tokens. int > 0.
:return: The generated text. str.
"""
return pipeline(
prompt_s=prompt,
max_new_tokens=output_length,
return_text=True,
return_full_text=False,
)["generated_text"]
def on_form_submit(
pipeline: GroupedSamplingPipeLine,
output_length: int,
prompt: str,
) -> str:
"""
Called when the user submits the form.
:param pipeline: The pipeline to use. GroupedSamplingPipeLine.
:param output_length: The size of the groups to use.
:param prompt: The prompt to use.
:return: The output of the model.
:raises ValueError: If the model name is not supported, the output length is <= 0,
the prompt is empty or longer than
16384 characters, or the output length is not an integer.
TypeError: If the output length is not an integer or the prompt is not a string.
RuntimeError: If the model is not found.
"""
if len(prompt) == 0:
raise ValueError("The prompt must not be empty.")
st.write("Generating text...")
print("Generating text...")
generation_start_time = time()
generated_text = generate_text(
pipeline=pipeline,
prompt=prompt,
output_length=output_length,
)
generation_end_time = time()
generation_time = generation_end_time - generation_start_time
st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
print(f"Finished generating text in {generation_time:,.2f} seconds.")
return generated_text