|
|
|
|
|
|
|
|
|
import argparse |
|
import gc |
|
import json |
|
import math |
|
import os |
|
import shutil |
|
import warnings |
|
from collections import defaultdict |
|
from contextlib import nullcontext |
|
from copy import deepcopy |
|
from fractions import Fraction |
|
from functools import partial |
|
from pathlib import Path |
|
from pprint import pprint |
|
from random import Random |
|
from typing import BinaryIO, Literal, Optional, Union |
|
|
|
import numpy as np |
|
import pyworld |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from torch.nn import functional as F |
|
from torch.nn.utils import remove_weight_norm, weight_norm |
|
from torch.utils.tensorboard import SummaryWriter |
|
from tqdm.auto import tqdm |
|
|
|
assert "soundfile" in torchaudio.list_audio_backends() |
|
if not hasattr(torch.amp, "GradScaler"): |
|
|
|
class GradScaler(torch.cuda.amp.GradScaler): |
|
def __init__(self, _, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
torch.amp.GradScaler = GradScaler |
|
|
|
|
|
|
|
PARAPHERNALIA_VERSION = "2.0.0-beta.1" |
|
|
|
|
|
def is_notebook() -> bool: |
|
return "get_ipython" in globals() |
|
|
|
|
|
def repo_root() -> Path: |
|
d = Path.cwd() / "dummy" if is_notebook() else Path(__file__) |
|
assert d.is_absolute(), d |
|
for d in d.parents: |
|
if (d / ".git").is_dir(): |
|
return d |
|
raise RuntimeError("Repository root is not found.") |
|
|
|
|
|
|
|
|
|
dict_default_hparams = { |
|
|
|
"learning_rate_g": 2e-4, |
|
"learning_rate_d": 1e-4, |
|
"min_learning_rate_g": 1e-5, |
|
"min_learning_rate_d": 5e-6, |
|
"adam_betas": [0.8, 0.99], |
|
"adam_eps": 1e-6, |
|
"batch_size": 8, |
|
"grad_weight_mel": 1.0, |
|
"grad_weight_ap": 2.0, |
|
"grad_weight_adv": 3.0, |
|
"grad_weight_fm": 3.0, |
|
"grad_balancer_ema_decay": 0.995, |
|
"use_amp": True, |
|
"num_workers": 16, |
|
"n_steps": 10000, |
|
"warmup_steps": 2000, |
|
"in_sample_rate": 16000, |
|
"out_sample_rate": 24000, |
|
"wav_length": 4 * 24000, |
|
"segment_length": 100, |
|
|
|
"phone_extractor_file": "assets/pretrained/003b_checkpoint_03000000.pt", |
|
"pitch_estimator_file": "assets/pretrained/008_1_checkpoint_00300000.pt", |
|
"in_ir_wav_dir": "assets/ir", |
|
"in_noise_wav_dir": "assets/noise", |
|
"in_test_wav_dir": "assets/test", |
|
"pretrained_file": "assets/pretrained/079_checkpoint_libritts_r_200_02400000.pt", |
|
|
|
"hidden_channels": 256, |
|
"san": False, |
|
"compile_convnext": False, |
|
"compile_d4c": False, |
|
"compile_discriminator": False, |
|
"profile": False, |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
default_config_file = repo_root() / "assets/default_config.json" |
|
if default_config_file.is_file(): |
|
with open(default_config_file, encoding="utf-8") as f: |
|
default_config: dict = json.load(f) |
|
for key, value in dict_default_hparams.items(): |
|
if key not in default_config: |
|
warnings.warn(f"{key} not found in default_config.json.") |
|
else: |
|
if value != default_config[key]: |
|
warnings.warn( |
|
f"{key} differs between default_config.json ({default_config[key]}) and internal default hparams ({value})." |
|
) |
|
del default_config[key] |
|
for key in default_config: |
|
warnings.warn(f"{key} found in default_config.json is unknown.") |
|
else: |
|
warnings.warn("dafualt_config.json not found.") |
|
|
|
|
|
def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool, bool]: |
|
import ipynbname |
|
from IPython import get_ipython |
|
|
|
h = deepcopy(dict_default_hparams) |
|
in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200" |
|
try: |
|
notebook_name = ipynbname.name() |
|
except FileNotFoundError: |
|
notebook_name = Path(get_ipython().user_ns["__vsc_ipynb_file__"]).name |
|
out_dir = repo_root() / "notebooks" / notebook_name.split(".")[0].split("_")[0] |
|
resume = False |
|
skip_training = False |
|
return h, in_wav_dataset_dir, out_dir, resume, skip_training |
|
|
|
|
|
def prepare_training_configs() -> tuple[dict, Path, Path, bool, bool]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("-d", "--data_dir", type=Path, help="directory containing the training data") |
|
parser.add_argument("-o", "--out_dir", type=Path, help="output directory") |
|
parser.add_argument("-r", "--resume", action="store_true", help="resume training") |
|
parser.add_argument("-c", "--config", type=Path, help="path to the config file") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.config is None: |
|
h = deepcopy(dict_default_hparams) |
|
else: |
|
with open(args.config, encoding="utf-8") as f: |
|
h = json.load(f) |
|
for key in dict_default_hparams.keys(): |
|
if key not in h: |
|
h[key] = dict_default_hparams[key] |
|
warnings.warn( |
|
f"{key} is not specified in the config file. Using the default value." |
|
) |
|
|
|
if args.data_dir is not None: |
|
in_wav_dataset_dir = args.data_dir |
|
elif "data_dir" in h: |
|
in_wav_dataset_dir = repo_root() / Path(h["data_dir"]) |
|
del h["data_dir"] |
|
else: |
|
raise ValueError( |
|
"data_dir must be specified. " |
|
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`." |
|
) |
|
|
|
if args.out_dir is not None: |
|
out_dir = args.out_dir |
|
elif "out_dir" in h: |
|
out_dir = repo_root() / Path(h["out_dir"]) |
|
del h["out_dir"] |
|
else: |
|
raise ValueError( |
|
"out_dir must be specified. " |
|
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`." |
|
) |
|
for key in list(h.keys()): |
|
if key not in dict_default_hparams: |
|
warnings.warn(f"`{key}` specified in the config file will be ignored.") |
|
del h[key] |
|
|
|
resume = args.resume |
|
return h, in_wav_dataset_dir, out_dir, resume, False |
|
|
|
|
|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dump_params(params: torch.Tensor, f: BinaryIO): |
|
if params is None: |
|
return |
|
if params.dtype == torch.bfloat16: |
|
f.write( |
|
params.detach() |
|
.clone() |
|
.float() |
|
.view(torch.short) |
|
.numpy() |
|
.ravel()[1::2] |
|
.tobytes() |
|
) |
|
else: |
|
f.write(params.detach().numpy().ravel().tobytes()) |
|
f.flush() |
|
|
|
|
|
def dump_layer(layer: nn.Module, f: BinaryIO): |
|
dump = partial(dump_params, f=f) |
|
if hasattr(layer, "dump"): |
|
layer.dump(f) |
|
elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)): |
|
dump(layer.weight) |
|
dump(layer.bias) |
|
elif isinstance(layer, nn.ConvTranspose1d): |
|
dump(layer.weight.transpose(0, 1)) |
|
dump(layer.bias) |
|
elif isinstance(layer, nn.GRU): |
|
dump(layer.weight_ih_l0) |
|
dump(layer.bias_ih_l0) |
|
dump(layer.weight_hh_l0) |
|
dump(layer.bias_hh_l0) |
|
for i in range(1, 99999): |
|
if not hasattr(layer, f"weight_ih_l{i}"): |
|
break |
|
dump(getattr(layer, f"weight_ih_l{i}")) |
|
dump(getattr(layer, f"bias_ih_l{i}")) |
|
dump(getattr(layer, f"weight_hh_l{i}")) |
|
dump(getattr(layer, f"bias_hh_l{i}")) |
|
elif isinstance(layer, nn.Embedding): |
|
dump(layer.weight) |
|
elif isinstance(layer, nn.Parameter): |
|
dump(layer) |
|
elif isinstance(layer, nn.ModuleList): |
|
for l in layer: |
|
dump_layer(l, f) |
|
else: |
|
assert False, layer |
|
|
|
|
|
class CausalConv1d(nn.Conv1d): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
delay: int = 0, |
|
): |
|
padding = (kernel_size - 1) * dilation - delay |
|
self.trim = (kernel_size - 1) * dilation - 2 * delay |
|
if self.trim < 0: |
|
raise ValueError |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
result = super().forward(input) |
|
if self.trim == 0: |
|
return result |
|
else: |
|
return result[:, :, : -self.trim] |
|
|
|
|
|
class WSConv1d(CausalConv1d): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
dilation: int = 1, |
|
groups: int = 1, |
|
bias: bool = True, |
|
delay: int = 0, |
|
): |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
delay=delay, |
|
) |
|
self.weight.data.normal_( |
|
0.0, math.sqrt(1.0 / (in_channels * kernel_size // groups)) |
|
) |
|
if bias: |
|
self.bias.data.zero_() |
|
self.gain = nn.Parameter(torch.ones((out_channels, 1, 1))) |
|
|
|
def standardized_weight(self) -> torch.Tensor: |
|
var, mean = torch.var_mean(self.weight, [1, 2], keepdim=True) |
|
scale = ( |
|
self.gain |
|
* ( |
|
self.in_channels * self.kernel_size[0] // self.groups * var + 1e-8 |
|
).rsqrt() |
|
) |
|
return scale * (self.weight - mean) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
result = F.conv1d( |
|
input, |
|
self.standardized_weight(), |
|
self.bias, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
if self.trim == 0: |
|
return result |
|
else: |
|
return result[:, :, : -self.trim] |
|
|
|
def merge_weights(self): |
|
self.weight.data[:] = self.standardized_weight().detach() |
|
self.gain.data.fill_(1.0) |
|
|
|
|
|
class WSLinear(nn.Linear): |
|
def __init__(self, in_features: int, out_features: int, bias: bool = True): |
|
super().__init__(in_features, out_features, bias) |
|
self.weight.data.normal_(0.0, math.sqrt(1.0 / in_features)) |
|
self.bias.data.zero_() |
|
self.gain = nn.Parameter(torch.ones((out_features, 1))) |
|
|
|
def standardized_weight(self) -> torch.Tensor: |
|
var, mean = torch.var_mean(self.weight, 1, keepdim=True) |
|
scale = self.gain * (self.in_features * var + 1e-8).rsqrt() |
|
return scale * (self.weight - mean) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return F.linear(input, self.standardized_weight(), self.bias) |
|
|
|
def merge_weights(self): |
|
self.weight.data[:] = self.standardized_weight().detach() |
|
self.gain.data.fill_(1.0) |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels: int, |
|
intermediate_channels: int, |
|
layer_scale_init_value: float, |
|
kernel_size: int = 7, |
|
use_weight_standardization: bool = False, |
|
enable_scaling: bool = False, |
|
pre_scale: float = 1.0, |
|
post_scale: float = 1.0, |
|
): |
|
super().__init__() |
|
self.use_weight_standardization = use_weight_standardization |
|
self.enable_scaling = enable_scaling |
|
self.dwconv = CausalConv1d( |
|
channels, channels, kernel_size=kernel_size, groups=channels |
|
) |
|
self.norm = nn.LayerNorm(channels) |
|
self.pwconv1 = nn.Linear(channels, intermediate_channels) |
|
self.pwconv2 = nn.Linear(intermediate_channels, channels) |
|
self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value)) |
|
self.dwconv.weight.data.normal_(0.0, math.sqrt(1.0 / kernel_size)) |
|
self.dwconv.bias.data.zero_() |
|
self.pwconv1.weight.data.normal_(0.0, math.sqrt(2.0 / channels)) |
|
self.pwconv1.bias.data.zero_() |
|
self.pwconv2.weight.data.normal_(0.0, math.sqrt(1.0 / intermediate_channels)) |
|
self.pwconv2.bias.data.zero_() |
|
if use_weight_standardization: |
|
self.norm = nn.Identity() |
|
self.dwconv = WSConv1d(channels, channels, kernel_size, groups=channels) |
|
self.pwconv1 = WSLinear(channels, intermediate_channels) |
|
self.pwconv2 = WSLinear(intermediate_channels, channels) |
|
del self.gamma |
|
if enable_scaling: |
|
self.register_buffer("pre_scale", torch.tensor(pre_scale)) |
|
self.register_buffer("post_scale", torch.tensor(post_scale)) |
|
self.post_scale_weight = nn.Parameter(torch.ones(())) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
identity = x |
|
if self.enable_scaling: |
|
x = x * self.pre_scale |
|
x = self.dwconv(x) |
|
x = x.transpose(1, 2) |
|
x = self.norm(x) |
|
x = self.pwconv1(x) |
|
x = F.gelu(x, approximate="tanh") |
|
x = self.pwconv2(x) |
|
if not self.use_weight_standardization: |
|
x *= self.gamma |
|
if self.enable_scaling: |
|
x *= self.post_scale * self.post_scale_weight |
|
x = x.transpose(1, 2) |
|
x += identity |
|
return x |
|
|
|
def merge_weights(self): |
|
if self.use_weight_standardization: |
|
self.dwconv.merge_weights() |
|
self.pwconv1.merge_weights() |
|
self.pwconv2.merge_weights() |
|
else: |
|
self.pwconv1.bias.data += ( |
|
self.norm.bias.data[None, :] * self.pwconv1.weight.data |
|
).sum(1) |
|
self.pwconv1.weight.data *= self.norm.weight.data[None, :] |
|
self.norm.bias.data[:] = 0.0 |
|
self.norm.weight.data[:] = 1.0 |
|
self.pwconv2.weight.data *= self.gamma.data[:, None] |
|
self.pwconv2.bias.data *= self.gamma.data |
|
self.gamma.data[:] = 1.0 |
|
if self.enable_scaling: |
|
self.dwconv.weight.data *= self.pre_scale.data |
|
self.pre_scale.data.fill_(1.0) |
|
self.pwconv2.weight.data *= ( |
|
self.post_scale.data * self.post_scale_weight.data |
|
) |
|
self.pwconv2.bias.data *= self.post_scale.data * self.post_scale_weight.data |
|
self.post_scale.data.fill_(1.0) |
|
self.post_scale_weight.data.fill_(1.0) |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.dwconv, f) |
|
dump_layer(self.pwconv1, f) |
|
dump_layer(self.pwconv2, f) |
|
|
|
|
|
class ConvNeXtStack(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
channels: int, |
|
intermediate_channels: int, |
|
n_blocks: int, |
|
delay: int, |
|
embed_kernel_size: int, |
|
kernel_size: int, |
|
use_weight_standardization: bool = False, |
|
enable_scaling: bool = False, |
|
): |
|
super().__init__() |
|
assert delay * 2 + 1 <= embed_kernel_size |
|
self.use_weight_standardization = use_weight_standardization |
|
self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay) |
|
self.norm = nn.LayerNorm(channels) |
|
self.convnext = nn.ModuleList() |
|
for i in range(n_blocks): |
|
pre_scale = 1.0 / math.sqrt(1.0 + i / n_blocks) if enable_scaling else 1.0 |
|
post_scale = 1.0 / math.sqrt(n_blocks) if enable_scaling else 1.0 |
|
block = ConvNeXtBlock( |
|
channels=channels, |
|
intermediate_channels=intermediate_channels, |
|
layer_scale_init_value=1.0 / n_blocks, |
|
kernel_size=kernel_size, |
|
use_weight_standardization=use_weight_standardization, |
|
enable_scaling=enable_scaling, |
|
pre_scale=pre_scale, |
|
post_scale=post_scale, |
|
) |
|
self.convnext.append(block) |
|
self.final_layer_norm = nn.LayerNorm(channels) |
|
self.embed.weight.data.normal_( |
|
0.0, math.sqrt(0.5 / (embed_kernel_size * in_channels)) |
|
) |
|
self.embed.bias.data.zero_() |
|
if use_weight_standardization: |
|
self.embed = WSConv1d(in_channels, channels, embed_kernel_size, delay=delay) |
|
self.norm = nn.Identity() |
|
self.final_layer_norm = nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.embed(x) |
|
x = self.norm(x.transpose(1, 2)).transpose(1, 2) |
|
for conv_block in self.convnext: |
|
x = conv_block(x) |
|
x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2) |
|
return x |
|
|
|
def merge_weights(self): |
|
if self.use_weight_standardization: |
|
self.embed.merge_weights() |
|
for conv_block in self.convnext: |
|
conv_block.merge_weights() |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.embed, f) |
|
if not self.use_weight_standardization: |
|
dump_layer(self.norm, f) |
|
dump_layer(self.convnext, f) |
|
if not self.use_weight_standardization: |
|
dump_layer(self.final_layer_norm, f) |
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
def __init__(self, hidden_channels: int): |
|
super().__init__() |
|
|
|
self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False)) |
|
self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False)) |
|
self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False)) |
|
self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False)) |
|
self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False)) |
|
self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
wav_length = x.size(2) |
|
if wav_length % 160 != 0: |
|
warnings.warn("wav_length % 160 != 0") |
|
x = F.pad(x, (40, 40)) |
|
x = F.gelu(self.conv0(x), approximate="tanh") |
|
x = F.gelu(self.conv1(x), approximate="tanh") |
|
x = F.gelu(self.conv2(x), approximate="tanh") |
|
x = F.gelu(self.conv3(x), approximate="tanh") |
|
x = F.gelu(self.conv4(x), approximate="tanh") |
|
x = F.gelu(self.conv5(x), approximate="tanh") |
|
|
|
return x |
|
|
|
def remove_weight_norm(self): |
|
remove_weight_norm(self.conv0) |
|
remove_weight_norm(self.conv1) |
|
remove_weight_norm(self.conv2) |
|
remove_weight_norm(self.conv3) |
|
remove_weight_norm(self.conv4) |
|
remove_weight_norm(self.conv5) |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.conv0, f) |
|
dump_layer(self.conv1, f) |
|
dump_layer(self.conv2, f) |
|
dump_layer(self.conv3, f) |
|
dump_layer(self.conv4, f) |
|
dump_layer(self.conv5, f) |
|
|
|
|
|
class FeatureProjection(nn.Module): |
|
def __init__(self, in_channels: int, out_channels: int): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(in_channels) |
|
self.projection = nn.Conv1d(in_channels, out_channels, 1) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
x = self.norm(x.transpose(1, 2)).transpose(1, 2) |
|
x = self.projection(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
def merge_weights(self): |
|
self.projection.bias.data += ( |
|
(self.norm.bias.data[None, :, None] * self.projection.weight.data) |
|
.sum(1) |
|
.squeeze(1) |
|
) |
|
self.projection.weight.data *= self.norm.weight.data[None, :, None] |
|
self.norm.bias.data[:] = 0.0 |
|
self.norm.weight.data[:] = 1.0 |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.projection, f) |
|
|
|
|
|
class PhoneExtractor(nn.Module): |
|
def __init__( |
|
self, |
|
phone_channels: int = 256, |
|
hidden_channels: int = 256, |
|
backbone_embed_kernel_size: int = 7, |
|
kernel_size: int = 17, |
|
n_blocks: int = 8, |
|
): |
|
super().__init__() |
|
self.feature_extractor = FeatureExtractor(hidden_channels) |
|
self.feature_projection = FeatureProjection(hidden_channels, hidden_channels) |
|
self.n_speaker_encoder_layers = 3 |
|
self.speaker_encoder = nn.GRU( |
|
hidden_channels, |
|
hidden_channels, |
|
self.n_speaker_encoder_layers, |
|
batch_first=True, |
|
) |
|
for i in range(self.n_speaker_encoder_layers): |
|
for input_char in "ih": |
|
self.speaker_encoder = weight_norm( |
|
self.speaker_encoder, f"weight_{input_char}h_l{i}" |
|
) |
|
self.backbone = ConvNeXtStack( |
|
in_channels=hidden_channels, |
|
channels=hidden_channels, |
|
intermediate_channels=hidden_channels * 3, |
|
n_blocks=n_blocks, |
|
delay=0, |
|
embed_kernel_size=backbone_embed_kernel_size, |
|
kernel_size=kernel_size, |
|
) |
|
self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1)) |
|
|
|
def forward( |
|
self, x: torch.Tensor, return_stats: bool = True |
|
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: |
|
|
|
|
|
stats = {} |
|
|
|
|
|
x = self.feature_extractor(x) |
|
if return_stats: |
|
stats["feature_norm"] = x.detach().norm(dim=1).mean() |
|
|
|
x = self.feature_projection(x) |
|
|
|
g, _ = self.speaker_encoder(x.transpose(1, 2)) |
|
if self.training: |
|
batch_size, length, _ = g.size() |
|
shuffle_sizes_for_each_data = torch.randint( |
|
0, 50, (batch_size,), device=g.device |
|
) |
|
max_indices = torch.arange(length, device=g.device)[None, :, None] |
|
min_indices = ( |
|
max_indices - shuffle_sizes_for_each_data[:, None, None] |
|
).clamp_(min=0) |
|
with torch.cuda.amp.autocast(False): |
|
indices = ( |
|
torch.rand(g.size(), device=g.device) |
|
* (max_indices - min_indices + 1) |
|
).long() + min_indices |
|
assert indices.min() >= 0, indices.min() |
|
assert indices.max() < length, (indices.max(), length) |
|
g = g.gather(1, indices) |
|
|
|
|
|
g = g.transpose(1, 2).contiguous() |
|
|
|
x = self.backbone(x + g) |
|
|
|
phone = self.head(F.gelu(x, approximate="tanh")) |
|
|
|
results = [phone] |
|
if return_stats: |
|
stats["code_norm"] = phone.detach().norm(dim=1).mean().item() |
|
results.append(stats) |
|
|
|
if len(results) == 1: |
|
return results[0] |
|
return tuple(results) |
|
|
|
@torch.inference_mode() |
|
def units(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
phone = self.forward(x, return_stats=False) |
|
|
|
phone = phone.transpose(1, 2) |
|
|
|
return phone |
|
|
|
def remove_weight_norm(self): |
|
self.feature_extractor.remove_weight_norm() |
|
for i in range(self.n_speaker_encoder_layers): |
|
for input_char in "ih": |
|
remove_weight_norm(self.speaker_encoder, f"weight_{input_char}h_l{i}") |
|
remove_weight_norm(self.head) |
|
|
|
def merge_weights(self): |
|
self.feature_projection.merge_weights() |
|
self.backbone.merge_weights() |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.feature_extractor, f) |
|
dump_layer(self.feature_projection, f) |
|
dump_layer(self.speaker_encoder, f) |
|
dump_layer(self.backbone, f) |
|
dump_layer(self.head, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_pitch_features( |
|
y: torch.Tensor, |
|
hop_length: int = 160, |
|
win_length: int = 560, |
|
max_corr_period: int = 256, |
|
corr_win_length: int = 304, |
|
instfreq_features_cutoff_bin: int = 64, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
assert max_corr_period + corr_win_length == win_length |
|
|
|
|
|
padding_length = (win_length - hop_length) // 2 |
|
y = F.pad(y, (padding_length, padding_length)) |
|
|
|
|
|
|
|
y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1) |
|
|
|
|
|
|
|
spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2) |
|
|
|
|
|
spec = spec[..., :instfreq_features_cutoff_bin, :] |
|
|
|
|
|
log_power_spec = spec.abs().add_(1e-5).log10_() |
|
|
|
|
|
|
|
delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj() |
|
delta_spec /= delta_spec.abs().add_(1e-5) |
|
delta_spec = torch.cat( |
|
[torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1 |
|
) |
|
|
|
|
|
instfreq_features = torch.cat( |
|
[log_power_spec, delta_spec.real, delta_spec.imag], dim=-2 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
flipped_y_frames = y_frames.flip((-2,)) |
|
a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2) |
|
b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2) |
|
|
|
corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :] |
|
|
|
|
|
energy = flipped_y_frames.square_().cumsum_(-2) |
|
energy0 = energy[..., corr_win_length - 1 : corr_win_length, :] |
|
energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :] |
|
|
|
|
|
corr_diff = (energy0 + energy).sub_(corr.mul_(2.0)) |
|
assert corr_diff.min() >= -1e-3, corr_diff.min() |
|
corr_diff.clamp_(min=0.0) |
|
|
|
|
|
corr_diff *= 2.0 / corr_win_length |
|
corr_diff.sqrt_() |
|
|
|
|
|
energy = ( |
|
(y_frames * torch.signal.windows.cosine(win_length, device=y.device)[..., None]) |
|
.square_() |
|
.sum(-2, keepdim=True) |
|
) |
|
|
|
energy.clamp_(min=1e-3).log10_() |
|
energy *= 0.5 |
|
|
|
return ( |
|
instfreq_features, |
|
corr_diff, |
|
energy, |
|
) |
|
|
|
|
|
class PitchEstimator(nn.Module): |
|
def __init__( |
|
self, |
|
input_instfreq_channels: int = 192, |
|
input_corr_channels: int = 256, |
|
pitch_channels: int = 384, |
|
channels: int = 192, |
|
intermediate_channels: int = 192 * 3, |
|
n_blocks: int = 6, |
|
delay: int = 1, |
|
embed_kernel_size: int = 3, |
|
kernel_size: int = 33, |
|
bins_per_octave: int = 96, |
|
): |
|
super().__init__() |
|
self.bins_per_octave = bins_per_octave |
|
|
|
self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1) |
|
self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1) |
|
self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1) |
|
self.corr_embed_1 = nn.Conv1d(channels, channels, 1) |
|
self.backbone = ConvNeXtStack( |
|
channels, |
|
channels, |
|
intermediate_channels, |
|
n_blocks, |
|
delay, |
|
embed_kernel_size, |
|
kernel_size, |
|
) |
|
self.head = nn.Conv1d(channels, pitch_channels, 1) |
|
|
|
def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
with torch.amp.autocast("cuda", enabled=False): |
|
instfreq_features, corr_diff, energy = extract_pitch_features( |
|
wav.squeeze(1), |
|
hop_length=160, |
|
win_length=560, |
|
max_corr_period=256, |
|
corr_win_length=304, |
|
instfreq_features_cutoff_bin=64, |
|
) |
|
instfreq_features = F.gelu( |
|
self.instfreq_embed_0(instfreq_features), approximate="tanh" |
|
) |
|
instfreq_features = self.instfreq_embed_1(instfreq_features) |
|
corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh") |
|
corr_diff = self.corr_embed_1(corr_diff) |
|
|
|
x = instfreq_features + corr_diff |
|
x = self.backbone(x) |
|
|
|
x = self.head(x) |
|
return x, energy |
|
|
|
def sample_pitch( |
|
self, pitch: torch.Tensor, band_width: int = 48, return_features: bool = False |
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
|
|
|
|
batch_size, pitch_channels, length = pitch.size() |
|
pitch = pitch.softmax(1) |
|
if return_features: |
|
unvoiced_proba = pitch[:, :1, :].clone() |
|
pitch[:, 0, :] = -100.0 |
|
pitch = ( |
|
pitch.transpose(1, 2) |
|
.contiguous() |
|
.view(batch_size * length, 1, pitch_channels) |
|
) |
|
band_pitch = F.conv1d( |
|
pitch, |
|
torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width), |
|
) |
|
|
|
quantized_band_pitch = band_pitch.argmax(2) |
|
if return_features: |
|
|
|
band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None]) |
|
|
|
half_pitch_band_proba = band_pitch.gather( |
|
2, |
|
(quantized_band_pitch - self.bins_per_octave).clamp_(min=1)[:, :, None], |
|
) |
|
half_pitch_band_proba[quantized_band_pitch <= self.bins_per_octave] = 0.0 |
|
half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view( |
|
batch_size, 1, length |
|
) |
|
|
|
double_pitch_band_proba = band_pitch.gather( |
|
2, |
|
(quantized_band_pitch + self.bins_per_octave).clamp_( |
|
max=pitch_channels - band_width |
|
)[:, :, None], |
|
) |
|
double_pitch_band_proba[ |
|
quantized_band_pitch |
|
> pitch_channels - band_width - self.bins_per_octave |
|
] = 0.0 |
|
double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view( |
|
batch_size, 1, length |
|
) |
|
|
|
mask = torch.arange(pitch_channels, device=pitch.device)[None, :] |
|
|
|
mask = (quantized_band_pitch <= mask) & ( |
|
mask < quantized_band_pitch + band_width |
|
) |
|
|
|
quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length) |
|
|
|
if return_features: |
|
features = torch.cat( |
|
[unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1 |
|
) |
|
|
|
return quantized_pitch, features |
|
else: |
|
return quantized_pitch |
|
|
|
def merge_weights(self): |
|
self.backbone.merge_weights() |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.instfreq_embed_0, f) |
|
dump_layer(self.instfreq_embed_1, f) |
|
dump_layer(self.corr_embed_0, f) |
|
dump_layer(self.corr_embed_1, f) |
|
dump_layer(self.backbone, f) |
|
dump_layer(self.head, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def overlap_add( |
|
ir_amp: torch.Tensor, |
|
ir_phase: torch.Tensor, |
|
window: torch.Tensor, |
|
pitch: torch.Tensor, |
|
hop_length: int = 240, |
|
delay: int = 0, |
|
sr: float = 24000.0, |
|
) -> torch.Tensor: |
|
batch_size, ir_length, length = ir_amp.size() |
|
ir_length = (ir_length - 1) * 2 |
|
assert ir_phase.size() == ir_amp.size() |
|
assert window.size() == (ir_length,), (window.size(), ir_amp.size()) |
|
assert pitch.size() == (batch_size, length * hop_length) |
|
assert 0 <= delay < ir_length, (delay, ir_length) |
|
|
|
normalized_freq = pitch / sr |
|
|
|
normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device) |
|
with torch.amp.autocast("cuda", enabled=False): |
|
phase = (normalized_freq.double().cumsum_(1) % 1.0).float() |
|
|
|
|
|
indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True) |
|
|
|
numer = 1.0 - phase[indices0, indices1] |
|
|
|
fractional_part = numer / (numer + phase[indices0, indices1 + 1]) |
|
|
|
|
|
ir_amp = ir_amp[indices0, :, indices1 // hop_length] |
|
ir_phase = ir_phase[indices0, :, indices1 // hop_length] |
|
|
|
|
|
delay_phase = ( |
|
torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[ |
|
None, : |
|
] |
|
* (-math.tau / ir_length) |
|
* fractional_part[:, None] |
|
) |
|
|
|
spec = torch.polar(ir_amp, ir_phase + delay_phase) |
|
|
|
ir = torch.fft.irfft(spec, n=ir_length, dim=1) |
|
ir *= window |
|
|
|
|
|
|
|
ir = ir.ravel() |
|
|
|
indices0 = indices0[:, None].expand(-1, ir_length).ravel() |
|
|
|
indices1 = ( |
|
indices1[:, None] + torch.arange(ir_length, device=pitch.device) |
|
).ravel() |
|
|
|
|
|
overlap_added_signal = torch.zeros( |
|
(batch_size, length * hop_length + ir_length), device=pitch.device |
|
) |
|
overlap_added_signal.index_put_((indices0, indices1), ir, accumulate=True) |
|
overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay] |
|
|
|
return overlap_added_signal |
|
|
|
|
|
def generate_noise( |
|
aperiodicity: torch.Tensor, delay: int = 0 |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
|
batch_size, hop_length, length = aperiodicity.size() |
|
excitation = torch.rand( |
|
batch_size, (length + 1) * hop_length, device=aperiodicity.device |
|
) |
|
excitation -= 0.5 |
|
n_fft = 2 * hop_length |
|
|
|
|
|
noise = torch.stft( |
|
excitation, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
window=torch.ones(n_fft, device=excitation.device), |
|
center=False, |
|
return_complex=True, |
|
) |
|
assert noise.size(2) == aperiodicity.size(2) |
|
noise[:, 0, :] = 0.0 |
|
noise[:, 1:, :] *= aperiodicity |
|
|
|
|
|
|
|
noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1) |
|
noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None] |
|
|
|
noise = F.fold( |
|
noise, |
|
(1, (length + 1) * hop_length), |
|
(1, 2 * hop_length), |
|
stride=(1, hop_length), |
|
).squeeze_((1, 2)) |
|
|
|
assert delay < hop_length |
|
noise = noise[:, delay : -hop_length + delay] |
|
excitation = excitation[:, delay : -hop_length + delay] |
|
return noise, excitation |
|
|
|
|
|
class GradientEqualizerFunction(torch.autograd.Function): |
|
"""ノルムが小さいほど勾配が大きくなってしまうのを補正する""" |
|
|
|
@staticmethod |
|
def forward(ctx, x: torch.Tensor) -> torch.Tensor: |
|
|
|
rms = x.square().mean(dim=2, keepdim=True).sqrt_() |
|
ctx.save_for_backward(rms) |
|
return x |
|
|
|
@staticmethod |
|
def backward(ctx, dx: torch.Tensor) -> torch.Tensor: |
|
|
|
(rms,) = ctx.saved_tensors |
|
dx = dx * (math.sqrt(2.0) * rms + 0.1) |
|
return dx |
|
|
|
|
|
D4C_PREVENT_ZERO_DIVISION = True |
|
|
|
|
|
def interp(x: torch.Tensor, y: torch.Tensor, xi: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x = torch.as_tensor(x) |
|
y = torch.as_tensor(y) |
|
xi = torch.as_tensor(xi) |
|
if xi.ndim < y.ndim: |
|
diff_ndim = y.ndim - xi.ndim |
|
xi = xi.view(tuple([1] * diff_ndim) + xi.size()) |
|
if xi.size()[:-1] != y.size()[:-1]: |
|
xi = xi.expand(y.size()[:-1] + (xi.size(-1),)) |
|
assert (x.min(-1).values == x[..., 0]).all() |
|
assert (x.max(-1).values == x[..., -1]).all() |
|
assert (xi.min(-1).values >= x[..., 0]).all() |
|
assert (xi.max(-1).values <= x[..., -1]).all() |
|
delta_x = (x[..., -1].double() - x[..., 0].double()) / (x.size(-1) - 1.0) |
|
delta_x = delta_x.to(x.dtype) |
|
xi = (xi - x[..., :1]).div_(delta_x[..., None]) |
|
xi_base = xi.floor() |
|
xi_fraction = xi.sub_(xi_base) |
|
xi_base = xi_base.long() |
|
delta_y = y.diff(dim=-1, append=y[..., -1:]) |
|
yi = y.gather(-1, xi_base) + delta_y.gather(-1, xi_base) * xi_fraction |
|
return yi |
|
|
|
|
|
def linear_smoothing( |
|
group_delay: torch.Tensor, sr: int, n_fft: int, width: torch.Tensor |
|
) -> torch.Tensor: |
|
group_delay = torch.as_tensor(group_delay) |
|
assert group_delay.size(-1) == n_fft // 2 + 1 |
|
width = torch.as_tensor(width) |
|
boundary = (width.max() * n_fft / sr).long() + 1 |
|
|
|
dtype = group_delay.dtype |
|
device = group_delay.device |
|
fft_resolution = sr / n_fft |
|
mirroring_freq_axis = ( |
|
torch.arange(-boundary, n_fft // 2 + 1 + boundary, dtype=dtype, device=device) |
|
.add(0.5) |
|
.mul(fft_resolution) |
|
) |
|
if group_delay.ndim == 1: |
|
mirroring_spec = F.pad( |
|
group_delay[None], (boundary, boundary), mode="reflect" |
|
).squeeze_(0) |
|
elif group_delay.ndim >= 4: |
|
shape = group_delay.size() |
|
mirroring_spec = F.pad( |
|
group_delay.view(math.prod(shape[:-1]), group_delay.size(-1)), |
|
(boundary, boundary), |
|
mode="reflect", |
|
).view(shape[:-1] + (shape[-1] + 2 * boundary,)) |
|
else: |
|
mirroring_spec = F.pad(group_delay, (boundary, boundary), mode="reflect") |
|
mirroring_segment = mirroring_spec.mul(fft_resolution).cumsum_(-1) |
|
center_freq = torch.arange(n_fft // 2 + 1, dtype=dtype, device=device).mul_( |
|
fft_resolution |
|
) |
|
low_freq = center_freq - width[..., None] * 0.5 |
|
high_freq = center_freq + width[..., None] * 0.5 |
|
levels = interp( |
|
mirroring_freq_axis, mirroring_segment, torch.cat([low_freq, high_freq], dim=-1) |
|
) |
|
low_levels, high_levels = levels.split([n_fft // 2 + 1] * 2, dim=-1) |
|
smoothed = (high_levels - low_levels).div_(width[..., None]) |
|
return smoothed |
|
|
|
|
|
def dc_correction( |
|
spec: torch.Tensor, sr: int, n_fft: int, f0: torch.Tensor |
|
) -> torch.Tensor: |
|
spec = torch.as_tensor(spec) |
|
f0 = torch.as_tensor(f0) |
|
dtype = spec.dtype |
|
device = spec.device |
|
|
|
upper_limit = 2 + (f0 * (n_fft / sr)).long() |
|
max_upper_limit = upper_limit.max() |
|
upper_limit_mask = ( |
|
torch.arange(max_upper_limit - 1, device=device) < (upper_limit - 1)[..., None] |
|
) |
|
low_freq_axis = torch.arange(max_upper_limit + 1, dtype=dtype, device=device) * ( |
|
sr / n_fft |
|
) |
|
low_freq_replica = interp( |
|
f0[..., None] - low_freq_axis.flip(-1), |
|
spec[..., : max_upper_limit + 1].flip(-1), |
|
low_freq_axis[..., : max_upper_limit - 1] * upper_limit_mask, |
|
) |
|
output = spec.clone() |
|
output[..., : max_upper_limit - 1] += low_freq_replica * upper_limit_mask |
|
return output |
|
|
|
|
|
def nuttall(n: int, device: torch.types.Device) -> torch.Tensor: |
|
t = torch.linspace(0, math.tau, n, device=device) |
|
coefs = torch.tensor([0.355768, -0.487396, 0.144232, -0.012604], device=device) |
|
terms = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device) |
|
cos_matrix = (terms[:, None] * t).cos_() |
|
window = coefs.matmul(cos_matrix) |
|
return window |
|
|
|
|
|
def get_windowed_waveform( |
|
x: torch.Tensor, |
|
sr: int, |
|
f0: torch.Tensor, |
|
position: torch.Tensor, |
|
half_window_length_ratio: float, |
|
window_type: Literal["hann", "blackman"], |
|
n_fft: int, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
x = torch.as_tensor(x) |
|
f0 = torch.as_tensor(f0) |
|
position = torch.as_tensor(position) |
|
|
|
current_sample = position * sr |
|
|
|
half_window_length = (half_window_length_ratio * sr / f0).add_(0.5).long() |
|
|
|
base_index = -half_window_length[..., None] + torch.arange(n_fft, device=x.device) |
|
base_index_mask = base_index <= half_window_length[..., None] |
|
|
|
safe_index = ((current_sample + 0.501).long()[..., None] + base_index).clamp_( |
|
0, x.size(-1) - 1 |
|
) |
|
|
|
time_axis = base_index.to(x.dtype).div_(half_window_length_ratio) |
|
|
|
normalized_f0 = math.pi / sr * f0 |
|
|
|
phase = time_axis.mul_(normalized_f0[..., None]) |
|
|
|
if window_type == "hann": |
|
window = phase.cos_().mul_(0.5).add_(0.5) |
|
elif window_type == "blackman": |
|
window = phase.mul(2.0).cos_().mul_(0.08).add_(phase.cos().mul_(0.5)).add_(0.42) |
|
else: |
|
assert False |
|
window *= base_index_mask |
|
|
|
prefix_shape = tuple( |
|
max(x_size, i_size) for x_size, i_size in zip(x.size(), safe_index.size()) |
|
)[:-1] |
|
waveform = ( |
|
x.expand(prefix_shape + (-1,)) |
|
.gather(-1, safe_index.expand(prefix_shape + (-1,))) |
|
.mul_(window) |
|
) |
|
if not D4C_PREVENT_ZERO_DIVISION: |
|
waveform += torch.randn_like(window).mul_(1e-12) |
|
waveform *= base_index_mask |
|
waveform -= window * waveform.sum(-1, keepdim=True).div_( |
|
window.sum(-1, keepdim=True) |
|
) |
|
return waveform, window |
|
|
|
|
|
def get_centroid(x: torch.Tensor, n_fft: int) -> torch.Tensor: |
|
x = torch.as_tensor(x) |
|
if D4C_PREVENT_ZERO_DIVISION: |
|
x = x / x.norm(dim=-1, keepdim=True).clamp(min=6e-8) |
|
else: |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
spec0 = torch.fft.rfft(x, n_fft) |
|
spec1 = torch.fft.rfft( |
|
x * torch.arange(1, x.size(-1) + 1, dtype=x.dtype, device=x.device).div_(n_fft), |
|
n_fft, |
|
) |
|
centroid = spec0.real * spec1.real + spec0.imag * spec1.imag |
|
return centroid |
|
|
|
|
|
def get_static_centroid( |
|
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int |
|
) -> torch.Tensor: |
|
"""First step: calculation of temporally static parameters on basis of group delay""" |
|
x1, _ = get_windowed_waveform( |
|
x, sr, f0, position + 0.25 / f0, 2.0, "blackman", n_fft |
|
) |
|
x2, _ = get_windowed_waveform( |
|
x, sr, f0, position - 0.25 / f0, 2.0, "blackman", n_fft |
|
) |
|
centroid1 = get_centroid(x1, n_fft) |
|
centroid2 = get_centroid(x2, n_fft) |
|
return dc_correction(centroid1 + centroid2, sr, n_fft, f0) |
|
|
|
|
|
def get_smoothed_power_spec( |
|
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, n_fft: int |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
x = torch.as_tensor(x) |
|
f0 = torch.as_tensor(f0) |
|
x, window = get_windowed_waveform(x, sr, f0, position, 2.0, "hann", n_fft) |
|
window_weight = window.square().sum(-1, keepdim=True) |
|
rms = x.square().sum(-1, keepdim=True).div_(window_weight).sqrt_() |
|
if D4C_PREVENT_ZERO_DIVISION: |
|
x = x / (rms * math.sqrt(n_fft)).clamp_(min=6e-8) |
|
smoothed_power_spec = torch.fft.rfft(x, n_fft).abs().square_() |
|
smoothed_power_spec = dc_correction(smoothed_power_spec, sr, n_fft, f0) |
|
smoothed_power_spec = linear_smoothing(smoothed_power_spec, sr, n_fft, f0) |
|
return smoothed_power_spec, rms.detach().squeeze(-1) |
|
|
|
|
|
def get_static_group_delay( |
|
static_centroid: torch.Tensor, |
|
smoothed_power_spec: torch.Tensor, |
|
sr: int, |
|
f0: torch.Tensor, |
|
n_fft: int, |
|
) -> torch.Tensor: |
|
"""Second step: calculation of parameter shaping""" |
|
if D4C_PREVENT_ZERO_DIVISION: |
|
smoothed_power_spec = smoothed_power_spec.clamp(min=6e-8) |
|
static_group_delay = static_centroid / smoothed_power_spec |
|
static_group_delay = linear_smoothing( |
|
static_group_delay, sr, n_fft, f0 * 0.5 |
|
) |
|
smoothed_group_delay = linear_smoothing(static_group_delay, sr, n_fft, f0) |
|
static_group_delay = static_group_delay - smoothed_group_delay |
|
return static_group_delay |
|
|
|
|
|
def get_coarse_aperiodicity( |
|
group_delay: torch.Tensor, |
|
sr: int, |
|
n_fft: int, |
|
freq_interval: int, |
|
n_aperiodicities: int, |
|
window: torch.Tensor, |
|
) -> torch.Tensor: |
|
"""Third step: estimation of band-aperiodicity""" |
|
group_delay = torch.as_tensor(group_delay) |
|
window = torch.as_tensor(window) |
|
boundary = int(round(n_fft * 8 / window.size(-1))) |
|
half_window_length = window.size(-1) // 2 |
|
coarse_aperiodicity = torch.empty( |
|
group_delay.size()[:-1] + (n_aperiodicities,), |
|
dtype=group_delay.dtype, |
|
device=group_delay.device, |
|
) |
|
for i in range(n_aperiodicities): |
|
center = freq_interval * (i + 1) * n_fft // sr |
|
segment = ( |
|
group_delay[ |
|
..., center - half_window_length : center + half_window_length + 1 |
|
] |
|
* window |
|
) |
|
power_spec: torch.Tensor = torch.fft.rfft(segment, n_fft).abs().square_() |
|
cumulative_power_spec = power_spec.sort(-1).values.cumsum_(-1) |
|
if D4C_PREVENT_ZERO_DIVISION: |
|
cumulative_power_spec.clamp_(min=6e-8) |
|
coarse_aperiodicity[..., i] = ( |
|
cumulative_power_spec[..., n_fft // 2 - boundary - 1] |
|
/ cumulative_power_spec[..., -1] |
|
) |
|
coarse_aperiodicity.log10_().mul_(10.0) |
|
return coarse_aperiodicity |
|
|
|
|
|
def d4c_love_train( |
|
x: torch.Tensor, sr: int, f0: torch.Tensor, position: torch.Tensor, threshold: float |
|
) -> int: |
|
x = torch.as_tensor(x) |
|
position = torch.as_tensor(position) |
|
f0: torch.Tensor = torch.as_tensor(f0) |
|
vuv = f0 != 0 |
|
lowest_f0 = 40 |
|
f0 = f0.clamp(min=lowest_f0) |
|
n_fft = 1 << (3 * sr // lowest_f0).bit_length() |
|
boundary0 = (100 * n_fft - 1) // sr + 1 |
|
boundary1 = (4000 * n_fft - 1) // sr + 1 |
|
boundary2 = (7900 * n_fft - 1) // sr + 1 |
|
waveform, _ = get_windowed_waveform(x, sr, f0, position, 1.5, "blackman", n_fft) |
|
power_spec = torch.fft.rfft(waveform, n_fft).abs().square_() |
|
power_spec[..., : boundary0 + 1] = 0.0 |
|
cumulative_spec = power_spec.cumsum_(-1) |
|
vuv = vuv & ( |
|
cumulative_spec[..., boundary1] > threshold * cumulative_spec[..., boundary2] |
|
) |
|
return vuv |
|
|
|
|
|
def d4c_general_body( |
|
x: torch.Tensor, |
|
sr: int, |
|
f0: torch.Tensor, |
|
freq_interval: int, |
|
position: torch.Tensor, |
|
n_fft: int, |
|
n_aperiodicities: int, |
|
window: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
static_centroid = get_static_centroid(x, sr, f0, position, n_fft) |
|
smoothed_power_spec, rms = get_smoothed_power_spec(x, sr, f0, position, n_fft) |
|
static_group_delay = get_static_group_delay( |
|
static_centroid, smoothed_power_spec, sr, f0, n_fft |
|
) |
|
coarse_aperiodicity = get_coarse_aperiodicity( |
|
static_group_delay, sr, n_fft, freq_interval, n_aperiodicities, window |
|
) |
|
coarse_aperiodicity.add_((f0[..., None] - 100.0).div_(50.0)).clamp_(max=0.0) |
|
return coarse_aperiodicity, rms |
|
|
|
|
|
def d4c( |
|
x: torch.Tensor, |
|
f0: torch.Tensor, |
|
t: torch.Tensor, |
|
sr: int, |
|
threshold: float = 0.85, |
|
n_fft_spec: Optional[int] = None, |
|
coarse_only: bool = False, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Adapted from https://github.com/tuanad121/Python-WORLD/blob/master/world/d4c.py""" |
|
FLOOR_F0 = 71 |
|
FLOOR_F0_D4C = 47 |
|
UPPER_LIMIT = 15000 |
|
FREQ_INTERVAL = 3000 |
|
|
|
assert sr == int(sr) |
|
sr = int(sr) |
|
assert sr % 2 == 0 |
|
x = torch.as_tensor(x) |
|
f0 = torch.as_tensor(f0) |
|
temporal_positions = torch.as_tensor(t) |
|
|
|
n_fft_d4c = 1 << (4 * sr // FLOOR_F0_D4C).bit_length() |
|
if n_fft_spec is None: |
|
n_fft_spec = 1 << (3 * sr // FLOOR_F0).bit_length() |
|
n_aperiodicities = min(UPPER_LIMIT, sr // 2 - FREQ_INTERVAL) // FREQ_INTERVAL |
|
assert n_aperiodicities >= 1 |
|
window_length = FREQ_INTERVAL * n_fft_d4c // sr * 2 + 1 |
|
window = nuttall(window_length, device=x.device) |
|
freq_axis = torch.arange(n_fft_spec // 2 + 1, device=x.device) * (sr / n_fft_spec) |
|
|
|
coarse_aperiodicity, rms = d4c_general_body( |
|
x[..., None, :], |
|
sr, |
|
f0.clamp(min=FLOOR_F0_D4C), |
|
FREQ_INTERVAL, |
|
temporal_positions, |
|
n_fft_d4c, |
|
n_aperiodicities, |
|
window, |
|
) |
|
if coarse_only: |
|
return coarse_aperiodicity, rms |
|
|
|
even_coarse_axis = ( |
|
torch.arange(n_aperiodicities + 3, device=x.device) * FREQ_INTERVAL |
|
) |
|
assert even_coarse_axis[-2] <= sr // 2 < even_coarse_axis[-1], sr |
|
coarse_axis_low = ( |
|
torch.arange(n_aperiodicities + 1, dtype=torch.float, device=x.device) |
|
* FREQ_INTERVAL |
|
) |
|
aperiodicity_low = interp( |
|
coarse_axis_low, |
|
F.pad(coarse_aperiodicity, (1, 0), value=-60.0), |
|
freq_axis[freq_axis < n_aperiodicities * FREQ_INTERVAL], |
|
) |
|
coarse_axis_high = torch.tensor( |
|
[n_aperiodicities * FREQ_INTERVAL, sr * 0.5], device=x.device |
|
) |
|
aperiodicity_high = interp( |
|
coarse_axis_high, |
|
F.pad(coarse_aperiodicity[..., -1:], (0, 1), value=-1e-12), |
|
freq_axis[freq_axis >= n_aperiodicities * FREQ_INTERVAL], |
|
) |
|
aperiodicity = torch.cat([aperiodicity_low, aperiodicity_high], -1) |
|
aperiodicity = 10.0 ** (aperiodicity / 20.0) |
|
vuv = d4c_love_train(x[..., None, :], sr, f0, temporal_positions, threshold) |
|
aperiodicity = torch.where(vuv[..., None], aperiodicity, 1 - 1e-12) |
|
|
|
return aperiodicity, coarse_aperiodicity |
|
|
|
|
|
class Vocoder(nn.Module): |
|
def __init__( |
|
self, |
|
channels: int, |
|
hop_length: int = 240, |
|
n_pre_blocks: int = 4, |
|
out_sample_rate: float = 24000.0, |
|
): |
|
super().__init__() |
|
self.hop_length = hop_length |
|
self.out_sample_rate = out_sample_rate |
|
|
|
self.prenet = ConvNeXtStack( |
|
in_channels=channels, |
|
channels=channels, |
|
intermediate_channels=channels * 3, |
|
n_blocks=n_pre_blocks, |
|
delay=2, |
|
embed_kernel_size=7, |
|
kernel_size=33, |
|
enable_scaling=True, |
|
) |
|
self.ir_generator = ConvNeXtStack( |
|
in_channels=channels, |
|
channels=channels, |
|
intermediate_channels=channels * 3, |
|
n_blocks=2, |
|
delay=0, |
|
embed_kernel_size=3, |
|
kernel_size=33, |
|
use_weight_standardization=True, |
|
enable_scaling=True, |
|
) |
|
self.ir_generator_post = WSConv1d(channels, 512, 1) |
|
self.register_buffer("ir_scale", torch.tensor(1.0)) |
|
self.ir_window = nn.Parameter(torch.ones(512)) |
|
self.aperiodicity_generator = ConvNeXtStack( |
|
in_channels=channels, |
|
channels=channels, |
|
intermediate_channels=channels * 3, |
|
n_blocks=1, |
|
delay=0, |
|
embed_kernel_size=3, |
|
kernel_size=33, |
|
use_weight_standardization=True, |
|
enable_scaling=True, |
|
) |
|
self.aperiodicity_generator_post = WSConv1d(channels, hop_length, 1, bias=False) |
|
self.register_buffer("aperiodicity_scale", torch.tensor(0.005)) |
|
self.post_filter_generator = ConvNeXtStack( |
|
in_channels=channels, |
|
channels=channels, |
|
intermediate_channels=channels * 3, |
|
n_blocks=1, |
|
delay=0, |
|
embed_kernel_size=3, |
|
kernel_size=33, |
|
use_weight_standardization=True, |
|
enable_scaling=True, |
|
) |
|
self.post_filter_generator_post = WSConv1d(channels, 512, 1, bias=False) |
|
self.register_buffer("post_filter_scale", torch.tensor(0.01)) |
|
|
|
def forward( |
|
self, x: torch.Tensor, pitch: torch.Tensor |
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: |
|
|
|
|
|
batch_size, _, length = x.size() |
|
|
|
x = self.prenet(x) |
|
ir = self.ir_generator(x) |
|
ir = F.silu(ir, inplace=True) |
|
|
|
ir = self.ir_generator_post(ir) |
|
ir *= self.ir_scale |
|
ir_amp = ir[:, : ir.size(1) // 2 + 1, :].exp() |
|
ir_phase = F.pad(ir[:, ir.size(1) // 2 + 1 :, :], (0, 0, 1, 1)) |
|
ir_phase[:, 1::2, :] += math.pi |
|
|
|
|
|
|
|
|
|
pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1) |
|
|
|
|
|
periodic_signal = overlap_add( |
|
ir_amp, |
|
ir_phase, |
|
self.ir_window, |
|
pitch, |
|
self.hop_length, |
|
delay=0, |
|
sr=self.out_sample_rate, |
|
) |
|
|
|
aperiodicity = self.aperiodicity_generator(x) |
|
aperiodicity = F.silu(aperiodicity, inplace=True) |
|
|
|
aperiodicity = self.aperiodicity_generator_post(aperiodicity) |
|
aperiodicity *= self.aperiodicity_scale |
|
|
|
aperiodic_signal, noise_excitation = generate_noise(aperiodicity, delay=0) |
|
|
|
post_filter = self.post_filter_generator(x) |
|
post_filter = F.silu(post_filter, inplace=True) |
|
|
|
post_filter = self.post_filter_generator_post(post_filter) |
|
post_filter *= self.post_filter_scale |
|
post_filter[:, 0, :] += 1.0 |
|
|
|
post_filter = post_filter.transpose(1, 2) |
|
with torch.amp.autocast("cuda", enabled=False): |
|
periodic_signal = periodic_signal.float() |
|
aperiodic_signal = aperiodic_signal.float() |
|
post_filter = post_filter.float() |
|
post_filter = torch.fft.rfft(post_filter, n=768) |
|
|
|
|
|
periodic_signal = torch.fft.irfft( |
|
torch.fft.rfft( |
|
periodic_signal.view(batch_size, length, self.hop_length), n=768 |
|
) |
|
* post_filter, |
|
n=768, |
|
) |
|
aperiodic_signal = torch.fft.irfft( |
|
torch.fft.rfft( |
|
aperiodic_signal.view(batch_size, length, self.hop_length), n=768 |
|
) |
|
* post_filter, |
|
n=768, |
|
) |
|
periodic_signal = F.fold( |
|
periodic_signal.transpose(1, 2), |
|
(1, (length - 1) * self.hop_length + 768), |
|
(1, 768), |
|
stride=(1, self.hop_length), |
|
).squeeze_((1, 2)) |
|
aperiodic_signal = F.fold( |
|
aperiodic_signal.transpose(1, 2), |
|
(1, (length - 1) * self.hop_length + 768), |
|
(1, 768), |
|
stride=(1, self.hop_length), |
|
).squeeze_((1, 2)) |
|
periodic_signal = periodic_signal[:, 120 : 120 + length * self.hop_length] |
|
aperiodic_signal = aperiodic_signal[:, 120 : 120 + length * self.hop_length] |
|
noise_excitation = noise_excitation[:, 120:] |
|
|
|
|
|
|
|
|
|
y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :] |
|
|
|
y_g_hat = GradientEqualizerFunction.apply(y_g_hat) |
|
|
|
return y_g_hat, { |
|
"periodic_signal": periodic_signal.detach(), |
|
"aperiodic_signal": aperiodic_signal.detach(), |
|
"noise_excitation": noise_excitation.detach(), |
|
} |
|
|
|
def merge_weights(self): |
|
self.prenet.merge_weights() |
|
self.ir_generator.merge_weights() |
|
self.ir_generator_post.merge_weights() |
|
self.aperiodicity_generator.merge_weights() |
|
self.aperiodicity_generator_post.merge_weights() |
|
self.ir_generator_post.weight.data *= self.ir_scale |
|
self.ir_generator_post.bias.data *= self.ir_scale |
|
self.ir_scale.fill_(1.0) |
|
self.aperiodicity_generator_post.weight.data *= self.aperiodicity_scale |
|
self.aperiodicity_scale.fill_(1.0) |
|
self.post_filter_generator.merge_weights() |
|
self.post_filter_generator_post.merge_weights() |
|
self.post_filter_generator_post.weight.data *= self.post_filter_scale |
|
self.post_filter_scale.fill_(1.0) |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.prenet, f) |
|
dump_layer(self.ir_generator, f) |
|
dump_layer(self.ir_generator_post, f) |
|
dump_layer(self.ir_window, f) |
|
dump_layer(self.aperiodicity_generator, f) |
|
dump_layer(self.aperiodicity_generator_post, f) |
|
dump_layer(self.post_filter_generator, f) |
|
dump_layer(self.post_filter_generator_post, f) |
|
|
|
|
|
def compute_loudness( |
|
x: torch.Tensor, sr: int, win_lengths: list[int] |
|
) -> list[torch.Tensor]: |
|
|
|
assert x.ndim == 2 |
|
n_fft = 2048 |
|
chunk_length = n_fft // 2 |
|
n_taps = chunk_length + 1 |
|
|
|
results = [] |
|
with torch.amp.autocast("cuda", enabled=False): |
|
if not hasattr(compute_loudness, "filter"): |
|
compute_loudness.filter = {} |
|
if sr not in compute_loudness.filter: |
|
ir = torch.zeros(n_taps, device=x.device, dtype=torch.double) |
|
ir[0] = 0.5 |
|
ir = torchaudio.functional.treble_biquad( |
|
ir, sr, 4.0, 1500.0, 1.0 / math.sqrt(2) |
|
) |
|
ir = torchaudio.functional.highpass_biquad(ir, sr, 38.0, 0.5) |
|
ir *= 2.0 |
|
compute_loudness.filter[sr] = torch.fft.rfft(ir, n=n_fft).to( |
|
torch.complex64 |
|
) |
|
|
|
x = x.float() |
|
wav_length = x.size(-1) |
|
if wav_length % chunk_length != 0: |
|
x = F.pad(x, (0, chunk_length - wav_length % chunk_length)) |
|
padded_wav_length = x.size(-1) |
|
x = x.view(x.size()[:-1] + (padded_wav_length // chunk_length, chunk_length)) |
|
x = torch.fft.irfft( |
|
torch.fft.rfft(x, n=n_fft) * compute_loudness.filter[sr], |
|
n=n_fft, |
|
) |
|
x = F.fold( |
|
x.transpose(-2, -1), |
|
(1, padded_wav_length + chunk_length), |
|
(1, n_fft), |
|
stride=(1, chunk_length), |
|
).squeeze_((-3, -2))[..., :wav_length] |
|
|
|
x.square_() |
|
for win_length in win_lengths: |
|
hop_length = win_length // 4 |
|
|
|
energy = ( |
|
x.unfold(-1, win_length, hop_length) |
|
.matmul(torch.hann_window(win_length, device=x.device)) |
|
.add_(win_length / 4.0 * 1e-5) |
|
.log10_() |
|
) |
|
|
|
results.append(energy) |
|
return results |
|
|
|
|
|
def slice_segments( |
|
x: torch.Tensor, start_indices: torch.Tensor, segment_length: int |
|
) -> torch.Tensor: |
|
batch_size, channels, _ = x.size() |
|
|
|
indices = start_indices[:, None, None] + torch.arange( |
|
segment_length, device=start_indices.device |
|
) |
|
|
|
indices = indices.expand(batch_size, channels, segment_length) |
|
return x.gather(2, indices) |
|
|
|
|
|
class ConverterNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
phone_extractor: PhoneExtractor, |
|
pitch_estimator: PitchEstimator, |
|
n_speakers: int, |
|
hidden_channels: int, |
|
): |
|
super().__init__() |
|
self.frozen_modules = { |
|
"phone_extractor": phone_extractor.eval().requires_grad_(False), |
|
"pitch_estimator": pitch_estimator.eval().requires_grad_(False), |
|
} |
|
self.out_sample_rate = out_sample_rate = 24000 |
|
self.embed_phone = nn.Conv1d(256, hidden_channels, 1) |
|
self.embed_phone.weight.data.normal_(0.0, math.sqrt(2.0 / (256 * 5))) |
|
self.embed_phone.bias.data.zero_() |
|
self.embed_quantized_pitch = nn.Embedding(384, hidden_channels) |
|
phase = ( |
|
torch.arange(384, dtype=torch.float)[:, None] |
|
* ( |
|
torch.arange(0, hidden_channels, 2, dtype=torch.float) |
|
* (-math.log(10000.0) / hidden_channels) |
|
).exp_() |
|
) |
|
self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin() |
|
self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_() |
|
self.embed_quantized_pitch.weight.data *= math.sqrt(4.0 / 5.0) |
|
self.embed_quantized_pitch.weight.requires_grad_(False) |
|
self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1) |
|
self.embed_pitch_features.weight.data.normal_(0.0, math.sqrt(2.0 / (4 * 5))) |
|
self.embed_pitch_features.bias.data.zero_() |
|
self.embed_speaker = nn.Embedding(n_speakers, hidden_channels) |
|
self.embed_speaker.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) |
|
self.embed_formant_shift = nn.Embedding(9, hidden_channels) |
|
self.embed_formant_shift.weight.data.normal_(0.0, math.sqrt(2.0 / 5.0)) |
|
self.vocoder = Vocoder( |
|
channels=hidden_channels, |
|
hop_length=out_sample_rate // 100, |
|
n_pre_blocks=4, |
|
out_sample_rate=out_sample_rate, |
|
) |
|
self.melspectrograms = nn.ModuleList() |
|
for win_length, n_mels in [ |
|
(32, 5), |
|
(64, 10), |
|
(128, 20), |
|
(256, 40), |
|
(512, 80), |
|
(1024, 160), |
|
(2048, 320), |
|
]: |
|
self.melspectrograms.append( |
|
torchaudio.transforms.MelSpectrogram( |
|
sample_rate=out_sample_rate, |
|
n_fft=win_length, |
|
win_length=win_length, |
|
hop_length=win_length // 4, |
|
n_mels=n_mels, |
|
power=2, |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
) |
|
|
|
def _get_resampler( |
|
self, orig_freq, new_freq, device, cache={} |
|
) -> torchaudio.transforms.Resample: |
|
key = orig_freq, new_freq |
|
if key in cache: |
|
return cache[key] |
|
resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to( |
|
device, non_blocking=True |
|
) |
|
cache[key] = resampler |
|
return resampler |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
target_speaker_id: torch.Tensor, |
|
formant_shift_semitone: torch.Tensor, |
|
pitch_shift_semitone: Optional[torch.Tensor] = None, |
|
slice_start_indices: Optional[torch.Tensor] = None, |
|
slice_segment_length: Optional[int] = None, |
|
return_stats: bool = False, |
|
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, _, _ = x.size() |
|
|
|
with torch.inference_mode(): |
|
phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"] |
|
pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"] |
|
|
|
phone = phone_extractor.units(x).transpose(1, 2) |
|
|
|
pitch, energy = pitch_estimator(x) |
|
|
|
if self.training: |
|
|
|
weights = pitch.softmax(1)[:, 1:, :].mean(2) |
|
|
|
mean_pitch = ( |
|
weights * torch.arange(1, 384, device=weights.device) |
|
).sum(1) / weights.sum(1) |
|
mean_pitch = mean_pitch.round_().long() |
|
target_pitch = torch.randint_like(mean_pitch, 64, 257) |
|
shift = target_pitch - mean_pitch |
|
shift_ratio = ( |
|
2.0 ** (shift.float() / pitch_estimator.bins_per_octave) |
|
).tolist() |
|
shift = [] |
|
interval_length = 100 |
|
interval_zeros = torch.zeros( |
|
(1, 1, interval_length * 160), device=x.device |
|
) |
|
concatenated_shifted_x = [] |
|
offsets = [0] |
|
torch.backends.cudnn.benchmark = False |
|
for i in range(batch_size): |
|
shift_ratio_i = shift_ratio[i] |
|
shift_ratio_fraction_i = Fraction.from_float( |
|
shift_ratio_i |
|
).limit_denominator(30) |
|
shift_numer_i = shift_ratio_fraction_i.numerator |
|
shift_denom_i = shift_ratio_fraction_i.denominator |
|
shift_ratio_i = shift_numer_i / shift_denom_i |
|
shift_i = int( |
|
round( |
|
math.log2(shift_ratio_i) * pitch_estimator.bins_per_octave |
|
) |
|
) |
|
shift.append(shift_i) |
|
shift_ratio[i] = shift_ratio_i |
|
|
|
with torch.amp.autocast("cuda", enabled=False): |
|
shifted_x_i = self._get_resampler( |
|
shift_numer_i, shift_denom_i, x.device |
|
)(x[i])[None] |
|
if shifted_x_i.size(2) % 160 != 0: |
|
shifted_x_i = F.pad( |
|
shifted_x_i, |
|
(0, 160 - shifted_x_i.size(2) % 160), |
|
mode="reflect", |
|
) |
|
assert shifted_x_i.size(2) % 160 == 0 |
|
offsets.append( |
|
offsets[-1] + interval_length + shifted_x_i.size(2) // 160 |
|
) |
|
concatenated_shifted_x.extend([interval_zeros, shifted_x_i]) |
|
if offsets[-1] % 256 != 0: |
|
|
|
|
|
concatenated_shifted_x.append( |
|
torch.zeros( |
|
(1, 1, (256 - offsets[-1] % 256) * 160), device=x.device |
|
) |
|
) |
|
|
|
concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2) |
|
assert concatenated_shifted_x.size(2) % (256 * 160) == 0 |
|
|
|
concatenated_pitch, concatenated_energy = pitch_estimator( |
|
concatenated_shifted_x |
|
) |
|
for i in range(batch_size): |
|
shift_i = shift[i] |
|
shift_ratio_i = shift_ratio[i] |
|
left = offsets[i] + interval_length |
|
right = offsets[i + 1] |
|
pitch_i = concatenated_pitch[:, :, left:right] |
|
energy_i = concatenated_energy[:, :, left:right] |
|
pitch_i = F.interpolate( |
|
pitch_i, |
|
scale_factor=shift_ratio_i, |
|
mode="linear", |
|
align_corners=False, |
|
) |
|
energy_i = F.interpolate( |
|
energy_i, |
|
scale_factor=shift_ratio_i, |
|
mode="linear", |
|
align_corners=False, |
|
) |
|
assert pitch_i.size(2) == energy_i.size(2) |
|
assert abs(pitch_i.size(2) - pitch.size(2)) <= 10 |
|
length = min(pitch_i.size(2), pitch.size(2)) |
|
|
|
if shift_i > 0: |
|
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] |
|
pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[ |
|
:, 1 + shift_i :, :length |
|
] |
|
pitch[i : i + 1, -shift_i:, :length] = -10.0 |
|
elif shift_i < 0: |
|
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length] |
|
pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0 |
|
pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[ |
|
:, 1:shift_i, :length |
|
] |
|
energy[i : i + 1, :, :length] = energy_i[:, :, :length] |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
quantized_pitch, pitch_features = pitch_estimator.sample_pitch( |
|
pitch, return_features=True |
|
) |
|
if pitch_shift_semitone is not None: |
|
quantized_pitch = torch.where( |
|
quantized_pitch == 0, |
|
quantized_pitch, |
|
( |
|
quantized_pitch |
|
+ ( |
|
pitch_shift_semitone[:, None] |
|
* (pitch_estimator.bins_per_octave / 12.0) |
|
) |
|
.round_() |
|
.long() |
|
).clamp_(1, 383), |
|
) |
|
pitch = 55.0 * 2.0 ** ( |
|
quantized_pitch.float() / pitch_estimator.bins_per_octave |
|
) |
|
|
|
|
|
|
|
energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect") |
|
quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect") |
|
pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect") |
|
|
|
pitch_features = torch.cat([energy, pitch_features], dim=1) |
|
formant_shift_indices = ( |
|
((formant_shift_semitone + 2.0) * 2.0).round_().long() |
|
) |
|
|
|
phone = phone.clone() |
|
quantized_pitch = quantized_pitch.clone() |
|
pitch_features = pitch_features.clone() |
|
formant_shift_indices = formant_shift_indices.clone() |
|
pitch = pitch.clone() |
|
|
|
|
|
x = ( |
|
self.embed_phone(phone) |
|
+ self.embed_quantized_pitch(quantized_pitch).transpose(1, 2) |
|
+ self.embed_pitch_features(pitch_features) |
|
+ ( |
|
self.embed_speaker(target_speaker_id)[:, :, None] |
|
+ self.embed_formant_shift(formant_shift_indices)[:, :, None] |
|
) |
|
) |
|
if slice_start_indices is not None: |
|
assert slice_segment_length is not None |
|
|
|
x = slice_segments(x, slice_start_indices, slice_segment_length) |
|
x = F.silu(x, inplace=True) |
|
|
|
y_g_hat, stats = self.vocoder(x, pitch) |
|
stats["pitch"] = pitch |
|
if return_stats: |
|
return y_g_hat, stats |
|
else: |
|
return y_g_hat |
|
|
|
def _normalize_melsp(self, x): |
|
return x.clamp(min=1e-10).log_().mul_(0.5) |
|
|
|
def forward_and_compute_loss( |
|
self, |
|
noisy_wavs_16k: torch.Tensor, |
|
target_speaker_id: torch.Tensor, |
|
formant_shift_semitone: torch.Tensor, |
|
slice_start_indices: torch.Tensor, |
|
slice_segment_length: int, |
|
y_all: torch.Tensor, |
|
enable_loss_ap: bool = False, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stats = {} |
|
loss_mel = 0.0 |
|
|
|
|
|
y_hat_all, intermediates = self( |
|
noisy_wavs_16k, |
|
target_speaker_id, |
|
formant_shift_semitone, |
|
return_stats=True, |
|
) |
|
|
|
with torch.amp.autocast("cuda", enabled=False): |
|
periodic_signal = intermediates["periodic_signal"].float() |
|
aperiodic_signal = intermediates["aperiodic_signal"].float() |
|
noise_excitation = intermediates["noise_excitation"].float() |
|
periodic_signal = periodic_signal[:, : noise_excitation.size(1)] |
|
aperiodic_signal = aperiodic_signal[:, : noise_excitation.size(1)] |
|
y_hat_all = y_hat_all.float() |
|
y_hat_all_truncated = y_hat_all.squeeze(1)[:, : periodic_signal.size(1)] |
|
y_all_truncated = y_all.squeeze(1)[:, : periodic_signal.size(1)] |
|
|
|
for melspectrogram in self.melspectrograms: |
|
melsp_periodic_signal = melspectrogram(periodic_signal) |
|
melsp_aperiodic_signal = melspectrogram(aperiodic_signal) |
|
melsp_noise_excitation = melspectrogram(noise_excitation) |
|
|
|
|
|
|
|
|
|
reference_melsp = melspectrogram.mel_scale( |
|
torch.full( |
|
(1, melspectrogram.n_fft // 2 + 1, 1), |
|
(1 / 6) * (3 / 8) * 0.5 * melspectrogram.win_length, |
|
device=noisy_wavs_16k.device, |
|
) |
|
) |
|
aperiodic_ratio = melsp_aperiodic_signal / ( |
|
melsp_periodic_signal + melsp_aperiodic_signal + 1e-5 |
|
) |
|
compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5) |
|
|
|
melsp_y_hat = melspectrogram(y_hat_all_truncated) |
|
melsp_y_hat = melsp_y_hat * ( |
|
(1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio |
|
) |
|
y_hat_mel = self._normalize_melsp(melsp_y_hat) |
|
|
|
y_mel = self._normalize_melsp(melspectrogram(y_all_truncated)) |
|
loss_mel_i = F.l1_loss(y_hat_mel, y_mel) |
|
loss_mel += loss_mel_i |
|
stats[ |
|
f"loss_mel_{melspectrogram.win_length}_{melspectrogram.n_mels}" |
|
] = loss_mel_i.item() |
|
|
|
loss_mel /= len(self.melspectrograms) |
|
|
|
if enable_loss_ap: |
|
t = ( |
|
torch.arange(intermediates["pitch"].size(1), device=y_all.device) |
|
* 0.01 |
|
) |
|
y_coarse_aperiodicity, y_rms = d4c( |
|
y_all.squeeze(1), |
|
intermediates["pitch"], |
|
t, |
|
self.vocoder.out_sample_rate, |
|
coarse_only=True, |
|
) |
|
y_coarse_aperiodicity = 10.0 ** (y_coarse_aperiodicity / 10.0) |
|
y_hat_coarse_aperiodicity, y_hat_rms = d4c( |
|
y_hat_all.squeeze(1), |
|
intermediates["pitch"], |
|
t, |
|
self.vocoder.out_sample_rate, |
|
coarse_only=True, |
|
) |
|
y_hat_coarse_aperiodicity = 10.0 ** (y_hat_coarse_aperiodicity / 10.0) |
|
rms = torch.maximum(y_rms, y_hat_rms) |
|
loss_ap = F.mse_loss( |
|
y_hat_coarse_aperiodicity, y_coarse_aperiodicity, reduction="none" |
|
) |
|
loss_ap *= (rms / (rms + 1e-3))[:, :, None] |
|
loss_ap = loss_ap.mean() |
|
else: |
|
loss_ap = torch.tensor(0.0) |
|
|
|
|
|
y_hat = slice_segments( |
|
y_hat_all, slice_start_indices * 240, slice_segment_length * 240 |
|
) |
|
|
|
y = slice_segments(y_all, slice_start_indices * 240, slice_segment_length * 240) |
|
return y, y_hat, y_hat_all, loss_mel, loss_ap, stats |
|
|
|
def merge_weights(self): |
|
self.vocoder.merge_weights() |
|
|
|
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]): |
|
if isinstance(f, (str, bytes, os.PathLike)): |
|
with open(f, "wb") as f: |
|
self.dump(f) |
|
return |
|
if not hasattr(f, "write"): |
|
raise TypeError |
|
|
|
dump_layer(self.embed_phone, f) |
|
dump_layer(self.embed_quantized_pitch, f) |
|
dump_layer(self.embed_pitch_features, f) |
|
dump_layer(self.vocoder, f) |
|
|
|
|
|
|
|
|
|
|
|
def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor: |
|
denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6) |
|
return tensor / denom |
|
|
|
|
|
class SANConv2d(nn.Conv2d): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
dilation: int = 1, |
|
bias: bool = True, |
|
padding_mode="zeros", |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=1, |
|
bias=bias, |
|
padding_mode=padding_mode, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6) |
|
self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight)) |
|
self.scale = nn.parameter.Parameter(scale.view(out_channels)) |
|
if bias: |
|
self.bias = nn.parameter.Parameter( |
|
torch.zeros(in_channels, device=device, dtype=dtype) |
|
) |
|
else: |
|
self.register_parameter("bias", None) |
|
|
|
def forward( |
|
self, input: torch.Tensor, flg_san_train: bool = False |
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
|
if self.bias is not None: |
|
input = input + self.bias.view(self.in_channels, 1, 1) |
|
normalized_weight = self._get_normalized_weight() |
|
scale = self.scale.view(self.out_channels, 1, 1) |
|
if flg_san_train: |
|
out_fun = F.conv2d( |
|
input, |
|
normalized_weight.detach(), |
|
None, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
out_dir = F.conv2d( |
|
input.detach(), |
|
normalized_weight, |
|
None, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
out = out_fun * scale, out_dir * scale.detach() |
|
else: |
|
out = F.conv2d( |
|
input, |
|
normalized_weight, |
|
None, |
|
self.stride, |
|
self.padding, |
|
self.dilation, |
|
self.groups, |
|
) |
|
out = out * scale |
|
return out |
|
|
|
@torch.no_grad() |
|
def normalize_weight(self): |
|
self.weight.data = self._get_normalized_weight() |
|
|
|
def _get_normalized_weight(self) -> torch.Tensor: |
|
return _normalize(self.weight, dim=[1, 2, 3]) |
|
|
|
|
|
def get_padding(kernel_size: int, dilation: int = 1) -> int: |
|
return (kernel_size * dilation - dilation) // 2 |
|
|
|
|
|
class DiscriminatorP(nn.Module): |
|
def __init__( |
|
self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False |
|
): |
|
super().__init__() |
|
self.period = period |
|
self.san = san |
|
|
|
self.convs = nn.ModuleList([ |
|
weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), |
|
weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), |
|
weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), |
|
weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))), |
|
weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))), |
|
]) |
|
|
|
if san: |
|
self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0)) |
|
else: |
|
self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) |
|
|
|
def forward( |
|
self, x: torch.Tensor, flg_san_train: bool = False |
|
) -> tuple[ |
|
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] |
|
]: |
|
fmap = [] |
|
|
|
b, c, t = x.shape |
|
if t % self.period != 0: |
|
n_pad = self.period - (t % self.period) |
|
x = F.pad(x, (0, n_pad), "reflect") |
|
t = t + n_pad |
|
x = x.view(b, c, t // self.period, self.period) |
|
|
|
for l in self.convs: |
|
x = l(x) |
|
x = F.silu(x, inplace=True) |
|
fmap.append(x) |
|
if self.san: |
|
x = self.conv_post(x, flg_san_train=flg_san_train) |
|
else: |
|
x = self.conv_post(x) |
|
if flg_san_train: |
|
x_fun, x_dir = x |
|
fmap.append(x_fun) |
|
x_fun = torch.flatten(x_fun, 1, -1) |
|
x_dir = torch.flatten(x_dir, 1, -1) |
|
x = x_fun, x_dir |
|
else: |
|
fmap.append(x) |
|
x = torch.flatten(x, 1, -1) |
|
return x, fmap |
|
|
|
|
|
class DiscriminatorR(nn.Module): |
|
def __init__(self, resolution: int, san: bool = False): |
|
super().__init__() |
|
self.resolution = resolution |
|
self.san = san |
|
assert len(self.resolution) == 3 |
|
self.convs = nn.ModuleList( |
|
[ |
|
weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), |
|
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), |
|
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), |
|
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))), |
|
weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), |
|
] |
|
) |
|
if san: |
|
self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1)) |
|
else: |
|
self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) |
|
|
|
def forward( |
|
self, x: torch.Tensor, flg_san_train: bool = False |
|
) -> tuple[ |
|
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor] |
|
]: |
|
fmap = [] |
|
|
|
x = self._spectrogram(x).unsqueeze(1) |
|
for l in self.convs: |
|
x = l(x) |
|
x = F.silu(x, inplace=True) |
|
fmap.append(x) |
|
if self.san: |
|
x = self.conv_post(x, flg_san_train=flg_san_train) |
|
else: |
|
x = self.conv_post(x) |
|
if flg_san_train: |
|
x_fun, x_dir = x |
|
fmap.append(x_fun) |
|
x_fun = torch.flatten(x_fun, 1, -1) |
|
x_dir = torch.flatten(x_dir, 1, -1) |
|
x = x_fun, x_dir |
|
else: |
|
fmap.append(x) |
|
x = torch.flatten(x, 1, -1) |
|
|
|
return x, fmap |
|
|
|
def _spectrogram(self, x: torch.Tensor) -> torch.Tensor: |
|
n_fft, hop_length, win_length = self.resolution |
|
x = F.pad( |
|
x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect" |
|
).squeeze(1) |
|
with torch.amp.autocast("cuda", enabled=False): |
|
mag = torch.stft( |
|
x.float(), |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
win_length=win_length, |
|
window=torch.ones(win_length, device=x.device), |
|
center=False, |
|
return_complex=True, |
|
).abs() |
|
|
|
return mag |
|
|
|
|
|
class MultiPeriodDiscriminator(nn.Module): |
|
def __init__(self, san: bool = False): |
|
super().__init__() |
|
resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] |
|
periods = [2, 3, 5, 7, 11] |
|
self.discriminators = nn.ModuleList( |
|
[DiscriminatorR(r, san=san) for r in resolutions] |
|
+ [DiscriminatorP(p, san=san) for p in periods] |
|
) |
|
self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [ |
|
f"P_{p}" for p in periods |
|
] |
|
self.san = san |
|
|
|
def forward( |
|
self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False |
|
) -> tuple[ |
|
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], |
|
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]], |
|
list[list[torch.Tensor]], |
|
list[list[torch.Tensor]], |
|
]: |
|
batch_size = y.size(0) |
|
concatenated_y_y_hat = torch.cat([y, y_hat]) |
|
y_d_rs = [] |
|
y_d_gs = [] |
|
fmap_rs = [] |
|
fmap_gs = [] |
|
for d in self.discriminators: |
|
if flg_san_train: |
|
(y_d_fun, y_d_dir), fmap = d( |
|
concatenated_y_y_hat, flg_san_train=flg_san_train |
|
) |
|
y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size) |
|
y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size) |
|
y_d_r = y_d_r_fun, y_d_r_dir |
|
y_d_g = y_d_g_fun, y_d_g_dir |
|
else: |
|
y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train) |
|
y_d_r, y_d_g = torch.split(y_d, batch_size) |
|
fmap_r = [] |
|
fmap_g = [] |
|
for fm in fmap: |
|
fm_r, fm_g = torch.split(fm, batch_size) |
|
fmap_r.append(fm_r) |
|
fmap_g.append(fm_g) |
|
y_d_rs.append(y_d_r) |
|
y_d_gs.append(y_d_g) |
|
fmap_rs.append(fmap_r) |
|
fmap_gs.append(fmap_g) |
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs |
|
|
|
def forward_and_compute_loss( |
|
self, y: torch.Tensor, y_hat: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, float]]: |
|
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=self.san) |
|
stats = {} |
|
assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators) |
|
with torch.amp.autocast("cuda", enabled=False): |
|
|
|
d_loss = 0.0 |
|
for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names): |
|
if self.san: |
|
dr_fun, dr_dir = map(lambda x: x.float(), dr) |
|
dg_fun, dg_dir = map(lambda x: x.float(), dg) |
|
r_loss_fun = F.softplus(1.0 - dr_fun).square().mean() |
|
g_loss_fun = F.softplus(dg_fun).square().mean() |
|
r_loss_dir = F.softplus(1.0 - dr_dir).square().mean() |
|
g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean() |
|
r_loss = r_loss_fun + r_loss_dir |
|
g_loss = g_loss_fun + g_loss_dir |
|
else: |
|
dr = dr.float() |
|
dg = dg.float() |
|
r_loss = (1.0 - dr).square().mean() |
|
g_loss = dg.square().mean() |
|
stats[f"{name}_dr_loss"] = r_loss.item() |
|
stats[f"{name}_dg_loss"] = g_loss.item() |
|
d_loss += r_loss + g_loss |
|
|
|
adv_loss = 0.0 |
|
for dg, name in zip(y_d_gs, self.discriminator_names): |
|
dg = dg.float() |
|
if self.san: |
|
g_loss = F.softplus(1.0 - dg).square().mean() |
|
else: |
|
g_loss = (1.0 - dg).square().mean() |
|
stats[f"{name}_gg_loss"] = g_loss.item() |
|
adv_loss += g_loss |
|
|
|
fm_loss = 0.0 |
|
for fr, fg, name in zip(fmap_rs, fmap_gs, self.discriminator_names): |
|
fm_loss_i = 0.0 |
|
for j, (r, g) in enumerate(zip(fr, fg)): |
|
fm_loss_ij = (r.detach().float() - g.float()).abs().mean() |
|
stats[f"~{name}_fm_loss_{j}"] = fm_loss_ij.item() |
|
fm_loss_i += fm_loss_ij |
|
stats[f"{name}_fm_loss"] = fm_loss_i.item() |
|
fm_loss += fm_loss_i |
|
return d_loss, adv_loss, fm_loss, stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradBalancer: |
|
"""Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py""" |
|
|
|
def __init__( |
|
self, |
|
weights: dict[str, float], |
|
rescale_grads: bool = True, |
|
total_norm: float = 1.0, |
|
ema_decay: float = 0.999, |
|
per_batch_item: bool = True, |
|
): |
|
self.weights = weights |
|
self.per_batch_item = per_batch_item |
|
self.total_norm = total_norm |
|
self.ema_decay = ema_decay |
|
self.rescale_grads = rescale_grads |
|
|
|
self.ema_total: dict[str, float] = defaultdict(float) |
|
self.ema_fix: dict[str, float] = defaultdict(float) |
|
|
|
def backward( |
|
self, |
|
losses: dict[str, torch.Tensor], |
|
input: torch.Tensor, |
|
scaler: Optional[torch.amp.GradScaler] = None, |
|
skip_update_ema: bool = False, |
|
) -> dict[str, float]: |
|
stats = {} |
|
if skip_update_ema: |
|
assert len(losses) == len(self.ema_total) |
|
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} |
|
else: |
|
|
|
norms = {} |
|
grads = {} |
|
for name, loss in losses.items(): |
|
if scaler is not None: |
|
loss = scaler.scale(loss) |
|
(grad,) = torch.autograd.grad(loss, [input], retain_graph=True) |
|
if not grad.isfinite().all(): |
|
input.backward(grad) |
|
return {} |
|
grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale()) |
|
if self.per_batch_item: |
|
dims = tuple(range(1, grad.dim())) |
|
ema_norm = grad.norm(dim=dims).mean() |
|
else: |
|
ema_norm = grad.norm() |
|
norms[name] = float(ema_norm) |
|
grads[name] = grad |
|
|
|
|
|
for key, value in norms.items(): |
|
self.ema_total[key] = self.ema_total[key] * self.ema_decay + value |
|
self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0 |
|
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()} |
|
|
|
|
|
total_ema_norm = sum(ema_norms.values()) |
|
for k, ema_norm in ema_norms.items(): |
|
stats[f"grad_norm_value_{k}"] = ema_norm |
|
stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12) |
|
|
|
|
|
if self.rescale_grads: |
|
total_weights = sum([self.weights[k] for k in ema_norms]) |
|
ratios = {k: w / total_weights for k, w in self.weights.items()} |
|
|
|
|
|
loss = 0.0 |
|
for name, ema_norm in ema_norms.items(): |
|
if self.rescale_grads: |
|
scale = ratios[name] * self.total_norm / (ema_norm + 1e-12) |
|
else: |
|
scale = self.weights[name] |
|
loss += (losses if skip_update_ema else grads)[name] * scale |
|
if scaler is not None: |
|
loss = scaler.scale(loss) |
|
if skip_update_ema: |
|
(loss,) = torch.autograd.grad(loss, [input]) |
|
input.backward(loss) |
|
return stats |
|
|
|
def state_dict(self) -> dict[str, dict[str, float]]: |
|
return { |
|
"ema_total": dict(self.ema_total), |
|
"ema_fix": dict(self.ema_fix), |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.ema_total = defaultdict(float, state_dict["ema_total"]) |
|
self.ema_fix = defaultdict(float, state_dict["ema_fix"]) |
|
|
|
|
|
class QualityTester(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.utmos = torch.hub.load( |
|
"tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True |
|
).eval() |
|
|
|
@torch.inference_mode() |
|
def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]: |
|
res = {"utmos": self.utmos(wav, sr=16000).tolist()} |
|
return res |
|
|
|
def test( |
|
self, converted_wav: torch.Tensor, source_wav: torch.Tensor |
|
) -> dict[str, list[float]]: |
|
|
|
res = {} |
|
res.update(self.compute_mos(converted_wav)) |
|
return res |
|
|
|
def test_many( |
|
self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor] |
|
) -> tuple[dict[str, float], dict[str, list[float]]]: |
|
|
|
results = defaultdict(list) |
|
assert len(converted_wavs) == len(source_wavs) |
|
for converted_wav, source_wav in zip(converted_wavs, source_wavs): |
|
res = self.test(converted_wav, source_wav) |
|
for metric_name, value in res.items(): |
|
results[metric_name].extend(value) |
|
return { |
|
metric_name: sum(values) / len(values) |
|
for metric_name, values in results.items() |
|
}, results |
|
|
|
|
|
def compute_grad_norm( |
|
model: nn.Module, return_stats: bool = False |
|
) -> Union[float, dict[str, float]]: |
|
total_norm = 0.0 |
|
stats = {} |
|
for name, p in model.named_parameters(): |
|
if p.grad is None: |
|
continue |
|
param_norm = p.grad.data.norm().item() |
|
if not math.isfinite(param_norm): |
|
param_norm = p.grad.data.float().norm().item() |
|
total_norm += param_norm * param_norm |
|
if return_stats: |
|
stats[f"grad_norm_{name}"] = param_norm |
|
total_norm = math.sqrt(total_norm) |
|
if return_stats: |
|
return total_norm, stats |
|
else: |
|
return total_norm |
|
|
|
|
|
def compute_mean_f0( |
|
files: list[Path], method: Literal["dio", "harvest"] = "dio" |
|
) -> float: |
|
sum_log_f0 = 0.0 |
|
n_frames = 0 |
|
for file in files: |
|
wav, sr = torchaudio.load(file, backend="soundfile") |
|
if method == "dio": |
|
f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr) |
|
elif method == "harvest": |
|
f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr) |
|
else: |
|
raise ValueError(f"Invalid method: {method}") |
|
f0 = f0[f0 > 0] |
|
sum_log_f0 += float(np.log(f0).sum()) |
|
n_frames += len(f0) |
|
if n_frames == 0: |
|
return math.nan |
|
mean_log_f0 = sum_log_f0 / n_frames |
|
return math.exp(mean_log_f0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_resampler( |
|
sr_before: int, sr_after: int, device="cpu", cache={} |
|
) -> torchaudio.transforms.Resample: |
|
if not isinstance(device, str): |
|
device = str(device) |
|
if (sr_before, sr_after, device) not in cache: |
|
cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample( |
|
sr_before, sr_after |
|
).to(device) |
|
return cache[(sr_before, sr_after, device)] |
|
|
|
|
|
def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor: |
|
n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length() |
|
res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n) |
|
return res[..., : signal.size(-1)] |
|
|
|
|
|
def random_filter(audio: torch.Tensor) -> torch.Tensor: |
|
assert audio.ndim == 2 |
|
ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375 |
|
a, b = ab[:, :3], ab[:, 3:] |
|
a[:, 0] = 1.0 |
|
b[:, 0] = 1.0 |
|
audio = torchaudio.functional.lfilter(audio, a, b, clamp=False) |
|
return audio |
|
|
|
|
|
def get_noise( |
|
n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]] |
|
) -> torch.Tensor: |
|
resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1] |
|
wavs = [] |
|
current_length = 0 |
|
while current_length < n_samples: |
|
idx_files = torch.randint(0, len(files), ()) |
|
file = files[idx_files] |
|
wav, sr = torchaudio.load(file, backend="soundfile") |
|
assert wav.size(0) == 1 |
|
augmented_sample_rate = int( |
|
round( |
|
sample_rate |
|
* resample_augmentation_candidates[ |
|
torch.randint(0, len(resample_augmentation_candidates), ()) |
|
] |
|
) |
|
) |
|
resampler = get_resampler(sr, augmented_sample_rate) |
|
wav = resampler(wav) |
|
wav = random_filter(wav) |
|
wav *= 0.99 / (wav.abs().max() + 1e-5) |
|
wavs.append(wav) |
|
current_length += wav.size(1) |
|
start = torch.randint(0, current_length - n_samples + 1, ()) |
|
wav = torch.cat(wavs, dim=1)[:, start : start + n_samples] |
|
assert wav.size() == (1, n_samples), wav.size() |
|
return wav |
|
|
|
|
|
def get_butterworth_lpf( |
|
cutoff_freq: int, sample_rate: int, cache={} |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
if (cutoff_freq, sample_rate) not in cache: |
|
q = math.sqrt(0.5) |
|
omega = math.tau * cutoff_freq / sample_rate |
|
cos_omega = math.cos(omega) |
|
alpha = math.sin(omega) / (2.0 * q) |
|
b1 = (1.0 - cos_omega) / (1.0 + alpha) |
|
b0 = b1 * 0.5 |
|
a1 = -2.0 * cos_omega / (1.0 + alpha) |
|
a2 = (1.0 - alpha) / (1.0 + alpha) |
|
cache[(cutoff_freq, sample_rate)] = torch.tensor([b0, b1, b0]), torch.tensor( |
|
[1.0, a1, a2] |
|
) |
|
return cache[(cutoff_freq, sample_rate)] |
|
|
|
|
|
def augment_audio( |
|
clean: torch.Tensor, |
|
sample_rate: int, |
|
noise_files: list[Union[str, bytes, os.PathLike]], |
|
ir_files: list[Union[str, bytes, os.PathLike]], |
|
) -> torch.Tensor: |
|
|
|
assert clean.size(0) == 1 |
|
n_samples = clean.size(1) |
|
|
|
snr_candidates = [-20, -25, -30, -35, -40, -45] |
|
|
|
original_clean_rms = clean.square().mean().sqrt_() |
|
|
|
|
|
noise = get_noise(n_samples, sample_rate, noise_files) |
|
signals = torch.cat([clean, noise]) |
|
|
|
|
|
signals = random_filter(signals) |
|
|
|
|
|
if torch.rand(()) < 0.5: |
|
ir_file = ir_files[torch.randint(0, len(ir_files), ())] |
|
ir, sr = torchaudio.load(ir_file, backend="soundfile") |
|
assert ir.size() == (2, sr), ir.size() |
|
assert sr == sample_rate, (sr, sample_rate) |
|
signals = convolve(signals, ir) |
|
|
|
|
|
if torch.rand(()) < 0.2: |
|
if signals.abs().max() > 0.8: |
|
signals /= signals.abs().max() * 1.25 |
|
cutoff_freq_candidates = [2000, 3000, 4000, 6000] |
|
cutoff_freq = cutoff_freq_candidates[ |
|
torch.randint(0, len(cutoff_freq_candidates), ()) |
|
] |
|
b, a = get_butterworth_lpf(cutoff_freq, sample_rate) |
|
signals = torchaudio.functional.lfilter(signals, a, b, clamp=False) |
|
|
|
|
|
clean, noise = signals |
|
clean_rms = clean.square().mean().sqrt_() |
|
clean *= original_clean_rms / clean_rms |
|
|
|
|
|
clean_level = clean.square().square_().mean().sqrt_().sqrt_() |
|
noise_level = noise.square().square_().mean().sqrt_().sqrt_() |
|
|
|
snr = snr_candidates[torch.randint(0, len(snr_candidates), ())] |
|
|
|
noisy = clean + noise * (10.0 ** (snr / 20.0) * clean_level / (noise_level + 1e-5)) |
|
return noisy |
|
|
|
|
|
class WavDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
audio_files: list[tuple[Path, int]], |
|
in_sample_rate: int = 16000, |
|
out_sample_rate: int = 24000, |
|
wav_length: int = 4 * 24000, |
|
segment_length: int = 100, |
|
noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, |
|
ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None, |
|
): |
|
self.audio_files = audio_files |
|
self.in_sample_rate = in_sample_rate |
|
self.out_sample_rate = out_sample_rate |
|
self.wav_length = wav_length |
|
self.segment_length = segment_length |
|
self.noise_files = noise_files |
|
self.ir_files = ir_files |
|
|
|
if (noise_files is None) is not (ir_files is None): |
|
raise ValueError("noise_files and ir_files must be both None or not None") |
|
|
|
self.in_hop_length = in_sample_rate // 100 |
|
self.out_hop_length = out_sample_rate // 100 |
|
|
|
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]: |
|
file, speaker_id = self.audio_files[index] |
|
clean_wav, sample_rate = torchaudio.load(file, backend="soundfile") |
|
if clean_wav.size(0) != 1: |
|
ch = torch.randint(0, clean_wav.size(0), ()) |
|
clean_wav = clean_wav[ch : ch + 1] |
|
|
|
formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0] |
|
formant_shift = formant_shift_candidates[ |
|
torch.randint(0, len(formant_shift_candidates), ()).item() |
|
] |
|
|
|
resampler_fraction = Fraction( |
|
sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0) |
|
).limit_denominator(300) |
|
clean_wav = get_resampler( |
|
resampler_fraction.numerator, resampler_fraction.denominator |
|
)(clean_wav) |
|
|
|
assert clean_wav.size(0) == 1 |
|
assert clean_wav.size(1) != 0 |
|
|
|
clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length)) |
|
|
|
if self.noise_files is None: |
|
assert False |
|
noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( |
|
clean_wav |
|
) |
|
else: |
|
clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)( |
|
clean_wav |
|
) |
|
noisy_wav_16k = augment_audio( |
|
clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files |
|
) |
|
|
|
clean_wav = clean_wav.squeeze_(0) |
|
noisy_wav_16k = noisy_wav_16k.squeeze_(0) |
|
|
|
|
|
amplitude = torch.rand(()).item() * 0.899 + 0.1 |
|
factor = amplitude / clean_wav.abs().max() |
|
clean_wav *= factor |
|
noisy_wav_16k *= factor |
|
while noisy_wav_16k.abs().max() >= 1.0: |
|
clean_wav *= 0.5 |
|
noisy_wav_16k *= 0.5 |
|
|
|
return clean_wav, noisy_wav_16k, speaker_id, formant_shift |
|
|
|
def __len__(self) -> int: |
|
return len(self.audio_files) |
|
|
|
def collate( |
|
self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]] |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
assert self.wav_length % self.out_hop_length == 0 |
|
length = self.wav_length // self.out_hop_length |
|
clean_wavs = [] |
|
noisy_wavs = [] |
|
slice_starts = [] |
|
speaker_ids = [] |
|
formant_shifts = [] |
|
for clean_wav, noisy_wav, speaker_id, formant_shift in batch: |
|
|
|
(voiced,) = clean_wav.nonzero(as_tuple=True) |
|
assert voiced.numel() != 0 |
|
center = voiced[torch.randint(0, voiced.numel(), ()).item()].item() |
|
|
|
slice_start = center - self.segment_length * self.out_hop_length // 2 |
|
assert slice_start >= 0 |
|
|
|
r = torch.randint(0, length - self.segment_length + 1, ()).item() |
|
offset = slice_start - r * self.out_hop_length |
|
clean_wavs.append(clean_wav[offset : offset + self.wav_length]) |
|
offset_in_sample_rate = int( |
|
round(offset * self.in_sample_rate / self.out_sample_rate) |
|
) |
|
noisy_wavs.append( |
|
noisy_wav[ |
|
offset_in_sample_rate : offset_in_sample_rate |
|
+ length * self.in_hop_length |
|
] |
|
) |
|
slice_start = r |
|
slice_starts.append(slice_start) |
|
speaker_ids.append(speaker_id) |
|
formant_shifts.append(formant_shift) |
|
clean_wavs = torch.stack(clean_wavs) |
|
noisy_wavs = torch.stack(noisy_wavs) |
|
slice_starts = torch.tensor(slice_starts) |
|
speaker_ids = torch.tensor(speaker_ids) |
|
formant_shifts = torch.tensor(formant_shifts) |
|
return ( |
|
clean_wavs, |
|
noisy_wavs, |
|
slice_starts, |
|
speaker_ids, |
|
formant_shifts, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
AUDIO_FILE_SUFFIXES = { |
|
".wav", |
|
".aif", |
|
".aiff", |
|
".fla", |
|
".flac", |
|
".oga", |
|
".ogg", |
|
".opus", |
|
".mp3", |
|
} |
|
|
|
|
|
def prepare_training(): |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"device={device}") |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
(h, in_wav_dataset_dir, out_dir, resume, skip_training) = ( |
|
prepare_training_configs_for_experiment |
|
if is_notebook() |
|
else prepare_training_configs |
|
)() |
|
|
|
print("config:") |
|
pprint(h) |
|
print() |
|
h = AttrDict(h) |
|
|
|
if not in_wav_dataset_dir.is_dir(): |
|
raise ValueError(f"{in_wav_dataset_dir} is not found.") |
|
if resume: |
|
latest_checkpoint_file = out_dir / "checkpoint_latest.pt" |
|
if not latest_checkpoint_file.is_file(): |
|
raise ValueError(f"{latest_checkpoint_file} is not found.") |
|
else: |
|
if out_dir.is_dir(): |
|
if (out_dir / "checkpoint_latest.pt").is_file(): |
|
raise ValueError( |
|
f"{out_dir / 'checkpoint_latest.pt'} already exists. " |
|
"Please specify a different output directory, or use --resume option." |
|
) |
|
for file in out_dir.iterdir(): |
|
if file.suffix == ".pt": |
|
raise ValueError( |
|
f"{out_dir} already contains model files. " |
|
"Please specify a different output directory." |
|
) |
|
else: |
|
out_dir.mkdir(parents=True) |
|
|
|
in_ir_wav_dir = repo_root() / h.in_ir_wav_dir |
|
in_noise_wav_dir = repo_root() / h.in_noise_wav_dir |
|
in_test_wav_dir = repo_root() / h.in_test_wav_dir |
|
|
|
assert in_wav_dataset_dir.is_dir(), in_wav_dataset_dir |
|
assert out_dir.is_dir(), out_dir |
|
assert in_ir_wav_dir.is_dir(), in_ir_wav_dir |
|
assert in_noise_wav_dir.is_dir(), in_noise_wav_dir |
|
assert in_test_wav_dir.is_dir(), in_test_wav_dir |
|
|
|
|
|
noise_files = sorted( |
|
list(in_noise_wav_dir.rglob("*.wav")) + list(in_noise_wav_dir.rglob("*.flac")) |
|
) |
|
if len(noise_files) == 0: |
|
raise ValueError(f"No audio data found in {in_noise_wav_dir}.") |
|
ir_files = sorted( |
|
list(in_ir_wav_dir.rglob("*.wav")) + list(in_ir_wav_dir.rglob("*.flac")) |
|
) |
|
if len(ir_files) == 0: |
|
raise ValueError(f"No audio data found in {in_ir_wav_dir}.") |
|
|
|
|
|
|
|
def get_training_filelist(in_wav_dataset_dir: Path): |
|
min_data_per_speaker = 1 |
|
speakers: list[str] = [] |
|
training_filelist: list[tuple[Path, int]] = [] |
|
speaker_audio_files: list[list[Path]] = [] |
|
for speaker_dir in sorted(in_wav_dataset_dir.iterdir()): |
|
if not speaker_dir.is_dir(): |
|
continue |
|
candidates = [] |
|
for wav_file in sorted(speaker_dir.rglob("*")): |
|
if ( |
|
not wav_file.is_file() |
|
or wav_file.suffix.lower() not in AUDIO_FILE_SUFFIXES |
|
): |
|
continue |
|
candidates.append(wav_file) |
|
if len(candidates) >= min_data_per_speaker: |
|
speaker_id = len(speakers) |
|
speakers.append(speaker_dir.name) |
|
training_filelist.extend([(file, speaker_id) for file in candidates]) |
|
speaker_audio_files.append(candidates) |
|
return speakers, training_filelist, speaker_audio_files |
|
|
|
speakers, training_filelist, speaker_audio_files = get_training_filelist( |
|
in_wav_dataset_dir |
|
) |
|
n_speakers = len(speakers) |
|
if n_speakers == 0: |
|
raise ValueError(f"No speaker data found in {in_wav_dataset_dir}.") |
|
print(f"{n_speakers=}") |
|
for i, speaker in enumerate(speakers): |
|
print(f" {i:{len(str(n_speakers - 1))}d}: {speaker}") |
|
print() |
|
print(f"{len(training_filelist)=}") |
|
|
|
def get_test_filelist( |
|
in_test_wav_dir: Path, n_speakers: int |
|
) -> list[tuple[Path, list[int]]]: |
|
max_n_test_files = 1000 |
|
test_filelist = [] |
|
rng = Random(42) |
|
|
|
def get_target_id_generator(): |
|
if n_speakers > 8: |
|
while True: |
|
order = list(range(n_speakers)) |
|
rng.shuffle(order) |
|
yield from order |
|
else: |
|
while True: |
|
yield from range(n_speakers) |
|
|
|
target_id_generator = get_target_id_generator() |
|
for file in sorted(in_test_wav_dir.iterdir())[:max_n_test_files]: |
|
if file.suffix.lower() not in AUDIO_FILE_SUFFIXES: |
|
continue |
|
target_ids = [next(target_id_generator) for _ in range(min(8, n_speakers))] |
|
test_filelist.append((file, target_ids)) |
|
return test_filelist |
|
|
|
test_filelist = get_test_filelist(in_test_wav_dir, n_speakers) |
|
if len(test_filelist) == 0: |
|
warnings.warn(f"No audio data found in {test_filelist}.") |
|
print(f"{len(test_filelist)=}") |
|
for file, target_ids in test_filelist[:12]: |
|
print(f" {file}, {target_ids}") |
|
if len(test_filelist) > 12: |
|
print(" ...") |
|
print() |
|
|
|
|
|
|
|
training_dataset = WavDataset( |
|
training_filelist, |
|
in_sample_rate=h.in_sample_rate, |
|
out_sample_rate=h.out_sample_rate, |
|
wav_length=h.wav_length, |
|
segment_length=h.segment_length, |
|
noise_files=noise_files, |
|
ir_files=ir_files, |
|
) |
|
training_loader = torch.utils.data.DataLoader( |
|
training_dataset, |
|
num_workers=min(h.num_workers, os.cpu_count()), |
|
collate_fn=training_dataset.collate, |
|
shuffle=True, |
|
sampler=None, |
|
batch_size=h.batch_size, |
|
pin_memory=True, |
|
drop_last=True, |
|
persistent_workers=True, |
|
) |
|
|
|
print("Computing mean F0s of target speakers...", end="") |
|
speaker_f0s = [] |
|
for speaker, files in enumerate(speaker_audio_files): |
|
if len(files) > 10: |
|
files = Random(42).sample(files, 10) |
|
f0 = compute_mean_f0(files) |
|
speaker_f0s.append(f0) |
|
if speaker % 5 == 0: |
|
print() |
|
print(f" {speaker:3d}: {f0:.1f}Hz", end=",") |
|
print() |
|
print("Done.") |
|
print("Computing pitch shifts for test files...") |
|
test_pitch_shifts = [] |
|
source_f0s = [] |
|
for i, (file, target_ids) in enumerate(tqdm(test_filelist)): |
|
source_f0 = compute_mean_f0([file], method="harvest") |
|
source_f0s.append(source_f0) |
|
if math.isnan(source_f0): |
|
test_pitch_shifts.append([0] * len(target_ids)) |
|
continue |
|
pitch_shifts = [] |
|
for target_id in target_ids: |
|
target_f0 = speaker_f0s[target_id] |
|
if target_f0 != target_f0: |
|
pitch_shift = 0 |
|
else: |
|
pitch_shift = int(round(12.0 * math.log2(target_f0 / source_f0))) |
|
pitch_shifts.append(pitch_shift) |
|
test_pitch_shifts.append(pitch_shifts) |
|
print("Done.") |
|
|
|
|
|
|
|
phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False) |
|
phone_extractor_checkpoint = torch.load( |
|
repo_root() / h.phone_extractor_file, map_location="cpu", weights_only=True |
|
) |
|
print( |
|
phone_extractor.load_state_dict(phone_extractor_checkpoint["phone_extractor"]) |
|
) |
|
del phone_extractor_checkpoint |
|
|
|
pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False) |
|
pitch_estimator_checkpoint = torch.load( |
|
repo_root() / h.pitch_estimator_file, map_location="cpu", weights_only=True |
|
) |
|
print( |
|
pitch_estimator.load_state_dict(pitch_estimator_checkpoint["pitch_estimator"]) |
|
) |
|
del pitch_estimator_checkpoint |
|
|
|
net_g = ConverterNetwork( |
|
phone_extractor, |
|
pitch_estimator, |
|
n_speakers, |
|
h.hidden_channels, |
|
).to(device) |
|
net_d = MultiPeriodDiscriminator(san=h.san).to(device) |
|
|
|
optim_g = torch.optim.AdamW( |
|
net_g.parameters(), |
|
h.learning_rate_g, |
|
betas=h.adam_betas, |
|
eps=h.adam_eps, |
|
) |
|
optim_d = torch.optim.AdamW( |
|
net_d.parameters(), |
|
h.learning_rate_d, |
|
betas=h.adam_betas, |
|
eps=h.adam_eps, |
|
) |
|
|
|
grad_scaler = torch.amp.GradScaler("cuda", enabled=h.use_amp) |
|
grad_balancer = GradBalancer( |
|
weights={ |
|
"loss_mel": h.grad_weight_mel, |
|
"loss_adv": h.grad_weight_adv, |
|
"loss_fm": h.grad_weight_fm, |
|
} |
|
| ({"loss_ap": h.grad_weight_ap} if h.grad_weight_ap else {}), |
|
ema_decay=h.grad_balancer_ema_decay, |
|
) |
|
resample_to_in_sample_rate = torchaudio.transforms.Resample( |
|
h.out_sample_rate, h.in_sample_rate |
|
).to(device) |
|
|
|
|
|
|
|
initial_iteration = 0 |
|
if resume: |
|
checkpoint_file = latest_checkpoint_file |
|
elif h.pretrained_file is not None: |
|
checkpoint_file = repo_root() / h.pretrained_file |
|
else: |
|
checkpoint_file = None |
|
if checkpoint_file is not None: |
|
checkpoint = torch.load(checkpoint_file, map_location="cpu", weights_only=True) |
|
if not resume and not skip_training: |
|
checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"]) |
|
initial_speaker_embedding = checkpoint["net_g"][ |
|
"embed_speaker.weight" |
|
].mean(0, keepdim=True) |
|
if True: |
|
checkpoint["net_g"]["embed_speaker.weight"] = initial_speaker_embedding[ |
|
[0] * n_speakers |
|
] |
|
else: |
|
assert n_speakers > checkpoint_n_speakers |
|
print( |
|
f"embed_speaker.weight was padded: {checkpoint_n_speakers} -> {n_speakers}" |
|
) |
|
checkpoint["net_g"]["embed_speaker.weight"] = F.pad( |
|
checkpoint["net_g"]["embed_speaker.weight"], |
|
(0, 0, 0, n_speakers - checkpoint_n_speakers), |
|
) |
|
checkpoint["net_g"]["embed_speaker.weight"][ |
|
checkpoint_n_speakers: |
|
] = initial_speaker_embedding |
|
print(net_g.load_state_dict(checkpoint["net_g"], strict=False)) |
|
print(net_d.load_state_dict(checkpoint["net_d"], strict=False)) |
|
if resume or skip_training: |
|
optim_g.load_state_dict(checkpoint["optim_g"]) |
|
optim_d.load_state_dict(checkpoint["optim_d"]) |
|
initial_iteration = checkpoint["iteration"] |
|
grad_balancer.load_state_dict(checkpoint["grad_balancer"]) |
|
grad_scaler.load_state_dict(checkpoint["grad_scaler"]) |
|
|
|
|
|
|
|
def get_cosine_annealing_warmup_scheduler( |
|
optimizer: torch.optim.Optimizer, |
|
warmup_epochs: int, |
|
total_epochs: int, |
|
min_learning_rate: float, |
|
) -> torch.optim.lr_scheduler.LambdaLR: |
|
lr_ratio = min_learning_rate / optimizer.param_groups[0]["lr"] |
|
m = 0.5 * (1.0 - lr_ratio) |
|
a = 0.5 * (1.0 + lr_ratio) |
|
|
|
def lr_lambda(current_epoch: int) -> float: |
|
if current_epoch < warmup_epochs: |
|
return current_epoch / warmup_epochs |
|
elif current_epoch < total_epochs: |
|
rate = (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs) |
|
return math.cos(rate * math.pi) * m + a |
|
else: |
|
return min_learning_rate |
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
scheduler_g = get_cosine_annealing_warmup_scheduler( |
|
optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate_g |
|
) |
|
scheduler_d = get_cosine_annealing_warmup_scheduler( |
|
optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate_d |
|
) |
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings( |
|
"ignore", |
|
message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.", |
|
) |
|
for _ in range(initial_iteration + 1): |
|
scheduler_g.step() |
|
scheduler_d.step() |
|
|
|
net_g.train() |
|
net_d.train() |
|
|
|
|
|
|
|
dict_scalars = defaultdict(list) |
|
quality_tester = QualityTester().eval().to(device) |
|
if skip_training: |
|
writer = None |
|
else: |
|
writer = SummaryWriter(out_dir) |
|
writer.add_text( |
|
"log", |
|
f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.", |
|
initial_iteration, |
|
) |
|
if not resume: |
|
with open(out_dir / "config.json", "w", encoding="utf-8") as f: |
|
json.dump(dict(h), f, indent=4) |
|
if not is_notebook(): |
|
shutil.copy(__file__, out_dir) |
|
|
|
return ( |
|
device, |
|
in_wav_dataset_dir, |
|
h, |
|
out_dir, |
|
speakers, |
|
test_filelist, |
|
training_loader, |
|
speaker_f0s, |
|
test_pitch_shifts, |
|
phone_extractor, |
|
pitch_estimator, |
|
net_g, |
|
net_d, |
|
optim_g, |
|
optim_d, |
|
grad_scaler, |
|
grad_balancer, |
|
resample_to_in_sample_rate, |
|
initial_iteration, |
|
scheduler_g, |
|
scheduler_d, |
|
dict_scalars, |
|
quality_tester, |
|
writer, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
( |
|
device, |
|
in_wav_dataset_dir, |
|
h, |
|
out_dir, |
|
speakers, |
|
test_filelist, |
|
training_loader, |
|
speaker_f0s, |
|
test_pitch_shifts, |
|
phone_extractor, |
|
pitch_estimator, |
|
net_g, |
|
net_d, |
|
optim_g, |
|
optim_d, |
|
grad_scaler, |
|
grad_balancer, |
|
resample_to_in_sample_rate, |
|
initial_iteration, |
|
scheduler_g, |
|
scheduler_d, |
|
dict_scalars, |
|
quality_tester, |
|
writer, |
|
) = prepare_training() |
|
|
|
if __name__ == "__main__" and writer is not None: |
|
if h.compile_convnext: |
|
raw_convnextstack_forward = ConvNeXtStack.forward |
|
compiled_convnextstack_forward = torch.compile( |
|
ConvNeXtStack.forward, mode="reduce-overhead" |
|
) |
|
if h.compile_d4c: |
|
d4c = torch.compile(d4c, mode="reduce-overhead") |
|
if h.compile_discriminator: |
|
MultiPeriodDiscriminator.forward_and_compute_loss = torch.compile( |
|
MultiPeriodDiscriminator.forward_and_compute_loss, mode="reduce-overhead" |
|
) |
|
|
|
|
|
with ( |
|
torch.profiler.profile( |
|
schedule=torch.profiler.schedule(wait=1500, warmup=10, active=5, repeat=1), |
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(out_dir), |
|
record_shapes=True, |
|
with_stack=True, |
|
profile_memory=True, |
|
with_flops=True, |
|
) |
|
if h.profile |
|
else nullcontext() |
|
) as profiler: |
|
|
|
for iteration in tqdm(range(initial_iteration, h.n_steps)): |
|
|
|
try: |
|
batch = next(data_iter) |
|
except: |
|
data_iter = iter(training_loader) |
|
batch = next(data_iter) |
|
( |
|
clean_wavs, |
|
noisy_wavs_16k, |
|
slice_starts, |
|
speaker_ids, |
|
formant_shift_semitone, |
|
) = map(lambda x: x.to(device, non_blocking=True), batch) |
|
|
|
|
|
with torch.amp.autocast("cuda", enabled=h.use_amp): |
|
|
|
if h.compile_convnext: |
|
ConvNeXtStack.forward = compiled_convnextstack_forward |
|
y, y_hat, y_hat_for_backward, loss_mel, loss_ap, generator_stats = ( |
|
net_g.forward_and_compute_loss( |
|
noisy_wavs_16k[:, None, :], |
|
speaker_ids, |
|
formant_shift_semitone, |
|
slice_start_indices=slice_starts, |
|
slice_segment_length=h.segment_length, |
|
y_all=clean_wavs[:, None, :], |
|
enable_loss_ap=h.grad_weight_ap != 0.0, |
|
) |
|
) |
|
if h.compile_convnext: |
|
ConvNeXtStack.forward = raw_convnextstack_forward |
|
assert y_hat.isfinite().all() |
|
assert loss_mel.isfinite().all() |
|
assert loss_ap.isfinite().all() |
|
|
|
|
|
loss_discriminator, loss_adv, loss_fm, discriminator_stats = ( |
|
net_d.forward_and_compute_loss(y, y_hat) |
|
) |
|
assert loss_discriminator.isfinite().all() |
|
assert loss_adv.isfinite().all() |
|
assert loss_fm.isfinite().all() |
|
|
|
|
|
for param in net_d.parameters(): |
|
assert param.grad is None |
|
grad_scaler.scale(loss_discriminator).backward( |
|
retain_graph=True, inputs=list(net_d.parameters()) |
|
) |
|
loss_discriminator = loss_discriminator.item() |
|
grad_scaler.unscale_(optim_d) |
|
if iteration % 5 == 0: |
|
grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True) |
|
else: |
|
grad_norm_d = math.nan |
|
d_grad_norm_stats = {} |
|
|
|
|
|
for param in net_g.parameters(): |
|
assert param.grad is None |
|
gradient_balancer_stats = grad_balancer.backward( |
|
{ |
|
"loss_mel": loss_mel, |
|
"loss_adv": loss_adv, |
|
"loss_fm": loss_fm, |
|
} |
|
| ({"loss_ap": loss_ap} if h.grad_weight_ap else {}), |
|
y_hat_for_backward, |
|
grad_scaler, |
|
skip_update_ema=iteration > 10 and iteration % 5 != 0, |
|
) |
|
loss_mel = loss_mel.item() |
|
loss_adv = loss_adv.item() |
|
loss_fm = loss_fm.item() |
|
if h.grad_weight_ap: |
|
loss_ap = loss_ap.item() |
|
grad_scaler.unscale_(optim_g) |
|
if iteration % 5 == 0: |
|
grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True) |
|
else: |
|
grad_norm_g = math.nan |
|
g_grad_norm_stats = {} |
|
|
|
|
|
grad_scaler.step(optim_g) |
|
optim_g.zero_grad(set_to_none=True) |
|
grad_scaler.step(optim_d) |
|
optim_d.zero_grad(set_to_none=True) |
|
grad_scaler.update() |
|
|
|
|
|
dict_scalars["loss_g/loss_mel"].append(loss_mel) |
|
if h.grad_weight_ap: |
|
dict_scalars["loss_g/loss_ap"].append(loss_ap) |
|
dict_scalars["loss_g/loss_fm"].append(loss_fm) |
|
dict_scalars["loss_g/loss_adv"].append(loss_adv) |
|
dict_scalars["other/grad_scale"].append(grad_scaler.get_scale()) |
|
dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator) |
|
if math.isfinite(grad_norm_d): |
|
dict_scalars["other/gradient_norm_d"].append(grad_norm_d) |
|
for name, value in d_grad_norm_stats.items(): |
|
dict_scalars[f"~gradient_norm_d/{name}"].append(value) |
|
if math.isfinite(grad_norm_g): |
|
dict_scalars["other/gradient_norm_g"].append(grad_norm_g) |
|
for name, value in g_grad_norm_stats.items(): |
|
dict_scalars[f"~gradient_norm_g/{name}"].append(value) |
|
dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0]) |
|
dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0]) |
|
for k, v in generator_stats.items(): |
|
dict_scalars[f"~loss_generator/{k}"].append(v) |
|
for k, v in discriminator_stats.items(): |
|
dict_scalars[f"~loss_discriminator/{k}"].append(v) |
|
for k, v in gradient_balancer_stats.items(): |
|
dict_scalars[f"~gradient_balancer/{k}"].append(v) |
|
|
|
if (iteration + 1) % 1000 == 0 or iteration == 0: |
|
for name, scalars in dict_scalars.items(): |
|
if scalars: |
|
writer.add_scalar( |
|
name, sum(scalars) / len(scalars), iteration + 1 |
|
) |
|
scalars.clear() |
|
for name, param in net_g.named_parameters(): |
|
writer.add_histogram(f"weight/{name}", param, iteration + 1) |
|
|
|
intermediate_feature_stats = {} |
|
hook_handles = [] |
|
|
|
def get_layer_hook(name): |
|
def compute_stats(module, x, suffix): |
|
if not isinstance(x, torch.Tensor): |
|
return |
|
if x.dtype not in [torch.float32, torch.float16]: |
|
return |
|
if isinstance(module, nn.Identity): |
|
return |
|
x = x.detach().float() |
|
var = x.var().item() |
|
if isinstance(module, (nn.Linear, nn.LayerNorm)): |
|
channel_var, channel_mean = torch.var_mean( |
|
x.reshape(-1, x.size(-1)), 0 |
|
) |
|
elif isinstance(module, nn.Conv1d): |
|
channel_var, channel_mean = torch.var_mean(x, [0, 2]) |
|
else: |
|
return |
|
average_squared_channel_mean = ( |
|
channel_mean.square().mean().item() |
|
) |
|
average_channel_var = channel_var.mean().item() |
|
|
|
tensor_idx = len(intermediate_feature_stats) // 3 |
|
intermediate_feature_stats[ |
|
f"var/{tensor_idx:02d}_{name}/{suffix}" |
|
] = var |
|
intermediate_feature_stats[ |
|
f"avg_sq_ch_mean/{tensor_idx:02d}_{name}/{suffix}" |
|
] = average_squared_channel_mean |
|
intermediate_feature_stats[ |
|
f"avg_ch_var/{tensor_idx:02d}_{name}/{suffix}" |
|
] = average_channel_var |
|
|
|
def forward_pre_hook(module, input): |
|
for i, input_i in enumerate(input): |
|
compute_stats(module, input_i, f"input_{i}") |
|
|
|
def forward_hook(module, input, output): |
|
if isinstance(output, tuple): |
|
for i, output_i in enumerate(output): |
|
compute_stats(module, output_i, f"output_{i}") |
|
else: |
|
compute_stats(module, output, "output") |
|
|
|
return forward_pre_hook, forward_hook |
|
|
|
for name, layer in net_g.named_modules(): |
|
forward_pre_hook, forward_hook = get_layer_hook(name) |
|
hook_handles.append( |
|
layer.register_forward_pre_hook(forward_pre_hook) |
|
) |
|
hook_handles.append(layer.register_forward_hook(forward_hook)) |
|
with torch.no_grad(), torch.amp.autocast("cuda", enabled=h.use_amp): |
|
net_g.forward_and_compute_loss( |
|
noisy_wavs_16k[:, None, :], |
|
speaker_ids, |
|
formant_shift_semitone, |
|
slice_start_indices=slice_starts, |
|
slice_segment_length=h.segment_length, |
|
y_all=clean_wavs[:, None, :], |
|
enable_loss_ap=h.grad_weight_ap != 0.0, |
|
) |
|
for handle in hook_handles: |
|
handle.remove() |
|
for name, value in intermediate_feature_stats.items(): |
|
writer.add_scalar( |
|
f"~intermediate_feature_{name}", value, iteration + 1 |
|
) |
|
|
|
|
|
if (iteration + 1) % ( |
|
50000 if h.n_steps > 200000 else 2000 |
|
) == 0 or iteration + 1 in { |
|
1, |
|
30000, |
|
h.n_steps, |
|
}: |
|
torch.backends.cudnn.benchmark = False |
|
net_g.eval() |
|
torch.cuda.empty_cache() |
|
|
|
dict_qualities_all = defaultdict(list) |
|
n_added_wavs = 0 |
|
with torch.inference_mode(): |
|
for i, ((file, target_ids), pitch_shift_semitones) in enumerate( |
|
zip(test_filelist, test_pitch_shifts) |
|
): |
|
source_wav, sr = torchaudio.load(file, backend="soundfile") |
|
source_wav = source_wav.to(device) |
|
if sr != h.in_sample_rate: |
|
source_wav = get_resampler(sr, h.in_sample_rate, device)( |
|
source_wav |
|
) |
|
source_wav = source_wav.to(device) |
|
original_source_wav_length = source_wav.size(1) |
|
|
|
if source_wav.size(1) % h.in_sample_rate == 0: |
|
padded_source_wav = source_wav |
|
else: |
|
padded_source_wav = F.pad( |
|
source_wav, |
|
( |
|
0, |
|
h.in_sample_rate |
|
- source_wav.size(1) % h.in_sample_rate, |
|
), |
|
) |
|
converted = net_g( |
|
padded_source_wav[[0] * len(target_ids), None], |
|
torch.tensor(target_ids, device=device), |
|
torch.tensor( |
|
[0.0] * len(target_ids), device=device |
|
), |
|
torch.tensor( |
|
[float(p) for p in pitch_shift_semitones], device=device |
|
), |
|
).squeeze_(1)[:, : original_source_wav_length // 160 * 240] |
|
if i < 12: |
|
if iteration == 0: |
|
writer.add_audio( |
|
f"source/y_{i:02d}", |
|
source_wav, |
|
iteration + 1, |
|
h.in_sample_rate, |
|
) |
|
for d in range( |
|
min( |
|
len(target_ids), |
|
1 + (12 - i - 1) // len(test_filelist), |
|
) |
|
): |
|
idx_in_batch = n_added_wavs % len(target_ids) |
|
writer.add_audio( |
|
f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}", |
|
converted[idx_in_batch], |
|
iteration + 1, |
|
h.out_sample_rate, |
|
) |
|
n_added_wavs += 1 |
|
converted = resample_to_in_sample_rate(converted) |
|
quality = quality_tester.test(converted, source_wav) |
|
for metric_name, values in quality.items(): |
|
dict_qualities_all[metric_name].extend(values) |
|
assert n_added_wavs == min( |
|
12, len(test_filelist) * len(test_filelist[0][1]) |
|
), ( |
|
n_added_wavs, |
|
len(test_filelist), |
|
len(speakers), |
|
len(test_filelist[0][1]), |
|
) |
|
dict_qualities = { |
|
metric_name: sum(values) / len(values) |
|
for metric_name, values in dict_qualities_all.items() |
|
if len(values) |
|
} |
|
for metric_name, value in dict_qualities.items(): |
|
writer.add_scalar(f"validation/{metric_name}", value, iteration + 1) |
|
for metric_name, values in dict_qualities_all.items(): |
|
for i, value in enumerate(values): |
|
writer.add_scalar( |
|
f"~validation_{metric_name}/{i:03d}", value, iteration + 1 |
|
) |
|
del dict_qualities, dict_qualities_all |
|
|
|
net_g.train() |
|
torch.backends.cudnn.benchmark = True |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if (iteration + 1) % ( |
|
50000 if h.n_steps > 200000 else 2000 |
|
) == 0 or iteration + 1 in { |
|
1, |
|
30000, |
|
h.n_steps, |
|
}: |
|
|
|
name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}" |
|
checkpoint_file_save = out_dir / f"checkpoint_{name}.pt" |
|
if checkpoint_file_save.exists(): |
|
checkpoint_file_save = checkpoint_file_save.with_name( |
|
f"{checkpoint_file_save.name}_{hash(None):x}" |
|
) |
|
torch.save( |
|
{ |
|
"iteration": iteration + 1, |
|
"net_g": net_g.state_dict(), |
|
"phone_extractor": phone_extractor.state_dict(), |
|
"pitch_estimator": pitch_estimator.state_dict(), |
|
"net_d": net_d.state_dict(), |
|
"optim_g": optim_g.state_dict(), |
|
"optim_d": optim_d.state_dict(), |
|
"grad_balancer": grad_balancer.state_dict(), |
|
"grad_scaler": grad_scaler.state_dict(), |
|
"h": dict(h), |
|
}, |
|
checkpoint_file_save, |
|
) |
|
shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt") |
|
|
|
|
|
paraphernalia_dir = out_dir / f"paraphernalia_{name}" |
|
if paraphernalia_dir.exists(): |
|
paraphernalia_dir = paraphernalia_dir.with_name( |
|
f"{paraphernalia_dir.name}_{hash(None):x}" |
|
) |
|
paraphernalia_dir.mkdir() |
|
phone_extractor_fp16 = PhoneExtractor() |
|
phone_extractor_fp16.load_state_dict(phone_extractor.state_dict()) |
|
phone_extractor_fp16.remove_weight_norm() |
|
phone_extractor_fp16.merge_weights() |
|
phone_extractor_fp16.half() |
|
phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin") |
|
del phone_extractor_fp16 |
|
pitch_estimator_fp16 = PitchEstimator() |
|
pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict()) |
|
pitch_estimator_fp16.merge_weights() |
|
pitch_estimator_fp16.half() |
|
pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin") |
|
del pitch_estimator_fp16 |
|
net_g_fp16 = ConverterNetwork( |
|
nn.Module(), nn.Module(), len(speakers), h.hidden_channels |
|
) |
|
net_g_fp16.load_state_dict(net_g.state_dict()) |
|
net_g_fp16.merge_weights() |
|
net_g_fp16.half() |
|
net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin") |
|
with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f: |
|
dump_layer(net_g_fp16.embed_speaker, f) |
|
with open( |
|
paraphernalia_dir / f"formant_shift_embeddings.bin", "wb" |
|
) as f: |
|
dump_layer(net_g_fp16.embed_formant_shift, f) |
|
del net_g_fp16 |
|
shutil.copy( |
|
repo_root() / "assets/images/noimage.png", paraphernalia_dir |
|
) |
|
with open( |
|
paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml", |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
f.write( |
|
f'''[model] |
|
version = "{PARAPHERNALIA_VERSION}" |
|
name = "{name}" |
|
description = """ |
|
No description for this model. |
|
このモデルの説明はありません。 |
|
""" |
|
''' |
|
) |
|
for speaker_id, (speaker, speaker_f0) in enumerate( |
|
zip(speakers, speaker_f0s) |
|
): |
|
average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0) |
|
average_pitch = round(average_pitch * 8.0) / 8.0 |
|
f.write( |
|
f''' |
|
[voice.{speaker_id}] |
|
name = "{speaker}" |
|
description = """ |
|
No description for this voice. |
|
この声の説明はありません。 |
|
""" |
|
average_pitch = {average_pitch} |
|
|
|
[voice.{speaker_id}.portrait] |
|
path = "noimage.png" |
|
description = """ |
|
""" |
|
''' |
|
) |
|
del paraphernalia_dir |
|
|
|
|
|
|
|
|
|
scheduler_g.step() |
|
scheduler_d.step() |
|
if h.profile: |
|
profiler.step() |
|
|
|
print("Training finished.") |
|
|