Spaces:
Sleeping
Sleeping
import argparse | |
import logging | |
import time | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from utils import postprocess, clear, make_email_link | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
use_gpu = torch.cuda.is_available() | |
def generate_text( | |
prompt: str, | |
gen_length=64, | |
num_beams=4, | |
no_repeat_ngram_size=2, | |
length_penalty=1.0, | |
# perma params (not set by user) | |
repetition_penalty=3.5, | |
abs_max_length=512, | |
verbose=False, | |
): | |
""" | |
generate_text - generate text from a prompt using a text generation pipeline | |
Args: | |
prompt (str): the prompt to generate text from | |
model_input (_type_): the text generation pipeline | |
max_length (int, optional): the maximum length of the generated text. Defaults to 128. | |
method (str, optional): the generation method. Defaults to "Sampling". | |
verbose (bool, optional): the verbosity of the output. Defaults to False. | |
Returns: | |
str: the generated text | |
""" | |
global generator | |
if verbose: | |
logging.info(f"Generating text from prompt:\n\n{prompt}") | |
logging.info( | |
f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}" | |
) | |
st = time.perf_counter() | |
input_tokens = generator.tokenizer(prompt) | |
input_len = len(input_tokens["input_ids"]) | |
if input_len > abs_max_length: | |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") | |
result = generator( | |
prompt, | |
max_length=gen_length + input_len, | |
min_length=input_len + 4, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
length_penalty=length_penalty, | |
do_sample=False, | |
early_stopping=True, | |
# tokenizer | |
truncation=True, | |
) # generate | |
response = result[0]["generated_text"] | |
rt = time.perf_counter() - st | |
if verbose: | |
logging.info(f"Generated text: {response}") | |
logging.info(f"Generation time: {rt:.2f}s") | |
formatted_email = postprocess(response) | |
return formatted_email, make_email_link(body=formatted_email) | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
""" | |
parser = argparse.ArgumentParser( | |
description="Text Generation demo for postbot", | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
required=False, | |
type=str, | |
default="postbot/distilgpt2-emailgen", | |
help="Pass an different huggingface model tag to use a custom model", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
required=False, | |
action="store_true", | |
help="Verbose output", | |
) | |
return parser | |
default_prompt = """ | |
Hello, | |
Following up on last week's bubblegum shipment, I""" | |
if __name__ == "__main__": | |
logging.info("\n\n\nStarting new instance of app.py") | |
args = get_parser().parse_args() | |
logging.info(f"received args:\t{args}") | |
model_tag = args.model | |
verbose = args.verbose | |
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") | |
generator = pipeline( | |
"text-generation", | |
model_tag, | |
device=0 if use_gpu else -1, | |
) | |
demo = gr.Blocks() | |
logging.info("launching interface...") | |
with demo: | |
gr.Markdown("# Auto-Complete Emails - Demo") | |
gr.Markdown( | |
"Enter part of an email, and a text-gen model will complete it! See details below. " | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## Generate Text") | |
gr.Markdown("Edit the prompt and parameters and press **Generate**!") | |
prompt_text = gr.Textbox( | |
lines=4, | |
label="Email Prompt", | |
value=default_prompt, | |
) | |
with gr.Row(): | |
clear_button = gr.Button( | |
value="Clear Prompt", | |
) | |
num_gen_tokens = gr.Slider( | |
label="Generation Tokens", | |
value=64, | |
maximum=128, | |
minimum=32, | |
step=16, | |
) | |
generated_email = gr.Textbox( | |
label="Generated Result", | |
placeholder="The completed email will appear here", | |
) | |
email_link = gr.HTML("<p><em>A mailto: link will appear here</em></p>") | |
generate_button = gr.Button( | |
value="Generate!", | |
variant="primary", | |
) | |
gr.Markdown("## Advanced Options") | |
gr.Markdown( | |
"This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is." | |
) | |
num_beams = gr.Radio( | |
choices=[4, 8, 16], | |
label="Number of Beams", | |
value=4, | |
) | |
with gr.Row(): | |
no_repeat_ngram_size = gr.Radio( | |
choices=[1, 2, 3, 4], | |
label="no repeat ngram size", | |
value=2, | |
) | |
length_penalty = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
label="length penalty", | |
value=0.8, | |
step=0.1, | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## About") | |
gr.Markdown( | |
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." | |
) | |
gr.Markdown( | |
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." | |
) | |
gr.Markdown("---") | |
clear_button.click( | |
fn=clear, | |
inputs=[prompt_text], | |
outputs=[prompt_text], | |
) | |
generate_button.click( | |
fn=generate_text, | |
inputs=[ | |
prompt_text, | |
num_gen_tokens, | |
num_beams, | |
no_repeat_ngram_size, | |
length_penalty, | |
], | |
outputs=[generated_email, email_link], | |
) | |
demo.launch( | |
enable_queue=True, | |
share=True, # for local testing | |
) | |