Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer | |
from transformers import pipeline | |
from utils import format_moves | |
import pandas as pd | |
import tensorflow as tf | |
import json | |
model_checkpoint = "distilgpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
generate = pipeline("text-generation", | |
model="arjunpatel/distilgpt2-finetuned-pokemon-moves", | |
tokenizer=tokenizer) | |
# load in the model | |
seed_text = "This move is called " | |
tf.random.set_seed(0) | |
def update_history(df, move_name, move_desc, generation, parameters): | |
# get rid of first seed phrase | |
move_desc = move_desc.split("\n")[1:] | |
move_desc = "\n".join(move_desc) | |
new_row = [{"Move Name": move_name, | |
"Move Description": move_desc, | |
"Generation Type": generation, | |
"Parameters": json.dumps(parameters)}] | |
return pd.concat([df, pd.DataFrame(new_row)]) | |
def create_move(move, history): | |
generated_move = format_moves(generate(seed_text + move, num_return_sequences=1)) | |
return generated_move, update_history(history, move, generated_move, | |
"baseline", "None") | |
def create_greedy_search_move(move, history): | |
generated_move = format_moves(generate(seed_text + move, do_sample=False)) | |
return generated_move, update_history(history, move, generated_move, | |
"greedy", "None") | |
def create_beam_search_move(move, num_beams, history): | |
generated_move = format_moves(generate(seed_text + move, num_beams=num_beams, | |
num_return_sequences=1, | |
do_sample=False, early_stopping=True)) | |
return generated_move, update_history(history, move, generated_move, | |
"beam", {"num_beams": 2}) | |
def create_sampling_search_move(move, do_sample, temperature, history): | |
generated_move = format_moves(generate(seed_text + move, do_sample=do_sample, temperature=float(temperature), | |
num_return_sequences=1, topk=0)) | |
return generated_move, update_history(history, move, generated_move, | |
"temperature", {"do_sample": do_sample, | |
"temperature": temperature}) | |
def create_top_search_move(move, topk, topp, history): | |
generated_move = format_moves(generate( | |
seed_text + move, | |
do_sample=True, | |
num_return_sequences=1, | |
top_k=topk, | |
top_p=topp, | |
force_word_ids=tokenizer.encode("The user", return_tensors='tf'))) | |
return generated_move, update_history(history, move, generated_move, | |
"top", {"top k": topk, | |
"top p": topp}) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("<h1><center>What's that Pokemon Move?</center></h1>") | |
gr.Markdown( | |
"""This Gradio demo allows you to generate Pokemon Move descriptions given a name, and learn more about text | |
decoding methods in the process! Each tab aims to explain each generation methodology available for the | |
model. The dataframe below allows you to keep track of each move generated, to compare!""") | |
gr.Markdown("<h3> How does text generation work? <h3>") | |
gr.Markdown("""Roughly, text generation models accept an input sequence of words (or parts of words, | |
known as tokens). | |
These models then output a corresponding set of words or tokens. Given the input, the model | |
estimates the probability of another possible word or token appearing right after the given sequence. In | |
other words, the model estimates conditional probabilities and ranks them in order to generate sequences | |
. """) | |
gr.Markdown("Enter a two to three word Pokemon Move name of your imagination below, with each word capitalized!") | |
gr.Markdown("<h3> Move Generation <h3>") | |
with gr.Tabs(): | |
with gr.TabItem("Standard Generation"): | |
gr.Markdown( | |
"""The default parameters for distilgpt2 work well to generate moves. Use this tab as | |
a baseline for your experiments.""") | |
with gr.Row(): | |
text_input_baseline = gr.Textbox(label="Move", | |
placeholder="Type a two or three word move name here! Try \"Wonder " | |
"Shield\"!") | |
text_output_baseline = gr.Textbox(label="Move Description", | |
placeholder="Leave this blank!") | |
text_button_baseline = gr.Button("Create my move!") | |
with gr.TabItem("Greedy Search Decoding"): | |
gr.Markdown(""" | |
Greedy search is a decoding method that relies on finding words that has the highest estimated | |
probability of following the sequence thus far. | |
Therefore, the model \"greedily\" grabs the highest | |
probability word and continues generating the sentence. | |
This has the side effect of finding sequences that are reasonable, but avoids sequences that are | |
less probable but way more interesting. | |
Try the other decoding methods to get sentences with more variety! | |
""") | |
with gr.Row(): | |
text_input_greedy = gr.Textbox(label="Move") | |
text_output_greedy = gr.Textbox(label="Move Description") | |
text_button_greedy = gr.Button("Create my move!") | |
with gr.TabItem("Beam Search"): | |
gr.Markdown("""Beam search is an improvement on Greedy Search. Instead of directly grabbing the word that | |
maximizes probability, we conduct a search with B number of candidates. We then try to find the next word | |
that would most likely follow each beam, and we grab the top B candidates of that search. This may | |
eliminate one of the original beams we started with, and that's okay! That is how the algorithm decides | |
on an optimal candidate. Eventually, the beam sequence terminate or are eliminated due to being too | |
improbable. | |
Increasing the number of beams will increase model generation time, but also result in a more thorough | |
search. Decreasing the number of beams will decrease decoding time, but it may not find an optimal | |
sentence. | |
Play around with the num_beams parameter to experiment! """ | |
) | |
with gr.Row(): | |
num_beams = gr.Slider(minimum=2, maximum=10, value=2, step=1, | |
label="Number of Beams") | |
text_input_beam = gr.Textbox(label="Move") | |
text_output_beam = gr.Textbox(label="Move Description") | |
text_button_beam = gr.Button("Create my move!") | |
with gr.TabItem("Sampling and Temperature Search"): | |
gr.Markdown( | |
"""Greedy Search and Beam Search were both good at finding sequences that are likely to follow our | |
input text, but when generating cool move descriptions, we want some more variety! | |
Instead of choosing the word or token that is most likely to follow a given sequence, we can instead | |
ask the model to sample across the probability distribution of likely words. | |
It's kind of like walking into the tall grass and finding a Pokemon encounter. | |
There are different encounter rates, which allow | |
for the most common mons to appear (looking at you, Zubat), but also account for surprise, like shinys! | |
We might even want to go further, though. We can rescale the probability distributions directly | |
instead, allowing for rare words to temporarily become more frequently. We do this using the | |
temperature parameter. | |
Turn the temperature up, and rare tokens become very likely! Cool down, and we approach more sensible | |
output. | |
Experiment with turning sampling on and off, and by varying temperature below!. | |
""") | |
with gr.Row(): | |
temperature = gr.Slider(minimum=0.3, maximum=4.0, value=1.0, step=0.1, | |
label="Temperature") | |
text_input_temp = gr.Textbox(label="Move") | |
with gr.Row(): | |
sample_boolean = gr.Checkbox(label="Enable Sampling?") | |
text_output_temp = gr.Textbox(label="Move Description") | |
text_button_temp = gr.Button("Create my move!") | |
with gr.TabItem("Top K and Top P Sampling"): | |
gr.Markdown( | |
"""When we want more control over the words we get to sample from, we turn to Top K and Top P | |
decoding methods! | |
The Top K sampling method selects the K most probable words given a sequence, and then samples from | |
that subset, rather than the whole vocabulary. This effectively cuts out low probability words. | |
Top P also reduces the available vocabulary to sample from, but instead of choosing the number of | |
words or tokens in advance, we sort the vocabulary from most to least likely word, and we | |
grab the smallest set of words that sum to P. This allows for the number of words we look at to | |
change while sampling, instead of being fixed. | |
We can even use both methods at the same time! To disable Top K, set it to 0 using the slider. | |
To disable Top P, set it to 1""") | |
with gr.Row(): | |
topk = gr.Slider(minimum=0, maximum=200, value=0, step=5, | |
label="Top K") | |
text_input_top = gr.Textbox(label="Move") | |
with gr.Row(): | |
topp = gr.Slider(minimum=0.10, maximum=1, value=1, step=0.05, | |
label="Top P") | |
text_output_top = gr.Textbox(label="Move Description") | |
text_button_top = gr.Button("Create my move!") | |
with gr.Box(): | |
gr.Markdown("<h3> Generation History <h3>") | |
# Displays a dataframe with the history of moves generated, with parameters | |
history = gr.Dataframe(headers=["Move Name", "Move Description", "Generation Type", "Parameters"]) | |
with gr.Box(): | |
gr.Markdown("<h3>How did you make this?<h3>") | |
gr.Markdown(""" | |
Hi! My name is <a href =https://www.linkedin.com/in/arjunkirtipatel/>Arjun Patel</a> and I'm Lead Data Scientist | |
over at <a href =https://www.speeko.co>Speeko</a>. | |
Nice to meet you! | |
I collected the dataset from <a href =https://www.serebii.net>Serebii</a>, a news source and aggregator of | |
Pokemon info. | |
I then added a seed phrase "This move is called" just before each move in order to assist the model in | |
generation. | |
I then followed HuggingFace's handy language_modeling.ipynb for fine-tuning distillgpt2 on this tiny dataset, | |
and it surprisingly worked! | |
I learned all about text generation using the book <a href | |
=https://www.oreilly.com/library/view/natural-language-processing/9781098103231/> Natural Language Processing | |
with Transformers</a> by Lewis Tunstall, Leandro von Werra and Thomas Wolf, as well as <a href | |
=https://huggingface.co/blog/how-to-generate>this fantastic article</a> by Patrick von Platen. Thanks to all | |
of these folks for creating these learning materials, and thanks to the Hugging Face team for developing this | |
product! """) | |
text_button_baseline.click(create_move, inputs=[text_input_baseline, history], | |
outputs=[text_output_baseline, history]) | |
text_button_greedy.click(create_greedy_search_move, inputs=[text_input_greedy, history], | |
outputs=[text_output_greedy, history]) | |
text_button_temp.click(create_sampling_search_move, inputs=[text_input_temp, sample_boolean, temperature, history], | |
outputs=[text_output_temp, history]) | |
text_button_beam.click(create_beam_search_move, inputs=[text_input_beam, num_beams, history], | |
outputs=[text_output_beam, history]) | |
text_button_top.click(create_top_search_move, inputs=[text_input_top, topk, topp, history], | |
outputs=[text_output_top, history]) | |
demo.launch(share=True) | |