from tqdm import trange, tqdm import numpy as np import wandb import torch import torch.optim as optim from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingWarmRestarts import torch.backends.cudnn as cudnn import os from cfg import get_cfg from datasets import get_ds from methods import get_method def get_scheduler(optimizer, cfg): if cfg.lr_step == "cos": return CosineAnnealingWarmRestarts( optimizer, T_0=cfg.epoch if cfg.T0 is None else cfg.T0, T_mult=cfg.Tmult, eta_min=cfg.eta_min, ) elif cfg.lr_step == "step": m = [cfg.epoch - a for a in cfg.drop] return MultiStepLR(optimizer, milestones=m, gamma=cfg.drop_gamma) else: return None if __name__ == "__main__": cfg = get_cfg() wandb.init(project=f"ssl-sota-{cfg.method}-{cfg.dataset}", config=cfg, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') run_id = wandb.run.id # if not os.path.exists('../results'): # os.mkdir('../results') run_id_dir = os.path.join('/mnt/store/wbandar1/projects/ssl-aug-artifacts/', run_id) if not os.path.exists(run_id_dir): print('Creating directory {}'.format(run_id_dir)) os.mkdir(run_id_dir) ds = get_ds(cfg.dataset)(cfg.bs, cfg, cfg.num_workers) model = get_method(cfg.method)(cfg) model.cuda().train() if cfg.fname is not None: model.load_state_dict(torch.load(cfg.fname)) optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.adam_l2) scheduler = get_scheduler(optimizer, cfg) eval_every = cfg.eval_every lr_warmup = 0 if cfg.lr_warmup else 500 cudnn.benchmark = True for ep in trange(cfg.epoch, position=0): loss_ep = [] iters = len(ds.train) for n_iter, (samples, _) in enumerate(tqdm(ds.train, position=1)): if lr_warmup < 500: lr_scale = (lr_warmup + 1) / 500 for pg in optimizer.param_groups: pg["lr"] = cfg.lr * lr_scale lr_warmup += 1 optimizer.zero_grad() loss = model(samples) loss.backward() optimizer.step() loss_ep.append(loss.item()) model.step(ep / cfg.epoch) if cfg.lr_step == "cos" and lr_warmup >= 500: scheduler.step(ep + n_iter / iters) if cfg.lr_step == "step": scheduler.step() if len(cfg.drop) and ep == (cfg.epoch - cfg.drop[0]): eval_every = cfg.eval_every_drop if (ep + 1) % eval_every == 0: # acc_knn, acc = model.get_acc(ds.clf, ds.test) # wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) acc_knn = model.get_acc_knn(ds.clf, ds.test) wandb.log({"acc_knn": acc_knn}, commit=False) if (ep + 1) % 100 == 0: fname = f"/mnt/store/wbandar1/projects/ssl-aug-artifacts/{run_id}/{cfg.method}_{cfg.dataset}_{ep}.pt" torch.save(model.state_dict(), fname) wandb.log({"loss": np.mean(loss_ep), "ep": ep}) acc_knn, acc = model.get_acc(ds.clf, ds.test) print('Final linear-acc: {}, knn-acc'.format(acc, acc_knn)) wandb.log({"acc": acc[1], "acc_5": acc[5], "acc_knn": acc_knn}, commit=False) wandb.finish()