Victor Shirasuna
commited on
Commit
·
6f04789
1
Parent(s):
db22044
Fix multitask finetune
Browse files
smi-ted/finetune/finetune_classification_multitask.py
CHANGED
@@ -48,6 +48,7 @@ def main(config):
|
|
48 |
'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
|
49 |
'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
|
50 |
]
|
|
|
51 |
|
52 |
# load dataset
|
53 |
df_train = pd.read_csv(f"{config.data_root}/train.csv")
|
|
|
48 |
'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
|
49 |
'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'
|
50 |
]
|
51 |
+
config.n_output = len(targets)
|
52 |
|
53 |
# load dataset
|
54 |
df_train = pd.read_csv(f"{config.data_root}/train.csv")
|
smi-ted/finetune/trainers.py
CHANGED
@@ -580,4 +580,4 @@ class TrainerClassifierMultitask(Trainer):
|
|
580 |
'specificity': average_sp.item(),
|
581 |
}
|
582 |
|
583 |
-
return preds, running_loss / len(data_loader), metrics
|
|
|
580 |
'specificity': average_sp.item(),
|
581 |
}
|
582 |
|
583 |
+
return preds.cpu().numpy(), running_loss / len(data_loader), metrics
|