Spaces:
Runtime error
Runtime error
chomayouni
commited on
Commit
·
c5e8d64
1
Parent(s):
6455751
Updated with initial logic
Browse files- sgg_app.py +65 -59
sgg_app.py
CHANGED
@@ -1,65 +1,61 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
# if generate_song:
|
5 |
-
# # Generate the song and the options based on the difficulty
|
6 |
-
# # In the actual implementation, you would generate the song text and the options based on the difficulty
|
7 |
-
# if difficulty == "Demo":
|
8 |
-
# song_text = "Generated song text for Demo"
|
9 |
-
# options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
|
10 |
-
# elif difficulty == "Medium":
|
11 |
-
# song_text = "Generated song text for Medium"
|
12 |
-
# options = ["Artist 5", "Artist 6", "Artist 7", "Artist 8"]
|
13 |
-
# else: # Hard
|
14 |
-
# song_text = "Generated song text for Hard"
|
15 |
-
# options = ["Artist 9", "Artist 10", "Artist 11", "Artist 12"]
|
16 |
-
# return {"Generated Song": song_text, "Options": options}
|
17 |
-
# elif submit_answer:
|
18 |
-
# # Check the selected artist and return whether it's correct
|
19 |
-
# correct_answer = "Artist 1" # Placeholder
|
20 |
-
# return {"Correct Answer": correct_answer == artist_choice}
|
21 |
|
22 |
-
|
23 |
-
# fn=game,
|
24 |
-
# inputs=[
|
25 |
-
# gr.Radio(["Demo", "Medium", "Hard"], label="Difficulty",
|
26 |
-
# info="The higher the difficulty makes it so that the options for the artists are more similar to one another?"),
|
27 |
-
# gr.Button("Generate Song"),
|
28 |
-
# gr.Radio(["A", "B", "C", "D"], label="Multi-Choice Options",
|
29 |
-
# info="Select the artist that you suspect is the correct artist for the song."),
|
30 |
-
# gr.Button("Submit Answer"),
|
31 |
-
# ],
|
32 |
-
|
33 |
-
# outputs=[
|
34 |
-
# gr.Textbox(label="Generated Song"),
|
35 |
-
# gr.Textbox(label="Options"),
|
36 |
-
# gr.Textbox(label="Correct Answer"),
|
37 |
-
# ],
|
38 |
|
39 |
-
#
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def submit_answer(state, artist_choice, submit_answer):
|
65 |
if submit_answer:
|
@@ -69,10 +65,20 @@ def submit_answer(state, artist_choice, submit_answer):
|
|
69 |
# Check the selected artist and return whether it's correct
|
70 |
correct_answer = state['options'][0] # Placeholder
|
71 |
return {"Correct Answer": correct_answer == artist_choice}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
with gr.Blocks(title="Song
|
74 |
state = gr.State({'options': []})
|
75 |
-
|
76 |
generate_song_button = gr.Button("Generate Song")
|
77 |
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
|
78 |
artist_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
|
@@ -83,7 +89,7 @@ with gr.Blocks(title="Song Genorator Guessing Game") as game_interface:
|
|
83 |
|
84 |
generate_song_button.click(
|
85 |
generate_song,
|
86 |
-
[state,
|
87 |
[state, generated_song, artist_choice_display, timer]
|
88 |
)
|
89 |
submit_answer_button.click(
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from transformers import TrainingArguments, Trainer
|
5 |
|
6 |
+
def generate_song(state, language_model, generate_song):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
# Logic that takes put the selected language model and generates a song
|
11 |
+
if generate_song:
|
12 |
+
if not language_model:
|
13 |
+
song_text = "Please select a language model before generating a song."
|
14 |
+
return state, song_text, "", ""
|
15 |
+
# Generate the song and the options based on the language_model
|
16 |
+
if language_model == "Custom Gpt2":
|
17 |
+
model_name = "SpartanCinder/GPT2-pretrained-lyric-generation"
|
18 |
+
elif language_model == "Gpt2-Medium":
|
19 |
+
model_name = "gpt2-medium"
|
20 |
+
elif language_model == "facebook/bart-base":
|
21 |
+
model_name = "facebook/bart-base"
|
22 |
+
elif language_model == "Gpt-Neo":
|
23 |
+
model_name = "EleutherAI/gpt-neo-1.3B"
|
24 |
+
else: # Customized Models
|
25 |
+
model_name = "customized-models"
|
26 |
|
27 |
+
#tokenzer and text generation logic
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
29 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
30 |
+
input_text = pick_artist()
|
31 |
+
max_length = 128
|
32 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt")
|
33 |
+
input_ids = input_ids.to(device)
|
34 |
|
35 |
+
if language_model != "customized-models":
|
36 |
+
### Using Beam search to generate text###
|
37 |
+
# encoded data
|
38 |
+
output = model.generate(input_ids, max_length=max_length, num_beams=5, num_return_sequences=5, do_sample=False, no_repeat_ngram_size=2) # Generate text
|
39 |
+
# Decode output
|
40 |
+
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
41 |
+
# But this output is repeating, so I need ot adjust this so that it is not repeating.
|
42 |
+
else:
|
43 |
+
### Nucleas Sampling to generate text###
|
44 |
+
# Set the do_sample parameter to True because we are using nucleus sampling is a probabilistic sampling method
|
45 |
+
# top_p is the probability threshold for nucleus sampling
|
46 |
+
# So, we set top_p to 0.9, which means that the model will sample from the top 90% of the probability distribution
|
47 |
+
# This will help to generate more diverse text that is less repetitive
|
48 |
+
encoded_output = model.generate(input_ids, max_length=max_length, num_return_sequences=5, do_sample=True, top_p = 0.9, )
|
49 |
+
|
50 |
+
song_text = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
|
51 |
+
|
52 |
+
# Generate the multiple-choice options
|
53 |
+
options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
|
54 |
+
|
55 |
+
state['options'] = options
|
56 |
+
state['timer_finished'] = False
|
57 |
+
timer_script = "<div id='progress-bar' style='width: 100%; background-color: #f3f3f3; border: 1px solid #bbb;'><div id='progress' style='height: 20px; width: 0%; background-color: #007bff;'></div></div><script>function startTimer() {var time = 30; var timer = setInterval(function() {time--; document.getElementById('progress').style.width = (time / 30 * 100) + '%'; if (time <= 0) {clearInterval(timer);}}, 1000);}</script>"
|
58 |
+
return state, song_text, ', '.join(options), timer_script
|
59 |
|
60 |
def submit_answer(state, artist_choice, submit_answer):
|
61 |
if submit_answer:
|
|
|
65 |
# Check the selected artist and return whether it's correct
|
66 |
correct_answer = state['options'][0] # Placeholder
|
67 |
return {"Correct Answer": correct_answer == artist_choice}
|
68 |
+
|
69 |
+
def pick_artist():
|
70 |
+
|
71 |
+
return "A song in the style of Taylor Swift:"
|
72 |
+
|
73 |
+
def generate_artist_options(correct_artist):
|
74 |
+
# Generate 3 incorrect options
|
75 |
+
options = ["Artist 1", "Artist 2", "Artist 3", "Artist 4"]
|
76 |
+
options.remove(correct_artist)
|
77 |
+
return [correct_artist] + options
|
78 |
|
79 |
+
with gr.Blocks(title="Song Generator Guessing Game") as game_interface:
|
80 |
state = gr.State({'options': []})
|
81 |
+
language_model = gr.Radio(["Custom Gpt2", "Gpt2-Medium", "facebook/bart-base","Gpt-Neo", "Customized Models"], label="Difficulty")
|
82 |
generate_song_button = gr.Button("Generate Song")
|
83 |
artist_choice_display = gr.Textbox(interactive=False, label="Multiple-Choice Options")
|
84 |
artist_choice = gr.Radio(["A", "B", "C", "D"], label="Updated Options", info="Select the artist that you suspect is the correct artist for the song.")
|
|
|
89 |
|
90 |
generate_song_button.click(
|
91 |
generate_song,
|
92 |
+
[state, language_model, generate_song_button],
|
93 |
[state, generated_song, artist_choice_display, timer]
|
94 |
)
|
95 |
submit_answer_button.click(
|