Victor Shirasuna commited on
Commit
db22044
·
1 Parent(s): 2b992bc

Added changes

Browse files
smi-ted/finetune/finetune_classification.py CHANGED
@@ -28,7 +28,7 @@ def main(config):
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
@@ -46,6 +46,7 @@ def main(config):
46
  hparams=config,
47
  target_metric=config.target_metric,
48
  seed=config.start_seed,
 
49
  checkpoints_folder=config.checkpoints_folder,
50
  device=device,
51
  save_every_epoch=bool(config.save_every_epoch),
 
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
 
46
  hparams=config,
47
  target_metric=config.target_metric,
48
  seed=config.start_seed,
49
+ smi_ted_version=config.smi_ted_version,
50
  checkpoints_folder=config.checkpoints_folder,
51
  device=device,
52
  save_every_epoch=bool(config.save_every_epoch),
smi-ted/finetune/finetune_classification_multitask.py CHANGED
@@ -60,7 +60,7 @@ def main(config):
60
  elif config.smi_ted_version == 'v2':
61
  from smi_ted_large.load import load_smi_ted
62
 
63
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets))
64
  model.net.apply(model._init_weights)
65
  print(model.net)
66
 
@@ -78,6 +78,7 @@ def main(config):
78
  hparams=config,
79
  target_metric=config.target_metric,
80
  seed=config.start_seed,
 
81
  checkpoints_folder=config.checkpoints_folder,
82
  device=device,
83
  save_every_epoch=bool(config.save_every_epoch),
 
60
  elif config.smi_ted_version == 'v2':
61
  from smi_ted_large.load import load_smi_ted
62
 
63
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=len(targets), eval=False)
64
  model.net.apply(model._init_weights)
65
  print(model.net)
66
 
 
78
  hparams=config,
79
  target_metric=config.target_metric,
80
  seed=config.start_seed,
81
+ smi_ted_version=config.smi_ted_version,
82
  checkpoints_folder=config.checkpoints_folder,
83
  device=device,
84
  save_every_epoch=bool(config.save_every_epoch),
smi-ted/finetune/finetune_regression.py CHANGED
@@ -28,7 +28,7 @@ def main(config):
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
- model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
@@ -48,6 +48,7 @@ def main(config):
48
  hparams=config,
49
  target_metric=config.target_metric,
50
  seed=config.start_seed,
 
51
  checkpoints_folder=config.checkpoints_folder,
52
  device=device,
53
  save_every_epoch=bool(config.save_every_epoch),
 
28
  elif config.smi_ted_version == 'v2':
29
  from smi_ted_large.load import load_smi_ted
30
 
31
+ model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
32
  model.net.apply(model._init_weights)
33
  print(model.net)
34
 
 
48
  hparams=config,
49
  target_metric=config.target_metric,
50
  seed=config.start_seed,
51
+ smi_ted_version=config.smi_ted_version,
52
  checkpoints_folder=config.checkpoints_folder,
53
  device=device,
54
  save_every_epoch=bool(config.save_every_epoch),
smi-ted/finetune/smi_ted_large/load.py CHANGED
@@ -318,7 +318,7 @@ class Net(nn.Module):
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
- def __init__(self, config, n_vocab):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
@@ -337,7 +337,7 @@ class MoLEncoder(nn.Module):
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
- deterministic_eval=False),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
@@ -361,7 +361,7 @@ class MoLDecoder(nn.Module):
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Large 738M Parameters"""
363
 
364
- def __init__(self, tokenizer, config=None):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
@@ -373,11 +373,11 @@ class Smi_ted(nn.Module):
373
 
374
  # instantiate modules
375
  if self.config:
376
- self.encoder = MoLEncoder(self.config, self.n_vocab)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
- def load_checkpoint(self, ckpt_path, n_output):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
@@ -388,7 +388,7 @@ class Smi_ted(nn.Module):
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
- self.encoder = MoLEncoder(self.config, self.n_vocab)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
@@ -493,11 +493,12 @@ class Smi_ted(nn.Module):
493
  def load_smi_ted(folder="./smi_ted_large",
494
  ckpt_filename="smi-ted-Large_30.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
- n_output=1
 
497
  ):
498
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
499
  model = Smi_ted(tokenizer)
500
- model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output)
501
  print('Vocab size:', len(tokenizer.vocab))
502
  print(f'[FINETUNE MODE - {str(model)}]')
503
  return model
 
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
+ def __init__(self, config, n_vocab, eval=False):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
 
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
+ deterministic_eval=eval),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
 
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Large 738M Parameters"""
363
 
