eduardosoares99 vshirasuna commited on
Commit
30e2e3e
·
verified ·
1 Parent(s): 752e501

Update finetune code (#8)

Browse files

- Added changes (db22044aa220f47192e3176972f36806d889b651)
- Fix multitask finetune (6f04789f1c86f548ec85d6b57663e7b6021d8410)
- Fix typo (8c39e885cc0c05402daae7e4f9013493bbf7e241)
- Added restart checkpoint in finetune (8abbc761c751ebb56ae1b476bac06c76a8014946)
- Restore best_vloss in finetune (f6401dcdb8630da3b3480d5f27e55dc9cf22d2b3)


Co-authored-by: Victor Yukio Shirasuna <[email protected]>

smi-ted/finetune/args.py CHANGED
@@ -304,6 +304,7 @@ def get_parser(parser=None):
304
  # parser.add_argument("--patience_epochs", type=int, required=True)
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
 
307
  # parser.add_argument('--n_output', type=int, default=1)
308
  parser.add_argument("--save_every_epoch", type=int, default=0)
309
  parser.add_argument("--save_ckpt", type=int, default=1)
 
304
  # parser.add_argument("--patience_epochs", type=int, required=True)
305
  parser.add_argument("--model_path", type=str, default="./smi_ted/")
306
  parser.add_argument("--ckpt_filename", type=str, default="smi_ted_Light_40.pt")
307
+ parser.add_argument("--restart_filename", type=str, default="")
308
  # parser.add_argument('--n_output', type=int, default=1)
309
  parser.add_argument("--save_every_epoch", type=int, default=0)
310
  parser.add_argument("--save_ckpt", type=int, default=1)
smi-ted/finetune/finetune_classification.py CHANGED
@@ -28,7 +28,7 @@ def main(config):
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
@@ -46,7 +46,9 @@ def main(config):
46
  hparams=config,
47
  target_metric=config.target_metric,
48
  seed=config.start_seed,
 
49
  checkpoints_folder=config.checkpoints_folder,
 
50
  device=device,
51
  save_every_epoch=bool(config.save_every_epoch),
52
  save_ckpt=bool(config.save_ckpt)
 
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
 
46
  hparams=config,
47
  target_metric=config.target_metric,
48
  seed=config.start_seed,
49
+ smi_ted_version=config.smi_ted_version,
50
  checkpoints_folder=config.checkpoints_folder,
51
+ restart_filename=config.restart_filename,
52
  device=device,
53
  save_every_epoch=bool(config.save_every_epoch),
54
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/finetune_classification_multitask.py CHANGED
@@ -48,6 +48,7 @@ def main(config):
48
  'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
49
  'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
50
  ]
 
51
 
52
  # load dataset
53
  df_train = pd.read_csv(f"{config.data_root}/train.csv")
@@ -60,7 +61,7 @@ def main(config):
60
  elif config.smi_ted_version == 'v2':
61
  from smi_ted_large.load import load_smi_ted
62
 
63
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets))
64
  model.net.apply(model._init_weights)
65
  print(model.net)
66
 
@@ -78,7 +79,9 @@ def main(config):
78
  hparams=config,
79
  target_metric=config.target_metric,
80
  seed=config.start_seed,
 
81
  checkpoints_folder=config.checkpoints_folder,
 
82
  device=device,
83
  save_every_epoch=bool(config.save_every_epoch),
84
  save_ckpt=bool(config.save_ckpt)
 
48
  'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
49
  'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
50
  ]
51
+ config.n_output = len(targets)
52
 
53
  # load dataset
54
  df_train = pd.read_csv(f"{config.data_root}/train.csv")
 
61
  elif config.smi_ted_version == 'v2':
62
  from smi_ted_large.load import load_smi_ted
63
 
64
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets), eval=False)
65
  model.net.apply(model._init_weights)
66
  print(model.net)
67
 
 
79
  hparams=config,
