RohitGandikota
fixing training
94e1b95
raw
history blame
15.3 kB
# ref:
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
from typing import List, Optional
import argparse
import ast
from pathlib import Path
import gc
import torch
from tqdm import tqdm
from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
import trainscripts.textsliders.train_util as train_util
import trainscripts.textsliders.model_util as model_util
import trainscripts.textsliders.prompt_util as prompt_util
from trainscripts.textsliders.prompt_util import (
PromptEmbedsCache,
PromptEmbedsPair,
PromptSettings,
PromptEmbedsXL,
)
import trainscripts.textsliders.debug_util as debug_util
import trainscripts.textsliders.config_util as config_util
from trainscripts.textsliders.config_util import RootConfig
import wandb
NUM_IMAGES_PER_PROMPT = 1
def flush():
torch.cuda.empty_cache()
gc.collect()
def train(
config: RootConfig,
prompts: list[PromptSettings],
device,
):
metadata = {
"prompts": ",".join([prompt.json() for prompt in prompts]),
"config": config.json(),
}
save_path = Path(config.save.path)
modules = DEFAULT_TARGET_REPLACE
if config.network.type == "c3lier":
modules += UNET_TARGET_REPLACE_MODULE_CONV
if config.logging.verbose:
print(metadata)
if config.logging.use_wandb:
wandb.init(project=f"LECO_{config.save.name}", config=metadata)
weight_dtype = config_util.parse_precision(config.train.precision)
save_weight_dtype = config_util.parse_precision(config.train.precision)
(
tokenizers,
text_encoders,
unet,
noise_scheduler,
) = model_util.load_models_xl(
config.pretrained_model.name_or_path,
scheduler_name=config.train.noise_scheduler,
)
for text_encoder in text_encoders:
text_encoder.to(device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
unet.to(device, dtype=weight_dtype)
if config.other.use_xformers:
unet.enable_xformers_memory_efficient_attention()
unet.requires_grad_(False)
unet.eval()
network = LoRANetwork(
unet,
rank=config.network.rank,
multiplier=1.0,
alpha=config.network.alpha,
train_method=config.network.training_method,
).to(device, dtype=weight_dtype)
optimizer_module = train_util.get_optimizer(config.train.optimizer)
#optimizer_args
optimizer_kwargs = {}
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
for arg in config.train.optimizer_args.split(" "):
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
lr_scheduler = train_util.get_lr_scheduler(
config.train.lr_scheduler,
optimizer,
max_iterations=config.train.iterations,
lr_min=config.train.lr / 100,
)
criteria = torch.nn.MSELoss()
print("Prompts")
for settings in prompts:
print(settings)
# debug
debug_util.check_requires_grad(network)
debug_util.check_training_mode(network)
cache = PromptEmbedsCache()
prompt_pairs: list[PromptEmbedsPair] = []
with torch.no_grad():
for settings in prompts:
print(settings)
for prompt in [
settings.target,
settings.positive,
settings.neutral,
settings.unconditional,
]:
if cache[prompt] == None:
tex_embs, pool_embs = train_util.encode_prompts_xl(
tokenizers,
text_encoders,
[prompt],
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
)
cache[prompt] = PromptEmbedsXL(
tex_embs,
pool_embs
)
prompt_pairs.append(
PromptEmbedsPair(
criteria,
cache[settings.target],
cache[settings.positive],
cache[settings.unconditional],
cache[settings.neutral],
settings,
)
)
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
del tokenizer, text_encoder
flush()
pbar = tqdm(range(config.train.iterations))
loss = None
for i in pbar:
with torch.no_grad():
noise_scheduler.set_timesteps(
config.train.max_denoising_steps, device=device
)
optimizer.zero_grad()
prompt_pair: PromptEmbedsPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
# 1 ~ 49 からランダム
timesteps_to = torch.randint(
1, config.train.max_denoising_steps, (1,)
).item()
height, width = prompt_pair.resolution, prompt_pair.resolution
if prompt_pair.dynamic_resolution:
height, width = train_util.get_random_resolution_in_bucket(
prompt_pair.resolution
)
if config.logging.verbose:
print("gudance_scale:", prompt_pair.guidance_scale)
print("resolution:", prompt_pair.resolution)
print("dynamic_resolution:", prompt_pair.dynamic_resolution)
if prompt_pair.dynamic_resolution:
print("bucketed resolution:", (height, width))
print("batch_size:", prompt_pair.batch_size)
print("dynamic_crops:", prompt_pair.dynamic_crops)
latents = train_util.get_initial_latents(
noise_scheduler, prompt_pair.batch_size, height, width, 1
).to(device, dtype=weight_dtype)
add_time_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=prompt_pair.dynamic_crops,
dtype=weight_dtype,
).to(device, dtype=weight_dtype)
with network:
# ちょっとデノイズされれたものが返る
denoised_latents = train_util.diffusion_xl(
unet,
noise_scheduler,
latents, # 単純なノイズのlatentsを渡す
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.target.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.target.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
start_timesteps=0,
total_timesteps=timesteps_to,
guidance_scale=3,
)
noise_scheduler.set_timesteps(1000)
current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / config.train.max_denoising_steps)
]
# with network: の外では空のLoRAのみが有効になる
positive_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.positive.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.positive.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=weight_dtype)
neutral_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.neutral.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.neutral.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=weight_dtype)
unconditional_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.unconditional.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.unconditional.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=weight_dtype)
if config.logging.verbose:
print("positive_latents:", positive_latents[0, 0, :5, :5])
print("neutral_latents:", neutral_latents[0, 0, :5, :5])
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
with network:
target_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.target.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.target.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=weight_dtype)
if config.logging.verbose:
print("target_latents:", target_latents[0, 0, :5, :5])
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
unconditional_latents.requires_grad = False
loss = prompt_pair.loss(
target_latents=target_latents,
positive_latents=positive_latents,
neutral_latents=neutral_latents,
unconditional_latents=unconditional_latents,
)
# 1000倍しないとずっと0.000...になってしまって見た目的に面白くない
pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}")
if config.logging.use_wandb:
wandb.log(
{"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]}
)
loss.backward()
optimizer.step()
lr_scheduler.step()
del (
positive_latents,
neutral_latents,
unconditional_latents,
target_latents,
latents,
)
flush()
# if (
# i % config.save.per_steps == 0
# and i != 0
# and i != config.train.iterations - 1
# ):
# print("Saving...")
# save_path.mkdir(parents=True, exist_ok=True)
# network.save_weights(
# save_path / f"{config.save.name}_{i}steps.pt",
# dtype=save_weight_dtype,
# )
print("Saving...")
save_path.mkdir(parents=True, exist_ok=True)
network.save_weights(
save_path / f"{config.save.name}",
dtype=save_weight_dtype,
)
del (
unet,
noise_scheduler,
loss,
optimizer,
network,
)
flush()
print("Done.")
# def main(args):
# config_file = args.config_file
# config = config_util.load_config_from_yaml(config_file)
# if args.name is not None:
# config.save.name = args.name
# attributes = []
# if args.attributes is not None:
# attributes = args.attributes.split(',')
# attributes = [a.strip() for a in attributes]
# config.network.alpha = args.alpha
# config.network.rank = args.rank
# config.save.name += f'_alpha{args.alpha}'
# config.save.name += f'_rank{config.network.rank }'
# config.save.name += f'_{config.network.training_method}'
# config.save.path += f'/{config.save.name}'
# prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
# device = torch.device(f"cuda:{args.device}")
# train(config, prompts, device)
def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
config = config_util.load_config_from_yaml(config_file)
randn = torch.randint(1, 10000000, (1,)).item()
config.save.name = save_name
config.train.lr = float(lr)
config.train.iterations=int(iterations)
if attributes is not None:
attributes = attributes.split(',')
attributes = [a.strip() for a in attributes]
else:
attributes = []
config.network.alpha = 1.0
config.network.rank = int(rank)
config.save.path += f'/{config.save.name}'
prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes)
device = torch.device(device)
train(config, prompts, device)