|
import os, json, argparse, yaml |
|
import numpy as np |
|
from tqdm import tqdm |
|
import librosa |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.cuda.amp import autocast, GradScaler |
|
|
|
from dataset.diffpitch import DiffPitch |
|
from models.transformer import PitchFormer |
|
from utils import minmax_norm_diff, reverse_minmax_norm_diff, save_curve_plot |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-config', type=str, default='config/DiffPitch.yaml') |
|
|
|
parser.add_argument('-seed', type=int, default=9811) |
|
parser.add_argument('-amp', type=bool, default=False) |
|
parser.add_argument('-compile', type=bool, default=False) |
|
|
|
parser.add_argument('-data_dir', type=str, default='data/') |
|
parser.add_argument('-content_dir', type=str, default='world') |
|
|
|
parser.add_argument('-train_frames', type=int, default=256) |
|
parser.add_argument('-test_frames', type=int, default=256) |
|
parser.add_argument('-batch_size', type=int, default=32) |
|
parser.add_argument('-test_size', type=int, default=1) |
|
parser.add_argument('-num_workers', type=int, default=4) |
|
parser.add_argument('-lr', type=float, default=5e-5) |
|
parser.add_argument('-weight_decay', type=int, default=1e-6) |
|
|
|
parser.add_argument('-epochs', type=int, default=1) |
|
parser.add_argument('-save_every', type=int, default=20) |
|
parser.add_argument('-log_step', type=int, default=100) |
|
parser.add_argument('-log_dir', type=str, default='logs_transformer_pitch') |
|
parser.add_argument('-ckpt_dir', type=str, default='ckpt_transformer_pitch') |
|
|
|
args = parser.parse_args() |
|
args.save_ori = True |
|
config = yaml.load(open(args.config), Loader=yaml.FullLoader) |
|
mel_cfg = config['logmel'] |
|
ddpm_cfg = config['ddpm'] |
|
|
|
|
|
|
|
def RMSE(gen_f0, gt_f0): |
|
|
|
gt_f0 = gt_f0[0] |
|
gen_f0 = gen_f0[0] |
|
|
|
nonzero_idxs = np.where((gen_f0 != 0) & (gt_f0 != 0))[0] |
|
gen_f0_voiced = np.log2(gen_f0[nonzero_idxs]) |
|
gt_f0_voiced = np.log2(gt_f0[nonzero_idxs]) |
|
|
|
if len(gen_f0_voiced) != 0: |
|
f0_rmse = np.sqrt(np.mean((gen_f0_voiced - gt_f0_voiced) ** 2)) |
|
else: |
|
f0_rmse = 0 |
|
return f0_rmse |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.manual_seed(args.seed) |
|
np.random.seed(args.seed) |
|
if torch.cuda.is_available(): |
|
args.device = 'cuda' |
|
torch.cuda.manual_seed(args.seed) |
|
torch.cuda.manual_seed_all(args.seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
if torch.backends.cudnn.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
else: |
|
args.device = 'cpu' |
|
|
|
if os.path.exists(args.log_dir) is False: |
|
os.makedirs(args.log_dir) |
|
|
|
if os.path.exists(args.ckpt_dir) is False: |
|
os.makedirs(args.ckpt_dir) |
|
|
|
print('Initializing data loaders...') |
|
trainset = DiffPitch('data/', 'train', args.train_frames, shift=True) |
|
train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, |
|
drop_last=True, shuffle=True) |
|
|
|
val_set = DiffPitch('data/', 'val', args.test_frames, shift=True) |
|
val_loader = DataLoader(val_set, batch_size=1, shuffle=False) |
|
|
|
test_set = DiffPitch('data/', 'test', args.test_frames, shift=True) |
|
test_loader = DataLoader(test_set, batch_size=1, shuffle=False) |
|
|
|
real_set = DiffPitch('data/', 'real', args.test_frames, shift=False) |
|
read_loader = DataLoader(real_set, batch_size=1, shuffle=False) |
|
|
|
print('Initializing and loading models...') |
|
model = PitchFormer(mel_cfg['n_mels'], 512).to(args.device) |
|
ckpt = torch.load('ckpt_transformer_pitch/transformer_pitch_460.pt') |
|
model.load_state_dict(ckpt) |
|
|
|
print('Initializing optimizers...') |
|
optimizer = torch.optim.AdamW(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
|
scaler = GradScaler() |
|
|
|
if args.compile: |
|
model = torch.compile(model) |
|
|
|
print('Start training.') |
|
global_step = 0 |
|
for epoch in range(1, args.epochs + 1): |
|
print(f'Epoch: {epoch} [iteration: {global_step}]') |
|
model.train() |
|
losses = [] |
|
|
|
for step, batch in enumerate(tqdm(train_loader)): |
|
optimizer.zero_grad() |
|
mel, midi, f0 = batch |
|
mel = mel.to(args.device) |
|
midi = midi.to(args.device) |
|
f0 = f0.to(args.device) |
|
|
|
if args.amp: |
|
with autocast(): |
|
f0_pred = model(sp=mel, midi=midi) |
|
loss = F.mse_loss(f0_pred, f0) |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
f0_pred = model(sp=mel, midi=midi) |
|
loss = F.l1_loss(f0_pred, f0) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
global_step += 1 |
|
|
|
if global_step % args.log_step == 0: |
|
losses = np.asarray(losses) |
|
|
|
msg = '\nEpoch: [{}][{}]\t' \ |
|
'Batch: [{}][{}]\tLoss: {:.6f}\n'.format(epoch, |
|
args.epochs, |
|
step+1, |
|
len(train_loader), |
|
np.mean(losses)) |
|
with open(f'{args.log_dir}/train_dec.log', 'a') as f: |
|
f.write(msg) |
|
losses = [] |
|
|
|
if epoch % args.save_every > 0: |
|
continue |
|
|
|
print('Saving model...\n') |
|
ckpt = model.state_dict() |
|
torch.save(ckpt, f=f"{args.ckpt_dir}/transformer_pitch_{epoch}.pt") |
|
|
|
print('Inference...\n') |
|
model.eval() |
|
with torch.no_grad(): |
|
val_loss = [] |
|
val_rmse = [] |
|
for i, batch in enumerate(val_loader): |
|
|
|
mel, midi, f0 = batch |
|
mel = mel.to(args.device) |
|
midi = midi.to(args.device) |
|
f0 = f0.to(args.device) |
|
|
|
f0_pred = model(sp=mel, midi=midi) |
|
|
|
|
|
f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
|
f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
|
|
|
val_loss.append(F.l1_loss(f0_pred, f0).item()) |
|
val_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy())) |
|
|
|
if i <= 4: |
|
save_path = f'{args.log_dir}/pic/{i}/{epoch}_val.png' |
|
if os.path.exists(os.path.dirname(save_path)) is False: |
|
os.makedirs(os.path.dirname(save_path)) |
|
save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
|
|
|
|
|
|
|
msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'.\ |
|
format(epoch, args.epochs, np.mean(val_loss), np.mean(val_rmse)) |
|
with open(f'{args.log_dir}/eval_dec.log', 'a') as f: |
|
f.write(msg) |
|
|
|
test_loss = [] |
|
test_rmse = [] |
|
for i, batch in enumerate(test_loader): |
|
|
|
mel, midi, f0 = batch |
|
mel = mel.to(args.device) |
|
midi = midi.to(args.device) |
|
f0 = f0.to(args.device) |
|
|
|
f0_pred = model(sp=mel, midi=midi) |
|
|
|
|
|
f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
|
f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
|
|
|
test_loss.append(F.l1_loss(f0_pred, f0).item()) |
|
test_rmse.append(RMSE(f0_pred.cpu().numpy(), f0.cpu().numpy())) |
|
|
|
if i <= 4: |
|
save_path = f'{args.log_dir}/pic/{i}/{epoch}_test.png' |
|
if os.path.exists(os.path.dirname(save_path)) is False: |
|
os.makedirs(os.path.dirname(save_path)) |
|
save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
|
|
|
msg = '\nEpoch: [{}][{}]\tLoss: {:.6f}\tRMSE:{:.6f}\n'. \ |
|
format(epoch, args.epochs, np.mean(test_loss), np.mean(test_rmse)) |
|
with open(f'{args.log_dir}/test_dec.log', 'a') as f: |
|
f.write(msg) |
|
|
|
for i, batch in enumerate(read_loader): |
|
|
|
mel, midi, f0 = batch |
|
mel = mel.to(args.device) |
|
midi = midi.to(args.device) |
|
f0 = f0.to(args.device) |
|
|
|
f0_pred = model(sp=mel, midi=midi) |
|
f0_pred[f0 == 0] = 0 |
|
|
|
|
|
f0_pred[f0_pred < librosa.note_to_hz('C2')] = 0 |
|
f0_pred[f0_pred > librosa.note_to_hz('C6')] = librosa.note_to_hz('C6') |
|
|
|
save_path = f'{args.log_dir}/pic/{i}/{epoch}_real.png' |
|
if os.path.exists(os.path.dirname(save_path)) is False: |
|
os.makedirs(os.path.dirname(save_path)) |
|
save_curve_plot(f0_pred.cpu().squeeze(), midi.cpu().squeeze(), f0.cpu().squeeze(), save_path) |
|
|
|
|
|
|
|
|
|
|