SunderAli17 commited on
Commit
67c1a11
·
verified ·
1 Parent(s): 2f580fc

Create train_previewer_lora.py

Browse files
Files changed (1) hide show
  1. functions/train_previewer_lora.py +1712 -0
functions/train_previewer_lora.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import functools
19
+ import gc
20
+ import logging
21
+ import pyrallis
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import accelerate
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+ import transformers
35
+ from PIL import Image
36
+ from accelerate import Accelerator
37
+ from accelerate.logging import get_logger
38
+ from accelerate.utils import ProjectConfiguration, set_seed
39
+ from datasets import load_dataset
40
+ from huggingface_hub import create_repo, upload_folder
41
+ from packaging import version
42
+ from collections import namedtuple
43
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
44
+ from torchvision import transforms
45
+ from torchvision.transforms.functional import crop
46
+ from tqdm.auto import tqdm
47
+ from transformers import (
48
+ AutoTokenizer,
49
+ PretrainedConfig,
50
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
51
+ AutoImageProcessor, AutoModel
52
+ )
53
+
54
+ import diffusers
55
+ from diffusers import (
56
+ AutoencoderKL,
57
+ DDPMScheduler,
58
+ LCMScheduler,
59
+ StableDiffusionXLPipeline,
60
+ UNet2DConditionModel,
61
+ )
62
+ from diffusers.optimization import get_scheduler
63
+ from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
64
+ from diffusers.utils import (
65
+ check_min_version,
66
+ convert_state_dict_to_diffusers,
67
+ convert_unet_state_dict_to_peft,
68
+ is_wandb_available,
69
+ )
70
+ from diffusers.utils.import_utils import is_xformers_available
71
+ from diffusers.utils.torch_utils import is_compiled_module
72
+
73
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
74
+ from utils.train_utils import (
75
+ seperate_ip_params_from_unet,
76
+ import_model_class_from_model_name_or_path,
77
+ tensor_to_pil,
78
+ get_train_dataset, prepare_train_dataset, collate_fn,
79
+ encode_prompt, importance_sampling_fn, extract_into_tensor
80
+
81
+ )
82
+ from data.data_config import DataConfig
83
+ from losses.loss_config import LossesConfig
84
+ from losses.losses import *
85
+
86
+ from module.ip_adapter.resampler import Resampler
87
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
88
+
89
+
90
+ if is_wandb_available():
91
+ import wandb
92
+
93
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
94
+
95
+ logger = get_logger(__name__)
96
+
97
+
98
+ def prepare_latents(lq, vae, scheduler, generator, timestep):
99
+ transform = transforms.Compose([
100
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
101
+ transforms.CenterCrop(args.resolution),
102
+ transforms.ToTensor(),
103
+ ])
104
+ lq_pt = [transform(lq_pil.convert("RGB")) for lq_pil in lq]
105
+ img_pt = torch.stack(lq_pt).to(vae.device, dtype=vae.dtype)
106
+ img_pt = img_pt * 2.0 - 1.0
107
+ with torch.no_grad():
108
+ latents = vae.encode(img_pt).latent_dist.sample()
109
+ latents = latents * vae.config.scaling_factor
110
+ noise = torch.randn(latents.shape, generator=generator, device=vae.device, dtype=vae.dtype, layout=torch.strided).to(vae.device)
111
+ bsz = latents.shape[0]
112
+ print(f"init latent at {timestep}")
113
+ timestep = torch.tensor([timestep]*bsz, device=vae.device, dtype=torch.int64)
114
+ latents = scheduler.add_noise(latents, noise, timestep)
115
+ return latents
116
+
117
+
118
+ def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
119
+ scheduler, image_encoder, image_processor,
120
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
121
+ logger.info("Running validation... ")
122
+
123
+ image_logs = []
124
+
125
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
126
+
127
+ pipe = StableDiffusionXLPipeline(
128
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
129
+ unet, scheduler, image_encoder, image_processor,
130
+ ).to(accelerator.device)
131
+
132
+ timesteps = [args.num_train_timesteps - 1]
133
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
134
+ latents = prepare_latents(lq, vae, scheduler, generator, timesteps[-1])
135
+ image = pipe(
136
+ prompt=[""]*len(lq),
137
+ ip_adapter_image=[lq],
138
+ num_inference_steps=1,
139
+ timesteps=timesteps,
140
+ generator=generator,
141
+ guidance_scale=1.0,
142
+ height=args.resolution,
143
+ width=args.resolution,
144
+ latents=latents,
145
+ ).images
146
+
147
+ if log_local:
148
+ # for i, img in enumerate(tensor_to_pil(lq_img)):
149
+ # img.save(f"./lq_{i}.png")
150
+ # for i, img in enumerate(tensor_to_pil(gt_img)):
151
+ # img.save(f"./gt_{i}.png")
152
+ for i, img in enumerate(image):
153
+ img.save(f"./lq_IPA_{i}.png")
154
+ return
155
+
156
+ tracker_key = "test" if is_final_validation else "validation"
157
+ for tracker in accelerator.trackers:
158
+ if tracker.name == "tensorboard":
159
+ images = [np.asarray(pil_img) for pil_img in image]
160
+ images = np.stack(images, axis=0)
161
+ if lq_img is not None and gt_img is not None:
162
+ input_lq = lq_img.detach().cpu()
163
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
164
+ input_gt = gt_img.detach().cpu()
165
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
166
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
167
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
168
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
169
+ elif tracker.name == "wandb":
170
+ raise NotImplementedError("Wandb logging not implemented for validation.")
171
+ formatted_images = []
172
+
173
+ for log in image_logs:
174
+ images = log["images"]
175
+ validation_prompt = log["validation_prompt"]
176
+ validation_image = log["validation_image"]
177
+
178
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
179
+
180
+ for image in images:
181
+ image = wandb.Image(image, caption=validation_prompt)
182
+ formatted_images.append(image)
183
+
184
+ tracker.log({tracker_key: formatted_images})
185
+ else:
186
+ logger.warning(f"image logging not implemented for {tracker.name}")
187
+
188
+ gc.collect()
189
+ torch.cuda.empty_cache()
190
+
191
+ return image_logs
192
+
193
+
194
+ class DDIMSolver:
195
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
196
+ # DDIM sampling parameters
197
+ step_ratio = timesteps // ddim_timesteps
198
+
199
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
200
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
201
+ self.ddim_alpha_cumprods_prev = np.asarray(
202
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
203
+ )
204
+ # convert to torch tensors
205
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
206
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
207
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
208
+
209
+ def to(self, device):
210
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
211
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
212
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
213
+ return self
214
+
215
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
216
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
217
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
218
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
219
+ return x_prev
220
+
221
+
222
+ def append_dims(x, target_dims):
223
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
224
+ dims_to_append = target_dims - x.ndim
225
+ if dims_to_append < 0:
226
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
227
+ return x[(...,) + (None,) * dims_to_append]
228
+
229
+
230
+ # From LCMScheduler.get_scalings_for_boundary_condition_discrete
231
+ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
232
+ scaled_timestep = timestep_scaling * timestep
233
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
234
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
235
+ return c_skip, c_out
236
+
237
+
238
+ # Compare LCMScheduler.step, Step 4
239
+ def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
240
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
241
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
242
+ if prediction_type == "epsilon":
243
+ pred_x_0 = (sample - sigmas * model_output) / alphas
244
+ elif prediction_type == "sample":
245
+ pred_x_0 = model_output
246
+ elif prediction_type == "v_prediction":
247
+ pred_x_0 = alphas * sample - sigmas * model_output
248
+ else:
249
+ raise ValueError(
250
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
251
+ f" are supported."
252
+ )
253
+
254
+ return pred_x_0
255
+
256
+
257
+ # Based on step 4 in DDIMScheduler.step
258
+ def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
259
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
260
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
261
+ if prediction_type == "epsilon":
262
+ pred_epsilon = model_output
263
+ elif prediction_type == "sample":
264
+ pred_epsilon = (sample - alphas * model_output) / sigmas
265
+ elif prediction_type == "v_prediction":
266
+ pred_epsilon = alphas * model_output + sigmas * sample
267
+ else:
268
+ raise ValueError(
269
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
270
+ f" are supported."
271
+ )
272
+
273
+ return pred_epsilon
274
+
275
+
276
+ def extract_into_tensor(a, t, x_shape):
277
+ b, *_ = t.shape
278
+ out = a.gather(-1, t)
279
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
280
+
281
+
282
+ def parse_args():
283
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
284
+ # ----------Model Checkpoint Loading Arguments----------
285
+ parser.add_argument(
286
+ "--pretrained_model_name_or_path",
287
+ type=str,
288
+ default=None,
289
+ required=True,
290
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
291
+ )
292
+ parser.add_argument(
293
+ "--pretrained_vae_model_name_or_path",
294
+ type=str,
295
+ default=None,
296
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
297
+ )
298
+ parser.add_argument(
299
+ "--teacher_revision",
300
+ type=str,
301
+ default=None,
302
+ required=False,
303
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
304
+ )
305
+ parser.add_argument(
306
+ "--revision",
307
+ type=str,
308
+ default=None,
309
+ required=False,
310
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
311
+ )
312
+ parser.add_argument(
313
+ "--pretrained_lcm_lora_path",
314
+ type=str,
315
+ default=None,
316
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
317
+ )
318
+ parser.add_argument(
319
+ "--feature_extractor_path",
320
+ type=str,
321
+ default=None,
322
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
323
+ )
324
+ parser.add_argument(
325
+ "--pretrained_adapter_model_path",
326
+ type=str,
327
+ default=None,
328
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
329
+ )
330
+ parser.add_argument(
331
+ "--adapter_tokens",
332
+ type=int,
333
+ default=64,
334
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
335
+ )
336
+ parser.add_argument(
337
+ "--use_clip_encoder",
338
+ action="store_true",
339
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
340
+ )
341
+ parser.add_argument(
342
+ "--image_encoder_hidden_feature",
343
+ action="store_true",
344
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
345
+ )
346
+ # ----------Training Arguments----------
347
+ # ----General Training Arguments----
348
+ parser.add_argument(
349
+ "--output_dir",
350
+ type=str,
351
+ default="lcm-xl-distilled",
352
+ help="The output directory where the model predictions and checkpoints will be written.",
353
+ )
354
+ parser.add_argument(
355
+ "--cache_dir",
356
+ type=str,
357
+ default=None,
358
+ help="The directory where the downloaded models and datasets will be stored.",
359
+ )
360
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
361
+ # ----Logging----
362
+ parser.add_argument(
363
+ "--logging_dir",
364
+ type=str,
365
+ default="logs",
366
+ help=(
367
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
368
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
369
+ ),
370
+ )
371
+ parser.add_argument(
372
+ "--report_to",
373
+ type=str,
374
+ default="tensorboard",
375
+ help=(
376
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
377
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
378
+ ),
379
+ )
380
+ # ----Checkpointing----
381
+ parser.add_argument(
382
+ "--checkpointing_steps",
383
+ type=int,
384
+ default=4000,
385
+ help=(
386
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
387
+ " training using `--resume_from_checkpoint`."
388
+ ),
389
+ )
390
+ parser.add_argument(
391
+ "--checkpoints_total_limit",
392
+ type=int,
393
+ default=5,
394
+ help=("Max number of checkpoints to store."),
395
+ )
396
+ parser.add_argument(
397
+ "--resume_from_checkpoint",
398
+ type=str,
399
+ default=None,
400
+ help=(
401
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
402
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
403
+ ),
404
+ )
405
+ parser.add_argument(
406
+ "--save_only_adapter",
407
+ action="store_true",
408
+ help="Only save extra adapter to save space.",
409
+ )
410
+ # ----Image Processing----
411
+ parser.add_argument(
412
+ "--data_config_path",
413
+ type=str,
414
+ default=None,
415
+ help=("A folder containing the training data. "),
416
+ )
417
+ parser.add_argument(
418
+ "--train_data_dir",
419
+ type=str,
420
+ default=None,
421
+ help=(
422
+ "A folder containing the training data. Folder contents must follow the structure described in"
423
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
424
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
425
+ ),
426
+ )
427
+ parser.add_argument(
428
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
429
+ )
430
+ parser.add_argument(
431
+ "--conditioning_image_column",
432
+ type=str,
433
+ default="conditioning_image",
434
+ help="The column of the dataset containing the controlnet conditioning image.",
435
+ )
436
+ parser.add_argument(
437
+ "--caption_column",
438
+ type=str,
439
+ default="text",
440
+ help="The column of the dataset containing a caption or a list of captions.",
441
+ )
442
+ parser.add_argument(
443
+ "--text_drop_rate",
444
+ type=float,
445
+ default=0,
446
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
447
+ )
448
+ parser.add_argument(
449
+ "--image_drop_rate",
450
+ type=float,
451
+ default=0,
452
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
453
+ )
454
+ parser.add_argument(
455
+ "--cond_drop_rate",
456
+ type=float,
457
+ default=0,
458
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
459
+ )
460
+ parser.add_argument(
461
+ "--resolution",
462
+ type=int,
463
+ default=1024,
464
+ help=(
465
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
466
+ " resolution"
467
+ ),
468
+ )
469
+ parser.add_argument(
470
+ "--interpolation_type",
471
+ type=str,
472
+ default="bilinear",
473
+ help=(
474
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
475
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
476
+ ),
477
+ )
478
+ parser.add_argument(
479
+ "--center_crop",
480
+ default=False,
481
+ action="store_true",
482
+ help=(
483
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
484
+ " cropped. The images will be resized to the resolution first before cropping."
485
+ ),
486
+ )
487
+ parser.add_argument(
488
+ "--random_flip",
489
+ action="store_true",
490
+ help="whether to randomly flip images horizontally",
491
+ )
492
+ parser.add_argument(
493
+ "--encode_batch_size",
494
+ type=int,
495
+ default=8,
496
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
497
+ )
498
+ # ----Dataloader----
499
+ parser.add_argument(
500
+ "--dataloader_num_workers",
501
+ type=int,
502
+ default=0,
503
+ help=(
504
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
505
+ ),
506
+ )
507
+ # ----Batch Size and Training Steps----
508
+ parser.add_argument(
509
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
510
+ )
511
+ parser.add_argument("--num_train_epochs", type=int, default=100)
512
+ parser.add_argument(
513
+ "--max_train_steps",
514
+ type=int,
515
+ default=None,
516
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
517
+ )
518
+ parser.add_argument(
519
+ "--max_train_samples",
520
+ type=int,
521
+ default=None,
522
+ help=(
523
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
524
+ "value if set."
525
+ ),
526
+ )
527
+ # ----Learning Rate----
528
+ parser.add_argument(
529
+ "--learning_rate",
530
+ type=float,
531
+ default=1e-6,
532
+ help="Initial learning rate (after the potential warmup period) to use.",
533
+ )
534
+ parser.add_argument(
535
+ "--scale_lr",
536
+ action="store_true",
537
+ default=False,
538
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
539
+ )
540
+ parser.add_argument(
541
+ "--lr_scheduler",
542
+ type=str,
543
+ default="constant",
544
+ help=(
545
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
546
+ ' "constant", "constant_with_warmup"]'
547
+ ),
548
+ )
549
+ parser.add_argument(
550
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
551
+ )
552
+ parser.add_argument(
553
+ "--lr_num_cycles",
554
+ type=int,
555
+ default=1,
556
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
557
+ )
558
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
559
+ parser.add_argument(
560
+ "--gradient_accumulation_steps",
561
+ type=int,
562
+ default=1,
563
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
564
+ )
565
+ # ----Optimizer (Adam)----
566
+ parser.add_argument(
567
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
568
+ )
569
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
570
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
571
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
572
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
573
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
574
+ # ----Diffusion Training Arguments----
575
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
576
+ parser.add_argument(
577
+ "--w_min",
578
+ type=float,
579
+ default=3.0,
580
+ required=False,
581
+ help=(
582
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
583
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
584
+ " compared to the original paper."
585
+ ),
586
+ )
587
+ parser.add_argument(
588
+ "--w_max",
589
+ type=float,
590
+ default=15.0,
591
+ required=False,
592
+ help=(
593
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
594
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
595
+ " compared to the original paper."
596
+ ),
597
+ )
598
+ parser.add_argument(
599
+ "--num_train_timesteps",
600
+ type=int,
601
+ default=1000,
602
+ help="The number of timesteps to use for DDIM sampling.",
603
+ )
604
+ parser.add_argument(
605
+ "--num_ddim_timesteps",
606
+ type=int,
607
+ default=50,
608
+ help="The number of timesteps to use for DDIM sampling.",
609
+ )
610
+ parser.add_argument(
611
+ "--losses_config_path",
612
+ type=str,
613
+ default='config_files/losses.yaml',
614
+ required=True,
615
+ help=("A yaml file containing losses to use and their weights."),
616
+ )
617
+ parser.add_argument(
618
+ "--loss_type",
619
+ type=str,
620
+ default="l2",
621
+ choices=["l2", "huber"],
622
+ help="The type of loss to use for the LCD loss.",
623
+ )
624
+ parser.add_argument(
625
+ "--huber_c",
626
+ type=float,
627
+ default=0.001,
628
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
629
+ )
630
+ parser.add_argument(
631
+ "--lora_rank",
632
+ type=int,
633
+ default=64,
634
+ help="The rank of the LoRA projection matrix.",
635
+ )
636
+ parser.add_argument(
637
+ "--lora_alpha",
638
+ type=int,
639
+ default=64,
640
+ help=(
641
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
642
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
643
+ ),
644
+ )
645
+ parser.add_argument(
646
+ "--lora_dropout",
647
+ type=float,
648
+ default=0.0,
649
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
650
+ )
651
+ parser.add_argument(
652
+ "--lora_target_modules",
653
+ type=str,
654
+ default=None,
655
+ help=(
656
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
657
+ " be used. By default, LoRA will be applied to all conv and linear layers."
658
+ ),
659
+ )
660
+ parser.add_argument(
661
+ "--vae_encode_batch_size",
662
+ type=int,
663
+ default=8,
664
+ required=False,
665
+ help=(
666
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
667
+ " Encoding or decoding the whole batch at once may run into OOM issues."
668
+ ),
669
+ )
670
+ parser.add_argument(
671
+ "--timestep_scaling_factor",
672
+ type=float,
673
+ default=10.0,
674
+ help=(
675
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
676
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
677
+ " suffice."
678
+ ),
679
+ )
680
+ # ----Mixed Precision----
681
+ parser.add_argument(
682
+ "--mixed_precision",
683
+ type=str,
684
+ default=None,
685
+ choices=["no", "fp16", "bf16"],
686
+ help=(
687
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
688
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
689
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
690
+ ),
691
+ )
692
+ parser.add_argument(
693
+ "--allow_tf32",
694
+ action="store_true",
695
+ help=(
696
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
697
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
698
+ ),
699
+ )
700
+ # ----Training Optimizations----
701
+ parser.add_argument(
702
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
703
+ )
704
+ parser.add_argument(
705
+ "--gradient_checkpointing",
706
+ action="store_true",
707
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
708
+ )
709
+ # ----Distributed Training----
710
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
711
+ # ----------Validation Arguments----------
712
+ parser.add_argument(
713
+ "--validation_steps",
714
+ type=int,
715
+ default=3000,
716
+ help="Run validation every X steps.",
717
+ )
718
+ parser.add_argument(
719
+ "--validation_image",
720
+ type=str,
721
+ default=None,
722
+ nargs="+",
723
+ help=(
724
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
725
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
726
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
727
+ " `--validation_image` that will be used with all `--validation_prompt`s."
728
+ ),
729
+ )
730
+ parser.add_argument(
731
+ "--validation_prompt",
732
+ type=str,
733
+ default=None,
734
+ nargs="+",
735
+ help=(
736
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
737
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
738
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
739
+ ),
740
+ )
741
+ parser.add_argument(
742
+ "--sanity_check",
743
+ action="store_true",
744
+ help=(
745
+ "sanity check"
746
+ ),
747
+ )
748
+ # ----------Huggingface Hub Arguments-----------
749
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
750
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
751
+ parser.add_argument(
752
+ "--hub_model_id",
753
+ type=str,
754
+ default=None,
755
+ help="The name of the repository to keep in sync with the local `output_dir`.",
756
+ )
757
+ # ----------Accelerate Arguments----------
758
+ parser.add_argument(
759
+ "--tracker_project_name",
760
+ type=str,
761
+ default="trian",
762
+ help=(
763
+ "The `project_name` argument passed to Accelerator.init_trackers for"
764
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
765
+ ),
766
+ )
767
+
768
+ args = parser.parse_args()
769
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
770
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
771
+ args.local_rank = env_local_rank
772
+
773
+ return args
774
+
775
+
776
+ def main(args):
777
+ if args.report_to == "wandb" and args.hub_token is not None:
778
+ raise ValueError(
779
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
780
+ " Please use `huggingface-cli login` to authenticate with the Hub."
781
+ )
782
+
783
+ logging_dir = Path(args.output_dir, args.logging_dir)
784
+
785
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
786
+
787
+ accelerator = Accelerator(
788
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
789
+ mixed_precision=args.mixed_precision,
790
+ log_with=args.report_to,
791
+ project_config=accelerator_project_config,
792
+ )
793
+
794
+ # Make one log on every process with the configuration for debugging.
795
+ logging.basicConfig(
796
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
797
+ datefmt="%m/%d/%Y %H:%M:%S",
798
+ level=logging.INFO,
799
+ )
800
+ logger.info(accelerator.state, main_process_only=False)
801
+ if accelerator.is_local_main_process:
802
+ transformers.utils.logging.set_verbosity_warning()
803
+ diffusers.utils.logging.set_verbosity_info()
804
+ else:
805
+ transformers.utils.logging.set_verbosity_error()
806
+ diffusers.utils.logging.set_verbosity_error()
807
+
808
+ # If passed along, set the training seed now.
809
+ if args.seed is not None:
810
+ set_seed(args.seed)
811
+
812
+ # Handle the repository creation.
813
+ if accelerator.is_main_process:
814
+ if args.output_dir is not None:
815
+ os.makedirs(args.output_dir, exist_ok=True)
816
+
817
+ # 1. Create the noise scheduler and the desired noise schedule.
818
+ noise_scheduler = DDPMScheduler.from_pretrained(
819
+ args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.teacher_revision
820
+ )
821
+ noise_scheduler.config.num_train_timesteps = args.num_train_timesteps
822
+ lcm_scheduler = LCMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
823
+
824
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
825
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
826
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
827
+ # Initialize the DDIM ODE solver for distillation.
828
+ solver = DDIMSolver(
829
+ noise_scheduler.alphas_cumprod.numpy(),
830
+ timesteps=noise_scheduler.config.num_train_timesteps,
831
+ ddim_timesteps=args.num_ddim_timesteps,
832
+ )
833
+
834
+ # 2. Load tokenizers from SDXL checkpoint.
835
+ tokenizer_one = AutoTokenizer.from_pretrained(
836
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
837
+ )
838
+ tokenizer_two = AutoTokenizer.from_pretrained(
839
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
840
+ )
841
+
842
+ # 3. Load text encoders from SDXL checkpoint.
843
+ # import correct text encoder classes
844
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
845
+ args.pretrained_model_name_or_path, args.teacher_revision
846
+ )
847
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
848
+ args.pretrained_model_name_or_path, args.teacher_revision, subfolder="text_encoder_2"
849
+ )
850
+
851
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
852
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.teacher_revision
853
+ )
854
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
855
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.teacher_revision
856
+ )
857
+
858
+ if args.use_clip_encoder:
859
+ image_processor = CLIPImageProcessor()
860
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
861
+ else:
862
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
863
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
864
+
865
+ # 4. Load VAE from SDXL checkpoint (or more stable VAE)
866
+ vae_path = (
867
+ args.pretrained_model_name_or_path
868
+ if args.pretrained_vae_model_name_or_path is None
869
+ else args.pretrained_vae_model_name_or_path
870
+ )
871
+ vae = AutoencoderKL.from_pretrained(
872
+ vae_path,
873
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
874
+ revision=args.teacher_revision,
875
+ )
876
+
877
+ # 7. Create online student U-Net.
878
+ unet = UNet2DConditionModel.from_pretrained(
879
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.teacher_revision
880
+ )
881
+
882
+ # Resampler for project model in IP-Adapter
883
+ image_proj_model = Resampler(
884
+ dim=1280,
885
+ depth=4,
886
+ dim_head=64,
887
+ heads=20,
888
+ num_queries=args.adapter_tokens,
889
+ embedding_dim=image_encoder.config.hidden_size,
890
+ output_dim=unet.config.cross_attention_dim,
891
+ ff_mult=4
892
+ )
893
+
894
+ # Load the same adapter in both unet.
895
+ init_adapter_in_unet(
896
+ unet,
897
+ image_proj_model,
898
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
899
+ adapter_tokens=args.adapter_tokens,
900
+ )
901
+
902
+ # Check that all trainable models are in full precision
903
+ low_precision_error_string = (
904
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
905
+ " doing mixed precision training, copy of the weights should still be float32."
906
+ )
907
+
908
+ def unwrap_model(model):
909
+ model = accelerator.unwrap_model(model)
910
+ model = model._orig_mod if is_compiled_module(model) else model
911
+ return model
912
+
913
+ if unwrap_model(unet).dtype != torch.float32:
914
+ raise ValueError(
915
+ f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}"
916
+ )
917
+
918
+ if args.pretrained_lcm_lora_path is not None:
919
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
920
+ unet_state_dict = {
921
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
922
+ }
923
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
924
+ lora_state_dict = dict()
925
+ for k, v in unet_state_dict.items():
926
+ if "ip" in k:
927
+ k = k.replace("attn2", "attn2.processor")
928
+ lora_state_dict[k] = v
929
+ else:
930
+ lora_state_dict[k] = v
931
+ if alpha_dict:
932
+ args.lora_alpha = next(iter(alpha_dict.values()))
933
+ else:
934
+ args.lora_alpha = 1
935
+ # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
936
+ if args.lora_target_modules is not None:
937
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
938
+ else:
939
+ lora_target_modules = [
940
+ "to_q",
941
+ "to_kv",
942
+ "0.to_out",
943
+ "attn1.to_k",
944
+ "attn1.to_v",
945
+ "to_k_ip",
946
+ "to_v_ip",
947
+ "ln_k_ip.linear",
948
+ "ln_v_ip.linear",
949
+ "to_out.0",
950
+ "proj_in",
951
+ "proj_out",
952
+ "ff.net.0.proj",
953
+ "ff.net.2",
954
+ "conv1",
955
+ "conv2",
956
+ "conv_shortcut",
957
+ "downsamplers.0.conv",
958
+ "upsamplers.0.conv",
959
+ "time_emb_proj",
960
+ ]
961
+ lora_config = LoraConfig(
962
+ r=args.lora_rank,
963
+ target_modules=lora_target_modules,
964
+ lora_alpha=args.lora_alpha,
965
+ lora_dropout=args.lora_dropout,
966
+ )
967
+
968
+ # Legacy
969
+ # for k, v in lcm_pipe.unet.state_dict().items():
970
+ # if "lora" in k or "base_layer" in k:
971
+ # lcm_dict[k.replace("default_0", "default")] = v
972
+
973
+ unet.add_adapter(lora_config)
974
+ if args.pretrained_lcm_lora_path is not None:
975
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
976
+ if incompatible_keys is not None:
977
+ # check only for unexpected keys
978
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
979
+ if unexpected_keys:
980
+ logger.warning(
981
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
982
+ f" {unexpected_keys}. "
983
+ )
984
+
985
+ # 6. Freeze unet, vae, text_encoders.
986
+ vae.requires_grad_(False)
987
+ text_encoder_one.requires_grad_(False)
988
+ text_encoder_two.requires_grad_(False)
989
+ image_encoder.requires_grad_(False)
990
+ unet.requires_grad_(False)
991
+
992
+ # 10. Handle saving and loading of checkpoints
993
+ # `accelerate` 0.16.0 will have better support for customized saving
994
+ if args.save_only_adapter:
995
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
996
+ def save_model_hook(models, weights, output_dir):
997
+ if accelerator.is_main_process:
998
+ for model in models:
999
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
1000
+ unet_ = unwrap_model(model)
1001
+ # also save the checkpoints in native `diffusers` format so that it can be easily
1002
+ # be independently loaded via `load_lora_weights()`.
1003
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
1004
+ StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict, safe_serialization=False)
1005
+
1006
+ weights.pop()
1007
+
1008
+ def load_model_hook(models, input_dir):
1009
+
1010
+ while len(models) > 0:
1011
+ # pop models so that they are not loaded again
1012
+ model = models.pop()
1013
+
1014
+ if isinstance(model, type(unwrap_model(unet))):
1015
+ unet_ = unwrap_model(model)
1016
+ lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
1017
+ unet_state_dict = {
1018
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
1019
+ }
1020
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1021
+ lora_state_dict = dict()
1022
+ for k, v in unet_state_dict.items():
1023
+ if "ip" in k:
1024
+ k = k.replace("attn2", "attn2.processor")
1025
+ lora_state_dict[k] = v
1026
+ else:
1027
+ lora_state_dict[k] = v
1028
+ incompatible_keys = set_peft_model_state_dict(unet_, lora_state_dict, adapter_name="default")
1029
+ if incompatible_keys is not None:
1030
+ # check only for unexpected keys
1031
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1032
+ if unexpected_keys:
1033
+ logger.warning(
1034
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1035
+ f" {unexpected_keys}. "
1036
+ )
1037
+
1038
+ accelerator.register_save_state_pre_hook(save_model_hook)
1039
+ accelerator.register_load_state_pre_hook(load_model_hook)
1040
+
1041
+ # 11. Enable optimizations
1042
+ if args.enable_xformers_memory_efficient_attention:
1043
+ if is_xformers_available():
1044
+ import xformers
1045
+
1046
+ xformers_version = version.parse(xformers.__version__)
1047
+ if xformers_version == version.parse("0.0.16"):
1048
+ logger.warning(
1049
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1050
+ )
1051
+ unet.enable_xformers_memory_efficient_attention()
1052
+ else:
1053
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1054
+
1055
+ # Enable TF32 for faster training on Ampere GPUs,
1056
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1057
+ if args.allow_tf32:
1058
+ torch.backends.cuda.matmul.allow_tf32 = True
1059
+
1060
+ if args.gradient_checkpointing:
1061
+ unet.enable_gradient_checkpointing()
1062
+ vae.enable_gradient_checkpointing()
1063
+
1064
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1065
+ if args.use_8bit_adam:
1066
+ try:
1067
+ import bitsandbytes as bnb
1068
+ except ImportError:
1069
+ raise ImportError(
1070
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1071
+ )
1072
+
1073
+ optimizer_class = bnb.optim.AdamW8bit
1074
+ else:
1075
+ optimizer_class = torch.optim.AdamW
1076
+
1077
+ # 12. Optimizer creation
1078
+ lora_params, non_lora_params = seperate_lora_params_from_unet(unet)
1079
+ params_to_optimize = lora_params
1080
+ optimizer = optimizer_class(
1081
+ params_to_optimize,
1082
+ lr=args.learning_rate,
1083
+ betas=(args.adam_beta1, args.adam_beta2),
1084
+ weight_decay=args.adam_weight_decay,
1085
+ eps=args.adam_epsilon,
1086
+ )
1087
+
1088
+ # 13. Dataset creation and data processing
1089
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
1090
+ # download the dataset.
1091
+ datasets = []
1092
+ datasets_name = []
1093
+ datasets_weights = []
1094
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1095
+ if args.data_config_path is not None:
1096
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
1097
+ for single_dataset in data_config.datasets:
1098
+ datasets_weights.append(single_dataset.dataset_weight)
1099
+ datasets_name.append(single_dataset.dataset_folder)
1100
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
1101
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
1102
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
1103
+ datasets.append(image_dataset)
1104
+ # TODO: Validation dataset
1105
+ if data_config.val_dataset is not None:
1106
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
1107
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
1108
+
1109
+ # Mix training datasets.
1110
+ sampler_train = None
1111
+ if len(datasets) == 1:
1112
+ train_dataset = datasets[0]
1113
+ else:
1114
+ # Weighted each dataset
1115
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
1116
+ dataset_weights = []
1117
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
1118
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
1119
+ sampler_train = torch.utils.data.WeightedRandomSampler(
1120
+ weights=dataset_weights,
1121
+ num_samples=len(dataset_weights)
1122
+ )
1123
+
1124
+ # DataLoaders creation:
1125
+ train_dataloader = torch.utils.data.DataLoader(
1126
+ train_dataset,
1127
+ sampler=sampler_train,
1128
+ shuffle=True if sampler_train is None else False,
1129
+ collate_fn=collate_fn,
1130
+ batch_size=args.train_batch_size,
1131
+ num_workers=args.dataloader_num_workers,
1132
+ )
1133
+
1134
+ # 14. Embeddings for the UNet.
1135
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1136
+ def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):
1137
+ def compute_time_ids(original_size, crops_coords_top_left):
1138
+ target_size = (args.resolution, args.resolution)
1139
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1140
+ add_time_ids = torch.tensor([add_time_ids])
1141
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1142
+ return add_time_ids
1143
+
1144
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)
1145
+ add_text_embeds = pooled_prompt_embeds
1146
+
1147
+ add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])
1148
+
1149
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1150
+ add_text_embeds = add_text_embeds.to(accelerator.device)
1151
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1152
+
1153
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
1154
+
1155
+ text_encoders = [text_encoder_one, text_encoder_two]
1156
+ tokenizers = [tokenizer_one, tokenizer_two]
1157
+
1158
+ compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
1159
+
1160
+ # Move pixels into latents.
1161
+ @torch.no_grad()
1162
+ def convert_to_latent(pixels):
1163
+ model_input = vae.encode(pixels).latent_dist.sample()
1164
+ model_input = model_input * vae.config.scaling_factor
1165
+ if args.pretrained_vae_model_name_or_path is None:
1166
+ model_input = model_input.to(weight_dtype)
1167
+ return model_input
1168
+
1169
+ # 15. LR Scheduler creation
1170
+ # Scheduler and math around the number of training steps.
1171
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1172
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
1173
+ if args.max_train_steps is None:
1174
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1175
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1176
+ num_training_steps_for_scheduler = (
1177
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1178
+ )
1179
+ else:
1180
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1181
+
1182
+ if args.scale_lr:
1183
+ args.learning_rate = (
1184
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1185
+ )
1186
+
1187
+ # Make sure the trainable params are in float32.
1188
+ if args.mixed_precision == "fp16":
1189
+ # only upcast trainable parameters (LoRA) into fp32
1190
+ cast_training_params(unet, dtype=torch.float32)
1191
+
1192
+ lr_scheduler = get_scheduler(
1193
+ args.lr_scheduler,
1194
+ optimizer=optimizer,
1195
+ num_warmup_steps=num_warmup_steps_for_scheduler,
1196
+ num_training_steps=num_training_steps_for_scheduler,
1197
+ )
1198
+
1199
+ # 16. Prepare for training
1200
+ # Prepare everything with our `accelerator`.
1201
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1202
+ unet, optimizer, train_dataloader, lr_scheduler
1203
+ )
1204
+
1205
+ # 8. Handle mixed precision and device placement
1206
+ # For mixed precision training we cast all non-trainable weigths to half-precision
1207
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1208
+ weight_dtype = torch.float32
1209
+ if accelerator.mixed_precision == "fp16":
1210
+ weight_dtype = torch.float16
1211
+ elif accelerator.mixed_precision == "bf16":
1212
+ weight_dtype = torch.bfloat16
1213
+
1214
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
1215
+ # The VAE is in float32 to avoid NaN losses.
1216
+ if args.pretrained_vae_model_name_or_path is None:
1217
+ vae.to(accelerator.device, dtype=torch.float32)
1218
+ else:
1219
+ vae.to(accelerator.device, dtype=weight_dtype)
1220
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1221
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1222
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
1223
+ for p in non_lora_params:
1224
+ p.data = p.data.to(dtype=weight_dtype)
1225
+ for p in lora_params:
1226
+ p.requires_grad_(True)
1227
+ unet.to(accelerator.device)
1228
+
1229
+ # Also move the alpha and sigma noise schedules to accelerator.device.
1230
+ alpha_schedule = alpha_schedule.to(accelerator.device)
1231
+ sigma_schedule = sigma_schedule.to(accelerator.device)
1232
+ solver = solver.to(accelerator.device)
1233
+
1234
+ # Instantiate Loss.
1235
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
1236
+ lcm_losses = list()
1237
+ for loss_config in losses_configs.lcm_losses:
1238
+ logger.info(f"Loading lcm loss: {loss_config.name}")
1239
+ loss = namedtuple("loss", ["loss", "weight"])
1240
+ loss_class = eval(loss_config.name)
1241
+ lcm_losses.append(loss(loss_class(
1242
+ visualize_every_k=loss_config.visualize_every_k,
1243
+ dtype=weight_dtype,
1244
+ accelerator=accelerator,
1245
+ dino_model=image_encoder,
1246
+ dino_preprocess=image_processor,
1247
+ huber_c=args.huber_c,
1248
+ **loss_config.init_params), weight=loss_config.weight))
1249
+
1250
+ # Final check.
1251
+ for n, p in unet.named_parameters():
1252
+ if p.requires_grad:
1253
+ assert "lora" in n, n
1254
+ assert p.dtype == torch.float32, n
1255
+ else:
1256
+ assert "lora" not in n, f"{n}"
1257
+ assert p.dtype == weight_dtype, n
1258
+ if args.sanity_check:
1259
+ if args.resume_from_checkpoint:
1260
+ if args.resume_from_checkpoint != "latest":
1261
+ path = os.path.basename(args.resume_from_checkpoint)
1262
+ else:
1263
+ # Get the most recent checkpoint
1264
+ dirs = os.listdir(args.output_dir)
1265
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1266
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1267
+ path = dirs[-1] if len(dirs) > 0 else None
1268
+
1269
+ if path is None:
1270
+ accelerator.print(
1271
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1272
+ )
1273
+ args.resume_from_checkpoint = None
1274
+ initial_global_step = 0
1275
+ else:
1276
+ accelerator.print(f"Resuming from checkpoint {path}")
1277
+ accelerator.load_state(os.path.join(args.output_dir, path))
1278
+
1279
+ # Check input data
1280
+ batch = next(iter(train_dataloader))
1281
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1282
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
1283
+ lcm_scheduler, image_encoder, image_processor,
1284
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True)
1285
+ exit()
1286
+
1287
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1288
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1289
+ if args.max_train_steps is None:
1290
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1291
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1292
+ logger.warning(
1293
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1294
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1295
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1296
+ )
1297
+ # Afterwards we recalculate our number of training epochs
1298
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1299
+
1300
+ # We need to initialize the trackers we use, and also store our configuration.
1301
+ # The trackers initializes automatically on the main process.
1302
+ if accelerator.is_main_process:
1303
+ tracker_config = dict(vars(args))
1304
+
1305
+ # tensorboard cannot handle list types for config
1306
+ tracker_config.pop("validation_prompt")
1307
+ tracker_config.pop("validation_image")
1308
+
1309
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1310
+
1311
+ # 17. Train!
1312
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1313
+
1314
+ logger.info("***** Running training *****")
1315
+ logger.info(f" Num examples = {len(train_dataset)}")
1316
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1317
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1318
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1319
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1320
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1321
+ global_step = 0
1322
+ first_epoch = 0
1323
+
1324
+ # Potentially load in the weights and states from a previous save
1325
+ if args.resume_from_checkpoint:
1326
+ if args.resume_from_checkpoint != "latest":
1327
+ path = os.path.basename(args.resume_from_checkpoint)
1328
+ else:
1329
+ # Get the most recent checkpoint
1330
+ dirs = os.listdir(args.output_dir)
1331
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1332
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1333
+ path = dirs[-1] if len(dirs) > 0 else None
1334
+
1335
+ if path is None:
1336
+ accelerator.print(
1337
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1338
+ )
1339
+ args.resume_from_checkpoint = None
1340
+ initial_global_step = 0
1341
+ else:
1342
+ accelerator.print(f"Resuming from checkpoint {path}")
1343
+ accelerator.load_state(os.path.join(args.output_dir, path))
1344
+ global_step = int(path.split("-")[1])
1345
+
1346
+ initial_global_step = global_step
1347
+ first_epoch = global_step // num_update_steps_per_epoch
1348
+ else:
1349
+ initial_global_step = 0
1350
+
1351
+ progress_bar = tqdm(
1352
+ range(0, args.max_train_steps),
1353
+ initial=initial_global_step,
1354
+ desc="Steps",
1355
+ # Only show the progress bar once on each machine.
1356
+ disable=not accelerator.is_local_main_process,
1357
+ )
1358
+
1359
+ unet.train()
1360
+ for epoch in range(first_epoch, args.num_train_epochs):
1361
+ for step, batch in enumerate(train_dataloader):
1362
+ with accelerator.accumulate(unet):
1363
+ total_loss = torch.tensor(0.0)
1364
+ bsz = batch["images"].shape[0]
1365
+
1366
+ # Drop conditions.
1367
+ rand_tensor = torch.rand(bsz)
1368
+ drop_image_idx = rand_tensor < args.image_drop_rate
1369
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1370
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1371
+ drop_image_idx = drop_image_idx | drop_both_idx
1372
+ drop_text_idx = drop_text_idx | drop_both_idx
1373
+
1374
+ with torch.no_grad():
1375
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1376
+ lq_pt = image_processor(
1377
+ images=lq_img*0.5+0.5,
1378
+ do_rescale=False, return_tensors="pt"
1379
+ ).pixel_values
1380
+ image_embeds = prepare_training_image_embeds(
1381
+ image_encoder, image_processor,
1382
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1383
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1384
+ idx_to_replace=drop_image_idx
1385
+ )
1386
+ uncond_image_embeds = prepare_training_image_embeds(
1387
+ image_encoder, image_processor,
1388
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1389
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
1390
+ idx_to_replace=torch.ones_like(drop_image_idx)
1391
+ )
1392
+ # 1. Load and process the image and text conditioning
1393
+ text, orig_size, crop_coords = (
1394
+ batch["text"],
1395
+ batch["original_sizes"],
1396
+ batch["crop_top_lefts"],
1397
+ )
1398
+
1399
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
1400
+ uncond_encoded_text = compute_embeddings_fn([""]*len(text), orig_size, crop_coords)
1401
+
1402
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
1403
+ gt_img = gt_img.to(dtype=vae.dtype)
1404
+ latents = []
1405
+ for i in range(0, gt_img.shape[0], args.vae_encode_batch_size):
1406
+ latents.append(vae.encode(gt_img[i : i + args.vae_encode_batch_size]).latent_dist.sample())
1407
+ latents = torch.cat(latents, dim=0)
1408
+ # latents = convert_to_latent(gt_img)
1409
+
1410
+ latents = latents * vae.config.scaling_factor
1411
+ if args.pretrained_vae_model_name_or_path is None:
1412
+ latents = latents.to(weight_dtype)
1413
+
1414
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
1415
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
1416
+ bsz = latents.shape[0]
1417
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
1418
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
1419
+ start_timesteps = solver.ddim_timesteps[index]
1420
+ timesteps = start_timesteps - topk
1421
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
1422
+
1423
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
1424
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
1425
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
1426
+ )
1427
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
1428
+ c_skip, c_out = scalings_for_boundary_conditions(
1429
+ timesteps, timestep_scaling=args.timestep_scaling_factor
1430
+ )
1431
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
1432
+
1433
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
1434
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1435
+ noise = torch.randn_like(latents)
1436
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
1437
+
1438
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
1439
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
1440
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1441
+ w = w.reshape(bsz, 1, 1, 1)
1442
+ w = w.to(device=latents.device, dtype=latents.dtype)
1443
+
1444
+ # 6. Prepare prompt embeds and unet_added_conditions
1445
+ prompt_embeds = encoded_text.pop("prompt_embeds")
1446
+ encoded_text["image_embeds"] = image_embeds
1447
+ uncond_prompt_embeds = uncond_encoded_text.pop("prompt_embeds")
1448
+ uncond_encoded_text["image_embeds"] = image_embeds
1449
+
1450
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
1451
+ noise_pred = unet(
1452
+ noisy_model_input,
1453
+ start_timesteps,
1454
+ encoder_hidden_states=uncond_prompt_embeds,
1455
+ added_cond_kwargs=uncond_encoded_text,
1456
+ ).sample
1457
+ pred_x_0 = get_predicted_original_sample(
1458
+ noise_pred,
1459
+ start_timesteps,
1460
+ noisy_model_input,
1461
+ noise_scheduler.config.prediction_type,
1462
+ alpha_schedule,
1463
+ sigma_schedule,
1464
+ )
1465
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
1466
+
1467
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
1468
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
1469
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
1470
+ # solver timestep.
1471
+
1472
+ # With the adapters disabled, the `unet` is the regular teacher model.
1473
+ accelerator.unwrap_model(unet).disable_adapters()
1474
+ with torch.no_grad():
1475
+
1476
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
1477
+ teacher_added_cond = dict()
1478
+ for k,v in encoded_text.items():
1479
+ if isinstance(v, torch.Tensor):
1480
+ teacher_added_cond[k] = v.to(weight_dtype)
1481
+ else:
1482
+ teacher_image_embeds = []
1483
+ for img_emb in v:
1484
+ teacher_image_embeds.append(img_emb.to(weight_dtype))
1485
+ teacher_added_cond[k] = teacher_image_embeds
1486
+ cond_teacher_output = unet(
1487
+ noisy_model_input,
1488
+ start_timesteps,
1489
+ encoder_hidden_states=prompt_embeds,
1490
+ added_cond_kwargs=teacher_added_cond,
1491
+ ).sample
1492
+ cond_pred_x0 = get_predicted_original_sample(
1493
+ cond_teacher_output,
1494
+ start_timesteps,
1495
+ noisy_model_input,
1496
+ noise_scheduler.config.prediction_type,
1497
+ alpha_schedule,
1498
+ sigma_schedule,
1499
+ )
1500
+ cond_pred_noise = get_predicted_noise(
1501
+ cond_teacher_output,
1502
+ start_timesteps,
1503
+ noisy_model_input,
1504
+ noise_scheduler.config.prediction_type,
1505
+ alpha_schedule,
1506
+ sigma_schedule,
1507
+ )
1508
+
1509
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
1510
+ teacher_added_uncond = dict()
1511
+ uncond_encoded_text["image_embeds"] = uncond_image_embeds
1512
+ for k,v in uncond_encoded_text.items():
1513
+ if isinstance(v, torch.Tensor):
1514
+ teacher_added_uncond[k] = v.to(weight_dtype)
1515
+ else:
1516
+ teacher_uncond_image_embeds = []
1517
+ for img_emb in v:
1518
+ teacher_uncond_image_embeds.append(img_emb.to(weight_dtype))
1519
+ teacher_added_uncond[k] = teacher_uncond_image_embeds
1520
+ uncond_teacher_output = unet(
1521
+ noisy_model_input,
1522
+ start_timesteps,
1523
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
1524
+ added_cond_kwargs=teacher_added_uncond,
1525
+ ).sample
1526
+ uncond_pred_x0 = get_predicted_original_sample(
1527
+ uncond_teacher_output,
1528
+ start_timesteps,
1529
+ noisy_model_input,
1530
+ noise_scheduler.config.prediction_type,
1531
+ alpha_schedule,
1532
+ sigma_schedule,
1533
+ )
1534
+ uncond_pred_noise = get_predicted_noise(
1535
+ uncond_teacher_output,
1536
+ start_timesteps,
1537
+ noisy_model_input,
1538
+ noise_scheduler.config.prediction_type,
1539
+ alpha_schedule,
1540
+ sigma_schedule,
1541
+ )
1542
+
1543
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
1544
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
1545
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
1546
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
1547
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
1548
+ # augmented PF-ODE trajectory (solving backward in time)
1549
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
1550
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(weight_dtype)
1551
+
1552
+ # re-enable unet adapters to turn the `unet` into a student unet.
1553
+ accelerator.unwrap_model(unet).enable_adapters()
1554
+
1555
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
1556
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
1557
+ with torch.no_grad():
1558
+ uncond_encoded_text["image_embeds"] = image_embeds
1559
+ target_added_cond = dict()
1560
+ for k,v in uncond_encoded_text.items():
1561
+ if isinstance(v, torch.Tensor):
1562
+ target_added_cond[k] = v.to(weight_dtype)
1563
+ else:
1564
+ target_image_embeds = []
1565
+ for img_emb in v:
1566
+ target_image_embeds.append(img_emb.to(weight_dtype))
1567
+ target_added_cond[k] = target_image_embeds
1568
+ target_noise_pred = unet(
1569
+ x_prev,
1570
+ timesteps,
1571
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
1572
+ added_cond_kwargs=target_added_cond,
1573
+ ).sample
1574
+ pred_x_0 = get_predicted_original_sample(
1575
+ target_noise_pred,
1576
+ timesteps,
1577
+ x_prev,
1578
+ noise_scheduler.config.prediction_type,
1579
+ alpha_schedule,
1580
+ sigma_schedule,
1581
+ )
1582
+ target = c_skip * x_prev + c_out * pred_x_0
1583
+
1584
+ # 10. Calculate loss
1585
+ lcm_loss_arguments = {
1586
+ "target": target.float(),
1587
+ "predict": model_pred.float(),
1588
+ }
1589
+ loss_dict = dict()
1590
+ # total_loss = total_loss + torch.mean(
1591
+ # torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
1592
+ # )
1593
+ # loss_dict["L2Loss"] = total_loss.item()
1594
+ for loss_config in lcm_losses:
1595
+ if loss_config.loss.__class__.__name__=="DINOLoss":
1596
+ with torch.no_grad():
1597
+ pixel_target = []
1598
+ latent_target = target.to(dtype=vae.dtype)
1599
+ for i in range(0, latent_target.shape[0], args.vae_encode_batch_size):
1600
+ pixel_target.append(
1601
+ vae.decode(
1602
+ latent_target[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
1603
+ return_dict=False
1604
+ )[0]
1605
+ )
1606
+ pixel_target = torch.cat(pixel_target, dim=0)
1607
+ pixel_pred = []
1608
+ latent_pred = model_pred.to(dtype=vae.dtype)
1609
+ for i in range(0, latent_pred.shape[0], args.vae_encode_batch_size):
1610
+ pixel_pred.append(
1611
+ vae.decode(
1612
+ latent_pred[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
1613
+ return_dict=False
1614
+ )[0]
1615
+ )
1616
+ pixel_pred = torch.cat(pixel_pred, dim=0)
1617
+ dino_loss_arguments = {
1618
+ "target": pixel_target,
1619
+ "predict": pixel_pred,
1620
+ }
1621
+ non_weighted_loss = loss_config.loss(**dino_loss_arguments, accelerator=accelerator)
1622
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1623
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
1624
+ else:
1625
+ non_weighted_loss = loss_config.loss(**lcm_loss_arguments, accelerator=accelerator)
1626
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
1627
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1628
+
1629
+ # 11. Backpropagate on the online student model (`unet`) (only LoRA)
1630
+ accelerator.backward(total_loss)
1631
+ if accelerator.sync_gradients:
1632
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1633
+ optimizer.step()
1634
+ lr_scheduler.step()
1635
+ optimizer.zero_grad(set_to_none=True)
1636
+
1637
+ # Checks if the accelerator has performed an optimization step behind the scenes
1638
+ if accelerator.sync_gradients:
1639
+ progress_bar.update(1)
1640
+ global_step += 1
1641
+
1642
+ if accelerator.is_main_process:
1643
+ if global_step % args.checkpointing_steps == 0:
1644
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1645
+ if args.checkpoints_total_limit is not None:
1646
+ checkpoints = os.listdir(args.output_dir)
1647
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1648
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1649
+
1650
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1651
+ if len(checkpoints) >= args.checkpoints_total_limit:
1652
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1653
+ removing_checkpoints = checkpoints[0:num_to_remove]
1654
+
1655
+ logger.info(
1656
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1657
+ )
1658
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1659
+
1660
+ for removing_checkpoint in removing_checkpoints:
1661
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1662
+ shutil.rmtree(removing_checkpoint)
1663
+
1664
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1665
+ accelerator.save_state(save_path)
1666
+ logger.info(f"Saved state to {save_path}")
1667
+
1668
+ if global_step % args.validation_steps == 0:
1669
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
1670
+ lcm_scheduler, image_encoder, image_processor,
1671
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False, log_local=False)
1672
+
1673
+ logs = dict()
1674
+ # logs.update({"loss": loss.detach().item()})
1675
+ logs.update(loss_dict)
1676
+ logs.update({"lr": lr_scheduler.get_last_lr()[0]})
1677
+ progress_bar.set_postfix(**logs)
1678
+ accelerator.log(logs, step=global_step)
1679
+
1680
+ if global_step >= args.max_train_steps:
1681
+ break
1682
+
1683
+ # Create the pipeline using using the trained modules and save it.
1684
+ accelerator.wait_for_everyone()
1685
+ if accelerator.is_main_process:
1686
+ unet = accelerator.unwrap_model(unet)
1687
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1688
+ StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
1689
+
1690
+ if args.push_to_hub:
1691
+ upload_folder(
1692
+ repo_id=repo_id,
1693
+ folder_path=args.output_dir,
1694
+ commit_message="End of training",
1695
+ ignore_patterns=["step_*", "epoch_*"],
1696
+ )
1697
+
1698
+ del unet
1699
+ torch.cuda.empty_cache()
1700
+
1701
+ # Final inference.
1702
+ if args.validation_steps is not None:
1703
+ log_validation(unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1704
+ lcm_scheduler, image_encoder=None, image_processor=None,
1705
+ args=args, accelerator=accelerator, weight_dtype=weight_dtype, step=0, is_final_validation=False, log_local=True)
1706
+
1707
+ accelerator.end_training()
1708
+
1709
+
1710
+ if __name__ == "__main__":
1711
+ args = parse_args()
1712
+ main(args)