364
+ def __init__(self, tokenizer, config=None, eval=False):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
 
373
 
374
  # instantiate modules
375
  if self.config:
376
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
+ def load_checkpoint(self, ckpt_path, n_outputm eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
 
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
 
493
  def load_smi_ted(folder="./smi_ted_large",
494
  ckpt_filename="smi-ted-Large_30.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
+ n_output=1,
497
+ eval=False
498
  ):
499
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
500
  model = Smi_ted(tokenizer)
501
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
502
  print('Vocab size:', len(tokenizer.vocab))
503
  print(f'[FINETUNE MODE - {str(model)}]')
504
  return model
smi-ted/finetune/smi_ted_light/load.py CHANGED
@@ -318,7 +318,7 @@ class Net(nn.Module):
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
- def __init__(self, config, n_vocab):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
@@ -337,7 +337,7 @@ class MoLEncoder(nn.Module):
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
- deterministic_eval=False),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
@@ -361,7 +361,7 @@ class MoLDecoder(nn.Module):
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Light 289M Parameters"""
363
 
364
- def __init__(self, tokenizer, config=None):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
@@ -373,11 +373,11 @@ class Smi_ted(nn.Module):
373
 
374
  # instantiate modules
375
  if self.config:
376
- self.encoder = MoLEncoder(self.config, self.n_vocab)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
- def load_checkpoint(self, ckpt_path, n_output):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
@@ -388,7 +388,7 @@ class Smi_ted(nn.Module):
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
- self.encoder = MoLEncoder(self.config, self.n_vocab)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
@@ -493,11 +493,12 @@ class Smi_ted(nn.Module):
493
  def load_smi_ted(folder="./smi_ted_light",
494
  ckpt_filename="smi-ted-Light_40.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
- n_output=1
 
497
  ):
498
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
499
  model = Smi_ted(tokenizer)
500
- model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output)
501
  print('Vocab size:', len(tokenizer.vocab))
502
  print(f'[FINETUNE MODE - {str(model)}]')
503
  return model
 
318
 
319
  class MoLEncoder(nn.Module):
320
 
321
+ def __init__(self, config, n_vocab, eval=False):
322
  super(MoLEncoder, self).__init__()
323
 
324
  # embeddings
 
337
  # unless we do deterministic_eval here, we will have random outputs
338
  feature_map=partial(GeneralizedRandomFeatures,
339
  n_dims=config['num_feats'],
340
+ deterministic_eval=eval),
341
  activation='gelu'
342
  )
343
  self.blocks = builder.get()
 
361
  class Smi_ted(nn.Module):
362
  """materials.smi-ted-Light 289M Parameters"""
363
 
364
+ def __init__(self, tokenizer, config=None, eval=False):
365
  super(Smi_ted, self).__init__()
366
 
367
  # configuration
 
373
 
374
  # instantiate modules
375
  if self.config:
376
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
377
  self.decoder = MoLDecoder(self.n_vocab, self.config['max_len'], self.config['n_embd'])
378
  self.net = Net(self.config['n_embd'], n_output=self.config['n_output'], dropout=self.config['dropout'])
379
 
380
+ def load_checkpoint(self, ckpt_path, n_output, eval=False):
381
  # load checkpoint file
382
  checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
383
 
 
388
  self._set_seed(self.config['seed'])
389
 
390
  # instantiate modules
391
+ self.encoder = MoLEncoder(self.config, self.n_vocab, eval=eval)
392
  self.decoder = MoLDecoder(self.n_vocab, self.max_len, self.n_embd)
393
  self.net = Net(self.n_embd, n_output=self.config['n_output'] if 'n_output' in self.config else n_output, dropout=self.config['dropout'])
394
 
 
493
  def load_smi_ted(folder="./smi_ted_light",
494
  ckpt_filename="smi-ted-Light_40.pt",
495
  vocab_filename="bert_vocab_curated.txt",
496
+ n_output=1,
497
+ eval=False
498
  ):
499
  tokenizer = MolTranBertTokenizer(os.path.join(folder, vocab_filename))
500
  model = Smi_ted(tokenizer)
501
+ model.load_checkpoint(os.path.join(folder, ckpt_filename), n_output, eval=eval)
502
  print('Vocab size:', len(tokenizer.vocab))
503
  print(f'[FINETUNE MODE - {str(model)}]')
504
  return model
smi-ted/finetune/trainers.py CHANGED
@@ -14,6 +14,7 @@ import numpy as np
14
  import random
15
  import args
16
  import os
 
17
  from tqdm import tqdm
18
 
19
  # Machine Learning
@@ -25,7 +26,7 @@ from utils import RMSE, sensitivity, specificity
25
  class Trainer:
26
 
27
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
28
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
29
  # data
30
  self.df_train = raw_data[0]
31
  self.df_valid = raw_data[1]
@@ -39,6 +40,7 @@ class Trainer:
39
  # config
40
  self.target_metric = target_metric
41
  self.seed = seed
 
42
  self.checkpoints_folder = checkpoints_folder
43
  self.save_every_epoch = save_every_epoch
44
  self.save_ckpt = save_ckpt
@@ -115,28 +117,52 @@ class Trainer:
115
  # update best loss
116
  best_vloss = val_loss
117
 
118
- def evaluate(self):
119
- print("\n=====Test Evaluation=====")
120
- self._load_checkpoint(self.last_filename)
121
- self.model.eval()
122
- tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader)
123
-
124
- # show metrics
125
- for m in tst_metrics.keys():
126
- print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
127
-
128
- # save predictions
129
- pd.DataFrame(tst_preds).to_csv(
130
- os.path.join(
131
- self.checkpoints_folder,
132
- f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
133
- index=False
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def _train_one_epoch(self):
137
  raise NotImplementedError
138
 
139
- def _validate_one_epoch(self, data_loader):
140
  raise NotImplementedError
141
 
142
  def _print_configuration(self):
@@ -203,9 +229,9 @@ class Trainer:
203
  class TrainerRegressor(Trainer):
204
 
205
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
206
- target_metric='rmse', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
207
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
208
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
209
 
210
  def _train_one_epoch(self):
211
  running_loss = 0.0
@@ -239,11 +265,13 @@ class TrainerRegressor(Trainer):
239
 
240
  return running_loss / len(self.train_loader)
241
 
242
- def _validate_one_epoch(self, data_loader):
243
  data_targets = []
244
  data_preds = []
245
  running_loss = 0.0
246
 
 
 
247
  with torch.no_grad():
248
  for idx, data in enumerate(pbar := tqdm(data_loader)):
249
  # Every data instance is an input + label pair
@@ -251,8 +279,8 @@ class TrainerRegressor(Trainer):
251
  targets = targets.clone().detach().to(self.device)
252
 
253
  # Make predictions for this batch
254
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
255
- predictions = self.model.net(embeddings).squeeze()
256
 
257
  # Compute the loss
258
  loss = self.loss_fn(predictions, targets)
@@ -292,9 +320,9 @@ class TrainerRegressor(Trainer):
292
  class TrainerClassifier(Trainer):
293
 
294
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
295
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
296
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
297
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
298
 
299
  def _train_one_epoch(self):
300
  running_loss = 0.0
@@ -328,11 +356,13 @@ class TrainerClassifier(Trainer):
328
 
329
  return running_loss / len(self.train_loader)
330
 
331
- def _validate_one_epoch(self, data_loader):
332
  data_targets = []
333
  data_preds = []
334
  running_loss = 0.0
335
 
 
 
336
  with torch.no_grad():
337
  for idx, data in enumerate(pbar := tqdm(data_loader)):
338
  # Every data instance is an input + label pair
@@ -340,8 +370,8 @@ class TrainerClassifier(Trainer):
340
  targets = targets.clone().detach().to(self.device)
341
 
342
  # Make predictions for this batch
343
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
344
- predictions = self.model.net(embeddings).squeeze()
345
 
346
  # Compute the loss
347
  loss = self.loss_fn(predictions, targets.long())
@@ -397,9 +427,9 @@ class TrainerClassifier(Trainer):
397
  class TrainerClassifierMultitask(Trainer):
398
 
399
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
400
- target_metric='roc-auc', seed=0, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
401
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
402
- target_metric, seed, checkpoints_folder, save_every_epoch, save_ckpt, device)
403
 
404
  def _prepare_data(self):
405
  # normalize dataset
@@ -464,12 +494,14 @@ class TrainerClassifierMultitask(Trainer):
464
 
465
  return running_loss / len(self.train_loader)
466
 
467
- def _validate_one_epoch(self, data_loader):
468
  data_targets = []
469
  data_preds = []
470
  data_masks = []
471
  running_loss = 0.0
472
 
 
 
473
  with torch.no_grad():
474
  for idx, data in enumerate(pbar := tqdm(data_loader)):
475
  # Every data instance is an input + label pair + mask
@@ -477,8 +509,8 @@ class TrainerClassifierMultitask(Trainer):
477
  targets = targets.clone().detach().to(self.device)
478
 
479
  # Make predictions for this batch
480
- embeddings = self.model.extract_embeddings(smiles).to(self.device)
481
- predictions = self.model.net(embeddings, multitask=True).squeeze()
482
  predictions = predictions * target_masks.to(self.device)
483
 
484
  # Compute the loss
 
14
  import random
15
  import args
16
  import os
17
+ import shutil
18
  from tqdm import tqdm
19
 
20
  # Machine Learning
 
26
  class Trainer:
27
 
28
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
29
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
30
  # data
31
  self.df_train = raw_data[0]
32
  self.df_valid = raw_data[1]
 
40
  # config
41
  self.target_metric = target_metric
42
  self.seed = seed
43
+ self.smi_ted_version = smi_ted_version
44
  self.checkpoints_folder = checkpoints_folder
45
  self.save_every_epoch = save_every_epoch
46
  self.save_ckpt = save_ckpt
 
117
  # update best loss
118
  best_vloss = val_loss
119
 
120
+ def evaluate(self, verbose=True):
121
+ if verbose:
122
+ print("\n=====Test Evaluation=====")
123
+
124
+ if self.smi_ted_version == 'v1':
125
+ import smi_ted_light.load as load
126
+ elif self.smi_ted_version == 'v2':
127
+ import smi_ted_large.load as load
128
+ else:
129
+ raise Exception('Please, specify the SMI-TED version: `v1` or `v2`.')
130
+
131
+ # copy vocabulary to checkpoint folder
132
+ if not os.path.exists(os.path.join(self.checkpoints_folder, 'bert_vocab_curated.txt')):
133
+ smi_ted_path = os.path.dirname(load.__file__)
134
+ shutil.copy(os.path.join(smi_ted_path, 'bert_vocab_curated.txt'), self.checkpoints_folder)
135
+
136
+ # load model for inference
137
+ model_inf = load.load_smi_ted(
138
+ folder=self.checkpoints_folder,
139
+ ckpt_filename=self.last_filename,
140
+ eval=True,
141
+ ).to(self.device)
142
+
143
+ # set model evaluation mode
144
+ model_inf.eval()
145
+
146
+ # evaluate on test set
147
+ tst_preds, tst_loss, tst_metrics = self._validate_one_epoch(self.test_loader, model_inf)
148
+
149
+ if verbose:
150
+ # show metrics
151
+ for m in tst_metrics.keys():
152
+ print(f"[TEST] Evaluation {m.upper()}: {round(tst_metrics[m], 4)}")
153
+
154
+ # save predictions
155
+ pd.DataFrame(tst_preds).to_csv(
156
+ os.path.join(
157
+ self.checkpoints_folder,
158
+ f'{self.dataset_name}_{self.target if isinstance(self.target, str) else self.target[0]}_predict_test_seed{self.seed}.csv'),
159
+ index=False
160
+ )
161
 
162
  def _train_one_epoch(self):
163
  raise NotImplementedError
164
 
165
+ def _validate_one_epoch(self, data_loader, model=None):
166
  raise NotImplementedError
167
 
168
  def _print_configuration(self):
 
229
  class TrainerRegressor(Trainer):
230
 
231
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
232
+ target_metric='rmse', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
233
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
234
+ target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
235
 
236
  def _train_one_epoch(self):
237
  running_loss = 0.0
 
265
 
266
  return running_loss / len(self.train_loader)
267
 
268
+ def _validate_one_epoch(self, data_loader, model=None):
269
  data_targets = []
270
  data_preds = []
271
  running_loss = 0.0
272
 
273
+ model = self.model if model is None else model
274
+
275
  with torch.no_grad():
276
  for idx, data in enumerate(pbar := tqdm(data_loader)):
277
  # Every data instance is an input + label pair
 
279
  targets = targets.clone().detach().to(self.device)
280
 
281
  # Make predictions for this batch
282
+ embeddings = model.extract_embeddings(smiles).to(self.device)
283
+ predictions = model.net(embeddings).squeeze()
284
 
285
  # Compute the loss
286
  loss = self.loss_fn(predictions, targets)
 
320
  class TrainerClassifier(Trainer):
321
 
322
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
323
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
324
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
325
+ target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
326
 
327
  def _train_one_epoch(self):
328
  running_loss = 0.0
 
356
 
357
  return running_loss / len(self.train_loader)
358
 
359
+ def _validate_one_epoch(self, data_loader, model=None):
360
  data_targets = []
361
  data_preds = []
362
  running_loss = 0.0
363
 
364
+ model = self.model if model is None else model
365
+
366
  with torch.no_grad():
367
  for idx, data in enumerate(pbar := tqdm(data_loader)):
368
  # Every data instance is an input + label pair
 
370
  targets = targets.clone().detach().to(self.device)
371
 
372
  # Make predictions for this batch
373
+ embeddings = model.extract_embeddings(smiles).to(self.device)
374
+ predictions = model.net(embeddings).squeeze()
375
 
376
  # Compute the loss
377
  loss = self.loss_fn(predictions, targets.long())
 
427
  class TrainerClassifierMultitask(Trainer):
428
 
429
  def __init__(self, raw_data, dataset_name, target, batch_size, hparams,
430
+ target_metric='roc-auc', seed=0, smi_ted_version=None, checkpoints_folder='./checkpoints', save_every_epoch=False, save_ckpt=True, device='cpu'):
431
  super().__init__(raw_data, dataset_name, target, batch_size, hparams,
432
+ target_metric, seed, smi_ted_version, checkpoints_folder, save_every_epoch, save_ckpt, device)
433
 
434
  def _prepare_data(self):
435
  # normalize dataset
 
494
 
495
  return running_loss / len(self.train_loader)
496
 
497
+ def _validate_one_epoch(self, data_loader, model=None):
498
  data_targets = []
499
  data_preds = []
500
  data_masks = []
501
  running_loss = 0.0
502
 
503
+ model = self.model if model is None else model
504
+
505
  with torch.no_grad():
506
  for idx, data in enumerate(pbar := tqdm(data_loader)):
507
  # Every data instance is an input + label pair + mask
 
509
  targets = targets.clone().detach().to(self.device)
510
 
511
  # Make predictions for this batch
512
+ embeddings = model.extract_embeddings(smiles).to(self.device)
513
+ predictions = model.net(embeddings, multitask=True).squeeze()
514
  predictions = predictions * target_masks.to(self.device)
515
 
516
  # Compute the loss