Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import TrainingArguments, Trainer | |
from datasets import load_dataset | |
import random | |
# Load the dataset | |
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier") | |
def generate_song(state, language_model, generate_song): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Logic that takes put the selected language model and generates a song | |
if generate_song: | |
if not language_model: | |
song_text = "Please select a language model before generating a song." | |
return state, song_text, "", "" | |
# Generate the song and the options based on the language_model | |
if language_model == "Custom Gpt2": | |
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation" | |
elif language_model == "Gpt2-Medium": | |
model_name = "gpt2-medium" | |
elif language_model == "facebook/bart-base": | |
model_name = "facebook/bart-base" | |
elif language_model == "Gpt-Neo": | |
model_name = "EleutherAI/gpt-neo-1.3B" | |
else: # Customized Models | |
model_name = "customized-models" | |
#tokenzer and text generation logic | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
#Call for a random artist from the dataset | |
correct_choice = pick_artist(dataset) | |
input_text = f"A Song in the style of {correct_choice}:" | |
# Tuninng settings | |
max_length = 128 | |
input_ids = tokenizer.encode(input_text, return_tensors="pt") | |
input_ids = input_ids.to(device) | |
if language_model != "customized-models" or "Custom Gpt2": | |
### Using Beam search to generate text### | |
# encoded data | |
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text | |
# Decode output | |
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) | |
# But this output is repeating, so I need ot adjust this so that it is not repeating. | |
elif language_model == "Custom Gpt2": | |
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text | |
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, ) | |
# Decode output | |
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True)) | |
else: | |
### Nucleas Sampling to generate text### | |
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method | |
# top_p is the probability threshold for nucleus sampling | |
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution | |
# This will help to generate more diverse text that is less repetitive | |
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, ) | |
# Decode output | |
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True) | |
# Remove the first line of the output if it contains newline characters | |
if '\n' in output: | |
output = '\n'.join(output.split('\n')[1:]) | |
song_text = output | |
# Generate the multiple-choice options | |
options = generate_artist_options(dataset, correct_choice) | |
state['options'] = options | |
# Generate the multiple-choice check | |
multiple_choice_check = generate_multiple_choice_check(options, correct_choice) | |
state['multiple_choice_check'] = multiple_choice_check | |
state['correct_choice'] = correct_choice | |
return state, song_text, ', '.join(options) | |
#Check the selected artist and return whether it's correct | |
def on_submit_answer(state, user_choice): | |
# Map the user's choice (A, B, C, or D) to an index | |
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3} | |
index = choice_to_index[user_choice] | |
# Retrieve the user's choice and the correct choice from the state | |
user_artist = state['options'][index] | |
correct_artist = state['correct_choice'] | |
# Compare the user's choice with the correct choice | |
if user_artist == correct_artist: | |
return {"CORRECT": f"You guessed the right artist: {correct_artist}"} | |
else: | |
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"} | |
def pick_artist(dataset): | |
# Check if 'Artist' is in the dataset columns | |
artist_choice = list(set(dataset['train']['Artist'])) | |
artist_choice = random.choice(artist_choice) | |
return artist_choice | |
def generate_artist_options(dataset, correct_artist): | |
# Generate 3 incorrect options | |
all_artists = list(set(dataset['train']['Artist'])) | |
if correct_artist in all_artists: | |
all_artists.remove(correct_artist) | |
options = random.sample(all_artists, 3) + [correct_artist] | |
random.shuffle(options) | |
return options | |
def generate_multiple_choice_check(options, correct_choice): | |
return {option: option == correct_choice for option in options} | |
with gr.Blocks(title="Song Generator Guessing Game") as game_interface: | |
gr.Markdown(" # Song Generator Guessing Game") | |
# gr.Markdown("![Image](https://huggingface.co/spaces/SpartanCinder/NLP_Song_Generator_Guessing_Game/raw/main/RobotSinger.png)") | |
# gr.HTML("<img src='/NLP_Song_Generator_Guessing_Game/RobotSinger.png'") | |
gr.Markdown(""" | |
## Instructions | |
1. Select a language model from the dropdown. | |
2. Click the 'Generate Song' button to generate a song. | |
3. Guess the artist of the generated song by selecting an option from the radio buttons. | |
4. Click the 'Submit Answer' button to submit your guess. | |
""") | |
state = gr.State({'options': []}) | |
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.") | |
generate_song_button = gr.Button("Generate Song") | |
generated_song = gr.Textbox(label="Generated Song") | |
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options") | |
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.") | |
submit_answer_button = gr.Button("Submit Answer") | |
correct_answer = gr.Textbox(label="Results") | |
gr.Markdown(""" | |
## Developer Notes: | |
- The 'Custom Gpt2' model is a custom fine-tuned GPT-2 model for generating song lyrics. | |
-- It was trained using a custom dataset of song lyrics from various artists that I created using the Genius API. | |
-- It uses beam search to generate text, and is much faster compared to the other moodels | |
-- However, it still has trouble producing prompts | |
-- I found that artists like "Adele" and "Taylor Swift" are more likely to have coherent lyrics | |
- The 'Gpt2-Medium' model is the GPT-2 medium model from the Hugging Face model hub. | |
-- It uses beam search to generate text, and is slower compared to the custom GPT-2 model | |
-- Without tuning, it is more likely to produce a general response to the prompt | |
-- Oddly enough, had a tendency to produce lyrics that were more coherent than the full GPT-2 model | |
- The 'facebook/bart-base' model is the BART base model from the Hugging Face model hub. | |
-- The model only workd 20% of the time | |
- The 'Gpt-Neo' model is the GPT-Neo 1.3B model from the EleutherAI model hub. | |
-- It performs well, but is slower compared to the GPT-2 models | |
- The 'Customized Models' option is a placeholder for any other custom models that you may have. | |
#### Known Issues: | |
- The 'facebook/bart-base' model has a tendency to produce empty responses. | |
-- This is likely due to the model's architecture and the way it processes the input data. | |
- Ocasionaly, the Custom Gpt2 model will produce a result that is just numbers. This has only happened once or twice and both times where when is was generating a song for the Weekend. | |
""") | |
generate_song_button.click( | |
generate_song, | |
[state, language_model, generate_song_button], | |
[state, generated_song, artist_choice_display,] | |
) | |
submit_answer_button.click( | |
on_submit_answer, | |
[state, user_choice,], | |
[correct_answer] | |
) | |
game_interface.launch() |