test / sgmse /sdes.py
Shokoufehhh's picture
Upload 27 files
b427b58 verified
raw
history blame
11.2 kB
"""
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
"""
import abc
import warnings
import numpy as np
from sgmse.util.tensors import batch_broadcast
import torch
from sgmse.util.registry import Registry
SDERegistry = Registry("SDE")
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t, *args):
pass
@abc.abstractmethod
def marginal_prob(self, x, t, *args):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
pass
@abc.abstractmethod
def prior_sampling(self, shape, *args):
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
@staticmethod
@abc.abstractmethod
def add_argparse_args(parent_parser):
"""
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
"""
pass
def discretize(self, x, t, y, stepsize):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = stepsize
drift, diffusion = self.sde(x, t, y)
f = drift * dt
G = diffusion * torch.sqrt(dt)
return f, G
def reverse(oself, score_model, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_model: A function that takes x, t and y and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = oself.N
T = oself.T
sde_fn = oself.sde
discretize_fn = oself.discretize
# Build the class for reverse-time SDE.
class RSDE(oself.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t, *args):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
rsde_parts = self.rsde_parts(x, t, *args)
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
return total_drift, diffusion
def rsde_parts(self, x, t, *args):
sde_drift, sde_diffusion = sde_fn(x, t, *args)
score = score_model(x, t, *args)
score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
total_drift = sde_drift + score_drift
return {
'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
}
def discretize(self, x, t, y, stepsize):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t, y, stepsize)
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
@abc.abstractmethod
def copy(self):
pass
@SDERegistry.register("ouve")
class OUVESDE(SDE):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default")
parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
return parser
def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs):
"""Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
to the methods which require it (e.g., `sde` or `marginal_prob`).
dx = -theta (y-x) dt + sigma(t) dw
with
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
Args:
theta: stiffness parameter.
sigma_min: smallest sigma.
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(N)
self.theta = theta
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.logsig = np.log(self.sigma_max / self.sigma_min)
self.N = N
def copy(self):
return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N)
@property
def T(self):
return 1
def sde(self, x, t, y):
drift = self.theta * (y - x)
# the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
# standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
# with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
# unless this sqrt(2*logsig) factor is included.
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
diffusion = sigma * np.sqrt(2 * self.logsig)
return drift, diffusion
def _mean(self, x0, t, y):
theta = self.theta
exp_interp = torch.exp(-theta * t)[:, None, None, None]
return exp_interp * x0 + (1 - exp_interp) * y
def alpha(self, t):
return torch.exp(-self.theta * t)
def _std(self, t):
# This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
# could maybe replace the two torch.exp(... * t) terms here by cached values **t
return torch.sqrt(
(
sigma_min**2
* torch.exp(-2 * theta * t)
* (torch.exp(2 * (theta + logsig) * t) - 1)
* logsig
)
/
(theta + logsig)
)
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T
def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
@SDERegistry.register("ouvp")
class OUVPSDE(SDE):
# !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=1000,
help="The number of timesteps in the SDE discretization. 1000 by default")
parser.add_argument("--beta-min", type=float, required=True,
help="The minimum beta to use.")
parser.add_argument("--beta-max", type=float, required=True,
help="The maximum beta to use.")
parser.add_argument("--stiffness", type=float, default=1,
help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
return parser
def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
"""
!!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
with
beta(t) = beta_min + t(beta_max - beta_min)
Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
to the methods which require it (e.g., `sde` or `marginal_prob`).
Args:
beta_min: smallest sigma.
beta_max: largest sigma.
stiffness: stiffness factor of the drift. 1 by default.
N: number of discretization steps
"""
super().__init__(N)
self.beta_min = beta_min
self.beta_max = beta_max
self.stiffness = stiffness
self.N = N
def copy(self):
return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
@property
def T(self):
return 1
def _beta(self, t):
return self.beta_min + t * (self.beta_max - self.beta_min)
def sde(self, x, t, y):
drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
diffusion = torch.sqrt(self._beta(t))
return drift, diffusion
def _mean(self, x0, t, y):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
return y + x0y_fac * (x0 - y)
def _std(self, t):
b0, b1, s = self.beta_min, self.beta_max, self.stiffness
return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(torch.ones((y.shape[0],), device=y.device))
x_T = y + torch.randn_like(y) * std[:, None, None, None]
return x_T
def prior_logp(self, z):
raise NotImplementedError("prior_logp for OU SDE not yet implemented!")