80
  target_metric=config.target_metric,
81
  seed=config.start_seed,
82
+ smi_ted_version=config.smi_ted_version,
83
  checkpoints_folder=config.checkpoints_folder,
84
+ restart_filename=config.restart_filename,
85
  device=device,
86
  save_every_epoch=bool(config.save_every_epoch),
87
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/finetune_regression.py CHANGED
@@ -28,7 +28,7 @@ def main(config):
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
@@ -48,7 +48,9 @@ def main(config):
48
  hparams=config,
49
  target_metric=config.target_metric,
50
  seed=config.start_seed,
 
51
  checkpoints_folder=config.checkpoints_folder,
 
52
  device=device,
53
  save_every_epoch=bool(config.save_every_epoch),
54
  save_ckpt=bool(config.save_ckpt)
 
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
 
48
  hparams=config,
49
  target_metric=config.target_metric,
50
  seed=config.start_seed,
51
+ smi_ted_version=config.smi_ted_version,
52
  checkpoints_folder=config.checkpoints_folder,
53
+ restart_filename=config.restart_filename,
54
  device=device,
55
  save_every_epoch=bool(config.save_every_epoch),
56
  save_ckpt=bool(config.save_ckpt)
smi-ted/finetune/smi_ted_large/load.py CHANGED
@@ -318,7 +318,7 @@ class Net(nn.Module):
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
- def __init__(self, config, n_vocab):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
@@ -337,7 +337,7 @@ class MoLEncoder(nn.Module):
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
- deterministic_eval=False),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
@@ -361,7 +361,7 @@ class MoLDecoder(nn.Module):
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Large 738M Parameters"""
363
 
364
- def __init__(self, tokenizer, config=None):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
@@ -373,11 +373,11 @@ class Smi_ted(nn.Module):
373
 
374
  # instantiate modules
375
  if self.config:
376
- self.encoder = MoLEncoder(self.config, self.n_vocab)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
- def load_checkpoint(self, ckpt_path, n_output):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
@@ -388,7 +388,7 @@ class Smi_ted(nn.Module):
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
- self.encoder = MoLEncoder(self.config, self.n_vocab)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
@@ -493,11 +493,12 @@ class Smi_ted(nn.Module):
493
  def load_smi_ted(folder="./smi_ted_large",
494
  ckpt_filename="smi-ted-Large_30.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
- n_output=1
 
497
  ):
498
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
499
  model = Smi_ted(tokenizer)
500
- model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output)
501
  print('Vocab size:', len(tokenizer.vocab))
502
  print(f'[FINETUNE MODE - {str(model)}]')
503
  return model
 
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
+ def __init__(self, config, n_vocab, eval=False):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
 
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
+ deterministic_eval=eval),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
 
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Large 738M Parameters"""
363
 
364
+ def __init__(self, tokenizer, config=None, eval=False):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
 
373
 
374
  # instantiate modules
375
  if self.config:
376
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
 
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
 
493
  def load_smi_ted(folder="./smi_ted_large",
494
  ckpt_filename="smi-ted-Large_30.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
+ n_output=1,
497
+ eval=False
498
  ):
499
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
500
  model = Smi_ted(tokenizer)
501
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
502
  print('Vocab size:', len(tokenizer.vocab))
503
  print(f'[FINETUNE MODE - {str(model)}]')
504
  return model
