import pytorch_lightning as pl import torch import torch.nn.functional as F from salad.data.dataset import SALADDataset from salad.utils.train_util import PolyDecayScheduler class BaseModel(pl.LightningModule): def __init__( self, network, variance_schedule, **kwargs, ): super().__init__() self.save_hyperparameters(logger=False) self.net = network self.var_sched = variance_schedule def forward(self, x): return self.get_loss(x) def step(self, x, stage: str): loss = self(x) self.log( f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True, ) return loss def training_step(self, batch, batch_idx): x = batch return self.step(x, "train") def add_noise(self, x, t): """ Input: x: [B,D] or [B,G,D] t: list of size B Output: x_noisy: [B,D] beta: [B] e_rand: [B,D] """ alpha_bar = self.var_sched.alpha_bars[t] beta = self.var_sched.betas[t] c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1] c1 = torch.sqrt(1 - alpha_bar).view(-1, 1) e_rand = torch.randn_like(x) if e_rand.dim() == 3: c0 = c0.unsqueeze(1) c1 = c1.unsqueeze(1) x_noisy = c0 * x + c1 * e_rand return x_noisy, beta, e_rand def get_loss( self, x0, t=None, noisy_in=False, beta_in=None, e_rand_in=None, ): if x0.dim() == 2: B, D = x0.shape else: B, G, D = x0.shape if not noisy_in: if t is None: t = self.var_sched.uniform_sample_t(B) x_noisy, beta, e_rand = self.add_noise(x0, t) else: x_noisy = x0 beta = beta_in e_rand = e_rand_in e_theta = self.net(x_noisy, beta=beta) loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") return loss @torch.no_grad() def sample( self, batch_size=0, return_traj=False, ): raise NotImplementedError def validation_epoch_end(self, outputs): if self.hparams.no_run_validation: return if not self.trainer.sanity_checking: if (self.current_epoch) % self.hparams.validation_step == 0: self.validation() def _build_dataset(self, stage): if hasattr(self, f"data_{stage}"): return getattr(self, f"data_{stage}") if stage == "train": ds = SALADDataset(**self.hparams.dataset_kwargs) else: dataset_kwargs = self.hparams.dataset_kwargs.copy() dataset_kwargs["repeat"] = 1 ds = SALADDataset(**dataset_kwargs) setattr(self, f"data_{stage}", ds) return ds def _build_dataloader(self, stage): try: ds = getattr(self, f"data_{stage}") except: ds = self._build_dataset(stage) return torch.utils.data.DataLoader( ds, batch_size=self.hparams.batch_size, shuffle=stage == "train", drop_last=stage == "train", num_workers=4, ) def train_dataloader(self): return self._build_dataloader("train") def val_dataloader(self): return self._build_dataloader("val") def test_dataloader(self): return self._build_dataloader("test") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999) return [optimizer], [scheduler] #TODO move get_wandb_logger to logutil.py def get_wandb_logger(self): for logger in self.logger: if isinstance(logger, pl.loggers.wandb.WandbLogger): return logger return None