support cuda
Browse files
model.py
CHANGED
@@ -28,13 +28,12 @@ class AudioMAEConfig(PretrainedConfig):
|
|
28 |
|
29 |
|
30 |
class AudioMAEEncoder(VisionTransformer):
|
31 |
-
def __init__(self,
|
32 |
-
super().__init__(
|
33 |
"""
|
34 |
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
35 |
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
36 |
"""
|
37 |
-
self.device = device
|
38 |
self.MEAN = -4.2677393 # written on the paper
|
39 |
self.STD = 4.5689974 # written on the paper
|
40 |
|
@@ -96,13 +95,13 @@ class AudioMAEEncoder(VisionTransformer):
|
|
96 |
return mel_spectrogram
|
97 |
|
98 |
@torch.no_grad()
|
99 |
-
def encode(self, file_path:str):
|
100 |
self.eval()
|
101 |
|
102 |
waveform = self.load_wav_file(file_path)
|
103 |
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
104 |
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
105 |
-
z = self.forward_features(melspec.to(
|
106 |
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
107 |
|
108 |
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
@@ -123,7 +122,8 @@ class PretrainedAudioMAEEncoder(PreTrainedModel):
|
|
123 |
|
124 |
def __init__(self, config):
|
125 |
super().__init__(config)
|
126 |
-
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes
|
127 |
|
128 |
def forward(self, file_path:str):
|
129 |
-
|
|
|
|
28 |
|
29 |
|
30 |
class AudioMAEEncoder(VisionTransformer):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
"""
|
34 |
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
35 |
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
36 |
"""
|
|
|
37 |
self.MEAN = -4.2677393 # written on the paper
|
38 |
self.STD = 4.5689974 # written on the paper
|
39 |
|
|
|
95 |
return mel_spectrogram
|
96 |
|
97 |
@torch.no_grad()
|
98 |
+
def encode(self, file_path:str, device):
|
99 |
self.eval()
|
100 |
|
101 |
waveform = self.load_wav_file(file_path)
|
102 |
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
103 |
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
104 |
+
z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
|
105 |
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
106 |
|
107 |
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
|
|
122 |
|
123 |
def __init__(self, config):
|
124 |
super().__init__(config)
|
125 |
+
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
|
126 |
|
127 |
def forward(self, file_path:str):
|
128 |
+
device = self.device
|
129 |
+
return self.encoder.encode(file_path, device) # (d h' w')
|