Spaces:
Sleeping
Sleeping
# Pip-Packages ----------------------------------------------------- | |
import importlib | |
import os | |
import sys | |
from datetime import datetime | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import optim | |
from torch.utils.data import DataLoader | |
# From local package ----------------------------------------------- | |
from disvae.models.losses import get_loss_f | |
from disvae.models.vae import init_specific_model | |
from disvae.training import Trainer | |
from disvae.utils.modelIO import save_model | |
# Loss stuff: | |
def parse_losses(p_model, filename="train_losses.log"): | |
df = pd.read_csv(Path(p_model) / filename) | |
losses = df["Loss"].unique() | |
rtn = [np.array(df[df["Loss"] == l]["Value"]) for l in losses] | |
rtn = pd.DataFrame(np.array(rtn).T, columns=losses) | |
return rtn | |
def get_kl_loss_latent(df): | |
"""df muss bereits geparsed sein!""" | |
rtn = {int(c.split("_")[-1]): df[c].iloc[-1] for c in df if "kl_loss_" in c} | |
rtn = dict(sorted(rtn.items(), key=lambda item: item[1], reverse=True)) | |
return rtn | |
def get_kl_dict(p_model): | |
df = parse_losses(p_model) | |
rtn = get_kl_loss_latent(df) | |
return rtn | |
# Datalaader convinience stuff | |
# def get_dataloader(dataset: torch.data.Dataset, batch_size, num_workers): | |
# # Funktion ist recht kompliziert. Das geht im Notebook schnell | |
# # Diese Dinge werden auch zur Visualisierung des Datasets benötigt | |
# # p_dataset_module, dataset_class, dataset_args | |
# # Import module | |
# # if p_dataset_module not in sys.path: | |
# # sys.path.append(str(Path(p_dataset_module).parent)) | |
# # Dataset = getattr( | |
# # importlib.import_module(Path(p_dataset_module).stem), dataset_class | |
# # ) | |
# # # Ab hier an, wenn das normal importiert würde | |
# # ds = Dataset(**dataset_args) | |
# | |
# return loader | |
def get_export_dir(base_dir: str, folder_name): | |
if folder_name is None: | |
folder_name = "Model_" + ( | |
datetime.now().replace(microsecond=0).isoformat() | |
).replace(" ", "_").replace(":", "-") | |
rtn = Path(base_dir) / folder_name | |
if not rtn.exists(): | |
os.makedirs(rtn) | |
else: | |
raise ValueError("Output directory already exists.") | |
return rtn | |
def train_model(model, data_loader, loss_f, device, lr, epochs, export_dir): | |
trainer = Trainer( | |
model, | |
optim.Adam(model.parameters(), lr=lr), | |
loss_f, | |
device=device, | |
# logger=logger, | |
save_dir=export_dir, | |
is_progress_bar=True, | |
) # , | |
# gif_visualizer=gif_visualizer) | |
trainer(data_loader, epochs=epochs, checkpoint_every=10) | |
save_model(trainer.model, export_dir) | |
# , metadata=config) # Speichern passiert auch schon vorher | |
# gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir) | |
def train(dataset, config) -> str: | |
# Validate Config? | |
print("1) Set device") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Device:\t\t {device}") | |
print("2) Get dataloader") | |
dataloader = DataLoader( | |
dataset, | |
batch_size=config["data_params"]["batch_size"], | |
shuffle=True, | |
pin_memory=torch.cuda.is_available, | |
num_workers=config["data_params"]["num_workers"], | |
) | |
print("3) Build model") | |
img_size = list(dataloader.dataset[0][0].shape) | |
print(f"Image size: \t {img_size}") | |
model = init_specific_model(img_size=img_size, **config["model_params"]) | |
model = model.to(device) # make sure trainer and viz on same device | |
print("4) Build loss function") | |
loss_f = get_loss_f( | |
n_data=len(dataloader.dataset), device=device, **config["loss_params"] | |
) | |
print("5) Parse Export Params") | |
export_dir = get_export_dir(**config["export_params"]) | |
print("6) Training model") | |
train_model( | |
model=model, | |
data_loader=dataloader, | |
loss_f=loss_f, | |
device=device, | |
export_dir=export_dir, | |
**config["trainer_params"], | |
) | |
return export_dir | |