#================================================================================== # https://huggingface.co/spaces/asigalov61/Ultimate-Chords-Progressions-Transformer #================================================================================== import time as reqtime import datetime from pytz import timezone import torch import spaces import gradio as gr from x_transformer_1_23_2 import * import random import statistics import copy import tqdm from midi_to_colab_audio import midi_to_colab_audio import TMIDIX import matplotlib.pyplot as plt # ================================================================================================= # @spaces.GPU def Generate_Chords(input_midi, input_num_prime_chords, input_num_gen_chords): print('=' * 70) print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) start_time = reqtime.time() print('=' * 70) print('Instantiating the model...') SEQ_LEN = 8192 PAD_IDX = 2239 DEVICE = 'cpu' # 'cpu' # instantiate the model model = TransformerWrapper( num_tokens = PAD_IDX+1, max_seq_len = SEQ_LEN, attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, rotary_pos_emb = True, attn_flash = True ) ) model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX) model.to(DEVICE) print('Done!') print('=' * 70) print('Loading model checkpoint...') model.load_state_dict( torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth', map_location=DEVICE)) model.eval() print('Done!') print('=' * 70) if DEVICE == 'cpu': dtype = torch.bfloat16 else: dtype = torch.bfloat16 ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) print('Done!') print('=' * 70) fn = os.path.basename(input_midi.name) fn1 = fn.split('.')[0] print('=' * 70) print('Input file name:', fn) print('Num prime chords:', input_num_prime_chords) print('Num gen chords:', input_num_gen_chords) print('=' * 70) #=============================================================================== raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name) escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True) if escore_notes: escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=32, legacy_timings=True) if escore_notes: #======================================================= # PRE-PROCESSING # checking number of instruments in a composition instruments_list = sorted(set([e[6] for e in escore_notes])) instruments_list_without_drums = sorted(set([e[6] for e in escore_notes if e[3] != 9])) main_instruments_list = sorted(set([e[6] for e in escore_notes if e[6] < 80])) comp_times = [e[1] for e in escore_notes if e[6] < 80] comp_dtimes = [max(1, min(127, b-a)) for a, b in zip(comp_times[:-1], comp_times[1:]) if b-a != 0] avg_comp_dtime = max(0, min(127, int(sum(comp_dtimes) / len(comp_dtimes)))) #======================================================= # FINAL PROCESSING #======================================================= # Adjusting avg velocity vels = [e[5] for e in escore_notes] avg_vel = int(sum(vels) / len(vels)) if avg_vel < 60: TMIDIX.adjust_score_velocities(escore_notes, avg_vel * 2) melody_chords = [] melody_chords2 = [] mel_cho = [] #======================================================= # Break between compositions / Intro seq if 128 in instruments_list: drums_present = 1931 # Yes else: drums_present = 1930 # No melody_chords.extend([1929, drums_present]) mel_cho.extend([1929, drums_present]) #======================================================= # Composition patches list melody_chords.extend([i+1932 for i in instruments_list_without_drums]) mel_cho.extend([i+1932 for i in instruments_list_without_drums]) #======================================================= # Composition avg pitch and dtime mode_instruments_pitch = statistics.mode([e[4] for e in escore_notes if e[6] < 80]) melody_chords.extend([2060+mode_instruments_pitch, 2188+avg_comp_dtime]) mel_cho.extend([2060+mode_instruments_pitch, 2188+avg_comp_dtime]) melody_chords2.append(mel_cho) #======================================================= # MAIN PROCESSING CYCLE #======================================================= cscore = TMIDIX.chordify_score([1000, escore_notes]) pc = cscore[0] # Previous chord for i, c in enumerate(cscore): c.sort(key=lambda x: x[6]) # Sorting by patch #======================================================= # Outro seq #if len(cscore) > 256: # if len(cscore) - i == 64: # melody_chords.extend([2236]) #======================================================= # Timings... # Cliping all values... delta_time = max(0, min(127, c[0][1]-pc[0][1])) #======================================================= # Chords... cpitches = sorted([e[4] for e in c if e[3] != 9]) dpitches = [e[4] for e in c if e[3] == 9] tones_chord = sorted(set([p % 12 for p in cpitches])) if tones_chord: if tones_chord not in TMIDIX.ALL_CHORDS_SORTED: tones_chord_tok = 644 tones_chord_tok = TMIDIX.ALL_CHORDS_SORTED.index(TMIDIX.advanced_check_and_fix_tones_chord(tones_chord, cpitches[-1])) else: tones_chord_tok = TMIDIX.ALL_CHORDS_SORTED.index(tones_chord) # 321 if dpitches: if tones_chord_tok == 644: tones_chord_tok = 645 else: tones_chord_tok += 321 else: tones_chord_tok = 643 # Drums-only chord #======================================================= # Writing chord/time... melody_chords.extend([tones_chord_tok, delta_time+646]) mel_cho = [] mel_cho.extend([tones_chord_tok, delta_time+646]) #======================================================= # Notes... pp = -1 for e in c: #======================================================= # Duration dur = max(0, min(63, int(max(0, e[2] // 4) * 2))) # Pitch ptc = max(1, min(127, e[4])) # Octo-velocity vel = max(8, min(127, (max(1, e[5] // 8) * 8))) velocity = round(vel / 15)-1 # Patch pat = max(0, min(128, e[6])) if 7 < pat < 80: ptc += 128 elif 79 < pat < 128: ptc += 256 elif pat == 128: ptc += 384 #======================================================= # FINAL NOTE SEQ # Writing final note asynchronously dur_vel = (8 * dur) + velocity # 512 if pat != pp: melody_chords.extend([pat+774, ptc+904, dur_vel+1416]) # 1928 mel_cho.extend([pat+774, ptc+904, dur_vel+1416]) else: melody_chords.extend([ptc+904, dur_vel+1416]) mel_cho.extend([ptc+904, dur_vel+1416]) pp = pat pc = c melody_chords2.append(mel_cho) #======================================================= #melody_chords.extend([2237]) # EOS #======================================================= # TOTAL DICTIONARY SIZE 2237+1=2238 #======================================================= print('Done!') print('=' * 70) print('Melody chords length:', len(melody_chords)) print('=' * 70) #================================================================== print('=' * 70) print('Sample output events', melody_chords[:12]) print('=' * 70) print('Generating...') output = [] for m in melody_chords2[:input_num_prime_chords]: output.extend(m) for ct in tqdm.tqdm(melody_chords2[input_num_prime_chords:input_num_prime_chords+input_num_gen_chords]): output.extend(ct[:2]) y = 774 while y > 773: x = torch.LongTensor(output).to(DEVICE) with ctx: out = model.generate(x, 1, filter_logits_fn=top_p, filter_kwargs={'thres': 0.96}, temperature=0.9, return_prime=False, verbose=False) y = out.tolist()[0][0] if y > 773: output.append(y) print('=' * 70) print('Done!') print('=' * 70) #=============================================================================== print('Rendering results...') print('=' * 70) print('Sample INTs', output[:12]) print('=' * 70) if len(output) != 0: song = output song_f = [] time = 0 dur = 4 vel = 90 pitch = 60 channel = 0 patches = [0] * 16 patches[9] = 9 for ss in song: if 645 < ss < 774: time += (ss-646) if 773 < ss < 904: pat = (ss - 774) chan = (pat // 8) if 0 <= chan < 9: channel = chan elif 8 < chan < 15: channel = chan + 1 elif chan == 16: channel = 9 if 903 < ss < 1416: pitch = (ss-904) % 128 if 1415 < ss < 1928: dur = (((ss-1416) // 8)+1) * 2 vel = (((ss-1416) % 8)+1) * 15 song_f.append(['note', time, dur, channel, pitch, vel, pat]) song_f, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f) fn1 = "Ultimate-Chords-Progressions-Transformer-Composition" detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, output_signature = 'Ultimate Chords Progressions Transformer', output_file_name = fn1, track_name='Project Los Angeles', list_of_MIDI_patches=patches, timings_multiplier=32 ) new_fn = fn1+'.mid' audio = midi_to_colab_audio(new_fn, soundfont_path=soundfont, sample_rate=16000, volume_scale=10, output_for_gradio=True ) print('Done!') print('=' * 70) #======================================================== output_midi_title = str(fn1) output_midi_summary = str(song_f[:3]) output_midi = str(new_fn) output_audio = (16000, audio) output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True, timings_multiplier=32) print('Output MIDI file name:', output_midi) print('Output MIDI title:', output_midi_title) print('Output MIDI summary:', '') print('=' * 70) #======================================================== print('=' * 70) print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) print('Req execution time:', (reqtime.time() - start_time), 'sec') print('*' * 70) return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot # ================================================================================================= if __name__ == "__main__": PDT = timezone('US/Pacific') print('=' * 70) print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) print('=' * 70) soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" app = gr.Blocks() with app: gr.Markdown("