File size: 5,508 Bytes
45916af 871344c 45916af bdb2571 871344c 45916af 6bd0ee9 45916af 6bd0ee9 45916af 5f5e024 45916af 6bd0ee9 45916af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import argparse
import os
import pytorch_lightning as pl
import soundfile as sf
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.model_summary import summarize
from torch.utils.data import DataLoader
from config import CONFIG
from dataset import TrainDataset, TestLoader, BlindTestLoader
from models.frn import PLCModel, OnnxWrapper
from utils.tblogger import TensorBoardLoggerExpanded
from utils.utils import mkdir_p
parser = argparse.ArgumentParser()
parser.add_argument('--version', default=None,
help='version to resume')
parser.add_argument('--mode', default='train',
help='training or testing mode')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(CONFIG.gpus)
assert args.mode in ['train', 'eval', 'test', 'onnx'], "--mode should be 'train', 'eval', 'test' or 'onnx'"
def resume(train_dataset, val_dataset, version):
print("Version", version)
model_path = os.path.join(CONFIG.LOG.log_dir, 'version_{}/checkpoints/'.format(str(version)))
config_path = os.path.join(CONFIG.LOG.log_dir, 'version_{}/'.format(str(version)) + 'hparams.yaml')
model_name = [x for x in os.listdir(model_path) if x.endswith(".ckpt")][0]
ckpt_path = model_path + model_name
checkpoint = PLCModel.load_from_checkpoint(ckpt_path,
strict=True,
hparams_file=config_path,
train_dataset=train_dataset,
val_dataset=val_dataset,
window_size=CONFIG.DATA.window_size)
return checkpoint
def train():
train_dataset = TrainDataset('train')
val_dataset = TrainDataset('val')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', verbose=True,
filename='frn-{epoch:02d}-{val_loss:.4f}', save_weights_only=False)
gpus = CONFIG.gpus.split(',')
logger = TensorBoardLoggerExpanded(CONFIG.DATA.sr)
if args.version is not None:
model = resume(train_dataset, val_dataset, args.version)
else:
model = PLCModel(train_dataset,
val_dataset,
window_size=CONFIG.DATA.window_size,
enc_layers=CONFIG.MODEL.enc_layers,
enc_in_dim=CONFIG.MODEL.enc_in_dim,
enc_dim=CONFIG.MODEL.enc_dim,
pred_dim=CONFIG.MODEL.pred_dim,
pred_layers=CONFIG.MODEL.pred_layers)
trainer = pl.Trainer(logger=logger,
gradient_clip_val=CONFIG.TRAIN.clipping_val,
gpus=len(gpus),
max_epochs=CONFIG.TRAIN.epochs,
accelerator="gpu" if len(gpus) > 1 else None,
callbacks=[checkpoint_callback]
)
print(model.hparams)
print(
'Dataset: {}, Train files: {}, Val files {}'.format(CONFIG.DATA.dataset, len(train_dataset), len(val_dataset)))
trainer.fit(model)
def to_onnx(model, onnx_path):
model.eval()
model = OnnxWrapper(model)
torch.onnx.export(model,
model.sample,
onnx_path,
export_params=True,
opset_version=12,
input_names=model.input_names,
output_names=model.output_names,
do_constant_folding=True,
verbose=False)
if __name__ == '__main__':
if args.mode == 'train':
train()
else:
model = resume(None, None, args.version)
print(model.hparams)
print(summarize(model))
model.eval()
model.freeze()
if args.mode == 'eval':
model.cuda(device=0)
trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False)
testset = TestLoader()
test_loader = DataLoader(testset, batch_size=1, num_workers=4)
trainer.test(model, test_loader)
print('Version', args.version)
masking = CONFIG.DATA.EVAL.masking
prob = CONFIG.DATA.EVAL.transition_probs[0]
loss_percent = (1 - prob[0]) / (2 - prob[0] - prob[1]) * 100
print('Evaluate with real trace' if masking == 'real' else
'Evaluate with generated trace with {:.2f}% packet loss'.format(loss_percent))
elif args.mode == 'test':
model.cuda(device=0)
testset = BlindTestLoader(test_dir=CONFIG.TEST.in_dir)
test_loader = DataLoader(testset, batch_size=1, num_workers=4)
trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False)
preds = trainer.predict(model, test_loader, return_predictions=True)
mkdir_p(CONFIG.TEST.out_dir)
for idx, path in enumerate(test_loader.dataset.data_list):
out_path = os.path.join(CONFIG.TEST.out_dir, os.path.basename(path))
sf.write(out_path, preds[idx], samplerate=CONFIG.DATA.sr, subtype='PCM_16')
else:
onnx_path = 'lightning_logs/version_{}/checkpoints/frn.onnx'.format(str(args.version))
to_onnx(model, onnx_path)
print('ONNX model saved to', onnx_path)
|