smi-ted/finetune/smi_ted_light/load.py CHANGED
@@ -318,7 +318,7 @@ class Net(nn.Module):
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
- def __init__(self, config, n_vocab):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
@@ -337,7 +337,7 @@ class MoLEncoder(nn.Module):
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
- deterministic_eval=False),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
@@ -361,7 +361,7 @@ class MoLDecoder(nn.Module):
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Light 289M Parameters"""
363
 
364
- def __init__(self, tokenizer, config=None):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
@@ -373,11 +373,11 @@ class Smi_ted(nn.Module):
373
 
374
  # instantiate modules
375
  if self.config:
376
- self.encoder = MoLEncoder(self.config, self.n_vocab)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
- def load_checkpoint(self, ckpt_path, n_output):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
@@ -388,7 +388,7 @@ class Smi_ted(nn.Module):
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
- self.encoder = MoLEncoder(self.config, self.n_vocab)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
@@ -493,11 +493,12 @@ class Smi_ted(nn.Module):
493
  def load_smi_ted(folder="./smi_ted_light",
494
  ckpt_filename="smi-ted-Light_40.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
- n_output=1
 
497
  ):
498
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
499
  model = Smi_ted(tokenizer)
500
- model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output)
501
  print('Vocab size:', len(tokenizer.vocab))
502
  print(f'[FINETUNE MODE - {str(model)}]')
503
  return model
 
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
+ def __init__(self, config, n_vocab, eval=False):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
 
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
+ deterministic_eval=eval),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
 
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Light 289M Parameters"""
363
 
364
+ def __init__(self, tokenizer, config=None, eval=False):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
 
373
 
374
  # instantiate modules
375
  if self.config:
376
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
 
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
 
493
  def load_smi_ted(folder="./smi_ted_light",
494
  ckpt_filename="smi-ted-Light_40.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
+ n_output=1,
497
+ eval=False
498
  ):
499
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
500
  model = Smi_ted(tokenizer)
501
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
502
  print('Vocab size:', len(tokenizer.vocab))
503
  print(f'[FINETUNE MODE - {str(model)}]')
504
  return model
smi-ted/finetune/trainers.py CHANGED
@@ -14,6 +14,7 @@ import numpy as np
14
  import random
15
  import args
16
  import os
 
17
  from tqdm import tqdm
18
 
19
  # Machine Learning
@@ -25,7 +26,7 @@ from utils import RMSE, sensitivity, specificity
25
  class Trainer:
26
 
27
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
28
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
29
  # data
30
  self.df_train = raw_data[0]
31
  self.df_valid = raw_data[1]
@@ -39,10 +40,15 @@ class Trainer:
39
  # config
40
  self.target_metric = target_metric
41
  self.seed = seed
 
42
  self.checkpoints_folder = checkpoints_folder
 
 
43
  self.save_every_epoch = save_every_epoch
44
  self.save_ckpt = save_ckpt
45
  self.device = device
 
 
46
  self._set_seed(seed)
47
 
48
  def _prepare_data(self):
@@ -80,11 +86,12 @@ class Trainer:
80
  self.optimizer = optimizer
81
  self.loss_fn = loss_fn
82
  self._print_configuration()
 
 
 
83
 
84
  def fit(self, max_epochs=500):
85
- best_vloss = float('inf')
86
-
87
- for epoch in range(1, max_epochs+1):
88
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
89
 
90
  # training
@@ -99,44 +106,68 @@ class Trainer:
99
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
100
 
101
  ############################### Save Finetune checkpoint #######################################
102
- if ((val_loss < best_vloss) or self.save_every_epoch) and self.save_ckpt:
103
  # remove old checkpoint
104
- if best_vloss != float('inf') and not self.save_every_epoch:
105
  os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
106
 
107
  # filename
108
  model_name = f'{str(self.model)}-Finetune'
109
- self.last_filename = f"{model_name}_epoch={epoch}_{self.dataset_name}_seed{self.seed}_valloss={round(val_loss, 4)}.pt"
 
 
 
110
 
111
  # save checkpoint
112
  print('Saving checkpoint...')
113
  self._save_checkpoint(epoch, self.last_filename)
114
 
