Spaces:
Sleeping
Sleeping
"""Training script for TiTok. | |
Copyright (2024) Bytedance Ltd. and/or its affiliates | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
Reference: | |
https://github.com/huggingface/open-muse | |
""" | |
import math | |
import os | |
from pathlib import Path | |
from accelerate.utils import set_seed | |
from accelerate import Accelerator | |
import torch | |
from omegaconf import OmegaConf | |
from utils.logger import setup_logger | |
from utils.train_utils import ( | |
get_config, create_pretrained_tokenizer, | |
create_model_and_loss_module, | |
create_optimizer, create_lr_scheduler, create_dataloader, | |
create_evaluator, auto_resume, save_checkpoint, | |
train_one_epoch) | |
def main(): | |
workspace = os.environ.get('WORKSPACE', '') | |
if workspace: | |
torch.hub.set_dir(workspace + "/models/hub") | |
config = get_config() | |
# Enable TF32 on Ampere GPUs. | |
if config.training.enable_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = False | |
output_dir = config.experiment.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
config.experiment.logging_dir = os.path.join(output_dir, "logs") | |
# Whether logging to Wandb or Tensorboard. | |
tracker = "tensorboard" | |
if config.training.enable_wandb: | |
tracker = "wandb" | |
accelerator = Accelerator( | |
gradient_accumulation_steps=config.training.gradient_accumulation_steps, | |
mixed_precision=config.training.mixed_precision, | |
log_with=tracker, | |
project_dir=config.experiment.logging_dir, | |
split_batches=False, | |
) | |
logger = setup_logger(name="TiTok", log_level="INFO", | |
output_file=f"{output_dir}/log{accelerator.process_index}.txt") | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
accelerator.init_trackers(config.experiment.name) | |
config_path = Path(output_dir) / "config.yaml" | |
logger.info(f"Saving config to {config_path}") | |
OmegaConf.save(config, config_path) | |
logger.info(f"Config:\n{OmegaConf.to_yaml(config)}") | |
# If passed along, set the training seed now. | |
if config.training.seed is not None: | |
set_seed(config.training.seed, device_specific=True) | |
if accelerator.local_process_index == 0: | |
# download the maskgit-vq tokenizer weight | |
from huggingface_hub import hf_hub_download | |
hf_hub_download(repo_id="fun-research/TiTok", filename=f"{config.model.vq_model.pretrained_tokenizer_weight}", local_dir="./") | |
accelerator.wait_for_everyone() | |
pretrained_tokenizer = create_pretrained_tokenizer(config, | |
accelerator) | |
model, ema_model, loss_module = create_model_and_loss_module( | |
config, logger, accelerator, model_type="titok") | |
optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module) | |
lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler( | |
config, logger, accelerator, optimizer, discriminator_optimizer) | |
train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator) | |
# Set up evaluator. | |
evaluator = create_evaluator(config, logger, accelerator) | |
# Prepare everything with accelerator. | |
logger.info("Preparing model, optimizer and dataloaders") | |
# The dataloader are already aware of distributed training, so we don't need to prepare them. | |
if config.model.vq_model.finetune_decoder: | |
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( | |
model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler | |
) | |
else: | |
model, optimizer, lr_scheduler = accelerator.prepare( | |
model, optimizer, lr_scheduler | |
) | |
if config.training.use_ema: | |
ema_model.to(accelerator.device) | |
total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes | |
num_batches = math.ceil( | |
config.experiment.max_train_examples / total_batch_size_without_accum) | |
num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps) | |
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) | |
# Start training. | |
logger.info("***** Running training *****") | |
logger.info(f" Num training steps = {config.training.max_train_steps}") | |
logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") | |
logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}") | |
logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {( | |
config.training.per_gpu_batch_size * | |
accelerator.num_processes * | |
config.training.gradient_accumulation_steps)}""") | |
global_step = 0 | |
first_epoch = 0 | |
global_step, first_epoch = auto_resume( | |
config, logger, accelerator, ema_model, num_update_steps_per_epoch, | |
strict=True) | |
for current_epoch in range(first_epoch, num_train_epochs): | |
accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.") | |
global_step = train_one_epoch(config, logger, accelerator, | |
model, ema_model, loss_module, | |
optimizer, discriminator_optimizer, | |
lr_scheduler, discriminator_lr_scheduler, | |
train_dataloader, eval_dataloader, | |
evaluator, | |
global_step, | |
pretrained_tokenizer=pretrained_tokenizer) | |
# Stop training if max steps is reached. | |
if global_step >= config.training.max_train_steps: | |
accelerator.print( | |
f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" | |
) | |
break | |
accelerator.wait_for_everyone() | |
# Save checkpoint at the end of training. | |
save_checkpoint(model, output_dir, accelerator, global_step, logger=logger) | |
# Save the final trained checkpoint | |
if accelerator.is_main_process: | |
model = accelerator.unwrap_model(model) | |
if config.training.use_ema: | |
ema_model.copy_to(model.parameters()) | |
model.save_pretrained_weight(output_dir) | |
accelerator.end_training() | |
if __name__ == "__main__": | |
main() |