import gradio as gr import numpy as np from PIL import Image import torch.nn as nn import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch from typing import Dict import functools import inspect from types import SimpleNamespace class Autoencoder(nn.Module): def __init__(self): super().__init__() # N, 1 512,512 self.encoder = nn.Sequential( # nn.Conv2d(input_channel,16,3,stride=2, padding=1), nn.Conv2d(1,2,3,stride=2, padding=1), # N, 2, 256, 256 nn.ReLU(), nn.Conv2d(2,3,3,stride=2, padding=1), # N, 3, 128, 128 nn.ReLU(), nn.Conv2d(3,4,3,stride=2, padding=1), # N, 4, 64, 64 ) self.decoder = nn.Sequential( nn.ConvTranspose2d(4,3,3,stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(3,2,3,stride=2, padding=1,output_padding=1), nn.ReLU(), nn.ConvTranspose2d(2,1,3,stride=2, padding=1,output_padding=1), nn.Tanh() ) def forward(self,x): encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ @functools.wraps(init) def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} ignore = getattr(self, "ignore_for_config", []) # Get positional arguments aligned with kwargs new_kwargs = {} signature = inspect.signature(init) parameters = { name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore } for arg, name in zip(args, parameters.keys()): new_kwargs[name] = arg # Then add all kwargs new_kwargs.update( { k: init_kwargs.get(k, default) for k, default in parameters.items() if k not in ignore and k not in new_kwargs } ) new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) return inner_init def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas) class DDIMScheduler(): config_name = "scheduler_config.json" _deprecated_kwargs = ["predict_epsilon"] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = False, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", **kwargs, ): message = ( "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" " DDIMScheduler.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = kwargs.get('predict_epsilon', None) if predict_epsilon is not None: self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) for key, value in kwargs.items(): try: setattr(self, key, value) except AttributeError as err: print(f"Can't set {key} with value {value} for {self}") raise err if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: previous_dict = dict(self._internal_dict) internal_dict = {**self._internal_dict, **kwargs} print(f"Updating config from {previous_dict} to {internal_dict}") self._internal_dict = internal_dict @property def config(self): """ Returns the config of the class as a frozen dictionary Returns: `Dict[str, Any]`: Config of the class. """ return SimpleNamespace(**self._internal_dict) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.config.steps_offset def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[Dict, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # predict V model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the model_output is always re-derived from the clipped x_0 in Glide model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 device = model_output.device if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: if device.type == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device) else: variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return dict(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps def dummy_model(img): img_array = np.array(img) return img_array iface = gr.Interface( fn=dummy_model, inputs="image", outputs="image" ) iface.launch()