Spaces:
Runtime error
Runtime error
File size: 8,286 Bytes
6455751 c5e8d64 394bbaa 6455751 c5e8d64 6455751 c5e8d64 6455751 c5e8d64 394bbaa c5e8d64 6455751 c5e8d64 394bbaa c5e8d64 6455751 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa c5e8d64 394bbaa 6455751 c5e8d64 6455751 394bbaa 6455751 394bbaa 6455751 394bbaa 6455751 394bbaa 6455751 c5e8d64 394bbaa 6455751 394bbaa 6455751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import random
# Load the dataset
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
# print(dataset.column_names)
# print(dataset['train']['Artist'])
# artist_list = list(set(dataset['train']['Artist']))
# print(artist_list)
def generate_song(state, language_model, generate_song):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logic that takes put the selected language model and generates a song
if generate_song:
if not language_model:
song_text = "Please select a language model before generating a song."
return state, song_text, "", ""
# Generate the song and the options based on the language_model
if language_model == "Custom Gpt2":
model_name = "SpartanCinder/GPT2-finetuned-lyric-generation"
elif language_model == "Gpt2-Medium":
model_name = "gpt2-medium"
elif language_model == "facebook/bart-base":
model_name = "facebook/bart-base"
elif language_model == "Gpt-Neo":
model_name = "EleutherAI/gpt-neo-1.3B"
else: # Customized Models
model_name = "customized-models"
#tokenzer and text generation logic
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
#Call for a random artist from the dataset
correct_choice = pick_artist(dataset)
input_text = f"Write a song in the style of {correct_choice}:"
# Tuninng settings
max_length = 128
input_ids = tokenizer.encode(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
if language_model != "customized-models" or "Custom Gpt2":
### Using Beam search to generate text###
# encoded data
encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
# But this output is repeating, so I need ot adjust this so that it is not repeating.
elif language_model == "Custom Gpt2":
# tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
# model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
# encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, )
# Decode output
print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
else:
### Nucleas Sampling to generate text###
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
# top_p is the probability threshold for nucleus sampling
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
# This will help to generate more diverse text that is less repetitive
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
# Decode output
output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
# Remove the first line of the output if it contains newline characters
# if '\n' in output:
# output = '\n'.join(output.split('\n')[1:])
# formatted_output = output.split('\n')[0] # might have to remove this line
song_text = output
# Generate the multiple-choice options
options = generate_artist_options(dataset, correct_choice)
state['options'] = options
# Generate the multiple-choice check
multiple_choice_check = generate_multiple_choice_check(options, correct_choice)
state['multiple_choice_check'] = multiple_choice_check
state['correct_choice'] = correct_choice
return state, song_text, ', '.join(options)
#Check the selected artist and return whether it's correct
# def on_submit_answer(state, correct_choice, user_choice, submit_answer):
# if submit_answer:
# if not user_choice:
# return {"Error": "Please select an artist before submitting an answer."}
# # Check if 'correct_choice' is in the state keys
# if 'correct_choice' in state:
# correct_answer = state['correct_choice']
# if correct_answer == user_choice:
# return {"Result": f"You guessed the right artist: {correct_choice}"}
# else:
# return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"}
# else:
# print("The 'correct_choice' key does not exist in the state.")
# return None
def on_submit_answer(state, user_choice):
# Map the user's choice (A, B, C, or D) to an index
choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
index = choice_to_index[user_choice]
# Retrieve the user's choice and the correct choice from the state
user_artist = state['options'][index]
correct_artist = state['correct_choice']
# Compare the user's choice with the correct choice
if user_artist == correct_artist:
return {"CORRECT": f"You guessed the right artist: {correct_artist}"}
else:
return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"}
def pick_artist(dataset):
# Check if 'Artist' is in the dataset columns
artist_choice = list(set(dataset['train']['Artist']))
artist_choice = random.choice(artist_choice)
return artist_choice
# print("The 'Artist' column does not exist in the dataset.")
# artist_choice = "Green Day"
# return artist_choice
def generate_artist_options(dataset, correct_artist):
# Generate 3 incorrect options
all_artists = list(set(dataset['train']['Artist']))
if correct_artist in all_artists:
all_artists.remove(correct_artist)
options = random.sample(all_artists, 3) + [correct_artist]
random.shuffle(options)
return options
def generate_multiple_choice_check(options, correct_choice):
return {option: option == correct_choice for option in options}
def check_correct_choice(user_choice, correct_choice):
if user_choice == correct_choice:
return True
return user_choice == correct_choice
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
state = gr.State({'options': []})
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
generate_song_button = gr.Button("Generate Song")
generated_song = gr.Textbox(label="Generated Song")
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
# timer = gr.HTML("<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>", label="Timer")
submit_answer_button = gr.Button("Submit Answer")
correct_answer = gr.Textbox(label="Results")
generate_song_button.click(
generate_song,
[state, language_model, generate_song_button],
[state, generated_song, artist_choice_display,]
)
submit_answer_button.click(
on_submit_answer,
[state, user_choice,],
[correct_answer]
)
game_interface.launch() |