Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import sys | |
import types | |
from pathlib import Path | |
current_file_path = Path(__file__).resolve() | |
sys.path.insert(0, str(current_file_path.parent.parent)) | |
import argparse | |
import datetime | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") # ignore warning | |
import torch | |
import torch.nn as nn | |
from accelerate import Accelerator, InitProcessGroupKwargs | |
from accelerate.utils import DistributedType | |
from diffusers.models import AutoencoderKL | |
from transformers import T5EncoderModel, T5Tokenizer | |
from torch.utils.data import RandomSampler | |
from mmcv.runner import LogBuffer | |
from copy import deepcopy | |
import numpy as np | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from PIL import Image | |
import gc | |
from diffusion import IDDPM | |
from diffusion.utils.checkpoint import save_checkpoint, load_checkpoint | |
from diffusion.utils.dist_utils import synchronize, get_world_size, clip_grad_norm_, flush | |
from diffusion.data.builder import build_dataset, build_dataloader, set_data_root | |
from diffusion.model.builder import build_model | |
from diffusion.utils.logger import get_root_logger | |
from diffusion.utils.misc import set_random_seed, read_config, init_random_seed, DebugUnderflowOverflow | |
from diffusion.utils.optimizer import build_optimizer, auto_scale_lr | |
from diffusion.utils.lr_scheduler import build_lr_scheduler | |
from diffusion.utils.data_sampler import AspectRatioBatchSampler, BalancedAspectRatioBatchSampler | |
from diffusion.lcm_scheduler import LCMScheduler | |
from torchvision.utils import save_image | |
def set_fsdp_env(): | |
os.environ["ACCELERATE_USE_FSDP"] = 'true' | |
os.environ["FSDP_AUTO_WRAP_POLICY"] = 'TRANSFORMER_BASED_WRAP' | |
os.environ["FSDP_BACKWARD_PREFETCH"] = 'BACKWARD_PRE' | |
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = 'PixArtBlock' | |
def ema_update(model_dest: nn.Module, model_src: nn.Module, rate): | |
param_dict_src = dict(model_src.named_parameters()) | |
for p_name, p_dest in model_dest.named_parameters(): | |
p_src = param_dict_src[p_name] | |
assert p_src is not p_dest | |
p_dest.data.mul_(rate).add_((1 - rate) * p_src.data) | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
return x[(...,) + (None,) * dims_to_append] | |
# From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) | |
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 | |
return c_skip, c_out | |
def extract_into_tensor(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
class DDIMSolver: | |
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
# DDIM sampling parameters | |
step_ratio = timesteps // ddim_timesteps | |
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
self.ddim_alpha_cumprods_prev = np.asarray( | |
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
) | |
# convert to torch tensors | |
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) | |
def to(self, device): | |
self.ddim_timesteps = self.ddim_timesteps.to(device) | |
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) | |
return self | |
def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) | |
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
return x_prev | |
def log_validation(model, step, device): | |
torch.cuda.empty_cache() | |
model = accelerator.unwrap_model(model).eval() | |
scheduler = LCMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule="linear", prediction_type="epsilon") | |
scheduler.set_timesteps(4, 50) | |
infer_timesteps = scheduler.timesteps | |
hw = torch.tensor([[1024, 1024]], dtype=torch.float, device=device).repeat(1, 1) | |
ar = torch.tensor([[1.]], device=device).repeat(1, 1) | |
# Create sampling noise: | |
logger.info("Running validation... ") | |
image_logs = [] | |
latents = [] | |
for prompt in validation_prompts: | |
infer_latents = torch.randn(1, 4, latent_size, latent_size, device=device) | |
embed = torch.load(f'output/tmp/{prompt}_{max_length}token.pth', map_location='cpu') | |
caption_embs, emb_masks = embed['caption_embeds'].to(device), embed['emb_mask'].to(device) | |
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks) | |
# 7. LCM MultiStep Sampling Loop: | |
for i, t in tqdm(list(enumerate(infer_timesteps))): | |
ts = torch.full((1,), t, device=device, dtype=torch.long) | |
# model prediction (v-prediction, eps, x) | |
model_pred = model(infer_latents, ts, caption_embs, **model_kwargs)[:, :4] | |
# compute the previous noisy sample x_t -> x_t-1 | |
infer_latents, denoised = scheduler.step(model_pred, i, t, infer_latents, return_dict=False) | |
latents.append(denoised) | |
torch.cuda.empty_cache() | |
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda() | |
for prompt, latent in zip(validation_prompts, latents): | |
samples = vae.decode(latent.detach() / vae.config.scaling_factor).sample | |
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] | |
image = Image.fromarray(samples) | |
image_logs.append({"validation_prompt": prompt, "images": [image]}) | |
for tracker in accelerator.trackers: | |
if tracker.name == "tensorboard": | |
for log in image_logs: | |
images = log["images"] | |
validation_prompt = log["validation_prompt"] | |
formatted_images = [] | |
for image in images: | |
formatted_images.append(np.asarray(image)) | |
formatted_images = np.stack(formatted_images) | |
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") | |
elif tracker.name == "wandb": | |
import wandb | |
formatted_images = [] | |
for log in image_logs: | |
images = log["images"] | |
validation_prompt = log["validation_prompt"] | |
for image in images: | |
image = wandb.Image(image, caption=validation_prompt) | |
formatted_images.append(image) | |
tracker.log({"validation": formatted_images}) | |
else: | |
logger.warn(f"image logging not implemented for {tracker.name}") | |
gc.collect() | |
torch.cuda.empty_cache() | |
return image_logs | |
def train(): | |
if config.get('debug_nan', False): | |
DebugUnderflowOverflow(model) | |
logger.info('NaN debugger registered. Start to detect overflow during training.') | |
time_start, last_tic = time.time(), time.time() | |
log_buffer = LogBuffer() | |
start_step = start_epoch * len(train_dataloader) | |
global_step = 0 | |
total_steps = len(train_dataloader) * config.num_epochs | |
load_vae_feat = getattr(train_dataloader.dataset, 'load_vae_feat', False) | |
load_t5_feat = getattr(train_dataloader.dataset, 'load_t5_feat', False) | |
# Create uncond embeds for classifier free guidance | |
uncond_prompt_embeds = model.module.y_embedder.y_embedding.repeat(config.train_batch_size, 1, 1, 1) | |
# Now you train the model | |
for epoch in range(start_epoch + 1, config.num_epochs + 1): | |
data_time_start= time.time() | |
data_time_all = 0 | |
for step, batch in enumerate(train_dataloader): | |
data_time_all += time.time() - data_time_start | |
if load_vae_feat: | |
z = batch[0] | |
else: | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(enabled=config.mixed_precision == 'fp16'): | |
posterior = vae.encode(batch[0]).latent_dist | |
if config.sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
latents = z * config.scale_factor | |
data_info = {'img_hw': batch[3]['img_hw'].to(latents.dtype), 'aspect_ratio': batch[3]['aspect_ratio'].to(latents.dtype),} | |
if load_t5_feat: | |
y = batch[1] | |
y_mask = batch[2] | |
else: | |
with torch.no_grad(): | |
txt_tokens = tokenizer( | |
batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(accelerator.device) | |
y = text_encoder( | |
txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] | |
y_mask = txt_tokens.attention_mask[:, None, None] | |
# Sample a random timestep for each image | |
grad_norm = None | |
with accelerator.accumulate(model): | |
# Predict the noise residual | |
optimizer.zero_grad() | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. | |
topk = config.train_sampling_steps // config.num_ddim_timesteps | |
index = torch.randint(0, config.num_ddim_timesteps, (bsz,), device=latents.device).long() | |
start_timesteps = solver.ddim_timesteps[index] | |
timesteps = start_timesteps - topk | |
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) | |
# Get boundary scalings for start_timesteps and (end) timesteps. | |
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) | |
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] | |
c_skip, c_out = scalings_for_boundary_conditions(timesteps) | |
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] | |
# Sample a random guidance scale w from U[w_min, w_max] and embed it | |
# w = (config.w_max - config.w_min) * torch.rand((bsz,)) + config.w_min | |
w = config.cfg_scale * torch.ones((bsz,)) | |
w = w.reshape(bsz, 1, 1, 1) | |
w = w.to(device=latents.device, dtype=latents.dtype) | |
# Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} | |
_, pred_x_0, noisy_model_input = train_diffusion.training_losses( | |
model, latents, start_timesteps, | |
model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), | |
noise=noise | |
) | |
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 | |
# Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after | |
# noisy_latents with both the conditioning embedding c and unconditional embedding 0 | |
# Get teacher model prediction on noisy_latents and conditional embedding | |
with torch.no_grad(): | |
with torch.autocast("cuda"): | |
cond_teacher_output, cond_pred_x0, _ = train_diffusion.training_losses( | |
model_teacher, latents, start_timesteps, | |
model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), | |
noise=noise | |
) | |
# Get teacher model prediction on noisy_latents and unconditional embedding | |
uncond_teacher_output, uncond_pred_x0, _ = train_diffusion.training_losses( | |
model_teacher, latents, start_timesteps, | |
model_kwargs=dict(y=uncond_prompt_embeds, mask=y_mask, data_info=data_info), | |
noise=noise | |
) | |
# Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation) | |
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) | |
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) | |
x_prev = solver.ddim_step(pred_x0, pred_noise, index) | |
# Get target LCM prediction on x_prev, w, c, t_n | |
with torch.no_grad(): | |
with torch.autocast("cuda", enabled=True): | |
_, pred_x_0, _ = train_diffusion.training_losses( | |
model_ema, x_prev.float(), timesteps, | |
model_kwargs=dict(y=y, mask=y_mask, data_info=data_info), | |
skip_noise=True | |
) | |
target = c_skip * x_prev + c_out * pred_x_0 | |
# Calculate loss | |
if config.loss_type == "l2": | |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
elif config.loss_type == "huber": | |
loss = torch.mean(torch.sqrt((model_pred.float() - target.float()) ** 2 + config.huber_c**2) - config.huber_c) | |
# Backpropagation on the online student model (`model`) | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.gradient_clip) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad(set_to_none=True) | |
if accelerator.sync_gradients: | |
ema_update(model_ema, model, config.ema_decay) | |
lr = lr_scheduler.get_last_lr()[0] | |
logs = {"loss": accelerator.gather(loss).mean().item()} | |
if grad_norm is not None: | |
logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) | |
log_buffer.update(logs) | |
if (step + 1) % config.log_interval == 0 or (step + 1) == 1: | |
t = (time.time() - last_tic) / config.log_interval | |
t_d = data_time_all / config.log_interval | |
avg_time = (time.time() - time_start) / (global_step + 1) | |
eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - start_step - global_step - 1)))) | |
eta_epoch = str(datetime.timedelta(seconds=int(avg_time * (len(train_dataloader) - step - 1)))) | |
# avg_loss = sum(loss_buffer) / len(loss_buffer) | |
log_buffer.average() | |
info = f"Step/Epoch [{(epoch-1)*len(train_dataloader)+step+1}/{epoch}][{step + 1}/{len(train_dataloader)}]:total_eta: {eta}, " \ | |
f"epoch_eta:{eta_epoch}, time_all:{t:.3f}, time_data:{t_d:.3f}, lr:{lr:.3e}, s:({data_info['img_hw'][0][0].item()}, {data_info['img_hw'][0][1].item()}), " | |
info += ', '.join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) | |
logger.info(info) | |
last_tic = time.time() | |
log_buffer.clear() | |
data_time_all = 0 | |
logs.update(lr=lr) | |
accelerator.log(logs, step=global_step + start_step) | |
global_step += 1 | |
data_time_start= time.time() | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.save_model_steps == 0: | |
os.umask(0o000) | |
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), | |
epoch=epoch, | |
step=(epoch - 1) * len(train_dataloader) + step + 1, | |
model=accelerator.unwrap_model(model), | |
model_ema=accelerator.unwrap_model(model_ema), | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler | |
) | |
if ((epoch - 1) * len(train_dataloader) + step + 1) % config.eval_sampling_steps == 0: | |
log_validation(model, global_step, device=accelerator.device) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
if epoch % config.save_model_epochs == 0 or epoch == config.num_epochs: | |
os.umask(0o000) | |
save_checkpoint(os.path.join(config.work_dir, 'checkpoints'), | |
epoch=epoch, | |
step=(epoch - 1) * len(train_dataloader) + step + 1, | |
model=accelerator.unwrap_model(model), | |
model_ema=accelerator.unwrap_model(model_ema), | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler | |
) | |
synchronize() | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Process some integers.") | |
parser.add_argument("config", type=str, help="config") | |
parser.add_argument("--cloud", action='store_true', default=False, help="cloud or local machine") | |
parser.add_argument('--work-dir', help='the dir to save logs and models') | |
parser.add_argument('--resume-from', help='the dir to resume the training') | |
parser.add_argument('--load-from', default=None, help='the dir to load a ckpt for training') | |
parser.add_argument('--local-rank', type=int, default=-1) | |
parser.add_argument('--local_rank', type=int, default=-1) | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument( | |
"--pipeline_load_from", default='output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers', | |
type=str, help="Download for loading text_encoder, " | |
"tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" | |
) | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
config = read_config(args.config) | |
if args.work_dir is not None: | |
# update configs according to CLI args if args.work_dir is not None | |
config.work_dir = args.work_dir | |
if args.cloud: | |
config.data_root = '/data/data' | |
if args.resume_from is not None: | |
config.load_from = None | |
config.resume_from = dict( | |
checkpoint=args.resume_from, | |
load_ema=False, | |
resume_optimizer=True, | |
resume_lr_scheduler=True) | |
if args.debug: | |
config.log_interval = 1 | |
config.train_batch_size = 2 | |
os.umask(0o000) | |
os.makedirs(config.work_dir, exist_ok=True) | |
init_handler = InitProcessGroupKwargs() | |
init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug | |
# Initialize accelerator and tensorboard logging | |
if config.use_fsdp: | |
init_train = 'FSDP' | |
from accelerate import FullyShardedDataParallelPlugin | |
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig | |
set_fsdp_env() | |
fsdp_plugin = FullyShardedDataParallelPlugin(state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),) | |
else: | |
init_train = 'DDP' | |
fsdp_plugin = None | |
even_batches = True | |
if config.multi_scale: | |
even_batches=False, | |
accelerator = Accelerator( | |
mixed_precision=config.mixed_precision, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
log_with="tensorboard", | |
project_dir=os.path.join(config.work_dir, "logs"), | |
fsdp_plugin=fsdp_plugin, | |
even_batches=even_batches, | |
kwargs_handlers=[init_handler] | |
) | |
logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log')) | |
config.seed = init_random_seed(config.get('seed', None)) | |
set_random_seed(config.seed) | |
if accelerator.is_main_process: | |
config.dump(os.path.join(config.work_dir, 'config.py')) | |
logger.info(f"Config: \n{config.pretty_text}") | |
logger.info(f"World_size: {get_world_size()}, seed: {config.seed}") | |
logger.info(f"Initializing: {init_train} for training") | |
image_size = config.image_size # @param [256, 512] | |
latent_size = int(image_size) // 8 | |
pred_sigma = getattr(config, 'pred_sigma', True) | |
learn_sigma = getattr(config, 'learn_sigma', True) and pred_sigma | |
max_length = config.model_max_length | |
model_kwargs={"pe_interpolation": config.pe_interpolation, 'config':config, 'model_max_length': max_length} | |
# build models | |
train_diffusion = IDDPM(str(config.train_sampling_steps), learn_sigma=learn_sigma, pred_sigma=pred_sigma, | |
snr=config.snr_loss, return_startx=True) | |
model = build_model(config.model, | |
config.grad_checkpointing, | |
config.get('fp32_attention', False), | |
input_size=latent_size, | |
learn_sigma=learn_sigma, | |
pred_sigma=pred_sigma, | |
**model_kwargs).train() | |
logger.info(f"{model.__class__.__name__} Model Parameters: {sum(p.numel() for p in model.parameters()):,}") | |
if config.load_from is not None: | |
if args.load_from is not None: | |
config.load_from = args.load_from | |
missing, unexpected = load_checkpoint( | |
config.load_from, model, load_ema=config.get('load_ema', False), max_length=max_length) | |
logger.warning(f'Missing keys: {missing}') | |
logger.warning(f'Unexpected keys: {unexpected}') | |
model_ema = deepcopy(model).eval() | |
model_teacher = deepcopy(model).eval() | |
if not config.data.load_vae_feat: | |
vae = AutoencoderKL.from_pretrained(config.vae_pretrained).cuda() | |
# prepare for FSDP clip grad norm calculation | |
if accelerator.distributed_type == DistributedType.FSDP: | |
for m in accelerator._models: | |
m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) | |
tokenizer = text_encoder = None | |
if not config.data.load_t5_feat: | |
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained( | |
args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) | |
logger.info(f"vae sacle factor: {config.scale_factor}") | |
# build dataloader | |
set_data_root(config.data_root) | |
dataset = build_dataset(config.data, resolution=image_size, aspect_ratio_type=config.aspect_ratio_type) | |
if config.multi_scale: | |
batch_sampler = AspectRatioBatchSampler(sampler=RandomSampler(dataset), dataset=dataset, | |
batch_size=config.train_batch_size, aspect_ratios=dataset.aspect_ratio, drop_last=True, | |
ratio_nums=dataset.ratio_nums, config=config, valid_num=config.valid_num) | |
train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.num_workers) | |
else: | |
train_dataloader = build_dataloader(dataset, num_workers=config.num_workers, batch_size=config.train_batch_size, shuffle=True) | |
# preparing embeddings for visualization. We put it here for saving GPU memory | |
validation_prompts = [ | |
"dog", | |
"portrait photo of a girl, photograph, highly detailed face, depth of field", | |
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", | |
] | |
logger.info("Preparing Visulalization prompt embeddings...") | |
skip = True | |
for prompt in validation_prompts: | |
if not os.path.exists(f'output/tmp/{prompt}_{max_length}token.pth'): | |
skip = False | |
break | |
logger.info("Preparing Visualization prompt embeddings...") | |
if accelerator.is_main_process and not skip: | |
if config.data.load_t5_feat and (tokenizer is None or text_encoder is None): | |
logger.info(f"Loading text encoder and tokenizer from {args.pipeline_load_from} ...") | |
tokenizer = T5Tokenizer.from_pretrained(args.pipeline_load_from, subfolder="tokenizer") | |
text_encoder = T5EncoderModel.from_pretrained( | |
args.pipeline_load_from, subfolder="text_encoder", torch_dtype=torch.float16).to(accelerator.device) | |
for prompt in validation_prompts: | |
txt_tokens = tokenizer( | |
prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
).to(accelerator.device) | |
caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] | |
torch.save( | |
{'caption_embeds': caption_emb, 'emb_mask': txt_tokens.attention_mask}, | |
f'output/tmp/{prompt}_{max_length}token.pth') | |
if config.data.load_t5_feat: | |
del tokenizer | |
del txt_tokens | |
flush() | |
time.sleep(5) | |
# build optimizer and lr scheduler | |
lr_scale_ratio = 1 | |
if config.get('auto_lr', None): | |
lr_scale_ratio = auto_scale_lr(config.train_batch_size * get_world_size() * config.gradient_accumulation_steps, | |
config.optimizer, | |
**config.auto_lr) | |
optimizer = build_optimizer(model, config.optimizer) | |
lr_scheduler = build_lr_scheduler(config, optimizer, train_dataloader, lr_scale_ratio) | |
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) | |
if accelerator.is_main_process: | |
accelerator.init_trackers(f"tb_{timestamp}") | |
start_epoch = 0 | |
if config.resume_from is not None and config.resume_from['checkpoint'] is not None: | |
start_epoch, missing, unexpected = load_checkpoint(**config.resume_from, | |
model=model, | |
model_ema=model_ema, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
) | |
logger.warning(f'Missing keys: {missing}') | |
logger.warning(f'Unexpected keys: {unexpected}') | |
solver = DDIMSolver(train_diffusion.alphas_cumprod, timesteps=config.train_sampling_steps, ddim_timesteps=config.num_ddim_timesteps) | |
solver.to(accelerator.device) | |
# Prepare everything | |
# There is no specific order to remember, you just need to unpack the | |
# objects in the same order you gave them to the prepare method. | |
model, model_ema, model_teacher = accelerator.prepare(model, model_ema, model_teacher) | |
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) | |
train() | |