File size: 5,464 Bytes
78b7cf3
 
074169d
 
 
 
78b7cf3
074169d
 
 
 
 
 
78b7cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
074169d
 
 
1a2cf86
 
074169d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a2cf86
074169d
 
 
 
 
1a2cf86
074169d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a2cf86
074169d
 
1a2cf86
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import Tuple

import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.compliance import kaldi
from transformers import PretrainedConfig

from einops import rearrange

from timm.models.vision_transformer import VisionTransformer
from transformers import PreTrainedModel


# it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
class AudioMAEConfig(PretrainedConfig):
    model_type = "audiomae"

    def __init__(self,

                 img_size:Tuple[int,int]=(1024,128),

                 in_chans:int=1,

                 num_classes:int=0,

                 **kwargs,):
        super().__init__(**kwargs)
        self.img_size = img_size
        self.in_chans = in_chans
        self.num_classes = num_classes


class AudioMAEEncoder(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        """

        - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper

        - AudoMAE accepts a mono-channel (i.e., in_chans=1)

        """
        self.MEAN = -4.2677393  # written on the paper
        self.STD = 4.5689974  # written on the paper

    def load_wav_file(self, file_path:str):
        """

        to use this, `torchaudio` and `ffmpeg` must be installed

        - `ffmpeg` version must be >=4.4 and <7.

        - `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`

        """
        audio, sample_rate = torchaudio.load(file_path)  # audio: (n_channels, length); 

        # length clip
        audio_len = audio.shape[-1] / sample_rate
        if audio_len > 10.0:
            print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')

        # Check if the audio has multiple channels
        if audio.shape[0] > 1:
            # Convert stereo audio to mono by taking the mean across channels
            # AudioMAE accepts a mono channel.
            audio = torch.mean(audio, dim=0, keepdim=True)

        # resample the audio into 16khz
        # AudioMAE accepts 16khz
        if sample_rate != 16000:
            converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            audio = converter(audio)
        return audio
    
    def waveform_to_melspec(self, waveform:torch.FloatTensor):
        # Compute the Mel spectrogram using Kaldi-compatible features
        # the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
        mel_spectrogram = kaldi.fbank(
            waveform, 
            num_mel_bins=128, 
            frame_length=25.0, 
            frame_shift=10.0, 
            htk_compat=True, 
            use_energy=False,
            sample_frequency=16000, 
            window_type='hanning',
            dither=0.0
        )
        
        # Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
        expected_frames = 1024  # as described in the paper
        current_frames = mel_spectrogram.shape[0]
        if current_frames > expected_frames:
            mel_spectrogram = mel_spectrogram[:expected_frames, :]
        elif current_frames < expected_frames:
            padding = expected_frames - current_frames
            mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0,  # (left, right) for the 1st dim
                                                                        0, padding),  # (left, right) for the 2nd dim
                                                                        )
            
        # scale
        # as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
        mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2)  # (length, n_freq_bins) = (1024, 128)
        return mel_spectrogram

    @torch.no_grad()
    def encode(self, file_path:str, device):
        self.eval()

        waveform = self.load_wav_file(file_path)
        melspec = self.waveform_to_melspec(waveform)  # (length, n_freq_bins) = (1024, 128)
        melspec = melspec[None,None,:,:]  # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
        z = self.forward_features(melspec.to(device)).cpu()  # (b, 1+n, d); d=768
        z = z[:,1:,:]  # (b n d); remove [CLS], the class token

        b, c, w, h = melspec.shape  # w: temporal dim; h:freq dim
        wprime = round(w / self.patch_embed.patch_size[0])  # width in the latent space
        hprime = round(h / self.patch_embed.patch_size[1])  # height in the latent space

        # reconstruct the temporal and freq dims
        z = rearrange(z, 'b (w h) d -> b d h w', h=hprime)  # (b d h' w')

        # remove the batch dim
        z = z[0]  # (d h' w')
        return z  # (d h' w')



class PretrainedAudioMAEEncoder(PreTrainedModel):
    config_class = AudioMAEConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
    
    def forward(self, file_path:str):
        device = self.device
        return self.encoder.encode(file_path, device)  # (d h' w')