115
- # update best loss
116
- best_vloss = val_loss
117
-
118
- def evaluate(self):
119
- print("\n=====Test Evaluation=====")
120
- self._load_checkpoint(self.last_filename)
121
- self.model.eval()
122
- tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
123
-
124
- # show metrics
125
- for m in tst_metrics.keys():
126
- print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
127
-
128
- # save predictions
129
- pd.DataFrame(tst_preds).to_csv(
130
- os.path.join(
131
- self.checkpoints_folder,
132
- f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
133
- index=False
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def _train_one_epoch(self):
137
  raise NotImplementedError
138
 
139
- def _validate_one_epoch(self, data_loader):
140
  raise NotImplementedError
141
 
142
  def _print_configuration(self):
@@ -157,6 +188,8 @@ class Trainer:
157
  ckpt_path = os.path.join(self.checkpoints_folder, filename)
158
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
159
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
 
 
160
 
161
  def _save_checkpoint(self, current_epoch, filename):
162
  if not os.path.exists(self.checkpoints_folder):
@@ -177,6 +210,7 @@ class Trainer:
177
  'train_size': self.df_train.shape[0],
178
  'valid_size': self.df_valid.shape[0],
179
  'test_size': self.df_test.shape[0],
 
180
  },
181
  'seed': self.seed,
182
  }
@@ -203,9 +237,9 @@ class Trainer:
203
  class TrainerRegressor(Trainer):
204
 
205
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
206
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
207
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
208
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
209
 
210
  def _train_one_epoch(self):
211
  running_loss = 0.0
@@ -239,11 +273,13 @@ class TrainerRegressor(Trainer):
239
 
240
  return running_loss / len(self.train_loader)
241
 
242
- def _validate_one_epoch(self, data_loader):
243
  data_targets = []
244
  data_preds = []
245
  running_loss = 0.0
246
 
 
 
247
  with torch.no_grad():
248
  for idx, data in enumerate(pbar := tqdm(data_loader)):
249
  # Every data instance is an input + label pair
@@ -251,8 +287,8 @@ class TrainerRegressor(Trainer):
251
  targets = targets.clone().detach().to(self.device)
252
 
253
  # Make predictions for this batch
254
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
255
- predictions = self.model.net(embeddings).squeeze()
256
 
257
  # Compute the loss
258
  loss = self.loss_fn(predictions, targets)
@@ -292,9 +328,9 @@ class TrainerRegressor(Trainer):
292
  class TrainerClassifier(Trainer):
293
 
294
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
295
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
296
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
297
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
298
 
299
  def _train_one_epoch(self):
300
  running_loss = 0.0
@@ -328,11 +364,13 @@ class TrainerClassifier(Trainer):
328
 
329
  return running_loss / len(self.train_loader)
330
 
331
- def _validate_one_epoch(self, data_loader):
332
  data_targets = []
333
  data_preds = []
334
  running_loss = 0.0
335
 
 
 
336
  with torch.no_grad():
337
  for idx, data in enumerate(pbar := tqdm(data_loader)):
338
  # Every data instance is an input + label pair
@@ -340,8 +378,8 @@ class TrainerClassifier(Trainer):
340
  targets = targets.clone().detach().to(self.device)
341
 
342
  # Make predictions for this batch
343
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
344
- predictions = self.model.net(embeddings).squeeze()
345
 
346
  # Compute the loss
347
  loss = self.loss_fn(predictions, targets.long())
@@ -397,9 +435,9 @@ class TrainerClassifier(Trainer):
397
  class TrainerClassifierMultitask(Trainer):
398
 
399
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
400
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
401
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
402
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
403
 
404
  def _prepare_data(self):
405
  # normalize dataset
@@ -464,12 +502,14 @@ class TrainerClassifierMultitask(Trainer):
464
 
465
  return running_loss / len(self.train_loader)
466
 
467
- def _validate_one_epoch(self, data_loader):
468
  data_targets = []
469
  data_preds = []
470
  data_masks = []
471
  running_loss = 0.0
472
 
 
 
473
  with torch.no_grad():
474
  for idx, data in enumerate(pbar := tqdm(data_loader)):
475
  # Every data instance is an input + label pair + mask
