File size: 5,008 Bytes
bab971b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys

sys.path.append(".")
import torch
import random
import numpy as np
from opensora.models.ae.videobase import (
    CausalVAEModel,
)
from torch.utils.data import DataLoader
from opensora.models.ae.videobase.dataset_videobase import VideoDataset
import argparse
from transformers import HfArgumentParser
from dataclasses import dataclass, field, asdict
import torch.distributed as dist
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


@dataclass
class TrainingArguments:
    exp_name: str = field(
        default="causalvae", metadata={"help": "The name of the experiment."}
    )
    batch_size: int = field(
        default=1, metadata={"help": "The number of samples per training iteration."}
    )
    precision: str = field(
        default="bf16",
        metadata={"help": "The precision type used for training."},
    )
    max_steps: int = field(
        default=100000,
        metadata={"help": "The maximum number of steps for the training process."},
    )
    save_steps: int = field(
        default=2000,
        metadata={"help": "The interval at which to save the model during training."},
    )
    output_dir: str = field(
        default="results/causalvae",
        metadata={"help": "The directory where training results are saved."},
    )
    video_path: str = field(
        default="/remote-home1/dataset/data_split_tt",
        metadata={"help": "The path where the video data is stored."},
    )
    video_num_frames: int = field(
        default=17, metadata={"help": "The number of frames per video."}
    )
    sample_rate: int = field(
        default=1,
        metadata={
            "help": "The sampling interval."
        },
    )
    dynamic_sample: bool = field(
        default=False, metadata={"help": "Whether to use dynamic sampling."}
    )
    model_config: str = field(
        default="scripts/causalvae/288.yaml",
        metadata={"help": "The path to the model configuration file."},
    )
    n_nodes: int = field(
        default=1, metadata={"help": "The number of nodes used for training."}
    )
    devices: int = field(
        default=8, metadata={"help": "The number of devices used for training."}
    )
    resolution: int = field(
        default=256, metadata={"help": "The resolution of the videos."}
    )
    num_workers: int = field(
        default=8,
        metadata={"help": "The number of subprocesses used for data handling."},
    )
    resume_from_checkpoint: str = field(
        default=None, metadata={"help": "Resume training from a specified checkpoint."}
    )
    load_from_checkpoint: str = field(
        default=None, metadata={"help": "Load the model from a specified checkpoint."}
    )


def set_seed(seed=1006):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def load_callbacks_and_logger(args):
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.output_dir,
        filename="model-{epoch:02d}-{step}",
        every_n_train_steps=args.save_steps,
        save_top_k=-1,
        save_on_train_epoch_end=False,
    )
    lr_monitor = LearningRateMonitor(logging_interval="step")
    logger = WandbLogger(name=args.exp_name, log_model=False)
    return [checkpoint_callback, lr_monitor], logger


def train(args):
    set_seed()
    # Load Config
    model = CausalVAEModel()
    if args.load_from_checkpoint is not None:
        model = CausalVAEModel.from_pretrained(args.load_from_checkpoint)
    else:
        model = CausalVAEModel.from_config(args.model_config)

    if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized():
        print(model)

    # Load Dataset
    dataset = VideoDataset(
        args.video_path,
        sequence_length=args.video_num_frames,
        resolution=args.resolution,
        sample_rate=args.sample_rate,
        dynamic_sample=args.dynamic_sample,
    )
    train_loader = DataLoader(
        dataset,
        shuffle=True,
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        pin_memory=True,
    )
    # Load Callbacks and Logger
    callbacks, logger = load_callbacks_and_logger(args)
    # Load Trainer
    trainer = pl.Trainer(
        accelerator="cuda",
        devices=args.devices,
        num_nodes=args.n_nodes,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=5,
        precision=args.precision,
        max_steps=args.max_steps,
        strategy="ddp_find_unused_parameters_true",
    )
    trainer_kwargs = {}
    if args.resume_from_checkpoint:
        trainer_kwargs["ckpt_path"] = args.resume_from_checkpoint

    trainer.fit(model, train_loader, **trainer_kwargs)
    # Save Huggingface Model
    model.save_pretrained(os.path.join(args.output_dir, "hf"))


if __name__ == "__main__":
    parser = HfArgumentParser(TrainingArguments)
    args = parser.parse_args_into_dataclasses()
    train(args[0])