File size: 9,227 Bytes
d8fad2b
 
 
 
 
 
 
7c7327d
d8fad2b
 
 
 
 
 
7c7327d
 
 
 
a322e01
7c7327d
 
 
d8fad2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a322e01
d8fad2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a322e01
d8fad2b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
from typing import List, Tuple

import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch

from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
model = AutoModelForCausalLM.from_pretrained("juancopi81/lmd-8bars-2048-epochs20_v3")

# Move model to device
model = model.to(device)


def create_seed_string(genre: str = "OTHER") -> str:
    """
    Creates a seed string for generating a new piece.

    Args:
        genre (str, optional): The genre of the piece. Defaults to "OTHER".

    Returns:
        str: The seed string.
    """
    seed_string = f"PIECE_START GENRE={genre} TRACK_START"
    return seed_string


def get_instruments(text_sequence: str) -> List[str]:
    """
    Extracts the list of instruments from a text sequence.

    Args:
        text_sequence (str): The text sequence.

    Returns:
        List[str]: The list of instruments.
    """
    instruments = []
    parts = text_sequence.split()
    for part in parts:
        if part.startswith("INST="):
            if part[5:] == "DRUMS":
                instruments.append("Drums")
            else:
                index = int(part[5:])
                instruments.append(GM_INSTRUMENTS[index])
    return instruments


def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
    """
    Generates a new instrument sequence from a given seed and temperature.

    Args:
        seed (str): The seed string for the generation.
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.

    Returns:
        str: The generated instrument sequence.
    """
    seed_length = len(tokenizer.encode(seed))

    while True:
        # Encode the conditioning tokens.
        input_ids = tokenizer.encode(seed, return_tensors="pt")

        # Move the input_ids tensor to the same device as the model
        input_ids = input_ids.to(model.device)

        # Generate more tokens.
        eos_token_id = tokenizer.encode("TRACK_END")[0]
        generated_ids = model.generate(
            input_ids,
            max_new_tokens=2048,
            do_sample=True,
            temperature=temp,
            eos_token_id=eos_token_id,
        )
        generated_sequence = tokenizer.decode(generated_ids[0])

        # Check if the generated sequence contains "NOTE_ON" beyond the seed
        new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
        if "NOTE_ON" in new_generated_sequence:
            return generated_sequence


def get_outputs_from_string(
    generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
    """
    Converts a generated sequence into various output formats including audio, MIDI, plot, etc.

    Args:
        generated_sequence (str): The generated sequence of tokens.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
                                               instruments string, and number of tokens string.
    """
    instruments = get_instruments(generated_sequence)
    instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
    note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)

    synth = note_seq.fluidsynth
    array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
    int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
    fig = note_seq.plot_sequence(note_sequence, show_figure=False)
    num_tokens = str(len(generated_sequence.split()))
    audio = gr.make_waveform((SAMPLE_RATE, int16_data))
    note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
    return audio, "midi_ouput.mid", fig, instruments_str, num_tokens


def remove_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Removes the last instrument from a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    # We split the song into tracks by splitting on 'TRACK_START'
    tracks = text_sequence.split("TRACK_START")
    # We keep all tracks except the last one
    modified_tracks = tracks[:-1]
    # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
    new_song = "TRACK_START".join(modified_tracks)

    if len(tracks) == 2:
        # There is only one instrument, so start from scratch
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_song
        )
    elif len(tracks) == 1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=""
        )
    else:
        audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
            new_song, qpm
        )

    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def regenerate_last_instrument(
    text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Regenerates the last instrument in a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, new song string, and number of tokens string.
    """
    last_inst_index = text_sequence.rfind("INST=")
    if last_inst_index == -1:
        # No instrument so start from empty sequence
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence="", qpm=qpm
        )
    else:
        # Take it from the last instrument and continue generation
        next_space_index = text_sequence.find(" ", last_inst_index)
        new_seed = text_sequence[:next_space_index]
        audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
            text_sequence=new_seed, qpm=qpm
        )
    return audio, midi_file, fig, instruments_str, new_song, num_tokens


def change_tempo(
    text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Changes the tempo of a song string and returns the various output formats.

    Args:
        text_sequence (str): The song string.
        qpm (int): The new quarter notes per minute.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, text sequence, and number of tokens string.
    """
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        text_sequence, qpm=qpm
    )
    return audio, midi_file, fig, instruments_str, text_sequence, num_tokens


def generate_song(
    genre: str = "OTHER",
    temp: float = 0.75,
    text_sequence: str = "",
    qpm: int = 120,
) -> Tuple[ndarray, str, Figure, str, str, str]:
    """
    Generates a song given a genre, temperature, initial text sequence, and tempo.

    Args:
        model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
        tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
        genre (str, optional): The genre of the song. Defaults to "OTHER".
        temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
        text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
        qpm (int, optional): The quarter notes per minute. Defaults to 120.

    Returns:
        Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
                                                    instruments string, generated song string, and number of tokens string.
    """
    if text_sequence == "":
        seed_string = create_seed_string(genre)
    else:
        seed_string = text_sequence

    generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
    audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
        generated_sequence, qpm
    )
    return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens