m41w4r3.exe commited on
Commit
2ec0615
·
1 Parent(s): 3e2b7ea

initial commit

Browse files
Files changed (9) hide show
  1. README.md +10 -11
  2. app.py +154 -0
  3. constants.py +77 -0
  4. decoder.py +197 -0
  5. generate.py +489 -0
  6. load.py +60 -0
  7. playback.py +35 -0
  8. requirements.txt +5 -0
  9. utils.py +246 -0
README.md CHANGED
@@ -1,12 +1,11 @@
1
- ---
2
- title: The Jam Machine
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.14.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
+ Contributors:
 
 
 
 
 
 
 
 
 
2
 
3
+ Jean Simonnet:
4
+ https://github.com/misnaej
5
+ https://www.linkedin.com/in/jeansimonnet/
6
+ Louis Demetz:
7
+ https://github.com/louis-demetz
8
+ https://www.linkedin.com/in/ldemetz/
9
+ Halid Bayram:
10
+ https://github.com/m41w4r3exe
11
+ https://www.linkedin.com/in/halid-bayram-6b9ba861/
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from load import LoadModel
3
+ from generate import GenerateMidiText
4
+ from constants import INSTRUMENT_CLASSES
5
+ from encoder import MIDIEncoder
6
+ from decoder import TextDecoder
7
+ from utils import get_miditok, index_has_substring
8
+ from playback import get_music
9
+ from matplotlib import pylab
10
+ import sys
11
+ import matplotlib
12
+ from generation_utils import plot_piano_roll
13
+ import numpy as np
14
+
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ sys.modules["pylab"] = pylab
19
+
20
+ model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
21
+ revision = "ddf00f90d6d27e4cc0cb99c04a22a8f0a16c933e"
22
+ n_bar_generated = 8
23
+ # model_repo = "JammyMachina/improved_4bars-mdl"
24
+ # n_bar_generated = 4
25
+
26
+ model, tokenizer = LoadModel(
27
+ model_repo, from_huggingface=True, revision=revision
28
+ ).load_model_and_tokenizer()
29
+ genesis = GenerateMidiText(
30
+ model,
31
+ tokenizer,
32
+ )
33
+ genesis.set_nb_bars_generated(n_bars=n_bar_generated)
34
+
35
+ miditok = get_miditok()
36
+ decoder = TextDecoder(miditok)
37
+
38
+
39
+ def define_prompt(state, genesis):
40
+ if len(state) == 0:
41
+ input_prompt = "PIECE_START "
42
+ else:
43
+ input_prompt = genesis.get_whole_piece_from_bar_dict()
44
+ return input_prompt
45
+
46
+
47
+ def generator(
48
+ regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
49
+ ):
50
+
51
+ inst = next(
52
+ (inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
53
+ {"family_number": "DRUMS"},
54
+ )["family_number"]
55
+
56
+ inst_index = index_has_substring(state, "INST=" + str(inst))
57
+
58
+ # Regenerate
59
+ if regenerate:
60
+ state.pop(inst_index)
61
+ genesis.delete_one_track(inst_index)
62
+ generated_text = (
63
+ genesis.get_whole_piece_from_bar_dict()
64
+ ) # maybe not useful here
65
+ inst_index = -1 # reset to last generated
66
+
67
+ # Generate
68
+ if not add_bars:
69
+ # NEW TRACK
70
+ input_prompt = define_prompt(state, genesis)
71
+ generated_text = genesis.generate_one_new_track(
72
+ inst, density, temp, input_prompt=input_prompt
73
+ )
74
+ else:
75
+ # NEW BARS
76
+ genesis.generate_n_more_bars(add_bar_count) # for all instruments
77
+ generated_text = genesis.get_whole_piece_from_bar_dict()
78
+
79
+ decoder.get_midi(generated_text, "tmp/mixed.mid")
80
+ mixed_inst_midi, mixed_audio = get_music("tmp/mixed.mid")
81
+
82
+ inst_text = genesis.get_selected_track_as_text(inst_index)
83
+ inst_midi_name = f"tmp/{instrument}.mid"
84
+ decoder.get_midi(inst_text, inst_midi_name)
85
+ _, inst_audio = get_music(inst_midi_name)
86
+ piano_roll = plot_piano_roll(mixed_inst_midi)
87
+ state.append(inst_text)
88
+
89
+ return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)
90
+
91
+
92
+ def instrument_row(default_inst):
93
+
94
+ with gr.Row():
95
+ with gr.Column(scale=1, min_width=50):
96
+ inst = gr.Dropdown(
97
+ [inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
98
+ value=default_inst,
99
+ label="Instrument",
100
+ )
101
+ temp = gr.Number(value=0.7, label="Creativity")
102
+ density = gr.Dropdown([0, 1, 2, 3], value=3, label="Density")
103
+
104
+ with gr.Column(scale=3):
105
+ output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
106
+ with gr.Column(scale=1, min_width=100):
107
+ inst_audio = gr.Audio(label="Audio")
108
+ regenerate = gr.Checkbox(value=False, label="Regenerate")
109
+ # add_bars = gr.Checkbox(value=False, label="Add Bars")
110
+ # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
111
+ gen_btn = gr.Button("Generate")
112
+ gen_btn.click(
113
+ fn=generator,
114
+ inputs=[
115
+ regenerate,
116
+ temp,
117
+ density,
118
+ inst,
119
+ state,
120
+ ],
121
+ outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
122
+ )
123
+
124
+
125
+ with gr.Blocks(cache_examples=False) as demo:
126
+ state = gr.State([])
127
+ mixed_audio = gr.Audio(label="Mixed Audio")
128
+ piano_roll = gr.Plot(label="Piano Roll")
129
+ instrument_row("Drums")
130
+ instrument_row("Bass")
131
+ instrument_row("Synth Lead")
132
+ # instrument_row("Piano")
133
+
134
+ demo.launch(debug=True)
135
+
136
+ """
137
+ TODO: DEPLOY
138
+ TODO: temp file situation
139
+ TODO: clear cache situation
140
+ TODO: reset button
141
+ TODO: instrument mapping business
142
+ TODO: Y lim axis of piano roll
143
+ TODO: add a button to save the generated midi
144
+ TODO: add improvise button
145
+ TODO: making the piano roll fit on the horizontal scale
146
+ TODO: set values for temperature as it is done for density
147
+ TODO: set the color situation to be dark background
148
+ TODO: make regeration default when an intrument has already been track has already been generated
149
+ TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
150
+ TODO: row height to fix
151
+
152
+ TODO: reset state of tick boxes after used maybe (regenerate, add bars) ;
153
+ TODO: block regenerate if add bar on
154
+ """
constants.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # fmt: off
2
+ # Instrument mapping and mapping functions
3
+ INSTRUMENT_CLASSES = [
4
+ {"name": "Piano", "program_range": range(0, 8), "family_number": 0},
5
+ {"name": "Chromatic Percussion", "program_range": range(8, 16), "family_number": 1},
6
+ {"name": "Organ", "program_range": range(16, 24), "family_number": 2},
7
+ {"name": "Guitar", "program_range": range(24, 32), "family_number": 3},
8
+ {"name": "Bass", "program_range": range(32, 40), "family_number": 4},
9
+ {"name": "Strings", "program_range": range(40, 48), "family_number": 5},
10
+ {"name": "Ensemble", "program_range": range(48, 56), "family_number": 6},
11
+ {"name": "Brass", "program_range": range(56, 64), "family_number": 7},
12
+ {"name": "Reed", "program_range": range(64, 72), "family_number": 8},
13
+ {"name": "Pipe", "program_range": range(72, 80), "family_number": 9},
14
+ {"name": "Synth Lead", "program_range": range(80, 88), "family_number": 10},
15
+ {"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
16
+ {"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
17
+ {"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
18
+ {"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
19
+ {"name": "Sound Effects", "program_range": range(120, 128), "family_number": 15,},
20
+ ]
21
+ # fmt: on
22
+
23
+ # Instrument mapping for decodiing our midi sequence into midi instruments of our choice
24
+ INSTRUMENT_TRANSFER_CLASSES = [
25
+ {
26
+ "name": "Piano",
27
+ "program_range": [4], # Electric Piano 1
28
+ "family_number": 0,
29
+ },
30
+ {
31
+ "name": "Chromatic Percussion",
32
+ "program_range": [11], # Vibraphone
33
+ "family_number": 1,
34
+ },
35
+ {"name": "Organ", "program_range": [17], "family_number": 2}, # Percussive Organ
36
+ {
37
+ "name": "Guitar",
38
+ "program_range": [80], # Synth Lead Square
39
+ "family_number": 3,
40
+ },
41
+ {
42
+ "name": "Bass",
43
+ "program_range": [38], # Synth bass 1,
44
+ "family_number": 4,
45
+ },
46
+ {
47
+ "name": "Strings",
48
+ "program_range": [50], # Synth Strings 1
49
+ "family_number": 5,
50
+ },
51
+ {
52
+ "name": "Ensemble",
53
+ "program_range": [51], # Synth Strings 2
54
+ "family_number": 6,
55
+ },
56
+ {
57
+ "name": "Brass",
58
+ "program_range": [63], # 63 Synth Brass 1,
59
+ "family_number": 7,
60
+ },
61
+ {"name": "Reed", "program_range": [64], "family_number": 8}, # Synth Brass 2
62
+ {"name": "Pipe", "program_range": [82], "family_number": 9}, # Lead 3
63
+ {
64
+ "name": "Synth Lead",
65
+ "program_range": [81], # Synth Lead Sawtooth
66
+ "family_number": 10,
67
+ },
68
+ {"name": "Synth Pad", "program_range": range(88, 96), "family_number": 11},
69
+ {"name": "Synth Effects", "program_range": range(96, 104), "family_number": 12},
70
+ {"name": "Ethnic", "program_range": range(104, 112), "family_number": 13},
71
+ {"name": "Percussive", "program_range": range(112, 120), "family_number": 14},
72
+ {
73
+ "name": "Sound Effects",
74
+ "program_range": range(120, 128),
75
+ "family_number": 15,
76
+ },
77
+ ]
decoder.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from familizer import Familizer
3
+ from miditok import Event
4
+
5
+
6
+ class TextDecoder:
7
+ """Decodes text into:
8
+ 1- List of events
9
+ 2- Then converts these events to midi file via MidiTok and miditoolkit
10
+
11
+ :param tokenizer: from MidiTok
12
+
13
+ Usage with write_to_midi method:
14
+ args: text(String) example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
15
+ returns: midi file from miditoolkit
16
+ """
17
+
18
+ def __init__(self, tokenizer, familized=True):
19
+ self.tokenizer = tokenizer
20
+ self.familized = familized
21
+
22
+ def decode(self, text):
23
+ r"""converts from text to instrument events
24
+ Args:
25
+ text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
26
+
27
+ Returns:
28
+ Dict{inst_id: List[Events]}: List of events of Notes with velocities, aggregated Timeshifts, for each instrument
29
+ """
30
+ piece_events = self.text_to_events(text)
31
+ inst_events = self.piece_to_inst_events(piece_events)
32
+ events = self.add_timeshifts_for_empty_bars(inst_events)
33
+ events = self.aggregate_timeshifts(events)
34
+ events = self.add_velocity(events)
35
+ return events
36
+
37
+ def tokenize(self, events):
38
+ r"""converts from events to MidiTok tokens
39
+ Args:
40
+ events (Dict{inst_id: List[Events]}): List of events for each instrument
41
+
42
+ Returns:
43
+ List[List[Events]]: List of tokens for each instrument
44
+ """
45
+ tokens = []
46
+ for inst in events.keys():
47
+ tokens.append(self.tokenizer.events_to_tokens(events[inst]))
48
+ return tokens
49
+
50
+ def get_midi(self, text, filename=None):
51
+ r"""converts from text to midi
52
+ Args:
53
+ text (String): example -> PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50...BAR_END TRACK_END
54
+
55
+ Returns:
56
+ miditoolkit midi: Returns and writes to midi
57
+ """
58
+ events = self.decode(text)
59
+ tokens = self.tokenize(events)
60
+ instruments = self.get_instruments_tuple(events)
61
+ midi = self.tokenizer.tokens_to_midi(tokens, instruments)
62
+
63
+ if filename is not None:
64
+ midi.dump(f"{filename}")
65
+ print(f"midi file written: {filename}")
66
+
67
+ return midi
68
+
69
+ @staticmethod
70
+ def text_to_events(text):
71
+ events = []
72
+ for word in text.split(" "):
73
+ # TODO: Handle bar and track values with a counter
74
+ _event = word.split("=")
75
+ value = _event[1] if len(_event) > 1 else None
76
+ event = get_event(_event[0], value)
77
+ if event:
78
+ events.append(event)
79
+ return events
80
+
81
+ @staticmethod
82
+ def piece_to_inst_events(piece_events):
83
+ """Converts piece events of 8 bars to instrument events for entire song
84
+
85
+ Args:
86
+ piece_events (List[Events]): List of events of Notes, Timeshifts, Bars, Tracks
87
+
88
+ Returns:
89
+ Dict{inst_id: List[Events]}: List of events for each instrument
90
+
91
+ """
92
+ inst_events = {}
93
+ current_instrument = -1
94
+ for event in piece_events:
95
+ if event.type == "Instrument":
96
+ current_instrument = event.value
97
+ if current_instrument not in inst_events:
98
+ inst_events[current_instrument] = []
99
+ elif current_instrument != -1:
100
+ inst_events[current_instrument].append(event)
101
+ return inst_events
102
+
103
+ @staticmethod
104
+ def add_timeshifts_for_empty_bars(inst_events):
105
+ """Adds time shift events instead of consecutive [BAR_START BAR_END] events"""
106
+ new_inst_events = {}
107
+ for inst, events in inst_events.items():
108
+ new_inst_events[inst] = []
109
+ for index, event in enumerate(events):
110
+ if event.type == "Bar-End" or event.type == "Bar-Start":
111
+ if events[index - 1].type == "Bar-Start":
112
+ new_inst_events[inst].append(Event("Time-Shift", "4.0.8"))
113
+ else:
114
+ new_inst_events[inst].append(event)
115
+ return new_inst_events
116
+
117
+ @staticmethod
118
+ def add_timeshifts(beat_values1, beat_values2):
119
+ """Adds two beat values
120
+
121
+ Args:
122
+ beat_values1 (String): like 0.3.8
123
+ beat_values2 (String): like 1.7.8
124
+
125
+ Returns:
126
+ beat_str (String): added beats like 2.2.8 for example values
127
+ """
128
+ value1 = to_base10(beat_values1)
129
+ value2 = to_base10(beat_values2)
130
+ return to_beat_str(value1 + value2)
131
+
132
+ def aggregate_timeshifts(self, events):
133
+ """Aggregates consecutive time shift events bigger than a bar
134
+ -> like Timeshift 4.0.8
135
+
136
+ Args:
137
+ events (_type_): _description_
138
+
139
+ Returns:
140
+ _type_: _description_
141
+ """
142
+ new_events = {}
143
+ for inst, events in events.items():
144
+ inst_events = []
145
+ for i, event in enumerate(events):
146
+ if (
147
+ event.type == "Time-Shift"
148
+ and len(inst_events) > 0
149
+ and inst_events[-1].type == "Time-Shift"
150
+ ):
151
+ inst_events[-1].value = self.add_timeshifts(
152
+ inst_events[-1].value, event.value
153
+ )
154
+ else:
155
+ inst_events.append(event)
156
+ new_events[inst] = inst_events
157
+ return new_events
158
+
159
+ @staticmethod
160
+ def add_velocity(events):
161
+ """Adds default velocity 99 to note events since they are removed from text, needed to generate midi"""
162
+ new_events = {}
163
+ for inst, events in events.items():
164
+ inst_events = []
165
+ for event in events:
166
+ inst_events.append(event)
167
+ if event.type == "Note-On":
168
+ inst_events.append(Event("Velocity", 99))
169
+ new_events[inst] = inst_events
170
+ return new_events
171
+
172
+ def get_instruments_tuple(self, events):
173
+ """Returns instruments tuple for midi generation"""
174
+ instruments = []
175
+ for inst in events.keys():
176
+ is_drum = 0
177
+ if inst == "DRUMS":
178
+ inst = 0
179
+ is_drum = 1
180
+ if self.familized:
181
+ inst = Familizer(arbitrary=True).get_program_number(int(inst)) + 1
182
+ instruments.append((int(inst), is_drum))
183
+ return tuple(instruments)
184
+
185
+
186
+ if __name__ == "__main__":
187
+
188
+ filename = "midi/generated/misnaej/the-jam-machine-elec-famil/20221209_175750"
189
+ encoded_json = readFromFile(
190
+ f"{filename}.json",
191
+ True,
192
+ )
193
+ encoded_text = encoded_json["sequence"]
194
+ # encoded_text = "PIECE_START TRACK_START INST=25 DENSITY=2 BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=69 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=69 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=57 TIME_DELTA=1 NOTE_OFF=57 NOTE_ON=56 TIME_DELTA=1 NOTE_OFF=56 NOTE_ON=64 NOTE_ON=60 NOTE_ON=55 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=55 BAR_END BAR_START NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=66 NOTE_ON=62 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=66 NOTE_OFF=62 NOTE_OFF=50 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=67 NOTE_ON=64 TIME_DELTA=1 NOTE_OFF=67 NOTE_OFF=64 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=50 NOTE_ON=64 NOTE_ON=60 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=64 NOTE_OFF=60 NOTE_OFF=50 NOTE_ON=59 NOTE_ON=55 NOTE_ON=50 TIME_DELTA=1 NOTE_OFF=59 NOTE_OFF=50 NOTE_OFF=55 NOTE_OFF=50 BAR_END BAR_START BAR_END TRACK_END"
195
+
196
+ miditok = get_miditok()
197
+ TextDecoder(miditok).get_midi(encoded_text, filename=filename)
generate.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from generation_utils import *
2
+ from utils import WriteTextMidiToFile, get_miditok
3
+ from load import LoadModel
4
+ from constants import INSTRUMENT_CLASSES
5
+
6
+ ## import for execution
7
+ from decoder import TextDecoder
8
+ from playback import get_music, show_piano_roll
9
+
10
+
11
+ class GenerateMidiText:
12
+ """Generating music with Class
13
+
14
+ LOGIC:
15
+
16
+ FOR GENERATING FROM SCRATCH:
17
+ - self.generate_one_new_track()
18
+ it calls
19
+ - self.generate_until_track_end()
20
+
21
+ FOR GENERATING NEW BARS:
22
+ - self.generate_one_more_bar()
23
+ it calls
24
+ - self.process_prompt_for_next_bar()
25
+ - self.generate_until_track_end()"""
26
+
27
+ def __init__(self, model, tokenizer):
28
+ self.model = model
29
+ self.tokenizer = tokenizer
30
+ # default initialization
31
+ self.initialize_default_parameters()
32
+ self.initialize_dictionaries()
33
+
34
+ """Setters"""
35
+
36
+ def initialize_default_parameters(self):
37
+ self.set_device()
38
+ self.set_attention_length()
39
+ self.generate_until = "TRACK_END"
40
+ self.set_force_sequence_lenth()
41
+ self.set_nb_bars_generated()
42
+ self.set_improvisation_level(0)
43
+
44
+ def initialize_dictionaries(self):
45
+ self.piece_by_track = []
46
+
47
+ def set_device(self, device="cpu"):
48
+ self.device = ("cpu",)
49
+
50
+ def set_attention_length(self):
51
+ self.max_length = self.model.config.n_positions
52
+ print(
53
+ f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
54
+ )
55
+
56
+ def set_force_sequence_lenth(self, force_sequence_length=True):
57
+ self.force_sequence_length = force_sequence_length
58
+
59
+ def set_improvisation_level(self, improvisation_value):
60
+ self.no_repeat_ngram_size = improvisation_value
61
+ print("--------------------")
62
+ print(f"no_repeat_ngram_size set to {improvisation_value}")
63
+ print("--------------------")
64
+
65
+ def reset_temperatures(self, track_id, temperature):
66
+ self.piece_by_track[track_id]["temperature"] = temperature
67
+
68
+ def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model
69
+ self.model_n_bar = n_bars
70
+
71
+ """ Generation Tools - Dictionnaries """
72
+
73
+ def initiate_track_dict(self, instr, density, temperature):
74
+ label = len(self.piece_by_track)
75
+ self.piece_by_track.append(
76
+ {
77
+ "label": f"track_{label}",
78
+ "instrument": instr,
79
+ "density": density,
80
+ "temperature": temperature,
81
+ "bars": [],
82
+ }
83
+ )
84
+
85
+ def update_track_dict__add_bars(self, bars, track_id):
86
+ """Add bars to the track dictionnary"""
87
+ for bar in self.striping_track_ends(bars).split("BAR_START "):
88
+ if bar == "": # happens is there is one bar only
89
+ continue
90
+ else:
91
+ if "TRACK_START" in bar:
92
+ self.piece_by_track[track_id]["bars"].append(bar)
93
+ else:
94
+ self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)
95
+
96
+ def get_all_instr_bars(self, track_id):
97
+ return self.piece_by_track[track_id]["bars"]
98
+
99
+ def striping_track_ends(self, text):
100
+ if "TRACK_END" in text:
101
+ # first get rid of extra space if any
102
+ # then gets rid of "TRACK_END"
103
+ text = text.rstrip(" ").rstrip("TRACK_END")
104
+ return text
105
+
106
+ def get_last_generated_track(self, full_piece):
107
+ track = (
108
+ "TRACK_START "
109
+ + self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
110
+ + "TRACK_END "
111
+ ) # forcing the space after track and
112
+ return track
113
+
114
+ def get_selected_track_as_text(self, track_id):
115
+ text = ""
116
+ for bar in self.piece_by_track[track_id]["bars"]:
117
+ text += bar
118
+ text += "TRACK_END "
119
+ return text
120
+
121
+ @staticmethod
122
+ def get_newly_generated_text(input_prompt, full_piece):
123
+ return full_piece[len(input_prompt) :]
124
+
125
+ def get_whole_piece_from_bar_dict(self):
126
+ text = "PIECE_START "
127
+ for track_id, _ in enumerate(self.piece_by_track):
128
+ text += self.get_selected_track_as_text(track_id)
129
+ return text
130
+
131
+ def delete_one_track(self, track): # TO BE TESTED
132
+ self.piece_by_track.pop(track)
133
+
134
+ # def update_piece_dict__add_track(self, track_id, track):
135
+ # self.piece_dict[track_id] = track
136
+
137
+ # def update_all_dictionnaries__add_track(self, track):
138
+ # self.update_piece_dict__add_track(track_id, track)
139
+
140
+ """Basic generation tools"""
141
+
142
+ def tokenize_input_prompt(self, input_prompt, verbose=True):
143
+ """Tokenizing prompt
144
+
145
+ Args:
146
+ - input_prompt (str): prompt to tokenize
147
+
148
+ Returns:
149
+ - input_prompt_ids (torch.tensor): tokenized prompt
150
+ """
151
+ if verbose:
152
+ print("Tokenizing input_prompt...")
153
+
154
+ return self.tokenizer.encode(input_prompt, return_tensors="pt")
155
+
156
+ def generate_sequence_of_token_ids(
157
+ self,
158
+ input_prompt_ids,
159
+ temperature,
160
+ verbose=True,
161
+ ):
162
+ """
163
+ generate a sequence of token ids based on input_prompt_ids
164
+ The sequence length depends on the trained model (self.model_n_bar)
165
+ """
166
+ generated_ids = self.model.generate(
167
+ input_prompt_ids,
168
+ max_length=self.max_length,
169
+ do_sample=True,
170
+ temperature=temperature,
171
+ no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0
172
+ eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good
173
+ )
174
+
175
+ if verbose:
176
+ print("Generating a token_id sequence...")
177
+
178
+ return generated_ids
179
+
180
+ def convert_ids_to_text(self, generated_ids, verbose=True):
181
+ """converts the token_ids to text"""
182
+ generated_text = self.tokenizer.decode(generated_ids[0])
183
+ if verbose:
184
+ print("Converting token sequence to MidiText...")
185
+ return generated_text
186
+
187
+ def generate_until_track_end(
188
+ self,
189
+ input_prompt="PIECE_START ",
190
+ instrument=None,
191
+ density=None,
192
+ temperature=None,
193
+ verbose=True,
194
+ expected_length=None,
195
+ ):
196
+
197
+ """generate until the TRACK_END token is reached
198
+ full_piece = input_prompt + generated"""
199
+ if expected_length is None:
200
+ expected_length = self.model_n_bar
201
+
202
+ if instrument is not None:
203
+ input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
204
+ if density is not None:
205
+ input_prompt = f"{input_prompt}DENSITY={str(density)} "
206
+
207
+ if instrument is None and density is not None:
208
+ print("Density cannot be defined without an input_prompt instrument #TOFIX")
209
+
210
+ if temperature is None:
211
+ ValueError("Temperature must be defined")
212
+
213
+ if verbose:
214
+ print("--------------------")
215
+ print(
216
+ f"Generating {instrument} - Density {density} - temperature {temperature}"
217
+ )
218
+ bar_count_checks = False
219
+ failed = 0
220
+ while not bar_count_checks: # regenerate until right length
221
+ input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
222
+ generated_tokens = self.generate_sequence_of_token_ids(
223
+ input_prompt_ids, temperature, verbose=verbose
224
+ )
225
+ full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
226
+ generated = self.get_newly_generated_text(input_prompt, full_piece)
227
+ # bar_count_checks
228
+ bar_count_checks, bar_count = bar_count_check(generated, expected_length)
229
+
230
+ if not self.force_sequence_length:
231
+ # set bar_count_checks to true to exist the while loop
232
+ bar_count_checks = True
233
+
234
+ if not bar_count_checks and self.force_sequence_length:
235
+ # if the generated sequence is not the expected length
236
+ if failed > 1:
237
+ full_piece, bar_count_checks = forcing_bar_count(
238
+ input_prompt,
239
+ generated,
240
+ bar_count,
241
+ expected_length,
242
+ )
243
+ else:
244
+ print('"--- Wrong length - Regenerating ---')
245
+ if not bar_count_checks:
246
+ failed += 1
247
+ if failed > 2:
248
+ bar_count_checks = True # TOFIX exit the while loop
249
+
250
+ return full_piece
251
+
252
+ def generate_one_new_track(
253
+ self,
254
+ instrument,
255
+ density,
256
+ temperature,
257
+ input_prompt="PIECE_START ",
258
+ ):
259
+ self.initiate_track_dict(instrument, density, temperature)
260
+ full_piece = self.generate_until_track_end(
261
+ input_prompt=input_prompt,
262
+ instrument=instrument,
263
+ density=density,
264
+ temperature=temperature,
265
+ )
266
+
267
+ track = self.get_last_generated_track(full_piece)
268
+ self.update_track_dict__add_bars(track, -1)
269
+ full_piece = self.get_whole_piece_from_bar_dict()
270
+ return full_piece
271
+
272
+ """ Piece generation - Basics """
273
+
274
+ def generate_piece(self, instrument_list, density_list, temperature_list):
275
+ """generate a sequence with mutiple tracks
276
+ - inst_list sets the list of instruments of the order of generation
277
+ - density is paired with inst_list
278
+ Each track/intrument is generated on a prompt which contains the previously generated track/instrument
279
+ This means that the first instrument is generated with less bias than the next one, and so on.
280
+
281
+ 'generated_piece' keeps track of the entire piece
282
+ 'generated_piece' is returned by self.generate_until_track_end
283
+ # it is returned by self.generate_until_track_end"""
284
+
285
+ generated_piece = "PIECE_START "
286
+ for instrument, density, temperature in zip(
287
+ instrument_list, density_list, temperature_list
288
+ ):
289
+ generated_piece = self.generate_one_new_track(
290
+ instrument,
291
+ density,
292
+ temperature,
293
+ input_prompt=generated_piece,
294
+ )
295
+
296
+ # generated_piece = self.get_whole_piece_from_bar_dict()
297
+ self.check_the_piece_for_errors()
298
+ return generated_piece
299
+
300
+ """ Piece generation - Extra Bars """
301
+
302
+ @staticmethod
303
+ def process_prompt_for_next_bar(self, track_idx):
304
+ """Processing the prompt for the model to generate one more bar only.
305
+ The prompt containts:
306
+ if not the first bar: the previous, already processed, bars of the track
307
+ the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
308
+ the last (self.model_n_bar)-1 bars of the track
309
+ Args:
310
+ track_idx (int): the index of the track to be processed
311
+
312
+ Returns:
313
+ the processed prompt for generating the next bar
314
+ """
315
+ track = self.piece_by_track[track_idx]
316
+ # for bars which are not the bar to prolong
317
+ pre_promt = "PIECE_START "
318
+ for i, othertrack in enumerate(self.piece_by_track):
319
+ if i != track_idx:
320
+ len_diff = len(othertrack["bars"]) - len(track["bars"])
321
+ if len_diff > 0:
322
+ # if other bars are longer, it mean that this one should catch up
323
+ pre_promt += othertrack["bars"][0]
324
+ for bar in track["bars"][-self.model_n_bar :]:
325
+ pre_promt += bar
326
+ pre_promt += "TRACK_END "
327
+ elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
328
+ # adding an empty bars at the end of the other tracks if they have not been processed yet
329
+ pre_promt += othertracks["bars"][0]
330
+ for bar in track["bars"][-(self.model_n_bar - 1) :]:
331
+ pre_promt += bar
332
+ for _ in range(abs(len_diff) + 1):
333
+ pre_promt += "BAR_START BAR_END "
334
+ pre_promt += "TRACK_END "
335
+
336
+ # for the bar to prolong
337
+ # initialization e.g TRACK_START INST=DRUMS DENSITY=2
338
+ processed_prompt = track["bars"][0]
339
+ for bar in track["bars"][-(self.model_n_bar - 1) :]:
340
+ # adding the "last" bars of the track
341
+ processed_prompt += bar
342
+
343
+ processed_prompt += "BAR_START "
344
+ print(
345
+ f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
346
+ )
347
+ return pre_promt + processed_prompt
348
+
349
+ def generate_one_more_bar(self, i):
350
+ """Generate one more bar from the input_prompt"""
351
+ processed_prompt = self.process_prompt_for_next_bar(self, i)
352
+ prompt_plus_bar = self.generate_until_track_end(
353
+ input_prompt=processed_prompt,
354
+ temperature=self.piece_by_track[i]["temperature"],
355
+ expected_length=1,
356
+ verbose=False,
357
+ )
358
+ added_bar = self.get_newly_generated_bar(prompt_plus_bar)
359
+ self.update_track_dict__add_bars(added_bar, i)
360
+
361
+ def get_newly_generated_bar(self, prompt_plus_bar):
362
+ return "BAR_START " + self.striping_track_ends(
363
+ prompt_plus_bar.split("BAR_START ")[-1]
364
+ )
365
+
366
+ def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
367
+ """Generate n more bars from the input_prompt"""
368
+ if only_this_track is None:
369
+ only_this_track
370
+
371
+ print(f"================== ")
372
+ print(f"Adding {n_bars} more bars to the piece ")
373
+ for bar_id in range(n_bars):
374
+ print(f"----- added bar #{bar_id+1} --")
375
+ for i, track in enumerate(self.piece_by_track):
376
+ if only_this_track is None or i == only_this_track:
377
+ print(f"--------- {track['label']}")
378
+ self.generate_one_more_bar(i)
379
+ self.check_the_piece_for_errors()
380
+
381
+ def check_the_piece_for_errors(self, piece: str = None):
382
+
383
+ if piece is None:
384
+ piece = generate_midi.get_whole_piece_from_bar_dict()
385
+ errors = []
386
+ errors.append(
387
+ [
388
+ (token, id)
389
+ for id, token in enumerate(piece.split(" "))
390
+ if token not in self.tokenizer.vocab or token == "UNK"
391
+ ]
392
+ )
393
+ if len(errors) > 0:
394
+ # print(piece)
395
+ for er in errors:
396
+ er
397
+ print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
398
+ print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])
399
+
400
+
401
+ if __name__ == "__main__":
402
+
403
+ # worker
404
+ DEVICE = "cpu"
405
+
406
+ # define generation parameters
407
+ N_FILES_TO_GENERATE = 2
408
+ Temperatures_to_try = [0.7]
409
+
410
+ USE_FAMILIZED_MODEL = True
411
+ force_sequence_length = True
412
+
413
+ if USE_FAMILIZED_MODEL:
414
+ # model_repo = "misnaej/the-jam-machine-elec-famil"
415
+ # model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
416
+
417
+ # model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
418
+ # n_bar_generated = 8
419
+
420
+ model_repo = "JammyMachina/improved_4bars-mdl"
421
+ n_bar_generated = 4
422
+ instrument_promt_list = ["4", "DRUMS", "3"]
423
+ # DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
424
+ density_list = [3, 2, 2]
425
+ # temperature_list = [0.7, 0.7, 0.75]
426
+ else:
427
+ model_repo = "misnaej/the-jam-machine"
428
+ instrument_promt_list = ["30"] # , "DRUMS", "0"]
429
+ density_list = [3] # , 2, 3]
430
+ # temperature_list = [0.7, 0.5, 0.75]
431
+ pass
432
+
433
+ # define generation directory
434
+ generated_sequence_files_path = define_generation_dir(model_repo)
435
+
436
+ # load model and tokenizer
437
+ model, tokenizer = LoadModel(
438
+ model_repo, from_huggingface=True
439
+ ).load_model_and_tokenizer()
440
+
441
+ # does the prompt make sense
442
+ check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
443
+
444
+ for temperature in Temperatures_to_try:
445
+ print(f"================= TEMPERATURE {temperature} =======================")
446
+ for _ in range(N_FILES_TO_GENERATE):
447
+ print(f"========================================")
448
+ # 1 - instantiate
449
+ generate_midi = GenerateMidiText(model, tokenizer)
450
+ # 0 - set the n_bar for this model
451
+ generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
452
+ # 1 - defines the instruments, densities and temperatures
453
+ # 2- generate the first 8 bars for each instrument
454
+ generate_midi.set_improvisation_level(30)
455
+ generate_midi.generate_piece(
456
+ instrument_promt_list,
457
+ density_list,
458
+ [temperature for _ in density_list],
459
+ )
460
+ # 3 - force the model to improvise
461
+ # generate_midi.set_improvisation_level(20)
462
+ # 4 - generate the next 4 bars for each instrument
463
+ # generate_midi.generate_n_more_bars(n_bar_generated)
464
+ # 5 - lower the improvisation level
465
+ generate_midi.generated_piece = (
466
+ generate_midi.get_whole_piece_from_bar_dict()
467
+ )
468
+
469
+ # print the generated sequence in terminal
470
+ print("=========================================")
471
+ print(generate_midi.generated_piece)
472
+ print("=========================================")
473
+
474
+ # write to JSON file
475
+ filename = WriteTextMidiToFile(
476
+ generate_midi,
477
+ generated_sequence_files_path,
478
+ ).text_midi_to_file()
479
+
480
+ # decode the sequence to MIDI """
481
+ decode_tokenizer = get_miditok()
482
+ TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
483
+ generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
484
+ )
485
+ inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
486
+ max_time = get_max_time(inst_midi)
487
+ plot_piano_roll(inst_midi)
488
+
489
+ print("Et voilà! Your MIDI file is ready! GO JAM!")
load.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel
2
+ from transformers import PreTrainedTokenizerFast
3
+ import os
4
+
5
+
6
+ class LoadModel:
7
+ """
8
+ Example usage:
9
+
10
+ # if loading model and tokenizer from Huggingface
11
+ model_repo = "misnaej/the-jam-machine"
12
+ model, tokenizer = LoadModel(
13
+ model_repo, from_huggingface=True
14
+ ).load_model_and_tokenizer()
15
+
16
+ # if loading model and tokenizer from a local folder
17
+ model_path = "models/model_2048_wholedataset"
18
+ model, tokenizer = LoadModel(
19
+ model_path, from_huggingface=False
20
+ ).load_model_and_tokenizer()
21
+
22
+ """
23
+
24
+ def __init__(self, path, from_huggingface=True, device="cpu", revision=None):
25
+ # path is either a relative path on a local/remote machine or a model repo on HuggingFace
26
+ if not from_huggingface:
27
+ if not os.path.exists(path):
28
+ print(path)
29
+ raise Exception("Model path does not exist")
30
+ self.from_huggingface = from_huggingface
31
+ self.path = path
32
+ self.device = device
33
+ self.revision = revision
34
+
35
+ def load_model_and_tokenizer(self):
36
+ model = self.load_model()
37
+ tokenizer = self.load_tokenizer()
38
+
39
+ return model, tokenizer
40
+
41
+ def load_model(self):
42
+ if self.revision is None:
43
+ model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
44
+ else:
45
+ model = GPT2LMHeadModel.from_pretrained(
46
+ self.path, revision=self.revision
47
+ ).to(self.device)
48
+
49
+ return model
50
+
51
+ def load_tokenizer(self):
52
+ if self.from_huggingface:
53
+ pass
54
+ else:
55
+ if not os.path.exists(f"{self.path}/tokenizer.json"):
56
+ raise Exception(
57
+ f"There is no 'tokenizer.json'file in the defined {self.path}"
58
+ )
59
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(self.path)
60
+ return tokenizer
playback.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import librosa.display
3
+ from pretty_midi import PrettyMIDI
4
+
5
+
6
+ # Note: these functions are meant to be played within an interactive Python shell
7
+ # Please refer to the synth.ipynb for an example of how to use them
8
+
9
+
10
+ def get_music(midi_file):
11
+ """
12
+ Load a midi file and return the PrettyMIDI object and the audio signal
13
+ """
14
+ music = PrettyMIDI(midi_file=midi_file)
15
+ waveform = music.fluidsynth()
16
+ return music, waveform
17
+
18
+
19
+ def show_piano_roll(music_notes, fs=100):
20
+ """
21
+ Show the piano roll of a music piece, with all instruments squashed onto a single 128xN matrix
22
+ :param music_notes: PrettyMIDI object
23
+ :param fs: sampling frequency
24
+ """
25
+ # get the piano roll
26
+ piano_roll = music_notes.get_piano_roll(fs)
27
+ print("Piano roll shape: {}".format(piano_roll.shape))
28
+
29
+ # plot the piano roll
30
+ plt.figure(figsize=(12, 4))
31
+ librosa.display.specshow(piano_roll, sr=100, x_axis="time", y_axis="cqt_note")
32
+ plt.colorbar()
33
+ plt.title("Piano roll")
34
+ plt.tight_layout()
35
+ plt.show()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ matplotlib
3
+ sys
4
+ matplotlib
5
+ numpy
utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from miditok import Event, MIDILike
3
+ import os
4
+ import json
5
+ from time import perf_counter
6
+ from joblib import Parallel, delayed
7
+ from zipfile import ZipFile, ZIP_DEFLATED
8
+ from scipy.io.wavfile import write
9
+ import numpy as np
10
+ from pydub import AudioSegment
11
+ import shutil
12
+
13
+
14
+ def writeToFile(path, content):
15
+ if type(content) is dict:
16
+ with open(f"{path}", "w") as json_file:
17
+ json.dump(content, json_file)
18
+ else:
19
+ if type(content) is not str:
20
+ content = str(content)
21
+ os.makedirs(os.path.dirname(path), exist_ok=True)
22
+ with open(path, "w") as f:
23
+ f.write(content)
24
+
25
+
26
+ # Function to read from text from txt file:
27
+ def readFromFile(path, isJSON=False):
28
+ with open(path, "r") as f:
29
+ if isJSON:
30
+ return json.load(f)
31
+ else:
32
+ return f.read()
33
+
34
+
35
+ def chain(input, funcs, *params):
36
+ res = input
37
+ for func in funcs:
38
+ try:
39
+ res = func(res, *params)
40
+ except TypeError:
41
+ res = func(res)
42
+ return res
43
+
44
+
45
+ def to_beat_str(value, beat_res=8):
46
+ values = [
47
+ int(int(value * beat_res) / beat_res),
48
+ int(int(value * beat_res) % beat_res),
49
+ beat_res,
50
+ ]
51
+ return ".".join(map(str, values))
52
+
53
+
54
+ def to_base10(beat_str):
55
+ integer, decimal, base = split_dots(beat_str)
56
+ return integer + decimal / base
57
+
58
+
59
+ def split_dots(value):
60
+ return list(map(int, value.split(".")))
61
+
62
+
63
+ def compute_list_average(l):
64
+ return sum(l) / len(l)
65
+
66
+
67
+ def get_datetime():
68
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
69
+
70
+
71
+ def get_text(event):
72
+ match event.type:
73
+ case "Piece-Start":
74
+ return "PIECE_START "
75
+ case "Track-Start":
76
+ return "TRACK_START "
77
+ case "Track-End":
78
+ return "TRACK_END "
79
+ case "Instrument":
80
+ return f"INST={event.value} "
81
+ case "Bar-Start":
82
+ return "BAR_START "
83
+ case "Bar-End":
84
+ return "BAR_END "
85
+ case "Time-Shift":
86
+ return f"TIME_SHIFT={event.value} "
87
+ case "Note-On":
88
+ return f"NOTE_ON={event.value} "
89
+ case "Note-Off":
90
+ return f"NOTE_OFF={event.value} "
91
+ case _:
92
+ return ""
93
+
94
+
95
+ def get_event(text, value=None):
96
+ match text:
97
+ case "PIECE_START":
98
+ return Event("Piece-Start", value)
99
+ case "TRACK_START":
100
+ return None
101
+ case "TRACK_END":
102
+ return None
103
+ case "INST":
104
+ return Event("Instrument", value)
105
+ case "BAR_START":
106
+ return Event("Bar-Start", value)
107
+ case "BAR_END":
108
+ return Event("Bar-End", value)
109
+ case "TIME_SHIFT":
110
+ return Event("Time-Shift", value)
111
+ case "TIME_DELTA":
112
+ return Event("Time-Shift", to_beat_str(int(value) / 4))
113
+ case "NOTE_ON":
114
+ return Event("Note-On", value)
115
+ case "NOTE_OFF":
116
+ return Event("Note-Off", value)
117
+ case _:
118
+ return None
119
+
120
+
121
+ # TODO: Make this singleton
122
+ def get_miditok():
123
+ pitch_range = range(0, 140) # was (21, 109)
124
+ beat_res = {(0, 400): 8}
125
+ return MIDILike(pitch_range, beat_res)
126
+
127
+
128
+ class WriteTextMidiToFile: # utils saving to file
129
+ def __init__(self, generate_midi, output_path):
130
+ self.generated_midi = generate_midi.generated_piece
131
+ self.output_path = output_path
132
+ self.hyperparameter_and_bars = generate_midi.piece_by_track
133
+
134
+ def hashing_seq(self):
135
+ self.current_time = get_datetime()
136
+ self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
137
+
138
+ def wrapping_seq_hyperparameters_in_dict(self):
139
+ # assert type(self.generated_midi) is str, "error: generate_midi must be a string"
140
+ # assert (
141
+ # type(self.hyperparameter_dict) is dict
142
+ # ), "error: feature_dict must be a dictionnary"
143
+ return {
144
+ "generate_midi": self.generated_midi,
145
+ "hyperparameters_and_bars": self.hyperparameter_and_bars,
146
+ }
147
+
148
+ def text_midi_to_file(self):
149
+ self.hashing_seq()
150
+ output_dict = self.wrapping_seq_hyperparameters_in_dict()
151
+ print(f"Token generate_midi written: {self.output_path_filename}")
152
+ writeToFile(self.output_path_filename, output_dict)
153
+ return self.output_path_filename
154
+
155
+
156
+ def get_files(directory, extension, recursive=False):
157
+ """
158
+ Given a directory, get a list of the file paths of all files matching the
159
+ specified file extension.
160
+ directory: the directory to search as a Path object
161
+ extension: the file extension to match as a string
162
+ recursive: whether to search recursively in the directory or not
163
+ """
164
+ if recursive:
165
+ return list(directory.rglob(f"*.{extension}"))
166
+ else:
167
+ return list(directory.glob(f"*.{extension}"))
168
+
169
+
170
+ def timeit(func):
171
+ def wrapper(*args, **kwargs):
172
+ start = perf_counter()
173
+ result = func(*args, **kwargs)
174
+ end = perf_counter()
175
+ print(f"{func.__name__} took {end - start:.2f} seconds to run.")
176
+ return result
177
+
178
+ return wrapper
179
+
180
+
181
+ class FileCompressor:
182
+ def __init__(self, input_directory, output_directory, n_jobs=-1):
183
+ self.input_directory = input_directory
184
+ self.output_directory = output_directory
185
+ self.n_jobs = n_jobs
186
+
187
+ # File compression and decompression
188
+ def unzip_file(self, file):
189
+ """uncompress single zip file"""
190
+ with ZipFile(file, "r") as zip_ref:
191
+ zip_ref.extractall(self.output_directory)
192
+
193
+ def zip_file(self, file):
194
+ """compress a single text file to a new zip file and delete the original"""
195
+ output_file = self.output_directory / (file.stem + ".zip")
196
+ with ZipFile(output_file, "w") as zip_ref:
197
+ zip_ref.write(file, arcname=file.name, compress_type=ZIP_DEFLATED)
198
+ file.unlink()
199
+
200
+ @timeit
201
+ def unzip(self):
202
+ """uncompress all zip files in folder"""
203
+ files = get_files(self.input_directory, extension="zip")
204
+ Parallel(n_jobs=self.n_jobs)(delayed(self.unzip_file)(file) for file in files)
205
+
206
+ @timeit
207
+ def zip(self):
208
+ """compress all text files in folder to new zip files and remove the text files"""
209
+ files = get_files(self.output_directory, extension="txt")
210
+ Parallel(n_jobs=self.n_jobs)(delayed(self.zip_file)(file) for file in files)
211
+
212
+
213
+ def load_jsonl(filepath):
214
+ """Load a jsonl file"""
215
+ with open(filepath, "r") as f:
216
+ data = [json.loads(line) for line in f]
217
+ return data
218
+
219
+
220
+ def write_mp3(waveform, output_path, bitrate="92k"):
221
+ """
222
+ Write a waveform to an mp3 file.
223
+ output_path: Path object for the output mp3 file
224
+ waveform: numpy array of the waveform
225
+ bitrate: bitrate of the mp3 file (64k, 92k, 128k, 256k, 312k)
226
+ """
227
+ # write the wav file
228
+ wav_path = output_path.with_suffix(".wav")
229
+ write(wav_path, 44100, waveform.astype(np.float32))
230
+ # compress the wav file as mp3
231
+ AudioSegment.from_wav(wav_path).export(output_path, format="mp3", bitrate=bitrate)
232
+ # remove the wav file
233
+ wav_path.unlink()
234
+
235
+
236
+ def copy_file(input_file, output_dir):
237
+ """Copy an input file to the output_dir"""
238
+ output_file = output_dir / input_file.name
239
+ shutil.copy(input_file, output_file)
240
+
241
+
242
+ def index_has_substring(list, substring):
243
+ for i, s in enumerate(list):
244
+ if substring in s:
245
+ return i
246
+ return -1