materials.smi-ted / smi-ted /finetune /finetune_regression.py
eduardosoares99's picture
Update finetune code (#8)
30e2e3e verified
raw
history blame
2.1 kB
# Deep learning
import torch
import torch.nn as nn
from torch import optim
from trainers import TrainerRegressor
from utils import RMSELoss, get_optim_groups
# Data
import pandas as pd
import numpy as np
# Standard library
import args
import os
def main(config):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# load dataset
df_train = pd.read_csv(f"{config.data_root}/train.csv")
df_valid = pd.read_csv(f"{config.data_root}/valid.csv")
df_test = pd.read_csv(f"{config.data_root}/test.csv")
# load model
if config.smi_ted_version == 'v1':
from smi_ted_light.load import load_smi_ted
elif config.smi_ted_version == 'v2':
from smi_ted_large.load import load_smi_ted
model = load_smi_ted(folder=config.model_path, ckpt_filename=config.ckpt_filename, n_output=config.n_output, eval=False)
model.net.apply(model._init_weights)
print(model.net)
lr = config.lr_start*config.lr_multiplier
optim_groups = get_optim_groups(model, keep_decoder=bool(config.train_decoder))
if config.loss_fn == 'rmse':
loss_function = RMSELoss()
elif config.loss_fn == 'mae':
loss_function = nn.L1Loss()
# init trainer
trainer = TrainerRegressor(
raw_data=(df_train, df_valid, df_test),
dataset_name=config.dataset_name,
target=config.measure_name,
batch_size=config.n_batch,
hparams=config,
target_metric=config.target_metric,
seed=config.start_seed,
smi_ted_version=config.smi_ted_version,
checkpoints_folder=config.checkpoints_folder,
restart_filename=config.restart_filename,
device=device,
save_every_epoch=bool(config.save_every_epoch),
save_ckpt=bool(config.save_ckpt)
)
trainer.compile(
model=model,
optimizer=optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99)),
loss_fn=loss_function
)
trainer.fit(max_epochs=config.max_epochs)
trainer.evaluate()
if __name__ == '__main__':
parser = args.get_parser()
config = parser.parse_args()
main(config)