File size: 2,107 Bytes
7ff2ba3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from typing import Optional
import torch
import numpy as np
from librosa.filters import mel
from .stft import STFT
class MelSpectrogram(torch.nn.Module):
def __init__(
self,
is_half: bool,
n_mel_channels: int,
sampling_rate: int,
win_length: int,
hop_length: int,
n_fft: Optional[int] = None,
mel_fmin: int = 0,
mel_fmax: int = None,
clamp: float = 1e-5,
device=torch.device("cpu"),
):
super().__init__()
if n_fft is None:
n_fft = win_length
mel_basis = mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
htk=True,
)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer("mel_basis", mel_basis)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.clamp = clamp
self.is_half = is_half
self.stft = STFT(
filter_length=n_fft,
hop_length=hop_length,
win_length=win_length,
window="hann",
use_torch_stft="privateuseone" not in str(device),
).to(device)
def forward(
self,
audio: torch.Tensor,
keyshift=0,
speed=1,
center=True,
):
factor = 2 ** (keyshift / 12)
win_length_new = int(np.round(self.win_length * factor))
magnitude = self.stft(audio, keyshift, speed, center)
if keyshift != 0:
size = self.n_fft // 2 + 1
resize = magnitude.size(1)
if resize < size:
magnitude = torch.nn.functional.pad(magnitude, (0, 0, 0, size - resize))
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
mel_output = torch.matmul(self.mel_basis, magnitude)
if self.is_half:
mel_output = mel_output.half()
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
return log_mel_spec
|