m41w4r3.exe
commited on
Commit
·
2ec0615
1
Parent(s):
3e2b7ea
initial commit
Browse files- README.md +10 -11
- app.py +154 -0
- constants.py +77 -0
- decoder.py +197 -0
- generate.py +489 -0
- load.py +60 -0
- playback.py +35 -0
- requirements.txt +5 -0
- utils.py +246 -0
README.md
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
-
|
2 |
-
title: The Jam Machine
|
3 |
-
emoji: 🏃
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.14.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Contributors:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
Jean Simonnet:
|
4 |
+
https://github.com/misnaej
|
5 |
+
https://www.linkedin.com/in/jeansimonnet/
|
6 |
+
Louis Demetz:
|
7 |
+
https://github.com/louis-demetz
|
8 |
+
https://www.linkedin.com/in/ldemetz/
|
9 |
+
Halid Bayram:
|
10 |
+
https://github.com/m41w4r3exe
|
11 |
+
https://www.linkedin.com/in/halid-bayram-6b9ba861/
|
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from load import LoadModel
|
3 |
+
from generate import GenerateMidiText
|
4 |
+
from constants import INSTRUMENT_CLASSES
|
5 |
+
from encoder import MIDIEncoder
|
6 |
+
from decoder import TextDecoder
|
7 |
+
from utils import get_miditok, index_has_substring
|
8 |
+
from playback import get_music
|
9 |
+
from matplotlib import pylab
|
10 |
+
import sys
|
11 |
+
import matplotlib
|
12 |
+
from generation_utils import plot_piano_roll
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
matplotlib.use("Agg")
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
sys.modules["pylab"] = pylab
|
19 |
+
|
20 |
+
model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
|
21 |
+
revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
|
22 |
+
n_bar_generated = 8
|
23 |
+
# model_repo = "JammyMachina/improved_4bars-mdl"
|
24 |
+
# n_bar_generated = 4
|
25 |
+
|
26 |
+
model, tokenizer = LoadModel(
|
27 |
+
model_repo, from_huggingface=True, revision=revision
|
28 |
+
).load_model_and_tokenizer()
|
29 |
+
genesis = GenerateMidiText(
|
30 |
+
model,
|
31 |
+
tokenizer,
|
32 |
+
)
|
33 |
+
genesis.set_nb_bars_generated(n_bars=n_bar_generated)
|
34 |
+
|
35 |
+
miditok = get_miditok()
|
36 |
+
decoder = TextDecoder(miditok)
|
37 |
+
|
38 |
+
|
39 |
+
def define_prompt(state, genesis):
|
40 |
+
if len(state) == 0:
|
41 |
+
input_prompt = "PIECE_START "
|
42 |
+
else:
|
43 |
+
input_prompt = genesis.get_whole_piece_from_bar_dict()
|
44 |
+
return input_prompt
|
45 |
+
|
46 |
+
|
47 |
+
def generator(
|
48 |
+
regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
|
49 |
+
):
|
50 |
+
|
51 |
+
inst = next(
|
52 |
+
(inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
|
53 |
+
{"family_number": "DRUMS"},
|
54 |
+
)["family_number"]
|
55 |
+
|
56 |
+
inst_index = index_has_substring(state, "INST=" + str(inst))
|
57 |
+
|
58 |
+
# Regenerate
|
59 |
+
if regenerate:
|
60 |
+
state.pop(inst_index)
|
61 |
+
genesis.delete_one_track(inst_index)
|
62 |
+
generated_text = (
|
63 |
+
genesis.get_whole_piece_from_bar_dict()
|
64 |
+
) # maybe not useful here
|
65 |
+
inst_index = -1 # reset to last generated
|
66 |
+
|
67 |
+
# Generate
|
68 |
+
if not add_bars:
|
69 |
+
# NEW TRACK
|
70 |
+
input_prompt = define_prompt(state, genesis)
|
71 |
+
generated_text = genesis.generate_one_new_track(
|
72 |
+
inst, density, temp, input_prompt=input_prompt
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
# NEW BARS
|
76 |
+
genesis.generate_n_more_bars(add_bar_count) # for all instruments
|
77 |
+
generated_text = genesis.get_whole_piece_from_bar_dict()
|
78 |
+
|
79 |
+
decoder.get_midi(generated_text, "tmp/mixed.mid")
|
80 |
+
mixed_inst_midi, mixed_audio = get_music("tmp/mixed.mid")
|
81 |
+
|
82 |
+
inst_text = genesis.get_selected_track_as_text(inst_index)
|
83 |
+
inst_midi_name = f"tmp/{instrument}.mid"
|
84 |
+
decoder.get_midi(inst_text, inst_midi_name)
|
85 |
+
_, inst_audio = get_music(inst_midi_name)
|
86 |
+
piano_roll = plot_piano_roll(mixed_inst_midi)
|
87 |
+
state.append(inst_text)
|
88 |
+
|
89 |
+
return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)
|
90 |
+
|
91 |
+
|
92 |
+
def instrument_row(default_inst):
|
93 |
+
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column(scale=1, min_width=50):
|
96 |
+
inst = gr.Dropdown(
|
97 |
+
[inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
|
98 |
+
value=default_inst,
|
99 |
+
label="Instrument",
|
100 |
+
)
|
101 |
+
temp = gr.Number(value=0.7, label="Creativity")
|
102 |
+
density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")
|
103 |
+
|
104 |
+
with gr.Column(scale=3):
|
105 |
+
output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
|
106 |
+
with gr.Column(scale=1, min_width=100):
|
107 |
+
inst_audio = gr.Audio(label="Audio")
|
108 |
+
regenerate = gr.Checkbox(value=False, label="Regenerate")
|
109 |
+
# add_bars = gr.Checkbox(value=False, label="Add Bars")
|
110 |
+
# add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
|
111 |
+
gen_btn = gr.Button("Generate")
|
112 |
+
gen_btn.click(
|
113 |
+
fn=generator,
|
114 |
+
inputs=[
|
115 |
+
regenerate,
|
116 |
+
temp,
|
117 |
+
density,
|
118 |
+
inst,
|
119 |
+
state,
|
120 |
+
],
|
121 |
+
outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
with gr.Blocks(cache_examples=False) as demo:
|
126 |
+
state = gr.State([])
|
127 |
+
mixed_audio = gr.Audio(label="Mixed Audio")
|
128 |
+
piano_roll = gr.Plot(label="Piano Roll")
|
129 |
+
instrument_row("Drums")
|
130 |
+
instrument_row("Bass")
|
131 |
+
instrument_row("Synth Lead")
|
132 |
+
# instrument_row("Piano")
|
133 |
+
|
134 |
+
demo.launch(debug=True)
|
135 |
+
|
136 |
+
"""
|
137 |
+
TODO: DEPLOY
|
138 |
+
TODO: temp file situation
|
139 |
+
TODO: clear cache situation
|
140 |
+
TODO: reset button
|
141 |
+
TODO: instrument mapping business
|
142 |
+
TODO: Y lim axis of piano roll
|
143 |
+
TODO: add a button to save the generated midi
|
144 |
+
TODO: add improvise button
|
145 |
+
TODO: making the piano roll fit on the horizontal scale
|
146 |
+
TODO: set values for temperature as it is done for density
|
147 |
+
TODO: set the color situation to be dark background
|
148 |
+
TODO: make regeration default when an intrument has already been track has already been generated
|
149 |
+
TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
|
150 |
+
TODO: row height to fix
|
151 |
+
|
152 |
+
TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
|
153 |
+
TODO: block regenerate if add bar on
|
154 |
+
"""
|
constants.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# fmt: off
|
2 |
+
# Instrument mapping and mapping functions
|
3 |
+
INSTRUMENT_CLASSES = [
|
4 |
+
{"name": "Piano", "program_range": range(0, 8), "family_number": 0},
|
5 |
+
{"name": "Chromatic Percussion", "program_range": range(8, 16), "family_number": 1},
|
6 |
+
{"name": "Organ", "program_range": range(16, 24), "family_number": 2},
|
7 |
+
{"name": "Guitar", "program_range": range(24, 32), "family_number": 3},
|
8 |
+
{"name": "Bass", "program_range": range(32, 40), "family_number": 4},
|
9 |
+
{"name": "Strings", "program_range": range(40, 48), "family_number": 5},
|
10 |
+
{"name": "Ensemble", "program_range": range(48, 56), "family_number": 6},
|
11 |
+
{"name": "Brass", "program_range": range(56, 64), "family_number": 7},
|
12 |
+
{"name": "Reed", "program_range": range(64, 72), "family_number": 8},
|
13 |
+
{"name": "Pipe", "program_range": range(72, 80), "family_number": 9},
|
14 |
+
{"name": "Synth Lead", "program_range": range(80, 88), "family_number": 10},
|
15 |
+
{"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
|
16 |
+
{"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
|
17 |
+
{"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
|
18 |
+
{"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
|
19 |
+
{"name": "Sound Effects", "program_range": range(120, 128), "family_number": 15,},
|
20 |
+
]
|
21 |
+
# fmt: on
|
22 |
+
|
23 |
+
# Instrument mapping for decodiing our midi sequence into midi instruments of our choice
|
24 |
+
INSTRUMENT_TRANSFER_CLASSES = [
|
25 |
+
{
|
26 |
+
"name": "Piano",
|
27 |
+
"program_range": [4], # Electric Piano 1
|
28 |
+
"family_number": 0,
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"name": "Chromatic Percussion",
|
32 |
+
"program_range": [11], # Vibraphone
|
33 |
+
"family_number": 1,
|
34 |
+
},
|
35 |
+
{"name": "Organ", "program_range": [17], "family_number": 2}, # Percussive Organ
|
36 |
+
{
|
37 |
+
"name": "Guitar",
|
38 |
+
"program_range": [80], # Synth Lead Square
|
39 |
+
"family_number": 3,
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "Bass",
|
43 |
+
"program_range": [38], # Synth bass 1,
|
44 |
+
"family_number": 4,
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"name": "Strings",
|
48 |
+
"program_range": [50], # Synth Strings 1
|
49 |
+
"family_number": 5,
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"name": "Ensemble",
|
53 |
+
"program_range": [51], # Synth Strings 2
|
54 |
+
"family_number": 6,
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"name": "Brass",
|
58 |
+
"program_range": [63], # 63 Synth Brass 1,
|
59 |
+
"family_number": 7,
|
60 |
+
},
|
61 |
+
{"name": "Reed", "program_range": [64], "family_number": 8}, # Synth Brass 2
|
62 |
+
{"name": "Pipe", "program_range": [82], "family_number": 9}, # Lead 3
|
63 |
+
{
|
64 |
+
"name": "Synth Lead",
|
65 |
+
"program_range": [81], # Synth Lead Sawtooth
|
66 |
+
"family_number": 10,
|
67 |
+
},
|
68 |
+
{"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
|
69 |
+
{"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
|
70 |
+
{"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
|
71 |
+
{"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
|
72 |
+
{
|
73 |
+
"name": "Sound Effects",
|
74 |
+
"program_range": range(120, 128),
|
75 |
+
"family_number": 15,
|
76 |
+
},
|
77 |
+
]
|
decoder.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
from familizer import Familizer
|
3 |
+
from miditok import Event
|
4 |
+
|
5 |
+
|
6 |
+
class TextDecoder:
|
7 |
+
"""Decodes text into:
|
8 |
+
1- List of events
|
9 |
+
2- Then converts these events to midi file via MidiTok and miditoolkit
|
10 |
+
|
11 |
+
:param tokenizer: from MidiTok
|
12 |
+
|
13 |
+
Usage with write_to_midi method:
|
14 |
+
args: text(String) example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
15 |
+
returns: midi file from miditoolkit
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, tokenizer, familized=True):
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.familized = familized
|
21 |
+
|
22 |
+
def decode(self, text):
|
23 |
+
r"""converts from text to instrument events
|
24 |
+
Args:
|
25 |
+
text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dict{inst_id: List[Events]}: List of events of Notes with velocities, aggregated Timeshifts, for each instrument
|
29 |
+
"""
|
30 |
+
piece_events = self.text_to_events(text)
|
31 |
+
inst_events = self.piece_to_inst_events(piece_events)
|
32 |
+
events = self.add_timeshifts_for_empty_bars(inst_events)
|
33 |
+
events = self.aggregate_timeshifts(events)
|
34 |
+
events = self.add_velocity(events)
|
35 |
+
return events
|
36 |
+
|
37 |
+
def tokenize(self, events):
|
38 |
+
r"""converts from events to MidiTok tokens
|
39 |
+
Args:
|
40 |
+
events (Dict{inst_id: List[Events]}): List of events for each instrument
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
List[List[Events]]: List of tokens for each instrument
|
44 |
+
"""
|
45 |
+
tokens = []
|
46 |
+
for inst in events.keys():
|
47 |
+
tokens.append(self.tokenizer.events_to_tokens(events[inst]))
|
48 |
+
return tokens
|
49 |
+
|
50 |
+
def get_midi(self, text, filename=None):
|
51 |
+
r"""converts from text to midi
|
52 |
+
Args:
|
53 |
+
text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
miditoolkit midi: Returns and writes to midi
|
57 |
+
"""
|
58 |
+
events = self.decode(text)
|
59 |
+
tokens = self.tokenize(events)
|
60 |
+
instruments = self.get_instruments_tuple(events)
|
61 |
+
midi = self.tokenizer.tokens_to_midi(tokens, instruments)
|
62 |
+
|
63 |
+
if filename is not None:
|
64 |
+
midi.dump(f"{filename}")
|
65 |
+
print(f"midi file written: {filename}")
|
66 |
+
|
67 |
+
return midi
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def text_to_events(text):
|
71 |
+
events = []
|
72 |
+
for word in text.split(" "):
|
73 |
+
# TODO: Handle bar and track values with a counter
|
74 |
+
_event = word.split("=")
|
75 |
+
value = _event[1] if len(_event) > 1 else None
|
76 |
+
event = get_event(_event[0], value)
|
77 |
+
if event:
|
78 |
+
events.append(event)
|
79 |
+
return events
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def piece_to_inst_events(piece_events):
|
83 |
+
"""Converts piece events of 8 bars to instrument events for entire song
|
84 |
+
|
85 |
+
Args:
|
86 |
+
piece_events (List[Events]): List of events of Notes, Timeshifts, Bars, Tracks
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Dict{inst_id: List[Events]}: List of events for each instrument
|
90 |
+
|
91 |
+
"""
|
92 |
+
inst_events = {}
|
93 |
+
current_instrument = -1
|
94 |
+
for event in piece_events:
|
95 |
+
if event.type == "Instrument":
|
96 |
+
current_instrument = event.value
|
97 |
+
if current_instrument not in inst_events:
|
98 |
+
inst_events[current_instrument] = []
|
99 |
+
elif current_instrument != -1:
|
100 |
+
inst_events[current_instrument].append(event)
|
101 |
+
return inst_events
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def add_timeshifts_for_empty_bars(inst_events):
|
105 |
+
"""Adds time shift events instead of consecutive [BAR_START BAR_END] events"""
|
106 |
+
new_inst_events = {}
|
107 |
+
for inst, events in inst_events.items():
|
108 |
+
new_inst_events[inst] = []
|
109 |
+
for index, event in enumerate(events):
|
110 |
+
if event.type == "Bar-End" or event.type == "Bar-Start":
|
111 |
+
if events[index - 1].type == "Bar-Start":
|
112 |
+
new_inst_events[inst].append(Event("Time-Shift", "4.0.8"))
|
113 |
+
else:
|
114 |
+
new_inst_events[inst].append(event)
|
115 |
+
return new_inst_events
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def add_timeshifts(beat_values1, beat_values2):
|
119 |
+
"""Adds two beat values
|
120 |
+
|
121 |
+
Args:
|
122 |
+
beat_values1 (String): like 0.3.8
|
123 |
+
beat_values2 (String): like 1.7.8
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
beat_str (String): added beats like 2.2.8 for example values
|
127 |
+
"""
|
128 |
+
value1 = to_base10(beat_values1)
|
129 |
+
value2 = to_base10(beat_values2)
|
130 |
+
return to_beat_str(value1 + value2)
|
131 |
+
|
132 |
+
def aggregate_timeshifts(self, events):
|
133 |
+
"""Aggregates consecutive time shift events bigger than a bar
|
134 |
+
-> like Timeshift 4.0.8
|
135 |
+
|
136 |
+
Args:
|
137 |
+
events (_type_): _description_
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
_type_: _description_
|
141 |
+
"""
|
142 |
+
new_events = {}
|
143 |
+
for inst, events in events.items():
|
144 |
+
inst_events = []
|
145 |
+
for i, event in enumerate(events):
|
146 |
+
if (
|
147 |
+
event.type == "Time-Shift"
|
148 |
+
and len(inst_events) > 0
|
149 |
+
and inst_events[-1].type == "Time-Shift"
|
150 |
+
):
|
151 |
+
inst_events[-1].value = self.add_timeshifts(
|
152 |
+
inst_events[-1].value, event.value
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
inst_events.append(event)
|
156 |
+
new_events[inst] = inst_events
|
157 |
+
return new_events
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def add_velocity(events):
|
161 |
+
"""Adds default velocity 99 to note events since they are removed from text, needed to generate midi"""
|
162 |
+
new_events = {}
|
163 |
+
for inst, events in events.items():
|
164 |
+
inst_events = []
|
165 |
+
for event in events:
|
166 |
+
inst_events.append(event)
|
167 |
+
if event.type == "Note-On":
|
168 |
+
inst_events.append(Event("Velocity", 99))
|
169 |
+
new_events[inst] = inst_events
|
170 |
+
return new_events
|
171 |
+
|
172 |
+
def get_instruments_tuple(self, events):
|
173 |
+
"""Returns instruments tuple for midi generation"""
|
174 |
+
instruments = []
|
175 |
+
for inst in events.keys():
|
176 |
+
is_drum = 0
|
177 |
+
if inst == "DRUMS":
|
178 |
+
inst = 0
|
179 |
+
is_drum = 1
|
180 |
+
if self.familized:
|
181 |
+
inst = Familizer(arbitrary=True).get_program_number(int(inst)) + 1
|
182 |
+
instruments.append((int(inst), is_drum))
|
183 |
+
return tuple(instruments)
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
|
188 |
+
filename = "midi/generated/misnaej/the-jam-machine-elec-famil/20221209_175750"
|
189 |
+
encoded_json = readFromFile(
|
190 |
+
f"{filename}.json",
|
191 |
+
True,
|
192 |
+
)
|
193 |
+
encoded_text = encoded_json["sequence"]
|
194 |
+
# encoded_text = "PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=69 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=69 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=57 TIME_DELTA=1 NOTE_OFF=57 NOTE_ON=56 TIME_DELTA=1 NOTE_OFF=56 NOTE_ON=64 NOTE_ON=60 NOTE_ON=55 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=55 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=59 NOTE_ON=55 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=59 NOTE_OFF=50 NOTE_OFF=55 NOTE_OFF=50 BAR_END BAR_START BAR_END TRACK_END"
|
195 |
+
|
196 |
+
miditok = get_miditok()
|
197 |
+
TextDecoder(miditok).get_midi(encoded_text, filename=filename)
|
generate.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from generation_utils import *
|
2 |
+
from utils import WriteTextMidiToFile, get_miditok
|
3 |
+
from load import LoadModel
|
4 |
+
from constants import INSTRUMENT_CLASSES
|
5 |
+
|
6 |
+
## import for execution
|
7 |
+
from decoder import TextDecoder
|
8 |
+
from playback import get_music, show_piano_roll
|
9 |
+
|
10 |
+
|
11 |
+
class GenerateMidiText:
|
12 |
+
"""Generating music with Class
|
13 |
+
|
14 |
+
LOGIC:
|
15 |
+
|
16 |
+
FOR GENERATING FROM SCRATCH:
|
17 |
+
- self.generate_one_new_track()
|
18 |
+
it calls
|
19 |
+
- self.generate_until_track_end()
|
20 |
+
|
21 |
+
FOR GENERATING NEW BARS:
|
22 |
+
- self.generate_one_more_bar()
|
23 |
+
it calls
|
24 |
+
- self.process_prompt_for_next_bar()
|
25 |
+
- self.generate_until_track_end()"""
|
26 |
+
|
27 |
+
def __init__(self, model, tokenizer):
|
28 |
+
self.model = model
|
29 |
+
self.tokenizer = tokenizer
|
30 |
+
# default initialization
|
31 |
+
self.initialize_default_parameters()
|
32 |
+
self.initialize_dictionaries()
|
33 |
+
|
34 |
+
"""Setters"""
|
35 |
+
|
36 |
+
def initialize_default_parameters(self):
|
37 |
+
self.set_device()
|
38 |
+
self.set_attention_length()
|
39 |
+
self.generate_until = "TRACK_END"
|
40 |
+
self.set_force_sequence_lenth()
|
41 |
+
self.set_nb_bars_generated()
|
42 |
+
self.set_improvisation_level(0)
|
43 |
+
|
44 |
+
def initialize_dictionaries(self):
|
45 |
+
self.piece_by_track = []
|
46 |
+
|
47 |
+
def set_device(self, device="cpu"):
|
48 |
+
self.device = ("cpu",)
|
49 |
+
|
50 |
+
def set_attention_length(self):
|
51 |
+
self.max_length = self.model.config.n_positions
|
52 |
+
print(
|
53 |
+
f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
|
54 |
+
)
|
55 |
+
|
56 |
+
def set_force_sequence_lenth(self, force_sequence_length=True):
|
57 |
+
self.force_sequence_length = force_sequence_length
|
58 |
+
|
59 |
+
def set_improvisation_level(self, improvisation_value):
|
60 |
+
self.no_repeat_ngram_size = improvisation_value
|
61 |
+
print("--------------------")
|
62 |
+
print(f"no_repeat_ngram_size set to {improvisation_value}")
|
63 |
+
print("--------------------")
|
64 |
+
|
65 |
+
def reset_temperatures(self, track_id, temperature):
|
66 |
+
self.piece_by_track[track_id]["temperature"] = temperature
|
67 |
+
|
68 |
+
def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model
|
69 |
+
self.model_n_bar = n_bars
|
70 |
+
|
71 |
+
""" Generation Tools - Dictionnaries """
|
72 |
+
|
73 |
+
def initiate_track_dict(self, instr, density, temperature):
|
74 |
+
label = len(self.piece_by_track)
|
75 |
+
self.piece_by_track.append(
|
76 |
+
{
|
77 |
+
"label": f"track_{label}",
|
78 |
+
"instrument": instr,
|
79 |
+
"density": density,
|
80 |
+
"temperature": temperature,
|
81 |
+
"bars": [],
|
82 |
+
}
|
83 |
+
)
|
84 |
+
|
85 |
+
def update_track_dict__add_bars(self, bars, track_id):
|
86 |
+
"""Add bars to the track dictionnary"""
|
87 |
+
for bar in self.striping_track_ends(bars).split("BAR_START "):
|
88 |
+
if bar == "": # happens is there is one bar only
|
89 |
+
continue
|
90 |
+
else:
|
91 |
+
if "TRACK_START" in bar:
|
92 |
+
self.piece_by_track[track_id]["bars"].append(bar)
|
93 |
+
else:
|
94 |
+
self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)
|
95 |
+
|
96 |
+
def get_all_instr_bars(self, track_id):
|
97 |
+
return self.piece_by_track[track_id]["bars"]
|
98 |
+
|
99 |
+
def striping_track_ends(self, text):
|
100 |
+
if "TRACK_END" in text:
|
101 |
+
# first get rid of extra space if any
|
102 |
+
# then gets rid of "TRACK_END"
|
103 |
+
text = text.rstrip(" ").rstrip("TRACK_END")
|
104 |
+
return text
|
105 |
+
|
106 |
+
def get_last_generated_track(self, full_piece):
|
107 |
+
track = (
|
108 |
+
"TRACK_START "
|
109 |
+
+ self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
|
110 |
+
+ "TRACK_END "
|
111 |
+
) # forcing the space after track and
|
112 |
+
return track
|
113 |
+
|
114 |
+
def get_selected_track_as_text(self, track_id):
|
115 |
+
text = ""
|
116 |
+
for bar in self.piece_by_track[track_id]["bars"]:
|
117 |
+
text += bar
|
118 |
+
text += "TRACK_END "
|
119 |
+
return text
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def get_newly_generated_text(input_prompt, full_piece):
|
123 |
+
return full_piece[len(input_prompt) :]
|
124 |
+
|
125 |
+
def get_whole_piece_from_bar_dict(self):
|
126 |
+
text = "PIECE_START "
|
127 |
+
for track_id, _ in enumerate(self.piece_by_track):
|
128 |
+
text += self.get_selected_track_as_text(track_id)
|
129 |
+
return text
|
130 |
+
|
131 |
+
def delete_one_track(self, track): # TO BE TESTED
|
132 |
+
self.piece_by_track.pop(track)
|
133 |
+
|
134 |
+
# def update_piece_dict__add_track(self, track_id, track):
|
135 |
+
# self.piece_dict[track_id] = track
|
136 |
+
|
137 |
+
# def update_all_dictionnaries__add_track(self, track):
|
138 |
+
# self.update_piece_dict__add_track(track_id, track)
|
139 |
+
|
140 |
+
"""Basic generation tools"""
|
141 |
+
|
142 |
+
def tokenize_input_prompt(self, input_prompt, verbose=True):
|
143 |
+
"""Tokenizing prompt
|
144 |
+
|
145 |
+
Args:
|
146 |
+
- input_prompt (str): prompt to tokenize
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
- input_prompt_ids (torch.tensor): tokenized prompt
|
150 |
+
"""
|
151 |
+
if verbose:
|
152 |
+
print("Tokenizing input_prompt...")
|
153 |
+
|
154 |
+
return self.tokenizer.encode(input_prompt, return_tensors="pt")
|
155 |
+
|
156 |
+
def generate_sequence_of_token_ids(
|
157 |
+
self,
|
158 |
+
input_prompt_ids,
|
159 |
+
temperature,
|
160 |
+
verbose=True,
|
161 |
+
):
|
162 |
+
"""
|
163 |
+
generate a sequence of token ids based on input_prompt_ids
|
164 |
+
The sequence length depends on the trained model (self.model_n_bar)
|
165 |
+
"""
|
166 |
+
generated_ids = self.model.generate(
|
167 |
+
input_prompt_ids,
|
168 |
+
max_length=self.max_length,
|
169 |
+
do_sample=True,
|
170 |
+
temperature=temperature,
|
171 |
+
no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0
|
172 |
+
eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good
|
173 |
+
)
|
174 |
+
|
175 |
+
if verbose:
|
176 |
+
print("Generating a token_id sequence...")
|
177 |
+
|
178 |
+
return generated_ids
|
179 |
+
|
180 |
+
def convert_ids_to_text(self, generated_ids, verbose=True):
|
181 |
+
"""converts the token_ids to text"""
|
182 |
+
generated_text = self.tokenizer.decode(generated_ids[0])
|
183 |
+
if verbose:
|
184 |
+
print("Converting token sequence to MidiText...")
|
185 |
+
return generated_text
|
186 |
+
|
187 |
+
def generate_until_track_end(
|
188 |
+
self,
|
189 |
+
input_prompt="PIECE_START ",
|
190 |
+
instrument=None,
|
191 |
+
density=None,
|
192 |
+
temperature=None,
|
193 |
+
verbose=True,
|
194 |
+
expected_length=None,
|
195 |
+
):
|
196 |
+
|
197 |
+
"""generate until the TRACK_END token is reached
|
198 |
+
full_piece = input_prompt + generated"""
|
199 |
+
if expected_length is None:
|
200 |
+
expected_length = self.model_n_bar
|
201 |
+
|
202 |
+
if instrument is not None:
|
203 |
+
input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
|
204 |
+
if density is not None:
|
205 |
+
input_prompt = f"{input_prompt}DENSITY={str(density)} "
|
206 |
+
|
207 |
+
if instrument is None and density is not None:
|
208 |
+
print("Density cannot be defined without an input_prompt instrument #TOFIX")
|
209 |
+
|
210 |
+
if temperature is None:
|
211 |
+
ValueError("Temperature must be defined")
|
212 |
+
|
213 |
+
if verbose:
|
214 |
+
print("--------------------")
|
215 |
+
print(
|
216 |
+
f"Generating {instrument} - Density {density} - temperature {temperature}"
|
217 |
+
)
|
218 |
+
bar_count_checks = False
|
219 |
+
failed = 0
|
220 |
+
while not bar_count_checks: # regenerate until right length
|
221 |
+
input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
|
222 |
+
generated_tokens = self.generate_sequence_of_token_ids(
|
223 |
+
input_prompt_ids, temperature, verbose=verbose
|
224 |
+
)
|
225 |
+
full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
|
226 |
+
generated = self.get_newly_generated_text(input_prompt, full_piece)
|
227 |
+
# bar_count_checks
|
228 |
+
bar_count_checks, bar_count = bar_count_check(generated, expected_length)
|
229 |
+
|
230 |
+
if not self.force_sequence_length:
|
231 |
+
# set bar_count_checks to true to exist the while loop
|
232 |
+
bar_count_checks = True
|
233 |
+
|
234 |
+
if not bar_count_checks and self.force_sequence_length:
|
235 |
+
# if the generated sequence is not the expected length
|
236 |
+
if failed > 1:
|
237 |
+
full_piece, bar_count_checks = forcing_bar_count(
|
238 |
+
input_prompt,
|
239 |
+
generated,
|
240 |
+
bar_count,
|
241 |
+
expected_length,
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
print('"--- Wrong length - Regenerating ---')
|
245 |
+
if not bar_count_checks:
|
246 |
+
failed += 1
|
247 |
+
if failed > 2:
|
248 |
+
bar_count_checks = True # TOFIX exit the while loop
|
249 |
+
|
250 |
+
return full_piece
|
251 |
+
|
252 |
+
def generate_one_new_track(
|
253 |
+
self,
|
254 |
+
instrument,
|
255 |
+
density,
|
256 |
+
temperature,
|
257 |
+
input_prompt="PIECE_START ",
|
258 |
+
):
|
259 |
+
self.initiate_track_dict(instrument, density, temperature)
|
260 |
+
full_piece = self.generate_until_track_end(
|
261 |
+
input_prompt=input_prompt,
|
262 |
+
instrument=instrument,
|
263 |
+
density=density,
|
264 |
+
temperature=temperature,
|
265 |
+
)
|
266 |
+
|
267 |
+
track = self.get_last_generated_track(full_piece)
|
268 |
+
self.update_track_dict__add_bars(track, -1)
|
269 |
+
full_piece = self.get_whole_piece_from_bar_dict()
|
270 |
+
return full_piece
|
271 |
+
|
272 |
+
""" Piece generation - Basics """
|
273 |
+
|
274 |
+
def generate_piece(self, instrument_list, density_list, temperature_list):
|
275 |
+
"""generate a sequence with mutiple tracks
|
276 |
+
- inst_list sets the list of instruments of the order of generation
|
277 |
+
- density is paired with inst_list
|
278 |
+
Each track/intrument is generated on a prompt which contains the previously generated track/instrument
|
279 |
+
This means that the first instrument is generated with less bias than the next one, and so on.
|
280 |
+
|
281 |
+
'generated_piece' keeps track of the entire piece
|
282 |
+
'generated_piece' is returned by self.generate_until_track_end
|
283 |
+
# it is returned by self.generate_until_track_end"""
|
284 |
+
|
285 |
+
generated_piece = "PIECE_START "
|
286 |
+
for instrument, density, temperature in zip(
|
287 |
+
instrument_list, density_list, temperature_list
|
288 |
+
):
|
289 |
+
generated_piece = self.generate_one_new_track(
|
290 |
+
instrument,
|
291 |
+
density,
|
292 |
+
temperature,
|
293 |
+
input_prompt=generated_piece,
|
294 |
+
)
|
295 |
+
|
296 |
+
# generated_piece = self.get_whole_piece_from_bar_dict()
|
297 |
+
self.check_the_piece_for_errors()
|
298 |
+
return generated_piece
|
299 |
+
|
300 |
+
""" Piece generation - Extra Bars """
|
301 |
+
|
302 |
+
@staticmethod
|
303 |
+
def process_prompt_for_next_bar(self, track_idx):
|
304 |
+
"""Processing the prompt for the model to generate one more bar only.
|
305 |
+
The prompt containts:
|
306 |
+
if not the first bar: the previous, already processed, bars of the track
|
307 |
+
the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
|
308 |
+
the last (self.model_n_bar)-1 bars of the track
|
309 |
+
Args:
|
310 |
+
track_idx (int): the index of the track to be processed
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
the processed prompt for generating the next bar
|
314 |
+
"""
|
315 |
+
track = self.piece_by_track[track_idx]
|
316 |
+
# for bars which are not the bar to prolong
|
317 |
+
pre_promt = "PIECE_START "
|
318 |
+
for i, othertrack in enumerate(self.piece_by_track):
|
319 |
+
if i != track_idx:
|
320 |
+
len_diff = len(othertrack["bars"]) - len(track["bars"])
|
321 |
+
if len_diff > 0:
|
322 |
+
# if other bars are longer, it mean that this one should catch up
|
323 |
+
pre_promt += othertrack["bars"][0]
|
324 |
+
for bar in track["bars"][-self.model_n_bar :]:
|
325 |
+
pre_promt += bar
|
326 |
+
pre_promt += "TRACK_END "
|
327 |
+
elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
|
328 |
+
# adding an empty bars at the end of the other tracks if they have not been processed yet
|
329 |
+
pre_promt += othertracks["bars"][0]
|
330 |
+
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
331 |
+
pre_promt += bar
|
332 |
+
for _ in range(abs(len_diff) + 1):
|
333 |
+
pre_promt += "BAR_START BAR_END "
|
334 |
+
pre_promt += "TRACK_END "
|
335 |
+
|
336 |
+
# for the bar to prolong
|
337 |
+
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
|
338 |
+
processed_prompt = track["bars"][0]
|
339 |
+
for bar in track["bars"][-(self.model_n_bar - 1) :]:
|
340 |
+
# adding the "last" bars of the track
|
341 |
+
processed_prompt += bar
|
342 |
+
|
343 |
+
processed_prompt += "BAR_START "
|
344 |
+
print(
|
345 |
+
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
|
346 |
+
)
|
347 |
+
return pre_promt + processed_prompt
|
348 |
+
|
349 |
+
def generate_one_more_bar(self, i):
|
350 |
+
"""Generate one more bar from the input_prompt"""
|
351 |
+
processed_prompt = self.process_prompt_for_next_bar(self, i)
|
352 |
+
prompt_plus_bar = self.generate_until_track_end(
|
353 |
+
input_prompt=processed_prompt,
|
354 |
+
temperature=self.piece_by_track[i]["temperature"],
|
355 |
+
expected_length=1,
|
356 |
+
verbose=False,
|
357 |
+
)
|
358 |
+
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
|
359 |
+
self.update_track_dict__add_bars(added_bar, i)
|
360 |
+
|
361 |
+
def get_newly_generated_bar(self, prompt_plus_bar):
|
362 |
+
return "BAR_START " + self.striping_track_ends(
|
363 |
+
prompt_plus_bar.split("BAR_START ")[-1]
|
364 |
+
)
|
365 |
+
|
366 |
+
def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
|
367 |
+
"""Generate n more bars from the input_prompt"""
|
368 |
+
if only_this_track is None:
|
369 |
+
only_this_track
|
370 |
+
|
371 |
+
print(f"================== ")
|
372 |
+
print(f"Adding {n_bars} more bars to the piece ")
|
373 |
+
for bar_id in range(n_bars):
|
374 |
+
print(f"----- added bar #{bar_id+1} --")
|
375 |
+
for i, track in enumerate(self.piece_by_track):
|
376 |
+
if only_this_track is None or i == only_this_track:
|
377 |
+
print(f"--------- {track['label']}")
|
378 |
+
self.generate_one_more_bar(i)
|
379 |
+
self.check_the_piece_for_errors()
|
380 |
+
|
381 |
+
def check_the_piece_for_errors(self, piece: str = None):
|
382 |
+
|
383 |
+
if piece is None:
|
384 |
+
piece = generate_midi.get_whole_piece_from_bar_dict()
|
385 |
+
errors = []
|
386 |
+
errors.append(
|
387 |
+
[
|
388 |
+
(token, id)
|
389 |
+
for id, token in enumerate(piece.split(" "))
|
390 |
+
if token not in self.tokenizer.vocab or token == "UNK"
|
391 |
+
]
|
392 |
+
)
|
393 |
+
if len(errors) > 0:
|
394 |
+
# print(piece)
|
395 |
+
for er in errors:
|
396 |
+
er
|
397 |
+
print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
|
398 |
+
print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])
|
399 |
+
|
400 |
+
|
401 |
+
if __name__ == "__main__":
|
402 |
+
|
403 |
+
# worker
|
404 |
+
DEVICE = "cpu"
|
405 |
+
|
406 |
+
# define generation parameters
|
407 |
+
N_FILES_TO_GENERATE = 2
|
408 |
+
Temperatures_to_try = [0.7]
|
409 |
+
|
410 |
+
USE_FAMILIZED_MODEL = True
|
411 |
+
force_sequence_length = True
|
412 |
+
|
413 |
+
if USE_FAMILIZED_MODEL:
|
414 |
+
# model_repo = "misnaej/the-jam-machine-elec-famil"
|
415 |
+
# model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
|
416 |
+
|
417 |
+
# model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
|
418 |
+
# n_bar_generated = 8
|
419 |
+
|
420 |
+
model_repo = "JammyMachina/improved_4bars-mdl"
|
421 |
+
n_bar_generated = 4
|
422 |
+
instrument_promt_list = ["4", "DRUMS", "3"]
|
423 |
+
# DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
|
424 |
+
density_list = [3, 2, 2]
|
425 |
+
# temperature_list = [0.7, 0.7, 0.75]
|
426 |
+
else:
|
427 |
+
model_repo = "misnaej/the-jam-machine"
|
428 |
+
instrument_promt_list = ["30"] # , "DRUMS", "0"]
|
429 |
+
density_list = [3] # , 2, 3]
|
430 |
+
# temperature_list = [0.7, 0.5, 0.75]
|
431 |
+
pass
|
432 |
+
|
433 |
+
# define generation directory
|
434 |
+
generated_sequence_files_path = define_generation_dir(model_repo)
|
435 |
+
|
436 |
+
# load model and tokenizer
|
437 |
+
model, tokenizer = LoadModel(
|
438 |
+
model_repo, from_huggingface=True
|
439 |
+
).load_model_and_tokenizer()
|
440 |
+
|
441 |
+
# does the prompt make sense
|
442 |
+
check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
|
443 |
+
|
444 |
+
for temperature in Temperatures_to_try:
|
445 |
+
print(f"================= TEMPERATURE {temperature} =======================")
|
446 |
+
for _ in range(N_FILES_TO_GENERATE):
|
447 |
+
print(f"========================================")
|
448 |
+
# 1 - instantiate
|
449 |
+
generate_midi = GenerateMidiText(model, tokenizer)
|
450 |
+
# 0 - set the n_bar for this model
|
451 |
+
generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
|
452 |
+
# 1 - defines the instruments, densities and temperatures
|
453 |
+
# 2- generate the first 8 bars for each instrument
|
454 |
+
generate_midi.set_improvisation_level(30)
|
455 |
+
generate_midi.generate_piece(
|
456 |
+
instrument_promt_list,
|
457 |
+
density_list,
|
458 |
+
[temperature for _ in density_list],
|
459 |
+
)
|
460 |
+
# 3 - force the model to improvise
|
461 |
+
# generate_midi.set_improvisation_level(20)
|
462 |
+
# 4 - generate the next 4 bars for each instrument
|
463 |
+
# generate_midi.generate_n_more_bars(n_bar_generated)
|
464 |
+
# 5 - lower the improvisation level
|
465 |
+
generate_midi.generated_piece = (
|
466 |
+
generate_midi.get_whole_piece_from_bar_dict()
|
467 |
+
)
|
468 |
+
|
469 |
+
# print the generated sequence in terminal
|
470 |
+
print("=========================================")
|
471 |
+
print(generate_midi.generated_piece)
|
472 |
+
print("=========================================")
|
473 |
+
|
474 |
+
# write to JSON file
|
475 |
+
filename = WriteTextMidiToFile(
|
476 |
+
generate_midi,
|
477 |
+
generated_sequence_files_path,
|
478 |
+
).text_midi_to_file()
|
479 |
+
|
480 |
+
# decode the sequence to MIDI """
|
481 |
+
decode_tokenizer = get_miditok()
|
482 |
+
TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
|
483 |
+
generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
|
484 |
+
)
|
485 |
+
inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
|
486 |
+
max_time = get_max_time(inst_midi)
|
487 |
+
plot_piano_roll(inst_midi)
|
488 |
+
|
489 |
+
print("Et voilà! Your MIDI file is ready! GO JAM!")
|
load.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2LMHeadModel
|
2 |
+
from transformers import PreTrainedTokenizerFast
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class LoadModel:
|
7 |
+
"""
|
8 |
+
Example usage:
|
9 |
+
|
10 |
+
# if loading model and tokenizer from Huggingface
|
11 |
+
model_repo = "misnaej/the-jam-machine"
|
12 |
+
model, tokenizer = LoadModel(
|
13 |
+
model_repo, from_huggingface=True
|
14 |
+
).load_model_and_tokenizer()
|
15 |
+
|
16 |
+
# if loading model and tokenizer from a local folder
|
17 |
+
model_path = "models/model_2048_wholedataset"
|
18 |
+
model, tokenizer = LoadModel(
|
19 |
+
model_path, from_huggingface=False
|
20 |
+
).load_model_and_tokenizer()
|
21 |
+
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
|
25 |
+
# path is either a relative path on a local/remote machine or a model repo on HuggingFace
|
26 |
+
if not from_huggingface:
|
27 |
+
if not os.path.exists(path):
|
28 |
+
print(path)
|
29 |
+
raise Exception("Model path does not exist")
|
30 |
+
self.from_huggingface = from_huggingface
|
31 |
+
self.path = path
|
32 |
+
self.device = device
|
33 |
+
self.revision = revision
|
34 |
+
|
35 |
+
def load_model_and_tokenizer(self):
|
36 |
+
model = self.load_model()
|
37 |
+
tokenizer = self.load_tokenizer()
|
38 |
+
|
39 |
+
return model, tokenizer
|
40 |
+
|
41 |
+
def load_model(self):
|
42 |
+
if self.revision is None:
|
43 |
+
model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
|
44 |
+
else:
|
45 |
+
model = GPT2LMHeadModel.from_pretrained(
|
46 |
+
self.path, revision=self.revision
|
47 |
+
).to(self.device)
|
48 |
+
|
49 |
+
return model
|
50 |
+
|
51 |
+
def load_tokenizer(self):
|
52 |
+
if self.from_huggingface:
|
53 |
+
pass
|
54 |
+
else:
|
55 |
+
if not os.path.exists(f"{self.path}/tokenizer.json"):
|
56 |
+
raise Exception(
|
57 |
+
f"There is no 'tokenizer.json'file in the defined {self.path}"
|
58 |
+
)
|
59 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
|
60 |
+
return tokenizer
|
playback.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import librosa.display
|
3 |
+
from pretty_midi import PrettyMIDI
|
4 |
+
|
5 |
+
|
6 |
+
# Note: these functions are meant to be played within an interactive Python shell
|
7 |
+
# Please refer to the synth.ipynb for an example of how to use them
|
8 |
+
|
9 |
+
|
10 |
+
def get_music(midi_file):
|
11 |
+
"""
|
12 |
+
Load a midi file and return the PrettyMIDI object and the audio signal
|
13 |
+
"""
|
14 |
+
music = PrettyMIDI(midi_file=midi_file)
|
15 |
+
waveform = music.fluidsynth()
|
16 |
+
return music, waveform
|
17 |
+
|
18 |
+
|
19 |
+
def show_piano_roll(music_notes, fs=100):
|
20 |
+
"""
|
21 |
+
Show the piano roll of a music piece, with all instruments squashed onto a single 128xN matrix
|
22 |
+
:param music_notes: PrettyMIDI object
|
23 |
+
:param fs: sampling frequency
|
24 |
+
"""
|
25 |
+
# get the piano roll
|
26 |
+
piano_roll = music_notes.get_piano_roll(fs)
|
27 |
+
print("Piano roll shape: {}".format(piano_roll.shape))
|
28 |
+
|
29 |
+
# plot the piano roll
|
30 |
+
plt.figure(figsize=(12, 4))
|
31 |
+
librosa.display.specshow(piano_roll, sr=100, x_axis="time", y_axis="cqt_note")
|
32 |
+
plt.colorbar()
|
33 |
+
plt.title("Piano roll")
|
34 |
+
plt.tight_layout()
|
35 |
+
plt.show()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
matplotlib
|
3 |
+
sys
|
4 |
+
matplotlib
|
5 |
+
numpy
|
utils.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from miditok import Event, MIDILike
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from time import perf_counter
|
6 |
+
from joblib import Parallel, delayed
|
7 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
8 |
+
from scipy.io.wavfile import write
|
9 |
+
import numpy as np
|
10 |
+
from pydub import AudioSegment
|
11 |
+
import shutil
|
12 |
+
|
13 |
+
|
14 |
+
def writeToFile(path, content):
|
15 |
+
if type(content) is dict:
|
16 |
+
with open(f"{path}", "w") as json_file:
|
17 |
+
json.dump(content, json_file)
|
18 |
+
else:
|
19 |
+
if type(content) is not str:
|
20 |
+
content = str(content)
|
21 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
22 |
+
with open(path, "w") as f:
|
23 |
+
f.write(content)
|
24 |
+
|
25 |
+
|
26 |
+
# Function to read from text from txt file:
|
27 |
+
def readFromFile(path, isJSON=False):
|
28 |
+
with open(path, "r") as f:
|
29 |
+
if isJSON:
|
30 |
+
return json.load(f)
|
31 |
+
else:
|
32 |
+
return f.read()
|
33 |
+
|
34 |
+
|
35 |
+
def chain(input, funcs, *params):
|
36 |
+
res = input
|
37 |
+
for func in funcs:
|
38 |
+
try:
|
39 |
+
res = func(res, *params)
|
40 |
+
except TypeError:
|
41 |
+
res = func(res)
|
42 |
+
return res
|
43 |
+
|
44 |
+
|
45 |
+
def to_beat_str(value, beat_res=8):
|
46 |
+
values = [
|
47 |
+
int(int(value * beat_res) / beat_res),
|
48 |
+
int(int(value * beat_res) % beat_res),
|
49 |
+
beat_res,
|
50 |
+
]
|
51 |
+
return ".".join(map(str, values))
|
52 |
+
|
53 |
+
|
54 |
+
def to_base10(beat_str):
|
55 |
+
integer, decimal, base = split_dots(beat_str)
|
56 |
+
return integer + decimal / base
|
57 |
+
|
58 |
+
|
59 |
+
def split_dots(value):
|
60 |
+
return list(map(int, value.split(".")))
|
61 |
+
|
62 |
+
|
63 |
+
def compute_list_average(l):
|
64 |
+
return sum(l) / len(l)
|
65 |
+
|
66 |
+
|
67 |
+
def get_datetime():
|
68 |
+
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
69 |
+
|
70 |
+
|
71 |
+
def get_text(event):
|
72 |
+
match event.type:
|
73 |
+
case "Piece-Start":
|
74 |
+
return "PIECE_START "
|
75 |
+
case "Track-Start":
|
76 |
+
return "TRACK_START "
|
77 |
+
case "Track-End":
|
78 |
+
return "TRACK_END "
|
79 |
+
case "Instrument":
|
80 |
+
return f"INST={event.value} "
|
81 |
+
case "Bar-Start":
|
82 |
+
return "BAR_START "
|
83 |
+
case "Bar-End":
|
84 |
+
return "BAR_END "
|
85 |
+
case "Time-Shift":
|
86 |
+
return f"TIME_SHIFT={event.value} "
|
87 |
+
case "Note-On":
|
88 |
+
return f"NOTE_ON={event.value} "
|
89 |
+
case "Note-Off":
|
90 |
+
return f"NOTE_OFF={event.value} "
|
91 |
+
case _:
|
92 |
+
return ""
|
93 |
+
|
94 |
+
|
95 |
+
def get_event(text, value=None):
|
96 |
+
match text:
|
97 |
+
case "PIECE_START":
|
98 |
+
return Event("Piece-Start", value)
|
99 |
+
case "TRACK_START":
|
100 |
+
return None
|
101 |
+
case "TRACK_END":
|
102 |
+
return None
|
103 |
+
case "INST":
|
104 |
+
return Event("Instrument", value)
|
105 |
+
case "BAR_START":
|
106 |
+
return Event("Bar-Start", value)
|
107 |
+
case "BAR_END":
|
108 |
+
return Event("Bar-End", value)
|
109 |
+
case "TIME_SHIFT":
|
110 |
+
return Event("Time-Shift", value)
|
111 |
+
case "TIME_DELTA":
|
112 |
+
return Event("Time-Shift", to_beat_str(int(value) / 4))
|
113 |
+
case "NOTE_ON":
|
114 |
+
return Event("Note-On", value)
|
115 |
+
case "NOTE_OFF":
|
116 |
+
return Event("Note-Off", value)
|
117 |
+
case _:
|
118 |
+
return None
|
119 |
+
|
120 |
+
|
121 |
+
# TODO: Make this singleton
|
122 |
+
def get_miditok():
|
123 |
+
pitch_range = range(0, 140) # was (21, 109)
|
124 |
+
beat_res = {(0, 400): 8}
|
125 |
+
return MIDILike(pitch_range, beat_res)
|
126 |
+
|
127 |
+
|
128 |
+
class WriteTextMidiToFile: # utils saving to file
|
129 |
+
def __init__(self, generate_midi, output_path):
|
130 |
+
self.generated_midi = generate_midi.generated_piece
|
131 |
+
self.output_path = output_path
|
132 |
+
self.hyperparameter_and_bars = generate_midi.piece_by_track
|
133 |
+
|
134 |
+
def hashing_seq(self):
|
135 |
+
self.current_time = get_datetime()
|
136 |
+
self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
|
137 |
+
|
138 |
+
def wrapping_seq_hyperparameters_in_dict(self):
|
139 |
+
# assert type(self.generated_midi) is str, "error: generate_midi must be a string"
|
140 |
+
# assert (
|
141 |
+
# type(self.hyperparameter_dict) is dict
|
142 |
+
# ), "error: feature_dict must be a dictionnary"
|
143 |
+
return {
|
144 |
+
"generate_midi": self.generated_midi,
|
145 |
+
"hyperparameters_and_bars": self.hyperparameter_and_bars,
|
146 |
+
}
|
147 |
+
|
148 |
+
def text_midi_to_file(self):
|
149 |
+
self.hashing_seq()
|
150 |
+
output_dict = self.wrapping_seq_hyperparameters_in_dict()
|
151 |
+
print(f"Token generate_midi written: {self.output_path_filename}")
|
152 |
+
writeToFile(self.output_path_filename, output_dict)
|
153 |
+
return self.output_path_filename
|
154 |
+
|
155 |
+
|
156 |
+
def get_files(directory, extension, recursive=False):
|
157 |
+
"""
|
158 |
+
Given a directory, get a list of the file paths of all files matching the
|
159 |
+
specified file extension.
|
160 |
+
directory: the directory to search as a Path object
|
161 |
+
extension: the file extension to match as a string
|
162 |
+
recursive: whether to search recursively in the directory or not
|
163 |
+
"""
|
164 |
+
if recursive:
|
165 |
+
return list(directory.rglob(f"*.{extension}"))
|
166 |
+
else:
|
167 |
+
return list(directory.glob(f"*.{extension}"))
|
168 |
+
|
169 |
+
|
170 |
+
def timeit(func):
|
171 |
+
def wrapper(*args, **kwargs):
|
172 |
+
start = perf_counter()
|
173 |
+
result = func(*args, **kwargs)
|
174 |
+
end = perf_counter()
|
175 |
+
print(f"{func.__name__} took {end - start:.2f} seconds to run.")
|
176 |
+
return result
|
177 |
+
|
178 |
+
return wrapper
|
179 |
+
|
180 |
+
|
181 |
+
class FileCompressor:
|
182 |
+
def __init__(self, input_directory, output_directory, n_jobs=-1):
|
183 |
+
self.input_directory = input_directory
|
184 |
+
self.output_directory = output_directory
|
185 |
+
self.n_jobs = n_jobs
|
186 |
+
|
187 |
+
# File compression and decompression
|
188 |
+
def unzip_file(self, file):
|
189 |
+
"""uncompress single zip file"""
|
190 |
+
with ZipFile(file, "r") as zip_ref:
|
191 |
+
zip_ref.extractall(self.output_directory)
|
192 |
+
|
193 |
+
def zip_file(self, file):
|
194 |
+
"""compress a single text file to a new zip file and delete the original"""
|
195 |
+
output_file = self.output_directory / (file.stem + ".zip")
|
196 |
+
with ZipFile(output_file, "w") as zip_ref:
|
197 |
+
zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED)
|
198 |
+
file.unlink()
|
199 |
+
|
200 |
+
@timeit
|
201 |
+
def unzip(self):
|
202 |
+
"""uncompress all zip files in folder"""
|
203 |
+
files = get_files(self.input_directory, extension="zip")
|
204 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files)
|
205 |
+
|
206 |
+
@timeit
|
207 |
+
def zip(self):
|
208 |
+
"""compress all text files in folder to new zip files and remove the text files"""
|
209 |
+
files = get_files(self.output_directory, extension="txt")
|
210 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)
|
211 |
+
|
212 |
+
|
213 |
+
def load_jsonl(filepath):
|
214 |
+
"""Load a jsonl file"""
|
215 |
+
with open(filepath, "r") as f:
|
216 |
+
data = [json.loads(line) for line in f]
|
217 |
+
return data
|
218 |
+
|
219 |
+
|
220 |
+
def write_mp3(waveform, output_path, bitrate="92k"):
|
221 |
+
"""
|
222 |
+
Write a waveform to an mp3 file.
|
223 |
+
output_path: Path object for the output mp3 file
|
224 |
+
waveform: numpy array of the waveform
|
225 |
+
bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
|
226 |
+
"""
|
227 |
+
# write the wav file
|
228 |
+
wav_path = output_path.with_suffix(".wav")
|
229 |
+
write(wav_path, 44100, waveform.astype(np.float32))
|
230 |
+
# compress the wav file as mp3
|
231 |
+
AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
|
232 |
+
# remove the wav file
|
233 |
+
wav_path.unlink()
|
234 |
+
|
235 |
+
|
236 |
+
def copy_file(input_file, output_dir):
|
237 |
+
"""Copy an input file to the output_dir"""
|
238 |
+
output_file = output_dir / input_file.name
|
239 |
+
shutil.copy(input_file, output_file)
|
240 |
+
|
241 |
+
|
242 |
+
def index_has_substring(list, substring):
|
243 |
+
for i, s in enumerate(list):
|
244 |
+
if substring in s:
|
245 |
+
return i
|
246 |
+
return -1
|