@@ -477,8 +517,8 @@ class TrainerClassifierMultitask(Trainer):
477
  targets = targets.clone().detach().to(self.device)
478
 
479
  # Make predictions for this batch
480
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
481
- predictions = self.model.net(embeddings, multitask=True).squeeze()
482
  predictions = predictions * target_masks.to(self.device)
483
 
484
  # Compute the loss
@@ -548,4 +588,4 @@ class TrainerClassifierMultitask(Trainer):
548
  'specificity': average_sp.item(),
549
  }
550
 
551
- return preds, running_loss / len(data_loader), metrics
 
14
  import random
15
  import args
16
  import os
17
+ import shutil
18
  from tqdm import tqdm
19
 
20
  # Machine Learning
 
26
  class Trainer:
27
 
28
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
29
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
30
  # data
31
  self.df_train = raw_data[0]
32
  self.df_valid = raw_data[1]
 
40
  # config
41
  self.target_metric = target_metric
42
  self.seed = seed
43
+ self.smi_ted_version = smi_ted_version
44
  self.checkpoints_folder = checkpoints_folder
45
+ self.restart_filename = restart_filename
46
+ self.start_epoch = 1
47
  self.save_every_epoch = save_every_epoch
48
  self.save_ckpt = save_ckpt
49
  self.device = device
50
+ self.best_vloss = float('inf')
51
+ self.last_filename = None
52
  self._set_seed(seed)
53
 
54
  def _prepare_data(self):
 
86
  self.optimizer = optimizer
87
  self.loss_fn = loss_fn
88
  self._print_configuration()
89
+ if self.restart_filename:
90
+ self._load_checkpoint(self.restart_filename)
91
+ print('Checkpoint restored!')
92
 
93
  def fit(self, max_epochs=500):
94
+ for epoch in range(self.start_epoch, max_epochs+1):
 
 
95
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
96
 
97
  # training
 
106
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
107
 
108
  ############################### Save Finetune checkpoint #######################################
109
+ if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
110
  # remove old checkpoint
111
+ if (self.last_filename != None) and (not self.save_every_epoch):
112
  os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
113
 
114
  # filename
115
  model_name = f'{str(self.model)}-Finetune'
116
+ self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
117
+
118
+ # update best loss
119
+ self.best_vloss = val_loss
120
 
121
  # save checkpoint
122
  print('Saving checkpoint...')
123
  self._save_checkpoint(epoch, self.last_filename)
124
 
125
+ def evaluate(self, verbose=True):
126
+ if verbose:
127
+ print("\n=====Test Evaluation=====")
128
+
129
+ if self.smi_ted_version == 'v1':
130
+ import smi_ted_light.load as load
131
+ elif self.smi_ted_version == 'v2':
132
+ import smi_ted_large.load as load
133
+ else:
134
+ raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.')
135
+
136
+ # copy vocabulary to checkpoint folder
137
+ if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')):
138
+ smi_ted_path = os.path.dirname(load.__file__)
139
+ shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder)
140
+
141
+ # load model for inference
142
+ model_inf = load.load_smi_ted(
143
+ folder=self.checkpoints_folder,
144
+ ckpt_filename=self.last_filename,
145
+ eval=True,
146
+ ).to(self.device)
147
+
148
+ # set model evaluation mode
149
+ model_inf.eval()
150
+
151
+ # evaluate on test set
152
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
153
+
154
+ if verbose:
155
+ # show metrics
156
+ for m in tst_metrics.keys():
157
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
158
+
159
+ # save predictions
160
+ pd.DataFrame(tst_preds).to_csv(
161
+ os.path.join(
162
+ self.checkpoints_folder,
163
+ f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
164
+ index=False
165
+ )
166
 
167
  def _train_one_epoch(self):
168
  raise NotImplementedError
169
 
170
+ def _validate_one_epoch(self, data_loader, model=None):
171
  raise NotImplementedError
172
 
173
  def _print_configuration(self):
 
188
  ckpt_path = os.path.join(self.checkpoints_folder, filename)
189
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
190
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
191
+ self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
192
+ self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
193
 
194
  def _save_checkpoint(self, current_epoch, filename):
195
  if not os.path.exists(self.checkpoints_folder):
 
210
  'train_size': self.df_train.shape[0],
211
  'valid_size': self.df_valid.shape[0],
212
  'test_size': self.df_test.shape[0],
213
+ 'best_vloss': self.best_vloss,
214
  },
