SunderAli17 commited on
Commit
1b9f95d
·
verified ·
1 Parent(s): d440b15

Create train_stage2_aggregator.py

Browse files
Files changed (1) hide show
  1. functions/train_stage2_aggregator.py +1698 -0
functions/train_stage2_aggregator.py ADDED
@@ -0,0 +1,1698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import os
17
+ import argparse
18
+ import time
19
+ import gc
20
+ import logging
21
+ import math
22
+ import copy
23
+ import random
24
+ import yaml
25
+ import functools
26
+ import shutil
27
+ import pyrallis
28
+ from pathlib import Path
29
+ from collections import namedtuple, OrderedDict
30
+
31
+ import accelerate
32
+ import numpy as np
33
+ import torch
34
+ from safetensors import safe_open
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
41
+ from datasets import load_dataset
42
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
43
+ from huggingface_hub import create_repo, upload_folder
44
+ from packaging import version
45
+ from PIL import Image
46
+ from data.data_config import DataConfig
47
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
48
+ from losses.loss_config import LossesConfig
49
+ from losses.losses import *
50
+ from torchvision import transforms
51
+ from torchvision.transforms.functional import crop
52
+ from tqdm.auto import tqdm
53
+ from transformers import (
54
+ AutoTokenizer,
55
+ PretrainedConfig,
56
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
57
+ AutoImageProcessor, AutoModel
58
+ )
59
+
60
+ import diffusers
61
+ from diffusers import (
62
+ AutoencoderKL,
63
+ DDPMScheduler,
64
+ StableDiffusionXLPipeline,
65
+ UNet2DConditionModel,
66
+ )
67
+ from diffusers.optimization import get_scheduler
68
+ from diffusers.utils import (
69
+ check_min_version,
70
+ convert_unet_state_dict_to_peft,
71
+ is_wandb_available,
72
+ )
73
+ from diffusers.utils.import_utils import is_xformers_available
74
+ from diffusers.utils.torch_utils import is_compiled_module
75
+
76
+ from module.aggregator import Aggregator
77
+ from pipelines.lcm_single_step_scheduler import LCMSingleStepScheduler
78
+ from module.ip_adapter.ip_adapter import MultiIPAdapterImageProjection
79
+ from module.ip_adapter.resampler import Resampler
80
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
81
+ from module.ip_adapter.attention_processor import init_attn_proc
82
+ from utils.train_utils import (
83
+ seperate_ip_params_from_unet,
84
+ import_model_class_from_model_name_or_path,
85
+ tensor_to_pil,
86
+ get_train_dataset, prepare_train_dataset, collate_fn,
87
+ encode_prompt, importance_sampling_fn, extract_into_tensor
88
+ )
89
+ from pipelines.sdxl_SAKBIR import SAKBIRPipeline
90
+
91
+
92
+ if is_wandb_available():
93
+ import wandb
94
+
95
+
96
+ logger = get_logger(__name__)
97
+
98
+
99
+ def log_validation(unet, aggregator, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
100
+ scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
101
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
102
+ logger.info("Running validation... ")
103
+
104
+ image_logs = []
105
+
106
+ # validation_batch = batchify_pil(args.validation_image, args.validation_prompt, deg_pipeline, image_processor)
107
+ lq = [Image.open(lq_example).convert("RGB") for lq_example in args.validation_image]
108
+
109
+ pipe = InstantIRPipeline(
110
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
111
+ unet, scheduler, aggregator, feature_extractor=image_processor, image_encoder=image_encoder,
112
+ ).to(accelerator.device)
113
+
114
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
115
+ if lq_img is not None and gt_img is not None:
116
+ lq_img = lq_img[:len(args.validation_image)]
117
+ lq_pt = image_processor(
118
+ images=lq_img*0.5+0.5,
119
+ do_rescale=False, return_tensors="pt"
120
+ ).pixel_values
121
+ image = pipe(
122
+ prompt=[""]*len(lq_img),
123
+ image=lq_img,
124
+ ip_adapter_image=lq_pt,
125
+ num_inference_steps=20,
126
+ generator=generator,
127
+ controlnet_conditioning_scale=1.0,
128
+ negative_prompt=[""]*len(lq),
129
+ guidance_scale=5.0,
130
+ height=args.resolution,
131
+ width=args.resolution,
132
+ lcm_scheduler=lcm_scheduler,
133
+ ).images
134
+ else:
135
+ image = pipe(
136
+ prompt=[""]*len(lq),
137
+ image=lq,
138
+ ip_adapter_image=lq,
139
+ num_inference_steps=20,
140
+ generator=generator,
141
+ controlnet_conditioning_scale=1.0,
142
+ negative_prompt=[""]*len(lq),
143
+ guidance_scale=5.0,
144
+ height=args.resolution,
145
+ width=args.resolution,
146
+ lcm_scheduler=lcm_scheduler,
147
+ ).images
148
+
149
+ if log_local:
150
+ for i, rec_image in enumerate(image):
151
+ rec_image.save(f"./instantid_{i}.png")
152
+ return
153
+
154
+ tracker_key = "test" if is_final_validation else "validation"
155
+ for tracker in accelerator.trackers:
156
+ if tracker.name == "tensorboard":
157
+ images = [np.asarray(pil_img) for pil_img in image]
158
+ images = np.stack(images, axis=0)
159
+ if lq_img is not None and gt_img is not None:
160
+ input_lq = lq_img.cpu()
161
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
162
+ input_gt = gt_img.cpu()
163
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
164
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
165
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
166
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
167
+ elif tracker.name == "wandb":
168
+ raise NotImplementedError("Wandb logging not implemented for validation.")
169
+ formatted_images = []
170
+
171
+ for log in image_logs:
172
+ images = log["images"]
173
+ validation_prompt = log["validation_prompt"]
174
+ validation_image = log["validation_image"]
175
+
176
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
177
+
178
+ for image in images:
179
+ image = wandb.Image(image, caption=validation_prompt)
180
+ formatted_images.append(image)
181
+
182
+ tracker.log({tracker_key: formatted_images})
183
+ else:
184
+ logger.warning(f"image logging not implemented for {tracker.name}")
185
+
186
+ gc.collect()
187
+ torch.cuda.empty_cache()
188
+
189
+ return image_logs
190
+
191
+
192
+ def remove_attn2(model):
193
+ def recursive_find_module(name, module):
194
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
195
+ elif "resnets" in name: return
196
+ if hasattr(module, "attn2"):
197
+ setattr(module, "attn2", None)
198
+ setattr(module, "norm2", None)
199
+ return
200
+ for sub_name, sub_module in module.named_children():
201
+ recursive_find_module(f"{name}.{sub_name}", sub_module)
202
+
203
+ for name, module in model.named_children():
204
+ recursive_find_module(name, module)
205
+
206
+
207
+ def parse_args(input_args=None):
208
+ parser = argparse.ArgumentParser(description="Simple example of a IP-Adapter training script.")
209
+ parser.add_argument(
210
+ "--pretrained_model_name_or_path",
211
+ type=str,
212
+ default=None,
213
+ required=True,
214
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
215
+ )
216
+ parser.add_argument(
217
+ "--pretrained_vae_model_name_or_path",
218
+ type=str,
219
+ default=None,
220
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
221
+ )
222
+ parser.add_argument(
223
+ "--controlnet_model_name_or_path",
224
+ type=str,
225
+ default=None,
226
+ help="Path to an pretrained controlnet model like tile-controlnet.",
227
+ )
228
+ parser.add_argument(
229
+ "--use_lcm",
230
+ action="store_true",
231
+ help="Whether or not to use lcm unet.",
232
+ )
233
+ parser.add_argument(
234
+ "--pretrained_lcm_lora_path",
235
+ type=str,
236
+ default=None,
237
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
238
+ )
239
+ parser.add_argument(
240
+ "--lora_rank",
241
+ type=int,
242
+ default=64,
243
+ help="The rank of the LoRA projection matrix.",
244
+ )
245
+ parser.add_argument(
246
+ "--lora_alpha",
247
+ type=int,
248
+ default=64,
249
+ help=(
250
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
251
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--lora_dropout",
256
+ type=float,
257
+ default=0.0,
258
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
259
+ )
260
+ parser.add_argument(
261
+ "--lora_target_modules",
262
+ type=str,
263
+ default=None,
264
+ help=(
265
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
266
+ " be used. By default, LoRA will be applied to all conv and linear layers."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--feature_extractor_path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
274
+ )
275
+ parser.add_argument(
276
+ "--pretrained_adapter_model_path",
277
+ type=str,
278
+ default=None,
279
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
280
+ )
281
+ parser.add_argument(
282
+ "--adapter_tokens",
283
+ type=int,
284
+ default=64,
285
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
286
+ )
287
+ parser.add_argument(
288
+ "--aggregator_adapter",
289
+ action="store_true",
290
+ help="Whether or not to add adapter on aggregator.",
291
+ )
292
+ parser.add_argument(
293
+ "--optimize_adapter",
294
+ action="store_true",
295
+ help="Whether or not to optimize IP-Adapter.",
296
+ )
297
+ parser.add_argument(
298
+ "--image_encoder_hidden_feature",
299
+ action="store_true",
300
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
301
+ )
302
+ parser.add_argument(
303
+ "--losses_config_path",
304
+ type=str,
305
+ required=True,
306
+ help=("A yaml file containing losses to use and their weights."),
307
+ )
308
+ parser.add_argument(
309
+ "--data_config_path",
310
+ type=str,
311
+ default=None,
312
+ help=("A folder containing the training data. "),
313
+ )
314
+ parser.add_argument(
315
+ "--variant",
316
+ type=str,
317
+ default=None,
318
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
319
+ )
320
+ parser.add_argument(
321
+ "--revision",
322
+ type=str,
323
+ default=None,
324
+ required=False,
325
+ help="Revision of pretrained model identifier from huggingface.co/models.",
326
+ )
327
+ parser.add_argument(
328
+ "--tokenizer_name",
329
+ type=str,
330
+ default=None,
331
+ help="Pretrained tokenizer name or path if not the same as model_name",
332
+ )
333
+ parser.add_argument(
334
+ "--output_dir",
335
+ type=str,
336
+ default="stage1_model",
337
+ help="The output directory where the model predictions and checkpoints will be written.",
338
+ )
339
+ parser.add_argument(
340
+ "--cache_dir",
341
+ type=str,
342
+ default=None,
343
+ help="The directory where the downloaded models and datasets will be stored.",
344
+ )
345
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
346
+ parser.add_argument(
347
+ "--resolution",
348
+ type=int,
349
+ default=512,
350
+ help=(
351
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
352
+ " resolution"
353
+ ),
354
+ )
355
+ parser.add_argument(
356
+ "--crops_coords_top_left_h",
357
+ type=int,
358
+ default=0,
359
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
360
+ )
361
+ parser.add_argument(
362
+ "--crops_coords_top_left_w",
363
+ type=int,
364
+ default=0,
365
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
366
+ )
367
+ parser.add_argument(
368
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
369
+ )
370
+ parser.add_argument("--num_train_epochs", type=int, default=1)
371
+ parser.add_argument(
372
+ "--max_train_steps",
373
+ type=int,
374
+ default=None,
375
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
376
+ )
377
+ parser.add_argument(
378
+ "--checkpointing_steps",
379
+ type=int,
380
+ default=3000,
381
+ help=(
382
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
383
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
384
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
385
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
386
+ "instructions."
387
+ ),
388
+ )
389
+ parser.add_argument(
390
+ "--checkpoints_total_limit",
391
+ type=int,
392
+ default=5,
393
+ help=("Max number of checkpoints to store."),
394
+ )
395
+ parser.add_argument(
396
+ "--previous_ckpt",
397
+ type=str,
398
+ default=None,
399
+ help=(
400
+ "Whether training should be initialized from a previous checkpoint."
401
+ ),
402
+ )
403
+ parser.add_argument(
404
+ "--resume_from_checkpoint",
405
+ type=str,
406
+ default=None,
407
+ help=(
408
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
409
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
410
+ ),
411
+ )
412
+ parser.add_argument(
413
+ "--gradient_accumulation_steps",
414
+ type=int,
415
+ default=1,
416
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
417
+ )
418
+ parser.add_argument(
419
+ "--gradient_checkpointing",
420
+ action="store_true",
421
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
422
+ )
423
+ parser.add_argument(
424
+ "--save_only_adapter",
425
+ action="store_true",
426
+ help="Only save extra adapter to save space.",
427
+ )
428
+ parser.add_argument(
429
+ "--cache_prompt_embeds",
430
+ action="store_true",
431
+ help="Whether or not to cache prompt embeds to save memory.",
432
+ )
433
+ parser.add_argument(
434
+ "--importance_sampling",
435
+ action="store_true",
436
+ help="Whether or not to use importance sampling.",
437
+ )
438
+ parser.add_argument(
439
+ "--CFG_scale",
440
+ type=float,
441
+ default=1.0,
442
+ help="CFG for previewer.",
443
+ )
444
+ parser.add_argument(
445
+ "--learning_rate",
446
+ type=float,
447
+ default=1e-4,
448
+ help="Initial learning rate (after the potential warmup period) to use.",
449
+ )
450
+ parser.add_argument(
451
+ "--scale_lr",
452
+ action="store_true",
453
+ default=False,
454
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
455
+ )
456
+ parser.add_argument(
457
+ "--lr_scheduler",
458
+ type=str,
459
+ default="constant",
460
+ help=(
461
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
462
+ ' "constant", "constant_with_warmup"]'
463
+ ),
464
+ )
465
+ parser.add_argument(
466
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
467
+ )
468
+ parser.add_argument(
469
+ "--lr_num_cycles",
470
+ type=int,
471
+ default=1,
472
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
473
+ )
474
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
475
+ parser.add_argument(
476
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
477
+ )
478
+ parser.add_argument(
479
+ "--dataloader_num_workers",
480
+ type=int,
481
+ default=0,
482
+ help=(
483
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
484
+ ),
485
+ )
486
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
487
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
488
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
489
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
490
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
491
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
492
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
493
+ parser.add_argument(
494
+ "--hub_model_id",
495
+ type=str,
496
+ default=None,
497
+ help="The name of the repository to keep in sync with the local `output_dir`.",
498
+ )
499
+ parser.add_argument(
500
+ "--logging_dir",
501
+ type=str,
502
+ default="logs",
503
+ help=(
504
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
505
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--allow_tf32",
510
+ action="store_true",
511
+ help=(
512
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
513
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
514
+ ),
515
+ )
516
+ parser.add_argument(
517
+ "--report_to",
518
+ type=str,
519
+ default="tensorboard",
520
+ help=(
521
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
522
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
523
+ ),
524
+ )
525
+ parser.add_argument(
526
+ "--mixed_precision",
527
+ type=str,
528
+ default=None,
529
+ choices=["no", "fp16", "bf16"],
530
+ help=(
531
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
532
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
533
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
534
+ ),
535
+ )
536
+ parser.add_argument(
537
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
538
+ )
539
+ parser.add_argument(
540
+ "--set_grads_to_none",
541
+ action="store_true",
542
+ help=(
543
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
544
+ " behaviors, so disable this argument if it causes any problems. More info:"
545
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
546
+ ),
547
+ )
548
+ parser.add_argument(
549
+ "--dataset_name",
550
+ type=str,
551
+ default=None,
552
+ help=(
553
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
554
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
555
+ " or to a folder containing files that 🤗 Datasets can understand."
556
+ ),
557
+ )
558
+ parser.add_argument(
559
+ "--dataset_config_name",
560
+ type=str,
561
+ default=None,
562
+ help="The config of the Dataset, leave as None if there's only one config.",
563
+ )
564
+ parser.add_argument(
565
+ "--train_data_dir",
566
+ type=str,
567
+ default=None,
568
+ help=(
569
+ "A folder containing the training data. Folder contents must follow the structure described in"
570
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
571
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
572
+ ),
573
+ )
574
+ parser.add_argument(
575
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
576
+ )
577
+ parser.add_argument(
578
+ "--conditioning_image_column",
579
+ type=str,
580
+ default="conditioning_image",
581
+ help="The column of the dataset containing the controlnet conditioning image.",
582
+ )
583
+ parser.add_argument(
584
+ "--caption_column",
585
+ type=str,
586
+ default="text",
587
+ help="The column of the dataset containing a caption or a list of captions.",
588
+ )
589
+ parser.add_argument(
590
+ "--max_train_samples",
591
+ type=int,
592
+ default=None,
593
+ help=(
594
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
595
+ "value if set."
596
+ ),
597
+ )
598
+ parser.add_argument(
599
+ "--text_drop_rate",
600
+ type=float,
601
+ default=0,
602
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
603
+ )
604
+ parser.add_argument(
605
+ "--image_drop_rate",
606
+ type=float,
607
+ default=0,
608
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
609
+ )
610
+ parser.add_argument(
611
+ "--cond_drop_rate",
612
+ type=float,
613
+ default=0,
614
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
615
+ )
616
+ parser.add_argument(
617
+ "--use_ema_adapter",
618
+ action="store_true",
619
+ help=(
620
+ "use ema ip-adapter for LCM preview"
621
+ ),
622
+ )
623
+ parser.add_argument(
624
+ "--sanity_check",
625
+ action="store_true",
626
+ help=(
627
+ "sanity check"
628
+ ),
629
+ )
630
+ parser.add_argument(
631
+ "--validation_prompt",
632
+ type=str,
633
+ default=None,
634
+ nargs="+",
635
+ help=(
636
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
637
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
638
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
639
+ ),
640
+ )
641
+ parser.add_argument(
642
+ "--validation_image",
643
+ type=str,
644
+ default=None,
645
+ nargs="+",
646
+ help=(
647
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
648
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
649
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
650
+ " `--validation_image` that will be used with all `--validation_prompt`s."
651
+ ),
652
+ )
653
+ parser.add_argument(
654
+ "--num_validation_images",
655
+ type=int,
656
+ default=4,
657
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
658
+ )
659
+ parser.add_argument(
660
+ "--validation_steps",
661
+ type=int,
662
+ default=4000,
663
+ help=(
664
+ "Run validation every X steps. Validation consists of running the prompt"
665
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
666
+ " and logging the images."
667
+ ),
668
+ )
669
+ parser.add_argument(
670
+ "--tracker_project_name",
671
+ type=str,
672
+ default='train',
673
+ help=(
674
+ "The `project_name` argument passed to Accelerator.init_trackers for"
675
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
676
+ ),
677
+ )
678
+
679
+ if input_args is not None:
680
+ args = parser.parse_args(input_args)
681
+ else:
682
+ args = parser.parse_args()
683
+
684
+ if not args.sanity_check and args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
685
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
686
+
687
+ if args.dataset_name is not None and args.train_data_dir is not None:
688
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
689
+
690
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
691
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
692
+
693
+ if args.validation_prompt is not None and args.validation_image is None:
694
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
695
+
696
+ if args.validation_prompt is None and args.validation_image is not None:
697
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
698
+
699
+ if (
700
+ args.validation_image is not None
701
+ and args.validation_prompt is not None
702
+ and len(args.validation_image) != 1
703
+ and len(args.validation_prompt) != 1
704
+ and len(args.validation_image) != len(args.validation_prompt)
705
+ ):
706
+ raise ValueError(
707
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
708
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
709
+ )
710
+
711
+ if args.resolution % 8 != 0:
712
+ raise ValueError(
713
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
714
+ )
715
+
716
+ return args
717
+
718
+
719
+ def update_ema_model(ema_model, model, ema_beta):
720
+ for ema_param, param in zip(ema_model.parameters(), model.parameters()):
721
+ ema_param.copy_(param.detach().lerp(ema_param, ema_beta))
722
+
723
+
724
+ def copy_dict(dict):
725
+ new_dict = {}
726
+ for key, value in dict.items():
727
+ new_dict[key] = value
728
+ return new_dict
729
+
730
+
731
+ def main(args):
732
+ if args.report_to == "wandb" and args.hub_token is not None:
733
+ raise ValueError(
734
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
735
+ " Please use `huggingface-cli login` to authenticate with the Hub."
736
+ )
737
+
738
+ logging_dir = Path(args.output_dir, args.logging_dir)
739
+
740
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
741
+ # due to pytorch#99272, MPS does not yet support bfloat16.
742
+ raise ValueError(
743
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
744
+ )
745
+
746
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
747
+
748
+ accelerator = Accelerator(
749
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
750
+ mixed_precision=args.mixed_precision,
751
+ log_with=args.report_to,
752
+ project_config=accelerator_project_config,
753
+ )
754
+
755
+ # Make one log on every process with the configuration for debugging.
756
+ logging.basicConfig(
757
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
758
+ datefmt="%m/%d/%Y %H:%M:%S",
759
+ level=logging.INFO,
760
+ )
761
+ logger.info(accelerator.state, main_process_only=False)
762
+ if accelerator.is_local_main_process:
763
+ transformers.utils.logging.set_verbosity_warning()
764
+ diffusers.utils.logging.set_verbosity_info()
765
+ else:
766
+ transformers.utils.logging.set_verbosity_error()
767
+ diffusers.utils.logging.set_verbosity_error()
768
+
769
+ # If passed along, set the training seed now.
770
+ if args.seed is not None:
771
+ set_seed(args.seed)
772
+
773
+ # Handle the repository creation.
774
+ if accelerator.is_main_process:
775
+ if args.output_dir is not None:
776
+ os.makedirs(args.output_dir, exist_ok=True)
777
+
778
+ # Load scheduler and models
779
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
780
+
781
+ # Importance sampling.
782
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
783
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
784
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
785
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
786
+
787
+ # Load the tokenizers
788
+ tokenizer = AutoTokenizer.from_pretrained(
789
+ args.pretrained_model_name_or_path,
790
+ subfolder="tokenizer",
791
+ revision=args.revision,
792
+ use_fast=False,
793
+ )
794
+ tokenizer_2 = AutoTokenizer.from_pretrained(
795
+ args.pretrained_model_name_or_path,
796
+ subfolder="tokenizer_2",
797
+ revision=args.revision,
798
+ use_fast=False,
799
+ )
800
+
801
+ # Text encoder and image encoder.
802
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
803
+ args.pretrained_model_name_or_path, args.revision
804
+ )
805
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
806
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
807
+ )
808
+ text_encoder = text_encoder_cls_one.from_pretrained(
809
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
810
+ )
811
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
812
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
813
+ )
814
+
815
+ # Image processor and image encoder.
816
+ if args.use_clip_encoder:
817
+ image_processor = CLIPImageProcessor()
818
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
819
+ else:
820
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
821
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
822
+
823
+ # VAE.
824
+ vae_path = (
825
+ args.pretrained_model_name_or_path
826
+ if args.pretrained_vae_model_name_or_path is None
827
+ else args.pretrained_vae_model_name_or_path
828
+ )
829
+ vae = AutoencoderKL.from_pretrained(
830
+ vae_path,
831
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
832
+ revision=args.revision,
833
+ variant=args.variant,
834
+ )
835
+
836
+ # UNet.
837
+ unet = UNet2DConditionModel.from_pretrained(
838
+ args.pretrained_model_name_or_path,
839
+ subfolder="unet",
840
+ revision=args.revision,
841
+ variant=args.variant
842
+ )
843
+
844
+ # Aggregator.
845
+ aggregator = Aggregator.from_unet(unet)
846
+ remove_attn2(aggregator)
847
+ if args.controlnet_model_name_or_path:
848
+ logger.info("Loading existing controlnet weights")
849
+ if args.controlnet_model_name_or_path.endswith(".safetensors"):
850
+ pretrained_cn_state_dict = {}
851
+ with safe_open(args.controlnet_model_name_or_path, framework="pt", device='cpu') as f:
852
+ for key in f.keys():
853
+ pretrained_cn_state_dict[key] = f.get_tensor(key)
854
+ else:
855
+ pretrained_cn_state_dict = torch.load(os.path.join(args.controlnet_model_name_or_path, "aggregator_ckpt.pt"), map_location="cpu")
856
+ aggregator.load_state_dict(pretrained_cn_state_dict, strict=True)
857
+ else:
858
+ logger.info("Initializing aggregator weights from unet.")
859
+
860
+ # Create image embedding projector for IP-Adapters.
861
+ if args.pretrained_adapter_model_path is not None:
862
+ if args.pretrained_adapter_model_path.endswith(".safetensors"):
863
+ pretrained_adapter_state_dict = {"image_proj": {}, "ip_adapter": {}}
864
+ with safe_open(args.pretrained_adapter_model_path, framework="pt", device="cpu") as f:
865
+ for key in f.keys():
866
+ if key.startswith("image_proj."):
867
+ pretrained_adapter_state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
868
+ elif key.startswith("ip_adapter."):
869
+ pretrained_adapter_state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
870
+ else:
871
+ pretrained_adapter_state_dict = torch.load(args.pretrained_adapter_model_path, map_location="cpu")
872
+
873
+ # Image embedding Projector.
874
+ image_proj_model = Resampler(
875
+ dim=1280,
876
+ depth=4,
877
+ dim_head=64,
878
+ heads=20,
879
+ num_queries=args.adapter_tokens,
880
+ embedding_dim=image_encoder.config.hidden_size,
881
+ output_dim=unet.config.cross_attention_dim,
882
+ ff_mult=4
883
+ )
884
+
885
+ init_adapter_in_unet(
886
+ unet,
887
+ image_proj_model,
888
+ pretrained_adapter_state_dict,
889
+ adapter_tokens=args.adapter_tokens,
890
+ )
891
+
892
+ # EMA adapter for LCM preview.
893
+ if args.use_ema_adapter:
894
+ assert args.optimize_adapter, "No need for EMA with frozen adapter."
895
+ ema_image_proj_model = Resampler(
896
+ dim=1280,
897
+ depth=4,
898
+ dim_head=64,
899
+ heads=20,
900
+ num_queries=args.adapter_tokens,
901
+ embedding_dim=image_encoder.config.hidden_size,
902
+ output_dim=unet.config.cross_attention_dim,
903
+ ff_mult=4
904
+ )
905
+ orig_encoder_hid_proj = unet.encoder_hid_proj
906
+ ema_encoder_hid_proj = MultiIPAdapterImageProjection([ema_image_proj_model])
907
+ orig_attn_procs = unet.attn_processors
908
+ orig_attn_procs_list = torch.nn.ModuleList(orig_attn_procs.values())
909
+ ema_attn_procs = init_attn_proc(unet, args.adapter_tokens, True, True, False)
910
+ ema_attn_procs_list = torch.nn.ModuleList(ema_attn_procs.values())
911
+ ema_attn_procs_list.requires_grad_(False)
912
+ ema_encoder_hid_proj.requires_grad_(False)
913
+
914
+ # Initialize EMA state.
915
+ ema_beta = 0.5 ** (args.ema_update_steps / max(args.ema_halflife_steps, 1e-8))
916
+ logger.info(f"Using EMA with beta: {ema_beta}")
917
+ ema_encoder_hid_proj.load_state_dict(orig_encoder_hid_proj.state_dict())
918
+ ema_attn_procs_list.load_state_dict(orig_attn_procs_list.state_dict())
919
+
920
+ # Projector for aggregator.
921
+ if args.aggregator_adapter:
922
+ image_proj_model = Resampler(
923
+ dim=1280,
924
+ depth=4,
925
+ dim_head=64,
926
+ heads=20,
927
+ num_queries=args.adapter_tokens,
928
+ embedding_dim=image_encoder.config.hidden_size,
929
+ output_dim=unet.config.cross_attention_dim,
930
+ ff_mult=4
931
+ )
932
+
933
+ init_adapter_in_unet(
934
+ aggregator,
935
+ image_proj_model,
936
+ pretrained_adapter_state_dict,
937
+ adapter_tokens=args.adapter_tokens,
938
+ )
939
+ del pretrained_adapter_state_dict
940
+
941
+ # Load LCM LoRA into unet.
942
+ if args.pretrained_lcm_lora_path is not None:
943
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
944
+ unet_state_dict = {
945
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
946
+ }
947
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
948
+ lora_state_dict = dict()
949
+ for k, v in unet_state_dict.items():
950
+ if "ip" in k:
951
+ k = k.replace("attn2", "attn2.processor")
952
+ lora_state_dict[k] = v
953
+ else:
954
+ lora_state_dict[k] = v
955
+ if alpha_dict:
956
+ args.lora_alpha = next(iter(alpha_dict.values()))
957
+ else:
958
+ args.lora_alpha = 1
959
+ logger.info(f"Loaded LCM LoRA with alpha: {args.lora_alpha}")
960
+ # Create LoRA config, FIXME: now hard-coded.
961
+ lora_target_modules = [
962
+ "to_q",
963
+ "to_kv",
964
+ "0.to_out",
965
+ "attn1.to_k",
966
+ "attn1.to_v",
967
+ "to_k_ip",
968
+ "to_v_ip",
969
+ "ln_k_ip.linear",
970
+ "ln_v_ip.linear",
971
+ "to_out.0",
972
+ "proj_in",
973
+ "proj_out",
974
+ "ff.net.0.proj",
975
+ "ff.net.2",
976
+ "conv1",
977
+ "conv2",
978
+ "conv_shortcut",
979
+ "downsamplers.0.conv",
980
+ "upsamplers.0.conv",
981
+ "time_emb_proj",
982
+ ]
983
+ lora_config = LoraConfig(
984
+ r=args.lora_rank,
985
+ target_modules=lora_target_modules,
986
+ lora_alpha=args.lora_alpha,
987
+ lora_dropout=args.lora_dropout,
988
+ )
989
+
990
+ unet.add_adapter(lora_config)
991
+ if args.pretrained_lcm_lora_path is not None:
992
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
993
+ if incompatible_keys is not None:
994
+ # check only for unexpected keys
995
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
996
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
997
+ if unexpected_keys:
998
+ raise ValueError(
999
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1000
+ f" {unexpected_keys}. "
1001
+ )
1002
+ for k in missing_keys:
1003
+ if "lora" in k:
1004
+ raise ValueError(
1005
+ f"Loading adapter weights from state_dict led to missing keys: {missing_keys}. "
1006
+ )
1007
+ lcm_scheduler = LCMSingleStepScheduler.from_config(noise_scheduler.config)
1008
+
1009
+ # Initialize training state.
1010
+ vae.requires_grad_(False)
1011
+ image_encoder.requires_grad_(False)
1012
+ text_encoder.requires_grad_(False)
1013
+ text_encoder_2.requires_grad_(False)
1014
+ unet.requires_grad_(False)
1015
+ aggregator.requires_grad_(False)
1016
+
1017
+ def unwrap_model(model):
1018
+ model = accelerator.unwrap_model(model)
1019
+ model = model._orig_mod if is_compiled_module(model) else model
1020
+ return model
1021
+
1022
+ # `accelerate` 0.16.0 will have better support for customized saving
1023
+ if args.save_only_adapter:
1024
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1025
+ def save_model_hook(models, weights, output_dir):
1026
+ if accelerator.is_main_process:
1027
+ for model in models:
1028
+ if isinstance(model, Aggregator):
1029
+ torch.save(model.state_dict(), os.path.join(output_dir, "aggregator_ckpt.pt"))
1030
+ weights.pop()
1031
+
1032
+ def load_model_hook(models, input_dir):
1033
+
1034
+ while len(models) > 0:
1035
+ # pop models so that they are not loaded again
1036
+ model = models.pop()
1037
+
1038
+ if isinstance(model, Aggregator):
1039
+ aggregator_state_dict = torch.load(os.path.join(input_dir, "aggregator_ckpt.pt"), map_location="cpu")
1040
+ model.load_state_dict(aggregator_state_dict)
1041
+
1042
+ accelerator.register_save_state_pre_hook(save_model_hook)
1043
+ accelerator.register_load_state_pre_hook(load_model_hook)
1044
+
1045
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1046
+ # as these models are only used for inference, keeping weights in full precision is not required.
1047
+ weight_dtype = torch.float32
1048
+ if accelerator.mixed_precision == "fp16":
1049
+ weight_dtype = torch.float16
1050
+ elif accelerator.mixed_precision == "bf16":
1051
+ weight_dtype = torch.bfloat16
1052
+
1053
+ if args.enable_xformers_memory_efficient_attention:
1054
+ if is_xformers_available():
1055
+ import xformers
1056
+
1057
+ xformers_version = version.parse(xformers.__version__)
1058
+ if xformers_version == version.parse("0.0.16"):
1059
+ logger.warning(
1060
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1061
+ )
1062
+ unet.enable_xformers_memory_efficient_attention()
1063
+ else:
1064
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1065
+
1066
+ if args.gradient_checkpointing:
1067
+ aggregator.enable_gradient_checkpointing()
1068
+ unet.enable_gradient_checkpointing()
1069
+
1070
+ # Check that all trainable models are in full precision
1071
+ low_precision_error_string = (
1072
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
1073
+ " doing mixed precision training, copy of the weights should still be float32."
1074
+ )
1075
+
1076
+ if unwrap_model(aggregator).dtype != torch.float32:
1077
+ raise ValueError(
1078
+ f"aggregator loaded as datatype {unwrap_model(aggregator).dtype}. {low_precision_error_string}"
1079
+ )
1080
+
1081
+ # Enable TF32 for faster training on Ampere GPUs,
1082
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1083
+ if args.allow_tf32:
1084
+ torch.backends.cuda.matmul.allow_tf32 = True
1085
+
1086
+ if args.scale_lr:
1087
+ args.learning_rate = (
1088
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1089
+ )
1090
+
1091
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1092
+ if args.use_8bit_adam:
1093
+ try:
1094
+ import bitsandbytes as bnb
1095
+ except ImportError:
1096
+ raise ImportError(
1097
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1098
+ )
1099
+
1100
+ optimizer_class = bnb.optim.AdamW8bit
1101
+ else:
1102
+ optimizer_class = torch.optim.AdamW
1103
+
1104
+ # Optimizer creation
1105
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
1106
+ if args.optimize_adapter:
1107
+ unet_params = ip_params
1108
+ unet_frozen_params = non_ip_params
1109
+ else: # freeze all unet params
1110
+ unet_params = []
1111
+ unet_frozen_params = ip_params + non_ip_params
1112
+ assert len(unet_frozen_params) == len(list(unet.parameters()))
1113
+ params_to_optimize = [p for p in aggregator.parameters()]
1114
+ params_to_optimize += unet_params
1115
+ optimizer = optimizer_class(
1116
+ params_to_optimize,
1117
+ lr=args.learning_rate,
1118
+ betas=(args.adam_beta1, args.adam_beta2),
1119
+ weight_decay=args.adam_weight_decay,
1120
+ eps=args.adam_epsilon,
1121
+ )
1122
+
1123
+ # Instantiate Loss.
1124
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
1125
+ diffusion_losses = list()
1126
+ lcm_losses = list()
1127
+ for loss_config in losses_configs.diffusion_losses:
1128
+ logger.info(f"Using diffusion loss: {loss_config.name}")
1129
+ loss = namedtuple("loss", ["loss", "weight"])
1130
+ diffusion_losses.append(
1131
+ loss(loss=eval(loss_config.name)(
1132
+ visualize_every_k=loss_config.visualize_every_k,
1133
+ dtype=weight_dtype,
1134
+ accelerator=accelerator,
1135
+ **loss_config.init_params), weight=loss_config.weight)
1136
+ )
1137
+ for loss_config in losses_configs.lcm_losses:
1138
+ logger.info(f"Using lcm loss: {loss_config.name}")
1139
+ loss = namedtuple("loss", ["loss", "weight"])
1140
+ loss_class = eval(loss_config.name)
1141
+ lcm_losses.append(loss(loss=loss_class(visualize_every_k=loss_config.visualize_every_k,
1142
+ dtype=weight_dtype,
1143
+ accelerator=accelerator,
1144
+ dino_model=image_encoder,
1145
+ dino_preprocess=image_processor,
1146
+ **loss_config.init_params), weight=loss_config.weight))
1147
+
1148
+ # SDXL additional condition that will be added to time embedding.
1149
+ def compute_time_ids(original_size, crops_coords_top_left):
1150
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1151
+ target_size = (args.resolution, args.resolution)
1152
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1153
+ add_time_ids = torch.tensor([add_time_ids])
1154
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1155
+ return add_time_ids
1156
+
1157
+ # Text prompt embeddings.
1158
+ @torch.no_grad()
1159
+ def compute_embeddings(batch, text_encoders, tokenizers, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
1160
+ prompt_batch = batch[args.caption_column]
1161
+ if drop_idx is not None:
1162
+ for i in range(len(prompt_batch)):
1163
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
1164
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1165
+ prompt_batch, text_encoders, tokenizers, is_train
1166
+ )
1167
+
1168
+ add_time_ids = torch.cat(
1169
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1170
+ )
1171
+
1172
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1173
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1174
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1175
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
1176
+
1177
+ return prompt_embeds, unet_added_cond_kwargs
1178
+
1179
+ @torch.no_grad()
1180
+ def get_added_cond(batch, prompt_embeds, pooled_prompt_embeds, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
1181
+
1182
+ add_time_ids = torch.cat(
1183
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1184
+ )
1185
+
1186
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1187
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1188
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1189
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
1190
+
1191
+ return prompt_embeds, unet_added_cond_kwargs
1192
+
1193
+ # Move pixels into latents.
1194
+ @torch.no_grad()
1195
+ def convert_to_latent(pixels):
1196
+ model_input = vae.encode(pixels).latent_dist.sample()
1197
+ model_input = model_input * vae.config.scaling_factor
1198
+ if args.pretrained_vae_model_name_or_path is None:
1199
+ model_input = model_input.to(weight_dtype)
1200
+ return model_input
1201
+
1202
+ # Helper functions for training loop.
1203
+ # if args.degradation_config_path is not None:
1204
+ # with open(args.degradation_config_path) as stream:
1205
+ # degradation_configs = yaml.safe_load(stream)
1206
+ # deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1207
+ # else:
1208
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1209
+ compute_embeddings_fn = functools.partial(
1210
+ compute_embeddings,
1211
+ text_encoders=[text_encoder, text_encoder_2],
1212
+ tokenizers=[tokenizer, tokenizer_2],
1213
+ is_train=True,
1214
+ )
1215
+
1216
+ datasets = []
1217
+ datasets_name = []
1218
+ datasets_weights = []
1219
+ if args.data_config_path is not None:
1220
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
1221
+ for single_dataset in data_config.datasets:
1222
+ datasets_weights.append(single_dataset.dataset_weight)
1223
+ datasets_name.append(single_dataset.dataset_folder)
1224
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
1225
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
1226
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
1227
+ datasets.append(image_dataset)
1228
+ # TODO: Validation dataset
1229
+ if data_config.val_dataset is not None:
1230
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
1231
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
1232
+
1233
+ # Mix training datasets.
1234
+ sampler_train = None
1235
+ if len(datasets) == 1:
1236
+ train_dataset = datasets[0]
1237
+ else:
1238
+ # Weighted each dataset
1239
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
1240
+ dataset_weights = []
1241
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
1242
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
1243
+ sampler_train = torch.utils.data.WeightedRandomSampler(
1244
+ weights=dataset_weights,
1245
+ num_samples=len(dataset_weights)
1246
+ )
1247
+
1248
+ train_dataloader = torch.utils.data.DataLoader(
1249
+ train_dataset,
1250
+ batch_size=args.train_batch_size,
1251
+ sampler=sampler_train,
1252
+ shuffle=True if sampler_train is None else False,
1253
+ collate_fn=collate_fn,
1254
+ num_workers=args.dataloader_num_workers
1255
+ )
1256
+
1257
+ # We need to initialize the trackers we use, and also store our configuration.
1258
+ # The trackers initializes automatically on the main process.
1259
+ if accelerator.is_main_process:
1260
+ tracker_config = dict(vars(args))
1261
+
1262
+ # tensorboard cannot handle list types for config
1263
+ tracker_config.pop("validation_prompt")
1264
+ tracker_config.pop("validation_image")
1265
+
1266
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1267
+
1268
+ # Scheduler and math around the number of training steps.
1269
+ overrode_max_train_steps = False
1270
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1271
+ if args.max_train_steps is None:
1272
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1273
+ overrode_max_train_steps = True
1274
+
1275
+ lr_scheduler = get_scheduler(
1276
+ args.lr_scheduler,
1277
+ optimizer=optimizer,
1278
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1279
+ num_training_steps=args.max_train_steps,
1280
+ num_cycles=args.lr_num_cycles,
1281
+ power=args.lr_power,
1282
+ )
1283
+
1284
+ # Prepare everything with our `accelerator`.
1285
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1286
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler
1287
+ )
1288
+
1289
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1290
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
1291
+
1292
+ # # cache empty prompts and move text encoders to cpu
1293
+ # empty_prompt_embeds, empty_pooled_prompt_embeds = encode_prompt(
1294
+ # [""]*args.train_batch_size, [text_encoder, text_encoder_2], [tokenizer, tokenizer_2], True
1295
+ # )
1296
+ # compute_embeddings_fn = functools.partial(
1297
+ # get_added_cond,
1298
+ # prompt_embeds=empty_prompt_embeds,
1299
+ # pooled_prompt_embeds=empty_pooled_prompt_embeds,
1300
+ # is_train=True,
1301
+ # )
1302
+ # text_encoder.to("cpu")
1303
+ # text_encoder_2.to("cpu")
1304
+
1305
+ # Move vae, unet and text_encoder to device and cast to `weight_dtype`.
1306
+ if args.pretrained_vae_model_name_or_path is None:
1307
+ # The VAE is fp32 to avoid NaN losses.
1308
+ vae.to(accelerator.device, dtype=torch.float32)
1309
+ else:
1310
+ vae.to(accelerator.device, dtype=weight_dtype)
1311
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
1312
+ if args.use_ema_adapter:
1313
+ # FIXME: prepare ema model
1314
+ # ema_encoder_hid_proj, ema_attn_procs_list = accelerator.prepare(ema_encoder_hid_proj, ema_attn_procs_list)
1315
+ ema_encoder_hid_proj.to(accelerator.device)
1316
+ ema_attn_procs_list.to(accelerator.device)
1317
+ for param in unet_frozen_params:
1318
+ param.data = param.data.to(dtype=weight_dtype)
1319
+ for param in unet_params:
1320
+ param.requires_grad_(True)
1321
+ unet.to(accelerator.device)
1322
+ aggregator.requires_grad_(True)
1323
+ aggregator.to(accelerator.device)
1324
+ importance_ratio = importance_ratio.to(accelerator.device)
1325
+
1326
+ # Final sanity check.
1327
+ for n, p in unet.named_parameters():
1328
+ assert not p.requires_grad, n
1329
+ if p.requires_grad:
1330
+ assert p.dtype == torch.float32, n
1331
+ else:
1332
+ assert p.dtype == weight_dtype, n
1333
+ for n, p in aggregator.named_parameters():
1334
+ if p.requires_grad: assert p.dtype == torch.float32, n
1335
+ else:
1336
+ raise RuntimeError(f"All parameters in aggregator should require grad. {n}")
1337
+ assert p.dtype == weight_dtype, n
1338
+ unwrap_model(unet).disable_adapters()
1339
+
1340
+ if args.sanity_check:
1341
+ if args.resume_from_checkpoint:
1342
+ if args.resume_from_checkpoint != "latest":
1343
+ path = os.path.basename(args.resume_from_checkpoint)
1344
+ else:
1345
+ # Get the most recent checkpoint
1346
+ dirs = os.listdir(args.output_dir)
1347
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1348
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1349
+ path = dirs[-1] if len(dirs) > 0 else None
1350
+
1351
+ if path is None:
1352
+ accelerator.print(
1353
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1354
+ )
1355
+ args.resume_from_checkpoint = None
1356
+ initial_global_step = 0
1357
+ else:
1358
+ accelerator.print(f"Resuming from checkpoint {path}")
1359
+ accelerator.load_state(os.path.join(args.output_dir, path))
1360
+
1361
+ if args.use_ema_adapter:
1362
+ unwrap_model(unet).set_attn_processor(ema_attn_procs)
1363
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
1364
+ batch = next(iter(train_dataloader))
1365
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1366
+ log_validation(
1367
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1368
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1369
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1370
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, log_local=True
1371
+ )
1372
+ if args.use_ema_adapter:
1373
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
1374
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
1375
+ for n, p in unet.named_parameters():
1376
+ if p.requires_grad: assert p.dtype == torch.float32, n
1377
+ else: assert p.dtype == weight_dtype, n
1378
+ for n, p in aggregator.named_parameters():
1379
+ if p.requires_grad: assert p.dtype == torch.float32, n
1380
+ else: assert p.dtype == weight_dtype, n
1381
+ print("PASS")
1382
+ exit()
1383
+
1384
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1385
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1386
+ if overrode_max_train_steps:
1387
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1388
+ # Afterwards we recalculate our number of training epochs
1389
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1390
+
1391
+ # Train!
1392
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1393
+
1394
+ logger.info("***** Running training *****")
1395
+ logger.info(f" Num examples = {len(train_dataset)}")
1396
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1397
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1398
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1399
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1400
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1401
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
1402
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1403
+ global_step = 0
1404
+ first_epoch = 0
1405
+
1406
+ # Potentially load in the weights and states from a previous save
1407
+ if args.resume_from_checkpoint:
1408
+ if args.resume_from_checkpoint != "latest":
1409
+ path = os.path.basename(args.resume_from_checkpoint)
1410
+ else:
1411
+ # Get the most recent checkpoint
1412
+ dirs = os.listdir(args.output_dir)
1413
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1414
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1415
+ path = dirs[-1] if len(dirs) > 0 else None
1416
+
1417
+ if path is None:
1418
+ accelerator.print(
1419
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1420
+ )
1421
+ args.resume_from_checkpoint = None
1422
+ initial_global_step = 0
1423
+ else:
1424
+ accelerator.print(f"Resuming from checkpoint {path}")
1425
+ accelerator.load_state(os.path.join(args.output_dir, path))
1426
+ global_step = int(path.split("-")[1])
1427
+
1428
+ initial_global_step = global_step
1429
+ first_epoch = global_step // num_update_steps_per_epoch
1430
+ else:
1431
+ initial_global_step = 0
1432
+
1433
+ progress_bar = tqdm(
1434
+ range(0, args.max_train_steps),
1435
+ initial=initial_global_step,
1436
+ desc="Steps",
1437
+ # Only show the progress bar once on each machine.
1438
+ disable=not accelerator.is_local_main_process,
1439
+ )
1440
+
1441
+ trainable_models = [aggregator, unet]
1442
+
1443
+ if args.gradient_checkpointing:
1444
+ # TODO: add vae
1445
+ checkpoint_models = []
1446
+ else:
1447
+ checkpoint_models = []
1448
+
1449
+ image_logs = None
1450
+ tic = time.time()
1451
+ for epoch in range(first_epoch, args.num_train_epochs):
1452
+ for step, batch in enumerate(train_dataloader):
1453
+ toc = time.time()
1454
+ io_time = toc - tic
1455
+ tic = time.time()
1456
+ for model in trainable_models + checkpoint_models:
1457
+ model.train()
1458
+ with accelerator.accumulate(*trainable_models):
1459
+ loss = torch.tensor(0.0)
1460
+
1461
+ # Drop conditions.
1462
+ rand_tensor = torch.rand(batch["images"].shape[0])
1463
+ drop_image_idx = rand_tensor < args.image_drop_rate
1464
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1465
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1466
+ drop_image_idx = drop_image_idx | drop_both_idx
1467
+ drop_text_idx = drop_text_idx | drop_both_idx
1468
+
1469
+ # Get LQ embeddings
1470
+ with torch.no_grad():
1471
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1472
+ lq_pt = image_processor(
1473
+ images=lq_img*0.5+0.5,
1474
+ do_rescale=False, return_tensors="pt"
1475
+ ).pixel_values
1476
+
1477
+ # Move inputs to latent space.
1478
+ gt_img = gt_img.to(dtype=vae.dtype)
1479
+ lq_img = lq_img.to(dtype=vae.dtype)
1480
+ model_input = convert_to_latent(gt_img)
1481
+ lq_latent = convert_to_latent(lq_img)
1482
+ if args.pretrained_vae_model_name_or_path is None:
1483
+ model_input = model_input.to(weight_dtype)
1484
+ lq_latent = lq_latent.to(weight_dtype)
1485
+
1486
+ # Process conditions.
1487
+ image_embeds = prepare_training_image_embeds(
1488
+ image_encoder, image_processor,
1489
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1490
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1491
+ idx_to_replace=drop_image_idx
1492
+ )
1493
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
1494
+
1495
+ # Sample noise that we'll add to the latents.
1496
+ noise = torch.randn_like(model_input)
1497
+ bsz = model_input.shape[0]
1498
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
1499
+
1500
+ # Add noise to the model input according to the noise magnitude at each timestep
1501
+ # (this is the forward diffusion process)
1502
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1503
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
1504
+
1505
+ if args.CFG_scale > 1.0:
1506
+ # Process negative conditions.
1507
+ drop_idx = torch.ones_like(drop_image_idx)
1508
+ neg_image_embeds = prepare_training_image_embeds(
1509
+ image_encoder, image_processor,
1510
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1511
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
1512
+ idx_to_replace=drop_idx
1513
+ )
1514
+ neg_prompt_embeds_input, neg_added_conditions = compute_embeddings_fn(batch, drop_idx=drop_idx)
1515
+ previewer_model_input = torch.cat([noisy_model_input] * 2)
1516
+ previewer_timesteps = torch.cat([timesteps] * 2)
1517
+ previewer_prompt_embeds = torch.cat([neg_prompt_embeds_input, prompt_embeds_input], dim=0)
1518
+ previewer_added_conditions = {}
1519
+ for k in added_conditions.keys():
1520
+ previewer_added_conditions[k] = torch.cat([neg_added_conditions[k], added_conditions[k]], dim=0)
1521
+ previewer_image_embeds = []
1522
+ for neg_image_embed, image_embed in zip(neg_image_embeds, image_embeds):
1523
+ previewer_image_embeds.append(torch.cat([neg_image_embed, image_embed], dim=0))
1524
+ previewer_added_conditions["image_embeds"] = previewer_image_embeds
1525
+ else:
1526
+ previewer_model_input = noisy_model_input
1527
+ previewer_timesteps = timesteps
1528
+ previewer_prompt_embeds = prompt_embeds_input
1529
+ previewer_added_conditions = {}
1530
+ for k in added_conditions.keys():
1531
+ previewer_added_conditions[k] = added_conditions[k]
1532
+ previewer_added_conditions["image_embeds"] = image_embeds
1533
+
1534
+ # Get LCM preview latent
1535
+ if args.use_ema_adapter:
1536
+ orig_encoder_hid_proj = unet.encoder_hid_proj
1537
+ orig_attn_procs = unet.attn_processors
1538
+ _ema_attn_procs = copy_dict(ema_attn_procs)
1539
+ unwrap_model(unet).set_attn_processor(_ema_attn_procs)
1540
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
1541
+ unwrap_model(unet).enable_adapters()
1542
+ preview_noise = unet(
1543
+ previewer_model_input, previewer_timesteps,
1544
+ encoder_hidden_states=previewer_prompt_embeds,
1545
+ added_cond_kwargs=previewer_added_conditions,
1546
+ return_dict=False
1547
+ )[0]
1548
+ if args.CFG_scale > 1.0:
1549
+ preview_noise_uncond, preview_noise_cond = preview_noise.chunk(2)
1550
+ cfg_scale = 1.0 + torch.rand_like(timesteps, dtype=weight_dtype) * (args.CFG_scale-1.0)
1551
+ cfg_scale = cfg_scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
1552
+ preview_noise = preview_noise_uncond + cfg_scale * (preview_noise_cond - preview_noise_uncond)
1553
+ preview_latents = lcm_scheduler.step(
1554
+ preview_noise,
1555
+ timesteps,
1556
+ noisy_model_input,
1557
+ return_dict=False
1558
+ )[0]
1559
+ unwrap_model(unet).disable_adapters()
1560
+ if args.use_ema_adapter:
1561
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
1562
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
1563
+ preview_error_latent = F.mse_loss(preview_latents, model_input).cpu().item()
1564
+ preview_error_noise = F.mse_loss(preview_noise, noise).cpu().item()
1565
+
1566
+ # # Add fresh noise
1567
+ # if args.noisy_encoder_input:
1568
+ # aggregator_noise = torch.randn_like(preview_latents)
1569
+ # aggregator_input = noise_scheduler.add_noise(preview_latents, aggregator_noise, timesteps)
1570
+
1571
+ down_block_res_samples, mid_block_res_sample = aggregator(
1572
+ lq_latent,
1573
+ timesteps,
1574
+ encoder_hidden_states=prompt_embeds_input,
1575
+ added_cond_kwargs=added_conditions,
1576
+ controlnet_cond=preview_latents,
1577
+ conditioning_scale=1.0,
1578
+ return_dict=False,
1579
+ )
1580
+
1581
+ # UNet denoise.
1582
+ added_conditions["image_embeds"] = image_embeds
1583
+ model_pred = unet(
1584
+ noisy_model_input,
1585
+ timesteps,
1586
+ encoder_hidden_states=prompt_embeds_input,
1587
+ added_cond_kwargs=added_conditions,
1588
+ down_block_additional_residuals=[
1589
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1590
+ ],
1591
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1592
+ return_dict=False
1593
+ )[0]
1594
+
1595
+ diffusion_loss_arguments = {
1596
+ "target": noise,
1597
+ "predict": model_pred,
1598
+ "prompt_embeddings_input": prompt_embeds_input,
1599
+ "timesteps": timesteps,
1600
+ "weights": loss_weights,
1601
+ }
1602
+
1603
+ loss_dict = dict()
1604
+ for loss_config in diffusion_losses:
1605
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
1606
+ loss = loss + non_weighted_loss * loss_config.weight
1607
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1608
+
1609
+ accelerator.backward(loss)
1610
+ if accelerator.sync_gradients:
1611
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1612
+ optimizer.step()
1613
+ lr_scheduler.step()
1614
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1615
+
1616
+ toc = time.time()
1617
+ forward_time = toc - tic
1618
+ tic = toc
1619
+
1620
+ # Checks if the accelerator has performed an optimization step behind the scenes
1621
+ if accelerator.sync_gradients:
1622
+ progress_bar.update(1)
1623
+ global_step += 1
1624
+
1625
+ if global_step % args.ema_update_steps == 0:
1626
+ if args.use_ema_adapter:
1627
+ update_ema_model(ema_encoder_hid_proj, orig_encoder_hid_proj, ema_beta)
1628
+ update_ema_model(ema_attn_procs_list, orig_attn_procs_list, ema_beta)
1629
+
1630
+ if accelerator.is_main_process:
1631
+ if global_step % args.checkpointing_steps == 0:
1632
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1633
+ if args.checkpoints_total_limit is not None:
1634
+ checkpoints = os.listdir(args.output_dir)
1635
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1636
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1637
+
1638
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1639
+ if len(checkpoints) >= args.checkpoints_total_limit:
1640
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1641
+ removing_checkpoints = checkpoints[0:num_to_remove]
1642
+
1643
+ logger.info(
1644
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1645
+ )
1646
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1647
+
1648
+ for removing_checkpoint in removing_checkpoints:
1649
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1650
+ shutil.rmtree(removing_checkpoint)
1651
+
1652
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1653
+ accelerator.save_state(save_path)
1654
+ logger.info(f"Saved state to {save_path}")
1655
+
1656
+ if global_step % args.validation_steps == 0:
1657
+ image_logs = log_validation(
1658
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1659
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1660
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1661
+ args, accelerator, weight_dtype, global_step, lq_img.detach().clone(), gt_img.detach().clone()
1662
+ )
1663
+
1664
+ logs = {}
1665
+ logs.update(loss_dict)
1666
+ logs.update(
1667
+ {"preview_error_latent": preview_error_latent, "preview_error_noise": preview_error_noise,
1668
+ "lr": lr_scheduler.get_last_lr()[0],
1669
+ "forward_time": forward_time, "io_time": io_time}
1670
+ )
1671
+ progress_bar.set_postfix(**logs)
1672
+ accelerator.log(logs, step=global_step)
1673
+ tic = time.time()
1674
+
1675
+ if global_step >= args.max_train_steps:
1676
+ break
1677
+
1678
+ # Create the pipeline using using the trained modules and save it.
1679
+ accelerator.wait_for_everyone()
1680
+ if accelerator.is_main_process:
1681
+ accelerator.save_state(save_path, safe_serialization=False)
1682
+ # Run a final round of validation.
1683
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1684
+ image_logs = None
1685
+ if args.validation_image is not None:
1686
+ image_logs = log_validation(
1687
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1688
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1689
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1690
+ args, accelerator, weight_dtype, global_step,
1691
+ )
1692
+
1693
+ accelerator.end_training()
1694
+
1695
+
1696
+ if __name__ == "__main__":
1697
+ args = parse_args()
1698
+ main(args)