diff --git "a/beatrice_trainer/__main__.py" "b/beatrice_trainer/__main__.py" --- "a/beatrice_trainer/__main__.py" +++ "b/beatrice_trainer/__main__.py" @@ -10,6 +10,7 @@ import os import shutil import warnings from collections import defaultdict +from contextlib import nullcontext from copy import deepcopy from fractions import Fraction from functools import partial @@ -29,10 +30,17 @@ from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm assert "soundfile" in torchaudio.list_audio_backends() +if not hasattr(torch.amp, "GradScaler"): + + class GradScaler(torch.cuda.amp.GradScaler): + def __init__(self, _, *args, **kwargs): + super().__init__(*args, **kwargs) + + torch.amp.GradScaler = GradScaler # モジュールのバージョンではない -PARAPHERNALIA_VERSION = "2.0.0-alpha.2" +PARAPHERNALIA_VERSION = "2.0.0-beta.1" def is_notebook() -> bool: @@ -52,19 +60,22 @@ def repo_root() -> Path: # 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない dict_default_hparams = { # train - "learning_rate": 1e-4, - "min_learning_rate": 5e-6, + "learning_rate_g": 2e-4, + "learning_rate_d": 1e-4, + "min_learning_rate_g": 1e-5, + "min_learning_rate_d": 5e-6, "adam_betas": [0.8, 0.99], "adam_eps": 1e-6, "batch_size": 8, "grad_weight_mel": 1.0, # grad_weight は比が同じなら同じ意味になるはず - "grad_weight_adv": 1.0, - "grad_weight_fm": 1.0, + "grad_weight_ap": 2.0, + "grad_weight_adv": 3.0, + "grad_weight_fm": 3.0, "grad_balancer_ema_decay": 0.995, "use_amp": True, "num_workers": 16, - "n_steps": 20000, - "warmup_steps": 10000, + "n_steps": 10000, + "warmup_steps": 2000, "in_sample_rate": 16000, # 変更不可 "out_sample_rate": 24000, # 変更不可 "wav_length": 4 * 24000, # 4s @@ -75,10 +86,14 @@ dict_default_hparams = { "in_ir_wav_dir": "assets/ir", "in_noise_wav_dir": "assets/noise", "in_test_wav_dir": "assets/test", - "pretrained_file": "assets/pretrained/040c_checkpoint_libritts_r_200_02300000.pt", # None も可 + "pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt", # None も可 # model "hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要 "san": False, # ファインチューン時変更不可 + "compile_convnext": False, + "compile_d4c": False, + "compile_discriminator": False, + "profile": False, } if __name__ == "__main__": @@ -102,7 +117,7 @@ if __name__ == "__main__": warnings.warn("dafualt_config.json not found.") -def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool]: +def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]: import ipynbname from IPython import get_ipython @@ -114,10 +129,11 @@ def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool]: notebook_name = Path(get_ipython().user_ns["__vsc_ipynb_file__"]).name out_dir = repo_root() / "notebooks" / notebook_name.split(".")[0].split("_")[0] resume = False - return h, in_wav_dataset_dir, out_dir, resume + skip_training = False + return h, in_wav_dataset_dir, out_dir, resume, skip_training -def prepare_training_configs() -> tuple[dict, Path, Path, bool]: +def prepare_training_configs() -> tuple[dict, Path, Path, bool, bool]: # data_dir, out_dir は config ファイルでもコマンドライン引数でも指定でき、 # コマンドライン引数が優先される。 # 各種ファイルパスを相対パスで指定した場合、config ファイルでは @@ -173,7 +189,7 @@ def prepare_training_configs() -> tuple[dict, Path, Path, bool]: del h[key] # resume resume = args.resume - return h, in_wav_dataset_dir, out_dir, resume + return h, in_wav_dataset_dir, out_dir, resume, False class AttrDict(dict): @@ -229,6 +245,8 @@ def dump_layer(layer: nn.Module, f: BinaryIO): dump(getattr(layer, f"bias_hh_l{i}")) elif isinstance(layer, nn.Embedding): dump(layer.weight) + elif isinstance(layer, nn.Parameter): + dump(layer) elif isinstance(layer, nn.ModuleList): for l in layer: dump_layer(l, f) @@ -271,6 +289,85 @@ class CausalConv1d(nn.Conv1d): return result[:, :, : -self.trim] +class WSConv1d(CausalConv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + delay: int = 0, + ): + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + delay=delay, + ) + self.weight.data.normal_( + 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) + ) + if bias: + self.bias.data.zero_() + self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) + + def standardized_weight(self) -> torch.Tensor: + var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) + scale = ( + self.gain + * ( + self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 + ).rsqrt() + ) + return scale * (self.weight - mean) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + result = F.conv1d( + input, + self.standardized_weight(), + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + if self.trim == 0: + return result + else: + return result[:, :, : -self.trim] + + def merge_weights(self): + self.weight.data[:] = self.standardized_weight().detach() + self.gain.data.fill_(1.0) + + +class WSLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias) + self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) + self.bias.data.zero_() + self.gain = nn.Parameter(torch.ones((out_features, 1))) + + def standardized_weight(self) -> torch.Tensor: + var, mean = torch.var_mean(self.weight, 1, keepdim=True) + scale = self.gain * (self.in_features * var + 1e-8).rsqrt() + return scale * (self.weight - mean) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.standardized_weight(), self.bias) + + def merge_weights(self): + self.weight.data[:] = self.standardized_weight().detach() + self.gain.data.fill_(1.0) + + class ConvNeXtBlock(nn.Module): def __init__( self, @@ -278,10 +375,14 @@ class ConvNeXtBlock(nn.Module): intermediate_channels: int, layer_scale_init_value: float, kernel_size: int = 7, - use_weight_norm: bool = False, + use_weight_standardization: bool = False, + enable_scaling: bool = False, + pre_scale: float = 1.0, + post_scale: float = 1.0, ): super().__init__() - self.use_weight_norm = use_weight_norm + self.use_weight_standardization = use_weight_standardization + self.enable_scaling = enable_scaling self.dwconv = CausalConv1d( channels, channels, kernel_size=kernel_size, groups=channels ) @@ -289,42 +390,65 @@ class ConvNeXtBlock(nn.Module): self.pwconv1 = nn.Linear(channels, intermediate_channels) self.pwconv2 = nn.Linear(intermediate_channels, channels) self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) - if use_weight_norm: + self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) + self.dwconv.bias.data.zero_() + self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) + self.pwconv1.bias.data.zero_() + self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) + self.pwconv2.bias.data.zero_() + if use_weight_standardization: self.norm = nn.Identity() - self.dwconv = weight_norm(self.dwconv) - self.pwconv1 = weight_norm(self.pwconv1) - self.pwconv2 = weight_norm(self.pwconv2) + self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) + self.pwconv1 = WSLinear(channels, intermediate_channels) + self.pwconv2 = WSLinear(intermediate_channels, channels) + del self.gamma + if enable_scaling: + self.register_buffer("pre_scale", torch.tensor(pre_scale)) + self.register_buffer("post_scale", torch.tensor(post_scale)) + self.post_scale_weight = nn.Parameter(torch.ones(())) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x + if self.enable_scaling: + x = x * self.pre_scale x = self.dwconv(x) x = x.transpose(1, 2) x = self.norm(x) x = self.pwconv1(x) x = F.gelu(x, approximate="tanh") x = self.pwconv2(x) - x *= self.gamma + if not self.use_weight_standardization: + x *= self.gamma + if self.enable_scaling: + x *= self.post_scale * self.post_scale_weight x = x.transpose(1, 2) x += identity return x - def remove_weight_norm(self): - if self.use_weight_norm: - remove_weight_norm(self.dwconv) - remove_weight_norm(self.pwconv1) - remove_weight_norm(self.pwconv2) - def merge_weights(self): - if not self.use_weight_norm: + if self.use_weight_standardization: + self.dwconv.merge_weights() + self.pwconv1.merge_weights() + self.pwconv2.merge_weights() + else: self.pwconv1.bias.data += ( self.norm.bias.data[None, :] * self.pwconv1.weight.data ).sum(1) self.pwconv1.weight.data *= self.norm.weight.data[None, :] self.norm.bias.data[:] = 0.0 self.norm.weight.data[:] = 1.0 - self.pwconv2.weight.data *= self.gamma.data[:, None] - self.pwconv2.bias.data *= self.gamma.data - self.gamma.data[:] = 1.0 + self.pwconv2.weight.data *= self.gamma.data[:, None] + self.pwconv2.bias.data *= self.gamma.data + self.gamma.data[:] = 1.0 + if self.enable_scaling: + self.dwconv.weight.data *= self.pre_scale.data + self.pre_scale.data.fill_(1.0) + self.pwconv2.weight.data *= ( + self.post_scale.data * self.post_scale_weight.data + ) + self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data + self.post_scale.data.fill_(1.0) + self.post_scale_weight.data.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): @@ -349,36 +473,38 @@ class ConvNeXtStack(nn.Module): delay: int, embed_kernel_size: int, kernel_size: int, - use_weight_norm: bool = False, + use_weight_standardization: bool = False, + enable_scaling: bool = False, ): super().__init__() assert delay * 2 + 1 <= embed_kernel_size - self.use_weight_norm = use_weight_norm + self.use_weight_standardization = use_weight_standardization self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.LayerNorm(channels) - self.convnext = nn.ModuleList( - [ - ConvNeXtBlock( - channels=channels, - intermediate_channels=intermediate_channels, - layer_scale_init_value=1.0 / n_blocks, - kernel_size=kernel_size, - use_weight_norm=use_weight_norm, - ) - for _ in range(n_blocks) - ] - ) + self.convnext = nn.ModuleList() + for i in range(n_blocks): + pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 + post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 + block = ConvNeXtBlock( + channels=channels, + intermediate_channels=intermediate_channels, + layer_scale_init_value=1.0 / n_blocks, + kernel_size=kernel_size, + use_weight_standardization=use_weight_standardization, + enable_scaling=enable_scaling, + pre_scale=pre_scale, + post_scale=post_scale, + ) + self.convnext.append(block) self.final_layer_norm = nn.LayerNorm(channels) - if use_weight_norm: - self.embed = weight_norm(self.embed) + self.embed.weight.data.normal_( + 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) + ) + self.embed.bias.data.zero_() + if use_weight_standardization: + self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.Identity() self.final_layer_norm = nn.Identity() - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embed(x) @@ -388,13 +514,9 @@ class ConvNeXtStack(nn.Module): x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) return x - def remove_weight_norm(self): - if self.use_weight_norm: - remove_weight_norm(self.embed) - for conv_block in self.convnext: - conv_block.remove_weight_norm() - def merge_weights(self): + if self.use_weight_standardization: + self.embed.merge_weights() for conv_block in self.convnext: conv_block.merge_weights() @@ -407,10 +529,10 @@ class ConvNeXtStack(nn.Module): raise TypeError dump_layer(self.embed, f) - if not self.use_weight_norm: + if not self.use_weight_standardization: dump_layer(self.norm, f) dump_layer(self.convnext, f) - if not self.use_weight_norm: + if not self.use_weight_standardization: dump_layer(self.final_layer_norm, f) @@ -694,9 +816,7 @@ def extract_pitch_features( # 変換モデルへの入力用のエネルギー energy = ( - y_frames.mul_( - torch.signal.windows.cosine(win_length, device=y.device)[..., None] - ) + (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) .square_() .sum(-2, keepdim=True) ) @@ -748,7 +868,7 @@ class PitchEstimator(nn.Module): # [batch_size, input_instfreq_channels, length], # [batch_size, input_corr_channels, length] - with torch.cuda.amp.autocast(False): + with torch.amp.autocast("cuda", enabled=False): instfreq_features, corr_diff, energy = extract_pitch_features( wav.squeeze(1), hop_length=160, @@ -860,20 +980,25 @@ class PitchEstimator(nn.Module): # %% def overlap_add( - ir: torch.Tensor, + ir_amp: torch.Tensor, + ir_phase: torch.Tensor, + window: torch.Tensor, pitch: torch.Tensor, hop_length: int = 240, delay: int = 0, + sr: float = 24000.0, ) -> torch.Tensor: - # print("ir, pitch: ", ir.dtype, pitch.dtype) - batch_size, ir_length, length = ir.size() + batch_size, ir_length, length = ir_amp.size() + ir_length = (ir_length - 1) * 2 + assert ir_phase.size() == ir_amp.size() + assert window.size() == (ir_length,), (window.size(), ir_amp.size()) assert pitch.size() == (batch_size, length * hop_length) assert 0 <= delay < ir_length, (delay, ir_length) - # 位相は [0, 1) で表す - normalized_freq = pitch / 24000.0 - # 初期位相をランダムに設定 + # 正規化角周波数 [2π rad] + normalized_freq = pitch / sr + # 初期位相 [2π rad] をランダムに設定 normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): phase = (normalized_freq.double().cumsum_(1) % 1.0).float() # 重ねる箇所を求める # [n_pitchmarks], [n_pitchmarks] @@ -883,31 +1008,27 @@ def overlap_add( # [n_pitchmarks] fractional_part = numer / (numer + phase[indices0, indices1 + 1]) # 重ねる値を求める - # [n_pitchmarks, ir_length] - values = ir[indices0, :, indices1 // hop_length] - # 位相を遅らせる - # values が時間領域と仮定 # Complex[n_pitchmarks, ir_length / 2 + 1] - values = torch.fft.rfft(values, n=ir_length, dim=1) - # 位相遅れの量 + ir_amp = ir_amp[indices0, :, indices1 // hop_length] + ir_phase = ir_phase[indices0, :, indices1 // hop_length] + # 位相遅れの量 [rad] # [n_pitchmarks, ir_length / 2 + 1] delay_phase = ( torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ None, : ] - / -ir_length + * (-math.tau / ir_length) * fractional_part[:, None] ) # Complex[n_pitchmarks, ir_length / 2 + 1] - delay_phase = torch.polar(torch.ones_like(delay_phase), delay_phase * math.tau) - # values *= delay_phase - values = values * delay_phase + spec = torch.polar(ir_amp, ir_phase + delay_phase) # [n_pitchmarks, ir_length] - values = torch.fft.irfft(values, n=ir_length, dim=1) + ir = torch.fft.irfft(spec, n=ir_length, dim=1) + ir *= window # 加算する値をサンプル単位にばらす # [n_pitchmarks * ir_length] - values = values.ravel() + ir = ir.ravel() # Long[n_pitchmarks * ir_length] indices0 = indices0[:, None].expand(-1, ir_length).ravel() # Long[n_pitchmarks * ir_length] @@ -919,15 +1040,15 @@ def overlap_add( overlap_added_signal = torch.zeros( (batch_size, length * hop_length + ir_length), device=pitch.device ) - # print("overlap_added_signal, values: ", overlap_added_signal.dtype, values.dtype) - overlap_added_signal.index_put_((indices0, indices1), values, accumulate=True) + overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] - # sinc 重ねたものと ir を畳み込んだ方が FFT の回数減らせた気がする return overlap_added_signal -def generate_noise(aperiodicity: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def generate_noise( + aperiodicity: torch.Tensor, delay: int = 0 +) -> tuple[torch.Tensor, torch.Tensor]: # aperiodicity: [batch_size, hop_length, length] batch_size, hop_length, length = aperiodicity.size() excitation = torch.rand( @@ -945,10 +1066,7 @@ def generate_noise(aperiodicity: torch.Tensor) -> tuple[torch.Tensor, torch.Tens center=False, return_complex=True, ) - assert noise.size(2) == aperiodicity.size(2), ( - noise.size(), - aperiodicity.size(), - ) + assert noise.size(2) == aperiodicity.size(2) noise[:, 0, :] = 0.0 noise[:, 1:, :] *= aperiodicity # ハン窓で合成 @@ -963,8 +1081,10 @@ def generate_noise(aperiodicity: torch.Tensor) -> tuple[torch.Tensor, torch.Tens (1, 2 * hop_length), stride=(1, hop_length), ).squeeze_((1, 2)) - noise = noise[:, hop_length // 2 : -hop_length // 2] - excitation = excitation[:, hop_length // 2 : -hop_length // 2] + + assert delay < hop_length + noise = noise[:, delay : -hop_length + delay] + excitation = excitation[:, delay : -hop_length + delay] return noise, excitation # [batch_size, length * hop_length] @@ -986,15 +1106,400 @@ class GradientEqualizerFunction(torch.autograd.Function): return dx -class PseudoDDSPVocoder(nn.Module): +D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 + + +def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: + # x が単調増加で等間隔と仮定 + # 外挿は起こらないと仮定 + x = torch.as_tensor(x) + y = torch.as_tensor(y) + xi = torch.as_tensor(xi) + if xi.ndim < y.ndim: + diff_ndim = y.ndim - xi.ndim + xi = xi.view(tuple([1] * diff_ndim) + xi.size()) + if xi.size()[:-1] != y.size()[:-1]: + xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) + assert (x.min(-1).values == x[..., 0]).all() + assert (x.max(-1).values == x[..., -1]).all() + assert (xi.min(-1).values >= x[..., 0]).all() + assert (xi.max(-1).values <= x[..., -1]).all() + delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) + delta_x = delta_x.to(x.dtype) + xi = (xi - x[..., :1]).div_(delta_x[..., None]) + xi_base = xi.floor() + xi_fraction = xi.sub_(xi_base) + xi_base = xi_base.long() + delta_y = y.diff(dim=-1, append=y[..., -1:]) + yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction + return yi + + +def linear_smoothing( + group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor +) -> torch.Tensor: + group_delay = torch.as_tensor(group_delay) + assert group_delay.size(-1) == n_fft // 2 + 1 + width = torch.as_tensor(width) + boundary = (width.max() * n_fft / sr).long() + 1 + + dtype = group_delay.dtype + device = group_delay.device + fft_resolution = sr / n_fft + mirroring_freq_axis = ( + torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) + .add(0.5) + .mul(fft_resolution) + ) + if group_delay.ndim == 1: + mirroring_spec = F.pad( + group_delay[None], (boundary, boundary), mode="reflect" + ).squeeze_(0) + elif group_delay.ndim >= 4: + shape = group_delay.size() + mirroring_spec = F.pad( + group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), + (boundary, boundary), + mode="reflect", + ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) + else: + mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") + mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) + center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( + fft_resolution + ) + low_freq = center_freq - width[..., None] * 0.5 + high_freq = center_freq + width[..., None] * 0.5 + levels = interp( + mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) + ) + low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) + smoothed = (high_levels - low_levels).div_(width[..., None]) + return smoothed + + +def dc_correction( + spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor +) -> torch.Tensor: + spec = torch.as_tensor(spec) + f0 = torch.as_tensor(f0) + dtype = spec.dtype + device = spec.device + + upper_limit = 2 + (f0 * (n_fft / sr)).long() + max_upper_limit = upper_limit.max() + upper_limit_mask = ( + torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] + ) + low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( + sr / n_fft + ) + low_freq_replica = interp( + f0[..., None] - low_freq_axis.flip(-1), + spec[..., : max_upper_limit + 1].flip(-1), + low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, + ) + output = spec.clone() + output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask + return output + + +def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: + t = torch.linspace(0, math.tau, n, device=device) + coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) + terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) + cos_matrix = (terms[:, None] * t).cos_() # [4, n] + window = coefs.matmul(cos_matrix) + return window + + +def get_windowed_waveform( + x: torch.Tensor, + sr: int, + f0: torch.Tensor, + position: torch.Tensor, + half_window_length_ratio: float, + window_type: Literal["hann", "blackman"], + n_fft: int, +) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + position = torch.as_tensor(position) + + current_sample = position * sr + # [...] + half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() + # [..., fft_size] + base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) + base_index_mask = base_index <= half_window_length[..., None] + # [..., fft_size] + safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( + 0, x.size(-1) - 1 + ) + # [..., fft_size] + time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) + # [...] + normalized_f0 = math.pi / sr * f0 + # [..., fft_size] + phase = time_axis.mul_(normalized_f0[..., None]) + + if window_type == "hann": + window = phase.cos_().mul_(0.5).add_(0.5) + elif window_type == "blackman": + window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) + else: + assert False + window *= base_index_mask + + prefix_shape = tuple( + max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) + )[:-1] + waveform = ( + x.expand(prefix_shape + (-1,)) + .gather(-1, safe_index.expand(prefix_shape + (-1,))) + .mul_(window) + ) + if not D4C_PREVENT_ZERO_DIVISION: + waveform += torch.randn_like(window).mul_(1e-12) + waveform *= base_index_mask + waveform -= window * waveform.sum(-1, keepdim=True).div_( + window.sum(-1, keepdim=True) + ) + return waveform, window + + +def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: + x = torch.as_tensor(x) + if D4C_PREVENT_ZERO_DIVISION: + x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) + else: + x = x / x.norm(dim=-1, keepdim=True) + spec0 = torch.fft.rfft(x, n_fft) + spec1 = torch.fft.rfft( + x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), + n_fft, + ) + centroid = spec0.real * spec1.real + spec0.imag * spec1.imag + return centroid + + +def get_static_centroid( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int +) -> torch.Tensor: + """First step: calculation of temporally static parameters on basis of group delay""" + x1, _ = get_windowed_waveform( + x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft + ) + x2, _ = get_windowed_waveform( + x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft + ) + centroid1 = get_centroid(x1, n_fft) + centroid2 = get_centroid(x2, n_fft) + return dc_correction(centroid1 + centroid2, sr, n_fft, f0) + + +def get_smoothed_power_spec( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int +) -> tuple[torch.Tensor, torch.Tensor]: + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) + window_weight = window.square().sum(-1, keepdim=True) + rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() + if D4C_PREVENT_ZERO_DIVISION: + x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) + smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() + smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) + smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) + return smoothed_power_spec, rms.detach().squeeze(-1) + + +def get_static_group_delay( + static_centroid: torch.Tensor, + smoothed_power_spec: torch.Tensor, + sr: int, + f0: torch.Tensor, + n_fft: int, +) -> torch.Tensor: + """Second step: calculation of parameter shaping""" + if D4C_PREVENT_ZERO_DIVISION: + smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) + static_group_delay = static_centroid / smoothed_power_spec # t_g + static_group_delay = linear_smoothing( + static_group_delay, sr, n_fft, f0 * 0.5 + ) # t_gs + smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb + static_group_delay = static_group_delay - smoothed_group_delay # t_D + return static_group_delay + + +def get_coarse_aperiodicity( + group_delay: torch.Tensor, + sr: int, + n_fft: int, + freq_interval: int, + n_aperiodicities: int, + window: torch.Tensor, +) -> torch.Tensor: + """Third step: estimation of band-aperiodicity""" + group_delay = torch.as_tensor(group_delay) + window = torch.as_tensor(window) + boundary = int(round(n_fft * 8 / window.size(-1))) + half_window_length = window.size(-1) // 2 + coarse_aperiodicity = torch.empty( + group_delay.size()[:-1] + (n_aperiodicities,), + dtype=group_delay.dtype, + device=group_delay.device, + ) + for i in range(n_aperiodicities): + center = freq_interval * (i + 1) * n_fft // sr + segment = ( + group_delay[ + ..., center - half_window_length : center + half_window_length + 1 + ] + * window + ) + power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() + cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) + if D4C_PREVENT_ZERO_DIVISION: + cumulative_power_spec.clamp_(min=6e-8) + coarse_aperiodicity[..., i] = ( + cumulative_power_spec[..., n_fft // 2 - boundary - 1] + / cumulative_power_spec[..., -1] + ) + coarse_aperiodicity.log10_().mul_(10.0) + return coarse_aperiodicity + + +def d4c_love_train( + x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float +) -> int: + x = torch.as_tensor(x) + position = torch.as_tensor(position) + f0: torch.Tensor = torch.as_tensor(f0) + vuv = f0 != 0 + lowest_f0 = 40 + f0 = f0.clamp(min=lowest_f0) + n_fft = 1 << (3 * sr // lowest_f0).bit_length() + boundary0 = (100 * n_fft - 1) // sr + 1 + boundary1 = (4000 * n_fft - 1) // sr + 1 + boundary2 = (7900 * n_fft - 1) // sr + 1 + waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) + power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() + power_spec[..., : boundary0 + 1] = 0.0 + cumulative_spec = power_spec.cumsum_(-1) + vuv = vuv & ( + cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] + ) + return vuv + + +def d4c_general_body( + x: torch.Tensor, + sr: int, + f0: torch.Tensor, + freq_interval: int, + position: torch.Tensor, + n_fft: int, + n_aperiodicities: int, + window: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + static_centroid = get_static_centroid(x, sr, f0, position, n_fft) + smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) + static_group_delay = get_static_group_delay( + static_centroid, smoothed_power_spec, sr, f0, n_fft + ) + coarse_aperiodicity = get_coarse_aperiodicity( + static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window + ) + coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) + return coarse_aperiodicity, rms + + +def d4c( + x: torch.Tensor, + f0: torch.Tensor, + t: torch.Tensor, + sr: int, + threshold: float = 0.85, + n_fft_spec: Optional[int] = None, + coarse_only: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" + FLOOR_F0 = 71 + FLOOR_F0_D4C = 47 + UPPER_LIMIT = 15000 + FREQ_INTERVAL = 3000 + + assert sr == int(sr) + sr = int(sr) + assert sr % 2 == 0 + x = torch.as_tensor(x) + f0 = torch.as_tensor(f0) + temporal_positions = torch.as_tensor(t) + + n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() + if n_fft_spec is None: + n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() + n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL + assert n_aperiodicities >= 1 + window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 + window = nuttall(window_length, device=x.device) + freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) + + coarse_aperiodicity, rms = d4c_general_body( + x[..., None, :], + sr, + f0.clamp(min=FLOOR_F0_D4C), + FREQ_INTERVAL, + temporal_positions, + n_fft_d4c, + n_aperiodicities, + window, + ) + if coarse_only: + return coarse_aperiodicity, rms + + even_coarse_axis = ( + torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL + ) + assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr + coarse_axis_low = ( + torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) + * FREQ_INTERVAL + ) + aperiodicity_low = interp( + coarse_axis_low, + F.pad(coarse_aperiodicity, (1, 0), value=-60.0), + freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], + ) + coarse_axis_high = torch.tensor( + [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device + ) + aperiodicity_high = interp( + coarse_axis_high, + F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), + freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], + ) + aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) + aperiodicity = 10.0 ** (aperiodicity / 20.0) + vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) + aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) + + return aperiodicity, coarse_aperiodicity + + +class Vocoder(nn.Module): def __init__( self, channels: int, hop_length: int = 240, n_pre_blocks: int = 4, + out_sample_rate: float = 24000.0, ): super().__init__() self.hop_length = hop_length + self.out_sample_rate = out_sample_rate self.prenet = ConvNeXtStack( in_channels=channels, @@ -1004,6 +1509,7 @@ class PseudoDDSPVocoder(nn.Module): delay=2, # 20ms 遅延 embed_kernel_size=7, kernel_size=33, + enable_scaling=True, ) self.ir_generator = ConvNeXtStack( in_channels=channels, @@ -1013,48 +1519,126 @@ class PseudoDDSPVocoder(nn.Module): delay=0, embed_kernel_size=3, kernel_size=33, - use_weight_norm=True, + use_weight_standardization=True, + enable_scaling=True, ) - self.ir_generator_post = weight_norm(nn.Conv1d(channels, 512, 1, bias=False)) + self.ir_generator_post = WSConv1d(channels, 512, 1) + self.register_buffer("ir_scale", torch.tensor(1.0)) + self.ir_window = nn.Parameter(torch.ones(512)) self.aperiodicity_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 3, - n_blocks=2, + n_blocks=1, delay=0, embed_kernel_size=3, kernel_size=33, - use_weight_norm=True, + use_weight_standardization=True, + enable_scaling=True, ) - self.aperiodicity_generator_post = weight_norm( - nn.Conv1d(channels, hop_length, 1, bias=False) + self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) + self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) + self.post_filter_generator = ConvNeXtStack( + in_channels=channels, + channels=channels, + intermediate_channels=channels * 3, + n_blocks=1, + delay=0, + embed_kernel_size=3, + kernel_size=33, + use_weight_standardization=True, + enable_scaling=True, ) + self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) + self.register_buffer("post_filter_scale", torch.tensor(0.01)) def forward( self, x: torch.Tensor, pitch: torch.Tensor ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: # x: [batch_size, channels, length] # pitch: [batch_size, length] + batch_size, _, length = x.size() x = self.prenet(x) ir = self.ir_generator(x) - ir = F.elu(ir, inplace=True) + ir = F.silu(ir, inplace=True) # [batch_size, 512, length] ir = self.ir_generator_post(ir) + ir *= self.ir_scale + ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() + ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) + ir_phase[:, 1::2, :] += math.pi + # TODO: 直流成分が正の値しか取れないのを修正する # 最近傍補間 # [batch_size, length * hop_length] pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) # [batch_size, length * hop_length] - periodic_signal = overlap_add(ir, pitch, self.hop_length, delay=120) + periodic_signal = overlap_add( + ir_amp, + ir_phase, + self.ir_window, + pitch, + self.hop_length, + delay=0, + sr=self.out_sample_rate, + ) aperiodicity = self.aperiodicity_generator(x) - aperiodicity = F.elu(aperiodicity, inplace=True) + aperiodicity = F.silu(aperiodicity, inplace=True) # [batch_size, hop_length, length] aperiodicity = self.aperiodicity_generator_post(aperiodicity) + aperiodicity *= self.aperiodicity_scale # [batch_size, length * hop_length], [batch_size, length * hop_length] - aperiodic_signal, noise_excitation = generate_noise(aperiodicity) + aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) + + post_filter = self.post_filter_generator(x) + post_filter = F.silu(post_filter, inplace=True) + # [batch_size, 512, length] + post_filter = self.post_filter_generator_post(post_filter) + post_filter *= self.post_filter_scale + post_filter[:, 0, :] += 1.0 + # [batch_size, length, 512] + post_filter = post_filter.transpose(1, 2) + with torch.amp.autocast("cuda", enabled=False): + periodic_signal = periodic_signal.float() + aperiodic_signal = aperiodic_signal.float() + post_filter = post_filter.float() + post_filter = torch.fft.rfft(post_filter, n=768) + + # [batch_size, length, 768] + periodic_signal = torch.fft.irfft( + torch.fft.rfft( + periodic_signal.view(batch_size, length, self.hop_length), n=768 + ) + * post_filter, + n=768, + ) + aperiodic_signal = torch.fft.irfft( + torch.fft.rfft( + aperiodic_signal.view(batch_size, length, self.hop_length), n=768 + ) + * post_filter, + n=768, + ) + periodic_signal = F.fold( + periodic_signal.transpose(1, 2), + (1, (length - 1) * self.hop_length + 768), + (1, 768), + stride=(1, self.hop_length), + ).squeeze_((1, 2)) + aperiodic_signal = F.fold( + aperiodic_signal.transpose(1, 2), + (1, (length - 1) * self.hop_length + 768), + (1, 768), + stride=(1, self.hop_length), + ).squeeze_((1, 2)) + periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] + aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] + noise_excitation = noise_excitation[:, 120:] + + # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? # [batch_size, 1, length * hop_length] y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] @@ -1067,17 +1651,21 @@ class PseudoDDSPVocoder(nn.Module): "noise_excitation": noise_excitation.detach(), } - def remove_weight_norm(self): - self.prenet.remove_weight_norm() - self.ir_generator.remove_weight_norm() - remove_weight_norm(self.ir_generator_post) - self.aperiodicity_generator.remove_weight_norm() - remove_weight_norm(self.aperiodicity_generator_post) - def merge_weights(self): self.prenet.merge_weights() self.ir_generator.merge_weights() + self.ir_generator_post.merge_weights() self.aperiodicity_generator.merge_weights() + self.aperiodicity_generator_post.merge_weights() + self.ir_generator_post.weight.data *= self.ir_scale + self.ir_generator_post.bias.data *= self.ir_scale + self.ir_scale.fill_(1.0) + self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale + self.aperiodicity_scale.fill_(1.0) + self.post_filter_generator.merge_weights() + self.post_filter_generator_post.merge_weights() + self.post_filter_generator_post.weight.data *= self.post_filter_scale + self.post_filter_scale.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): @@ -1090,8 +1678,68 @@ class PseudoDDSPVocoder(nn.Module): dump_layer(self.prenet, f) dump_layer(self.ir_generator, f) dump_layer(self.ir_generator_post, f) + dump_layer(self.ir_window, f) dump_layer(self.aperiodicity_generator, f) dump_layer(self.aperiodicity_generator_post, f) + dump_layer(self.post_filter_generator, f) + dump_layer(self.post_filter_generator_post, f) + + +def compute_loudness( + x: torch.Tensor, sr: int, win_lengths: list[int] +) -> list[torch.Tensor]: + # x: [batch_size, wav_length] + assert x.ndim == 2 + n_fft = 2048 + chunk_length = n_fft // 2 + n_taps = chunk_length + 1 + + results = [] + with torch.amp.autocast("cuda", enabled=False): + if not hasattr(compute_loudness, "filter"): + compute_loudness.filter = {} + if sr not in compute_loudness.filter: + ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) + ir[0] = 0.5 + ir = torchaudio.functional.treble_biquad( + ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) + ) + ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) + ir *= 2.0 + compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( + torch.complex64 + ) + + x = x.float() + wav_length = x.size(-1) + if wav_length % chunk_length != 0: + x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) + padded_wav_length = x.size(-1) + x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) + x = torch.fft.irfft( + torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], + n=n_fft, + ) + x = F.fold( + x.transpose(-2, -1), + (1, padded_wav_length + chunk_length), + (1, n_fft), + stride=(1, chunk_length), + ).squeeze_((-3, -2))[..., :wav_length] + + x.square_() + for win_length in win_lengths: + hop_length = win_length // 4 + # [..., n_frames] + energy = ( + x.unfold(-1, win_length, hop_length) + .matmul(torch.hann_window(win_length, device=x.device)) + .add_(win_length / 4.0 * 1e-5) + .log10_() + ) + # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 + results.append(energy) + return results def slice_segments( @@ -1120,7 +1768,10 @@ class ConverterNetwork(nn.Module): "phone_extractor": phone_extractor.eval().requires_grad_(False), "pitch_estimator": pitch_estimator.eval().requires_grad_(False), } + self.out_sample_rate = out_sample_rate = 24000 self.embed_phone = nn.Conv1d(256, hidden_channels, 1) + self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) + self.embed_phone.bias.data.zero_() self.embed_quantized_pitch = nn.Embedding(384, hidden_channels) phase = ( torch.arange(384, dtype=torch.float)[:, None] @@ -1131,25 +1782,43 @@ class ConverterNetwork(nn.Module): ) self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() + self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) self.embed_quantized_pitch.weight.requires_grad_(False) self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) + self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) + self.embed_pitch_features.bias.data.zero_() self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) + self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) self.embed_formant_shift = nn.Embedding(9, hidden_channels) - self.vocoder = PseudoDDSPVocoder( + self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) + self.vocoder = Vocoder( channels=hidden_channels, - hop_length=240, + hop_length=out_sample_rate // 100, n_pre_blocks=4, + out_sample_rate=out_sample_rate, ) - self.melspectrogram = torchaudio.transforms.MelSpectrogram( - sample_rate=24000, - n_fft=1024, - win_length=720, - hop_length=128, - n_mels=80, - power=2, # 不安定さの原因になっているかも - norm="slaney", - mel_scale="slaney", - ) + self.melspectrograms = nn.ModuleList() + for win_length, n_mels in [ + (32, 5), + (64, 10), + (128, 20), + (256, 40), + (512, 80), + (1024, 160), + (2048, 320), + ]: + self.melspectrograms.append( + torchaudio.transforms.MelSpectrogram( + sample_rate=out_sample_rate, + n_fft=win_length, + win_length=win_length, + hop_length=win_length // 4, + n_mels=n_mels, + power=2, + norm="slaney", + mel_scale="slaney", + ) + ) def _get_resampler( self, orig_freq, new_freq, device, cache={} @@ -1157,7 +1826,9 @@ class ConverterNetwork(nn.Module): key = orig_freq, new_freq if key in cache: return cache[key] - resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to(device) + resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( + device, non_blocking=True + ) cache[key] = resampler return resampler @@ -1207,6 +1878,7 @@ class ConverterNetwork(nn.Module): ) concatenated_shifted_x = [] offsets = [0] + torch.backends.cudnn.benchmark = False for i in range(batch_size): shift_ratio_i = shift_ratio[i] shift_ratio_fraction_i = Fraction.from_float( @@ -1223,7 +1895,7 @@ class ConverterNetwork(nn.Module): shift.append(shift_i) shift_ratio[i] = shift_ratio_i # [1, 1, wav_length / shift_ratio] - with torch.cuda.amp.autocast(False): + with torch.amp.autocast("cuda", enabled=False): shifted_x_i = self._get_resampler( shift_numer_i, shift_denom_i, x.device )(x[i])[None] @@ -1289,6 +1961,7 @@ class ConverterNetwork(nn.Module): :, 1:shift_i, :length ] energy[i : i + 1, :, :length] = energy_i[:, :, :length] + torch.backends.cudnn.benchmark = True # [batch_size, pitch_channels, length] -> Long[batch_size, length], [batch_size, 3, length] quantized_pitch, pitch_features = pitch_estimator.sample_pitch( @@ -1302,7 +1975,7 @@ class ConverterNetwork(nn.Module): quantized_pitch + ( pitch_shift_semitone[:, None] - * (pitch_estimator.bins_per_octave / 12) + * (pitch_estimator.bins_per_octave / 12.0) ) .round_() .long() @@ -1346,13 +2019,14 @@ class ConverterNetwork(nn.Module): x = F.silu(x, inplace=True) # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] y_g_hat, stats = self.vocoder(x, pitch) + stats["pitch"] = pitch if return_stats: return y_g_hat, stats else: return y_g_hat def _normalize_melsp(self, x): - return x.log().mul(0.5).clamp_(min=math.log(1e-5)) + return x.clamp(min=1e-10).log_().mul_(0.5) def forward_and_compute_loss( self, @@ -1362,6 +2036,7 @@ class ConverterNetwork(nn.Module): slice_start_indices: torch.Tensor, slice_segment_length: int, y_all: torch.Tensor, + enable_loss_ap: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # noisy_wavs_16k: [batch_size, 1, wav_length] # target_speaker_id: Long[batch_size] @@ -1370,63 +2045,99 @@ class ConverterNetwork(nn.Module): # slice_segment_length: int # y_all: [batch_size, 1, wav_length] + stats = {} + loss_mel = 0.0 + # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] - y_hat_all, stats = self( + y_hat_all, intermediates = self( noisy_wavs_16k, target_speaker_id, formant_shift_semitone, return_stats=True, ) - with torch.cuda.amp.autocast(False): - melsp_periodic_signal = self.melspectrogram( - stats["periodic_signal"].float() - ) - melsp_aperiodic_signal = self.melspectrogram( - stats["aperiodic_signal"].float() - ) - melsp_noise_excitation = self.melspectrogram( - stats["noise_excitation"].float() - ) - # [1, n_mels, 1] - # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー - # 3/8 ... ハン窓をかけた時のパワー減衰 - # 0.5 ... 謎 - reference_melsp = self.melspectrogram.mel_scale( - torch.full( - (1, self.melspectrogram.n_fft // 2 + 1, 1), - (1 / 6) * (3 / 8) * 0.5 * self.melspectrogram.win_length, - device=noisy_wavs_16k.device, + with torch.amp.autocast("cuda", enabled=False): + periodic_signal = intermediates["periodic_signal"].float() + aperiodic_signal = intermediates["aperiodic_signal"].float() + noise_excitation = intermediates["noise_excitation"].float() + periodic_signal = periodic_signal[:, : noise_excitation.size(1)] + aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] + y_hat_all = y_hat_all.float() + y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] + y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] + + for melspectrogram in self.melspectrograms: + melsp_periodic_signal = melspectrogram(periodic_signal) + melsp_aperiodic_signal = melspectrogram(aperiodic_signal) + melsp_noise_excitation = melspectrogram(noise_excitation) + # [1, n_mels, 1] + # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー + # 3/8 ... ハン窓をかけた時のパワー減衰 + # 0.5 ... 謎 + reference_melsp = melspectrogram.mel_scale( + torch.full( + (1, melspectrogram.n_fft // 2 + 1, 1), + (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, + device=noisy_wavs_16k.device, + ) ) - ) - aperiodic_ratio = melsp_aperiodic_signal / ( - melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 - ) - compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) - - melsp_y_hat = self.melspectrogram(y_hat_all.float().squeeze(1)) - melsp_y_hat = melsp_y_hat * ( - (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio - ) + aperiodic_ratio = melsp_aperiodic_signal / ( + melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 + ) + compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) - y_hat_mel = self._normalize_melsp(melsp_y_hat) - # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] - y_hat = slice_segments( - y_hat_all, slice_start_indices * 240, slice_segment_length * 240 - ) + melsp_y_hat = melspectrogram(y_hat_all_truncated) + melsp_y_hat = melsp_y_hat * ( + (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio + ) + y_hat_mel = self._normalize_melsp(melsp_y_hat) - y_mel = self._normalize_melsp(self.melspectrogram(y_all.squeeze(1))) - # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] - y = slice_segments( - y_all, slice_start_indices * 240, slice_segment_length * 240 - ) + y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) + loss_mel_i = F.l1_loss(y_hat_mel, y_mel) + loss_mel += loss_mel_i + stats[ + f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" + ] = loss_mel_i.item() - loss_mel = F.l1_loss(y_hat_mel, y_mel) + loss_mel /= len(self.melspectrograms) - return y, y_hat, y_hat_all, loss_mel + if enable_loss_ap: + t = ( + torch.arange(intermediates["pitch"].size(1), device=y_all.device) + * 0.01 + ) + y_coarse_aperiodicity, y_rms = d4c( + y_all.squeeze(1), + intermediates["pitch"], + t, + self.vocoder.out_sample_rate, + coarse_only=True, + ) + y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) + y_hat_coarse_aperiodicity, y_hat_rms = d4c( + y_hat_all.squeeze(1), + intermediates["pitch"], + t, + self.vocoder.out_sample_rate, + coarse_only=True, + ) + y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) + rms = torch.maximum(y_rms, y_hat_rms) + loss_ap = F.mse_loss( + y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" + ) + loss_ap *= (rms / (rms + 1e-3))[:, :, None] + loss_ap = loss_ap.mean() + else: + loss_ap = torch.tensor(0.0) - def remove_weight_norm(self): - self.vocoder.remove_weight_norm() + # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] + y_hat = slice_segments( + y_hat_all, slice_start_indices * 240, slice_segment_length * 240 + ) + # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] + y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) + return y, y_hat, y_hat_all, loss_mel, loss_ap, stats def merge_weights(self): self.vocoder.merge_weights() @@ -1624,8 +2335,7 @@ class DiscriminatorR(nn.Module): ]: fmap = [] - x = self._spectrogram(x) - x.unsqueeze_(1) + x = self._spectrogram(x).unsqueeze(1) for l in self.convs: x = l(x) x = F.silu(x, inplace=True) @@ -1650,9 +2360,8 @@ class DiscriminatorR(nn.Module): n_fft, hop_length, win_length = self.resolution x = F.pad( x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" - ) - x.squeeze_(1) - with torch.cuda.amp.autocast(False): + ).squeeze(1) + with torch.amp.autocast("cuda", enabled=False): mag = torch.stft( x.float(), n_fft=n_fft, @@ -1716,57 +2425,56 @@ class MultiPeriodDiscriminator(nn.Module): y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - def forward_and_compute_discriminator_loss( + def forward_and_compute_loss( self, y: torch.Tensor, y_hat: torch.Tensor - ) -> tuple[torch.Tensor, dict[str, float]]: - y_d_rs, y_d_gs, _, _ = self(y, y_hat, flg_san_train=self.san) - loss = 0.0 + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: + y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) stats = {} assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) - for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): - if self.san: - dr_fun, dr_dir = map(lambda x: x.float(), dr) - dg_fun, dg_dir = map(lambda x: x.float(), dg) - r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() - g_loss_fun = F.softplus(dg_fun).square().mean() - r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() - g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() - r_loss = r_loss_fun + r_loss_dir - g_loss = g_loss_fun + g_loss_dir - else: - dr = dr.float() + with torch.amp.autocast("cuda", enabled=False): + # discriminator loss + d_loss = 0.0 + for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): + if self.san: + dr_fun, dr_dir = map(lambda x: x.float(), dr) + dg_fun, dg_dir = map(lambda x: x.float(), dg) + r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() + g_loss_fun = F.softplus(dg_fun).square().mean() + r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() + g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() + r_loss = r_loss_fun + r_loss_dir + g_loss = g_loss_fun + g_loss_dir + else: + dr = dr.float() + dg = dg.float() + r_loss = (1.0 - dr).square().mean() + g_loss = dg.square().mean() + stats[f"{name}_dr_loss"] = r_loss.item() + stats[f"{name}_dg_loss"] = g_loss.item() + d_loss += r_loss + g_loss + # adversarial loss + adv_loss = 0.0 + for dg, name in zip(y_d_gs, self.discriminator_names): dg = dg.float() - r_loss = (1.0 - dr).square().mean() - g_loss = dg.square().mean() - stats[f"{name}_dr_loss"] = r_loss.item() - stats[f"{name}_dg_loss"] = g_loss.item() - loss += r_loss + g_loss - return loss, stats - - def forward_and_compute_generator_loss( - self, y: torch.Tensor, y_hat: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]: - _, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=False) - stats = {} - # adversarial loss - adv_loss = 0.0 - for dg, name in zip(y_d_gs, self.discriminator_names): - dg = dg.float() - if self.san: - g_loss = F.softplus(1.0 - dg).square().mean() - else: - g_loss = (1.0 - dg).square().mean() - stats[f"{name}_gg_loss"] = g_loss.item() - adv_loss += g_loss - # feature mathcing loss - fm_loss = 0.0 - for fr, fg in zip(fmap_rs, fmap_gs): - for r, g in zip(fr, fg): - fm_loss += (r.detach() - g).abs().mean() - return adv_loss, fm_loss, stats + if self.san: + g_loss = F.softplus(1.0 - dg).square().mean() + else: + g_loss = (1.0 - dg).square().mean() + stats[f"{name}_gg_loss"] = g_loss.item() + adv_loss += g_loss + # feature mathcing loss + fm_loss = 0.0 + for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): + fm_loss_i = 0.0 + for j, (r, g) in enumerate(zip(fr, fg)): + fm_loss_ij = (r.detach().float() - g.float()).abs().mean() + stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() + fm_loss_i += fm_loss_ij + stats[f"{name}_fm_loss"] = fm_loss_i.item() + fm_loss += fm_loss_i + return d_loss, adv_loss, fm_loss, stats # %% [markdown] @@ -1798,7 +2506,7 @@ class GradBalancer: self, losses: dict[str, torch.Tensor], input: torch.Tensor, - scaler: Optional[torch.cuda.amp.GradScaler] = None, + scaler: Optional[torch.amp.GradScaler] = None, skip_update_ema: bool = False, ) -> dict[str, float]: stats = {} @@ -1813,7 +2521,6 @@ class GradBalancer: if scaler is not None: loss = scaler.scale(loss) (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) - if not grad.isfinite().all(): input.backward(grad) return {} @@ -1854,20 +2561,19 @@ class GradBalancer: if scaler is not None: loss = scaler.scale(loss) if skip_update_ema: - loss.backward() - else: - input.backward(loss) + (loss,) = torch.autograd.grad(loss, [input]) + input.backward(loss) return stats - def state_dict(self): + def state_dict(self) -> dict[str, dict[str, float]]: return { - "ema_total": self.ema_total, - "ema_fix": self.ema_fix, + "ema_total": dict(self.ema_total), + "ema_fix": dict(self.ema_fix), } def load_state_dict(self, state_dict): - self.ema_total = state_dict["ema_total"] - self.ema_fix = state_dict["ema_fix"] + self.ema_total = defaultdict(float, state_dict["ema_total"]) + self.ema_fix = defaultdict(float, state_dict["ema_fix"]) class QualityTester(nn.Module): @@ -2114,6 +2820,9 @@ class WavDataset(torch.utils.data.Dataset): def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: file, speaker_id = self.audio_files[index] clean_wav, sample_rate = torchaudio.load(file, backend="soundfile") + if clean_wav.size(0) != 1: + ch = torch.randint(0, clean_wav.size(0), ()) + clean_wav = clean_wav[ch : ch + 1] formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] formant_shift = formant_shift_candidates[ @@ -2238,7 +2947,7 @@ def prepare_training(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - (h, in_wav_dataset_dir, out_dir, resume) = ( + (h, in_wav_dataset_dir, out_dir, resume, skip_training) = ( prepare_training_configs_for_experiment if is_notebook() else prepare_training_configs @@ -2385,6 +3094,7 @@ def prepare_training(): batch_size=h.batch_size, pin_memory=True, drop_last=True, + persistent_workers=True, ) print("Computing mean F0s of target speakers...", end="") @@ -2405,7 +3115,7 @@ def prepare_training(): for i, (file, target_ids) in enumerate(tqdm(test_filelist)): source_f0 = compute_mean_f0([file], method="harvest") source_f0s.append(source_f0) - if source_f0 != source_f0: + if math.isnan(source_f0): test_pitch_shifts.append([0] * len(target_ids)) continue pitch_shifts = [] @@ -2414,7 +3124,7 @@ def prepare_training(): if target_f0 != target_f0: pitch_shift = 0 else: - pitch_shift = int(round(12 * math.log2(target_f0 / source_f0))) + pitch_shift = int(round(12.0 * math.log2(target_f0 / source_f0))) pitch_shifts.append(pitch_shift) test_pitch_shifts.append(pitch_shifts) print("Done.") @@ -2423,7 +3133,7 @@ def prepare_training(): phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) phone_extractor_checkpoint = torch.load( - repo_root() / h.phone_extractor_file, map_location="cpu" + repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True ) print( phone_extractor.load_state_dict(phone_extractor_checkpoint["phone_extractor"]) @@ -2432,7 +3142,7 @@ def prepare_training(): pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) pitch_estimator_checkpoint = torch.load( - repo_root() / h.pitch_estimator_file, map_location="cpu" + repo_root() / h.pitch_estimator_file, map_location="cpu", weights_only=True ) print( pitch_estimator.load_state_dict(pitch_estimator_checkpoint["pitch_estimator"]) @@ -2449,24 +3159,25 @@ def prepare_training(): optim_g = torch.optim.AdamW( net_g.parameters(), - h.learning_rate, + h.learning_rate_g, betas=h.adam_betas, eps=h.adam_eps, ) optim_d = torch.optim.AdamW( net_d.parameters(), - h.learning_rate, + h.learning_rate_d, betas=h.adam_betas, eps=h.adam_eps, ) - grad_scaler = torch.cuda.amp.GradScaler(enabled=h.use_amp) + grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp) grad_balancer = GradBalancer( weights={ "loss_mel": h.grad_weight_mel, "loss_adv": h.grad_weight_adv, "loss_fm": h.grad_weight_fm, - }, + } + | ({"loss_ap": h.grad_weight_ap} if h.grad_weight_ap else {}), ema_decay=h.grad_balancer_ema_decay, ) resample_to_in_sample_rate = torchaudio.transforms.Resample( @@ -2483,15 +3194,13 @@ def prepare_training(): else: checkpoint_file = None if checkpoint_file is not None: - checkpoint = torch.load(checkpoint_file, map_location="cpu") - if not resume: # ファインチューニング + checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True) + if not resume and not skip_training: # ファインチューニング checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"]) - initial_speaker_embedding = checkpoint["net_g"]["embed_speaker.weight"][:1] - # initial_speaker_embedding = checkpoint["net_g"]["embed_speaker.weight"].mean( - # 0, keepdim=True - # ) + initial_speaker_embedding = checkpoint["net_g"][ + "embed_speaker.weight" + ].mean(0, keepdim=True) if True: - # 0 とかランダムとかの方が良いかもしれない checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[ [0] * n_speakers ] @@ -2509,7 +3218,7 @@ def prepare_training(): ] = initial_speaker_embedding print(net_g.load_state_dict(checkpoint["net_g"], strict=False)) print(net_d.load_state_dict(checkpoint["net_d"], strict=False)) - if resume: + if resume or skip_training: optim_g.load_state_dict(checkpoint["optim_g"]) optim_d.load_state_dict(checkpoint["optim_d"]) initial_iteration = checkpoint["iteration"] @@ -2540,14 +3249,19 @@ def prepare_training(): return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) scheduler_g = get_cosine_annealing_warmup_scheduler( - optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate + optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate_g ) scheduler_d = get_cosine_annealing_warmup_scheduler( - optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate + optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate_d ) - for _ in range(initial_iteration + 1): - scheduler_g.step() - scheduler_d.step() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.", + ) + for _ in range(initial_iteration + 1): + scheduler_g.step() + scheduler_d.step() net_g.train() net_d.train() @@ -2556,12 +3270,15 @@ def prepare_training(): dict_scalars = defaultdict(list) quality_tester = QualityTester().eval().to(device) - writer = SummaryWriter(out_dir) - writer.add_text( - "log", - f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.", - initial_iteration, - ) + if skip_training: + writer = None + else: + writer = SummaryWriter(out_dir) + writer.add_text( + "log", + f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.", + initial_iteration, + ) if not resume: with open(out_dir / "config.json", "w", encoding="utf-8") as f: json.dump(dict(h), f, indent=4) @@ -2624,272 +3341,410 @@ if __name__ == "__main__": writer, ) = prepare_training() - # 学習 - - for iteration in tqdm(range(initial_iteration, h.n_steps)): - # === 1. データ前処理 === - try: - batch = next(data_iter) - except: - data_iter = iter(training_loader) - batch = next(data_iter) - ( - clean_wavs, - noisy_wavs_16k, - slice_starts, - speaker_ids, - formant_shift_semitone, - ) = map(lambda x: x.to(device, non_blocking=True), batch) - - # === 2.1 Discriminator の学習 === +if __name__ == "__main__" and writer is not None: + if h.compile_convnext: + raw_convnextstack_forward = ConvNeXtStack.forward + compiled_convnextstack_forward = torch.compile( + ConvNeXtStack.forward, mode="reduce-overhead" + ) + if h.compile_d4c: + d4c = torch.compile(d4c, mode="reduce-overhead") + if h.compile_discriminator: + MultiPeriodDiscriminator.forward_and_compute_loss = torch.compile( + MultiPeriodDiscriminator.forward_and_compute_loss, mode="reduce-overhead" + ) - with torch.cuda.amp.autocast(h.use_amp): - # Generator - y, y_hat, y_hat_for_backward, loss_mel = net_g.forward_and_compute_loss( - noisy_wavs_16k[:, None, :], + # 学習 + with ( + torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1500, warmup=10, active=5, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(out_dir), + record_shapes=True, + with_stack=True, + profile_memory=True, + with_flops=True, + ) + if h.profile + else nullcontext() + ) as profiler: + + for iteration in tqdm(range(initial_iteration, h.n_steps)): + # === 1. データ前処理 === + try: + batch = next(data_iter) + except: + data_iter = iter(training_loader) + batch = next(data_iter) + ( + clean_wavs, + noisy_wavs_16k, + slice_starts, speaker_ids, formant_shift_semitone, - slice_start_indices=slice_starts, - slice_segment_length=h.segment_length, - y_all=clean_wavs[:, None, :], - ) - assert y_hat.isfinite().all() - assert loss_mel.isfinite().all() - - # Discriminator - loss_discriminator, discriminator_d_stats = ( - net_d.forward_and_compute_discriminator_loss(y, y_hat.detach()) + ) = map(lambda x: x.to(device, non_blocking=True), batch) + + # === 2. 学習 === + with torch.amp.autocast("cuda", enabled=h.use_amp): + # === 2.1 Generator の順伝播 === + if h.compile_convnext: + ConvNeXtStack.forward = compiled_convnextstack_forward + y, y_hat, y_hat_for_backward, loss_mel, loss_ap, generator_stats = ( + net_g.forward_and_compute_loss( + noisy_wavs_16k[:, None, :], + speaker_ids, + formant_shift_semitone, + slice_start_indices=slice_starts, + slice_segment_length=h.segment_length, + y_all=clean_wavs[:, None, :], + enable_loss_ap=h.grad_weight_ap != 0.0, + ) + ) + if h.compile_convnext: + ConvNeXtStack.forward = raw_convnextstack_forward + assert y_hat.isfinite().all() + assert loss_mel.isfinite().all() + assert loss_ap.isfinite().all() + + # === 2.2 Discriminator の順伝播 === + loss_discriminator, loss_adv, loss_fm, discriminator_stats = ( + net_d.forward_and_compute_loss(y, y_hat) + ) + assert loss_discriminator.isfinite().all() + assert loss_adv.isfinite().all() + assert loss_fm.isfinite().all() + + # === 2.3 Discriminator の逆伝播 === + for param in net_d.parameters(): + assert param.grad is None + grad_scaler.scale(loss_discriminator).backward( + retain_graph=True, inputs=list(net_d.parameters()) ) + loss_discriminator = loss_discriminator.item() + grad_scaler.unscale_(optim_d) + if iteration % 5 == 0: + grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True) + else: + grad_norm_d = math.nan + d_grad_norm_stats = {} - optim_d.zero_grad() - grad_scaler.scale(loss_discriminator).backward() - grad_scaler.unscale_(optim_d) - grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True) - grad_scaler.step(optim_d) - - # === 2.2 Generator の学習 === - - with torch.cuda.amp.autocast(h.use_amp): - # Discriminator - loss_adv, loss_fm, discriminator_g_stats = ( - net_d.forward_and_compute_generator_loss(y, y_hat) + # === 2.4 Generator の逆伝播 === + for param in net_g.parameters(): + assert param.grad is None + gradient_balancer_stats = grad_balancer.backward( + { + "loss_mel": loss_mel, + "loss_adv": loss_adv, + "loss_fm": loss_fm, + } + | ({"loss_ap": loss_ap} if h.grad_weight_ap else {}), + y_hat_for_backward, + grad_scaler, + skip_update_ema=iteration > 10 and iteration % 5 != 0, ) - - optim_g.zero_grad() - gradient_balancer_stats = grad_balancer.backward( - { - "loss_mel": loss_mel, - "loss_adv": loss_adv, - "loss_fm": loss_fm, - }, - y_hat_for_backward, - grad_scaler, - skip_update_ema=iteration > 10 and iteration % 5 != 0, - ) - grad_scaler.unscale_(optim_g) - grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True) - grad_scaler.step(optim_g) - grad_scaler.update() - - # === 3. ログ === - - dict_scalars["loss_g/loss_mel"].append(loss_mel.item()) - dict_scalars["loss_g/loss_fm"].append(loss_fm.item()) - dict_scalars["loss_g/loss_adv"].append(loss_adv.item()) - dict_scalars["other/grad_scale"].append(grad_scaler.get_scale()) - dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator.item()) - if math.isfinite(grad_norm_d): - dict_scalars["other/gradient_norm_d"].append(grad_norm_d) - for name, value in d_grad_norm_stats.items(): - dict_scalars[f"~gradient_norm_d/{name}"].append(value) - if math.isfinite(grad_norm_g): - dict_scalars["other/gradient_norm_g"].append(grad_norm_g) - for name, value in g_grad_norm_stats.items(): - dict_scalars[f"~gradient_norm_g/{name}"].append(value) - dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0]) - dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0]) - for k, v in discriminator_d_stats.items(): - dict_scalars[f"~loss_discriminator/{k}"].append(v) - for k, v in discriminator_g_stats.items(): - dict_scalars[f"~loss_discriminator/{k}"].append(v) - for k, v in gradient_balancer_stats.items(): - dict_scalars[f"~gradient_balancer/{k}"].append(v) - - if (iteration + 1) % 1000 == 0 or iteration == 0: - for name, scalars in dict_scalars.items(): - if scalars: - writer.add_scalar(name, sum(scalars) / len(scalars), iteration + 1) - scalars.clear() - - # === 4. 検証 === - if (iteration + 1) % 50000 == 0 or iteration + 1 in { - 1, - 5000, - 10000, - 30000, - h.n_steps, - }: - net_g.eval() - torch.cuda.empty_cache() - - dict_qualities_all = defaultdict(list) - n_added_wavs = 0 - with torch.inference_mode(): - for i, ((file, target_ids), pitch_shift_semitones) in enumerate( - zip(test_filelist, test_pitch_shifts) - ): - source_wav, sr = torchaudio.load(file, backend="soundfile") - source_wav = source_wav.to(device) - if sr != h.in_sample_rate: - source_wav = get_resampler(sr, h.in_sample_rate, device)( - source_wav + loss_mel = loss_mel.item() + loss_adv = loss_adv.item() + loss_fm = loss_fm.item() + if h.grad_weight_ap: + loss_ap = loss_ap.item() + grad_scaler.unscale_(optim_g) + if iteration % 5 == 0: + grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True) + else: + grad_norm_g = math.nan + g_grad_norm_stats = {} + + # === 2.5 パラメータの更新 === + grad_scaler.step(optim_g) + optim_g.zero_grad(set_to_none=True) + grad_scaler.step(optim_d) + optim_d.zero_grad(set_to_none=True) + grad_scaler.update() + + # === 3. ログ === + dict_scalars["loss_g/loss_mel"].append(loss_mel) + if h.grad_weight_ap: + dict_scalars["loss_g/loss_ap"].append(loss_ap) + dict_scalars["loss_g/loss_fm"].append(loss_fm) + dict_scalars["loss_g/loss_adv"].append(loss_adv) + dict_scalars["other/grad_scale"].append(grad_scaler.get_scale()) + dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator) + if math.isfinite(grad_norm_d): + dict_scalars["other/gradient_norm_d"].append(grad_norm_d) + for name, value in d_grad_norm_stats.items(): + dict_scalars[f"~gradient_norm_d/{name}"].append(value) + if math.isfinite(grad_norm_g): + dict_scalars["other/gradient_norm_g"].append(grad_norm_g) + for name, value in g_grad_norm_stats.items(): + dict_scalars[f"~gradient_norm_g/{name}"].append(value) + dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0]) + dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0]) + for k, v in generator_stats.items(): + dict_scalars[f"~loss_generator/{k}"].append(v) + for k, v in discriminator_stats.items(): + dict_scalars[f"~loss_discriminator/{k}"].append(v) + for k, v in gradient_balancer_stats.items(): + dict_scalars[f"~gradient_balancer/{k}"].append(v) + + if (iteration + 1) % 1000 == 0 or iteration == 0: + for name, scalars in dict_scalars.items(): + if scalars: + writer.add_scalar( + name, sum(scalars) / len(scalars), iteration + 1 ) - source_wav = source_wav.to(device) - original_source_wav_length = source_wav.size(1) - # 長さのパターンを減らしてキャッシュを効かせる - if source_wav.size(1) % h.in_sample_rate == 0: - padded_source_wav = source_wav - else: - padded_source_wav = F.pad( - source_wav, - ( - 0, - h.in_sample_rate - - source_wav.size(1) % h.in_sample_rate, - ), + scalars.clear() + for name, param in net_g.named_parameters(): + writer.add_histogram(f"weight/{name}", param, iteration + 1) + + intermediate_feature_stats = {} + hook_handles = [] + + def get_layer_hook(name): + def compute_stats(module, x, suffix): + if not isinstance(x, torch.Tensor): + return + if x.dtype not in [torch.float32, torch.float16]: + return + if isinstance(module, nn.Identity): + return + x = x.detach().float() + var = x.var().item() + if isinstance(module, (nn.Linear, nn.LayerNorm)): + channel_var, channel_mean = torch.var_mean( + x.reshape(-1, x.size(-1)), 0 + ) + elif isinstance(module, nn.Conv1d): + channel_var, channel_mean = torch.var_mean(x, [0, 2]) + else: + return + average_squared_channel_mean = ( + channel_mean.square().mean().item() ) - converted = net_g( - padded_source_wav[[0] * len(target_ids), None], - torch.tensor(target_ids, device=device), - torch.tensor( - [0.0] * len(target_ids), device=device - ), # フォルマントシフト - torch.tensor( - [float(p) for p in pitch_shift_semitones], device=device - ), - ).squeeze_(1)[:, : original_source_wav_length // 160 * 240] - if i < 12: - if iteration == 0: - writer.add_audio( - f"source/y_{i:02d}", - source_wav, - iteration + 1, - h.in_sample_rate, + average_channel_var = channel_var.mean().item() + + tensor_idx = len(intermediate_feature_stats) // 3 + intermediate_feature_stats[ + f"var/{tensor_idx:02d}_{name}/{suffix}" + ] = var + intermediate_feature_stats[ + f"avg_sq_ch_mean/{tensor_idx:02d}_{name}/{suffix}" + ] = average_squared_channel_mean + intermediate_feature_stats[ + f"avg_ch_var/{tensor_idx:02d}_{name}/{suffix}" + ] = average_channel_var + + def forward_pre_hook(module, input): + for i, input_i in enumerate(input): + compute_stats(module, input_i, f"input_{i}") + + def forward_hook(module, input, output): + if isinstance(output, tuple): + for i, output_i in enumerate(output): + compute_stats(module, output_i, f"output_{i}") + else: + compute_stats(module, output, "output") + + return forward_pre_hook, forward_hook + + for name, layer in net_g.named_modules(): + forward_pre_hook, forward_hook = get_layer_hook(name) + hook_handles.append( + layer.register_forward_pre_hook(forward_pre_hook) + ) + hook_handles.append(layer.register_forward_hook(forward_hook)) + with torch.no_grad(), torch.amp.autocast("cuda", enabled=h.use_amp): + net_g.forward_and_compute_loss( + noisy_wavs_16k[:, None, :], + speaker_ids, + formant_shift_semitone, + slice_start_indices=slice_starts, + slice_segment_length=h.segment_length, + y_all=clean_wavs[:, None, :], + enable_loss_ap=h.grad_weight_ap != 0.0, + ) + for handle in hook_handles: + handle.remove() + for name, value in intermediate_feature_stats.items(): + writer.add_scalar( + f"~intermediate_feature_{name}", value, iteration + 1 + ) + + # === 4. 検証 === + if (iteration + 1) % ( + 50000 if h.n_steps > 200000 else 2000 + ) == 0 or iteration + 1 in { + 1, + 30000, + h.n_steps, + }: + torch.backends.cudnn.benchmark = False + net_g.eval() + torch.cuda.empty_cache() + + dict_qualities_all = defaultdict(list) + n_added_wavs = 0 + with torch.inference_mode(): + for i, ((file, target_ids), pitch_shift_semitones) in enumerate( + zip(test_filelist, test_pitch_shifts) + ): + source_wav, sr = torchaudio.load(file, backend="soundfile") + source_wav = source_wav.to(device) + if sr != h.in_sample_rate: + source_wav = get_resampler(sr, h.in_sample_rate, device)( + source_wav ) - for d in range( - min(len(target_ids), 1 + (12 - i - 1) // len(test_filelist)) - ): - idx_in_batch = n_added_wavs % len(target_ids) - writer.add_audio( - f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}", - converted[idx_in_batch], - iteration + 1, - h.out_sample_rate, + source_wav = source_wav.to(device) + original_source_wav_length = source_wav.size(1) + # 長さのパターンを減らしてキャッシュを効かせる + if source_wav.size(1) % h.in_sample_rate == 0: + padded_source_wav = source_wav + else: + padded_source_wav = F.pad( + source_wav, + ( + 0, + h.in_sample_rate + - source_wav.size(1) % h.in_sample_rate, + ), ) - n_added_wavs += 1 - converted = resample_to_in_sample_rate(converted) - quality = quality_tester.test(converted, source_wav) - for metric_name, values in quality.items(): - dict_qualities_all[metric_name].extend(values) - assert n_added_wavs == min( - 12, len(test_filelist) * len(test_filelist[0][1]) - ), ( - n_added_wavs, - len(test_filelist), - len(speakers), - len(test_filelist[0][1]), - ) - dict_qualities = { - metric_name: sum(values) / len(values) - for metric_name, values in dict_qualities_all.items() - if len(values) - } - for metric_name, value in dict_qualities.items(): - writer.add_scalar(f"validation/{metric_name}", value, iteration + 1) - for metric_name, values in dict_qualities_all.items(): - for i, value in enumerate(values): - writer.add_scalar( - f"~validation_{metric_name}/{i:03d}", value, iteration + 1 + converted = net_g( + padded_source_wav[[0] * len(target_ids), None], + torch.tensor(target_ids, device=device), + torch.tensor( + [0.0] * len(target_ids), device=device + ), # フォルマントシフト + torch.tensor( + [float(p) for p in pitch_shift_semitones], device=device + ), + ).squeeze_(1)[:, : original_source_wav_length // 160 * 240] + if i < 12: + if iteration == 0: + writer.add_audio( + f"source/y_{i:02d}", + source_wav, + iteration + 1, + h.in_sample_rate, + ) + for d in range( + min( + len(target_ids), + 1 + (12 - i - 1) // len(test_filelist), + ) + ): + idx_in_batch = n_added_wavs % len(target_ids) + writer.add_audio( + f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}", + converted[idx_in_batch], + iteration + 1, + h.out_sample_rate, + ) + n_added_wavs += 1 + converted = resample_to_in_sample_rate(converted) + quality = quality_tester.test(converted, source_wav) + for metric_name, values in quality.items(): + dict_qualities_all[metric_name].extend(values) + assert n_added_wavs == min( + 12, len(test_filelist) * len(test_filelist[0][1]) + ), ( + n_added_wavs, + len(test_filelist), + len(speakers), + len(test_filelist[0][1]), + ) + dict_qualities = { + metric_name: sum(values) / len(values) + for metric_name, values in dict_qualities_all.items() + if len(values) + } + for metric_name, value in dict_qualities.items(): + writer.add_scalar(f"validation/{metric_name}", value, iteration + 1) + for metric_name, values in dict_qualities_all.items(): + for i, value in enumerate(values): + writer.add_scalar( + f"~validation_{metric_name}/{i:03d}", value, iteration + 1 + ) + del dict_qualities, dict_qualities_all + + net_g.train() + torch.backends.cudnn.benchmark = True + gc.collect() + torch.cuda.empty_cache() + + # === 5. 保存 === + if (iteration + 1) % ( + 50000 if h.n_steps > 200000 else 2000 + ) == 0 or iteration + 1 in { + 1, + 30000, + h.n_steps, + }: + # チェックポイント + name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}" + checkpoint_file_save = out_dir / f"checkpoint_{name}.pt" + if checkpoint_file_save.exists(): + checkpoint_file_save = checkpoint_file_save.with_name( + f"{checkpoint_file_save.name}_{hash(None):x}" ) - del dict_qualities, dict_qualities_all - - gc.collect() - net_g.train() - torch.cuda.empty_cache() - - # === 5. 保存 === - if (iteration + 1) % 50000 == 0 or iteration + 1 in { - 1, - 5000, - 10000, - 30000, - h.n_steps, - }: - # チェックポイント - name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}" - checkpoint_file_save = out_dir / f"checkpoint_{name}.pt" - if checkpoint_file_save.exists(): - checkpoint_file_save = checkpoint_file_save.with_name( - f"{checkpoint_file_save.name}_{hash(None):x}" + torch.save( + { + "iteration": iteration + 1, + "net_g": net_g.state_dict(), + "phone_extractor": phone_extractor.state_dict(), + "pitch_estimator": pitch_estimator.state_dict(), + "net_d": net_d.state_dict(), + "optim_g": optim_g.state_dict(), + "optim_d": optim_d.state_dict(), + "grad_balancer": grad_balancer.state_dict(), + "grad_scaler": grad_scaler.state_dict(), + "h": dict(h), + }, + checkpoint_file_save, ) - torch.save( - { - "iteration": iteration + 1, - "net_g": net_g.state_dict(), - "phone_extractor": phone_extractor.state_dict(), - "pitch_estimator": pitch_estimator.state_dict(), - "net_d": net_d.state_dict(), - "optim_g": optim_g.state_dict(), - "optim_d": optim_d.state_dict(), - "grad_balancer": grad_balancer.state_dict(), - "grad_scaler": grad_scaler.state_dict(), - "h": dict(h), - }, - checkpoint_file_save, - ) - shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt") + shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt") - # 推論用 - paraphernalia_dir = out_dir / f"paraphernalia_{name}" - if paraphernalia_dir.exists(): - paraphernalia_dir = paraphernalia_dir.with_name( - f"{paraphernalia_dir.name}_{hash(None):x}" + # 推論用 + paraphernalia_dir = out_dir / f"paraphernalia_{name}" + if paraphernalia_dir.exists(): + paraphernalia_dir = paraphernalia_dir.with_name( + f"{paraphernalia_dir.name}_{hash(None):x}" + ) + paraphernalia_dir.mkdir() + phone_extractor_fp16 = PhoneExtractor() + phone_extractor_fp16.load_state_dict(phone_extractor.state_dict()) + phone_extractor_fp16.remove_weight_norm() + phone_extractor_fp16.merge_weights() + phone_extractor_fp16.half() + phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin") + del phone_extractor_fp16 + pitch_estimator_fp16 = PitchEstimator() + pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict()) + pitch_estimator_fp16.merge_weights() + pitch_estimator_fp16.half() + pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin") + del pitch_estimator_fp16 + net_g_fp16 = ConverterNetwork( + nn.Module(), nn.Module(), len(speakers), h.hidden_channels ) - paraphernalia_dir.mkdir() - phone_extractor_fp16 = PhoneExtractor() - phone_extractor_fp16.load_state_dict(phone_extractor.state_dict()) - phone_extractor_fp16.remove_weight_norm() - phone_extractor_fp16.merge_weights() - phone_extractor_fp16.half() - phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin") - del phone_extractor_fp16 - pitch_estimator_fp16 = PitchEstimator() - pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict()) - pitch_estimator_fp16.merge_weights() - pitch_estimator_fp16.half() - pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin") - del pitch_estimator_fp16 - net_g_fp16 = ConverterNetwork( - nn.Module(), nn.Module(), len(speakers), h.hidden_channels - ) - net_g_fp16.load_state_dict(net_g.state_dict()) - net_g_fp16.remove_weight_norm() - net_g_fp16.merge_weights() - net_g_fp16.half() - net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin") - with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f: - dump_layer(net_g_fp16.embed_speaker, f) - with open(paraphernalia_dir / f"formant_shift_embeddings.bin", "wb") as f: - dump_layer(net_g_fp16.embed_formant_shift, f) - del net_g_fp16 - shutil.copy(repo_root() / "assets/images/noimage.png", paraphernalia_dir) - with open( - paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml", - "w", - encoding="utf-8", - ) as f: - f.write( - f'''[model] + net_g_fp16.load_state_dict(net_g.state_dict()) + net_g_fp16.merge_weights() + net_g_fp16.half() + net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin") + with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f: + dump_layer(net_g_fp16.embed_speaker, f) + with open( + paraphernalia_dir / f"formant_shift_embeddings.bin", "wb" + ) as f: + dump_layer(net_g_fp16.embed_formant_shift, f) + del net_g_fp16 + shutil.copy( + repo_root() / "assets/images/noimage.png", paraphernalia_dir + ) + with open( + paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml", + "w", + encoding="utf-8", + ) as f: + f.write( + f'''[model] version = "{PARAPHERNALIA_VERSION}" name = "{name}" description = """ @@ -2897,14 +3752,14 @@ No description for this model. このモデルの説明はありません。 """ ''' - ) - for speaker_id, (speaker, speaker_f0) in enumerate( - zip(speakers, speaker_f0s) - ): - average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0) - average_pitch = round(average_pitch * 8.0) / 8.0 - f.write( - f''' + ) + for speaker_id, (speaker, speaker_f0) in enumerate( + zip(speakers, speaker_f0s) + ): + average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0) + average_pitch = round(average_pitch * 8.0) / 8.0 + f.write( + f''' [voice.{speaker_id}] name = "{speaker}" description = """ @@ -2918,14 +3773,15 @@ path = "noimage.png" description = """ """ ''' - ) - del paraphernalia_dir - - # TODO: phone_extractor, pitch_estimator が既知のモデルであれば dump を省略 + ) + del paraphernalia_dir - # === 6. スケジューラ更新 === - scheduler_g.step() - scheduler_d.step() + # TODO: phone_extractor, pitch_estimator が既知のモデルであれば dump を省略 + # === 6. スケジューラ更新 === + scheduler_g.step() + scheduler_d.step() + if h.profile: + profiler.step() -print("Training finished.") + print("Training finished.")