# %% [markdown] # ## Settings # %% import argparse import gc import json import math import os import shutil import warnings from collections import defaultdict from contextlib import nullcontext from copy import deepcopy from fractions import Fraction from functools import partial from pathlib import Path from pprint import pprint from random import Random from typing import BinaryIO, Literal, Optional, Union import numpy as np import pyworld import torch import torch.nn as nn import torchaudio from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm assert "soundfile" in torchaudio.list_audio_backends() if not hasattr(torch.amp, "GradScaler"): class GradScaler(torch.cuda.amp.GradScaler): def __init__(self, _, *args, **kwargs): super().__init__(*args, **kwargs) torch.amp.GradScaler = GradScaler # モジュールのバージョンではない PARAPHERNALIA_VERSION = "2.0.0-beta.1" def is_notebook() -> bool: return "get_ipython" in globals() def repo_root() -> Path: d = Path.cwd() / "dummy" if is_notebook() else Path(__file__) assert d.is_absolute(), d for d in d.parents: if (d / ".git").is_dir(): return d raise RuntimeError("Repository root is not found.") # ハイパーパラメータ # 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない dict_default_hparams = { # train "learning_rate_g": 2e-4, "learning_rate_d": 1e-4, "min_learning_rate_g": 1e-5, "min_learning_rate_d": 5e-6, "adam_betas": [0.8, 0.99], "adam_eps": 1e-6, "batch_size": 8, "grad_weight_mel": 1.0, # grad_weight は比が同じなら同じ意味になるはず "grad_weight_ap": 2.0, "grad_weight_adv": 3.0, "grad_weight_fm": 3.0, "grad_balancer_ema_decay": 0.995, "use_amp": True, "num_workers": 16, "n_steps": 10000, "warmup_steps": 2000, "in_sample_rate": 16000, # 変更不可 "out_sample_rate": 24000, # 変更不可 "wav_length": 4 * 24000, # 4s "segment_length": 100, # 1s # data "phone_extractor_file": "assets/pretrained/003b_checkpoint_03000000.pt", "pitch_estimator_file": "assets/pretrained/008_1_checkpoint_00300000.pt", "in_ir_wav_dir": "assets/ir", "in_noise_wav_dir": "assets/noise", "in_test_wav_dir": "assets/test", "pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt", # None も可 # model "hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要 "san": False, # ファインチューン時変更不可 "compile_convnext": False, "compile_d4c": False, "compile_discriminator": False, "profile": False, } if __name__ == "__main__": # スクリプト内部のデフォルト設定と assets/default_config.json が同期されているか確認 default_config_file = repo_root() / "assets/default_config.json" if default_config_file.is_file(): with open(default_config_file, encoding="utf-8") as f: default_config: dict = json.load(f) for key, value in dict_default_hparams.items(): if key not in default_config: warnings.warn(f"{key} not found in default_config.json.") else: if value != default_config[key]: warnings.warn( f"{key} differs between default_config.json ({default_config[key]}) and internal default hparams ({value})." ) del default_config[key] for key in default_config: warnings.warn(f"{key} found in default_config.json is unknown.") else: warnings.warn("dafualt_config.json not found.") def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]: import ipynbname from IPython import get_ipython h = deepcopy(dict_default_hparams) in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200" try: notebook_name = ipynbname.name() except FileNotFoundError: notebook_name = Path(get_ipython().user_ns["__vsc_ipynb_file__"]).name out_dir = repo_root() / "notebooks" / notebook_name.split(".")[0].split("_")[0] resume = False skip_training = False return h, in_wav_dataset_dir, out_dir, resume, skip_training def prepare_training_configs() -> tuple[dict, Path, Path, bool, bool]: # data_dir, out_dir は config ファイルでもコマンドライン引数でも指定でき、 # コマンドライン引数が優先される。 # 各種ファイルパスを相対パスで指定した場合、config ファイルでは # リポジトリルートからの相対パスとなるが、コマンドライン引数では # カレントディレクトリからの相対パスとなる。 parser = argparse.ArgumentParser() # fmt: off parser.add_argument("-d", "--data_dir", type=Path, help="directory containing the training data") parser.add_argument("-o", "--out_dir", type=Path, help="output directory") parser.add_argument("-r", "--resume", action="store_true", help="resume training") parser.add_argument("-c", "--config", type=Path, help="path to the config file") # fmt: on args = parser.parse_args() # config if args.config is None: h = deepcopy(dict_default_hparams) else: with open(args.config, encoding="utf-8") as f: h = json.load(f) for key in dict_default_hparams.keys(): if key not in h: h[key] = dict_default_hparams[key] warnings.warn( f"{key} is not specified in the config file. Using the default value." ) # data_dir if args.data_dir is not None: in_wav_dataset_dir = args.data_dir elif "data_dir" in h: in_wav_dataset_dir = repo_root() / Path(h["data_dir"]) del h["data_dir"] else: raise ValueError( "data_dir must be specified. " "For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`." ) # out_dir if args.out_dir is not None: out_dir = args.out_dir elif "out_dir" in h: out_dir = repo_root() / Path(h["out_dir"]) del h["out_dir"] else: raise ValueError( "out_dir must be specified. " "For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`." ) for key in list(h.keys()): if key not in dict_default_hparams: warnings.warn(f"`{key}` specified in the config file will be ignored.") del h[key] # resume resume = args.resume return h, in_wav_dataset_dir, out_dir, resume, False class AttrDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self # %% [markdown] # ## Phone Extractor # %% def dump_params(params: torch.Tensor, f: BinaryIO): if params is None: return if params.dtype == torch.bfloat16: f.write( params.detach() .clone() .float() .view(torch.short) .numpy() .ravel()[1::2] .tobytes() ) else: f.write(params.detach().numpy().ravel().tobytes()) f.flush() def dump_layer(layer: nn.Module, f: BinaryIO): dump = partial(dump_params, f=f) if hasattr(layer, "dump"): layer.dump(f) elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): dump(layer.weight) dump(layer.bias) elif isinstance(layer, nn.ConvTranspose1d): dump(layer.weight.transpose(0, 1)) dump(layer.bias) elif isinstance(layer, nn.GRU): dump(layer.weight_ih_l0) dump(layer.bias_ih_l0) dump(layer.weight_hh_l0) dump(layer.bias_hh_l0) for i in range(1, 99999): if not hasattr(layer, f"weight_ih_l{i}"): break dump(getattr(layer, f"weight_ih_l{i}")) dump(getattr(layer, f"bias_ih_l{i}")) dump(getattr(layer, f"weight_hh_l{i}")) dump(getattr(layer, f"bias_hh_l{i}")) elif isinstance(layer, nn.Embedding): dump(layer.weight) elif isinstance(layer, nn.Parameter): dump(layer) elif isinstance(layer, nn.ModuleList): for l in layer: dump_layer(l, f) else: assert False, layer class CausalConv1d(nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, delay: int = 0, ): padding = (kernel_size - 1) * dilation - delay self.trim = (kernel_size - 1) * dilation - 2 * delay if self.trim < 0: raise ValueError super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) def forward(self, input: torch.Tensor) -> torch.Tensor: result = super().forward(input) if self.trim == 0: return result else: return result[:, :, : -self.trim] class WSConv1d(CausalConv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, delay: int = 0, ): super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=bias, delay=delay, ) self.weight.data.normal_( 0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) ) if bias: self.bias.data.zero_() self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) def standardized_weight(self) -> torch.Tensor: var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) scale = ( self.gain * ( self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 ).rsqrt() ) return scale * (self.weight - mean) def forward(self, input: torch.Tensor) -> torch.Tensor: result = F.conv1d( input, self.standardized_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups, ) if self.trim == 0: return result else: return result[:, :, : -self.trim] def merge_weights(self): self.weight.data[:] = self.standardized_weight().detach() self.gain.data.fill_(1.0) class WSLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True): super().__init__(in_features, out_features, bias) self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) self.bias.data.zero_() self.gain = nn.Parameter(torch.ones((out_features, 1))) def standardized_weight(self) -> torch.Tensor: var, mean = torch.var_mean(self.weight, 1, keepdim=True) scale = self.gain * (self.in_features * var + 1e-8).rsqrt() return scale * (self.weight - mean) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.standardized_weight(), self.bias) def merge_weights(self): self.weight.data[:] = self.standardized_weight().detach() self.gain.data.fill_(1.0) class ConvNeXtBlock(nn.Module): def __init__( self, channels: int, intermediate_channels: int, layer_scale_init_value: float, kernel_size: int = 7, use_weight_standardization: bool = False, enable_scaling: bool = False, pre_scale: float = 1.0, post_scale: float = 1.0, ): super().__init__() self.use_weight_standardization = use_weight_standardization self.enable_scaling = enable_scaling self.dwconv = CausalConv1d( channels, channels, kernel_size=kernel_size, groups=channels ) self.norm = nn.LayerNorm(channels) self.pwconv1 = nn.Linear(channels, intermediate_channels) self.pwconv2 = nn.Linear(intermediate_channels, channels) self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) self.dwconv.bias.data.zero_() self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) self.pwconv1.bias.data.zero_() self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) self.pwconv2.bias.data.zero_() if use_weight_standardization: self.norm = nn.Identity() self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) self.pwconv1 = WSLinear(channels, intermediate_channels) self.pwconv2 = WSLinear(intermediate_channels, channels) del self.gamma if enable_scaling: self.register_buffer("pre_scale", torch.tensor(pre_scale)) self.register_buffer("post_scale", torch.tensor(post_scale)) self.post_scale_weight = nn.Parameter(torch.ones(())) def forward(self, x: torch.Tensor) -> torch.Tensor: identity = x if self.enable_scaling: x = x * self.pre_scale x = self.dwconv(x) x = x.transpose(1, 2) x = self.norm(x) x = self.pwconv1(x) x = F.gelu(x, approximate="tanh") x = self.pwconv2(x) if not self.use_weight_standardization: x *= self.gamma if self.enable_scaling: x *= self.post_scale * self.post_scale_weight x = x.transpose(1, 2) x += identity return x def merge_weights(self): if self.use_weight_standardization: self.dwconv.merge_weights() self.pwconv1.merge_weights() self.pwconv2.merge_weights() else: self.pwconv1.bias.data += ( self.norm.bias.data[None, :] * self.pwconv1.weight.data ).sum(1) self.pwconv1.weight.data *= self.norm.weight.data[None, :] self.norm.bias.data[:] = 0.0 self.norm.weight.data[:] = 1.0 self.pwconv2.weight.data *= self.gamma.data[:, None] self.pwconv2.bias.data *= self.gamma.data self.gamma.data[:] = 1.0 if self.enable_scaling: self.dwconv.weight.data *= self.pre_scale.data self.pre_scale.data.fill_(1.0) self.pwconv2.weight.data *= ( self.post_scale.data * self.post_scale_weight.data ) self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data self.post_scale.data.fill_(1.0) self.post_scale_weight.data.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.dwconv, f) dump_layer(self.pwconv1, f) dump_layer(self.pwconv2, f) class ConvNeXtStack(nn.Module): def __init__( self, in_channels: int, channels: int, intermediate_channels: int, n_blocks: int, delay: int, embed_kernel_size: int, kernel_size: int, use_weight_standardization: bool = False, enable_scaling: bool = False, ): super().__init__() assert delay * 2 + 1 <= embed_kernel_size self.use_weight_standardization = use_weight_standardization self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.LayerNorm(channels) self.convnext = nn.ModuleList() for i in range(n_blocks): pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 block = ConvNeXtBlock( channels=channels, intermediate_channels=intermediate_channels, layer_scale_init_value=1.0 / n_blocks, kernel_size=kernel_size, use_weight_standardization=use_weight_standardization, enable_scaling=enable_scaling, pre_scale=pre_scale, post_scale=post_scale, ) self.convnext.append(block) self.final_layer_norm = nn.LayerNorm(channels) self.embed.weight.data.normal_( 0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) ) self.embed.bias.data.zero_() if use_weight_standardization: self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) self.norm = nn.Identity() self.final_layer_norm = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embed(x) x = self.norm(x.transpose(1, 2)).transpose(1, 2) for conv_block in self.convnext: x = conv_block(x) x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) return x def merge_weights(self): if self.use_weight_standardization: self.embed.merge_weights() for conv_block in self.convnext: conv_block.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.embed, f) if not self.use_weight_standardization: dump_layer(self.norm, f) dump_layer(self.convnext, f) if not self.use_weight_standardization: dump_layer(self.final_layer_norm, f) class FeatureExtractor(nn.Module): def __init__(self, hidden_channels: int): super().__init__() # fmt: off self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) # fmt: on def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, 1, wav_length] wav_length = x.size(2) if wav_length % 160 != 0: warnings.warn("wav_length % 160 != 0") x = F.pad(x, (40, 40)) x = F.gelu(self.conv0(x), approximate="tanh") x = F.gelu(self.conv1(x), approximate="tanh") x = F.gelu(self.conv2(x), approximate="tanh") x = F.gelu(self.conv3(x), approximate="tanh") x = F.gelu(self.conv4(x), approximate="tanh") x = F.gelu(self.conv5(x), approximate="tanh") # [batch_size, hidden_channels, wav_length / 160] return x def remove_weight_norm(self): remove_weight_norm(self.conv0) remove_weight_norm(self.conv1) remove_weight_norm(self.conv2) remove_weight_norm(self.conv3) remove_weight_norm(self.conv4) remove_weight_norm(self.conv5) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.conv0, f) dump_layer(self.conv1, f) dump_layer(self.conv2, f) dump_layer(self.conv3, f) dump_layer(self.conv4, f) dump_layer(self.conv5, f) class FeatureProjection(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.norm = nn.LayerNorm(in_channels) self.projection = nn.Conv1d(in_channels, out_channels, 1) self.dropout = nn.Dropout(0.1) def forward(self, x: torch.Tensor) -> torch.Tensor: # [batch_size, channels, length] x = self.norm(x.transpose(1, 2)).transpose(1, 2) x = self.projection(x) x = self.dropout(x) return x def merge_weights(self): self.projection.bias.data += ( (self.norm.bias.data[None, :, None] * self.projection.weight.data) .sum(1) .squeeze(1) ) self.projection.weight.data *= self.norm.weight.data[None, :, None] self.norm.bias.data[:] = 0.0 self.norm.weight.data[:] = 1.0 def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.projection, f) class PhoneExtractor(nn.Module): def __init__( self, phone_channels: int = 256, hidden_channels: int = 256, backbone_embed_kernel_size: int = 7, kernel_size: int = 17, n_blocks: int = 8, ): super().__init__() self.feature_extractor = FeatureExtractor(hidden_channels) self.feature_projection = FeatureProjection(hidden_channels, hidden_channels) self.n_speaker_encoder_layers = 3 self.speaker_encoder = nn.GRU( hidden_channels, hidden_channels, self.n_speaker_encoder_layers, batch_first=True, ) for i in range(self.n_speaker_encoder_layers): for input_char in "ih": self.speaker_encoder = weight_norm( self.speaker_encoder, f"weight_{input_char}h_l{i}" ) self.backbone = ConvNeXtStack( in_channels=hidden_channels, channels=hidden_channels, intermediate_channels=hidden_channels * 3, n_blocks=n_blocks, delay=0, embed_kernel_size=backbone_embed_kernel_size, kernel_size=kernel_size, ) self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) def forward( self, x: torch.Tensor, return_stats: bool = True ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: # x: [batch_size, 1, wav_length] stats = {} # [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length] x = self.feature_extractor(x) if return_stats: stats["feature_norm"] = x.detach().norm(dim=1).mean() # [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length] x = self.feature_projection(x) # [batch_size, hidden_channels, length] -> [batch_size, length, hidden_channels] g, _ = self.speaker_encoder(x.transpose(1, 2)) if self.training: batch_size, length, _ = g.size() shuffle_sizes_for_each_data = torch.randint( 0, 50, (batch_size,), device=g.device ) max_indices = torch.arange(length, device=g.device)[None, :, None] min_indices = ( max_indices - shuffle_sizes_for_each_data[:, None, None] ).clamp_(min=0) with torch.cuda.amp.autocast(False): indices = ( torch.rand(g.size(), device=g.device) * (max_indices - min_indices + 1) ).long() + min_indices assert indices.min() >= 0, indices.min() assert indices.max() < length, (indices.max(), length) g = g.gather(1, indices) # [batch_size, length, hidden_channels] -> [batch_size, hidden_channels, length] g = g.transpose(1, 2).contiguous() # [batch_size, hidden_channels, length] x = self.backbone(x + g) # [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length] phone = self.head(F.gelu(x, approximate="tanh")) results = [phone] if return_stats: stats["code_norm"] = phone.detach().norm(dim=1).mean().item() results.append(stats) if len(results) == 1: return results[0] return tuple(results) @torch.inference_mode() def units(self, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, 1, wav_length] # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] phone = self.forward(x, return_stats=False) # [batch_size, phone_channels, length] -> [batch_size, length, phone_channels] phone = phone.transpose(1, 2) # [batch_size, length, phone_channels] return phone def remove_weight_norm(self): self.feature_extractor.remove_weight_norm() for i in range(self.n_speaker_encoder_layers): for input_char in "ih": remove_weight_norm(self.speaker_encoder, f"weight_{input_char}h_l{i}") remove_weight_norm(self.head) def merge_weights(self): self.feature_projection.merge_weights() self.backbone.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.feature_extractor, f) dump_layer(self.feature_projection, f) dump_layer(self.speaker_encoder, f) dump_layer(self.backbone, f) dump_layer(self.head, f) # %% [markdown] # ## Pitch Estimator # %% def extract_pitch_features( y: torch.Tensor, # [..., wav_length] hop_length: int = 160, # 10ms win_length: int = 560, # 35ms max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256) corr_win_length: int = 304, # 19ms instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560) ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert max_corr_period + corr_win_length == win_length # パディングする padding_length = (win_length - hop_length) // 2 y = F.pad(y, (padding_length, padding_length)) # フレームにする # [..., win_length, n_frames] y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) # 複素スペクトログラム # Complex[..., (win_length // 2 + 1), n_frames] spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) # Complex[..., instfreq_features_cutoff_bin, n_frames] spec = spec[..., :instfreq_features_cutoff_bin, :] # 対数パワースペクトログラム log_power_spec = spec.abs().add_(1e-5).log10_() # 瞬時位相の時間差分 # 時刻 0 の値は 0 delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() delta_spec /= delta_spec.abs().add_(1e-5) delta_spec = torch.cat( [torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 ) # [..., instfreq_features_cutoff_bin * 3, n_frames] instfreq_features = torch.cat( [log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 ) # 自己相関 # 余裕があったら LPC 残差にするのも試したい # 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、 # この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を # 標準化する方法が思いつかなかったのでやめた flipped_y_frames = y_frames.flip((-2,)) a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) # [..., max_corr_period, n_frames] corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] # エネルギー項 energy = flipped_y_frames.square_().cumsum_(-2) energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] # Difference function corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) assert corr_diff.min() >= -1e-3, corr_diff.min() corr_diff.clamp_(min=0.0) # 計算誤差対策 # 標準化 corr_diff *= 2.0 / corr_win_length corr_diff.sqrt_() # 変換モデルへの入力用のエネルギー energy = ( (y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) .square_() .sum(-2, keepdim=True) ) energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15 energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差 return ( instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames] corr_diff, # [..., max_corr_period, n_frames] energy, # [..., 1, n_frames] ) class PitchEstimator(nn.Module): def __init__( self, input_instfreq_channels: int = 192, input_corr_channels: int = 256, pitch_channels: int = 384, channels: int = 192, intermediate_channels: int = 192 * 3, n_blocks: int = 6, delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms embed_kernel_size: int = 3, kernel_size: int = 33, bins_per_octave: int = 96, ): super().__init__() self.bins_per_octave = bins_per_octave self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) self.corr_embed_1 = nn.Conv1d(channels, channels, 1) self.backbone = ConvNeXtStack( channels, channels, intermediate_channels, n_blocks, delay, embed_kernel_size, kernel_size, ) self.head = nn.Conv1d(channels, pitch_channels, 1) def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # wav: [batch_size, 1, wav_length] # [batch_size, input_instfreq_channels, length], # [batch_size, input_corr_channels, length] with torch.amp.autocast("cuda", enabled=False): instfreq_features, corr_diff, energy = extract_pitch_features( wav.squeeze(1), hop_length=160, win_length=560, max_corr_period=256, corr_win_length=304, instfreq_features_cutoff_bin=64, ) instfreq_features = F.gelu( self.instfreq_embed_0(instfreq_features), approximate="tanh" ) instfreq_features = self.instfreq_embed_1(instfreq_features) corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") corr_diff = self.corr_embed_1(corr_diff) # [batch_size, channels, length] x = instfreq_features + corr_diff # ここ活性化関数忘れてる x = self.backbone(x) # [batch_size, pitch_channels, length] x = self.head(x) return x, energy def sample_pitch( self, pitch: torch.Tensor, band_width: int = 48, return_features: bool = False ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # pitch: [batch_size, pitch_channels, length] # 返されるピッチの値には 0 は含まれない batch_size, pitch_channels, length = pitch.size() pitch = pitch.softmax(1) if return_features: unvoiced_proba = pitch[:, :1, :].clone() pitch[:, 0, :] = -100.0 pitch = ( pitch.transpose(1, 2) .contiguous() .view(batch_size * length, 1, pitch_channels) ) band_pitch = F.conv1d( pitch, torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), ) # [batch_size * length, 1, pitch_channels - band_width + 1] -> Long[batch_size * length, 1] quantized_band_pitch = band_pitch.argmax(2) if return_features: # [batch_size * length, 1] band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) # [batch_size * length, 1] half_pitch_band_proba = band_pitch.gather( 2, (quantized_band_pitch - self.bins_per_octave).clamp_(min=1)[:, :, None], ) half_pitch_band_proba[quantized_band_pitch <= self.bins_per_octave] = 0.0 half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( batch_size, 1, length ) # [batch_size * length, 1] double_pitch_band_proba = band_pitch.gather( 2, (quantized_band_pitch + self.bins_per_octave).clamp_( max=pitch_channels - band_width )[:, :, None], ) double_pitch_band_proba[ quantized_band_pitch > pitch_channels - band_width - self.bins_per_octave ] = 0.0 double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( batch_size, 1, length ) # Long[1, pitch_channels] mask = torch.arange(pitch_channels, device=pitch.device)[None, :] # bool[batch_size * length, pitch_channels] mask = (quantized_band_pitch <= mask) & ( mask < quantized_band_pitch + band_width ) # Long[batch_size, length] quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) if return_features: features = torch.cat( [unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 ) # Long[batch_size, length], [batch_size, 3, length] return quantized_pitch, features else: return quantized_pitch def merge_weights(self): self.backbone.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.instfreq_embed_0, f) dump_layer(self.instfreq_embed_1, f) dump_layer(self.corr_embed_0, f) dump_layer(self.corr_embed_1, f) dump_layer(self.backbone, f) dump_layer(self.head, f) # %% [markdown] # ## Vocoder # %% def overlap_add( ir_amp: torch.Tensor, ir_phase: torch.Tensor, window: torch.Tensor, pitch: torch.Tensor, hop_length: int = 240, delay: int = 0, sr: float = 24000.0, ) -> torch.Tensor: batch_size, ir_length, length = ir_amp.size() ir_length = (ir_length - 1) * 2 assert ir_phase.size() == ir_amp.size() assert window.size() == (ir_length,), (window.size(), ir_amp.size()) assert pitch.size() == (batch_size, length * hop_length) assert 0 <= delay < ir_length, (delay, ir_length) # 正規化角周波数 [2π rad] normalized_freq = pitch / sr # 初期位相 [2π rad] をランダムに設定 normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) with torch.amp.autocast("cuda", enabled=False): phase = (normalized_freq.double().cumsum_(1) % 1.0).float() # 重ねる箇所を求める # [n_pitchmarks], [n_pitchmarks] indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) # 重ねる箇所の小数部分 (位相の遅れ) を求める numer = 1.0 - phase[indices0, indices1] # [n_pitchmarks] fractional_part = numer / (numer + phase[indices0, indices1 + 1]) # 重ねる値を求める # Complex[n_pitchmarks, ir_length / 2 + 1] ir_amp = ir_amp[indices0, :, indices1 // hop_length] ir_phase = ir_phase[indices0, :, indices1 // hop_length] # 位相遅れの量 [rad] # [n_pitchmarks, ir_length / 2 + 1] delay_phase = ( torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ None, : ] * (-math.tau / ir_length) * fractional_part[:, None] ) # Complex[n_pitchmarks, ir_length / 2 + 1] spec = torch.polar(ir_amp, ir_phase + delay_phase) # [n_pitchmarks, ir_length] ir = torch.fft.irfft(spec, n=ir_length, dim=1) ir *= window # 加算する値をサンプル単位にばらす # [n_pitchmarks * ir_length] ir = ir.ravel() # Long[n_pitchmarks * ir_length] indices0 = indices0[:, None].expand(-1, ir_length).ravel() # Long[n_pitchmarks * ir_length] indices1 = ( indices1[:, None] + torch.arange(ir_length, device=pitch.device) ).ravel() # overlap-add する overlap_added_signal = torch.zeros( (batch_size, length * hop_length + ir_length), device=pitch.device ) overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] return overlap_added_signal def generate_noise( aperiodicity: torch.Tensor, delay: int = 0 ) -> tuple[torch.Tensor, torch.Tensor]: # aperiodicity: [batch_size, hop_length, length] batch_size, hop_length, length = aperiodicity.size() excitation = torch.rand( batch_size, (length + 1) * hop_length, device=aperiodicity.device ) excitation -= 0.5 n_fft = 2 * hop_length # 矩形窓で分析 # Complex[batch_size, hop_length + 1, length] noise = torch.stft( excitation, n_fft=n_fft, hop_length=hop_length, window=torch.ones(n_fft, device=excitation.device), center=False, return_complex=True, ) assert noise.size(2) == aperiodicity.size(2) noise[:, 0, :] = 0.0 noise[:, 1:, :] *= aperiodicity # ハン窓で合成 # torch.istft は最適合成窓が使われるので使えないことに注意 # [batch_size, 2 * hop_length, length] noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] # [batch_size, (length + 1) * hop_length] noise = F.fold( noise, (1, (length + 1) * hop_length), (1, 2 * hop_length), stride=(1, hop_length), ).squeeze_((1, 2)) assert delay < hop_length noise = noise[:, delay : -hop_length + delay] excitation = excitation[:, delay : -hop_length + delay] return noise, excitation # [batch_size, length * hop_length] class GradientEqualizerFunction(torch.autograd.Function): """ノルムが小さいほど勾配が大きくなってしまうのを補正する""" @staticmethod def forward(ctx, x: torch.Tensor) -> torch.Tensor: # x: [batch_size, 1, length] rms = x.square().mean(dim=2, keepdim=True).sqrt_() ctx.save_for_backward(rms) return x @staticmethod def backward(ctx, dx: torch.Tensor) -> torch.Tensor: # dx: [batch_size, 1, length] (rms,) = ctx.saved_tensors dx = dx * (math.sqrt(2.0) * rms + 0.1) return dx D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理 def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: # x が単調増加で等間隔と仮定 # 外挿は起こらないと仮定 x = torch.as_tensor(x) y = torch.as_tensor(y) xi = torch.as_tensor(xi) if xi.ndim < y.ndim: diff_ndim = y.ndim - xi.ndim xi = xi.view(tuple([1] * diff_ndim) + xi.size()) if xi.size()[:-1] != y.size()[:-1]: xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) assert (x.min(-1).values == x[..., 0]).all() assert (x.max(-1).values == x[..., -1]).all() assert (xi.min(-1).values >= x[..., 0]).all() assert (xi.max(-1).values <= x[..., -1]).all() delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) delta_x = delta_x.to(x.dtype) xi = (xi - x[..., :1]).div_(delta_x[..., None]) xi_base = xi.floor() xi_fraction = xi.sub_(xi_base) xi_base = xi_base.long() delta_y = y.diff(dim=-1, append=y[..., -1:]) yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction return yi def linear_smoothing( group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor ) -> torch.Tensor: group_delay = torch.as_tensor(group_delay) assert group_delay.size(-1) == n_fft // 2 + 1 width = torch.as_tensor(width) boundary = (width.max() * n_fft / sr).long() + 1 dtype = group_delay.dtype device = group_delay.device fft_resolution = sr / n_fft mirroring_freq_axis = ( torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) .add(0.5) .mul(fft_resolution) ) if group_delay.ndim == 1: mirroring_spec = F.pad( group_delay[None], (boundary, boundary), mode="reflect" ).squeeze_(0) elif group_delay.ndim >= 4: shape = group_delay.size() mirroring_spec = F.pad( group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), (boundary, boundary), mode="reflect", ).view(shape[:-1] + (shape[-1] + 2 * boundary,)) else: mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( fft_resolution ) low_freq = center_freq - width[..., None] * 0.5 high_freq = center_freq + width[..., None] * 0.5 levels = interp( mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) ) low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) smoothed = (high_levels - low_levels).div_(width[..., None]) return smoothed def dc_correction( spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor ) -> torch.Tensor: spec = torch.as_tensor(spec) f0 = torch.as_tensor(f0) dtype = spec.dtype device = spec.device upper_limit = 2 + (f0 * (n_fft / sr)).long() max_upper_limit = upper_limit.max() upper_limit_mask = ( torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] ) low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( sr / n_fft ) low_freq_replica = interp( f0[..., None] - low_freq_axis.flip(-1), spec[..., : max_upper_limit + 1].flip(-1), low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, ) output = spec.clone() output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask return output def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: t = torch.linspace(0, math.tau, n, device=device) coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) cos_matrix = (terms[:, None] * t).cos_() # [4, n] window = coefs.matmul(cos_matrix) return window def get_windowed_waveform( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, half_window_length_ratio: float, window_type: Literal["hann", "blackman"], n_fft: int, ) -> tuple[torch.Tensor, torch.Tensor]: x = torch.as_tensor(x) f0 = torch.as_tensor(f0) position = torch.as_tensor(position) current_sample = position * sr # [...] half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() # [..., fft_size] base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) base_index_mask = base_index <= half_window_length[..., None] # [..., fft_size] safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( 0, x.size(-1) - 1 ) # [..., fft_size] time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) # [...] normalized_f0 = math.pi / sr * f0 # [..., fft_size] phase = time_axis.mul_(normalized_f0[..., None]) if window_type == "hann": window = phase.cos_().mul_(0.5).add_(0.5) elif window_type == "blackman": window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) else: assert False window *= base_index_mask prefix_shape = tuple( max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) )[:-1] waveform = ( x.expand(prefix_shape + (-1,)) .gather(-1, safe_index.expand(prefix_shape + (-1,))) .mul_(window) ) if not D4C_PREVENT_ZERO_DIVISION: waveform += torch.randn_like(window).mul_(1e-12) waveform *= base_index_mask waveform -= window * waveform.sum(-1, keepdim=True).div_( window.sum(-1, keepdim=True) ) return waveform, window def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: x = torch.as_tensor(x) if D4C_PREVENT_ZERO_DIVISION: x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) else: x = x / x.norm(dim=-1, keepdim=True) spec0 = torch.fft.rfft(x, n_fft) spec1 = torch.fft.rfft( x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), n_fft, ) centroid = spec0.real * spec1.real + spec0.imag * spec1.imag return centroid def get_static_centroid( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int ) -> torch.Tensor: """First step: calculation of temporally static parameters on basis of group delay""" x1, _ = get_windowed_waveform( x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft ) x2, _ = get_windowed_waveform( x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft ) centroid1 = get_centroid(x1, n_fft) centroid2 = get_centroid(x2, n_fft) return dc_correction(centroid1 + centroid2, sr, n_fft, f0) def get_smoothed_power_spec( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int ) -> tuple[torch.Tensor, torch.Tensor]: x = torch.as_tensor(x) f0 = torch.as_tensor(f0) x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) window_weight = window.square().sum(-1, keepdim=True) rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() if D4C_PREVENT_ZERO_DIVISION: x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) return smoothed_power_spec, rms.detach().squeeze(-1) def get_static_group_delay( static_centroid: torch.Tensor, smoothed_power_spec: torch.Tensor, sr: int, f0: torch.Tensor, n_fft: int, ) -> torch.Tensor: """Second step: calculation of parameter shaping""" if D4C_PREVENT_ZERO_DIVISION: smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) static_group_delay = static_centroid / smoothed_power_spec # t_g static_group_delay = linear_smoothing( static_group_delay, sr, n_fft, f0 * 0.5 ) # t_gs smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb static_group_delay = static_group_delay - smoothed_group_delay # t_D return static_group_delay def get_coarse_aperiodicity( group_delay: torch.Tensor, sr: int, n_fft: int, freq_interval: int, n_aperiodicities: int, window: torch.Tensor, ) -> torch.Tensor: """Third step: estimation of band-aperiodicity""" group_delay = torch.as_tensor(group_delay) window = torch.as_tensor(window) boundary = int(round(n_fft * 8 / window.size(-1))) half_window_length = window.size(-1) // 2 coarse_aperiodicity = torch.empty( group_delay.size()[:-1] + (n_aperiodicities,), dtype=group_delay.dtype, device=group_delay.device, ) for i in range(n_aperiodicities): center = freq_interval * (i + 1) * n_fft // sr segment = ( group_delay[ ..., center - half_window_length : center + half_window_length + 1 ] * window ) power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) if D4C_PREVENT_ZERO_DIVISION: cumulative_power_spec.clamp_(min=6e-8) coarse_aperiodicity[..., i] = ( cumulative_power_spec[..., n_fft // 2 - boundary - 1] / cumulative_power_spec[..., -1] ) coarse_aperiodicity.log10_().mul_(10.0) return coarse_aperiodicity def d4c_love_train( x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float ) -> int: x = torch.as_tensor(x) position = torch.as_tensor(position) f0: torch.Tensor = torch.as_tensor(f0) vuv = f0 != 0 lowest_f0 = 40 f0 = f0.clamp(min=lowest_f0) n_fft = 1 << (3 * sr // lowest_f0).bit_length() boundary0 = (100 * n_fft - 1) // sr + 1 boundary1 = (4000 * n_fft - 1) // sr + 1 boundary2 = (7900 * n_fft - 1) // sr + 1 waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() power_spec[..., : boundary0 + 1] = 0.0 cumulative_spec = power_spec.cumsum_(-1) vuv = vuv & ( cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] ) return vuv def d4c_general_body( x: torch.Tensor, sr: int, f0: torch.Tensor, freq_interval: int, position: torch.Tensor, n_fft: int, n_aperiodicities: int, window: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: static_centroid = get_static_centroid(x, sr, f0, position, n_fft) smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) static_group_delay = get_static_group_delay( static_centroid, smoothed_power_spec, sr, f0, n_fft ) coarse_aperiodicity = get_coarse_aperiodicity( static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window ) coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) return coarse_aperiodicity, rms def d4c( x: torch.Tensor, f0: torch.Tensor, t: torch.Tensor, sr: int, threshold: float = 0.85, n_fft_spec: Optional[int] = None, coarse_only: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" FLOOR_F0 = 71 FLOOR_F0_D4C = 47 UPPER_LIMIT = 15000 FREQ_INTERVAL = 3000 assert sr == int(sr) sr = int(sr) assert sr % 2 == 0 x = torch.as_tensor(x) f0 = torch.as_tensor(f0) temporal_positions = torch.as_tensor(t) n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() if n_fft_spec is None: n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL assert n_aperiodicities >= 1 window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 window = nuttall(window_length, device=x.device) freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) coarse_aperiodicity, rms = d4c_general_body( x[..., None, :], sr, f0.clamp(min=FLOOR_F0_D4C), FREQ_INTERVAL, temporal_positions, n_fft_d4c, n_aperiodicities, window, ) if coarse_only: return coarse_aperiodicity, rms even_coarse_axis = ( torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL ) assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr coarse_axis_low = ( torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) * FREQ_INTERVAL ) aperiodicity_low = interp( coarse_axis_low, F.pad(coarse_aperiodicity, (1, 0), value=-60.0), freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], ) coarse_axis_high = torch.tensor( [n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device ) aperiodicity_high = interp( coarse_axis_high, F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], ) aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) aperiodicity = 10.0 ** (aperiodicity / 20.0) vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) return aperiodicity, coarse_aperiodicity class Vocoder(nn.Module): def __init__( self, channels: int, hop_length: int = 240, n_pre_blocks: int = 4, out_sample_rate: float = 24000.0, ): super().__init__() self.hop_length = hop_length self.out_sample_rate = out_sample_rate self.prenet = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 3, n_blocks=n_pre_blocks, delay=2, # 20ms 遅延 embed_kernel_size=7, kernel_size=33, enable_scaling=True, ) self.ir_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 3, n_blocks=2, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.ir_generator_post = WSConv1d(channels, 512, 1) self.register_buffer("ir_scale", torch.tensor(1.0)) self.ir_window = nn.Parameter(torch.ones(512)) self.aperiodicity_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 3, n_blocks=1, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) self.post_filter_generator = ConvNeXtStack( in_channels=channels, channels=channels, intermediate_channels=channels * 3, n_blocks=1, delay=0, embed_kernel_size=3, kernel_size=33, use_weight_standardization=True, enable_scaling=True, ) self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) self.register_buffer("post_filter_scale", torch.tensor(0.01)) def forward( self, x: torch.Tensor, pitch: torch.Tensor ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: # x: [batch_size, channels, length] # pitch: [batch_size, length] batch_size, _, length = x.size() x = self.prenet(x) ir = self.ir_generator(x) ir = F.silu(ir, inplace=True) # [batch_size, 512, length] ir = self.ir_generator_post(ir) ir *= self.ir_scale ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) ir_phase[:, 1::2, :] += math.pi # TODO: 直流成分が正の値しか取れないのを修正する # 最近傍補間 # [batch_size, length * hop_length] pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) # [batch_size, length * hop_length] periodic_signal = overlap_add( ir_amp, ir_phase, self.ir_window, pitch, self.hop_length, delay=0, sr=self.out_sample_rate, ) aperiodicity = self.aperiodicity_generator(x) aperiodicity = F.silu(aperiodicity, inplace=True) # [batch_size, hop_length, length] aperiodicity = self.aperiodicity_generator_post(aperiodicity) aperiodicity *= self.aperiodicity_scale # [batch_size, length * hop_length], [batch_size, length * hop_length] aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) post_filter = self.post_filter_generator(x) post_filter = F.silu(post_filter, inplace=True) # [batch_size, 512, length] post_filter = self.post_filter_generator_post(post_filter) post_filter *= self.post_filter_scale post_filter[:, 0, :] += 1.0 # [batch_size, length, 512] post_filter = post_filter.transpose(1, 2) with torch.amp.autocast("cuda", enabled=False): periodic_signal = periodic_signal.float() aperiodic_signal = aperiodic_signal.float() post_filter = post_filter.float() post_filter = torch.fft.rfft(post_filter, n=768) # [batch_size, length, 768] periodic_signal = torch.fft.irfft( torch.fft.rfft( periodic_signal.view(batch_size, length, self.hop_length), n=768 ) * post_filter, n=768, ) aperiodic_signal = torch.fft.irfft( torch.fft.rfft( aperiodic_signal.view(batch_size, length, self.hop_length), n=768 ) * post_filter, n=768, ) periodic_signal = F.fold( periodic_signal.transpose(1, 2), (1, (length - 1) * self.hop_length + 768), (1, 768), stride=(1, self.hop_length), ).squeeze_((1, 2)) aperiodic_signal = F.fold( aperiodic_signal.transpose(1, 2), (1, (length - 1) * self.hop_length + 768), (1, 768), stride=(1, self.hop_length), ).squeeze_((1, 2)) periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] noise_excitation = noise_excitation[:, 120:] # TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか? # [batch_size, 1, length * hop_length] y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] y_g_hat = GradientEqualizerFunction.apply(y_g_hat) return y_g_hat, { "periodic_signal": periodic_signal.detach(), "aperiodic_signal": aperiodic_signal.detach(), "noise_excitation": noise_excitation.detach(), } def merge_weights(self): self.prenet.merge_weights() self.ir_generator.merge_weights() self.ir_generator_post.merge_weights() self.aperiodicity_generator.merge_weights() self.aperiodicity_generator_post.merge_weights() self.ir_generator_post.weight.data *= self.ir_scale self.ir_generator_post.bias.data *= self.ir_scale self.ir_scale.fill_(1.0) self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale self.aperiodicity_scale.fill_(1.0) self.post_filter_generator.merge_weights() self.post_filter_generator_post.merge_weights() self.post_filter_generator_post.weight.data *= self.post_filter_scale self.post_filter_scale.fill_(1.0) def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.prenet, f) dump_layer(self.ir_generator, f) dump_layer(self.ir_generator_post, f) dump_layer(self.ir_window, f) dump_layer(self.aperiodicity_generator, f) dump_layer(self.aperiodicity_generator_post, f) dump_layer(self.post_filter_generator, f) dump_layer(self.post_filter_generator_post, f) def compute_loudness( x: torch.Tensor, sr: int, win_lengths: list[int] ) -> list[torch.Tensor]: # x: [batch_size, wav_length] assert x.ndim == 2 n_fft = 2048 chunk_length = n_fft // 2 n_taps = chunk_length + 1 results = [] with torch.amp.autocast("cuda", enabled=False): if not hasattr(compute_loudness, "filter"): compute_loudness.filter = {} if sr not in compute_loudness.filter: ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) ir[0] = 0.5 ir = torchaudio.functional.treble_biquad( ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) ) ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) ir *= 2.0 compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( torch.complex64 ) x = x.float() wav_length = x.size(-1) if wav_length % chunk_length != 0: x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) padded_wav_length = x.size(-1) x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) x = torch.fft.irfft( torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], n=n_fft, ) x = F.fold( x.transpose(-2, -1), (1, padded_wav_length + chunk_length), (1, n_fft), stride=(1, chunk_length), ).squeeze_((-3, -2))[..., :wav_length] x.square_() for win_length in win_lengths: hop_length = win_length // 4 # [..., n_frames] energy = ( x.unfold(-1, win_length, hop_length) .matmul(torch.hann_window(win_length, device=x.device)) .add_(win_length / 4.0 * 1e-5) .log10_() ) # フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差 results.append(energy) return results def slice_segments( x: torch.Tensor, start_indices: torch.Tensor, segment_length: int ) -> torch.Tensor: batch_size, channels, _ = x.size() # [batch_size, 1, segment_size] indices = start_indices[:, None, None] + torch.arange( segment_length, device=start_indices.device ) # [batch_size, channels, segment_size] indices = indices.expand(batch_size, channels, segment_length) return x.gather(2, indices) class ConverterNetwork(nn.Module): def __init__( self, phone_extractor: PhoneExtractor, pitch_estimator: PitchEstimator, n_speakers: int, hidden_channels: int, ): super().__init__() self.frozen_modules = { "phone_extractor": phone_extractor.eval().requires_grad_(False), "pitch_estimator": pitch_estimator.eval().requires_grad_(False), } self.out_sample_rate = out_sample_rate = 24000 self.embed_phone = nn.Conv1d(256, hidden_channels, 1) self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) self.embed_phone.bias.data.zero_() self.embed_quantized_pitch = nn.Embedding(384, hidden_channels) phase = ( torch.arange(384, dtype=torch.float)[:, None] * ( torch.arange(0, hidden_channels, 2, dtype=torch.float) * (-math.log(10000.0) / hidden_channels) ).exp_() ) self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) self.embed_quantized_pitch.weight.requires_grad_(False) self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) self.embed_pitch_features.bias.data.zero_() self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) self.embed_formant_shift = nn.Embedding(9, hidden_channels) self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) self.vocoder = Vocoder( channels=hidden_channels, hop_length=out_sample_rate // 100, n_pre_blocks=4, out_sample_rate=out_sample_rate, ) self.melspectrograms = nn.ModuleList() for win_length, n_mels in [ (32, 5), (64, 10), (128, 20), (256, 40), (512, 80), (1024, 160), (2048, 320), ]: self.melspectrograms.append( torchaudio.transforms.MelSpectrogram( sample_rate=out_sample_rate, n_fft=win_length, win_length=win_length, hop_length=win_length // 4, n_mels=n_mels, power=2, norm="slaney", mel_scale="slaney", ) ) def _get_resampler( self, orig_freq, new_freq, device, cache={} ) -> torchaudio.transforms.Resample: key = orig_freq, new_freq if key in cache: return cache[key] resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( device, non_blocking=True ) cache[key] = resampler return resampler def forward( self, x: torch.Tensor, target_speaker_id: torch.Tensor, formant_shift_semitone: torch.Tensor, pitch_shift_semitone: Optional[torch.Tensor] = None, slice_start_indices: Optional[torch.Tensor] = None, slice_segment_length: Optional[int] = None, return_stats: bool = False, ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: # x: [batch_size, 1, wav_length] # target_speaker_id: Long[batch_size] # formant_shift_semitone: [batch_size] # pitch_shift_semitone: [batch_size] # slice_start_indices: [batch_size] batch_size, _, _ = x.size() with torch.inference_mode(): phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] # [batch_size, 1, wav_length] -> [batch_size, phone_channels, length] phone = phone_extractor.units(x).transpose(1, 2) # [batch_size, 1, wav_length] -> [batch_size, pitch_channels, length], [batch_size, 1, length] pitch, energy = pitch_estimator(x) # augmentation if self.training: # [batch_size, pitch_channels - 1] weights = pitch.softmax(1)[:, 1:, :].mean(2) # [batch_size] mean_pitch = ( weights * torch.arange(1, 384, device=weights.device) ).sum(1) / weights.sum(1) mean_pitch = mean_pitch.round_().long() target_pitch = torch.randint_like(mean_pitch, 64, 257) shift = target_pitch - mean_pitch shift_ratio = ( 2.0 ** (shift.float() / pitch_estimator.bins_per_octave) ).tolist() shift = [] interval_length = 100 # 1s interval_zeros = torch.zeros( (1, 1, interval_length * 160), device=x.device ) concatenated_shifted_x = [] offsets = [0] torch.backends.cudnn.benchmark = False for i in range(batch_size): shift_ratio_i = shift_ratio[i] shift_ratio_fraction_i = Fraction.from_float( shift_ratio_i ).limit_denominator(30) shift_numer_i = shift_ratio_fraction_i.numerator shift_denom_i = shift_ratio_fraction_i.denominator shift_ratio_i = shift_numer_i / shift_denom_i shift_i = int( round( math.log2(shift_ratio_i) * pitch_estimator.bins_per_octave ) ) shift.append(shift_i) shift_ratio[i] = shift_ratio_i # [1, 1, wav_length / shift_ratio] with torch.amp.autocast("cuda", enabled=False): shifted_x_i = self._get_resampler( shift_numer_i, shift_denom_i, x.device )(x[i])[None] if shifted_x_i.size(2) % 160 != 0: shifted_x_i = F.pad( shifted_x_i, (0, 160 - shifted_x_i.size(2) % 160), mode="reflect", ) assert shifted_x_i.size(2) % 160 == 0 offsets.append( offsets[-1] + interval_length + shifted_x_i.size(2) // 160 ) concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) if offsets[-1] % 256 != 0: # 長さが同じ方が何かのキャッシュが効いて早くなるようなので # 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす concatenated_shifted_x.append( torch.zeros( (1, 1, (256 - offsets[-1] % 256) * 160), device=x.device ) ) # [batch_size, 1, sum(wav_length) + batch_size * 16000] concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) assert concatenated_shifted_x.size(2) % (256 * 160) == 0 # [1, pitch_channels, length / shift_ratio], [1, 1, length / shift_ratio] concatenated_pitch, concatenated_energy = pitch_estimator( concatenated_shifted_x ) for i in range(batch_size): shift_i = shift[i] shift_ratio_i = shift_ratio[i] left = offsets[i] + interval_length right = offsets[i + 1] pitch_i = concatenated_pitch[:, :, left:right] energy_i = concatenated_energy[:, :, left:right] pitch_i = F.interpolate( pitch_i, scale_factor=shift_ratio_i, mode="linear", align_corners=False, ) energy_i = F.interpolate( energy_i, scale_factor=shift_ratio_i, mode="linear", align_corners=False, ) assert pitch_i.size(2) == energy_i.size(2) assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 length = min(pitch_i.size(2), pitch.size(2)) if shift_i > 0: pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ :, 1 + shift_i :, :length ] pitch[i : i + 1, -shift_i:, :length] = -10.0 elif shift_i < 0: pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ :, 1:shift_i, :length ] energy[i : i + 1, :, :length] = energy_i[:, :, :length] torch.backends.cudnn.benchmark = True # [batch_size, pitch_channels, length] -> Long[batch_size, length], [batch_size, 3, length] quantized_pitch, pitch_features = pitch_estimator.sample_pitch( pitch, return_features=True ) if pitch_shift_semitone is not None: quantized_pitch = torch.where( quantized_pitch == 0, quantized_pitch, ( quantized_pitch + ( pitch_shift_semitone[:, None] * (pitch_estimator.bins_per_octave / 12.0) ) .round_() .long() ).clamp_(1, 383), ) pitch = 55.0 * 2.0 ** ( quantized_pitch.float() / pitch_estimator.bins_per_octave ) # phone が 2.5ms 先読みしているのに対して、 # energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、 # ずらして phone に合わせる energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") # [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length] pitch_features = torch.cat([energy, pitch_features], dim=1) formant_shift_indices = ( ((formant_shift_semitone + 2.0) * 2.0).round_().long() ) phone = phone.clone() quantized_pitch = quantized_pitch.clone() pitch_features = pitch_features.clone() formant_shift_indices = formant_shift_indices.clone() pitch = pitch.clone() # [batch_sise, hidden_channels, length] x = ( self.embed_phone(phone) + self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) + self.embed_pitch_features(pitch_features) + ( self.embed_speaker(target_speaker_id)[:, :, None] + self.embed_formant_shift(formant_shift_indices)[:, :, None] ) ) if slice_start_indices is not None: assert slice_segment_length is not None # [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length] x = slice_segments(x, slice_start_indices, slice_segment_length) x = F.silu(x, inplace=True) # [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240] y_g_hat, stats = self.vocoder(x, pitch) stats["pitch"] = pitch if return_stats: return y_g_hat, stats else: return y_g_hat def _normalize_melsp(self, x): return x.clamp(min=1e-10).log_().mul_(0.5) def forward_and_compute_loss( self, noisy_wavs_16k: torch.Tensor, target_speaker_id: torch.Tensor, formant_shift_semitone: torch.Tensor, slice_start_indices: torch.Tensor, slice_segment_length: int, y_all: torch.Tensor, enable_loss_ap: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # noisy_wavs_16k: [batch_size, 1, wav_length] # target_speaker_id: Long[batch_size] # formant_shift_semitone: [batch_size] # slice_start_indices: [batch_size] # slice_segment_length: int # y_all: [batch_size, 1, wav_length] stats = {} loss_mel = 0.0 # [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240] y_hat_all, intermediates = self( noisy_wavs_16k, target_speaker_id, formant_shift_semitone, return_stats=True, ) with torch.amp.autocast("cuda", enabled=False): periodic_signal = intermediates["periodic_signal"].float() aperiodic_signal = intermediates["aperiodic_signal"].float() noise_excitation = intermediates["noise_excitation"].float() periodic_signal = periodic_signal[:, : noise_excitation.size(1)] aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] y_hat_all = y_hat_all.float() y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] for melspectrogram in self.melspectrograms: melsp_periodic_signal = melspectrogram(periodic_signal) melsp_aperiodic_signal = melspectrogram(aperiodic_signal) melsp_noise_excitation = melspectrogram(noise_excitation) # [1, n_mels, 1] # 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー # 3/8 ... ハン窓をかけた時のパワー減衰 # 0.5 ... 謎 reference_melsp = melspectrogram.mel_scale( torch.full( (1, melspectrogram.n_fft // 2 + 1, 1), (1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, device=noisy_wavs_16k.device, ) ) aperiodic_ratio = melsp_aperiodic_signal / ( melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 ) compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) melsp_y_hat = melspectrogram(y_hat_all_truncated) melsp_y_hat = melsp_y_hat * ( (1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio ) y_hat_mel = self._normalize_melsp(melsp_y_hat) y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) loss_mel_i = F.l1_loss(y_hat_mel, y_mel) loss_mel += loss_mel_i stats[ f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" ] = loss_mel_i.item() loss_mel /= len(self.melspectrograms) if enable_loss_ap: t = ( torch.arange(intermediates["pitch"].size(1), device=y_all.device) * 0.01 ) y_coarse_aperiodicity, y_rms = d4c( y_all.squeeze(1), intermediates["pitch"], t, self.vocoder.out_sample_rate, coarse_only=True, ) y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) y_hat_coarse_aperiodicity, y_hat_rms = d4c( y_hat_all.squeeze(1), intermediates["pitch"], t, self.vocoder.out_sample_rate, coarse_only=True, ) y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) rms = torch.maximum(y_rms, y_hat_rms) loss_ap = F.mse_loss( y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" ) loss_ap *= (rms / (rms + 1e-3))[:, :, None] loss_ap = loss_ap.mean() else: loss_ap = torch.tensor(0.0) # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] y_hat = slice_segments( y_hat_all, slice_start_indices * 240, slice_segment_length * 240 ) # [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240] y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) return y, y_hat, y_hat_all, loss_mel, loss_ap, stats def merge_weights(self): self.vocoder.merge_weights() def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): if isinstance(f, (str, bytes, os.PathLike)): with open(f, "wb") as f: self.dump(f) return if not hasattr(f, "write"): raise TypeError dump_layer(self.embed_phone, f) dump_layer(self.embed_quantized_pitch, f) dump_layer(self.embed_pitch_features, f) dump_layer(self.vocoder, f) # Discriminator def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) return tensor / denom class SANConv2d(nn.Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, bias: bool = True, padding_mode="zeros", device=None, dtype=None, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, groups=1, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype, ) scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) self.scale = nn.parameter.Parameter(scale.view(out_channels)) if bias: self.bias = nn.parameter.Parameter( torch.zeros(in_channels, device=device, dtype=dtype) ) else: self.register_parameter("bias", None) def forward( self, input: torch.Tensor, flg_san_train: bool = False ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if self.bias is not None: input = input + self.bias.view(self.in_channels, 1, 1) normalized_weight = self._get_normalized_weight() scale = self.scale.view(self.out_channels, 1, 1) if flg_san_train: out_fun = F.conv2d( input, normalized_weight.detach(), None, self.stride, self.padding, self.dilation, self.groups, ) out_dir = F.conv2d( input.detach(), normalized_weight, None, self.stride, self.padding, self.dilation, self.groups, ) out = out_fun * scale, out_dir * scale.detach() else: out = F.conv2d( input, normalized_weight, None, self.stride, self.padding, self.dilation, self.groups, ) out = out * scale return out @torch.no_grad() def normalize_weight(self): self.weight.data = self._get_normalized_weight() def _get_normalized_weight(self) -> torch.Tensor: return _normalize(self.weight, dim=[1, 2, 3]) def get_padding(kernel_size: int, dilation: int = 1) -> int: return (kernel_size * dilation - dilation) // 2 class DiscriminatorP(nn.Module): def __init__( self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False ): super().__init__() self.period = period self.san = san # fmt: off self.convs = nn.ModuleList([ weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), ]) # fmt: on if san: self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) else: self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) def forward( self, x: torch.Tensor, flg_san_train: bool = False ) -> tuple[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] ]: fmap = [] b, c, t = x.shape if t % self.period != 0: n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad x = x.view(b, c, t // self.period, self.period) for l in self.convs: x = l(x) x = F.silu(x, inplace=True) fmap.append(x) if self.san: x = self.conv_post(x, flg_san_train=flg_san_train) else: x = self.conv_post(x) if flg_san_train: x_fun, x_dir = x fmap.append(x_fun) x_fun = torch.flatten(x_fun, 1, -1) x_dir = torch.flatten(x_dir, 1, -1) x = x_fun, x_dir else: fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap class DiscriminatorR(nn.Module): def __init__(self, resolution: int, san: bool = False): super().__init__() self.resolution = resolution self.san = san assert len(self.resolution) == 3 self.convs = nn.ModuleList( [ weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), ] ) if san: self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) else: self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) def forward( self, x: torch.Tensor, flg_san_train: bool = False ) -> tuple[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] ]: fmap = [] x = self._spectrogram(x).unsqueeze(1) for l in self.convs: x = l(x) x = F.silu(x, inplace=True) fmap.append(x) if self.san: x = self.conv_post(x, flg_san_train=flg_san_train) else: x = self.conv_post(x) if flg_san_train: x_fun, x_dir = x fmap.append(x_fun) x_fun = torch.flatten(x_fun, 1, -1) x_dir = torch.flatten(x_dir, 1, -1) x = x_fun, x_dir else: fmap.append(x) x = torch.flatten(x, 1, -1) return x, fmap def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: n_fft, hop_length, win_length = self.resolution x = F.pad( x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" ).squeeze(1) with torch.amp.autocast("cuda", enabled=False): mag = torch.stft( x.float(), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=torch.ones(win_length, device=x.device), center=False, return_complex=True, ).abs() return mag class MultiPeriodDiscriminator(nn.Module): def __init__(self, san: bool = False): super().__init__() resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] periods = [2, 3, 5, 7, 11] self.discriminators = nn.ModuleList( [DiscriminatorR(r, san=san) for r in resolutions] + [DiscriminatorP(p, san=san) for p in periods] ) self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ f"P_{p}" for p in periods ] self.san = san def forward( self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False ) -> tuple[ list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], list[list[torch.Tensor]], list[list[torch.Tensor]], ]: batch_size = y.size(0) concatenated_y_y_hat = torch.cat([y, y_hat]) y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for d in self.discriminators: if flg_san_train: (y_d_fun, y_d_dir), fmap = d( concatenated_y_y_hat, flg_san_train=flg_san_train ) y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) y_d_r = y_d_r_fun, y_d_r_dir y_d_g = y_d_g_fun, y_d_g_dir else: y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) y_d_r, y_d_g = torch.split(y_d, batch_size) fmap_r = [] fmap_g = [] for fm in fmap: fm_r, fm_g = torch.split(fm, batch_size) fmap_r.append(fm_r) fmap_g.append(fm_g) y_d_rs.append(y_d_r) y_d_gs.append(y_d_g) fmap_rs.append(fmap_r) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs def forward_and_compute_loss( self, y: torch.Tensor, y_hat: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) stats = {} assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) with torch.amp.autocast("cuda", enabled=False): # discriminator loss d_loss = 0.0 for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): if self.san: dr_fun, dr_dir = map(lambda x: x.float(), dr) dg_fun, dg_dir = map(lambda x: x.float(), dg) r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() g_loss_fun = F.softplus(dg_fun).square().mean() r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() r_loss = r_loss_fun + r_loss_dir g_loss = g_loss_fun + g_loss_dir else: dr = dr.float() dg = dg.float() r_loss = (1.0 - dr).square().mean() g_loss = dg.square().mean() stats[f"{name}_dr_loss"] = r_loss.item() stats[f"{name}_dg_loss"] = g_loss.item() d_loss += r_loss + g_loss # adversarial loss adv_loss = 0.0 for dg, name in zip(y_d_gs, self.discriminator_names): dg = dg.float() if self.san: g_loss = F.softplus(1.0 - dg).square().mean() else: g_loss = (1.0 - dg).square().mean() stats[f"{name}_gg_loss"] = g_loss.item() adv_loss += g_loss # feature mathcing loss fm_loss = 0.0 for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): fm_loss_i = 0.0 for j, (r, g) in enumerate(zip(fr, fg)): fm_loss_ij = (r.detach().float() - g.float()).abs().mean() stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() fm_loss_i += fm_loss_ij stats[f"{name}_fm_loss"] = fm_loss_i.item() fm_loss += fm_loss_i return d_loss, adv_loss, fm_loss, stats # %% [markdown] # ## Utilities # %% class GradBalancer: """Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" def __init__( self, weights: dict[str, float], rescale_grads: bool = True, total_norm: float = 1.0, ema_decay: float = 0.999, per_batch_item: bool = True, ): self.weights = weights self.per_batch_item = per_batch_item self.total_norm = total_norm self.ema_decay = ema_decay self.rescale_grads = rescale_grads self.ema_total: dict[str, float] = defaultdict(float) self.ema_fix: dict[str, float] = defaultdict(float) def backward( self, losses: dict[str, torch.Tensor], input: torch.Tensor, scaler: Optional[torch.amp.GradScaler] = None, skip_update_ema: bool = False, ) -> dict[str, float]: stats = {} if skip_update_ema: assert len(losses) == len(self.ema_total) ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} else: # 各 loss に対して d loss / d input とそのノルムを計算する norms = {} grads = {} for name, loss in losses.items(): if scaler is not None: loss = scaler.scale(loss) (grad,) = torch.autograd.grad(loss, [input], retain_graph=True) if not grad.isfinite().all(): input.backward(grad) return {} grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) if self.per_batch_item: dims = tuple(range(1, grad.dim())) ema_norm = grad.norm(dim=dims).mean() else: ema_norm = grad.norm() norms[name] = float(ema_norm) grads[name] = grad # ノルムの移動平均を計算する for key, value in norms.items(): self.ema_total[key] = self.ema_total[key] * self.ema_decay + value self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} # ログを取る total_ema_norm = sum(ema_norms.values()) for k, ema_norm in ema_norms.items(): stats[f"grad_norm_value_{k}"] = ema_norm stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) # loss の係数の比率を計算する if self.rescale_grads: total_weights = sum([self.weights[k] for k in ema_norms]) ratios = {k: w / total_weights for k, w in self.weights.items()} # 勾配を修正する loss = 0.0 for name, ema_norm in ema_norms.items(): if self.rescale_grads: scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) else: scale = self.weights[name] loss += (losses if skip_update_ema else grads)[name] * scale if scaler is not None: loss = scaler.scale(loss) if skip_update_ema: (loss,) = torch.autograd.grad(loss, [input]) input.backward(loss) return stats def state_dict(self) -> dict[str, dict[str, float]]: return { "ema_total": dict(self.ema_total), "ema_fix": dict(self.ema_fix), } def load_state_dict(self, state_dict): self.ema_total = defaultdict(float, state_dict["ema_total"]) self.ema_fix = defaultdict(float, state_dict["ema_fix"]) class QualityTester(nn.Module): def __init__(self): super().__init__() self.utmos = torch.hub.load( "tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True ).eval() @torch.inference_mode() def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: res = {"utmos": self.utmos(wav, sr=16000).tolist()} return res def test( self, converted_wav: torch.Tensor, source_wav: torch.Tensor ) -> dict[str, list[float]]: # [batch_size, wav_length] res = {} res.update(self.compute_mos(converted_wav)) return res def test_many( self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] ) -> tuple[dict[str, float], dict[str, list[float]]]: # list[batch_size, wav_length] results = defaultdict(list) assert len(converted_wavs) == len(source_wavs) for converted_wav, source_wav in zip(converted_wavs, source_wavs): res = self.test(converted_wav, source_wav) for metric_name, value in res.items(): results[metric_name].extend(value) return { metric_name: sum(values) / len(values) for metric_name, values in results.items() }, results def compute_grad_norm( model: nn.Module, return_stats: bool = False ) -> Union[float, dict[str, float]]: total_norm = 0.0 stats = {} for name, p in model.named_parameters(): if p.grad is None: continue param_norm = p.grad.data.norm().item() if not math.isfinite(param_norm): param_norm = p.grad.data.float().norm().item() total_norm += param_norm * param_norm if return_stats: stats[f"grad_norm_{name}"] = param_norm total_norm = math.sqrt(total_norm) if return_stats: return total_norm, stats else: return total_norm def compute_mean_f0( files: list[Path], method: Literal["dio", "harvest"] = "dio" ) -> float: sum_log_f0 = 0.0 n_frames = 0 for file in files: wav, sr = torchaudio.load(file, backend="soundfile") if method == "dio": f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) elif method == "harvest": f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) else: raise ValueError(f"Invalid method: {method}") f0 = f0[f0 > 0] sum_log_f0 += float(np.log(f0).sum()) n_frames += len(f0) if n_frames == 0: return math.nan mean_log_f0 = sum_log_f0 / n_frames return math.exp(mean_log_f0) # %% [markdown] # ## Dataset # %% def get_resampler( sr_before: int, sr_after: int, device="cpu", cache={} ) -> torchaudio.transforms.Resample: if not isinstance(device, str): device = str(device) if (sr_before, sr_after, device) not in cache: cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( sr_before, sr_after ).to(device) return cache[(sr_before, sr_after, device)] def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) return res[..., : signal.size(-1)] def random_filter(audio: torch.Tensor) -> torch.Tensor: assert audio.ndim == 2 ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 a, b = ab[:, :3], ab[:, 3:] a[:, 0] = 1.0 b[:, 0] = 1.0 audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) return audio def get_noise( n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] ) -> torch.Tensor: resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] wavs = [] current_length = 0 while current_length < n_samples: idx_files = torch.randint(0, len(files), ()) file = files[idx_files] wav, sr = torchaudio.load(file, backend="soundfile") assert wav.size(0) == 1 augmented_sample_rate = int( round( sample_rate * resample_augmentation_candidates[ torch.randint(0, len(resample_augmentation_candidates), ()) ] ) ) resampler = get_resampler(sr, augmented_sample_rate) wav = resampler(wav) wav = random_filter(wav) wav *= 0.99 / (wav.abs().max() + 1e-5) wavs.append(wav) current_length += wav.size(1) start = torch.randint(0, current_length - n_samples + 1, ()) wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] assert wav.size() == (1, n_samples), wav.size() return wav def get_butterworth_lpf( cutoff_freq: int, sample_rate: int, cache={} ) -> tuple[torch.Tensor, torch.Tensor]: if (cutoff_freq, sample_rate) not in cache: q = math.sqrt(0.5) omega = math.tau * cutoff_freq / sample_rate cos_omega = math.cos(omega) alpha = math.sin(omega) / (2.0 * q) b1 = (1.0 - cos_omega) / (1.0 + alpha) b0 = b1 * 0.5 a1 = -2.0 * cos_omega / (1.0 + alpha) a2 = (1.0 - alpha) / (1.0 + alpha) cache[(cutoff_freq, sample_rate)] = torch.tensor([b0, b1, b0]), torch.tensor( [1.0, a1, a2] ) return cache[(cutoff_freq, sample_rate)] def augment_audio( clean: torch.Tensor, sample_rate: int, noise_files: list[Union[str, bytes, os.PathLike]], ir_files: list[Union[str, bytes, os.PathLike]], ) -> torch.Tensor: # [1, wav_length] assert clean.size(0) == 1 n_samples = clean.size(1) snr_candidates = [-20, -25, -30, -35, -40, -45] original_clean_rms = clean.square().mean().sqrt_() # noise を取得して clean と concat する noise = get_noise(n_samples, sample_rate, noise_files) signals = torch.cat([clean, noise]) # clean, noise に異なるランダムフィルタをかける signals = random_filter(signals) # clean, noise にリバーブをかける if torch.rand(()) < 0.5: ir_file = ir_files[torch.randint(0, len(ir_files), ())] ir, sr = torchaudio.load(ir_file, backend="soundfile") assert ir.size() == (2, sr), ir.size() assert sr == sample_rate, (sr, sample_rate) signals = convolve(signals, ir) # clean, noise に同じ LPF をかける if torch.rand(()) < 0.2: if signals.abs().max() > 0.8: signals /= signals.abs().max() * 1.25 cutoff_freq_candidates = [2000, 3000, 4000, 6000] cutoff_freq = cutoff_freq_candidates[ torch.randint(0, len(cutoff_freq_candidates), ()) ] b, a = get_butterworth_lpf(cutoff_freq, sample_rate) signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) # clean の音量を合わせる clean, noise = signals clean_rms = clean.square().mean().sqrt_() clean *= original_clean_rms / clean_rms # clean, noise の音量をピークを重視して取る clean_level = clean.square().square_().mean().sqrt_().sqrt_() noise_level = noise.square().square_().mean().sqrt_().sqrt_() # SNR snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] # noisy を生成 noisy = clean + noise * (10.0 ** (snr / 20.0) * clean_level / (noise_level + 1e-5)) return noisy class WavDataset(torch.utils.data.Dataset): def __init__( self, audio_files: list[tuple[Path, int]], in_sample_rate: int = 16000, out_sample_rate: int = 24000, wav_length: int = 4 * 24000, # 4s segment_length: int = 100, # 1s noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, ): self.audio_files = audio_files self.in_sample_rate = in_sample_rate self.out_sample_rate = out_sample_rate self.wav_length = wav_length self.segment_length = segment_length self.noise_files = noise_files self.ir_files = ir_files if (noise_files is None) is not (ir_files is None): raise ValueError("noise_files and ir_files must be both None or not None") self.in_hop_length = in_sample_rate // 100 self.out_hop_length = out_sample_rate // 100 # 10ms 刻み def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: file, speaker_id = self.audio_files[index] clean_wav, sample_rate = torchaudio.load(file, backend="soundfile") if clean_wav.size(0) != 1: ch = torch.randint(0, clean_wav.size(0), ()) clean_wav = clean_wav[ch : ch + 1] formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] formant_shift = formant_shift_candidates[ torch.randint(0, len(formant_shift_candidates), ()).item() ] resampler_fraction = Fraction( sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) ).limit_denominator(300) clean_wav = get_resampler( resampler_fraction.numerator, resampler_fraction.denominator )(clean_wav) assert clean_wav.size(0) == 1 assert clean_wav.size(1) != 0 clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) if self.noise_files is None: assert False noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( clean_wav ) else: clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( clean_wav ) noisy_wav_16k = augment_audio( clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files ) clean_wav = clean_wav.squeeze_(0) noisy_wav_16k = noisy_wav_16k.squeeze_(0) # 音量をランダマイズする amplitude = torch.rand(()).item() * 0.899 + 0.1 factor = amplitude / clean_wav.abs().max() clean_wav *= factor noisy_wav_16k *= factor while noisy_wav_16k.abs().max() >= 1.0: clean_wav *= 0.5 noisy_wav_16k *= 0.5 return clean_wav, noisy_wav_16k, speaker_id, formant_shift def __len__(self) -> int: return len(self.audio_files) def collate( self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: assert self.wav_length % self.out_hop_length == 0 length = self.wav_length // self.out_hop_length clean_wavs = [] noisy_wavs = [] slice_starts = [] speaker_ids = [] formant_shifts = [] for clean_wav, noisy_wav, speaker_id, formant_shift in batch: # 発声部分をランダムに 1 箇所選ぶ (voiced,) = clean_wav.nonzero(as_tuple=True) assert voiced.numel() != 0 center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() # 発声部分が中央にくるように、スライス区間を選ぶ slice_start = center - self.segment_length * self.out_hop_length // 2 assert slice_start >= 0 # スライス区間が含まれるように、ランダムに wav_length の長さを切り出す r = torch.randint(0, length - self.segment_length + 1, ()).item() offset = slice_start - r * self.out_hop_length clean_wavs.append(clean_wav[offset : offset + self.wav_length]) offset_in_sample_rate = int( round(offset * self.in_sample_rate / self.out_sample_rate) ) noisy_wavs.append( noisy_wav[ offset_in_sample_rate : offset_in_sample_rate + length * self.in_hop_length ] ) slice_start = r slice_starts.append(slice_start) speaker_ids.append(speaker_id) formant_shifts.append(formant_shift) clean_wavs = torch.stack(clean_wavs) noisy_wavs = torch.stack(noisy_wavs) slice_starts = torch.tensor(slice_starts) speaker_ids = torch.tensor(speaker_ids) formant_shifts = torch.tensor(formant_shifts) return ( clean_wavs, # [batch_size, wav_length] noisy_wavs, # [batch_size, wav_length] slice_starts, # Long[batch_size] speaker_ids, # Long[batch_size] formant_shifts, # Long[batch_size] ) # %% [markdown] # ## Train # %% AUDIO_FILE_SUFFIXES = { ".wav", ".aif", ".aiff", ".fla", ".flac", ".oga", ".ogg", ".opus", ".mp3", } def prepare_training(): # 各種準備をする # 副作用として、出力ディレクトリと TensorBoard のログファイルなどが生成される device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"device={device}") torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True (h, in_wav_dataset_dir, out_dir, resume, skip_training) = ( prepare_training_configs_for_experiment if is_notebook() else prepare_training_configs )() print("config:") pprint(h) print() h = AttrDict(h) if not in_wav_dataset_dir.is_dir(): raise ValueError(f"{in_wav_dataset_dir} is not found.") if resume: latest_checkpoint_file = out_dir / "checkpoint_latest.pt" if not latest_checkpoint_file.is_file(): raise ValueError(f"{latest_checkpoint_file} is not found.") else: if out_dir.is_dir(): if (out_dir / "checkpoint_latest.pt").is_file(): raise ValueError( f"{out_dir / 'checkpoint_latest.pt'} already exists. " "Please specify a different output directory, or use --resume option." ) for file in out_dir.iterdir(): if file.suffix == ".pt": raise ValueError( f"{out_dir} already contains model files. " "Please specify a different output directory." ) else: out_dir.mkdir(parents=True) in_ir_wav_dir = repo_root() / h.in_ir_wav_dir in_noise_wav_dir = repo_root() / h.in_noise_wav_dir in_test_wav_dir = repo_root() / h.in_test_wav_dir assert in_wav_dataset_dir.is_dir(), in_wav_dataset_dir assert out_dir.is_dir(), out_dir assert in_ir_wav_dir.is_dir(), in_ir_wav_dir assert in_noise_wav_dir.is_dir(), in_noise_wav_dir assert in_test_wav_dir.is_dir(), in_test_wav_dir # .wav または *.flac のファイルを再帰的に取得 noise_files = sorted( list(in_noise_wav_dir.rglob("*.wav")) + list(in_noise_wav_dir.rglob("*.flac")) ) if len(noise_files) == 0: raise ValueError(f"No audio data found in {in_noise_wav_dir}.") ir_files = sorted( list(in_ir_wav_dir.rglob("*.wav")) + list(in_ir_wav_dir.rglob("*.flac")) ) if len(ir_files) == 0: raise ValueError(f"No audio data found in {in_ir_wav_dir}.") # TODO: 無音除去とか def get_training_filelist(in_wav_dataset_dir: Path): min_data_per_speaker = 1 speakers: list[str] = [] training_filelist: list[tuple[Path, int]] = [] speaker_audio_files: list[list[Path]] = [] for speaker_dir in sorted(in_wav_dataset_dir.iterdir()): if not speaker_dir.is_dir(): continue candidates = [] for wav_file in sorted(speaker_dir.rglob("*")): if ( not wav_file.is_file() or wav_file.suffix.lower() not in AUDIO_FILE_SUFFIXES ): continue candidates.append(wav_file) if len(candidates) >= min_data_per_speaker: speaker_id = len(speakers) speakers.append(speaker_dir.name) training_filelist.extend([(file, speaker_id) for file in candidates]) speaker_audio_files.append(candidates) return speakers, training_filelist, speaker_audio_files speakers, training_filelist, speaker_audio_files = get_training_filelist( in_wav_dataset_dir ) n_speakers = len(speakers) if n_speakers == 0: raise ValueError(f"No speaker data found in {in_wav_dataset_dir}.") print(f"{n_speakers=}") for i, speaker in enumerate(speakers): print(f" {i:{len(str(n_speakers - 1))}d}: {speaker}") print() print(f"{len(training_filelist)=}") def get_test_filelist( in_test_wav_dir: Path, n_speakers: int ) -> list[tuple[Path, list[int]]]: max_n_test_files = 1000 test_filelist = [] rng = Random(42) def get_target_id_generator(): if n_speakers > 8: while True: order = list(range(n_speakers)) rng.shuffle(order) yield from order else: while True: yield from range(n_speakers) target_id_generator = get_target_id_generator() for file in sorted(in_test_wav_dir.iterdir())[:max_n_test_files]: if file.suffix.lower() not in AUDIO_FILE_SUFFIXES: continue target_ids = [next(target_id_generator) for _ in range(min(8, n_speakers))] test_filelist.append((file, target_ids)) return test_filelist test_filelist = get_test_filelist(in_test_wav_dir, n_speakers) if len(test_filelist) == 0: warnings.warn(f"No audio data found in {test_filelist}.") print(f"{len(test_filelist)=}") for file, target_ids in test_filelist[:12]: print(f" {file}, {target_ids}") if len(test_filelist) > 12: print(" ...") print() # データ training_dataset = WavDataset( training_filelist, in_sample_rate=h.in_sample_rate, out_sample_rate=h.out_sample_rate, wav_length=h.wav_length, segment_length=h.segment_length, noise_files=noise_files, ir_files=ir_files, ) training_loader = torch.utils.data.DataLoader( training_dataset, num_workers=min(h.num_workers, os.cpu_count()), collate_fn=training_dataset.collate, shuffle=True, sampler=None, batch_size=h.batch_size, pin_memory=True, drop_last=True, persistent_workers=True, ) print("Computing mean F0s of target speakers...", end="") speaker_f0s = [] for speaker, files in enumerate(speaker_audio_files): if len(files) > 10: files = Random(42).sample(files, 10) f0 = compute_mean_f0(files) speaker_f0s.append(f0) if speaker % 5 == 0: print() print(f" {speaker:3d}: {f0:.1f}Hz", end=",") print() print("Done.") print("Computing pitch shifts for test files...") test_pitch_shifts = [] source_f0s = [] for i, (file, target_ids) in enumerate(tqdm(test_filelist)): source_f0 = compute_mean_f0([file], method="harvest") source_f0s.append(source_f0) if math.isnan(source_f0): test_pitch_shifts.append([0] * len(target_ids)) continue pitch_shifts = [] for target_id in target_ids: target_f0 = speaker_f0s[target_id] if target_f0 != target_f0: pitch_shift = 0 else: pitch_shift = int(round(12.0 * math.log2(target_f0 / source_f0))) pitch_shifts.append(pitch_shift) test_pitch_shifts.append(pitch_shifts) print("Done.") # モデルと最適化 phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) phone_extractor_checkpoint = torch.load( repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True ) print( phone_extractor.load_state_dict(phone_extractor_checkpoint["phone_extractor"]) ) del phone_extractor_checkpoint pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) pitch_estimator_checkpoint = torch.load( repo_root() / h.pitch_estimator_file, map_location="cpu", weights_only=True ) print( pitch_estimator.load_state_dict(pitch_estimator_checkpoint["pitch_estimator"]) ) del pitch_estimator_checkpoint net_g = ConverterNetwork( phone_extractor, pitch_estimator, n_speakers, h.hidden_channels, ).to(device) net_d = MultiPeriodDiscriminator(san=h.san).to(device) optim_g = torch.optim.AdamW( net_g.parameters(), h.learning_rate_g, betas=h.adam_betas, eps=h.adam_eps, ) optim_d = torch.optim.AdamW( net_d.parameters(), h.learning_rate_d, betas=h.adam_betas, eps=h.adam_eps, ) grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp) grad_balancer = GradBalancer( weights={ "loss_mel": h.grad_weight_mel, "loss_adv": h.grad_weight_adv, "loss_fm": h.grad_weight_fm, } | ({"loss_ap": h.grad_weight_ap} if h.grad_weight_ap else {}), ema_decay=h.grad_balancer_ema_decay, ) resample_to_in_sample_rate = torchaudio.transforms.Resample( h.out_sample_rate, h.in_sample_rate ).to(device) # チェックポイント読み出し initial_iteration = 0 if resume: checkpoint_file = latest_checkpoint_file elif h.pretrained_file is not None: checkpoint_file = repo_root() / h.pretrained_file else: checkpoint_file = None if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True) if not resume and not skip_training: # ファインチューニング checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"]) initial_speaker_embedding = checkpoint["net_g"][ "embed_speaker.weight" ].mean(0, keepdim=True) if True: checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[ [0] * n_speakers ] else: # 話者追加用 assert n_speakers > checkpoint_n_speakers print( f"embed_speaker.weight was padded: {checkpoint_n_speakers} -> {n_speakers}" ) checkpoint["net_g"]["embed_speaker.weight"] = F.pad( checkpoint["net_g"]["embed_speaker.weight"], (0, 0, 0, n_speakers - checkpoint_n_speakers), ) checkpoint["net_g"]["embed_speaker.weight"][ checkpoint_n_speakers: ] = initial_speaker_embedding print(net_g.load_state_dict(checkpoint["net_g"], strict=False)) print(net_d.load_state_dict(checkpoint["net_d"], strict=False)) if resume or skip_training: optim_g.load_state_dict(checkpoint["optim_g"]) optim_d.load_state_dict(checkpoint["optim_d"]) initial_iteration = checkpoint["iteration"] grad_balancer.load_state_dict(checkpoint["grad_balancer"]) grad_scaler.load_state_dict(checkpoint["grad_scaler"]) # スケジューラ def get_cosine_annealing_warmup_scheduler( optimizer: torch.optim.Optimizer, warmup_epochs: int, total_epochs: int, min_learning_rate: float, ) -> torch.optim.lr_scheduler.LambdaLR: lr_ratio = min_learning_rate / optimizer.param_groups[0]["lr"] m = 0.5 * (1.0 - lr_ratio) a = 0.5 * (1.0 + lr_ratio) def lr_lambda(current_epoch: int) -> float: if current_epoch < warmup_epochs: return current_epoch / warmup_epochs elif current_epoch < total_epochs: rate = (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs) return math.cos(rate * math.pi) * m + a else: return min_learning_rate return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) scheduler_g = get_cosine_annealing_warmup_scheduler( optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate_g ) scheduler_d = get_cosine_annealing_warmup_scheduler( optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate_d ) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.", ) for _ in range(initial_iteration + 1): scheduler_g.step() scheduler_d.step() net_g.train() net_d.train() # ログとか dict_scalars = defaultdict(list) quality_tester = QualityTester().eval().to(device) if skip_training: writer = None else: writer = SummaryWriter(out_dir) writer.add_text( "log", f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.", initial_iteration, ) if not resume: with open(out_dir / "config.json", "w", encoding="utf-8") as f: json.dump(dict(h), f, indent=4) if not is_notebook(): shutil.copy(__file__, out_dir) return ( device, in_wav_dataset_dir, h, out_dir, speakers, test_filelist, training_loader, speaker_f0s, test_pitch_shifts, phone_extractor, pitch_estimator, net_g, net_d, optim_g, optim_d, grad_scaler, grad_balancer, resample_to_in_sample_rate, initial_iteration, scheduler_g, scheduler_d, dict_scalars, quality_tester, writer, ) if __name__ == "__main__": ( device, in_wav_dataset_dir, h, out_dir, speakers, test_filelist, training_loader, speaker_f0s, test_pitch_shifts, phone_extractor, pitch_estimator, net_g, net_d, optim_g, optim_d, grad_scaler, grad_balancer, resample_to_in_sample_rate, initial_iteration, scheduler_g, scheduler_d, dict_scalars, quality_tester, writer, ) = prepare_training() if __name__ == "__main__" and writer is not None: if h.compile_convnext: raw_convnextstack_forward = ConvNeXtStack.forward compiled_convnextstack_forward = torch.compile( ConvNeXtStack.forward, mode="reduce-overhead" ) if h.compile_d4c: d4c = torch.compile(d4c, mode="reduce-overhead") if h.compile_discriminator: MultiPeriodDiscriminator.forward_and_compute_loss = torch.compile( MultiPeriodDiscriminator.forward_and_compute_loss, mode="reduce-overhead" ) # 学習 with ( torch.profiler.profile( schedule=torch.profiler.schedule(wait=1500, warmup=10, active=5, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler(out_dir), record_shapes=True, with_stack=True, profile_memory=True, with_flops=True, ) if h.profile else nullcontext() ) as profiler: for iteration in tqdm(range(initial_iteration, h.n_steps)): # === 1. データ前処理 === try: batch = next(data_iter) except: data_iter = iter(training_loader) batch = next(data_iter) ( clean_wavs, noisy_wavs_16k, slice_starts, speaker_ids, formant_shift_semitone, ) = map(lambda x: x.to(device, non_blocking=True), batch) # === 2. 学習 === with torch.amp.autocast("cuda", enabled=h.use_amp): # === 2.1 Generator の順伝播 === if h.compile_convnext: ConvNeXtStack.forward = compiled_convnextstack_forward y, y_hat, y_hat_for_backward, loss_mel, loss_ap, generator_stats = ( net_g.forward_and_compute_loss( noisy_wavs_16k[:, None, :], speaker_ids, formant_shift_semitone, slice_start_indices=slice_starts, slice_segment_length=h.segment_length, y_all=clean_wavs[:, None, :], enable_loss_ap=h.grad_weight_ap != 0.0, ) ) if h.compile_convnext: ConvNeXtStack.forward = raw_convnextstack_forward assert y_hat.isfinite().all() assert loss_mel.isfinite().all() assert loss_ap.isfinite().all() # === 2.2 Discriminator の順伝播 === loss_discriminator, loss_adv, loss_fm, discriminator_stats = ( net_d.forward_and_compute_loss(y, y_hat) ) assert loss_discriminator.isfinite().all() assert loss_adv.isfinite().all() assert loss_fm.isfinite().all() # === 2.3 Discriminator の逆伝播 === for param in net_d.parameters(): assert param.grad is None grad_scaler.scale(loss_discriminator).backward( retain_graph=True, inputs=list(net_d.parameters()) ) loss_discriminator = loss_discriminator.item() grad_scaler.unscale_(optim_d) if iteration % 5 == 0: grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True) else: grad_norm_d = math.nan d_grad_norm_stats = {} # === 2.4 Generator の逆伝播 === for param in net_g.parameters(): assert param.grad is None gradient_balancer_stats = grad_balancer.backward( { "loss_mel": loss_mel, "loss_adv": loss_adv, "loss_fm": loss_fm, } | ({"loss_ap": loss_ap} if h.grad_weight_ap else {}), y_hat_for_backward, grad_scaler, skip_update_ema=iteration > 10 and iteration % 5 != 0, ) loss_mel = loss_mel.item() loss_adv = loss_adv.item() loss_fm = loss_fm.item() if h.grad_weight_ap: loss_ap = loss_ap.item() grad_scaler.unscale_(optim_g) if iteration % 5 == 0: grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True) else: grad_norm_g = math.nan g_grad_norm_stats = {} # === 2.5 パラメータの更新 === grad_scaler.step(optim_g) optim_g.zero_grad(set_to_none=True) grad_scaler.step(optim_d) optim_d.zero_grad(set_to_none=True) grad_scaler.update() # === 3. ログ === dict_scalars["loss_g/loss_mel"].append(loss_mel) if h.grad_weight_ap: dict_scalars["loss_g/loss_ap"].append(loss_ap) dict_scalars["loss_g/loss_fm"].append(loss_fm) dict_scalars["loss_g/loss_adv"].append(loss_adv) dict_scalars["other/grad_scale"].append(grad_scaler.get_scale()) dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator) if math.isfinite(grad_norm_d): dict_scalars["other/gradient_norm_d"].append(grad_norm_d) for name, value in d_grad_norm_stats.items(): dict_scalars[f"~gradient_norm_d/{name}"].append(value) if math.isfinite(grad_norm_g): dict_scalars["other/gradient_norm_g"].append(grad_norm_g) for name, value in g_grad_norm_stats.items(): dict_scalars[f"~gradient_norm_g/{name}"].append(value) dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0]) dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0]) for k, v in generator_stats.items(): dict_scalars[f"~loss_generator/{k}"].append(v) for k, v in discriminator_stats.items(): dict_scalars[f"~loss_discriminator/{k}"].append(v) for k, v in gradient_balancer_stats.items(): dict_scalars[f"~gradient_balancer/{k}"].append(v) if (iteration + 1) % 1000 == 0 or iteration == 0: for name, scalars in dict_scalars.items(): if scalars: writer.add_scalar( name, sum(scalars) / len(scalars), iteration + 1 ) scalars.clear() for name, param in net_g.named_parameters(): writer.add_histogram(f"weight/{name}", param, iteration + 1) intermediate_feature_stats = {} hook_handles = [] def get_layer_hook(name): def compute_stats(module, x, suffix): if not isinstance(x, torch.Tensor): return if x.dtype not in [torch.float32, torch.float16]: return if isinstance(module, nn.Identity): return x = x.detach().float() var = x.var().item() if isinstance(module, (nn.Linear, nn.LayerNorm)): channel_var, channel_mean = torch.var_mean( x.reshape(-1, x.size(-1)), 0 ) elif isinstance(module, nn.Conv1d): channel_var, channel_mean = torch.var_mean(x, [0, 2]) else: return average_squared_channel_mean = ( channel_mean.square().mean().item() ) average_channel_var = channel_var.mean().item() tensor_idx = len(intermediate_feature_stats) // 3 intermediate_feature_stats[ f"var/{tensor_idx:02d}_{name}/{suffix}" ] = var intermediate_feature_stats[ f"avg_sq_ch_mean/{tensor_idx:02d}_{name}/{suffix}" ] = average_squared_channel_mean intermediate_feature_stats[ f"avg_ch_var/{tensor_idx:02d}_{name}/{suffix}" ] = average_channel_var def forward_pre_hook(module, input): for i, input_i in enumerate(input): compute_stats(module, input_i, f"input_{i}") def forward_hook(module, input, output): if isinstance(output, tuple): for i, output_i in enumerate(output): compute_stats(module, output_i, f"output_{i}") else: compute_stats(module, output, "output") return forward_pre_hook, forward_hook for name, layer in net_g.named_modules(): forward_pre_hook, forward_hook = get_layer_hook(name) hook_handles.append( layer.register_forward_pre_hook(forward_pre_hook) ) hook_handles.append(layer.register_forward_hook(forward_hook)) with torch.no_grad(), torch.amp.autocast("cuda", enabled=h.use_amp): net_g.forward_and_compute_loss( noisy_wavs_16k[:, None, :], speaker_ids, formant_shift_semitone, slice_start_indices=slice_starts, slice_segment_length=h.segment_length, y_all=clean_wavs[:, None, :], enable_loss_ap=h.grad_weight_ap != 0.0, ) for handle in hook_handles: handle.remove() for name, value in intermediate_feature_stats.items(): writer.add_scalar( f"~intermediate_feature_{name}", value, iteration + 1 ) # === 4. 検証 === if (iteration + 1) % ( 50000 if h.n_steps > 200000 else 2000 ) == 0 or iteration + 1 in { 1, 30000, h.n_steps, }: torch.backends.cudnn.benchmark = False net_g.eval() torch.cuda.empty_cache() dict_qualities_all = defaultdict(list) n_added_wavs = 0 with torch.inference_mode(): for i, ((file, target_ids), pitch_shift_semitones) in enumerate( zip(test_filelist, test_pitch_shifts) ): source_wav, sr = torchaudio.load(file, backend="soundfile") source_wav = source_wav.to(device) if sr != h.in_sample_rate: source_wav = get_resampler(sr, h.in_sample_rate, device)( source_wav ) source_wav = source_wav.to(device) original_source_wav_length = source_wav.size(1) # 長さのパターンを減らしてキャッシュを効かせる if source_wav.size(1) % h.in_sample_rate == 0: padded_source_wav = source_wav else: padded_source_wav = F.pad( source_wav, ( 0, h.in_sample_rate - source_wav.size(1) % h.in_sample_rate, ), ) converted = net_g( padded_source_wav[[0] * len(target_ids), None], torch.tensor(target_ids, device=device), torch.tensor( [0.0] * len(target_ids), device=device ), # フォルマントシフト torch.tensor( [float(p) for p in pitch_shift_semitones], device=device ), ).squeeze_(1)[:, : original_source_wav_length // 160 * 240] if i < 12: if iteration == 0: writer.add_audio( f"source/y_{i:02d}", source_wav, iteration + 1, h.in_sample_rate, ) for d in range( min( len(target_ids), 1 + (12 - i - 1) // len(test_filelist), ) ): idx_in_batch = n_added_wavs % len(target_ids) writer.add_audio( f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}", converted[idx_in_batch], iteration + 1, h.out_sample_rate, ) n_added_wavs += 1 converted = resample_to_in_sample_rate(converted) quality = quality_tester.test(converted, source_wav) for metric_name, values in quality.items(): dict_qualities_all[metric_name].extend(values) assert n_added_wavs == min( 12, len(test_filelist) * len(test_filelist[0][1]) ), ( n_added_wavs, len(test_filelist), len(speakers), len(test_filelist[0][1]), ) dict_qualities = { metric_name: sum(values) / len(values) for metric_name, values in dict_qualities_all.items() if len(values) } for metric_name, value in dict_qualities.items(): writer.add_scalar(f"validation/{metric_name}", value, iteration + 1) for metric_name, values in dict_qualities_all.items(): for i, value in enumerate(values): writer.add_scalar( f"~validation_{metric_name}/{i:03d}", value, iteration + 1 ) del dict_qualities, dict_qualities_all net_g.train() torch.backends.cudnn.benchmark = True gc.collect() torch.cuda.empty_cache() # === 5. 保存 === if (iteration + 1) % ( 50000 if h.n_steps > 200000 else 2000 ) == 0 or iteration + 1 in { 1, 30000, h.n_steps, }: # チェックポイント name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}" checkpoint_file_save = out_dir / f"checkpoint_{name}.pt" if checkpoint_file_save.exists(): checkpoint_file_save = checkpoint_file_save.with_name( f"{checkpoint_file_save.name}_{hash(None):x}" ) torch.save( { "iteration": iteration + 1, "net_g": net_g.state_dict(), "phone_extractor": phone_extractor.state_dict(), "pitch_estimator": pitch_estimator.state_dict(), "net_d": net_d.state_dict(), "optim_g": optim_g.state_dict(), "optim_d": optim_d.state_dict(), "grad_balancer": grad_balancer.state_dict(), "grad_scaler": grad_scaler.state_dict(), "h": dict(h), }, checkpoint_file_save, ) shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt") # 推論用 paraphernalia_dir = out_dir / f"paraphernalia_{name}" if paraphernalia_dir.exists(): paraphernalia_dir = paraphernalia_dir.with_name( f"{paraphernalia_dir.name}_{hash(None):x}" ) paraphernalia_dir.mkdir() phone_extractor_fp16 = PhoneExtractor() phone_extractor_fp16.load_state_dict(phone_extractor.state_dict()) phone_extractor_fp16.remove_weight_norm() phone_extractor_fp16.merge_weights() phone_extractor_fp16.half() phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin") del phone_extractor_fp16 pitch_estimator_fp16 = PitchEstimator() pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict()) pitch_estimator_fp16.merge_weights() pitch_estimator_fp16.half() pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin") del pitch_estimator_fp16 net_g_fp16 = ConverterNetwork( nn.Module(), nn.Module(), len(speakers), h.hidden_channels ) net_g_fp16.load_state_dict(net_g.state_dict()) net_g_fp16.merge_weights() net_g_fp16.half() net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin") with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f: dump_layer(net_g_fp16.embed_speaker, f) with open( paraphernalia_dir / f"formant_shift_embeddings.bin", "wb" ) as f: dump_layer(net_g_fp16.embed_formant_shift, f) del net_g_fp16 shutil.copy( repo_root() / "assets/images/noimage.png", paraphernalia_dir ) with open( paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml", "w", encoding="utf-8", ) as f: f.write( f'''[model] version = "{PARAPHERNALIA_VERSION}" name = "{name}" description = """ No description for this model. このモデルの説明はありません。 """ ''' ) for speaker_id, (speaker, speaker_f0) in enumerate( zip(speakers, speaker_f0s) ): average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0) average_pitch = round(average_pitch * 8.0) / 8.0 f.write( f''' [voice.{speaker_id}] name = "{speaker}" description = """ No description for this voice. この声の説明はありません。 """ average_pitch = {average_pitch} [voice.{speaker_id}.portrait] path = "noimage.png" description = """ """ ''' ) del paraphernalia_dir # TODO: phone_extractor, pitch_estimator が既知のモデルであれば dump を省略 # === 6. スケジューラ更新 === scheduler_g.step() scheduler_d.step() if h.profile: profiler.step() print("Training finished.")