215
  'seed': self.seed,
216
  }
 
237
  class TrainerRegressor(Trainer):
238
 
239
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
240
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
241
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
242
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
243
 
244
  def _train_one_epoch(self):
245
  running_loss = 0.0
 
273
 
274
  return running_loss / len(self.train_loader)
275
 
276
+ def _validate_one_epoch(self, data_loader, model=None):
277
  data_targets = []
278
  data_preds = []
279
  running_loss = 0.0
280
 
281
+ model = self.model if model is None else model
282
+
283
  with torch.no_grad():
284
  for idx, data in enumerate(pbar := tqdm(data_loader)):
285
  # Every data instance is an input + label pair
 
287
  targets = targets.clone().detach().to(self.device)
288
 
289
  # Make predictions for this batch
290
+ embeddings = model.extract_embeddings(smiles).to(self.device)
291
+ predictions = model.net(embeddings).squeeze()
292
 
293
  # Compute the loss
294
  loss = self.loss_fn(predictions, targets)
 
328
  class TrainerClassifier(Trainer):
329
 
330
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
331
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
332
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
333
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
334
 
335
  def _train_one_epoch(self):
336
  running_loss = 0.0
 
364
 
365
  return running_loss / len(self.train_loader)
366
 
367
+ def _validate_one_epoch(self, data_loader, model=None):
368
  data_targets = []
369
  data_preds = []
370
  running_loss = 0.0
371
 
372
+ model = self.model if model is None else model
373
+
374
  with torch.no_grad():
375
  for idx, data in enumerate(pbar := tqdm(data_loader)):
376
  # Every data instance is an input + label pair
 
378
  targets = targets.clone().detach().to(self.device)
379
 
380
  # Make predictions for this batch
381
+ embeddings = model.extract_embeddings(smiles).to(self.device)
382
+ predictions = model.net(embeddings).squeeze()
383
 
384
  # Compute the loss
385
  loss = self.loss_fn(predictions, targets.long())
 
435
  class TrainerClassifierMultitask(Trainer):
436
 
437
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
438
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', restart_filename=None, save_every_epoch=False, save_ckpt=True, device='cpu'):
439
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
440
+ target_metric, seed, smi_ted_version, checkpoints_folder, restart_filename, save_every_epoch, save_ckpt, device)
441
 
442
  def _prepare_data(self):
443
  # normalize dataset
 
502
 
503
  return running_loss / len(self.train_loader)
504
 
505
+ def _validate_one_epoch(self, data_loader, model=None):
506
  data_targets = []
507
  data_preds = []
508
  data_masks = []
509
  running_loss = 0.0
510
 
511
+ model = self.model if model is None else model
512
+
513
  with torch.no_grad():
514
  for idx, data in enumerate(pbar := tqdm(data_loader)):
515
  # Every data instance is an input + label pair + mask
 
517
  targets = targets.clone().detach().to(self.device)
518
 
519
  # Make predictions for this batch
520
+ embeddings = model.extract_embeddings(smiles).to(self.device)
521
+ predictions = model.net(embeddings, multitask=True).squeeze()
522
  predictions = predictions * target_masks.to(self.device)
523
 
524
  # Compute the loss
 
588
  'specificity': average_sp.item(),
589
  }
590
 
591
+ return preds.cpu().numpy(), running_loss / len(data_loader), metrics