|
import os, json, argparse, yaml |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.cuda.amp import autocast, GradScaler |
|
|
|
from diffusers import DDIMScheduler |
|
|
|
from dataset import VCDecLPCDataset, VCDecLPCBatchCollate, VCDecLPCTest |
|
from models.unet import UNetVC |
|
from modules.BigVGAN.inference import load_model |
|
from utils import save_plot, save_audio |
|
from utils import minmax_norm_diff, reverse_minmax_norm_diff |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-config', type=str, default='config/DiffWorld_24k_log.yaml') |
|
|
|
parser.add_argument('-seed', type=int, default=98) |
|
parser.add_argument('-amp', type=bool, default=True) |
|
parser.add_argument('-compile', type=bool, default=False) |
|
|
|
parser.add_argument('-data_dir', type=str, default='../24k_center/') |
|
parser.add_argument('-lpc_dir', type=str, default='world') |
|
parser.add_argument('-vocoder_dir', type=str, default='modules/BigVGAN/ckpt/bigvgan_base_24khz_100band/g_05000000') |
|
|
|
parser.add_argument('-train_frames', type=int, default=128) |
|
parser.add_argument('-batch_size', type=int, default=32) |
|
parser.add_argument('-test_size', type=int, default=1) |
|
parser.add_argument('-num_workers', type=int, default=4) |
|
parser.add_argument('-lr', type=float, default=5e-5) |
|
parser.add_argument('-weight_decay', type=int, default=1e-6) |
|
|
|
parser.add_argument('-epochs', type=int, default=80) |
|
parser.add_argument('-save_every', type=int, default=2) |
|
parser.add_argument('-log_step', type=int, default=200) |
|
parser.add_argument('-log_dir', type=str, default='logs_dec_world_24k') |
|
parser.add_argument('-ckpt_dir', type=str, default='ckpt_world_24k') |
|
|
|
args = parser.parse_args() |
|
args.save_ori = True |
|
config = yaml.load(open(args.config), Loader=yaml.FullLoader) |
|
mel_cfg = config['logmel'] |
|
ddpm_cfg = config['ddpm'] |
|
unet_cfg = config['unet'] |
|
f0_type = unet_cfg['pitch_type'] |
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
if torch.cuda.is_available(): |
|
args.device = 'cuda' |
|
torch.cuda.manual_seed(args.seed) |
|
torch.cuda.manual_seed_all(args.seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
if torch.backends.cudnn.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
else: |
|
args.device = 'cpu' |
|
|
|
if os.path.exists(args.log_dir) is False: |
|
os.makedirs(args.log_dir) |
|
|
|
if os.path.exists(args.ckpt_dir) is False: |
|
os.makedirs(args.ckpt_dir) |
|
|
|
print('Initializing vocoder...') |
|
hifigan, cfg = load_model(args.vocoder_dir, device=args.device) |
|
|
|
print('Initializing data loaders...') |
|
train_set = VCDecLPCDataset(args.data_dir, subset='train', content_dir=args.lpc_dir, f0_type=f0_type) |
|
collate_fn = VCDecLPCBatchCollate(args.train_frames) |
|
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, |
|
collate_fn=collate_fn, num_workers=args.num_workers, drop_last=True) |
|
|
|
val_set = VCDecLPCTest(args.data_dir, content_dir=args.lpc_dir, f0_type=f0_type) |
|
val_loader = DataLoader(val_set, batch_size=1, shuffle=False) |
|
|
|
print('Initializing and loading models...') |
|
model = UNetVC(**unet_cfg).to(args.device) |
|
print('Number of parameters = %.2fm\n' % (model.nparams / 1e6)) |
|
|
|
|
|
noise_scheduler = DDIMScheduler(num_train_timesteps=ddpm_cfg['num_train_steps']) |
|
|
|
print('Initializing optimizers...') |
|
optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
scaler = GradScaler() |
|
|
|
if args.compile: |
|
model = torch.compile(model) |
|
|
|
print('Start training.') |
|
global_step = 0 |
|
for epoch in range(1, args.epochs + 1): |
|
print(f'Epoch: {epoch} [iteration: {global_step}]') |
|
model.train() |
|
losses = [] |
|
|
|
for step, batch in enumerate(tqdm(train_loader)): |
|
optimizer.zero_grad() |
|
|
|
|
|
mel = batch['mel1'].to(args.device) |
|
mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
|
|
if unet_cfg["use_ref_t"]: |
|
mel_ref = batch['mel2'].to(args.device) |
|
mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
else: |
|
mel_ref = None |
|
|
|
f0 = batch['f0_1'].to(args.device) |
|
|
|
mean = batch['content1'].to(args.device) |
|
mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
|
|
noise = torch.randn(mel.shape).to(args.device) |
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, |
|
(args.batch_size,), |
|
device=args.device, ).long() |
|
|
|
noisy_mel = noise_scheduler.add_noise(mel, noise, timesteps) |
|
|
|
if args.amp: |
|
with autocast(): |
|
noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None) |
|
loss = F.mse_loss(noise_pred, noise) |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
noise_pred = model(x=noisy_mel, mean=mean, f0=f0, t=timesteps, ref=mel_ref, embed=None) |
|
loss = F.mse_loss(noise_pred, noise) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
global_step += 1 |
|
|
|
if global_step % args.log_step == 0: |
|
losses = np.asarray(losses) |
|
|
|
msg = '\nEpoch: [{}][{}]\t' \ |
|
'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch, |
|
args.epochs, |
|
step+1, |
|
len(train_loader), |
|
np.mean(losses)) |
|
with open(f'{args.log_dir}/train_dec.log', 'a') as f: |
|
f.write(msg) |
|
losses = [] |
|
|
|
if epoch % args.save_every > 0: |
|
continue |
|
|
|
print('Saving model...\n') |
|
ckpt = model.state_dict() |
|
torch.save(ckpt, f=f"{args.ckpt_dir}/lpc_vc_{epoch}.pt") |
|
|
|
print('Inference...\n') |
|
noise = None |
|
noise_scheduler.set_timesteps(ddpm_cfg['inference_steps']) |
|
model.eval() |
|
with torch.no_grad(): |
|
for i, batch in enumerate(val_loader): |
|
|
|
generator = torch.Generator(device=args.device).manual_seed(args.seed) |
|
|
|
mel = batch['mel1'].to(args.device) |
|
mel = minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
|
|
if unet_cfg["use_ref_t"]: |
|
mel_ref = batch['mel2'].to(args.device) |
|
mel_ref = minmax_norm_diff(mel_ref, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
else: |
|
mel_ref = None |
|
|
|
f0 = batch['f0_1'].to(args.device) |
|
embed = batch['embed'].to(args.device) |
|
|
|
mean = batch['content1'].to(args.device) |
|
mean = minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
|
|
|
|
if noise is None: |
|
noise = torch.randn(mel.shape, |
|
generator=generator, |
|
device=args.device, |
|
) |
|
pred = noise |
|
|
|
for t in noise_scheduler.timesteps: |
|
pred = noise_scheduler.scale_model_input(pred, t) |
|
model_output = model(x=pred, mean=mean, f0=f0, t=t, ref=mel_ref, embed=None) |
|
pred = noise_scheduler.step(model_output=model_output, |
|
timestep=t, |
|
sample=pred, |
|
eta=ddpm_cfg['eta'], generator=generator).prev_sample |
|
|
|
|
|
if os.path.exists(f'{args.log_dir}/audio/{i}/') is False: |
|
os.makedirs(f'{args.log_dir}/audio/{i}/') |
|
os.makedirs(f'{args.log_dir}/pic/{i}/') |
|
|
|
|
|
pred = reverse_minmax_norm_diff(pred, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
save_plot(pred.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_pred.png') |
|
audio = hifigan(pred) |
|
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_pred.wav', mel_cfg['sampling_rate'], audio) |
|
|
|
if args.save_ori is True: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mel = reverse_minmax_norm_diff(mel, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
save_plot(mel.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_source.png') |
|
audio = hifigan(mel) |
|
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_source.wav', mel_cfg['sampling_rate'], audio) |
|
|
|
|
|
mean = reverse_minmax_norm_diff(mean, vmax=mel_cfg['max'], vmin=mel_cfg['min']) |
|
save_plot(mean.squeeze().cpu(), f'{args.log_dir}/pic/{i}/{epoch}_avg.png') |
|
audio = hifigan(mean) |
|
save_audio(f'{args.log_dir}/audio/{i}/{epoch}_avg.wav', mel_cfg['sampling_rate'], audio) |
|
|
|
args.save_ori = False |
|
|