import argparse import pprint as pp import logging import time import gradio as gr import torch from transformers import pipeline from utils import make_mailto_form, postprocess, clear, make_email_link logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) use_gpu = torch.cuda.is_available() def generate_text( prompt: str, gen_length=64, penalty_alpha=0.6, top_k=6, length_penalty=1.0, # perma params (not set by user) abs_max_length=512, verbose=False, ): """ generate_text - generate text using the text generation pipeline :param str prompt: the prompt to use for the text generation pipeline :param int gen_length: the number of tokens to generate :param float penalty_alpha: the penalty alpha for the text generation pipeline (contrastive search) :param int top_k: the top k for the text generation pipeline (contrastive search) :param int abs_max_length: the absolute max length for the text generation pipeline :param bool verbose: verbose output :return str: the generated text """ global generator if verbose: logging.info(f"Generating text from prompt:\n\n{prompt}") logging.info( pp.pformat( f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, length_penalty={length_penalty}" ) ) st = time.perf_counter() input_tokens = generator.tokenizer(prompt) input_len = len(input_tokens["input_ids"]) if input_len > abs_max_length: logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") result = generator( prompt, max_length=gen_length + input_len, # old API for generation min_length=input_len + 4, penalty_alpha=penalty_alpha, top_k=top_k, length_penalty=length_penalty, ) # generate response = result[0]["generated_text"] rt = time.perf_counter() - st if verbose: logging.info(f"Generated text: {response}") rt_string = f"Generation time: {rt:.2f}s" logging.info(rt_string) formatted_email = postprocess(response) return make_mailto_form(body=formatted_email), formatted_email def load_emailgen_model(model_tag: str): """ load_emailgen_model - load a text generation pipeline for email generation Args: model_tag (str): the huggingface model tag to load Returns: transformers.pipelines.TextGenerationPipeline: the text generation pipeline """ global generator generator = pipeline( "text-generation", model_tag, device=0 if use_gpu else -1, ) def get_parser(): """ get_parser - a helper function for the argparse module """ parser = argparse.ArgumentParser( description="Text Generation demo for postbot", ) parser.add_argument( "-m", "--model", required=False, type=str, default="postbot/distilgpt2-emailgen-V2", help="Pass an different huggingface model tag to use a custom model", ) parser.add_argument( "-l", "--max_length", required=False, type=int, default=40, help="default max length of the generated text", ) parser.add_argument( "-a", "--penalty_alpha", type=float, default=0.6, help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6", ) parser.add_argument( "-k", "--top_k", type=int, default=6, help="The top k for the text generation pipeline (contrastive search) - default 6", ) parser.add_argument( "-v", "--verbose", required=False, action="store_true", help="Verbose output", ) return parser default_prompt = """ Hello, Following up on last week's bubblegum shipment, I""" available_models = [ "postbot/distilgpt2-emailgen-V2", "postbot/distilgpt2-emailgen", "postbot/gpt2-medium-emailgen", "postbot/pythia-160m-hq-emails", ] if __name__ == "__main__": logging.info("\n\n\nStarting new instance of app.py") args = get_parser().parse_args() logging.info(f"received args:\t{args}") model_tag = args.model verbose = args.verbose max_length = args.max_length top_k = args.top_k alpha = args.penalty_alpha assert top_k > 0, "top_k must be greater than 0" assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1" logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") generator = pipeline( "text-generation", model_tag, device=0 if use_gpu else -1, ) demo = gr.Blocks() logging.info("launching interface...") with demo: gr.Markdown("# Auto-Complete Emails - Demo") gr.Markdown( "Enter part of an email, and a text-gen model will complete it! See details below. " ) gr.Markdown("---") with gr.Column(): gr.Markdown("## Generate Text") gr.Markdown("Edit the prompt and parameters and press **Generate**!") prompt_text = gr.Textbox( lines=4, label="Email Prompt", value=default_prompt, ) with gr.Row(): clear_button = gr.Button( value="Clear Prompt", ) num_gen_tokens = gr.Slider( label="Generation Tokens", value=max_length, maximum=96, minimum=16, step=8, ) generate_button = gr.Button( value="Generate!", variant="primary", ) gr.Markdown("---") gr.Markdown("### Results") # put a large HTML placeholder here generated_email = gr.Textbox( label="Generated Text", placeholder="This is where the generated text will appear", interactive=False, ) email_mailto_button = gr.HTML( "a clickable email button will appear here" ) gr.Markdown("---") gr.Markdown("## Advanced Options") gr.Markdown( "This demo generates text via the new [contrastive search](https://huggingface.co/blog/introducing-csearch). See the csearch blog post for details on the parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding." ) with gr.Row(): model_name = gr.Dropdown( choices=available_models, label="Choose a model", value=model_tag, ) load_model_button = gr.Button( "Load Model", variant="secondary", ) with gr.Row(): contrastive_top_k = gr.Radio( choices=[2, 4, 6, 8], label="Top K", value=top_k, ) penalty_alpha = gr.Slider( label="Penalty Alpha", value=alpha, maximum=1.0, minimum=0.0, step=0.1, ) length_penalty = gr.Slider( minimum=0.5, maximum=1.0, label="Length Penalty", value=1.0, step=0.1, ) gr.Markdown("---") with gr.Column(): gr.Markdown("## About") gr.Markdown( "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." ) gr.Markdown( "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." ) gr.Markdown("---") clear_button.click( fn=clear, inputs=[prompt_text], outputs=[prompt_text], ) generate_button.click( fn=generate_text, inputs=[ prompt_text, num_gen_tokens, penalty_alpha, contrastive_top_k, length_penalty, ], outputs=[email_mailto_button, generated_email], ) load_model_button.click( fn=load_emailgen_model, inputs=[model_name], outputs=[], ) demo.launch( enable_queue=True, share=True, # for local testing )