Geneformer / geneformer /mtl /optuna_utils.py
ctheodoris's picture
update with 12L and 20L i4096 gc95M models, multitask and quantiz code
933ca80
raw
history blame
822 Bytes
import optuna
from optuna.integration import TensorBoardCallback
def save_trial_callback(study, trial, trials_result_path):
with open(trials_result_path, "a") as f:
f.write(f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n")
def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
study = optuna.create_study(direction="maximize")
# init TensorBoard callback
tensorboard_callback = TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro")
# callback and TensorBoard callback
callbacks = [
lambda study, trial: save_trial_callback(study, trial, trials_result_path),
tensorboard_callback
]
study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
return study