Victor Shirasuna commited on
Commit
6f04789
·
1 Parent(s): db22044

Fix multitask finetune

Browse files
smi-ted/finetune/finetune_classification_multitask.py CHANGED
@@ -48,6 +48,7 @@ def main(config):
48
  'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
49
  'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
50
  ]
 
51
 
52
  # load dataset
53
  df_train = pd.read_csv(f"{config.data_root}/train.csv")
 
48
  'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
49
  'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
50
  ]
51
+ config.n_output = len(targets)
52
 
53
  # load dataset
54
  df_train = pd.read_csv(f"{config.data_root}/train.csv")
smi-ted/finetune/trainers.py CHANGED
@@ -580,4 +580,4 @@ class TrainerClassifierMultitask(Trainer):
580
  'specificity': average_sp.item(),
581
  }
582
 
583
- return preds, running_loss / len(data_loader), metrics
 
580
  'specificity': average_sp.item(),
581
  }
582
 
583
+ return preds.cpu().numpy(), running_loss / len(data_loader), metrics