Audio-to-Audio
audio
speech
voice-conversion
Project Beatrice
Add 2.0.0-beta.2 features
79c3b57
# %% [markdown]
# ## Settings
# %%
import argparse
import gc
import json
import math
import os
import shutil
import warnings
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from fractions import Fraction
from functools import partial
from pathlib import Path
from pprint import pprint
from random import Random
from typing import BinaryIO, Literal, Optional, Union
import numpy as np
import pyworld
import torch
import torch.nn as nn
import torchaudio
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
assert "soundfile" in torchaudio.list_audio_backends()
if not hasattr(torch.amp, "GradScaler"):
class GradScaler(torch.cuda.amp.GradScaler):
def __init__(self, _, *args, **kwargs):
super().__init__(*args, **kwargs)
torch.amp.GradScaler = GradScaler
# モジュールのバージョンではない
PARAPHERNALIA_VERSION = "2.0.0-beta.1"
def is_notebook() -> bool:
return "get_ipython" in globals()
def repo_root() -> Path:
d = Path.cwd() / "dummy" if is_notebook() else Path(__file__)
assert d.is_absolute(), d
for d in d.parents:
if (d / ".git").is_dir():
return d
raise RuntimeError("Repository root is not found.")
# ハイパーパラメータ
# 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない
dict_default_hparams = {
# train
"learning_rate_g": 2e-4,
"learning_rate_d": 1e-4,
"min_learning_rate_g": 1e-5,
"min_learning_rate_d": 5e-6,
"adam_betas": [0.8, 0.99],
"adam_eps": 1e-6,
"batch_size": 8,
"grad_weight_mel": 1.0, # grad_weight は比が同じなら同じ意味になるはず
"grad_weight_ap": 2.0,
"grad_weight_adv": 3.0,
"grad_weight_fm": 3.0,
"grad_balancer_ema_decay": 0.995,
"use_amp": True,
"num_workers": 16,
"n_steps": 10000,
"warmup_steps": 2000,
"in_sample_rate": 16000, # 変更不可
"out_sample_rate": 24000, # 変更不可
"wav_length": 4 * 24000, # 4s
"segment_length": 100, # 1s
# data
"phone_extractor_file": "assets/pretrained/003b_checkpoint_03000000.pt",
"pitch_estimator_file": "assets/pretrained/008_1_checkpoint_00300000.pt",
"in_ir_wav_dir": "assets/ir",
"in_noise_wav_dir": "assets/noise",
"in_test_wav_dir": "assets/test",
"pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt", # None も可
# model
"hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要
"san": False, # ファインチューン時変更不可
"compile_convnext": False,
"compile_d4c": False,
"compile_discriminator": False,
"profile": False,
}
if __name__ == "__main__":
# スクリプト内部のデフォルト設定と assets/default_config.json が同期されているか確認
default_config_file = repo_root() / "assets/default_config.json"
if default_config_file.is_file():
with open(default_config_file, encoding="utf-8") as f:
default_config: dict = json.load(f)
for key, value in dict_default_hparams.items():
if key not in default_config:
warnings.warn(f"{key} not found in default_config.json.")
else:
if value != default_config[key]:
warnings.warn(
f"{key} differs between default_config.json ({default_config[key]}) and internal default hparams ({value})."
)
del default_config[key]
for key in default_config:
warnings.warn(f"{key} found in default_config.json is unknown.")
else:
warnings.warn("dafualt_config.json not found.")
def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]:
import ipynbname
from IPython import get_ipython
h = deepcopy(dict_default_hparams)
in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200"
try:
notebook_name = ipynbname.name()
except FileNotFoundError:
notebook_name = Path(get_ipython().user_ns["__vsc_ipynb_file__"]).name
out_dir = repo_root() / "notebooks" / notebook_name.split(".")[0].split("_")[0]
resume = False
skip_training = False
return h, in_wav_dataset_dir, out_dir, resume, skip_training
def prepare_training_configs() -> tuple[dict, Path, Path, bool, bool]:
# data_dir, out_dir は config ファイルでもコマンドライン引数でも指定でき、
# コマンドライン引数が優先される。
# 各種ファイルパスを相対パスで指定した場合、config ファイルでは
# リポジトリルートからの相対パスとなるが、コマンドライン引数では
# カレントディレクトリからの相対パスとなる。
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-d", "--data_dir", type=Path, help="directory containing the training data")
parser.add_argument("-o", "--out_dir", type=Path, help="output directory")
parser.add_argument("-r", "--resume", action="store_true", help="resume training")
parser.add_argument("-c", "--config", type=Path, help="path to the config file")
# fmt: on
args = parser.parse_args()
# config
if args.config is None:
h = deepcopy(dict_default_hparams)
else:
with open(args.config, encoding="utf-8") as f:
h = json.load(f)
for key in dict_default_hparams.keys():
if key not in h:
h[key] = dict_default_hparams[key]
warnings.warn(
f"{key} is not specified in the config file. Using the default value."
)
# data_dir
if args.data_dir is not None:
in_wav_dataset_dir = args.data_dir
elif "data_dir" in h:
in_wav_dataset_dir = repo_root() / Path(h["data_dir"])
del h["data_dir"]
else:
raise ValueError(
"data_dir must be specified. "
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`."
)
# out_dir
if args.out_dir is not None:
out_dir = args.out_dir
elif "out_dir" in h:
out_dir = repo_root() / Path(h["out_dir"])
del h["out_dir"]
else:
raise ValueError(
"out_dir must be specified. "
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`."
)
for key in list(h.keys()):
if key not in dict_default_hparams:
warnings.warn(f"`{key}` specified in the config file will be ignored.")
del h[key]
# resume
resume = args.resume
return h, in_wav_dataset_dir, out_dir, resume, False
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
# %% [markdown]
# ## Phone Extractor
# %%
def dump_params(params: torch.Tensor, f: BinaryIO):
if params is None:
return
if params.dtype == torch.bfloat16:
f.write(
params.detach()
.clone()
.float()
.view(torch.short)
.numpy()
.ravel()[1::2]
.tobytes()
)
else:
f.write(params.detach().numpy().ravel().tobytes())
f.flush()
def dump_layer(layer: nn.Module, f: BinaryIO):
dump = partial(dump_params, f=f)
if hasattr(layer, "dump"):
layer.dump(f)
elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)):
dump(layer.weight)
dump(layer.bias)
elif isinstance(layer, nn.ConvTranspose1d):
dump(layer.weight.transpose(0, 1))
dump(layer.bias)
elif isinstance(layer, nn.GRU):
dump(layer.weight_ih_l0)
dump(layer.bias_ih_l0)
dump(layer.weight_hh_l0)
dump(layer.bias_hh_l0)
for i in range(1, 99999):
if not hasattr(layer, f"weight_ih_l{i}"):
break
dump(getattr(layer, f"weight_ih_l{i}"))
dump(getattr(layer, f"bias_ih_l{i}"))
dump(getattr(layer, f"weight_hh_l{i}"))
dump(getattr(layer, f"bias_hh_l{i}"))
elif isinstance(layer, nn.Embedding):
dump(layer.weight)
elif isinstance(layer, nn.Parameter):
dump(layer)
elif isinstance(layer, nn.ModuleList):
for l in layer:
dump_layer(l, f)
else:
assert False, layer
class CausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
delay: int = 0,
):
padding = (kernel_size - 1) * dilation - delay
self.trim = (kernel_size - 1) * dilation - 2 * delay
if self.trim < 0:
raise ValueError
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = super().forward(input)
if self.trim == 0:
return result
else:
return result[:, :, : -self.trim]
class WSConv1d(CausalConv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
delay: int = 0,
):
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
delay=delay,
)
self.weight.data.normal_(
0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups))
)
if bias:
self.bias.data.zero_()
self.gain = nn.Parameter(torch.ones((out_channels, 1, 1)))
def standardized_weight(self) -> torch.Tensor:
var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True)
scale = (
self.gain
* (
self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8
).rsqrt()
)
return scale * (self.weight - mean)
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = F.conv1d(
input,
self.standardized_weight(),
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
if self.trim == 0:
return result
else:
return result[:, :, : -self.trim]
def merge_weights(self):
self.weight.data[:] = self.standardized_weight().detach()
self.gain.data.fill_(1.0)
class WSLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__(in_features, out_features, bias)
self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features))
self.bias.data.zero_()
self.gain = nn.Parameter(torch.ones((out_features, 1)))
def standardized_weight(self) -> torch.Tensor:
var, mean = torch.var_mean(self.weight, 1, keepdim=True)
scale = self.gain * (self.in_features * var + 1e-8).rsqrt()
return scale * (self.weight - mean)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.standardized_weight(), self.bias)
def merge_weights(self):
self.weight.data[:] = self.standardized_weight().detach()
self.gain.data.fill_(1.0)
class ConvNeXtBlock(nn.Module):
def __init__(
self,
channels: int,
intermediate_channels: int,
layer_scale_init_value: float,
kernel_size: int = 7,
use_weight_standardization: bool = False,
enable_scaling: bool = False,
pre_scale: float = 1.0,
post_scale: float = 1.0,
):
super().__init__()
self.use_weight_standardization = use_weight_standardization
self.enable_scaling = enable_scaling
self.dwconv = CausalConv1d(
channels, channels, kernel_size=kernel_size, groups=channels
)
self.norm = nn.LayerNorm(channels)
self.pwconv1 = nn.Linear(channels, intermediate_channels)
self.pwconv2 = nn.Linear(intermediate_channels, channels)
self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value))
self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size))
self.dwconv.bias.data.zero_()
self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels))
self.pwconv1.bias.data.zero_()
self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels))
self.pwconv2.bias.data.zero_()
if use_weight_standardization:
self.norm = nn.Identity()
self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels)
self.pwconv1 = WSLinear(channels, intermediate_channels)
self.pwconv2 = WSLinear(intermediate_channels, channels)
del self.gamma
if enable_scaling:
self.register_buffer("pre_scale", torch.tensor(pre_scale))
self.register_buffer("post_scale", torch.tensor(post_scale))
self.post_scale_weight = nn.Parameter(torch.ones(()))
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
if self.enable_scaling:
x = x * self.pre_scale
x = self.dwconv(x)
x = x.transpose(1, 2)
x = self.norm(x)
x = self.pwconv1(x)
x = F.gelu(x, approximate="tanh")
x = self.pwconv2(x)
if not self.use_weight_standardization:
x *= self.gamma
if self.enable_scaling:
x *= self.post_scale * self.post_scale_weight
x = x.transpose(1, 2)
x += identity
return x
def merge_weights(self):
if self.use_weight_standardization:
self.dwconv.merge_weights()
self.pwconv1.merge_weights()
self.pwconv2.merge_weights()
else:
self.pwconv1.bias.data += (
self.norm.bias.data[None, :] * self.pwconv1.weight.data
).sum(1)
self.pwconv1.weight.data *= self.norm.weight.data[None, :]
self.norm.bias.data[:] = 0.0
self.norm.weight.data[:] = 1.0
self.pwconv2.weight.data *= self.gamma.data[:, None]
self.pwconv2.bias.data *= self.gamma.data
self.gamma.data[:] = 1.0
if self.enable_scaling:
self.dwconv.weight.data *= self.pre_scale.data
self.pre_scale.data.fill_(1.0)
self.pwconv2.weight.data *= (
self.post_scale.data * self.post_scale_weight.data
)
self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data
self.post_scale.data.fill_(1.0)
self.post_scale_weight.data.fill_(1.0)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.dwconv, f)
dump_layer(self.pwconv1, f)
dump_layer(self.pwconv2, f)
class ConvNeXtStack(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
intermediate_channels: int,
n_blocks: int,
delay: int,
embed_kernel_size: int,
kernel_size: int,
use_weight_standardization: bool = False,
enable_scaling: bool = False,
):
super().__init__()
assert delay * 2 + 1 <= embed_kernel_size
self.use_weight_standardization = use_weight_standardization
self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay)
self.norm = nn.LayerNorm(channels)
self.convnext = nn.ModuleList()
for i in range(n_blocks):
pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0
post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0
block = ConvNeXtBlock(
channels=channels,
intermediate_channels=intermediate_channels,
layer_scale_init_value=1.0 / n_blocks,
kernel_size=kernel_size,
use_weight_standardization=use_weight_standardization,
enable_scaling=enable_scaling,
pre_scale=pre_scale,
post_scale=post_scale,
)
self.convnext.append(block)
self.final_layer_norm = nn.LayerNorm(channels)
self.embed.weight.data.normal_(
0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels))
)
self.embed.bias.data.zero_()
if use_weight_standardization:
self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay)
self.norm = nn.Identity()
self.final_layer_norm = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2)
return x
def merge_weights(self):
if self.use_weight_standardization:
self.embed.merge_weights()
for conv_block in self.convnext:
conv_block.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed, f)
if not self.use_weight_standardization:
dump_layer(self.norm, f)
dump_layer(self.convnext, f)
if not self.use_weight_standardization:
dump_layer(self.final_layer_norm, f)
class FeatureExtractor(nn.Module):
def __init__(self, hidden_channels: int):
super().__init__()
# fmt: off
self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False))
self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False))
self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False))
self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False))
self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False))
self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False))
# fmt: on
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
wav_length = x.size(2)
if wav_length % 160 != 0:
warnings.warn("wav_length % 160 != 0")
x = F.pad(x, (40, 40))
x = F.gelu(self.conv0(x), approximate="tanh")
x = F.gelu(self.conv1(x), approximate="tanh")
x = F.gelu(self.conv2(x), approximate="tanh")
x = F.gelu(self.conv3(x), approximate="tanh")
x = F.gelu(self.conv4(x), approximate="tanh")
x = F.gelu(self.conv5(x), approximate="tanh")
# [batch_size, hidden_channels, wav_length / 160]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv0)
remove_weight_norm(self.conv1)
remove_weight_norm(self.conv2)
remove_weight_norm(self.conv3)
remove_weight_norm(self.conv4)
remove_weight_norm(self.conv5)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.conv0, f)
dump_layer(self.conv1, f)
dump_layer(self.conv2, f)
dump_layer(self.conv3, f)
dump_layer(self.conv4, f)
dump_layer(self.conv5, f)
class FeatureProjection(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.norm = nn.LayerNorm(in_channels)
self.projection = nn.Conv1d(in_channels, out_channels, 1)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [batch_size, channels, length]
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
x = self.projection(x)
x = self.dropout(x)
return x
def merge_weights(self):
self.projection.bias.data += (
(self.norm.bias.data[None, :, None] * self.projection.weight.data)
.sum(1)
.squeeze(1)
)
self.projection.weight.data *= self.norm.weight.data[None, :, None]
self.norm.bias.data[:] = 0.0
self.norm.weight.data[:] = 1.0
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.projection, f)
class PhoneExtractor(nn.Module):
def __init__(
self,
phone_channels: int = 256,
hidden_channels: int = 256,
backbone_embed_kernel_size: int = 7,
kernel_size: int = 17,
n_blocks: int = 8,
):
super().__init__()
self.feature_extractor = FeatureExtractor(hidden_channels)
self.feature_projection = FeatureProjection(hidden_channels, hidden_channels)
self.n_speaker_encoder_layers = 3
self.speaker_encoder = nn.GRU(
hidden_channels,
hidden_channels,
self.n_speaker_encoder_layers,
batch_first=True,
)
for i in range(self.n_speaker_encoder_layers):
for input_char in "ih":
self.speaker_encoder = weight_norm(
self.speaker_encoder, f"weight_{input_char}h_l{i}"
)
self.backbone = ConvNeXtStack(
in_channels=hidden_channels,
channels=hidden_channels,
intermediate_channels=hidden_channels * 3,
n_blocks=n_blocks,
delay=0,
embed_kernel_size=backbone_embed_kernel_size,
kernel_size=kernel_size,
)
self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1))
def forward(
self, x: torch.Tensor, return_stats: bool = True
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
stats = {}
# [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length]
x = self.feature_extractor(x)
if return_stats:
stats["feature_norm"] = x.detach().norm(dim=1).mean()
# [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length]
x = self.feature_projection(x)
# [batch_size, hidden_channels, length] -> [batch_size, length, hidden_channels]
g, _ = self.speaker_encoder(x.transpose(1, 2))
if self.training:
batch_size, length, _ = g.size()
shuffle_sizes_for_each_data = torch.randint(
0, 50, (batch_size,), device=g.device
)
max_indices = torch.arange(length, device=g.device)[None, :, None]
min_indices = (
max_indices - shuffle_sizes_for_each_data[:, None, None]
).clamp_(min=0)
with torch.cuda.amp.autocast(False):
indices = (
torch.rand(g.size(), device=g.device)
* (max_indices - min_indices + 1)
).long() + min_indices
assert indices.min() >= 0, indices.min()
assert indices.max() < length, (indices.max(), length)
g = g.gather(1, indices)
# [batch_size, length, hidden_channels] -> [batch_size, hidden_channels, length]
g = g.transpose(1, 2).contiguous()
# [batch_size, hidden_channels, length]
x = self.backbone(x + g)
# [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length]
phone = self.head(F.gelu(x, approximate="tanh"))
results = [phone]
if return_stats:
stats["code_norm"] = phone.detach().norm(dim=1).mean().item()
results.append(stats)
if len(results) == 1:
return results[0]
return tuple(results)
@torch.inference_mode()
def units(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = self.forward(x, return_stats=False)
# [batch_size, phone_channels, length] -> [batch_size, length, phone_channels]
phone = phone.transpose(1, 2)
# [batch_size, length, phone_channels]
return phone
def remove_weight_norm(self):
self.feature_extractor.remove_weight_norm()
for i in range(self.n_speaker_encoder_layers):
for input_char in "ih":
remove_weight_norm(self.speaker_encoder, f"weight_{input_char}h_l{i}")
remove_weight_norm(self.head)
def merge_weights(self):
self.feature_projection.merge_weights()
self.backbone.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.feature_extractor, f)
dump_layer(self.feature_projection, f)
dump_layer(self.speaker_encoder, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
# %% [markdown]
# ## Pitch Estimator
# %%
def extract_pitch_features(
y: torch.Tensor, # [..., wav_length]
hop_length: int = 160, # 10ms
win_length: int = 560, # 35ms
max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256)
corr_win_length: int = 304, # 19ms
instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert max_corr_period + corr_win_length == win_length
# パディングする
padding_length = (win_length - hop_length) // 2
y = F.pad(y, (padding_length, padding_length))
# フレームにする
# [..., win_length, n_frames]
y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1)
# 複素スペクトログラム
# Complex[..., (win_length // 2 + 1), n_frames]
spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2)
# Complex[..., instfreq_features_cutoff_bin, n_frames]
spec = spec[..., :instfreq_features_cutoff_bin, :]
# 対数パワースペクトログラム
log_power_spec = spec.abs().add_(1e-5).log10_()
# 瞬時位相の時間差分
# 時刻 0 の値は 0
delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj()
delta_spec /= delta_spec.abs().add_(1e-5)
delta_spec = torch.cat(
[torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1
)
# [..., instfreq_features_cutoff_bin * 3, n_frames]
instfreq_features = torch.cat(
[log_power_spec, delta_spec.real, delta_spec.imag], dim=-2
)
# 自己相関
# 余裕があったら LPC 残差にするのも試したい
# 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、
# この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を
# 標準化する方法が思いつかなかったのでやめた
flipped_y_frames = y_frames.flip((-2,))
a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2)
b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2)
# [..., max_corr_period, n_frames]
corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :]
# エネルギー項
energy = flipped_y_frames.square_().cumsum_(-2)
energy0 = energy[..., corr_win_length - 1 : corr_win_length, :]
energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :]
# Difference function
corr_diff = (energy0 + energy).sub_(corr.mul_(2.0))
assert corr_diff.min() >= -1e-3, corr_diff.min()
corr_diff.clamp_(min=0.0) # 計算誤差対策
# 標準化
corr_diff *= 2.0 / corr_win_length
corr_diff.sqrt_()
# 変換モデルへの入力用のエネルギー
energy = (
(y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None])
.square_()
.sum(-2, keepdim=True)
)
energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15
energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差
return (
instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames]
corr_diff, # [..., max_corr_period, n_frames]
energy, # [..., 1, n_frames]
)
class PitchEstimator(nn.Module):
def __init__(
self,
input_instfreq_channels: int = 192,
input_corr_channels: int = 256,
pitch_channels: int = 384,
channels: int = 192,
intermediate_channels: int = 192 * 3,
n_blocks: int = 6,
delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms
embed_kernel_size: int = 3,
kernel_size: int = 33,
bins_per_octave: int = 96,
):
super().__init__()
self.bins_per_octave = bins_per_octave
self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1)
self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1)
self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1)
self.corr_embed_1 = nn.Conv1d(channels, channels, 1)
self.backbone = ConvNeXtStack(
channels,
channels,
intermediate_channels,
n_blocks,
delay,
embed_kernel_size,
kernel_size,
)
self.head = nn.Conv1d(channels, pitch_channels, 1)
def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# wav: [batch_size, 1, wav_length]
# [batch_size, input_instfreq_channels, length],
# [batch_size, input_corr_channels, length]
with torch.amp.autocast("cuda", enabled=False):
instfreq_features, corr_diff, energy = extract_pitch_features(
wav.squeeze(1),
hop_length=160,
win_length=560,
max_corr_period=256,
corr_win_length=304,
instfreq_features_cutoff_bin=64,
)
instfreq_features = F.gelu(
self.instfreq_embed_0(instfreq_features), approximate="tanh"
)
instfreq_features = self.instfreq_embed_1(instfreq_features)
corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh")
corr_diff = self.corr_embed_1(corr_diff)
# [batch_size, channels, length]
x = instfreq_features + corr_diff # ここ活性化関数忘れてる
x = self.backbone(x)
# [batch_size, pitch_channels, length]
x = self.head(x)
return x, energy
def sample_pitch(
self, pitch: torch.Tensor, band_width: int = 48, return_features: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# pitch: [batch_size, pitch_channels, length]
# 返されるピッチの値には 0 は含まれない
batch_size, pitch_channels, length = pitch.size()
pitch = pitch.softmax(1)
if return_features:
unvoiced_proba = pitch[:, :1, :].clone()
pitch[:, 0, :] = -100.0
pitch = (
pitch.transpose(1, 2)
.contiguous()
.view(batch_size * length, 1, pitch_channels)
)
band_pitch = F.conv1d(
pitch,
torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width),
)
# [batch_size * length, 1, pitch_channels - band_width + 1] -> Long[batch_size * length, 1]
quantized_band_pitch = band_pitch.argmax(2)
if return_features:
# [batch_size * length, 1]
band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None])
# [batch_size * length, 1]
half_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch - self.bins_per_octave).clamp_(min=1)[:, :, None],
)
half_pitch_band_proba[quantized_band_pitch <= self.bins_per_octave] = 0.0
half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# [batch_size * length, 1]
double_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch + self.bins_per_octave).clamp_(
max=pitch_channels - band_width
)[:, :, None],
)
double_pitch_band_proba[
quantized_band_pitch
> pitch_channels - band_width - self.bins_per_octave
] = 0.0
double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# Long[1, pitch_channels]
mask = torch.arange(pitch_channels, device=pitch.device)[None, :]
# bool[batch_size * length, pitch_channels]
mask = (quantized_band_pitch <= mask) & (
mask < quantized_band_pitch + band_width
)
# Long[batch_size, length]
quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length)
if return_features:
features = torch.cat(
[unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1
)
# Long[batch_size, length], [batch_size, 3, length]
return quantized_pitch, features
else:
return quantized_pitch
def merge_weights(self):
self.backbone.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.instfreq_embed_0, f)
dump_layer(self.instfreq_embed_1, f)
dump_layer(self.corr_embed_0, f)
dump_layer(self.corr_embed_1, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
# %% [markdown]
# ## Vocoder
# %%
def overlap_add(
ir_amp: torch.Tensor,
ir_phase: torch.Tensor,
window: torch.Tensor,
pitch: torch.Tensor,
hop_length: int = 240,
delay: int = 0,
sr: float = 24000.0,
) -> torch.Tensor:
batch_size, ir_length, length = ir_amp.size()
ir_length = (ir_length - 1) * 2
assert ir_phase.size() == ir_amp.size()
assert window.size() == (ir_length,), (window.size(), ir_amp.size())
assert pitch.size() == (batch_size, length * hop_length)
assert 0 <= delay < ir_length, (delay, ir_length)
# 正規化角周波数 [2π rad]
normalized_freq = pitch / sr
# 初期位相 [2π rad] をランダムに設定
normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device)
with torch.amp.autocast("cuda", enabled=False):
phase = (normalized_freq.double().cumsum_(1) % 1.0).float()
# 重ねる箇所を求める
# [n_pitchmarks], [n_pitchmarks]
indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True)
# 重ねる箇所の小数部分 (位相の遅れ) を求める
numer = 1.0 - phase[indices0, indices1]
# [n_pitchmarks]
fractional_part = numer / (numer + phase[indices0, indices1 + 1])
# 重ねる値を求める
# Complex[n_pitchmarks, ir_length / 2 + 1]
ir_amp = ir_amp[indices0, :, indices1 // hop_length]
ir_phase = ir_phase[indices0, :, indices1 // hop_length]
# 位相遅れの量 [rad]
# [n_pitchmarks, ir_length / 2 + 1]
delay_phase = (
torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[
None, :
]
* (-math.tau / ir_length)
* fractional_part[:, None]
)
# Complex[n_pitchmarks, ir_length / 2 + 1]
spec = torch.polar(ir_amp, ir_phase + delay_phase)
# [n_pitchmarks, ir_length]
ir = torch.fft.irfft(spec, n=ir_length, dim=1)
ir *= window
# 加算する値をサンプル単位にばらす
# [n_pitchmarks * ir_length]
ir = ir.ravel()
# Long[n_pitchmarks * ir_length]
indices0 = indices0[:, None].expand(-1, ir_length).ravel()
# Long[n_pitchmarks * ir_length]
indices1 = (
indices1[:, None] + torch.arange(ir_length, device=pitch.device)
).ravel()
# overlap-add する
overlap_added_signal = torch.zeros(
(batch_size, length * hop_length + ir_length), device=pitch.device
)
overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True)
overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay]
return overlap_added_signal
def generate_noise(
aperiodicity: torch.Tensor, delay: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
# aperiodicity: [batch_size, hop_length, length]
batch_size, hop_length, length = aperiodicity.size()
excitation = torch.rand(
batch_size, (length + 1) * hop_length, device=aperiodicity.device
)
excitation -= 0.5
n_fft = 2 * hop_length
# 矩形窓で分析
# Complex[batch_size, hop_length + 1, length]
noise = torch.stft(
excitation,
n_fft=n_fft,
hop_length=hop_length,
window=torch.ones(n_fft, device=excitation.device),
center=False,
return_complex=True,
)
assert noise.size(2) == aperiodicity.size(2)
noise[:, 0, :] = 0.0
noise[:, 1:, :] *= aperiodicity
# ハン窓で合成
# torch.istft は最適合成窓が使われるので使えないことに注意
# [batch_size, 2 * hop_length, length]
noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1)
noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None]
# [batch_size, (length + 1) * hop_length]
noise = F.fold(
noise,
(1, (length + 1) * hop_length),
(1, 2 * hop_length),
stride=(1, hop_length),
).squeeze_((1, 2))
assert delay < hop_length
noise = noise[:, delay : -hop_length + delay]
excitation = excitation[:, delay : -hop_length + delay]
return noise, excitation # [batch_size, length * hop_length]
class GradientEqualizerFunction(torch.autograd.Function):
"""ノルムが小さいほど勾配が大きくなってしまうのを補正する"""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, length]
rms = x.square().mean(dim=2, keepdim=True).sqrt_()
ctx.save_for_backward(rms)
return x
@staticmethod
def backward(ctx, dx: torch.Tensor) -> torch.Tensor:
# dx: [batch_size, 1, length]
(rms,) = ctx.saved_tensors
dx = dx * (math.sqrt(2.0) * rms + 0.1)
return dx
D4C_PREVENT_ZERO_DIVISION = True # False にすると本家の処理
def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor:
# x が単調増加で等間隔と仮定
# 外挿は起こらないと仮定
x = torch.as_tensor(x)
y = torch.as_tensor(y)
xi = torch.as_tensor(xi)
if xi.ndim < y.ndim:
diff_ndim = y.ndim - xi.ndim
xi = xi.view(tuple([1] * diff_ndim) + xi.size())
if xi.size()[:-1] != y.size()[:-1]:
xi = xi.expand(y.size()[:-1] + (xi.size(-1),))
assert (x.min(-1).values == x[..., 0]).all()
assert (x.max(-1).values == x[..., -1]).all()
assert (xi.min(-1).values >= x[..., 0]).all()
assert (xi.max(-1).values <= x[..., -1]).all()
delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0)
delta_x = delta_x.to(x.dtype)
xi = (xi - x[..., :1]).div_(delta_x[..., None])
xi_base = xi.floor()
xi_fraction = xi.sub_(xi_base)
xi_base = xi_base.long()
delta_y = y.diff(dim=-1, append=y[..., -1:])
yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction
return yi
def linear_smoothing(
group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor
) -> torch.Tensor:
group_delay = torch.as_tensor(group_delay)
assert group_delay.size(-1) == n_fft // 2 + 1
width = torch.as_tensor(width)
boundary = (width.max() * n_fft / sr).long() + 1
dtype = group_delay.dtype
device = group_delay.device
fft_resolution = sr / n_fft
mirroring_freq_axis = (
torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device)
.add(0.5)
.mul(fft_resolution)
)
if group_delay.ndim == 1:
mirroring_spec = F.pad(
group_delay[None], (boundary, boundary), mode="reflect"
).squeeze_(0)
elif group_delay.ndim >= 4:
shape = group_delay.size()
mirroring_spec = F.pad(
group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)),
(boundary, boundary),
mode="reflect",
).view(shape[:-1] + (shape[-1] + 2 * boundary,))
else:
mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect")
mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1)
center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_(
fft_resolution
)
low_freq = center_freq - width[..., None] * 0.5
high_freq = center_freq + width[..., None] * 0.5
levels = interp(
mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1)
)
low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1)
smoothed = (high_levels - low_levels).div_(width[..., None])
return smoothed
def dc_correction(
spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor
) -> torch.Tensor:
spec = torch.as_tensor(spec)
f0 = torch.as_tensor(f0)
dtype = spec.dtype
device = spec.device
upper_limit = 2 + (f0 * (n_fft / sr)).long()
max_upper_limit = upper_limit.max()
upper_limit_mask = (
torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None]
)
low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * (
sr / n_fft
)
low_freq_replica = interp(
f0[..., None] - low_freq_axis.flip(-1),
spec[..., : max_upper_limit + 1].flip(-1),
low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask,
)
output = spec.clone()
output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask
return output
def nuttall(n: int, device: torch.types.Device) -> torch.Tensor:
t = torch.linspace(0, math.tau, n, device=device)
coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device)
terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device)
cos_matrix = (terms[:, None] * t).cos_() # [4, n]
window = coefs.matmul(cos_matrix)
return window
def get_windowed_waveform(
x: torch.Tensor,
sr: int,
f0: torch.Tensor,
position: torch.Tensor,
half_window_length_ratio: float,
window_type: Literal["hann", "blackman"],
n_fft: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
position = torch.as_tensor(position)
current_sample = position * sr
# [...]
half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long()
# [..., fft_size]
base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device)
base_index_mask = base_index <= half_window_length[..., None]
# [..., fft_size]
safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_(
0, x.size(-1) - 1
)
# [..., fft_size]
time_axis = base_index.to(x.dtype).div_(half_window_length_ratio)
# [...]
normalized_f0 = math.pi / sr * f0
# [..., fft_size]
phase = time_axis.mul_(normalized_f0[..., None])
if window_type == "hann":
window = phase.cos_().mul_(0.5).add_(0.5)
elif window_type == "blackman":
window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42)
else:
assert False
window *= base_index_mask
prefix_shape = tuple(
max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size())
)[:-1]
waveform = (
x.expand(prefix_shape + (-1,))
.gather(-1, safe_index.expand(prefix_shape + (-1,)))
.mul_(window)
)
if not D4C_PREVENT_ZERO_DIVISION:
waveform += torch.randn_like(window).mul_(1e-12)
waveform *= base_index_mask
waveform -= window * waveform.sum(-1, keepdim=True).div_(
window.sum(-1, keepdim=True)
)
return waveform, window
def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor:
x = torch.as_tensor(x)
if D4C_PREVENT_ZERO_DIVISION:
x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8)
else:
x = x / x.norm(dim=-1, keepdim=True)
spec0 = torch.fft.rfft(x, n_fft)
spec1 = torch.fft.rfft(
x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft),
n_fft,
)
centroid = spec0.real * spec1.real + spec0.imag * spec1.imag
return centroid
def get_static_centroid(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int
) -> torch.Tensor:
"""First step: calculation of temporally static parameters on basis of group delay"""
x1, _ = get_windowed_waveform(
x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft
)
x2, _ = get_windowed_waveform(
x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft
)
centroid1 = get_centroid(x1, n_fft)
centroid2 = get_centroid(x2, n_fft)
return dc_correction(centroid1 + centroid2, sr, n_fft, f0)
def get_smoothed_power_spec(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int
) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft)
window_weight = window.square().sum(-1, keepdim=True)
rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_()
if D4C_PREVENT_ZERO_DIVISION:
x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8)
smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_()
smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0)
smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0)
return smoothed_power_spec, rms.detach().squeeze(-1)
def get_static_group_delay(
static_centroid: torch.Tensor,
smoothed_power_spec: torch.Tensor,
sr: int,
f0: torch.Tensor,
n_fft: int,
) -> torch.Tensor:
"""Second step: calculation of parameter shaping"""
if D4C_PREVENT_ZERO_DIVISION:
smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8)
static_group_delay = static_centroid / smoothed_power_spec # t_g
static_group_delay = linear_smoothing(
static_group_delay, sr, n_fft, f0 * 0.5
) # t_gs
smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) # t_gb
static_group_delay = static_group_delay - smoothed_group_delay # t_D
return static_group_delay
def get_coarse_aperiodicity(
group_delay: torch.Tensor,
sr: int,
n_fft: int,
freq_interval: int,
n_aperiodicities: int,
window: torch.Tensor,
) -> torch.Tensor:
"""Third step: estimation of band-aperiodicity"""
group_delay = torch.as_tensor(group_delay)
window = torch.as_tensor(window)
boundary = int(round(n_fft * 8 / window.size(-1)))
half_window_length = window.size(-1) // 2
coarse_aperiodicity = torch.empty(
group_delay.size()[:-1] + (n_aperiodicities,),
dtype=group_delay.dtype,
device=group_delay.device,
)
for i in range(n_aperiodicities):
center = freq_interval * (i + 1) * n_fft // sr
segment = (
group_delay[
..., center - half_window_length : center + half_window_length + 1
]
* window
)
power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_()
cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1)
if D4C_PREVENT_ZERO_DIVISION:
cumulative_power_spec.clamp_(min=6e-8)
coarse_aperiodicity[..., i] = (
cumulative_power_spec[..., n_fft // 2 - boundary - 1]
/ cumulative_power_spec[..., -1]
)
coarse_aperiodicity.log10_().mul_(10.0)
return coarse_aperiodicity
def d4c_love_train(
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float
) -> int:
x = torch.as_tensor(x)
position = torch.as_tensor(position)
f0: torch.Tensor = torch.as_tensor(f0)
vuv = f0 != 0
lowest_f0 = 40
f0 = f0.clamp(min=lowest_f0)
n_fft = 1 << (3 * sr // lowest_f0).bit_length()
boundary0 = (100 * n_fft - 1) // sr + 1
boundary1 = (4000 * n_fft - 1) // sr + 1
boundary2 = (7900 * n_fft - 1) // sr + 1
waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft)
power_spec = torch.fft.rfft(waveform, n_fft).abs().square_()
power_spec[..., : boundary0 + 1] = 0.0
cumulative_spec = power_spec.cumsum_(-1)
vuv = vuv & (
cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2]
)
return vuv
def d4c_general_body(
x: torch.Tensor,
sr: int,
f0: torch.Tensor,
freq_interval: int,
position: torch.Tensor,
n_fft: int,
n_aperiodicities: int,
window: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
static_centroid = get_static_centroid(x, sr, f0, position, n_fft)
smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft)
static_group_delay = get_static_group_delay(
static_centroid, smoothed_power_spec, sr, f0, n_fft
)
coarse_aperiodicity = get_coarse_aperiodicity(
static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window
)
coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0)
return coarse_aperiodicity, rms
def d4c(
x: torch.Tensor,
f0: torch.Tensor,
t: torch.Tensor,
sr: int,
threshold: float = 0.85,
n_fft_spec: Optional[int] = None,
coarse_only: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py"""
FLOOR_F0 = 71
FLOOR_F0_D4C = 47
UPPER_LIMIT = 15000
FREQ_INTERVAL = 3000
assert sr == int(sr)
sr = int(sr)
assert sr % 2 == 0
x = torch.as_tensor(x)
f0 = torch.as_tensor(f0)
temporal_positions = torch.as_tensor(t)
n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length()
if n_fft_spec is None:
n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length()
n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL
assert n_aperiodicities >= 1
window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1
window = nuttall(window_length, device=x.device)
freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec)
coarse_aperiodicity, rms = d4c_general_body(
x[..., None, :],
sr,
f0.clamp(min=FLOOR_F0_D4C),
FREQ_INTERVAL,
temporal_positions,
n_fft_d4c,
n_aperiodicities,
window,
)
if coarse_only:
return coarse_aperiodicity, rms
even_coarse_axis = (
torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL
)
assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr
coarse_axis_low = (
torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device)
* FREQ_INTERVAL
)
aperiodicity_low = interp(
coarse_axis_low,
F.pad(coarse_aperiodicity, (1, 0), value=-60.0),
freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL],
)
coarse_axis_high = torch.tensor(
[n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device
)
aperiodicity_high = interp(
coarse_axis_high,
F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12),
freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL],
)
aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1)
aperiodicity = 10.0 ** (aperiodicity / 20.0)
vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold)
aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12)
return aperiodicity, coarse_aperiodicity
class Vocoder(nn.Module):
def __init__(
self,
channels: int,
hop_length: int = 240,
n_pre_blocks: int = 4,
out_sample_rate: float = 24000.0,
):
super().__init__()
self.hop_length = hop_length
self.out_sample_rate = out_sample_rate
self.prenet = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=n_pre_blocks,
delay=2, # 20ms 遅延
embed_kernel_size=7,
kernel_size=33,
enable_scaling=True,
)
self.ir_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=2,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.ir_generator_post = WSConv1d(channels, 512, 1)
self.register_buffer("ir_scale", torch.tensor(1.0))
self.ir_window = nn.Parameter(torch.ones(512))
self.aperiodicity_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=1,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False)
self.register_buffer("aperiodicity_scale", torch.tensor(0.005))
self.post_filter_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=1,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_standardization=True,
enable_scaling=True,
)
self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False)
self.register_buffer("post_filter_scale", torch.tensor(0.01))
def forward(
self, x: torch.Tensor, pitch: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
# x: [batch_size, channels, length]
# pitch: [batch_size, length]
batch_size, _, length = x.size()
x = self.prenet(x)
ir = self.ir_generator(x)
ir = F.silu(ir, inplace=True)
# [batch_size, 512, length]
ir = self.ir_generator_post(ir)
ir *= self.ir_scale
ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp()
ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1))
ir_phase[:, 1::2, :] += math.pi
# TODO: 直流成分が正の値しか取れないのを修正する
# 最近傍補間
# [batch_size, length * hop_length]
pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1)
# [batch_size, length * hop_length]
periodic_signal = overlap_add(
ir_amp,
ir_phase,
self.ir_window,
pitch,
self.hop_length,
delay=0,
sr=self.out_sample_rate,
)
aperiodicity = self.aperiodicity_generator(x)
aperiodicity = F.silu(aperiodicity, inplace=True)
# [batch_size, hop_length, length]
aperiodicity = self.aperiodicity_generator_post(aperiodicity)
aperiodicity *= self.aperiodicity_scale
# [batch_size, length * hop_length], [batch_size, length * hop_length]
aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0)
post_filter = self.post_filter_generator(x)
post_filter = F.silu(post_filter, inplace=True)
# [batch_size, 512, length]
post_filter = self.post_filter_generator_post(post_filter)
post_filter *= self.post_filter_scale
post_filter[:, 0, :] += 1.0
# [batch_size, length, 512]
post_filter = post_filter.transpose(1, 2)
with torch.amp.autocast("cuda", enabled=False):
periodic_signal = periodic_signal.float()
aperiodic_signal = aperiodic_signal.float()
post_filter = post_filter.float()
post_filter = torch.fft.rfft(post_filter, n=768)
# [batch_size, length, 768]
periodic_signal = torch.fft.irfft(
torch.fft.rfft(
periodic_signal.view(batch_size, length, self.hop_length), n=768
)
* post_filter,
n=768,
)
aperiodic_signal = torch.fft.irfft(
torch.fft.rfft(
aperiodic_signal.view(batch_size, length, self.hop_length), n=768
)
* post_filter,
n=768,
)
periodic_signal = F.fold(
periodic_signal.transpose(1, 2),
(1, (length - 1) * self.hop_length + 768),
(1, 768),
stride=(1, self.hop_length),
).squeeze_((1, 2))
aperiodic_signal = F.fold(
aperiodic_signal.transpose(1, 2),
(1, (length - 1) * self.hop_length + 768),
(1, 768),
stride=(1, self.hop_length),
).squeeze_((1, 2))
periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length]
aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length]
noise_excitation = noise_excitation[:, 120:]
# TODO: compensation の正確さが怪しくなってくる。今も本当に必要なのか?
# [batch_size, 1, length * hop_length]
y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :]
y_g_hat = GradientEqualizerFunction.apply(y_g_hat)
return y_g_hat, {
"periodic_signal": periodic_signal.detach(),
"aperiodic_signal": aperiodic_signal.detach(),
"noise_excitation": noise_excitation.detach(),
}
def merge_weights(self):
self.prenet.merge_weights()
self.ir_generator.merge_weights()
self.ir_generator_post.merge_weights()
self.aperiodicity_generator.merge_weights()
self.aperiodicity_generator_post.merge_weights()
self.ir_generator_post.weight.data *= self.ir_scale
self.ir_generator_post.bias.data *= self.ir_scale
self.ir_scale.fill_(1.0)
self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale
self.aperiodicity_scale.fill_(1.0)
self.post_filter_generator.merge_weights()
self.post_filter_generator_post.merge_weights()
self.post_filter_generator_post.weight.data *= self.post_filter_scale
self.post_filter_scale.fill_(1.0)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.prenet, f)
dump_layer(self.ir_generator, f)
dump_layer(self.ir_generator_post, f)
dump_layer(self.ir_window, f)
dump_layer(self.aperiodicity_generator, f)
dump_layer(self.aperiodicity_generator_post, f)
dump_layer(self.post_filter_generator, f)
dump_layer(self.post_filter_generator_post, f)
def compute_loudness(
x: torch.Tensor, sr: int, win_lengths: list[int]
) -> list[torch.Tensor]:
# x: [batch_size, wav_length]
assert x.ndim == 2
n_fft = 2048
chunk_length = n_fft // 2
n_taps = chunk_length + 1
results = []
with torch.amp.autocast("cuda", enabled=False):
if not hasattr(compute_loudness, "filter"):
compute_loudness.filter = {}
if sr not in compute_loudness.filter:
ir = torch.zeros(n_taps, device=x.device, dtype=torch.double)
ir[0] = 0.5
ir = torchaudio.functional.treble_biquad(
ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2)
)
ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5)
ir *= 2.0
compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to(
torch.complex64
)
x = x.float()
wav_length = x.size(-1)
if wav_length % chunk_length != 0:
x = F.pad(x, (0, chunk_length - wav_length % chunk_length))
padded_wav_length = x.size(-1)
x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length))
x = torch.fft.irfft(
torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr],
n=n_fft,
)
x = F.fold(
x.transpose(-2, -1),
(1, padded_wav_length + chunk_length),
(1, n_fft),
stride=(1, chunk_length),
).squeeze_((-3, -2))[..., :wav_length]
x.square_()
for win_length in win_lengths:
hop_length = win_length // 4
# [..., n_frames]
energy = (
x.unfold(-1, win_length, hop_length)
.matmul(torch.hann_window(win_length, device=x.device))
.add_(win_length / 4.0 * 1e-5)
.log10_()
)
# フィルタリング後の波形が振幅 1 の正弦波なら大体 log10(win_length/4), 1 の差は 10dB の差
results.append(energy)
return results
def slice_segments(
x: torch.Tensor, start_indices: torch.Tensor, segment_length: int
) -> torch.Tensor:
batch_size, channels, _ = x.size()
# [batch_size, 1, segment_size]
indices = start_indices[:, None, None] + torch.arange(
segment_length, device=start_indices.device
)
# [batch_size, channels, segment_size]
indices = indices.expand(batch_size, channels, segment_length)
return x.gather(2, indices)
class ConverterNetwork(nn.Module):
def __init__(
self,
phone_extractor: PhoneExtractor,
pitch_estimator: PitchEstimator,
n_speakers: int,
hidden_channels: int,
):
super().__init__()
self.frozen_modules = {
"phone_extractor": phone_extractor.eval().requires_grad_(False),
"pitch_estimator": pitch_estimator.eval().requires_grad_(False),
}
self.out_sample_rate = out_sample_rate = 24000
self.embed_phone = nn.Conv1d(256, hidden_channels, 1)
self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5)))
self.embed_phone.bias.data.zero_()
self.embed_quantized_pitch = nn.Embedding(384, hidden_channels)
phase = (
torch.arange(384, dtype=torch.float)[:, None]
* (
torch.arange(0, hidden_channels, 2, dtype=torch.float)
* (-math.log(10000.0) / hidden_channels)
).exp_()
)
self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin()
self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_()
self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0)
self.embed_quantized_pitch.weight.requires_grad_(False)
self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1)
self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5)))
self.embed_pitch_features.bias.data.zero_()
self.embed_speaker = nn.Embedding(n_speakers, hidden_channels)
self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
self.embed_formant_shift = nn.Embedding(9, hidden_channels)
self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0))
self.vocoder = Vocoder(
channels=hidden_channels,
hop_length=out_sample_rate // 100,
n_pre_blocks=4,
out_sample_rate=out_sample_rate,
)
self.melspectrograms = nn.ModuleList()
for win_length, n_mels in [
(32, 5),
(64, 10),
(128, 20),
(256, 40),
(512, 80),
(1024, 160),
(2048, 320),
]:
self.melspectrograms.append(
torchaudio.transforms.MelSpectrogram(
sample_rate=out_sample_rate,
n_fft=win_length,
win_length=win_length,
hop_length=win_length // 4,
n_mels=n_mels,
power=2,
norm="slaney",
mel_scale="slaney",
)
)
def _get_resampler(
self, orig_freq, new_freq, device, cache={}
) -> torchaudio.transforms.Resample:
key = orig_freq, new_freq
if key in cache:
return cache[key]
resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to(
device, non_blocking=True
)
cache[key] = resampler
return resampler
def forward(
self,
x: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
pitch_shift_semitone: Optional[torch.Tensor] = None,
slice_start_indices: Optional[torch.Tensor] = None,
slice_segment_length: Optional[int] = None,
return_stats: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# pitch_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
batch_size, _, _ = x.size()
with torch.inference_mode():
phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"]
pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = phone_extractor.units(x).transpose(1, 2)
# [batch_size, 1, wav_length] -> [batch_size, pitch_channels, length], [batch_size, 1, length]
pitch, energy = pitch_estimator(x)
# augmentation
if self.training:
# [batch_size, pitch_channels - 1]
weights = pitch.softmax(1)[:, 1:, :].mean(2)
# [batch_size]
mean_pitch = (
weights * torch.arange(1, 384, device=weights.device)
).sum(1) / weights.sum(1)
mean_pitch = mean_pitch.round_().long()
target_pitch = torch.randint_like(mean_pitch, 64, 257)
shift = target_pitch - mean_pitch
shift_ratio = (
2.0 ** (shift.float() / pitch_estimator.bins_per_octave)
).tolist()
shift = []
interval_length = 100 # 1s
interval_zeros = torch.zeros(
(1, 1, interval_length * 160), device=x.device
)
concatenated_shifted_x = []
offsets = [0]
torch.backends.cudnn.benchmark = False
for i in range(batch_size):
shift_ratio_i = shift_ratio[i]
shift_ratio_fraction_i = Fraction.from_float(
shift_ratio_i
).limit_denominator(30)
shift_numer_i = shift_ratio_fraction_i.numerator
shift_denom_i = shift_ratio_fraction_i.denominator
shift_ratio_i = shift_numer_i / shift_denom_i
shift_i = int(
round(
math.log2(shift_ratio_i) * pitch_estimator.bins_per_octave
)
)
shift.append(shift_i)
shift_ratio[i] = shift_ratio_i
# [1, 1, wav_length / shift_ratio]
with torch.amp.autocast("cuda", enabled=False):
shifted_x_i = self._get_resampler(
shift_numer_i, shift_denom_i, x.device
)(x[i])[None]
if shifted_x_i.size(2) % 160 != 0:
shifted_x_i = F.pad(
shifted_x_i,
(0, 160 - shifted_x_i.size(2) % 160),
mode="reflect",
)
assert shifted_x_i.size(2) % 160 == 0
offsets.append(
offsets[-1] + interval_length + shifted_x_i.size(2) // 160
)
concatenated_shifted_x.extend([interval_zeros, shifted_x_i])
if offsets[-1] % 256 != 0:
# 長さが同じ方が何かのキャッシュが効いて早くなるようなので
# 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす
concatenated_shifted_x.append(
torch.zeros(
(1, 1, (256 - offsets[-1] % 256) * 160), device=x.device
)
)
# [batch_size, 1, sum(wav_length) + batch_size * 16000]
concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2)
assert concatenated_shifted_x.size(2) % (256 * 160) == 0
# [1, pitch_channels, length / shift_ratio], [1, 1, length / shift_ratio]
concatenated_pitch, concatenated_energy = pitch_estimator(
concatenated_shifted_x
)
for i in range(batch_size):
shift_i = shift[i]
shift_ratio_i = shift_ratio[i]
left = offsets[i] + interval_length
right = offsets[i + 1]
pitch_i = concatenated_pitch[:, :, left:right]
energy_i = concatenated_energy[:, :, left:right]
pitch_i = F.interpolate(
pitch_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
energy_i = F.interpolate(
energy_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
assert pitch_i.size(2) == energy_i.size(2)
assert abs(pitch_i.size(2) - pitch.size(2)) <= 10
length = min(pitch_i.size(2), pitch.size(2))
if shift_i > 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[
:, 1 + shift_i :, :length
]
pitch[i : i + 1, -shift_i:, :length] = -10.0
elif shift_i < 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0
pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[
:, 1:shift_i, :length
]
energy[i : i + 1, :, :length] = energy_i[:, :, :length]
torch.backends.cudnn.benchmark = True
# [batch_size, pitch_channels, length] -> Long[batch_size, length], [batch_size, 3, length]
quantized_pitch, pitch_features = pitch_estimator.sample_pitch(
pitch, return_features=True
)
if pitch_shift_semitone is not None:
quantized_pitch = torch.where(
quantized_pitch == 0,
quantized_pitch,
(
quantized_pitch
+ (
pitch_shift_semitone[:, None]
* (pitch_estimator.bins_per_octave / 12.0)
)
.round_()
.long()
).clamp_(1, 383),
)
pitch = 55.0 * 2.0 ** (
quantized_pitch.float() / pitch_estimator.bins_per_octave
)
# phone が 2.5ms 先読みしているのに対して、
# energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、
# ずらして phone に合わせる
energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect")
quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect")
pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect")
# [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length]
pitch_features = torch.cat([energy, pitch_features], dim=1)
formant_shift_indices = (
((formant_shift_semitone + 2.0) * 2.0).round_().long()
)
phone = phone.clone()
quantized_pitch = quantized_pitch.clone()
pitch_features = pitch_features.clone()
formant_shift_indices = formant_shift_indices.clone()
pitch = pitch.clone()
# [batch_sise, hidden_channels, length]
x = (
self.embed_phone(phone)
+ self.embed_quantized_pitch(quantized_pitch).transpose(1, 2)
+ self.embed_pitch_features(pitch_features)
+ (
self.embed_speaker(target_speaker_id)[:, :, None]
+ self.embed_formant_shift(formant_shift_indices)[:, :, None]
)
)
if slice_start_indices is not None:
assert slice_segment_length is not None
# [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length]
x = slice_segments(x, slice_start_indices, slice_segment_length)
x = F.silu(x, inplace=True)
# [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240]
y_g_hat, stats = self.vocoder(x, pitch)
stats["pitch"] = pitch
if return_stats:
return y_g_hat, stats
else:
return y_g_hat
def _normalize_melsp(self, x):
return x.clamp(min=1e-10).log_().mul_(0.5)
def forward_and_compute_loss(
self,
noisy_wavs_16k: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
slice_start_indices: torch.Tensor,
slice_segment_length: int,
y_all: torch.Tensor,
enable_loss_ap: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# noisy_wavs_16k: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
# slice_segment_length: int
# y_all: [batch_size, 1, wav_length]
stats = {}
loss_mel = 0.0
# [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
y_hat_all, intermediates = self(
noisy_wavs_16k,
target_speaker_id,
formant_shift_semitone,
return_stats=True,
)
with torch.amp.autocast("cuda", enabled=False):
periodic_signal = intermediates["periodic_signal"].float()
aperiodic_signal = intermediates["aperiodic_signal"].float()
noise_excitation = intermediates["noise_excitation"].float()
periodic_signal = periodic_signal[:, : noise_excitation.size(1)]
aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)]
y_hat_all = y_hat_all.float()
y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)]
y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)]
for melspectrogram in self.melspectrograms:
melsp_periodic_signal = melspectrogram(periodic_signal)
melsp_aperiodic_signal = melspectrogram(aperiodic_signal)
melsp_noise_excitation = melspectrogram(noise_excitation)
# [1, n_mels, 1]
# 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー
# 3/8 ... ハン窓をかけた時のパワー減衰
# 0.5 ... 謎
reference_melsp = melspectrogram.mel_scale(
torch.full(
(1, melspectrogram.n_fft // 2 + 1, 1),
(1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length,
device=noisy_wavs_16k.device,
)
)
aperiodic_ratio = melsp_aperiodic_signal / (
melsp_periodic_signal + melsp_aperiodic_signal + 1e-5
)
compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5)
melsp_y_hat = melspectrogram(y_hat_all_truncated)
melsp_y_hat = melsp_y_hat * (
(1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio
)
y_hat_mel = self._normalize_melsp(melsp_y_hat)
y_mel = self._normalize_melsp(melspectrogram(y_all_truncated))
loss_mel_i = F.l1_loss(y_hat_mel, y_mel)
loss_mel += loss_mel_i
stats[
f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}"
] = loss_mel_i.item()
loss_mel /= len(self.melspectrograms)
if enable_loss_ap:
t = (
torch.arange(intermediates["pitch"].size(1), device=y_all.device)
* 0.01
)
y_coarse_aperiodicity, y_rms = d4c(
y_all.squeeze(1),
intermediates["pitch"],
t,
self.vocoder.out_sample_rate,
coarse_only=True,
)
y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0)
y_hat_coarse_aperiodicity, y_hat_rms = d4c(
y_hat_all.squeeze(1),
intermediates["pitch"],
t,
self.vocoder.out_sample_rate,
coarse_only=True,
)
y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0)
rms = torch.maximum(y_rms, y_hat_rms)
loss_ap = F.mse_loss(
y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none"
)
loss_ap *= (rms / (rms + 1e-3))[:, :, None]
loss_ap = loss_ap.mean()
else:
loss_ap = torch.tensor(0.0)
# [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
y_hat = slice_segments(
y_hat_all, slice_start_indices * 240, slice_segment_length * 240
)
# [batch_size, 1, wav_length] -> [batch_size, 1, slice_segment_length * 240]
y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240)
return y, y_hat, y_hat_all, loss_mel, loss_ap, stats
def merge_weights(self):
self.vocoder.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed_phone, f)
dump_layer(self.embed_quantized_pitch, f)
dump_layer(self.embed_pitch_features, f)
dump_layer(self.vocoder, f)
# Discriminator
def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor:
denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6)
return tensor / denom
class SANConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=1,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6)
self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
self.scale = nn.parameter.Parameter(scale.view(out_channels))
if bias:
self.bias = nn.parameter.Parameter(
torch.zeros(in_channels, device=device, dtype=dtype)
)
else:
self.register_parameter("bias", None)
def forward(
self, input: torch.Tensor, flg_san_train: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.bias is not None:
input = input + self.bias.view(self.in_channels, 1, 1)
normalized_weight = self._get_normalized_weight()
scale = self.scale.view(self.out_channels, 1, 1)
if flg_san_train:
out_fun = F.conv2d(
input,
normalized_weight.detach(),
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out_dir = F.conv2d(
input.detach(),
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out_fun * scale, out_dir * scale.detach()
else:
out = F.conv2d(
input,
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out * scale
return out
@torch.no_grad()
def normalize_weight(self):
self.weight.data = self._get_normalized_weight()
def _get_normalized_weight(self) -> torch.Tensor:
return _normalize(self.weight, dim=[1, 2, 3])
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return (kernel_size * dilation - dilation) // 2
class DiscriminatorP(nn.Module):
def __init__(
self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False
):
super().__init__()
self.period = period
self.san = san
# fmt: off
self.convs = nn.ModuleList([
weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))),
])
# fmt: on
if san:
self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0))
else:
self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorR(nn.Module):
def __init__(self, resolution: int, san: bool = False):
super().__init__()
self.resolution = resolution
self.san = san
assert len(self.resolution) == 3
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
]
)
if san:
self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1))
else:
self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
x = self._spectrogram(x).unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def _spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
x = F.pad(
x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect"
).squeeze(1)
with torch.amp.autocast("cuda", enabled=False):
mag = torch.stft(
x.float(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=torch.ones(win_length, device=x.device),
center=False,
return_complex=True,
).abs()
return mag
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, san: bool = False):
super().__init__()
resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]]
periods = [2, 3, 5, 7, 11]
self.discriminators = nn.ModuleList(
[DiscriminatorR(r, san=san) for r in resolutions]
+ [DiscriminatorP(p, san=san) for p in periods]
)
self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [
f"P_{p}" for p in periods
]
self.san = san
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False
) -> tuple[
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[list[torch.Tensor]],
list[list[torch.Tensor]],
]:
batch_size = y.size(0)
concatenated_y_y_hat = torch.cat([y, y_hat])
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
if flg_san_train:
(y_d_fun, y_d_dir), fmap = d(
concatenated_y_y_hat, flg_san_train=flg_san_train
)
y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size)
y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size)
y_d_r = y_d_r_fun, y_d_r_dir
y_d_g = y_d_g_fun, y_d_g_dir
else:
y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train)
y_d_r, y_d_g = torch.split(y_d, batch_size)
fmap_r = []
fmap_g = []
for fm in fmap:
fm_r, fm_g = torch.split(fm, batch_size)
fmap_r.append(fm_r)
fmap_g.append(fm_g)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def forward_and_compute_loss(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]:
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san)
stats = {}
assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators)
with torch.amp.autocast("cuda", enabled=False):
# discriminator loss
d_loss = 0.0
for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names):
if self.san:
dr_fun, dr_dir = map(lambda x: x.float(), dr)
dg_fun, dg_dir = map(lambda x: x.float(), dg)
r_loss_fun = F.softplus(1.0 - dr_fun).square().mean()
g_loss_fun = F.softplus(dg_fun).square().mean()
r_loss_dir = F.softplus(1.0 - dr_dir).square().mean()
g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean()
r_loss = r_loss_fun + r_loss_dir
g_loss = g_loss_fun + g_loss_dir
else:
dr = dr.float()
dg = dg.float()
r_loss = (1.0 - dr).square().mean()
g_loss = dg.square().mean()
stats[f"{name}_dr_loss"] = r_loss.item()
stats[f"{name}_dg_loss"] = g_loss.item()
d_loss += r_loss + g_loss
# adversarial loss
adv_loss = 0.0
for dg, name in zip(y_d_gs, self.discriminator_names):
dg = dg.float()
if self.san:
g_loss = F.softplus(1.0 - dg).square().mean()
else:
g_loss = (1.0 - dg).square().mean()
stats[f"{name}_gg_loss"] = g_loss.item()
adv_loss += g_loss
# feature mathcing loss
fm_loss = 0.0
for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names):
fm_loss_i = 0.0
for j, (r, g) in enumerate(zip(fr, fg)):
fm_loss_ij = (r.detach().float() - g.float()).abs().mean()
stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item()
fm_loss_i += fm_loss_ij
stats[f"{name}_fm_loss"] = fm_loss_i.item()
fm_loss += fm_loss_i
return d_loss, adv_loss, fm_loss, stats
# %% [markdown]
# ## Utilities
# %%
class GradBalancer:
"""Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py"""
def __init__(
self,
weights: dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm
self.ema_decay = ema_decay
self.rescale_grads = rescale_grads
self.ema_total: dict[str, float] = defaultdict(float)
self.ema_fix: dict[str, float] = defaultdict(float)
def backward(
self,
losses: dict[str, torch.Tensor],
input: torch.Tensor,
scaler: Optional[torch.amp.GradScaler] = None,
skip_update_ema: bool = False,
) -> dict[str, float]:
stats = {}
if skip_update_ema:
assert len(losses) == len(self.ema_total)
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
else:
# 各 loss に対して d loss / d input とそのノルムを計算する
norms = {}
grads = {}
for name, loss in losses.items():
if scaler is not None:
loss = scaler.scale(loss)
(grad,) = torch.autograd.grad(loss, [input], retain_graph=True)
if not grad.isfinite().all():
input.backward(grad)
return {}
grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale())
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
ema_norm = grad.norm(dim=dims).mean()
else:
ema_norm = grad.norm()
norms[name] = float(ema_norm)
grads[name] = grad
# ノルムの移動平均を計算する
for key, value in norms.items():
self.ema_total[key] = self.ema_total[key] * self.ema_decay + value
self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
# ログを取る
total_ema_norm = sum(ema_norms.values())
for k, ema_norm in ema_norms.items():
stats[f"grad_norm_value_{k}"] = ema_norm
stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12)
# loss の係数の比率を計算する
if self.rescale_grads:
total_weights = sum([self.weights[k] for k in ema_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}
# 勾配を修正する
loss = 0.0
for name, ema_norm in ema_norms.items():
if self.rescale_grads:
scale = ratios[name] * self.total_norm / (ema_norm + 1e-12)
else:
scale = self.weights[name]
loss += (losses if skip_update_ema else grads)[name] * scale
if scaler is not None:
loss = scaler.scale(loss)
if skip_update_ema:
(loss,) = torch.autograd.grad(loss, [input])
input.backward(loss)
return stats
def state_dict(self) -> dict[str, dict[str, float]]:
return {
"ema_total": dict(self.ema_total),
"ema_fix": dict(self.ema_fix),
}
def load_state_dict(self, state_dict):
self.ema_total = defaultdict(float, state_dict["ema_total"])
self.ema_fix = defaultdict(float, state_dict["ema_fix"])
class QualityTester(nn.Module):
def __init__(self):
super().__init__()
self.utmos = torch.hub.load(
"tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True
).eval()
@torch.inference_mode()
def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]:
res = {"utmos": self.utmos(wav, sr=16000).tolist()}
return res
def test(
self, converted_wav: torch.Tensor, source_wav: torch.Tensor
) -> dict[str, list[float]]:
# [batch_size, wav_length]
res = {}
res.update(self.compute_mos(converted_wav))
return res
def test_many(
self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor]
) -> tuple[dict[str, float], dict[str, list[float]]]:
# list[batch_size, wav_length]
results = defaultdict(list)
assert len(converted_wavs) == len(source_wavs)
for converted_wav, source_wav in zip(converted_wavs, source_wavs):
res = self.test(converted_wav, source_wav)
for metric_name, value in res.items():
results[metric_name].extend(value)
return {
metric_name: sum(values) / len(values)
for metric_name, values in results.items()
}, results
def compute_grad_norm(
model: nn.Module, return_stats: bool = False
) -> Union[float, dict[str, float]]:
total_norm = 0.0
stats = {}
for name, p in model.named_parameters():
if p.grad is None:
continue
param_norm = p.grad.data.norm().item()
if not math.isfinite(param_norm):
param_norm = p.grad.data.float().norm().item()
total_norm += param_norm * param_norm
if return_stats:
stats[f"grad_norm_{name}"] = param_norm
total_norm = math.sqrt(total_norm)
if return_stats:
return total_norm, stats
else:
return total_norm
def compute_mean_f0(
files: list[Path], method: Literal["dio", "harvest"] = "dio"
) -> float:
sum_log_f0 = 0.0
n_frames = 0
for file in files:
wav, sr = torchaudio.load(file, backend="soundfile")
if method == "dio":
f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr)
elif method == "harvest":
f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr)
else:
raise ValueError(f"Invalid method: {method}")
f0 = f0[f0 > 0]
sum_log_f0 += float(np.log(f0).sum())
n_frames += len(f0)
if n_frames == 0:
return math.nan
mean_log_f0 = sum_log_f0 / n_frames
return math.exp(mean_log_f0)
# %% [markdown]
# ## Dataset
# %%
def get_resampler(
sr_before: int, sr_after: int, device="cpu", cache={}
) -> torchaudio.transforms.Resample:
if not isinstance(device, str):
device = str(device)
if (sr_before, sr_after, device) not in cache:
cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample(
sr_before, sr_after
).to(device)
return cache[(sr_before, sr_after, device)]
def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor:
n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length()
res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n)
return res[..., : signal.size(-1)]
def random_filter(audio: torch.Tensor) -> torch.Tensor:
assert audio.ndim == 2
ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375
a, b = ab[:, :3], ab[:, 3:]
a[:, 0] = 1.0
b[:, 0] = 1.0
audio = torchaudio.functional.lfilter(audio, a, b, clamp=False)
return audio
def get_noise(
n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]]
) -> torch.Tensor:
resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1]
wavs = []
current_length = 0
while current_length < n_samples:
idx_files = torch.randint(0, len(files), ())
file = files[idx_files]
wav, sr = torchaudio.load(file, backend="soundfile")
assert wav.size(0) == 1
augmented_sample_rate = int(
round(
sample_rate
* resample_augmentation_candidates[
torch.randint(0, len(resample_augmentation_candidates), ())
]
)
)
resampler = get_resampler(sr, augmented_sample_rate)
wav = resampler(wav)
wav = random_filter(wav)
wav *= 0.99 / (wav.abs().max() + 1e-5)
wavs.append(wav)
current_length += wav.size(1)
start = torch.randint(0, current_length - n_samples + 1, ())
wav = torch.cat(wavs, dim=1)[:, start : start + n_samples]
assert wav.size() == (1, n_samples), wav.size()
return wav
def get_butterworth_lpf(
cutoff_freq: int, sample_rate: int, cache={}
) -> tuple[torch.Tensor, torch.Tensor]:
if (cutoff_freq, sample_rate) not in cache:
q = math.sqrt(0.5)
omega = math.tau * cutoff_freq / sample_rate
cos_omega = math.cos(omega)
alpha = math.sin(omega) / (2.0 * q)
b1 = (1.0 - cos_omega) / (1.0 + alpha)
b0 = b1 * 0.5
a1 = -2.0 * cos_omega / (1.0 + alpha)
a2 = (1.0 - alpha) / (1.0 + alpha)
cache[(cutoff_freq, sample_rate)] = torch.tensor([b0, b1, b0]), torch.tensor(
[1.0, a1, a2]
)
return cache[(cutoff_freq, sample_rate)]
def augment_audio(
clean: torch.Tensor,
sample_rate: int,
noise_files: list[Union[str, bytes, os.PathLike]],
ir_files: list[Union[str, bytes, os.PathLike]],
) -> torch.Tensor:
# [1, wav_length]
assert clean.size(0) == 1
n_samples = clean.size(1)
snr_candidates = [-20, -25, -30, -35, -40, -45]
original_clean_rms = clean.square().mean().sqrt_()
# noise を取得して clean と concat する
noise = get_noise(n_samples, sample_rate, noise_files)
signals = torch.cat([clean, noise])
# clean, noise に異なるランダムフィルタをかける
signals = random_filter(signals)
# clean, noise にリバーブをかける
if torch.rand(()) < 0.5:
ir_file = ir_files[torch.randint(0, len(ir_files), ())]
ir, sr = torchaudio.load(ir_file, backend="soundfile")
assert ir.size() == (2, sr), ir.size()
assert sr == sample_rate, (sr, sample_rate)
signals = convolve(signals, ir)
# clean, noise に同じ LPF をかける
if torch.rand(()) < 0.2:
if signals.abs().max() > 0.8:
signals /= signals.abs().max() * 1.25
cutoff_freq_candidates = [2000, 3000, 4000, 6000]
cutoff_freq = cutoff_freq_candidates[
torch.randint(0, len(cutoff_freq_candidates), ())
]
b, a = get_butterworth_lpf(cutoff_freq, sample_rate)
signals = torchaudio.functional.lfilter(signals, a, b, clamp=False)
# clean の音量を合わせる
clean, noise = signals
clean_rms = clean.square().mean().sqrt_()
clean *= original_clean_rms / clean_rms
# clean, noise の音量をピークを重視して取る
clean_level = clean.square().square_().mean().sqrt_().sqrt_()
noise_level = noise.square().square_().mean().sqrt_().sqrt_()
# SNR
snr = snr_candidates[torch.randint(0, len(snr_candidates), ())]
# noisy を生成
noisy = clean + noise * (10.0 ** (snr / 20.0) * clean_level / (noise_level + 1e-5))
return noisy
class WavDataset(torch.utils.data.Dataset):
def __init__(
self,
audio_files: list[tuple[Path, int]],
in_sample_rate: int = 16000,
out_sample_rate: int = 24000,
wav_length: int = 4 * 24000, # 4s
segment_length: int = 100, # 1s
noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
):
self.audio_files = audio_files
self.in_sample_rate = in_sample_rate
self.out_sample_rate = out_sample_rate
self.wav_length = wav_length
self.segment_length = segment_length
self.noise_files = noise_files
self.ir_files = ir_files
if (noise_files is None) is not (ir_files is None):
raise ValueError("noise_files and ir_files must be both None or not None")
self.in_hop_length = in_sample_rate // 100
self.out_hop_length = out_sample_rate // 100 # 10ms 刻み
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]:
file, speaker_id = self.audio_files[index]
clean_wav, sample_rate = torchaudio.load(file, backend="soundfile")
if clean_wav.size(0) != 1:
ch = torch.randint(0, clean_wav.size(0), ())
clean_wav = clean_wav[ch : ch + 1]
formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]
formant_shift = formant_shift_candidates[
torch.randint(0, len(formant_shift_candidates), ()).item()
]
resampler_fraction = Fraction(
sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0)
).limit_denominator(300)
clean_wav = get_resampler(
resampler_fraction.numerator, resampler_fraction.denominator
)(clean_wav)
assert clean_wav.size(0) == 1
assert clean_wav.size(1) != 0
clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length))
if self.noise_files is None:
assert False
noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
else:
clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
noisy_wav_16k = augment_audio(
clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files
)
clean_wav = clean_wav.squeeze_(0)
noisy_wav_16k = noisy_wav_16k.squeeze_(0)
# 音量をランダマイズする
amplitude = torch.rand(()).item() * 0.899 + 0.1
factor = amplitude / clean_wav.abs().max()
clean_wav *= factor
noisy_wav_16k *= factor
while noisy_wav_16k.abs().max() >= 1.0:
clean_wav *= 0.5
noisy_wav_16k *= 0.5
return clean_wav, noisy_wav_16k, speaker_id, formant_shift
def __len__(self) -> int:
return len(self.audio_files)
def collate(
self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.wav_length % self.out_hop_length == 0
length = self.wav_length // self.out_hop_length
clean_wavs = []
noisy_wavs = []
slice_starts = []
speaker_ids = []
formant_shifts = []
for clean_wav, noisy_wav, speaker_id, formant_shift in batch:
# 発声部分をランダムに 1 箇所選ぶ
(voiced,) = clean_wav.nonzero(as_tuple=True)
assert voiced.numel() != 0
center = voiced[torch.randint(0, voiced.numel(), ()).item()].item()
# 発声部分が中央にくるように、スライス区間を選ぶ
slice_start = center - self.segment_length * self.out_hop_length // 2
assert slice_start >= 0
# スライス区間が含まれるように、ランダムに wav_length の長さを切り出す
r = torch.randint(0, length - self.segment_length + 1, ()).item()
offset = slice_start - r * self.out_hop_length
clean_wavs.append(clean_wav[offset : offset + self.wav_length])
offset_in_sample_rate = int(
round(offset * self.in_sample_rate / self.out_sample_rate)
)
noisy_wavs.append(
noisy_wav[
offset_in_sample_rate : offset_in_sample_rate
+ length * self.in_hop_length
]
)
slice_start = r
slice_starts.append(slice_start)
speaker_ids.append(speaker_id)
formant_shifts.append(formant_shift)
clean_wavs = torch.stack(clean_wavs)
noisy_wavs = torch.stack(noisy_wavs)
slice_starts = torch.tensor(slice_starts)
speaker_ids = torch.tensor(speaker_ids)
formant_shifts = torch.tensor(formant_shifts)
return (
clean_wavs, # [batch_size, wav_length]
noisy_wavs, # [batch_size, wav_length]
slice_starts, # Long[batch_size]
speaker_ids, # Long[batch_size]
formant_shifts, # Long[batch_size]
)
# %% [markdown]
# ## Train
# %%
AUDIO_FILE_SUFFIXES = {
".wav",
".aif",
".aiff",
".fla",
".flac",
".oga",
".ogg",
".opus",
".mp3",
}
def prepare_training():
# 各種準備をする
# 副作用として、出力ディレクトリと TensorBoard のログファイルなどが生成される
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
(h, in_wav_dataset_dir, out_dir, resume, skip_training) = (
prepare_training_configs_for_experiment
if is_notebook()
else prepare_training_configs
)()
print("config:")
pprint(h)
print()
h = AttrDict(h)
if not in_wav_dataset_dir.is_dir():
raise ValueError(f"{in_wav_dataset_dir} is not found.")
if resume:
latest_checkpoint_file = out_dir / "checkpoint_latest.pt"
if not latest_checkpoint_file.is_file():
raise ValueError(f"{latest_checkpoint_file} is not found.")
else:
if out_dir.is_dir():
if (out_dir / "checkpoint_latest.pt").is_file():
raise ValueError(
f"{out_dir / 'checkpoint_latest.pt'} already exists. "
"Please specify a different output directory, or use --resume option."
)
for file in out_dir.iterdir():
if file.suffix == ".pt":
raise ValueError(
f"{out_dir} already contains model files. "
"Please specify a different output directory."
)
else:
out_dir.mkdir(parents=True)
in_ir_wav_dir = repo_root() / h.in_ir_wav_dir
in_noise_wav_dir = repo_root() / h.in_noise_wav_dir
in_test_wav_dir = repo_root() / h.in_test_wav_dir
assert in_wav_dataset_dir.is_dir(), in_wav_dataset_dir
assert out_dir.is_dir(), out_dir
assert in_ir_wav_dir.is_dir(), in_ir_wav_dir
assert in_noise_wav_dir.is_dir(), in_noise_wav_dir
assert in_test_wav_dir.is_dir(), in_test_wav_dir
# .wav または *.flac のファイルを再帰的に取得
noise_files = sorted(
list(in_noise_wav_dir.rglob("*.wav")) + list(in_noise_wav_dir.rglob("*.flac"))
)
if len(noise_files) == 0:
raise ValueError(f"No audio data found in {in_noise_wav_dir}.")
ir_files = sorted(
list(in_ir_wav_dir.rglob("*.wav")) + list(in_ir_wav_dir.rglob("*.flac"))
)
if len(ir_files) == 0:
raise ValueError(f"No audio data found in {in_ir_wav_dir}.")
# TODO: 無音除去とか
def get_training_filelist(in_wav_dataset_dir: Path):
min_data_per_speaker = 1
speakers: list[str] = []
training_filelist: list[tuple[Path, int]] = []
speaker_audio_files: list[list[Path]] = []
for speaker_dir in sorted(in_wav_dataset_dir.iterdir()):
if not speaker_dir.is_dir():
continue
candidates = []
for wav_file in sorted(speaker_dir.rglob("*")):
if (
not wav_file.is_file()
or wav_file.suffix.lower() not in AUDIO_FILE_SUFFIXES
):
continue
candidates.append(wav_file)
if len(candidates) >= min_data_per_speaker:
speaker_id = len(speakers)
speakers.append(speaker_dir.name)
training_filelist.extend([(file, speaker_id) for file in candidates])
speaker_audio_files.append(candidates)
return speakers, training_filelist, speaker_audio_files
speakers, training_filelist, speaker_audio_files = get_training_filelist(
in_wav_dataset_dir
)
n_speakers = len(speakers)
if n_speakers == 0:
raise ValueError(f"No speaker data found in {in_wav_dataset_dir}.")
print(f"{n_speakers=}")
for i, speaker in enumerate(speakers):
print(f" {i:{len(str(n_speakers - 1))}d}: {speaker}")
print()
print(f"{len(training_filelist)=}")
def get_test_filelist(
in_test_wav_dir: Path, n_speakers: int
) -> list[tuple[Path, list[int]]]:
max_n_test_files = 1000
test_filelist = []
rng = Random(42)
def get_target_id_generator():
if n_speakers > 8:
while True:
order = list(range(n_speakers))
rng.shuffle(order)
yield from order
else:
while True:
yield from range(n_speakers)
target_id_generator = get_target_id_generator()
for file in sorted(in_test_wav_dir.iterdir())[:max_n_test_files]:
if file.suffix.lower() not in AUDIO_FILE_SUFFIXES:
continue
target_ids = [next(target_id_generator) for _ in range(min(8, n_speakers))]
test_filelist.append((file, target_ids))
return test_filelist
test_filelist = get_test_filelist(in_test_wav_dir, n_speakers)
if len(test_filelist) == 0:
warnings.warn(f"No audio data found in {test_filelist}.")
print(f"{len(test_filelist)=}")
for file, target_ids in test_filelist[:12]:
print(f" {file}, {target_ids}")
if len(test_filelist) > 12:
print(" ...")
print()
# データ
training_dataset = WavDataset(
training_filelist,
in_sample_rate=h.in_sample_rate,
out_sample_rate=h.out_sample_rate,
wav_length=h.wav_length,
segment_length=h.segment_length,
noise_files=noise_files,
ir_files=ir_files,
)
training_loader = torch.utils.data.DataLoader(
training_dataset,
num_workers=min(h.num_workers, os.cpu_count()),
collate_fn=training_dataset.collate,
shuffle=True,
sampler=None,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True,
persistent_workers=True,
)
print("Computing mean F0s of target speakers...", end="")
speaker_f0s = []
for speaker, files in enumerate(speaker_audio_files):
if len(files) > 10:
files = Random(42).sample(files, 10)
f0 = compute_mean_f0(files)
speaker_f0s.append(f0)
if speaker % 5 == 0:
print()
print(f" {speaker:3d}: {f0:.1f}Hz", end=",")
print()
print("Done.")
print("Computing pitch shifts for test files...")
test_pitch_shifts = []
source_f0s = []
for i, (file, target_ids) in enumerate(tqdm(test_filelist)):
source_f0 = compute_mean_f0([file], method="harvest")
source_f0s.append(source_f0)
if math.isnan(source_f0):
test_pitch_shifts.append([0] * len(target_ids))
continue
pitch_shifts = []
for target_id in target_ids:
target_f0 = speaker_f0s[target_id]
if target_f0 != target_f0:
pitch_shift = 0
else:
pitch_shift = int(round(12.0 * math.log2(target_f0 / source_f0)))
pitch_shifts.append(pitch_shift)
test_pitch_shifts.append(pitch_shifts)
print("Done.")
# モデルと最適化
phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False)
phone_extractor_checkpoint = torch.load(
repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True
)
print(
phone_extractor.load_state_dict(phone_extractor_checkpoint["phone_extractor"])
)
del phone_extractor_checkpoint
pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False)
pitch_estimator_checkpoint = torch.load(
repo_root() / h.pitch_estimator_file, map_location="cpu", weights_only=True
)
print(
pitch_estimator.load_state_dict(pitch_estimator_checkpoint["pitch_estimator"])
)
del pitch_estimator_checkpoint
net_g = ConverterNetwork(
phone_extractor,
pitch_estimator,
n_speakers,
h.hidden_channels,
).to(device)
net_d = MultiPeriodDiscriminator(san=h.san).to(device)
optim_g = torch.optim.AdamW(
net_g.parameters(),
h.learning_rate_g,
betas=h.adam_betas,
eps=h.adam_eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
h.learning_rate_d,
betas=h.adam_betas,
eps=h.adam_eps,
)
grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp)
grad_balancer = GradBalancer(
weights={
"loss_mel": h.grad_weight_mel,
"loss_adv": h.grad_weight_adv,
"loss_fm": h.grad_weight_fm,
}
| ({"loss_ap": h.grad_weight_ap} if h.grad_weight_ap else {}),
ema_decay=h.grad_balancer_ema_decay,
)
resample_to_in_sample_rate = torchaudio.transforms.Resample(
h.out_sample_rate, h.in_sample_rate
).to(device)
# チェックポイント読み出し
initial_iteration = 0
if resume:
checkpoint_file = latest_checkpoint_file
elif h.pretrained_file is not None:
checkpoint_file = repo_root() / h.pretrained_file
else:
checkpoint_file = None
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
if not resume and not skip_training: # ファインチューニング
checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"])
initial_speaker_embedding = checkpoint["net_g"][
"embed_speaker.weight"
].mean(0, keepdim=True)
if True:
checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[
[0] * n_speakers
]
else: # 話者追加用
assert n_speakers > checkpoint_n_speakers
print(
f"embed_speaker.weight was padded: {checkpoint_n_speakers} -> {n_speakers}"
)
checkpoint["net_g"]["embed_speaker.weight"] = F.pad(
checkpoint["net_g"]["embed_speaker.weight"],
(0, 0, 0, n_speakers - checkpoint_n_speakers),
)
checkpoint["net_g"]["embed_speaker.weight"][
checkpoint_n_speakers:
] = initial_speaker_embedding
print(net_g.load_state_dict(checkpoint["net_g"], strict=False))
print(net_d.load_state_dict(checkpoint["net_d"], strict=False))
if resume or skip_training:
optim_g.load_state_dict(checkpoint["optim_g"])
optim_d.load_state_dict(checkpoint["optim_d"])
initial_iteration = checkpoint["iteration"]
grad_balancer.load_state_dict(checkpoint["grad_balancer"])
grad_scaler.load_state_dict(checkpoint["grad_scaler"])
# スケジューラ
def get_cosine_annealing_warmup_scheduler(
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
min_learning_rate: float,
) -> torch.optim.lr_scheduler.LambdaLR:
lr_ratio = min_learning_rate / optimizer.param_groups[0]["lr"]
m = 0.5 * (1.0 - lr_ratio)
a = 0.5 * (1.0 + lr_ratio)
def lr_lambda(current_epoch: int) -> float:
if current_epoch < warmup_epochs:
return current_epoch / warmup_epochs
elif current_epoch < total_epochs:
rate = (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs)
return math.cos(rate * math.pi) * m + a
else:
return min_learning_rate
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler_g = get_cosine_annealing_warmup_scheduler(
optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate_g
)
scheduler_d = get_cosine_annealing_warmup_scheduler(
optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate_d
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.",
)
for _ in range(initial_iteration + 1):
scheduler_g.step()
scheduler_d.step()
net_g.train()
net_d.train()
# ログとか
dict_scalars = defaultdict(list)
quality_tester = QualityTester().eval().to(device)
if skip_training:
writer = None
else:
writer = SummaryWriter(out_dir)
writer.add_text(
"log",
f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.",
initial_iteration,
)
if not resume:
with open(out_dir / "config.json", "w", encoding="utf-8") as f:
json.dump(dict(h), f, indent=4)
if not is_notebook():
shutil.copy(__file__, out_dir)
return (
device,
in_wav_dataset_dir,
h,
out_dir,
speakers,
test_filelist,
training_loader,
speaker_f0s,
test_pitch_shifts,
phone_extractor,
pitch_estimator,
net_g,
net_d,
optim_g,
optim_d,
grad_scaler,
grad_balancer,
resample_to_in_sample_rate,
initial_iteration,
scheduler_g,
scheduler_d,
dict_scalars,
quality_tester,
writer,
)
if __name__ == "__main__":
(
device,
in_wav_dataset_dir,
h,
out_dir,
speakers,
test_filelist,
training_loader,
speaker_f0s,
test_pitch_shifts,
phone_extractor,
pitch_estimator,
net_g,
net_d,
optim_g,
optim_d,
grad_scaler,
grad_balancer,
resample_to_in_sample_rate,
initial_iteration,
scheduler_g,
scheduler_d,
dict_scalars,
quality_tester,
writer,
) = prepare_training()
if __name__ == "__main__" and writer is not None:
if h.compile_convnext:
raw_convnextstack_forward = ConvNeXtStack.forward
compiled_convnextstack_forward = torch.compile(
ConvNeXtStack.forward, mode="reduce-overhead"
)
if h.compile_d4c:
d4c = torch.compile(d4c, mode="reduce-overhead")
if h.compile_discriminator:
MultiPeriodDiscriminator.forward_and_compute_loss = torch.compile(
MultiPeriodDiscriminator.forward_and_compute_loss, mode="reduce-overhead"
)
# 学習
with (
torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1500, warmup=10, active=5, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(out_dir),
record_shapes=True,
with_stack=True,
profile_memory=True,
with_flops=True,
)
if h.profile
else nullcontext()
) as profiler:
for iteration in tqdm(range(initial_iteration, h.n_steps)):
# === 1. データ前処理 ===
try:
batch = next(data_iter)
except:
data_iter = iter(training_loader)
batch = next(data_iter)
(
clean_wavs,
noisy_wavs_16k,
slice_starts,
speaker_ids,
formant_shift_semitone,
) = map(lambda x: x.to(device, non_blocking=True), batch)
# === 2. 学習 ===
with torch.amp.autocast("cuda", enabled=h.use_amp):
# === 2.1 Generator の順伝播 ===
if h.compile_convnext:
ConvNeXtStack.forward = compiled_convnextstack_forward
y, y_hat, y_hat_for_backward, loss_mel, loss_ap, generator_stats = (
net_g.forward_and_compute_loss(
noisy_wavs_16k[:, None, :],
speaker_ids,
formant_shift_semitone,
slice_start_indices=slice_starts,
slice_segment_length=h.segment_length,
y_all=clean_wavs[:, None, :],
enable_loss_ap=h.grad_weight_ap != 0.0,
)
)
if h.compile_convnext:
ConvNeXtStack.forward = raw_convnextstack_forward
assert y_hat.isfinite().all()
assert loss_mel.isfinite().all()
assert loss_ap.isfinite().all()
# === 2.2 Discriminator の順伝播 ===
loss_discriminator, loss_adv, loss_fm, discriminator_stats = (
net_d.forward_and_compute_loss(y, y_hat)
)
assert loss_discriminator.isfinite().all()
assert loss_adv.isfinite().all()
assert loss_fm.isfinite().all()
# === 2.3 Discriminator の逆伝播 ===
for param in net_d.parameters():
assert param.grad is None
grad_scaler.scale(loss_discriminator).backward(
retain_graph=True, inputs=list(net_d.parameters())
)
loss_discriminator = loss_discriminator.item()
grad_scaler.unscale_(optim_d)
if iteration % 5 == 0:
grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True)
else:
grad_norm_d = math.nan
d_grad_norm_stats = {}
# === 2.4 Generator の逆伝播 ===
for param in net_g.parameters():
assert param.grad is None
gradient_balancer_stats = grad_balancer.backward(
{
"loss_mel": loss_mel,
"loss_adv": loss_adv,
"loss_fm": loss_fm,
}
| ({"loss_ap": loss_ap} if h.grad_weight_ap else {}),
y_hat_for_backward,
grad_scaler,
skip_update_ema=iteration > 10 and iteration % 5 != 0,
)
loss_mel = loss_mel.item()
loss_adv = loss_adv.item()
loss_fm = loss_fm.item()
if h.grad_weight_ap:
loss_ap = loss_ap.item()
grad_scaler.unscale_(optim_g)
if iteration % 5 == 0:
grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True)
else:
grad_norm_g = math.nan
g_grad_norm_stats = {}
# === 2.5 パラメータの更新 ===
grad_scaler.step(optim_g)
optim_g.zero_grad(set_to_none=True)
grad_scaler.step(optim_d)
optim_d.zero_grad(set_to_none=True)
grad_scaler.update()
# === 3. ログ ===
dict_scalars["loss_g/loss_mel"].append(loss_mel)
if h.grad_weight_ap:
dict_scalars["loss_g/loss_ap"].append(loss_ap)
dict_scalars["loss_g/loss_fm"].append(loss_fm)
dict_scalars["loss_g/loss_adv"].append(loss_adv)
dict_scalars["other/grad_scale"].append(grad_scaler.get_scale())
dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator)
if math.isfinite(grad_norm_d):
dict_scalars["other/gradient_norm_d"].append(grad_norm_d)
for name, value in d_grad_norm_stats.items():
dict_scalars[f"~gradient_norm_d/{name}"].append(value)
if math.isfinite(grad_norm_g):
dict_scalars["other/gradient_norm_g"].append(grad_norm_g)
for name, value in g_grad_norm_stats.items():
dict_scalars[f"~gradient_norm_g/{name}"].append(value)
dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0])
dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0])
for k, v in generator_stats.items():
dict_scalars[f"~loss_generator/{k}"].append(v)
for k, v in discriminator_stats.items():
dict_scalars[f"~loss_discriminator/{k}"].append(v)
for k, v in gradient_balancer_stats.items():
dict_scalars[f"~gradient_balancer/{k}"].append(v)
if (iteration + 1) % 1000 == 0 or iteration == 0:
for name, scalars in dict_scalars.items():
if scalars:
writer.add_scalar(
name, sum(scalars) / len(scalars), iteration + 1
)
scalars.clear()
for name, param in net_g.named_parameters():
writer.add_histogram(f"weight/{name}", param, iteration + 1)
intermediate_feature_stats = {}
hook_handles = []
def get_layer_hook(name):
def compute_stats(module, x, suffix):
if not isinstance(x, torch.Tensor):
return
if x.dtype not in [torch.float32, torch.float16]:
return
if isinstance(module, nn.Identity):
return
x = x.detach().float()
var = x.var().item()
if isinstance(module, (nn.Linear, nn.LayerNorm)):
channel_var, channel_mean = torch.var_mean(
x.reshape(-1, x.size(-1)), 0
)
elif isinstance(module, nn.Conv1d):
channel_var, channel_mean = torch.var_mean(x, [0, 2])
else:
return
average_squared_channel_mean = (
channel_mean.square().mean().item()
)
average_channel_var = channel_var.mean().item()
tensor_idx = len(intermediate_feature_stats) // 3
intermediate_feature_stats[
f"var/{tensor_idx:02d}_{name}/{suffix}"
] = var
intermediate_feature_stats[
f"avg_sq_ch_mean/{tensor_idx:02d}_{name}/{suffix}"
] = average_squared_channel_mean
intermediate_feature_stats[
f"avg_ch_var/{tensor_idx:02d}_{name}/{suffix}"
] = average_channel_var
def forward_pre_hook(module, input):
for i, input_i in enumerate(input):
compute_stats(module, input_i, f"input_{i}")
def forward_hook(module, input, output):
if isinstance(output, tuple):
for i, output_i in enumerate(output):
compute_stats(module, output_i, f"output_{i}")
else:
compute_stats(module, output, "output")
return forward_pre_hook, forward_hook
for name, layer in net_g.named_modules():
forward_pre_hook, forward_hook = get_layer_hook(name)
hook_handles.append(
layer.register_forward_pre_hook(forward_pre_hook)
)
hook_handles.append(layer.register_forward_hook(forward_hook))
with torch.no_grad(), torch.amp.autocast("cuda", enabled=h.use_amp):
net_g.forward_and_compute_loss(
noisy_wavs_16k[:, None, :],
speaker_ids,
formant_shift_semitone,
slice_start_indices=slice_starts,
slice_segment_length=h.segment_length,
y_all=clean_wavs[:, None, :],
enable_loss_ap=h.grad_weight_ap != 0.0,
)
for handle in hook_handles:
handle.remove()
for name, value in intermediate_feature_stats.items():
writer.add_scalar(
f"~intermediate_feature_{name}", value, iteration + 1
)
# === 4. 検証 ===
if (iteration + 1) % (
50000 if h.n_steps > 200000 else 2000
) == 0 or iteration + 1 in {
1,
30000,
h.n_steps,
}:
torch.backends.cudnn.benchmark = False
net_g.eval()
torch.cuda.empty_cache()
dict_qualities_all = defaultdict(list)
n_added_wavs = 0
with torch.inference_mode():
for i, ((file, target_ids), pitch_shift_semitones) in enumerate(
zip(test_filelist, test_pitch_shifts)
):
source_wav, sr = torchaudio.load(file, backend="soundfile")
source_wav = source_wav.to(device)
if sr != h.in_sample_rate:
source_wav = get_resampler(sr, h.in_sample_rate, device)(
source_wav
)
source_wav = source_wav.to(device)
original_source_wav_length = source_wav.size(1)
# 長さのパターンを減らしてキャッシュを効かせる
if source_wav.size(1) % h.in_sample_rate == 0:
padded_source_wav = source_wav
else:
padded_source_wav = F.pad(
source_wav,
(
0,
h.in_sample_rate
- source_wav.size(1) % h.in_sample_rate,
),
)
converted = net_g(
padded_source_wav[[0] * len(target_ids), None],
torch.tensor(target_ids, device=device),
torch.tensor(
[0.0] * len(target_ids), device=device
), # フォルマントシフト
torch.tensor(
[float(p) for p in pitch_shift_semitones], device=device
),
).squeeze_(1)[:, : original_source_wav_length // 160 * 240]
if i < 12:
if iteration == 0:
writer.add_audio(
f"source/y_{i:02d}",
source_wav,
iteration + 1,
h.in_sample_rate,
)
for d in range(
min(
len(target_ids),
1 + (12 - i - 1) // len(test_filelist),
)
):
idx_in_batch = n_added_wavs % len(target_ids)
writer.add_audio(
f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}",
converted[idx_in_batch],
iteration + 1,
h.out_sample_rate,
)
n_added_wavs += 1
converted = resample_to_in_sample_rate(converted)
quality = quality_tester.test(converted, source_wav)
for metric_name, values in quality.items():
dict_qualities_all[metric_name].extend(values)
assert n_added_wavs == min(
12, len(test_filelist) * len(test_filelist[0][1])
), (
n_added_wavs,
len(test_filelist),
len(speakers),
len(test_filelist[0][1]),
)
dict_qualities = {
metric_name: sum(values) / len(values)
for metric_name, values in dict_qualities_all.items()
if len(values)
}
for metric_name, value in dict_qualities.items():
writer.add_scalar(f"validation/{metric_name}", value, iteration + 1)
for metric_name, values in dict_qualities_all.items():
for i, value in enumerate(values):
writer.add_scalar(
f"~validation_{metric_name}/{i:03d}", value, iteration + 1
)
del dict_qualities, dict_qualities_all
net_g.train()
torch.backends.cudnn.benchmark = True
gc.collect()
torch.cuda.empty_cache()
# === 5. 保存 ===
if (iteration + 1) % (
50000 if h.n_steps > 200000 else 2000
) == 0 or iteration + 1 in {
1,
30000,
h.n_steps,
}:
# チェックポイント
name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}"
checkpoint_file_save = out_dir / f"checkpoint_{name}.pt"
if checkpoint_file_save.exists():
checkpoint_file_save = checkpoint_file_save.with_name(
f"{checkpoint_file_save.name}_{hash(None):x}"
)
torch.save(
{
"iteration": iteration + 1,
"net_g": net_g.state_dict(),
"phone_extractor": phone_extractor.state_dict(),
"pitch_estimator": pitch_estimator.state_dict(),
"net_d": net_d.state_dict(),
"optim_g": optim_g.state_dict(),
"optim_d": optim_d.state_dict(),
"grad_balancer": grad_balancer.state_dict(),
"grad_scaler": grad_scaler.state_dict(),
"h": dict(h),
},
checkpoint_file_save,
)
shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt")
# 推論用
paraphernalia_dir = out_dir / f"paraphernalia_{name}"
if paraphernalia_dir.exists():
paraphernalia_dir = paraphernalia_dir.with_name(
f"{paraphernalia_dir.name}_{hash(None):x}"
)
paraphernalia_dir.mkdir()
phone_extractor_fp16 = PhoneExtractor()
phone_extractor_fp16.load_state_dict(phone_extractor.state_dict())
phone_extractor_fp16.remove_weight_norm()
phone_extractor_fp16.merge_weights()
phone_extractor_fp16.half()
phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin")
del phone_extractor_fp16
pitch_estimator_fp16 = PitchEstimator()
pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict())
pitch_estimator_fp16.merge_weights()
pitch_estimator_fp16.half()
pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin")
del pitch_estimator_fp16
net_g_fp16 = ConverterNetwork(
nn.Module(), nn.Module(), len(speakers), h.hidden_channels
)
net_g_fp16.load_state_dict(net_g.state_dict())
net_g_fp16.merge_weights()
net_g_fp16.half()
net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin")
with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f:
dump_layer(net_g_fp16.embed_speaker, f)
with open(
paraphernalia_dir / f"formant_shift_embeddings.bin", "wb"
) as f:
dump_layer(net_g_fp16.embed_formant_shift, f)
del net_g_fp16
shutil.copy(
repo_root() / "assets/images/noimage.png", paraphernalia_dir
)
with open(
paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml",
"w",
encoding="utf-8",
) as f:
f.write(
f'''[model]
version = "{PARAPHERNALIA_VERSION}"
name = "{name}"
description = """
No description for this model.
このモデルの説明はありません。
"""
'''
)
for speaker_id, (speaker, speaker_f0) in enumerate(
zip(speakers, speaker_f0s)
):
average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0)
average_pitch = round(average_pitch * 8.0) / 8.0
f.write(
f'''
[voice.{speaker_id}]
name = "{speaker}"
description = """
No description for this voice.
この声の説明はありません。
"""
average_pitch = {average_pitch}
[voice.{speaker_id}.portrait]
path = "noimage.png"
description = """
"""
'''
)
del paraphernalia_dir
# TODO: phone_extractor, pitch_estimator が既知のモデルであれば dump を省略
# === 6. スケジューラ更新 ===
scheduler_g.step()
scheduler_d.step()
if h.profile:
profiler.step()
print("Training finished.")