Open-Sora-Plan-v1.1.0 / opensora /train /train_causalvae.py
LinB203
update
bab971b
raw
history blame
5.01 kB
import sys
sys.path.append(".")
import torch
import random
import numpy as np
from opensora.models.ae.videobase import (
CausalVAEModel,
)
from torch.utils.data import DataLoader
from opensora.models.ae.videobase.dataset_videobase import VideoDataset
import argparse
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
import torch.distributed as dist
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
@dataclass
class TrainingArguments:
exp_name: str = field(
default="causalvae", metadata={"help": "The name of the experiment."}
)
batch_size: int = field(
default=1, metadata={"help": "The number of samples per training iteration."}
)
precision: str = field(
default="bf16",
metadata={"help": "The precision type used for training."},
)
max_steps: int = field(
default=100000,
metadata={"help": "The maximum number of steps for the training process."},
)
save_steps: int = field(
default=2000,
metadata={"help": "The interval at which to save the model during training."},
)
output_dir: str = field(
default="results/causalvae",
metadata={"help": "The directory where training results are saved."},
)
video_path: str = field(
default="/remote-home1/dataset/data_split_tt",
metadata={"help": "The path where the video data is stored."},
)
video_num_frames: int = field(
default=17, metadata={"help": "The number of frames per video."}
)
sample_rate: int = field(
default=1,
metadata={
"help": "The sampling interval."
},
)
dynamic_sample: bool = field(
default=False, metadata={"help": "Whether to use dynamic sampling."}
)
model_config: str = field(
default="scripts/causalvae/288.yaml",
metadata={"help": "The path to the model configuration file."},
)
n_nodes: int = field(
default=1, metadata={"help": "The number of nodes used for training."}
)
devices: int = field(
default=8, metadata={"help": "The number of devices used for training."}
)
resolution: int = field(
default=256, metadata={"help": "The resolution of the videos."}
)
num_workers: int = field(
default=8,
metadata={"help": "The number of subprocesses used for data handling."},
)
resume_from_checkpoint: str = field(
default=None, metadata={"help": "Resume training from a specified checkpoint."}
)
load_from_checkpoint: str = field(
default=None, metadata={"help": "Load the model from a specified checkpoint."}
)
def set_seed(seed=1006):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def load_callbacks_and_logger(args):
checkpoint_callback = ModelCheckpoint(
dirpath=args.output_dir,
filename="model-{epoch:02d}-{step}",
every_n_train_steps=args.save_steps,
save_top_k=-1,
save_on_train_epoch_end=False,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
logger = WandbLogger(name=args.exp_name, log_model=False)
return [checkpoint_callback, lr_monitor], logger
def train(args):
set_seed()
# Load Config
model = CausalVAEModel()
if args.load_from_checkpoint is not None:
model = CausalVAEModel.from_pretrained(args.load_from_checkpoint)
else:
model = CausalVAEModel.from_config(args.model_config)
if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized():
print(model)
# Load Dataset
dataset = VideoDataset(
args.video_path,
sequence_length=args.video_num_frames,
resolution=args.resolution,
sample_rate=args.sample_rate,
dynamic_sample=args.dynamic_sample,
)
train_loader = DataLoader(
dataset,
shuffle=True,
num_workers=args.num_workers,
batch_size=args.batch_size,
pin_memory=True,
)
# Load Callbacks and Logger
callbacks, logger = load_callbacks_and_logger(args)
# Load Trainer
trainer = pl.Trainer(
accelerator="cuda",
devices=args.devices,
num_nodes=args.n_nodes,
callbacks=callbacks,
logger=logger,
log_every_n_steps=5,
precision=args.precision,
max_steps=args.max_steps,
strategy="ddp_find_unused_parameters_true",
)
trainer_kwargs = {}
if args.resume_from_checkpoint:
trainer_kwargs["ckpt_path"] = args.resume_from_checkpoint
trainer.fit(model, train_loader, **trainer_kwargs)
# Save Huggingface Model
model.save_pretrained(os.path.join(args.output_dir, "hf"))
if __name__ == "__main__":
parser = HfArgumentParser(TrainingArguments)
args = parser.parse_args_into_dataclasses()
train(args[0])