""" The Streamlit app for the project demo. In the demo, the user can write a prompt and the model will generate a response using the grouped sampling algorithm. """ import streamlit as st from torch.cuda import CudaError from hanlde_form_submit import on_form_submit from on_server_start import main as on_server_start_main on_server_start_main() st.title("Grouped Sampling Demo") with st.form("request_form"): selected_model_name: str = st.text_input( label="Model name", value="gpt2", help=f"The name of the model to use." f"Supported models are all the models in:" f" https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch", ) output_length: int = st.number_input( label="Number of word pieces in the generated text, 1-4096 (default: 100)", min_value=1, max_value=4096, value=100, help="The length of the output text in tokens (word pieces)." ) submitted_prompt: str = st.text_area( label="Input for the model, It is highly recommended to write an English prompt.", help="Enter the prompt for the model. The model will generate a response based on this prompt.", value="Instruction: Answer in yes or no.\n" "Question: Is this a prompt?\n" "Answer: ", max_chars=2048, ) submitted: bool = st.form_submit_button( label="Generate", help="Generate the output text.", disabled=False, ) if submitted: try: output = on_form_submit(selected_model_name, output_length, submitted_prompt) except CudaError as e: st.error("Out of memory. Please try a smaller model, shorter prompt, or a smaller output length.") except (ValueError, TypeError, RuntimeError) as e: st.error(e) st.write(f"Generated text: {output}")