asigalov61 commited on
Commit
0792426
·
verified ·
1 Parent(s): 5bb58b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -44
app.py CHANGED
@@ -57,6 +57,46 @@ NUM_OUT_BATCHES = 8
57
 
58
  #==================================================================================
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def load_midi(input_midi):
61
 
62
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
@@ -219,52 +259,8 @@ def generate_music(prime,
219
  model_sampling_top_p
220
  ):
221
 
222
-
223
- #==============================================================================
224
-
225
- print('=' * 70)
226
- print('Instantiating model...')
227
-
228
- device_type = 'cuda'
229
- dtype = 'bfloat16'
230
-
231
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
232
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
233
-
234
- SEQ_LEN = 8192
235
- PAD_IDX = 19463
236
-
237
- model = TransformerWrapper(
238
- num_tokens = PAD_IDX+1,
239
- max_seq_len = SEQ_LEN,
240
- attn_layers = Decoder(dim = 2048,
241
- depth = 8,
242
- heads = 32,
243
- rotary_pos_emb = True,
244
- attn_flash = True
245
- )
246
- )
247
-
248
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
249
-
250
- print('=' * 70)
251
- print('Loading model checkpoint...')
252
-
253
- model_path = 'Giant_Music_Transformer_Medium_Trained_Model_10446_steps_0.7202_loss_0.8233_acc.pth'
254
-
255
- model.load_state_dict(torch.load(model_path))
256
-
257
- print('=' * 70)
258
-
259
  model.cuda()
260
  model.eval()
261
-
262
- print('Done!')
263
- print('=' * 70)
264
- print('Model will use', dtype, 'precision...')
265
- print('=' * 70)
266
-
267
- #==============================================================================
268
 
269
  print('Generating...')
270
 
 
57
 
58
  #==================================================================================
59
 
60
+ print('=' * 70)
61
+ print('Instantiating model...')
62
+
63
+ device_type = 'cuda'
64
+ dtype = 'bfloat16'
65
+
66
+ ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
67
+ ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
68
+
69
+ SEQ_LEN = 8192
70
+ PAD_IDX = 19463
71
+
72
+ model = TransformerWrapper(
73
+ num_tokens = PAD_IDX+1,
74
+ max_seq_len = SEQ_LEN,
75
+ attn_layers = Decoder(dim = 2048,
76
+ depth = 8,
77
+ heads = 32,
78
+ rotary_pos_emb = True,
79
+ attn_flash = True
80
+ )
81
+ )
82
+
83
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
84
+
85
+ print('=' * 70)
86
+ print('Loading model checkpoint...')
87
+
88
+ model_path = 'Giant_Music_Transformer_Medium_Trained_Model_10446_steps_0.7202_loss_0.8233_acc.pth'
89
+
90
+ model.load_state_dict(torch.load(model_path, map_location='cpu'))
91
+
92
+ print('=' * 70)
93
+ print('Done!')
94
+ print('=' * 70)
95
+ print('Model will use', dtype, 'precision...')
96
+ print('=' * 70)
97
+
98
+ #==================================================================================
99
+
100
  def load_midi(input_midi):
101
 
102
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
 
259
  model_sampling_top_p
260
  ):
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  model.cuda()
263
  model.eval()
 
 
 
 
 
 
 
264
 
265
  print('Generating...')
266