Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import datetime | |
import os | |
import sys | |
import time | |
import types | |
import warnings | |
from pathlib import Path | |
current_file_path = Path(__file__).resolve() | |
sys.path.insert(0, str(current_file_path.parent.parent)) | |
import numpy as np | |
import torch | |
from accelerate import Accelerator, InitProcessGroupKwargs | |
from accelerate.utils import DistributedType | |
from diffusers.models import AutoencoderKL | |
from transformers import T5EncoderModel, T5Tokenizer | |
from mmcv.runner import LogBuffer | |
from PIL import Image | |
from torch.utils.data import RandomSampler | |
from diffusion import IDDPM, DPMS | |
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root | |
from diffusion.model.builder import build_model | |
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint | |
from diffusion.utils.data_sampler import AspectRatioBatchSampler | |
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_, flush | |
from diffusion.utils.logger import get_root_logger, rename_file_with_creation_time | |
from diffusion.utils.lr_scheduler import build_lr_scheduler | |
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow | |
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr | |
warnings.filterwarnings("ignore") # ignore warning | |
def set_fsdp_env(): | |
os.environ["ACCELERATE_USE_FSDP"] = 'true' | |
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP' | |
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE' | |
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock' | |
def log_validation(model, step, device, vae=None): | |
torch.cuda.empty_cache() | |
model = accelerator.unwrap_model(model).eval() | |
hw = torch.tensor([[1024, 1024]], dtype=torch.float, device=device).repeat(1, 1) | |
ar = torch.tensor([[1.]], device=device).repeat(1, 1) | |
null_y = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth') | |
null_y = null_y['uncond_prompt_embeds'].to(device) | |
# Create sampling noise: | |
logger.info("Running validation... ") | |
image_logs = [] | |
latents = [] | |
for prompt in validation_prompts: | |
z = torch.randn(1, 4, latent_size, latent_size, device=device) | |
embed = torch.load(f'output/tmp/{prompt}_{max_length}token.pth', map_location='cpu') | |
caption_embs, emb_masks = embed['caption_embeds'].to(device), embed['emb_mask'].to(device) | |
# caption_embs = caption_embs[:, None] | |
# emb_masks = emb_masks[:, None] | |
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) | |
dpm_solver = DPMS(model.forward_with_dpmsolver, | |
condition=caption_embs, | |
uncondition=null_y, | |
cfg_scale=4.5, | |
model_kwargs=model_kwargs) | |
denoised = dpm_solver.sample( | |
z, | |
steps=14, | |
order=2, | |
skip_type="time_uniform", | |
method="multistep", | |
) | |
latents.append(denoised) | |
torch.cuda.empty_cache() | |
if vae is None: | |
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).to(accelerator.device).to(torch.float16) | |
for prompt, latent in zip(validation_prompts, latents): | |
latent = latent.to(torch.float16) | |
samples = vae.decode(latent.detach() / vae.config.scaling_factor).sample | |
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] | |
image = Image.fromarray(samples) | |
image_logs.append({"validation_prompt": prompt, "images": [image]}) | |
for tracker in accelerator.trackers: | |
if tracker.name == "tensorboard": | |
for log in image_logs: | |
images = log["images"] | |
validation_prompt = log["validation_prompt"] | |
formatted_images = [] | |
for image in images: | |
formatted_images.append(np.asarray(image)) | |
formatted_images = np.stack(formatted_images) | |
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") | |
elif tracker.name == "wandb": | |
import wandb | |
formatted_images = [] | |
for log in image_logs: | |
images = log["images"] | |
validation_prompt = log["validation_prompt"] | |
for image in images: | |
image = wandb.Image(image, caption=validation_prompt) | |
formatted_images.append(image) | |
tracker.log({"validation": formatted_images}) | |
else: | |
logger.warn(f"image logging not implemented for {tracker.name}") | |
del vae | |
flush() | |
return image_logs | |
def train(): | |
if config.get('debug_nan', False): | |
DebugUnderflowOverflow(model) | |
logger.info('NaN debugger registered. Start to detect overflow during training.') | |
time_start, last_tic = time.time(), time.time() | |
log_buffer = LogBuffer() | |
global_step = start_step + 1 | |
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False) | |
load_t5_feat = getattr(train_dataloader.dataset, 'load_t5_feat', False) | |
# Now you train the model | |
for epoch in range(start_epoch + 1, config.num_epochs + 1): | |
data_time_start= time.time() | |
data_time_all = 0 | |
for step, batch in enumerate(train_dataloader): | |
if step < skip_step: | |
global_step += 1 | |
continue # skip data in the resumed ckpt | |
if load_vae_feat: | |
z = batch[0] | |
else: | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(enabled=(config.mixed_precision == 'fp16' or config.mixed_precision == 'bf16')): | |
posterior = vae.encode(batch[0]).latent_dist | |
if config.sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
clean_images = z * config.scale_factor | |
data_info = batch[3] | |
if load_t5_feat: | |
y = batch[1] | |
y_mask = batch[2] | |
else: | |
with torch.no_grad(): | |
txt_tokens = tokenizer( | |
batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(accelerator.device) | |
y = text_encoder( | |
txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] | |
y_mask = txt_tokens.attention_mask[:, None, None] | |
# Sample a random timestep for each image | |
bs = clean_images.shape[0] | |
timesteps = torch.randint(0, config.train_sampling_steps, (bs,), device=clean_images.device).long() | |
grad_norm = None | |
data_time_all += time.time() - data_time_start | |
with accelerator.accumulate(model): | |
# Predict the noise residual | |
optimizer.zero_grad() | |
loss_term = train_diffusion.training_losses(model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info)) | |
loss = loss_term['loss'].mean() | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip) | |
optimizer.step() | |
lr_scheduler.step() | |
lr = lr_scheduler.get_last_lr()[0] | |
logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} | |
if grad_norm is not None: | |
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) | |
log_buffer.update(logs) | |
if (step + 1) % config.log_interval == 0 or (step + 1) == 1: | |
t = (time.time() - last_tic) / config.log_interval | |
t_d = data_time_all / config.log_interval | |
avg_time = (time.time() - time_start) / (global_step + 1) | |
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) | |
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1)))) | |
log_buffer.average() | |
info = f"Step/Epoch [{global_step}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \ | |
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({model.module.h}, {model.module.w}), " | |
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) | |
logger.info(info) | |
last_tic = time.time() | |
log_buffer.clear() | |
data_time_all = 0 | |
logs.update(lr=lr) | |
accelerator.log(logs, step=global_step) | |
global_step += 1 | |
data_time_start = time.time() | |
if global_step % config.save_model_steps == 0: | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
os.umask(0o000) | |
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), | |
epoch=epoch, | |
step=global_step, | |
model=accelerator.unwrap_model(model), | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler | |
) | |
if config.visualize and (global_step % config.eval_sampling_steps == 0 or (step + 1) == 1): | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
log_validation(model, global_step, device=accelerator.device, vae=vae) | |
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs: | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
os.umask(0o000) | |
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), | |
epoch=epoch, | |
step=global_step, | |
model=accelerator.unwrap_model(model), | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler | |
) | |
accelerator.wait_for_everyone() | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Process some integers.") | |
parser.add_argument("config", type=str, help="config") | |
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine") | |
parser.add_argument('--work-dir', help='the dir to save logs and models') | |
parser.add_argument('--resume-from', help='the dir to resume the training') | |
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training') | |
parser.add_argument('--local-rank', type=int, default=-1) | |
parser.add_argument('--local_rank', type=int, default=-1) | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument( | |
"--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', | |
type=str, help="Download for loading text_encoder, " | |
"tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" | |
) | |
parser.add_argument( | |
"--report_to", | |
type=str, | |
default="tensorboard", | |
help=( | |
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
), | |
) | |
parser.add_argument( | |
"--tracker_project_name", | |
type=str, | |
default="text2image-fine-tune", | |
help=( | |
"The `project_name` argument passed to Accelerator.init_trackers for" | |
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" | |
), | |
) | |
parser.add_argument("--loss_report_name", type=str, default="loss") | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
config = read_config(args.config) | |
if args.work_dir is not None: | |
config.work_dir = args.work_dir | |
if args.resume_from is not None: | |
config.load_from = None | |
config.resume_from = dict( | |
checkpoint=args.resume_from, | |
load_ema=False, | |
resume_optimizer=True, | |
resume_lr_scheduler=True) | |
if args.debug: | |
config.log_interval = 1 | |
config.train_batch_size = 2 | |
os.umask(0o000) | |
os.makedirs(config.work_dir, exist_ok=True) | |
init_handler = InitProcessGroupKwargs() | |
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug | |
# Initialize accelerator and tensorboard logging | |
if config.use_fsdp: | |
init_train = 'FSDP' | |
from accelerate import FullyShardedDataParallelPlugin | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig | |
set_fsdp_env() | |
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),) | |
else: | |
init_train = 'DDP' | |
fsdp_plugin = None | |
even_batches = True | |
if config.multi_scale: | |
even_batches=False, | |
accelerator = Accelerator( | |
mixed_precision=config.mixed_precision, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
log_with=args.report_to, | |
project_dir=os.path.join(config.work_dir, "logs"), | |
fsdp_plugin=fsdp_plugin, | |
even_batches=even_batches, | |
kwargs_handlers=[init_handler] | |
) | |
log_name = 'train_log.log' | |
if accelerator.is_main_process: | |
if os.path.exists(os.path.join(config.work_dir, log_name)): | |
rename_file_with_creation_time(os.path.join(config.work_dir, log_name)) | |
logger = get_root_logger(os.path.join(config.work_dir, log_name)) | |
logger.info(accelerator.state) | |
config.seed = init_random_seed(config.get('seed', None)) | |
set_random_seed(config.seed) | |
if accelerator.is_main_process: | |
config.dump(os.path.join(config.work_dir, 'config.py')) | |
logger.info(f"Config: \n{config.pretty_text}") | |
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}") | |
logger.info(f"Initializing: {init_train} for training") | |
image_size = config.image_size # @param [256, 512] | |
latent_size = int(image_size) // 8 | |
pred_sigma = getattr(config, 'pred_sigma', True) | |
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma | |
max_length = config.model_max_length | |
kv_compress_config = config.kv_compress_config if config.kv_compress else None | |
vae = None | |
if not config.data.load_vae_feat: | |
vae = AutoencoderKL.from_pretrained(config.vae_pretrained, torch_dtype=torch.float16).to(accelerator.device) | |
config.scale_factor = vae.config.scaling_factor | |
tokenizer = text_encoder = None | |
if not config.data.load_t5_feat: | |
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained( | |
args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) | |
logger.info(f"vae sacle factor: {config.scale_factor}") | |
if config.visualize: | |
# preparing embeddings for visualization. We put it here for saving GPU memory | |
validation_prompts = [ | |
"dog", | |
"portrait photo of a girl, photograph, highly detailed face, depth of field", | |
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", | |
] | |
skip = True | |
for prompt in validation_prompts: | |
if not (os.path.exists(f'output/tmp/{prompt}_{max_length}token.pth') | |
and os.path.exists(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth')): | |
skip = False | |
logger.info("Preparing Visualization prompt embeddings...") | |
break | |
if accelerator.is_main_process and not skip: | |
if config.data.load_t5_feat and (tokenizer is None or text_encoder is None): | |
logger.info(f"Loading text encoder and tokenizer from {args.pipeline_load_from} ...") | |
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained( | |
args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) | |
for prompt in validation_prompts: | |
txt_tokens = tokenizer( | |
prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(accelerator.device) | |
caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] | |
torch.save( | |
{'caption_embeds': caption_emb, 'emb_mask': txt_tokens.attention_mask}, | |
f'output/tmp/{prompt}_{max_length}token.pth') | |
null_tokens = tokenizer( | |
"", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(accelerator.device) | |
null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] | |
torch.save( | |
{'uncond_prompt_embeds': null_token_emb, 'uncond_prompt_embeds_mask': null_tokens.attention_mask}, | |
f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth') | |
if config.data.load_t5_feat: | |
del tokenizer | |
del txt_tokens | |
flush() | |
model_kwargs={"pe_interpolation": config.pe_interpolation, "config":config, | |
"model_max_length": max_length, "qk_norm": config.qk_norm, | |
"kv_compress_config": kv_compress_config, "micro_condition": config.micro_condition} | |
# build models | |
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, snr=config.snr_loss) | |
model = build_model(config.model, | |
config.grad_checkpointing, | |
config.get('fp32_attention', False), | |
input_size=latent_size, | |
learn_sigma=learn_sigma, | |
pred_sigma=pred_sigma, | |
**model_kwargs).train() | |
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
if args.load_from is not None: | |
config.load_from = args.load_from | |
if config.load_from is not None: | |
missing, unexpected = load_checkpoint( | |
config.load_from, model, load_ema=config.get('load_ema', False), max_length=max_length) | |
logger.warning(f'Missing keys: {missing}') | |
logger.warning(f'Unexpected keys: {unexpected}') | |
# prepare for FSDP clip grad norm calculation | |
if accelerator.distributed_type == DistributedType.FSDP: | |
for m in accelerator._models: | |
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) | |
# build dataloader | |
set_data_root(config.data_root) | |
dataset = build_dataset( | |
config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type, | |
real_prompt_ratio=config.real_prompt_ratio, max_length=max_length, config=config, | |
) | |
if config.multi_scale: | |
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, | |
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True, | |
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num) | |
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers) | |
else: | |
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True) | |
# build optimizer and lr scheduler | |
lr_scale_ratio = 1 | |
if config.get('auto_lr', None): | |
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps, | |
config.optimizer, **config.auto_lr) | |
optimizer = build_optimizer(model, config.optimizer) | |
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio) | |
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) | |
if accelerator.is_main_process: | |
tracker_config = dict(vars(config)) | |
try: | |
accelerator.init_trackers(args.tracker_project_name, tracker_config) | |
except: | |
accelerator.init_trackers(f"tb_{timestamp}") | |
start_epoch = 0 | |
start_step = 0 | |
skip_step = config.skip_step | |
total_steps = len(train_dataloader) * config.num_epochs | |
if config.resume_from is not None and config.resume_from['checkpoint'] is not None: | |
resume_path = config.resume_from['checkpoint'] | |
path = os.path.basename(resume_path) | |
start_epoch = int(path.replace('.pth', '').split("_")[1]) - 1 | |
start_step = int(path.replace('.pth', '').split("_")[3]) | |
_, missing, unexpected = load_checkpoint(**config.resume_from, | |
model=model, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
max_length=max_length, | |
) | |
logger.warning(f'Missing keys: {missing}') | |
logger.warning(f'Unexpected keys: {unexpected}') | |
# Prepare everything | |
# There is no specific order to remember, you just need to unpack the | |
# objects in the same order you gave them to the prepare method. | |
model = accelerator.prepare(model) | |
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) | |
train() | |