from typing import Tuple import torch import torchaudio import torchaudio.transforms as transforms from torchaudio.compliance import kaldi from transformers import PretrainedConfig from einops import rearrange from timm.models.vision_transformer import VisionTransformer from transformers import PreTrainedModel # it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF. class AudioMAEConfig(PretrainedConfig): model_type = "audiomae" def __init__(self, img_size:Tuple[int,int]=(1024,128), in_chans:int=1, num_classes:int=0, **kwargs,): super().__init__(**kwargs) self.img_size = img_size self.in_chans = in_chans self.num_classes = num_classes class AudioMAEEncoder(VisionTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) """ - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper - AudoMAE accepts a mono-channel (i.e., in_chans=1) """ self.MEAN = -4.2677393 # written on the paper self.STD = 4.5689974 # written on the paper def load_wav_file(self, file_path:str): """ to use this, `torchaudio` and `ffmpeg` must be installed - `ffmpeg` version must be >=4.4 and <7. - `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1` """ audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length); # length clip audio_len = audio.shape[-1] / sample_rate if audio_len > 10.0: print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.') # Check if the audio has multiple channels if audio.shape[0] > 1: # Convert stereo audio to mono by taking the mean across channels # AudioMAE accepts a mono channel. audio = torch.mean(audio, dim=0, keepdim=True) # resample the audio into 16khz # AudioMAE accepts 16khz if sample_rate != 16000: converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000) audio = converter(audio) return audio def waveform_to_melspec(self, waveform:torch.FloatTensor): # Compute the Mel spectrogram using Kaldi-compatible features # the parameters are chosen as described in the audioMAE paper (4.2 implementation details) mel_spectrogram = kaldi.fbank( waveform, num_mel_bins=128, frame_length=25.0, frame_shift=10.0, htk_compat=True, use_energy=False, sample_frequency=16000, window_type='hanning', dither=0.0 ) # Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension expected_frames = 1024 # as described in the paper current_frames = mel_spectrogram.shape[0] if current_frames > expected_frames: mel_spectrogram = mel_spectrogram[:expected_frames, :] elif current_frames < expected_frames: padding = expected_frames - current_frames mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim 0, padding), # (left, right) for the 2nd dim ) # scale # as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300] mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128) return mel_spectrogram @torch.no_grad() def encode(self, file_path:str, device): self.eval() waveform = self.load_wav_file(file_path) melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128) melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128) z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768 z = z[:,1:,:] # (b n d); remove [CLS], the class token b, c, w, h = melspec.shape # w: temporal dim; h:freq dim wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space # reconstruct the temporal and freq dims z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w') # remove the batch dim z = z[0] # (d h' w') return z # (d h' w') class PretrainedAudioMAEEncoder(PreTrainedModel): config_class = AudioMAEConfig def __init__(self, config): super().__init__(config) self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes) def forward(self, file_path:str): device = self.device return self.encoder.encode(file_path, device) # (d h' w')