File size: 977 Bytes
7b918f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import matplotlib.pyplot as plt
import json
import torch
import torchaudio
def configure_args(config, args):
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
if getattr(args, key) != None:
config["general"][key] = str(getattr(args, key))
for key in ["n_train", "n_val", "n_test"]:
if getattr(args, key) != None:
config["preprocess"][key] = getattr(args, key)
for key in ["alpha", "beta", "learning_rate", "epoch"]:
if getattr(args, key) != None:
config["train"][key] = getattr(args, key)
for key in ["load_pretrained", "early_stopping"]:
config["train"][key] = getattr(args, key)
if args.feature_loss_type != None:
config["train"]["feature_loss"]["type"] = args.feature_loss_type
for key in ["pretrained_path"]:
if getattr(args, key) != None:
config["train"][key] = str(getattr(args, key))
return config, args
|