audiomae / model.py
dslee2601's picture
support cuda
1a2cf86
raw
history blame
5.46 kB
from typing import Tuple
import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.compliance import kaldi
from transformers import PretrainedConfig
from einops import rearrange
from timm.models.vision_transformer import VisionTransformer
from transformers import PreTrainedModel
# it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
class AudioMAEConfig(PretrainedConfig):
model_type = "audiomae"
def __init__(self,
img_size:Tuple[int,int]=(1024,128),
in_chans:int=1,
num_classes:int=0,
**kwargs,):
super().__init__(**kwargs)
self.img_size = img_size
self.in_chans = in_chans
self.num_classes = num_classes
class AudioMAEEncoder(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
"""
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
"""
self.MEAN = -4.2677393 # written on the paper
self.STD = 4.5689974 # written on the paper
def load_wav_file(self, file_path:str):
"""
to use this, `torchaudio` and `ffmpeg` must be installed
- `ffmpeg` version must be >=4.4 and <7.
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
"""
audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
# length clip
audio_len = audio.shape[-1] / sample_rate
if audio_len > 10.0:
print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
# Check if the audio has multiple channels
if audio.shape[0] > 1:
# Convert stereo audio to mono by taking the mean across channels
# AudioMAE accepts a mono channel.
audio = torch.mean(audio, dim=0, keepdim=True)
# resample the audio into 16khz
# AudioMAE accepts 16khz
if sample_rate != 16000:
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
audio = converter(audio)
return audio
def waveform_to_melspec(self, waveform:torch.FloatTensor):
# Compute the Mel spectrogram using Kaldi-compatible features
# the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
mel_spectrogram = kaldi.fbank(
waveform,
num_mel_bins=128,
frame_length=25.0,
frame_shift=10.0,
htk_compat=True,
use_energy=False,
sample_frequency=16000,
window_type='hanning',
dither=0.0
)
# Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
expected_frames = 1024 # as described in the paper
current_frames = mel_spectrogram.shape[0]
if current_frames > expected_frames:
mel_spectrogram = mel_spectrogram[:expected_frames, :]
elif current_frames < expected_frames:
padding = expected_frames - current_frames
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
0, padding), # (left, right) for the 2nd dim
)
# scale
# as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
return mel_spectrogram
@torch.no_grad()
def encode(self, file_path:str, device):
self.eval()
waveform = self.load_wav_file(file_path)
melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
z = z[:,1:,:] # (b n d); remove [CLS], the class token
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
# reconstruct the temporal and freq dims
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
# remove the batch dim
z = z[0] # (d h' w')
return z # (d h' w')
class PretrainedAudioMAEEncoder(PreTrainedModel):
config_class = AudioMAEConfig
def __init__(self, config):
super().__init__(config)
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
def forward(self, file_path:str):
device = self.device
return self.encoder.encode(file_path, device) # (d h' w')