|
import gradio as gr |
|
from load import LoadModel |
|
from generate import GenerateMidiText |
|
from constants import INSTRUMENT_TRANSFER_CLASSES |
|
from decoder import TextDecoder |
|
from utils import get_miditok, index_has_substring |
|
from playback import get_music |
|
from matplotlib import pylab |
|
import sys |
|
import os |
|
import matplotlib |
|
from generation_utils import plot_piano_roll |
|
|
|
matplotlib.use("Agg") |
|
|
|
sys.modules["pylab"] = pylab |
|
|
|
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53" |
|
n_bar_generated = 8 |
|
|
|
|
|
|
|
model, tokenizer = LoadModel( |
|
model_repo, |
|
from_huggingface=True, |
|
).load_model_and_tokenizer() |
|
|
|
miditok = get_miditok() |
|
decoder = TextDecoder(miditok) |
|
|
|
|
|
def define_prompt(state, genesis): |
|
if len(state) == 0: |
|
input_prompt = "PIECE_START " |
|
else: |
|
input_prompt = genesis.get_whole_piece_from_bar_dict() |
|
return input_prompt |
|
|
|
|
|
def generator( |
|
label, |
|
regenerate, |
|
temp, |
|
density, |
|
instrument, |
|
state, |
|
piece_by_track, |
|
add_bars=False, |
|
add_bar_count=1, |
|
): |
|
|
|
genesis = GenerateMidiText(model, tokenizer, piece_by_track) |
|
track = {"label": label} |
|
inst = next( |
|
( |
|
inst |
|
for inst in INSTRUMENT_TRANSFER_CLASSES |
|
if inst["transfer_to"] == instrument |
|
), |
|
{"family_number": "DRUMS"}, |
|
)["family_number"] |
|
|
|
inst_index = -1 |
|
if state != []: |
|
for index, instrum in enumerate(state): |
|
if instrum["label"] == track["label"]: |
|
inst_index = index |
|
|
|
|
|
if not add_bars: |
|
|
|
if regenerate: |
|
state.pop(inst_index) |
|
genesis.delete_one_track(inst_index) |
|
|
|
generated_text = ( |
|
genesis.get_whole_piece_from_bar_dict() |
|
) |
|
inst_index = -1 |
|
|
|
|
|
input_prompt = define_prompt(state, genesis) |
|
generated_text = genesis.generate_one_new_track( |
|
inst, density, temp, input_prompt=input_prompt |
|
) |
|
|
|
regenerate = True |
|
else: |
|
|
|
genesis.generate_n_more_bars(add_bar_count) |
|
generated_text = genesis.get_whole_piece_from_bar_dict() |
|
|
|
|
|
decoder.get_midi(generated_text, "mixed.mid") |
|
mixed_inst_midi, mixed_audio = get_music("mixed.mid") |
|
|
|
inst_text = genesis.get_selected_track_as_text(inst_index) |
|
|
|
decoder.get_midi(inst_text, f"{instrument}.mid") |
|
_, inst_audio = get_music(f"{instrument}.mid") |
|
|
|
piano_roll = plot_piano_roll(mixed_inst_midi) |
|
track["text"] = inst_text |
|
state.append(track) |
|
|
|
return ( |
|
inst_text, |
|
(44100, inst_audio), |
|
piano_roll, |
|
state, |
|
(44100, mixed_audio), |
|
regenerate, |
|
genesis.piece_by_track, |
|
) |
|
|
|
|
|
def generated_text_from_state(state): |
|
generated_text_from_state = "PIECE_START " |
|
for track in state: |
|
generated_text_from_state += track["text"] |
|
return generated_text_from_state |
|
|
|
|
|
def save_midi_to_folder(state, midi_path): |
|
|
|
if os.path.exists(midi_path): |
|
print(f"The path {midi_path} already exists") |
|
else: |
|
os.mkdir(midi_path) |
|
print(f"Path '{midi_path}' created") |
|
|
|
|
|
if midi_path[-1] != "/": |
|
midi_path += "/" |
|
|
|
generated_text = generated_text_from_state(state) |
|
decoder.get_midi(generated_text, f"{midi_path}your_awesome_midi.mid") |
|
status = f"MIDI file saved to: {midi_path}your_awesome_midi.mid" |
|
|
|
return status |
|
|
|
|
|
def instrument_row(default_inst, row_id): |
|
with gr.Row(): |
|
row = gr.Variable(row_id) |
|
with gr.Column(scale=1, min_width=100): |
|
inst = gr.Dropdown( |
|
sorted([inst["transfer_to"] for inst in INSTRUMENT_TRANSFER_CLASSES]) |
|
+ ["Drums"], |
|
value=default_inst, |
|
label="Instrument", |
|
) |
|
temp = gr.Dropdown( |
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], |
|
value=0.7, |
|
label="Creativity", |
|
) |
|
density = gr.Dropdown([1, 2, 3], value=3, label="Note Density") |
|
|
|
with gr.Column(scale=3): |
|
output_txt = gr.Textbox( |
|
label="output", lines=10, max_lines=10, show_label=False |
|
) |
|
with gr.Column(scale=1, min_width=100): |
|
inst_audio = gr.Audio(label="TRACK Audio", show_label=True) |
|
regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False) |
|
|
|
|
|
gen_btn = gr.Button("Generate") |
|
|
|
gen_btn.click( |
|
fn=generator, |
|
inputs=[row, regenerate, temp, density, inst, state, piece_by_track], |
|
outputs=[ |
|
output_txt, |
|
inst_audio, |
|
piano_roll, |
|
state, |
|
mixed_audio, |
|
regenerate, |
|
piece_by_track, |
|
], |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
piece_by_track = gr.State([]) |
|
state = gr.State([]) |
|
title = gr.Markdown( |
|
""" # Demo-App of The-Jam-Machine |
|
A Generative AI trained on text transcription of MIDI music """ |
|
) |
|
|
|
track1_md = gr.Markdown(""" ## Mixed Audio and Piano Roll """) |
|
mixed_audio = gr.Audio(label="Mixed Audio") |
|
piano_roll = gr.Plot(label="Piano Roll", show_label=False) |
|
description = gr.Markdown( |
|
""" |
|
For each **TRACK**, choose your **instrument** along with **creativity** (temperature) and **note density**. Then, hit the **Generate** Button! |
|
You can have a look at the generated text; but most importantly, check the **piano roll** and listen to the TRACK audio! |
|
If you don't like the track, hit the generate button to regenerate it! Generate more tracks and listen to the **mixed audio**! |
|
""" |
|
) |
|
track1_md = gr.Markdown(""" ## TRACK 1 """) |
|
instrument_row("Drums", 0) |
|
track1_md = gr.Markdown(""" ## TRACK 2 """) |
|
instrument_row("Synth Bass 1", 1) |
|
track1_md = gr.Markdown(""" ## TRACK 3 """) |
|
instrument_row("Synth Lead Square", 2) |
|
|
|
|
|
track1_md = gr.Markdown(""" ## SAVE AS MIDI file - enter a valid path """) |
|
with gr.Row(): |
|
|
|
saving_path = gr.Textbox( |
|
value="/Users/ChatGPT/Downloads/TheJamMachine/", |
|
show_label=False, |
|
label="Saving Path", |
|
interactive=True, |
|
) |
|
|
|
gen_btn = gr.Button("Save MIDI") |
|
|
|
status = gr.Textbox(show_label=False) |
|
|
|
gen_btn.click(fn=save_midi_to_folder, inputs=[state, saving_path], outputs=status) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
""" |
|
TODO: add improvise button |
|
TODO: cleanup input output of generator |
|
TODO: add a way to add bars |
|
""" |
|
|