File size: 8,286 Bytes
6455751
c5e8d64
 
 
394bbaa
 
 
 
 
 
 
 
 
6455751
c5e8d64
6455751
c5e8d64
6455751
c5e8d64
 
 
 
 
 
 
394bbaa
c5e8d64
 
 
 
 
 
 
 
6455751
c5e8d64
 
 
394bbaa
 
 
 
 
 
c5e8d64
 
 
6455751
394bbaa
c5e8d64
 
394bbaa
c5e8d64
394bbaa
c5e8d64
394bbaa
 
 
 
 
 
 
c5e8d64
 
 
 
 
 
 
 
394bbaa
 
 
 
 
 
 
c5e8d64
 
394bbaa
c5e8d64
394bbaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5e8d64
394bbaa
 
 
c5e8d64
394bbaa
 
 
c5e8d64
394bbaa
c5e8d64
394bbaa
 
 
 
 
 
 
 
 
 
 
 
 
 
6455751
c5e8d64
6455751
394bbaa
6455751
394bbaa
6455751
394bbaa
 
6455751
394bbaa
 
6455751
 
c5e8d64
394bbaa
6455751
 
394bbaa
 
6455751
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import random

# Load the dataset
dataset = load_dataset("SpartanCinder/song-lyrics-artist-classifier")
# print(dataset.column_names)
# print(dataset['train']['Artist'])
# artist_list = list(set(dataset['train']['Artist']))
# print(artist_list)

def generate_song(state, language_model, generate_song):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Logic that takes put the selected language model and generates a song
    if generate_song:
        if not language_model:
            song_text = "Please select a language model before generating a song."
            return state, song_text, "", ""
        # Generate the song and the options based on the language_model
        if language_model == "Custom Gpt2":
            model_name = "SpartanCinder/GPT2-finetuned-lyric-generation"
        elif language_model == "Gpt2-Medium":
            model_name = "gpt2-medium"
        elif language_model == "facebook/bart-base":
            model_name = "facebook/bart-base"
        elif language_model == "Gpt-Neo":
            model_name = "EleutherAI/gpt-neo-1.3B"
        else:  # Customized Models
            model_name = "customized-models"

    #tokenzer and text generation logic
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    #Call for a random artist from the dataset
    correct_choice = pick_artist(dataset)
    input_text = f"Write a song in the style of {correct_choice}:"

    # Tuninng settings
    max_length = 128
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    input_ids = input_ids.to(device)

    if language_model != "customized-models" or "Custom Gpt2":
        ### Using Beam search to generate text###
        # encoded data
        encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
        # Decode output
        print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
        # But this output is repeating, so I need ot adjust this so that it is not repeating.
    elif language_model == "Custom Gpt2":
        # tokenizer = AutoTokenizer.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
        # model = AutoModelForCausalLM.from_pretrained("SpartanCinder/GPT2-pretrained-lyric-generation")
        # encoded_output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
        encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.95, ) 
        # Decode output
        print(tokenizer.decode(encoded_output[0], skip_special_tokens=True))
    else:
        ### Nucleas Sampling to generate text###
        # Set the do_sample parameter to True because we are using nucleus sampling  is a probabilistic sampling method
        # top_p is the probability threshold for nucleus sampling
        # So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
        # This will help to generate more diverse text that is less repetitive
        encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, ) 

    # Decode output
    output = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
    # Remove the first line of the output if it contains newline characters
    # if '\n' in output:
    #     output = '\n'.join(output.split('\n')[1:])
    # formatted_output = output.split('\n')[0] # might have to remove this line
    song_text = output

    # Generate the multiple-choice options
    options = generate_artist_options(dataset, correct_choice)
    state['options'] = options

    # Generate the multiple-choice check
    multiple_choice_check = generate_multiple_choice_check(options, correct_choice)
    state['multiple_choice_check'] = multiple_choice_check
    state['correct_choice'] = correct_choice

    return state, song_text, ', '.join(options)

#Check the selected artist and return whether it's correct
# def on_submit_answer(state, correct_choice, user_choice, submit_answer):
#     if submit_answer:
#         if not user_choice:
#             return {"Error": "Please select an artist before submitting an answer."}
#         # Check if 'correct_choice' is in the state keys
#         if 'correct_choice' in state:
#             correct_answer = state['correct_choice']
#             if correct_answer == user_choice:
#                 return {"Result": f"You guessed the right artist: {correct_choice}"}
#             else:
#                 return {"Result": f"You selected {user_choice}, but the correct answer is {correct_choice}"}
#         else:
#             print("The 'correct_choice' key does not exist in the state.")
#             return None

def on_submit_answer(state, user_choice):
    # Map the user's choice (A, B, C, or D) to an index
    choice_to_index = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
    index = choice_to_index[user_choice]

    # Retrieve the user's choice and the correct choice from the state
    user_artist = state['options'][index]
    correct_artist = state['correct_choice']

    # Compare the user's choice with the correct choice
    if user_artist == correct_artist:
        return {"CORRECT": f"You guessed the right artist: {correct_artist}"}
    else:
        return {"INCORRECT": f"You selected {user_choice}, but the correct answer is {correct_artist}"}

def pick_artist(dataset):
    # Check if 'Artist' is in the dataset columns
    
    artist_choice = list(set(dataset['train']['Artist']))
    artist_choice = random.choice(artist_choice)
    return artist_choice

        # print("The 'Artist' column does not exist in the dataset.")
        # artist_choice = "Green Day"
    # return artist_choice

def generate_artist_options(dataset, correct_artist):
    # Generate 3 incorrect options
    all_artists = list(set(dataset['train']['Artist']))
    if correct_artist in all_artists:
        all_artists.remove(correct_artist)
    options = random.sample(all_artists, 3) + [correct_artist]
    random.shuffle(options)
    return options

def generate_multiple_choice_check(options, correct_choice):
    return {option: option == correct_choice for option in options}

def check_correct_choice(user_choice, correct_choice):
    if user_choice == correct_choice:
        return True
    return user_choice == correct_choice

with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
    state = gr.State({'options': []})
    language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Model Selection", info="Select the language model to generate the song.")
    generate_song_button = gr.Button("Generate Song")
    generated_song = gr.Textbox(label="Generated Song")
    artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
    user_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
    # timer = gr.HTML("<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>", label="Timer")
    submit_answer_button = gr.Button("Submit Answer")
    correct_answer = gr.Textbox(label="Results")
    
    generate_song_button.click(
        generate_song,
        [state, language_model, generate_song_button],
        [state, generated_song, artist_choice_display,]
    )
    submit_answer_button.click(
        on_submit_answer,
        [state, user_choice,],
        [correct_answer]
    )

game_interface.launch()