|
from time import time |
|
|
|
import streamlit as st |
|
from grouped_sampling import GroupedSamplingPipeLine |
|
|
|
|
|
def generate_text( |
|
pipeline: GroupedSamplingPipeLine, |
|
prompt: str, |
|
output_length: int, |
|
) -> str: |
|
""" |
|
Generates text using the given pipeline. |
|
:param pipeline: The pipeline to use. GroupedSamplingPipeLine. |
|
:param prompt: The prompt to use. str. |
|
:param output_length: The size of the text to generate in tokens. int > 0. |
|
:return: The generated text. str. |
|
""" |
|
return pipeline( |
|
prompt_s=prompt, |
|
max_new_tokens=output_length, |
|
return_text=True, |
|
return_full_text=False, |
|
)["generated_text"] |
|
|
|
|
|
def on_form_submit( |
|
pipeline: GroupedSamplingPipeLine, |
|
output_length: int, |
|
prompt: str, |
|
) -> str: |
|
""" |
|
Called when the user submits the form. |
|
:param pipeline: The pipeline to use. GroupedSamplingPipeLine. |
|
:param output_length: The size of the groups to use. |
|
:param prompt: The prompt to use. |
|
:return: The output of the model. |
|
:raises ValueError: If the model name is not supported, the output length is <= 0, |
|
the prompt is empty or longer than |
|
16384 characters, or the output length is not an integer. |
|
TypeError: If the output length is not an integer or the prompt is not a string. |
|
RuntimeError: If the model is not found. |
|
""" |
|
if len(prompt) == 0: |
|
raise ValueError("The prompt must not be empty.") |
|
st.write("Generating text...") |
|
print("Generating text...") |
|
generation_start_time = time() |
|
generated_text = generate_text( |
|
pipeline=pipeline, |
|
prompt=prompt, |
|
output_length=output_length, |
|
) |
|
generation_end_time = time() |
|
generation_time = generation_end_time - generation_start_time |
|
st.write(f"Finished generating text in {generation_time:,.2f} seconds.") |
|
print(f"Finished generating text in {generation_time:,.2f} seconds.") |
|
return generated_text |
|
|