dslee2601 commited on
Commit
c60507b
·
1 Parent(s): 2fccb59

support cuda

Browse files
Files changed (1) hide show
  1. model.py +5 -4
model.py CHANGED
@@ -28,12 +28,13 @@ class AudioMAEConfig(PretrainedConfig):
28
 
29
 
30
  class AudioMAEEncoder(VisionTransformer):
31
- def __init__(self, *args, **kwargs):
32
- super().__init__(*args, **kwargs)
33
  """
34
  - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
35
  - AudoMAE accepts a mono-channel (i.e., in_chans=1)
36
  """
 
37
  self.MEAN = -4.2677393 # written on the paper
38
  self.STD = 4.5689974 # written on the paper
39
 
@@ -101,7 +102,7 @@ class AudioMAEEncoder(VisionTransformer):
101
  waveform = self.load_wav_file(file_path)
102
  melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
103
  melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
104
- z = self.forward_features(melspec) # (b, 1+n, d); d=768
105
  z = z[:,1:,:] # (b n d); remove [CLS], the class token
106
 
107
  b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
@@ -122,7 +123,7 @@ class PretrainedAudioMAEEncoder(PreTrainedModel):
122
 
123
  def __init__(self, config):
124
  super().__init__(config)
125
- self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
126
 
127
  def forward(self, file_path:str):
128
  return self.encoder.encode(file_path) # (d h' w')
 
28
 
29
 
30
  class AudioMAEEncoder(VisionTransformer):
31
+ def __init__(self, img_size, in_chans, num_classes, device):
32
+ super().__init__(img_size, in_chans, num_classes)
33
  """
34
  - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
35
  - AudoMAE accepts a mono-channel (i.e., in_chans=1)
36
  """
37
+ self.device = device
38
  self.MEAN = -4.2677393 # written on the paper
39
  self.STD = 4.5689974 # written on the paper
40
 
 
102
  waveform = self.load_wav_file(file_path)
103
  melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
104
  melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
105
+ z = self.forward_features(melspec.to(self.device)).cpu() # (b, 1+n, d); d=768
106
  z = z[:,1:,:] # (b n d); remove [CLS], the class token
107
 
108
  b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
 
123
 
124
  def __init__(self, config):
125
  super().__init__(config)
126
+ self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes, device=self.device)
127
 
128
  def forward(self, file_path:str):
129
  return self.encoder.encode(file_path) # (d h' w')