Victor Shirasuna
commited on
Commit
·
8abbc76
1
Parent(s):
8c39e88
Added restart checkpoint in finetune
Browse files
smi-ted/finetune/args.py
CHANGED
@@ -304,6 +304,7 @@ def get_parser(parser=None):
|
|
304 |
# parser.add_argument("--patience_epochs", type=int, required=True)
|
305 |
parser.add_argument("--model_path", type=str, default="./smi_ted/")
|
306 |
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
|
|
|
307 |
# parser.add_argument('--n_output', type=int, default=1)
|
308 |
parser.add_argument("--save_every_epoch", type=int, default=0)
|
309 |
parser.add_argument("--save_ckpt", type=int, default=1)
|
|
|
304 |
# parser.add_argument("--patience_epochs", type=int, required=True)
|
305 |
parser.add_argument("--model_path", type=str, default="./smi_ted/")
|
306 |
parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
|
307 |
+
parser.add_argument("--restart_filename", type=str, default="")
|
308 |
# parser.add_argument('--n_output', type=int, default=1)
|
309 |
parser.add_argument("--save_every_epoch", type=int, default=0)
|
310 |
parser.add_argument("--save_ckpt", type=int, default=1)
|
smi-ted/finetune/finetune_classification.py
CHANGED
@@ -48,6 +48,7 @@ def main(config):
|
|
48 |
seed=config.start_seed,
|
49 |
smi_ted_version=config.smi_ted_version,
|
50 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
51 |
device=device,
|
52 |
save_every_epoch=bool(config.save_every_epoch),
|
53 |
save_ckpt=bool(config.save_ckpt)
|
|
|
48 |
seed=config.start_seed,
|
49 |
smi_ted_version=config.smi_ted_version,
|
50 |
checkpoints_folder=config.checkpoints_folder,
|
51 |
+
restart_filename=config.restart_filename,
|
52 |
device=device,
|
53 |
save_every_epoch=bool(config.save_every_epoch),
|
54 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/finetune_classification_multitask.py
CHANGED
@@ -81,6 +81,7 @@ def main(config):
|
|
81 |
seed=config.start_seed,
|
82 |
smi_ted_version=config.smi_ted_version,
|
83 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
84 |
device=device,
|
85 |
save_every_epoch=bool(config.save_every_epoch),
|
86 |
save_ckpt=bool(config.save_ckpt)
|
|
|
81 |
seed=config.start_seed,
|
82 |
smi_ted_version=config.smi_ted_version,
|
83 |
checkpoints_folder=config.checkpoints_folder,
|
84 |
+
restart_filename=config.restart_filename,
|
85 |
device=device,
|
86 |
save_every_epoch=bool(config.save_every_epoch),
|
87 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/finetune_regression.py
CHANGED
@@ -50,6 +50,7 @@ def main(config):
|
|
50 |
seed=config.start_seed,
|
51 |
smi_ted_version=config.smi_ted_version,
|
52 |
checkpoints_folder=config.checkpoints_folder,
|
|
|
53 |
device=device,
|
54 |
save_every_epoch=bool(config.save_every_epoch),
|
55 |
save_ckpt=bool(config.save_ckpt)
|
|
|
50 |
seed=config.start_seed,
|
51 |
smi_ted_version=config.smi_ted_version,
|
52 |
checkpoints_folder=config.checkpoints_folder,
|
53 |
+
restart_filename=config.restart_filename,
|
54 |
device=device,
|
55 |
save_every_epoch=bool(config.save_every_epoch),
|
56 |
save_ckpt=bool(config.save_ckpt)
|
smi-ted/finetune/trainers.py
CHANGED
@@ -26,7 +26,7 @@ from utils import RMSE, sensitivity, specificity
|
|
26 |
class Trainer:
|
27 |
|
28 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
29 |
-
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
30 |
# data
|
31 |
self.df_train = raw_data[0]
|
32 |
self.df_valid = raw_data[1]
|
@@ -42,6 +42,8 @@ class Trainer:
|
|
42 |
self.seed = seed
|
43 |
self.smi_ted_version = smi_ted_version
|
44 |
self.checkpoints_folder = checkpoints_folder
|
|
|
|
|
45 |
self.save_every_epoch = save_every_epoch
|
46 |
self.save_ckpt = save_ckpt
|
47 |
self.device = device
|
@@ -82,11 +84,14 @@ class Trainer:
|
|
82 |
self.optimizer = optimizer
|
83 |
self.loss_fn = loss_fn
|
84 |
self._print_configuration()
|
|
|
|
|
|
|
85 |
|
86 |
def fit(self, max_epochs=500):
|
87 |
best_vloss = float('inf')
|
88 |
|
89 |
-
for epoch in range(
|
90 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
91 |
|
92 |
# training
|
@@ -183,6 +188,7 @@ class Trainer:
|
|
183 |
ckpt_path = os.path.join(self.checkpoints_folder, filename)
|
184 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
185 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
|
|
186 |
|
187 |
def _save_checkpoint(self, current_epoch, filename):
|
188 |
if not os.path.exists(self.checkpoints_folder):
|
@@ -229,9 +235,9 @@ class Trainer:
|
|
229 |
class TrainerRegressor(Trainer):
|
230 |
|
231 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
232 |
-
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
233 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
234 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
235 |
|
236 |
def _train_one_epoch(self):
|
237 |
running_loss = 0.0
|
@@ -320,9 +326,9 @@ class TrainerRegressor(Trainer):
|
|
320 |
class TrainerClassifier(Trainer):
|
321 |
|
322 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
323 |
-
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
324 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
325 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
326 |
|
327 |
def _train_one_epoch(self):
|
328 |
running_loss = 0.0
|
@@ -427,9 +433,9 @@ class TrainerClassifier(Trainer):
|
|
427 |
class TrainerClassifierMultitask(Trainer):
|
428 |
|
429 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
430 |
-
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
|
431 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
432 |
-
target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
|
433 |
|
434 |
def _prepare_data(self):
|
435 |
# normalize dataset
|
|
|
26 |
class Trainer:
|
27 |
|
28 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
29 |
+
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
30 |
# data
|
31 |
self.df_train = raw_data[0]
|
32 |
self.df_valid = raw_data[1]
|
|
|
42 |
self.seed = seed
|
43 |
self.smi_ted_version = smi_ted_version
|
44 |
self.checkpoints_folder = checkpoints_folder
|
45 |
+
self.restart_filename = restart_filename
|
46 |
+
self.start_epoch = 1
|
47 |
self.save_every_epoch = save_every_epoch
|
48 |
self.save_ckpt = save_ckpt
|
49 |
self.device = device
|
|
|
84 |
self.optimizer = optimizer
|
85 |
self.loss_fn = loss_fn
|
86 |
self._print_configuration()
|
87 |
+
if self.restart_filename:
|
88 |
+
self._load_checkpoint(self.restart_filename)
|
89 |
+
print('Checkpoint restored!')
|
90 |
|
91 |
def fit(self, max_epochs=500):
|
92 |
best_vloss = float('inf')
|
93 |
|
94 |
+
for epoch in range(self.start_epoch, max_epochs+1):
|
95 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
96 |
|
97 |
# training
|
|
|
188 |
ckpt_path = os.path.join(self.checkpoints_folder, filename)
|
189 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
190 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
191 |
+
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
|
192 |
|
193 |
def _save_checkpoint(self, current_epoch, filename):
|
194 |
if not os.path.exists(self.checkpoints_folder):
|
|
|
235 |
class TrainerRegressor(Trainer):
|
236 |
|
237 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
238 |
+
target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
239 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
240 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
241 |
|
242 |
def _train_one_epoch(self):
|
243 |
running_loss = 0.0
|
|
|
326 |
class TrainerClassifier(Trainer):
|
327 |
|
328 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
329 |
+
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
330 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
331 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
332 |
|
333 |
def _train_one_epoch(self):
|
334 |
running_loss = 0.0
|
|
|
433 |
class TrainerClassifierMultitask(Trainer):
|
434 |
|
435 |
def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
|
436 |
+
target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
|
437 |
super().__init__(raw_data, dataset_name, target, batch_size, hparams,
|
438 |
+
target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
|
439 |
|
440 |
def _prepare_data(self):
|
441 |
# normalize dataset
|