import librosa.display import matplotlib.pyplot as plt import json import torch import torchaudio import hifigan def manual_logging(logger, item, idx, tag, global_step, data_type, config): if data_type == "audio": audio = item[idx, ...].detach().cpu().numpy() logger.add_audio( tag, audio, global_step, sample_rate=config["preprocess"]["sampling_rate"], ) elif data_type == "image": image = item[idx, ...].detach().cpu().numpy() fig, ax = plt.subplots() _ = librosa.display.specshow( image, x_axis="time", y_axis="linear", sr=config["preprocess"]["sampling_rate"], hop_length=config["preprocess"]["frame_shift"], fmax=config["preprocess"]["sampling_rate"] // 2, ax=ax, ) logger.add_figure(tag, fig, global_step) else: raise NotImplementedError( "Data type given to logger should be [audio] or [image]" ) def load_vocoder(config): with open( "hifigan/config_{}.json".format(config["general"]["feature_type"]), "r" ) as f: config_hifigan = hifigan.AttrDict(json.load(f)) vocoder = hifigan.Generator(config_hifigan) vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"]) vocoder.remove_weight_norm() for param in vocoder.parameters(): param.requires_grad = False return vocoder def get_conv_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) def plot_and_save_mels(wav, save_path, config): spec_module = torchaudio.transforms.MelSpectrogram( sample_rate=config["preprocess"]["sampling_rate"], n_fft=config["preprocess"]["fft_length"], win_length=config["preprocess"]["frame_length"], hop_length=config["preprocess"]["frame_shift"], f_min=config["preprocess"]["fmin"], f_max=config["preprocess"]["fmax"], n_mels=config["preprocess"]["n_mels"], power=1, center=True, norm="slaney", mel_scale="slaney", ) spec = spec_module(wav.unsqueeze(0)) log_spec = torch.log( torch.clamp_min(spec, config["preprocess"]["min_magnitude"]) * config["preprocess"]["comp_factor"] ) fig, ax = plt.subplots() _ = librosa.display.specshow( log_spec.squeeze(0).numpy(), x_axis="time", y_axis="linear", sr=config["preprocess"]["sampling_rate"], hop_length=config["preprocess"]["frame_shift"], fmax=config["preprocess"]["sampling_rate"] // 2, ax=ax, cmap="viridis", ) fig.savefig(save_path, bbox_inches="tight", pad_inches=0) def plot_and_save_mels_all(wavs, keys, save_path, config): spec_module = torchaudio.transforms.MelSpectrogram( sample_rate=config["preprocess"]["sampling_rate"], n_fft=config["preprocess"]["fft_length"], win_length=config["preprocess"]["frame_length"], hop_length=config["preprocess"]["frame_shift"], f_min=config["preprocess"]["fmin"], f_max=config["preprocess"]["fmax"], n_mels=config["preprocess"]["n_mels"], power=1, center=True, norm="slaney", mel_scale="slaney", ) fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18)) for i, key in enumerate(keys): wav = wavs[key][0, ...].cpu() spec = spec_module(wav.unsqueeze(0)) log_spec = torch.log( torch.clamp_min(spec, config["preprocess"]["min_magnitude"]) * config["preprocess"]["comp_factor"] ) ax[i // 3, i % 3].set(title=key) _ = librosa.display.specshow( log_spec.squeeze(0).numpy(), x_axis="time", y_axis="linear", sr=config["preprocess"]["sampling_rate"], hop_length=config["preprocess"]["frame_shift"], fmax=config["preprocess"]["sampling_rate"] // 2, ax=ax[i // 3, i % 3], cmap="viridis", ) fig.savefig(save_path, bbox_inches="tight", pad_inches=0) def configure_args(config, args): for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]: if getattr(args, key) != None: config["general"][key] = str(getattr(args, key)) for key in ["n_train", "n_val", "n_test"]: if getattr(args, key) != None: config["preprocess"][key] = getattr(args, key) for key in ["alpha", "beta", "learning_rate", "epoch"]: if getattr(args, key) != None: config["train"][key] = getattr(args, key) for key in ["load_pretrained", "early_stopping"]: config["train"][key] = getattr(args, key) if args.feature_loss_type != None: config["train"]["feature_loss"]["type"] = args.feature_loss_type for key in ["pretrained_path"]: if getattr(args, key) != None: config["train"][key] = str(getattr(args, key)) return config, args