support cuda
Browse files
model.py
CHANGED
@@ -28,12 +28,13 @@ class AudioMAEConfig(PretrainedConfig):
|
|
28 |
|
29 |
|
30 |
class AudioMAEEncoder(VisionTransformer):
|
31 |
-
def __init__(self,
|
32 |
-
super().__init__(
|
33 |
"""
|
34 |
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
35 |
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
36 |
"""
|
|
|
37 |
self.MEAN = -4.2677393 # written on the paper
|
38 |
self.STD = 4.5689974 # written on the paper
|
39 |
|
@@ -101,7 +102,7 @@ class AudioMAEEncoder(VisionTransformer):
|
|
101 |
waveform = self.load_wav_file(file_path)
|
102 |
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
103 |
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
104 |
-
z = self.forward_features(melspec) # (b, 1+n, d); d=768
|
105 |
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
106 |
|
107 |
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
@@ -122,7 +123,7 @@ class PretrainedAudioMAEEncoder(PreTrainedModel):
|
|
122 |
|
123 |
def __init__(self, config):
|
124 |
super().__init__(config)
|
125 |
-
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
|
126 |
|
127 |
def forward(self, file_path:str):
|
128 |
return self.encoder.encode(file_path) # (d h' w')
|
|
|
28 |
|
29 |
|
30 |
class AudioMAEEncoder(VisionTransformer):
|
31 |
+
def __init__(self, img_size, in_chans, num_classes, device):
|
32 |
+
super().__init__(img_size, in_chans, num_classes)
|
33 |
"""
|
34 |
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
|
35 |
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
|
36 |
"""
|
37 |
+
self.device = device
|
38 |
self.MEAN = -4.2677393 # written on the paper
|
39 |
self.STD = 4.5689974 # written on the paper
|
40 |
|
|
|
102 |
waveform = self.load_wav_file(file_path)
|
103 |
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
|
104 |
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
|
105 |
+
z = self.forward_features(melspec.to(self.device)).cpu() # (b, 1+n, d); d=768
|
106 |
z = z[:,1:,:] # (b n d); remove [CLS], the class token
|
107 |
|
108 |
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
|
|
|
123 |
|
124 |
def __init__(self, config):
|
125 |
super().__init__(config)
|
126 |
+
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes, device=self.device)
|
127 |
|
128 |
def forward(self, file_path:str):
|
129 |
return self.encoder.encode(file_path) # (d h' w')
|