SunderAli17 commited on
Commit
6caf646
·
verified ·
1 Parent(s): cd98a68

Create parser.py

Browse files
Files changed (1) hide show
  1. utils/parser.py +452 -0
utils/parser.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ def parse_args(input_args=None):
5
+ parser = argparse.ArgumentParser(description="Train Consistency Encoder.")
6
+ parser.add_argument(
7
+ "--pretrained_model_name_or_path",
8
+ type=str,
9
+ default=None,
10
+ required=True,
11
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
12
+ )
13
+ parser.add_argument(
14
+ "--pretrained_vae_model_name_or_path",
15
+ type=str,
16
+ default=None,
17
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
18
+ )
19
+ parser.add_argument(
20
+ "--revision",
21
+ type=str,
22
+ default=None,
23
+ required=False,
24
+ help="Revision of pretrained model identifier from huggingface.co/models.",
25
+ )
26
+ parser.add_argument(
27
+ "--variant",
28
+ type=str,
29
+ default=None,
30
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
31
+ )
32
+
33
+ # parser.add_argument(
34
+ # "--instance_data_dir",
35
+ # type=str,
36
+ # required=True,
37
+ # help=("A folder containing the training data. "),
38
+ # )
39
+
40
+ parser.add_argument(
41
+ "--data_config_path",
42
+ type=str,
43
+ required=True,
44
+ help=("A folder containing the training data. "),
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--cache_dir",
49
+ type=str,
50
+ default=None,
51
+ help="The directory where the downloaded models and datasets will be stored.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--image_column",
56
+ type=str,
57
+ default="image",
58
+ help="The column of the dataset containing the target image. By "
59
+ "default, the standard Image Dataset maps out 'file_name' "
60
+ "to 'image'.",
61
+ )
62
+ parser.add_argument(
63
+ "--caption_column",
64
+ type=str,
65
+ default=None,
66
+ help="The column of the dataset containing the instance prompt for each image",
67
+ )
68
+
69
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
70
+
71
+ parser.add_argument(
72
+ "--instance_prompt",
73
+ type=str,
74
+ default=None,
75
+ required=True,
76
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--validation_prompt",
81
+ type=str,
82
+ default=None,
83
+ help="A prompt that is used during validation to verify that the model is learning.",
84
+ )
85
+ parser.add_argument(
86
+ "--num_train_vis_images",
87
+ type=int,
88
+ default=2,
89
+ help="Number of images that should be generated during validation with `validation_prompt`.",
90
+ )
91
+ parser.add_argument(
92
+ "--num_validation_images",
93
+ type=int,
94
+ default=2,
95
+ help="Number of images that should be generated during validation with `validation_prompt`.",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--validation_vis_steps",
100
+ type=int,
101
+ default=500,
102
+ help=(
103
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
104
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
105
+ ),
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--train_vis_steps",
110
+ type=int,
111
+ default=500,
112
+ help=(
113
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
114
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
115
+ ),
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--vis_lcm",
120
+ type=bool,
121
+ default=True,
122
+ help=(
123
+ "Also log results of LCM inference",
124
+ ),
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--output_dir",
129
+ type=str,
130
+ default="lora-dreambooth-model",
131
+ help="The output directory where the model predictions and checkpoints will be written.",
132
+ )
133
+
134
+ parser.add_argument("--save_only_encoder", action="store_true", help="Only save the encoder and not the full accelerator state")
135
+
136
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
137
+
138
+ parser.add_argument("--freeze_encoder_unet", action="store_true", help="Don't train encoder unet")
139
+ parser.add_argument("--predict_word_embedding", action="store_true", help="Predict word embeddings in addition to KV features")
140
+ parser.add_argument("--ip_adapter_feature_extractor_path", type=str, help="Path to pre-trained feature extractor for IP-adapter")
141
+ parser.add_argument("--ip_adapter_model_path", type=str, help="Path to pre-trained IP-adapter.")
142
+ parser.add_argument("--ip_adapter_tokens", type=int, default=16, help="Number of tokens to use in IP-adapter cross attention mechanism")
143
+ parser.add_argument("--optimize_adapter", action="store_true", help="Optimize IP-adapter parameters (projector + cross-attention layers)")
144
+ parser.add_argument("--adapter_attention_scale", type=float, default=1.0, help="Relative strength of the adapter cross attention layers")
145
+ parser.add_argument("--adapter_lr", type=float, help="Learning rate for the adapter parameters. Defaults to the global LR if not provided")
146
+
147
+ parser.add_argument("--noisy_encoder_input", action="store_true", help="Noise the encoder input to the same step as the decoder?")
148
+
149
+ # related to CFG:
150
+ parser.add_argument("--adapter_drop_chance", type=float, default=0.0, help="Chance to drop adapter condition input during training")
151
+ parser.add_argument("--text_drop_chance", type=float, default=0.0, help="Chance to drop text condition during training")
152
+ parser.add_argument("--kv_drop_chance", type=float, default=0.0, help="Chance to drop KV condition during training")
153
+
154
+
155
+
156
+ parser.add_argument(
157
+ "--resolution",
158
+ type=int,
159
+ default=1024,
160
+ help=(
161
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
162
+ " resolution"
163
+ ),
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--crops_coords_top_left_h",
168
+ type=int,
169
+ default=0,
170
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
171
+ )
172
+
173
+ parser.add_argument(
174
+ "--crops_coords_top_left_w",
175
+ type=int,
176
+ default=0,
177
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--center_crop",
182
+ default=False,
183
+ action="store_true",
184
+ help=(
185
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
186
+ " cropped. The images will be resized to the resolution first before cropping."
187
+ ),
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
192
+ )
193
+
194
+ parser.add_argument("--num_train_epochs", type=int, default=1)
195
+
196
+ parser.add_argument(
197
+ "--max_train_steps",
198
+ type=int,
199
+ default=None,
200
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--checkpointing_steps",
205
+ type=int,
206
+ default=500,
207
+ help=(
208
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
209
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
210
+ " training using `--resume_from_checkpoint`."
211
+ ),
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--checkpoints_total_limit",
216
+ type=int,
217
+ default=5,
218
+ help=("Max number of checkpoints to store."),
219
+ )
220
+
221
+ parser.add_argument(
222
+ "--resume_from_checkpoint",
223
+ type=str,
224
+ default=None,
225
+ help=(
226
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
227
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
228
+ ),
229
+ )
230
+
231
+ parser.add_argument("--max_timesteps_for_x0_loss", type=int, default=1001)
232
+
233
+ parser.add_argument(
234
+ "--gradient_accumulation_steps",
235
+ type=int,
236
+ default=1,
237
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--gradient_checkpointing",
242
+ action="store_true",
243
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--learning_rate",
248
+ type=float,
249
+ default=1e-4,
250
+ help="Initial learning rate (after the potential warmup period) to use.",
251
+ )
252
+
253
+ parser.add_argument(
254
+ "--scale_lr",
255
+ action="store_true",
256
+ default=False,
257
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
258
+ )
259
+
260
+ parser.add_argument(
261
+ "--lr_scheduler",
262
+ type=str,
263
+ default="constant",
264
+ help=(
265
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
266
+ ' "constant", "constant_with_warmup"]'
267
+ ),
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--snr_gamma",
272
+ type=float,
273
+ default=None,
274
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
275
+ "More details here: https://arxiv.org/abs/2303.09556.",
276
+ )
277
+
278
+ parser.add_argument(
279
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--lr_num_cycles",
284
+ type=int,
285
+ default=1,
286
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
287
+ )
288
+
289
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
290
+
291
+ parser.add_argument(
292
+ "--dataloader_num_workers",
293
+ type=int,
294
+ default=0,
295
+ help=(
296
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
297
+ ),
298
+ )
299
+
300
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
301
+
302
+ parser.add_argument(
303
+ "--adam_epsilon",
304
+ type=float,
305
+ default=1e-08,
306
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
307
+ )
308
+
309
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
310
+
311
+ parser.add_argument(
312
+ "--logging_dir",
313
+ type=str,
314
+ default="logs",
315
+ help=(
316
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
317
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--allow_tf32",
322
+ action="store_true",
323
+ help=(
324
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
325
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
326
+ ),
327
+ )
328
+
329
+ parser.add_argument(
330
+ "--report_to",
331
+ type=str,
332
+ default="wandb",
333
+ help=(
334
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
335
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
336
+ ),
337
+ )
338
+
339
+ parser.add_argument(
340
+ "--mixed_precision",
341
+ type=str,
342
+ default=None,
343
+ choices=["no", "fp16", "bf16"],
344
+ help=(
345
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
346
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
347
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
348
+ ),
349
+ )
350
+
351
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
352
+
353
+ parser.add_argument(
354
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
355
+ )
356
+
357
+ parser.add_argument(
358
+ "--rank",
359
+ type=int,
360
+ default=4,
361
+ help=("The dimension of the LoRA update matrices."),
362
+ )
363
+
364
+ parser.add_argument(
365
+ "--pretrained_lcm_lora_path",
366
+ type=str,
367
+ default="latent-consistency/lcm-lora-sdxl",
368
+ help=("Path for lcm lora pretrained"),
369
+ )
370
+
371
+ parser.add_argument(
372
+ "--losses_config_path",
373
+ type=str,
374
+ required=True,
375
+ help=("A yaml file containing losses to use and their weights."),
376
+ )
377
+
378
+ parser.add_argument(
379
+ "--lcm_every_k_steps",
380
+ type=int,
381
+ default=-1,
382
+ help="How often to run lcm. If -1, lcm is not run."
383
+ )
384
+
385
+ parser.add_argument(
386
+ "--lcm_batch_size",
387
+ type=int,
388
+ default=1,
389
+ help="Batch size for lcm."
390
+ )
391
+ parser.add_argument(
392
+ "--lcm_max_timestep",
393
+ type=int,
394
+ default=1000,
395
+ help="Max timestep to use with LCM."
396
+ )
397
+
398
+ parser.add_argument(
399
+ "--lcm_sample_scale_every_k_steps",
400
+ type=int,
401
+ default=-1,
402
+ help="How often to change lcm scale. If -1, scale is fixed at 1."
403
+ )
404
+
405
+ parser.add_argument(
406
+ "--lcm_min_scale",
407
+ type=float,
408
+ default=0.1,
409
+ help="When sampling lcm scale, the minimum scale to use."
410
+ )
411
+
412
+ parser.add_argument(
413
+ "--scale_lcm_by_max_step",
414
+ action="store_true",
415
+ help="scale LCM lora alpha linearly by the maximal timestep sampled that iteration"
416
+ )
417
+
418
+ parser.add_argument(
419
+ "--lcm_sample_full_lcm_prob",
420
+ type=float,
421
+ default=0.2,
422
+ help="When sampling lcm scale, the probability of using full lcm (scale of 1)."
423
+ )
424
+
425
+ parser.add_argument(
426
+ "--run_on_cpu",
427
+ action="store_true",
428
+ help="whether to run on cpu or not"
429
+ )
430
+
431
+ parser.add_argument(
432
+ "--experiment_name",
433
+ type=str,
434
+ help=("A short description of the experiment to add to the wand run log. "),
435
+ )
436
+ parser.add_argument("--encoder_lora_rank", type=int, default=0, help="Rank of Lora in unet encoder. 0 means no lora")
437
+
438
+ parser.add_argument("--kvcopy_lora_rank", type=int, default=0, help="Rank of lora in the kvcopy modules. 0 means no lora")
439
+
440
+
441
+ if input_args is not None:
442
+ args = parser.parse_args(input_args)
443
+ else:
444
+ args = parser.parse_args()
445
+
446
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
447
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
448
+ args.local_rank = env_local_rank
449
+
450
+ args.optimizer = "AdamW"
451
+
452
+ return args