|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as transforms |
|
from torchaudio.compliance import kaldi |
|
|
|
from einops import rearrange |
|
|
|
from timm.models.vision_transformer import VisionTransformer |
|
from transformers import PreTrainedModel |
|
|
|
from model_config import AudioMAEConfig |
|
|
|
|
|
class AudioMAEEncoder(VisionTransformer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
""" |
|
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper |
|
- AudoMAE accepts a mono-channel (i.e., in_chans=1) |
|
""" |
|
self.MEAN = -4.2677393 |
|
self.STD = 4.5689974 |
|
|
|
def load_wav_file(self, file_path:str): |
|
""" |
|
to use this, `torchaudio` and `ffmpeg` must be installed |
|
- `ffmpeg` version must be >=4.4 and <7. |
|
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1` |
|
""" |
|
audio, sample_rate = torchaudio.load(file_path) |
|
|
|
|
|
audio_len = audio.shape[-1] / sample_rate |
|
if audio_len > 10.0: |
|
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.') |
|
|
|
|
|
if audio.shape[0] > 1: |
|
|
|
|
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
audio = converter(audio) |
|
return audio |
|
|
|
def waveform_to_melspec(self, waveform:torch.FloatTensor): |
|
|
|
|
|
mel_spectrogram = kaldi.fbank( |
|
waveform, |
|
num_mel_bins=128, |
|
frame_length=25.0, |
|
frame_shift=10.0, |
|
htk_compat=True, |
|
use_energy=False, |
|
sample_frequency=16000, |
|
window_type='hanning', |
|
dither=0.0 |
|
) |
|
|
|
|
|
expected_frames = 1024 |
|
current_frames = mel_spectrogram.shape[0] |
|
if current_frames > expected_frames: |
|
mel_spectrogram = mel_spectrogram[:expected_frames, :] |
|
elif current_frames < expected_frames: |
|
padding = expected_frames - current_frames |
|
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, |
|
0, padding), |
|
) |
|
|
|
|
|
|
|
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) |
|
return mel_spectrogram |
|
|
|
@torch.no_grad() |
|
def encode(self, file_path:str): |
|
self.eval() |
|
|
|
waveform = self.load_wav_file(file_path) |
|
melspec = self.waveform_to_melspec(waveform) |
|
melspec = melspec[None,None,:,:] |
|
z = self.forward_features(melspec) |
|
z = z[:,1:,:] |
|
|
|
b, c, w, h = melspec.shape |
|
wprime = round(w / self.patch_embed.patch_size[0]) |
|
hprime = round(h / self.patch_embed.patch_size[1]) |
|
|
|
|
|
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) |
|
|
|
|
|
z = z[0] |
|
return z |
|
|
|
|
|
|
|
class PretrainedAudioMAEEncoder(PreTrainedModel): |
|
config_class = AudioMAEConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes) |
|
|
|
def forward(self, file_path:str): |
|
return self.encoder.encode(file_path) |
|
|