Victor Shirasuna commited on
Commit
8abbc76
·
1 Parent(s): 8c39e88

Added restart checkpoint in finetune

Browse files
smi-ted/finetune/args.py CHANGED
@@ -304,6 +304,7 @@ def get_parser(parser=None):
304
  # parser.add_argument("--patience_epochs", type=int, required=True)
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
 
307
  # parser.add_argument('--n_output', type=int, default=1)
308
  parser.add_argument("--save_every_epoch", type=int, default=0)
309
  parser.add_argument("--save_ckpt", type=int, default=1)
 
304
  # parser.add_argument("--patience_epochs", type=int, required=True)
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
307
+ parser.add_argument("--restart_filename", type=str, default="")
308
  # parser.add_argument('--n_output', type=int, default=1)
309
  parser.add_argument("--save_every_epoch", type=int, default=0)
310
  parser.add_argument("--save_ckpt", type=int, default=1)
smi-ted/finetune/finetune_classification.py CHANGED
@@ -48,6 +48,7 @@ def main(config):
48
  seed=config.start_seed,
49
  smi_ted_version=config.smi_ted_version,
50
  checkpoints_folder=config.checkpoints_folder,
 
51
  device=device,
52
  save_every_epoch=bool(config.save_every_epoch),
53
  save_ckpt=bool(config.save_ckpt)
 
48
  seed=config.start_seed,
49
  smi_ted_version=config.smi_ted_version,
50
  checkpoints_folder=config.checkpoints_folder,
51
+ restart_filename=config.restart_filename,
52
  device=device,
53
  save_every_epoch=bool(config.save_every_epoch),
54
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/finetune_classification_multitask.py CHANGED
@@ -81,6 +81,7 @@ def main(config):
81
  seed=config.start_seed,
82
  smi_ted_version=config.smi_ted_version,
83
  checkpoints_folder=config.checkpoints_folder,
 
84
  device=device,
85
  save_every_epoch=bool(config.save_every_epoch),
86
  save_ckpt=bool(config.save_ckpt)
 
81
  seed=config.start_seed,
82
  smi_ted_version=config.smi_ted_version,
83
  checkpoints_folder=config.checkpoints_folder,
84
+ restart_filename=config.restart_filename,
85
  device=device,
86
  save_every_epoch=bool(config.save_every_epoch),
87
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/finetune_regression.py CHANGED
@@ -50,6 +50,7 @@ def main(config):
50
  seed=config.start_seed,
51
  smi_ted_version=config.smi_ted_version,
52
  checkpoints_folder=config.checkpoints_folder,
 
53
  device=device,
54
  save_every_epoch=bool(config.save_every_epoch),
55
  save_ckpt=bool(config.save_ckpt)
 
50
  seed=config.start_seed,
51
  smi_ted_version=config.smi_ted_version,
52
  checkpoints_folder=config.checkpoints_folder,
53
+ restart_filename=config.restart_filename,
54
  device=device,
55
  save_every_epoch=bool(config.save_every_epoch),
56
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/trainers.py CHANGED
@@ -26,7 +26,7 @@ from utils import RMSE, sensitivity, specificity
26
  class Trainer:
27
 
28
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
29
- target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
30
  # data
31
  self.df_train = raw_data[0]
32
  self.df_valid = raw_data[1]
@@ -42,6 +42,8 @@ class Trainer:
42
  self.seed = seed
43
  self.smi_ted_version = smi_ted_version
44
  self.checkpoints_folder = checkpoints_folder
 
 
45
  self.save_every_epoch = save_every_epoch
46
  self.save_ckpt = save_ckpt
47
  self.device = device
@@ -82,11 +84,14 @@ class Trainer:
82
  self.optimizer = optimizer
83
  self.loss_fn = loss_fn
84
  self._print_configuration()
 
 
 
85
 
86
  def fit(self, max_epochs=500):
87
  best_vloss = float('inf')
88
 
89
- for epoch in range(1, max_epochs+1):
90
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
91
 
92
  # training
@@ -183,6 +188,7 @@ class Trainer:
183
  ckpt_path = os.path.join(self.checkpoints_folder, filename)
184
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
185
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
 
186
 
187
  def _save_checkpoint(self, current_epoch, filename):
188
  if not os.path.exists(self.checkpoints_folder):
@@ -229,9 +235,9 @@ class Trainer:
229
  class TrainerRegressor(Trainer):
230
 
231
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
232
- target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
233
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
234
- target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
235
 
236
  def _train_one_epoch(self):
237
  running_loss = 0.0
@@ -320,9 +326,9 @@ class TrainerRegressor(Trainer):
320
  class TrainerClassifier(Trainer):
321
 
322
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
323
- target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
324
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
325
- target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
326
 
327
  def _train_one_epoch(self):
328
  running_loss = 0.0
@@ -427,9 +433,9 @@ class TrainerClassifier(Trainer):
427
  class TrainerClassifierMultitask(Trainer):
428
 
429
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
430
- target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
431
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
432
- target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
433
 
434
  def _prepare_data(self):
435
  # normalize dataset
 
26
  class Trainer:
27
 
28
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
29
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
30
  # data
31
  self.df_train = raw_data[0]
32
  self.df_valid = raw_data[1]
 
42
  self.seed = seed
43
  self.smi_ted_version = smi_ted_version
44
  self.checkpoints_folder = checkpoints_folder
45
+ self.restart_filename = restart_filename
46
+ self.start_epoch = 1
47
  self.save_every_epoch = save_every_epoch
48
  self.save_ckpt = save_ckpt
49
  self.device = device
 
84
  self.optimizer = optimizer
85
  self.loss_fn = loss_fn
86
  self._print_configuration()
87
+ if self.restart_filename:
88
+ self._load_checkpoint(self.restart_filename)
89
+ print('Checkpoint restored!')
90
 
91
  def fit(self, max_epochs=500):
92
  best_vloss = float('inf')
93
 
94
+ for epoch in range(self.start_epoch, max_epochs+1):
95
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
96
 
97
  # training
 
188
  ckpt_path = os.path.join(self.checkpoints_folder, filename)
189
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
190
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
191
+ self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
192
 
193
  def _save_checkpoint(self, current_epoch, filename):
194
  if not os.path.exists(self.checkpoints_folder):
 
235
  class TrainerRegressor(Trainer):
236
 
237
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
238
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
239
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
240
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
241
 
242
  def _train_one_epoch(self):
243
  running_loss = 0.0
 
326
  class TrainerClassifier(Trainer):
327
 
328
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
329
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
330
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
331
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
332
 
333
  def _train_one_epoch(self):
334
  running_loss = 0.0
 
433
  class TrainerClassifierMultitask(Trainer):
434
 
435
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
436
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
437
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
438
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
439
 
440
  def _prepare_data(self):
441
  # normalize dataset