dslee2601 commited on
Commit
1a2cf86
·
1 Parent(s): c60507b

support cuda

Browse files
Files changed (1) hide show
  1. model.py +7 -7
model.py CHANGED
@@ -28,13 +28,12 @@ class AudioMAEConfig(PretrainedConfig):
28
 
29
 
30
  class AudioMAEEncoder(VisionTransformer):
31
- def __init__(self, img_size, in_chans, num_classes, device):
32
- super().__init__(img_size, in_chans, num_classes)
33
  """
34
  - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
35
  - AudoMAE accepts a mono-channel (i.e., in_chans=1)
36
  """
37
- self.device = device
38
  self.MEAN = -4.2677393 # written on the paper
39
  self.STD = 4.5689974 # written on the paper
40
 
@@ -96,13 +95,13 @@ class AudioMAEEncoder(VisionTransformer):
96
  return mel_spectrogram
97
 
98
  @torch.no_grad()
99
- def encode(self, file_path:str):
100
  self.eval()
101
 
102
  waveform = self.load_wav_file(file_path)
103
  melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
104
  melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
105
- z = self.forward_features(melspec.to(self.device)).cpu() # (b, 1+n, d); d=768
106
  z = z[:,1:,:] # (b n d); remove [CLS], the class token
107
 
108
  b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
@@ -123,7 +122,8 @@ class PretrainedAudioMAEEncoder(PreTrainedModel):
123
 
124
  def __init__(self, config):
125
  super().__init__(config)
126
- self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes, device=self.device)
127
 
128
  def forward(self, file_path:str):
129
- return self.encoder.encode(file_path) # (d h' w')
 
 
28
 
29
 
30
  class AudioMAEEncoder(VisionTransformer):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
  """
34
  - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
35
  - AudoMAE accepts a mono-channel (i.e., in_chans=1)
36
  """
 
37
  self.MEAN = -4.2677393 # written on the paper
38
  self.STD = 4.5689974 # written on the paper
39
 
 
95
  return mel_spectrogram
96
 
97
  @torch.no_grad()
98
+ def encode(self, file_path:str, device):
99
  self.eval()
100
 
101
  waveform = self.load_wav_file(file_path)
102
  melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
103
  melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
104
+ z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
105
  z = z[:,1:,:] # (b n d); remove [CLS], the class token
106
 
107
  b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
 
122
 
123
  def __init__(self, config):
124
  super().__init__(config)
125
+ self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
126
 
127
  def forward(self, file_path:str):
128
+ device = self.device
129
+ return self.encoder.encode(file_path, device) # (d h' w')