mrfakename commited on
Commit
fde49fc
·
verified ·
1 Parent(s): 2969d5b

Update musiclib.py

Browse files
Files changed (1) hide show
  1. musiclib.py +4 -3
musiclib.py CHANGED
@@ -1,5 +1,5 @@
1
  # apache 2.0 license, modified by mrfakename, from https://github.com/BlinkDL/ChatRWKV/tree/main/music
2
-
3
  import os, sys
4
  import numpy as np
5
  from cached_path import cached_path
@@ -16,8 +16,9 @@ MODEL_FILE = str(cached_path('hf://BlinkDL/rwkv-4-music/RWKV-4-MIDI-120M-v1-2023
16
 
17
  ABC_MODE = ('-ABC-' in MODEL_FILE)
18
  MIDI_MODE = ('-MIDI-' in MODEL_FILE)
19
-
20
- model = RWKV(model=MODEL_FILE, strategy='mps fp32')
 
21
  pipeline = PIPELINE(model, "tokenizer-midi.json")
22
 
23
  tokenizer = pipeline
 
1
  # apache 2.0 license, modified by mrfakename, from https://github.com/BlinkDL/ChatRWKV/tree/main/music
2
+ import torch
3
  import os, sys
4
  import numpy as np
5
  from cached_path import cached_path
 
16
 
17
  ABC_MODE = ('-ABC-' in MODEL_FILE)
18
  MIDI_MODE = ('-MIDI-' in MODEL_FILE)
19
+ device = 'cpu'
20
+ if torch.cuda.is_available(): device = 'cuda'
21
+ model = RWKV(model=MODEL_FILE, strategy=f'{device} fp32')
22
  pipeline = PIPELINE(model, "tokenizer-midi.json")
23
 
24
  tokenizer = pipeline