supundhananjaya commited on
Commit
a053391
·
verified ·
1 Parent(s): 9b0fc58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py CHANGED
@@ -1,8 +1,413 @@
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def dummy_model(img):
7
  img_array = np.array(img)
8
  return img_array
 
1
  import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
+ import torch.nn as nn
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch
9
+ from typing import Dict
10
+ import functools
11
+ import inspect
12
+ from types import SimpleNamespace
13
 
14
+ class Autoencoder(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ # N, 1 512,512
18
 
19
+ self.encoder = nn.Sequential(
20
+ # nn.Conv2d(input_channel,16,3,stride=2, padding=1),
21
+ nn.Conv2d(1,2,3,stride=2, padding=1), # N, 2, 256, 256
22
+ nn.ReLU(),
23
+ nn.Conv2d(2,3,3,stride=2, padding=1), # N, 3, 128, 128
24
+ nn.ReLU(),
25
+ nn.Conv2d(3,4,3,stride=2, padding=1), # N, 4, 64, 64
26
+ )
27
+
28
+ self.decoder = nn.Sequential(
29
+ nn.ConvTranspose2d(4,3,3,stride=2, padding=1, output_padding=1),
30
+ nn.ReLU(),
31
+ nn.ConvTranspose2d(3,2,3,stride=2, padding=1,output_padding=1),
32
+ nn.ReLU(),
33
+ nn.ConvTranspose2d(2,1,3,stride=2, padding=1,output_padding=1),
34
+ nn.Tanh()
35
+ )
36
+
37
+ def forward(self,x):
38
+ encoded = self.encoder(x)
39
+ decoded = self.decoder(encoded)
40
+ return decoded
41
+
42
+ def register_to_config(init):
43
+ r"""
44
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
45
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
46
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
47
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
48
+ """
49
+
50
+ @functools.wraps(init)
51
+ def inner_init(self, *args, **kwargs):
52
+ # Ignore private kwargs in the init.
53
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
54
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
55
+
56
+ ignore = getattr(self, "ignore_for_config", [])
57
+ # Get positional arguments aligned with kwargs
58
+ new_kwargs = {}
59
+ signature = inspect.signature(init)
60
+ parameters = {
61
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
62
+ }
63
+ for arg, name in zip(args, parameters.keys()):
64
+ new_kwargs[name] = arg
65
+
66
+ # Then add all kwargs
67
+ new_kwargs.update(
68
+ {
69
+ k: init_kwargs.get(k, default)
70
+ for k, default in parameters.items()
71
+ if k not in ignore and k not in new_kwargs
72
+ }
73
+ )
74
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
75
+ getattr(self, "register_to_config")(**new_kwargs)
76
+ init(self, *args, **init_kwargs)
77
+
78
+ return inner_init
79
+
80
+
81
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
82
+ """
83
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
84
+ (1-beta) over time from t = [0,1].
85
+
86
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
87
+ to that part of the diffusion process.
88
+
89
+
90
+ Args:
91
+ num_diffusion_timesteps (`int`): the number of betas to produce.
92
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
93
+ prevent singularities.
94
+
95
+ Returns:
96
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
97
+ """
98
+
99
+ def alpha_bar(time_step):
100
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
101
+
102
+ betas = []
103
+ for i in range(num_diffusion_timesteps):
104
+ t1 = i / num_diffusion_timesteps
105
+ t2 = (i + 1) / num_diffusion_timesteps
106
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
107
+ return torch.tensor(betas)
108
+
109
+
110
+ class DDIMScheduler():
111
+ config_name = "scheduler_config.json"
112
+ _deprecated_kwargs = ["predict_epsilon"]
113
+ order = 1
114
+
115
+ @register_to_config
116
+ def __init__(
117
+ self,
118
+ num_train_timesteps: int = 1000,
119
+ beta_start: float = 0.0001,
120
+ beta_end: float = 0.02,
121
+ beta_schedule: str = "linear",
122
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
123
+ clip_sample: bool = False,
124
+ set_alpha_to_one: bool = True,
125
+ steps_offset: int = 0,
126
+ prediction_type: str = "epsilon",
127
+ **kwargs,
128
+ ):
129
+ message = (
130
+ "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
131
+ " DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
132
+ )
133
+ predict_epsilon = kwargs.get('predict_epsilon', None)
134
+ if predict_epsilon is not None:
135
+ self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
136
+
137
+ if trained_betas is not None:
138
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
139
+ elif beta_schedule == "linear":
140
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
141
+ elif beta_schedule == "scaled_linear":
142
+ # this schedule is very specific to the latent diffusion model.
143
+ self.betas = (
144
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
145
+ )
146
+ elif beta_schedule == "squaredcos_cap_v2":
147
+ # Glide cosine schedule
148
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
149
+ else:
150
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
151
+
152
+ self.alphas = 1.0 - self.betas
153
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
154
+
155
+ # At every step in ddim, we are looking into the previous alphas_cumprod
156
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
157
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
158
+ # whether we use the final alpha of the "non-previous" one.
159
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
160
+
161
+ # standard deviation of the initial noise distribution
162
+ self.init_noise_sigma = 1.0
163
+
164
+ # setable values
165
+ self.num_inference_steps = None
166
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
167
+
168
+ def register_to_config(self, **kwargs):
169
+ if self.config_name is None:
170
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
171
+ # Special case for `kwargs` used in deprecation warning added to schedulers
172
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
173
+ # or solve in a more general way.
174
+ kwargs.pop("kwargs", None)
175
+ for key, value in kwargs.items():
176
+ try:
177
+ setattr(self, key, value)
178
+ except AttributeError as err:
179
+ print(f"Can't set {key} with value {value} for {self}")
180
+ raise err
181
+
182
+ if not hasattr(self, "_internal_dict"):
183
+ internal_dict = kwargs
184
+ else:
185
+ previous_dict = dict(self._internal_dict)
186
+ internal_dict = {**self._internal_dict, **kwargs}
187
+ print(f"Updating config from {previous_dict} to {internal_dict}")
188
+
189
+ self._internal_dict = internal_dict
190
+
191
+ @property
192
+ def config(self):
193
+ """
194
+ Returns the config of the class as a frozen dictionary
195
+ Returns:
196
+ `Dict[str, Any]`: Config of the class.
197
+ """
198
+ return SimpleNamespace(**self._internal_dict)
199
+
200
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
201
+ """
202
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
203
+ current timestep.
204
+
205
+ Args:
206
+ sample (`torch.FloatTensor`): input sample
207
+ timestep (`int`, optional): current timestep
208
+
209
+ Returns:
210
+ `torch.FloatTensor`: scaled input sample
211
+ """
212
+ return sample
213
+
214
+ def _get_variance(self, timestep, prev_timestep):
215
+ alpha_prod_t = self.alphas_cumprod[timestep]
216
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
217
+ beta_prod_t = 1 - alpha_prod_t
218
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
219
+
220
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
221
+
222
+ return variance
223
+
224
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
225
+ """
226
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
227
+
228
+ Args:
229
+ num_inference_steps (`int`):
230
+ the number of diffusion steps used when generating samples with a pre-trained model.
231
+ """
232
+ self.num_inference_steps = num_inference_steps
233
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
234
+ # creates integer timesteps by multiplying by ratio
235
+ # casting to int to avoid issues when num_inference_step is power of 3
236
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
237
+ self.timesteps = torch.from_numpy(timesteps).to(device)
238
+ self.timesteps += self.config.steps_offset
239
+
240
+ def step(
241
+ self,
242
+ model_output: torch.FloatTensor,
243
+ timestep: int,
244
+ sample: torch.FloatTensor,
245
+ eta: float = 0.0,
246
+ use_clipped_model_output: bool = False,
247
+ generator=None,
248
+ variance_noise: Optional[torch.FloatTensor] = None,
249
+ return_dict: bool = True,
250
+ ) -> Union[Dict, Tuple]:
251
+ """
252
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
253
+ process from the learned model outputs (most often the predicted noise).
254
+
255
+ Args:
256
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
257
+ timestep (`int`): current discrete timestep in the diffusion chain.
258
+ sample (`torch.FloatTensor`):
259
+ current instance of sample being created by diffusion process.
260
+ eta (`float`): weight of noise for added noise in diffusion step.
261
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
262
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
263
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
264
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
265
+ generator: random number generator.
266
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
267
+ can directly provide the noise for the variance itself. This is useful for methods such as
268
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
269
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
270
+
271
+ Returns:
272
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
273
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
274
+ returning a tuple, the first element is the sample tensor.
275
+
276
+ """
277
+ if self.num_inference_steps is None:
278
+ raise ValueError(
279
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
280
+ )
281
+
282
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
283
+ # Ideally, read DDIM paper in-detail understanding
284
+
285
+ # Notation (<variable name> -> <name in paper>
286
+ # - pred_noise_t -> e_theta(x_t, t)
287
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
288
+ # - std_dev_t -> sigma_t
289
+ # - eta -> η
290
+ # - pred_sample_direction -> "direction pointing to x_t"
291
+ # - pred_prev_sample -> "x_t-1"
292
+
293
+ # 1. get previous step value (=t-1)
294
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
295
+
296
+ # 2. compute alphas, betas
297
+ alpha_prod_t = self.alphas_cumprod[timestep]
298
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
299
+
300
+ beta_prod_t = 1 - alpha_prod_t
301
+
302
+ # 3. compute predicted original sample from predicted noise also called
303
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
304
+ if self.config.prediction_type == "epsilon":
305
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
306
+ elif self.config.prediction_type == "sample":
307
+ pred_original_sample = model_output
308
+ elif self.config.prediction_type == "v_prediction":
309
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
310
+ # predict V
311
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
312
+ else:
313
+ raise ValueError(
314
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
315
+ " `v_prediction`"
316
+ )
317
+
318
+ # 4. Clip "predicted x_0"
319
+ if self.config.clip_sample:
320
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
321
+
322
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
323
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
324
+ variance = self._get_variance(timestep, prev_timestep)
325
+ std_dev_t = eta * variance ** (0.5)
326
+
327
+ if use_clipped_model_output:
328
+ # the model_output is always re-derived from the clipped x_0 in Glide
329
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
330
+
331
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
332
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
333
+
334
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
335
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
336
+
337
+ if eta > 0:
338
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
339
+ device = model_output.device
340
+ if variance_noise is not None and generator is not None:
341
+ raise ValueError(
342
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
343
+ " `variance_noise` stays `None`."
344
+ )
345
+
346
+ if variance_noise is None:
347
+ if device.type == "mps":
348
+ # randn does not work reproducibly on mps
349
+ variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
350
+ variance_noise = variance_noise.to(device)
351
+ else:
352
+ variance_noise = torch.randn(
353
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
354
+ )
355
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
356
+
357
+ prev_sample = prev_sample + variance
358
+
359
+ if not return_dict:
360
+ return (prev_sample,)
361
+
362
+ return dict(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
363
+
364
+ def add_noise(
365
+ self,
366
+ original_samples: torch.FloatTensor,
367
+ noise: torch.FloatTensor,
368
+ timesteps: torch.IntTensor,
369
+ ) -> torch.FloatTensor:
370
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
371
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
372
+ timesteps = timesteps.to(original_samples.device)
373
+
374
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
375
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
376
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
377
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
378
+
379
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
380
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
381
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
382
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
383
+
384
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
385
+ return noisy_samples
386
+
387
+ def get_velocity(
388
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
389
+ ) -> torch.FloatTensor:
390
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
391
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
392
+ timesteps = timesteps.to(sample.device)
393
+
394
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
395
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
396
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
397
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
398
+
399
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
400
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
401
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
402
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
403
+
404
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
405
+ return velocity
406
+
407
+ def __len__(self):
408
+ return self.config.num_train_timesteps
409
+
410
+
411
  def dummy_model(img):
412
  img_array = np.array(img)
413
  return img_array