asigalov61 commited on
Commit
465603c
·
verified ·
1 Parent(s): 2020c08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -54,6 +54,8 @@ print('=' * 70)
54
 
55
  #==================================================================================
56
 
 
 
57
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
58
 
59
  NUM_OUT_BATCHES = 8
@@ -90,9 +92,7 @@ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
90
  print('=' * 70)
91
  print('Loading model checkpoint...')
92
 
93
- model_path = 'Giant_Music_Transformer_Medium_Trained_Model_20355_steps_0.709_loss_0.812_acc.pth'
94
-
95
- model.load_state_dict(torch.load(model_path, map_location='cpu'))
96
 
97
  print('=' * 70)
98
  print('Done!')
 
54
 
55
  #==================================================================================
56
 
57
+ MODEL_CHECKPOINT = 'Giant_Music_Transformer_Medium_Trained_Model_25603_steps_0.3799_loss_0.8934_acc.pth'
58
+
59
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
60
 
61
  NUM_OUT_BATCHES = 8
 
92
  print('=' * 70)
93
  print('Loading model checkpoint...')
94
 
95
+ model.load_state_dict(torch.load(MODEL_CHECKPOINT, map_location='cpu'))
 
 
96
 
97
  print('=' * 70)
98
  print('Done!')