|
|
|
|
|
import os |
|
import argparse |
|
|
|
import torch |
|
import torchvision |
|
import numpy as np |
|
from PIL import Image |
|
import pytorch_lightning as pl |
|
from omegaconf import OmegaConf |
|
from librosa.util import normalize |
|
from ldm.util import instantiate_from_config |
|
from pytorch_lightning.trainer import Trainer |
|
from torch.utils.data import DataLoader, Dataset |
|
from datasets import load_from_disk, load_dataset |
|
from pytorch_lightning.callbacks import Callback, ModelCheckpoint |
|
from pytorch_lightning.utilities.distributed import rank_zero_only |
|
|
|
from audiodiffusion.mel import Mel |
|
from audiodiffusion.utils import convert_ldm_to_hf_vae |
|
|
|
|
|
class AudioDiffusion(Dataset): |
|
|
|
def __init__(self, model_id, channels=3): |
|
super().__init__() |
|
self.channels = channels |
|
if os.path.exists(model_id): |
|
self.hf_dataset = load_from_disk(model_id)['train'] |
|
else: |
|
self.hf_dataset = load_dataset(model_id)['train'] |
|
|
|
def __len__(self): |
|
return len(self.hf_dataset) |
|
|
|
def __getitem__(self, idx): |
|
image = self.hf_dataset[idx]['image'] |
|
if self.channels == 3: |
|
image = image.convert('RGB') |
|
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape( |
|
(image.height, image.width, self.channels)) |
|
image = ((image / 255) * 2 - 1) |
|
return {'image': image} |
|
|
|
|
|
class AudioDiffusionDataModule(pl.LightningDataModule): |
|
|
|
def __init__(self, model_id, batch_size, channels): |
|
super().__init__() |
|
self.batch_size = batch_size |
|
self.dataset = AudioDiffusion(model_id=model_id, channels=channels) |
|
self.num_workers = 1 |
|
|
|
def train_dataloader(self): |
|
return DataLoader(self.dataset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers) |
|
|
|
|
|
class ImageLogger(Callback): |
|
|
|
def __init__(self, every=1000, hop_length=512): |
|
super().__init__() |
|
self.every = every |
|
self.hop_length = hop_length |
|
|
|
@rank_zero_only |
|
def log_images_and_audios(self, pl_module, batch): |
|
pl_module.eval() |
|
with torch.no_grad(): |
|
images = pl_module.log_images(batch, split='train') |
|
pl_module.train() |
|
|
|
image_shape = next(iter(images.values())).shape |
|
channels = image_shape[1] |
|
mel = Mel(x_res=image_shape[2], |
|
y_res=image_shape[3], |
|
hop_length=self.hop_length) |
|
|
|
for k in images: |
|
images[k] = images[k].detach().cpu() |
|
images[k] = torch.clamp(images[k], -1., 1.) |
|
images[k] = (images[k] + 1.0) / 2.0 |
|
grid = torchvision.utils.make_grid(images[k]) |
|
|
|
tag = f"train/{k}" |
|
pl_module.logger.experiment.add_image( |
|
tag, grid, global_step=pl_module.global_step) |
|
|
|
images[k] = (images[k].numpy() * |
|
255).round().astype("uint8").transpose(0, 2, 3, 1) |
|
for _, image in enumerate(images[k]): |
|
audio = mel.image_to_audio( |
|
Image.fromarray(image, mode='RGB').convert('L') |
|
if channels == 3 else Image.fromarray(image[:, :, 0])) |
|
pl_module.logger.experiment.add_audio( |
|
tag + f"/{_}", |
|
normalize(audio), |
|
global_step=pl_module.global_step, |
|
sample_rate=mel.get_sample_rate()) |
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, |
|
batch_idx): |
|
if (batch_idx + 1) % self.every != 0: |
|
return |
|
self.log_images_and_audios(pl_module, batch) |
|
|
|
|
|
class HFModelCheckpoint(ModelCheckpoint): |
|
|
|
def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.ldm_config = ldm_config |
|
self.hf_checkpoint = hf_checkpoint |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
ldm_checkpoint = self._get_metric_interpolated_filepath_name( |
|
{'epoch': trainer.current_epoch}, trainer) |
|
super().on_train_epoch_end(trainer, pl_module) |
|
convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config, |
|
self.hf_checkpoint) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Train VAE using ldm.") |
|
parser.add_argument("-d", "--dataset_name", type=str, default=None) |
|
parser.add_argument("-b", "--batch_size", type=int, default=1) |
|
parser.add_argument("-c", |
|
"--ldm_config_file", |
|
type=str, |
|
default="config/ldm_autoencoder_kl.yaml") |
|
parser.add_argument("--ldm_checkpoint_dir", |
|
type=str, |
|
default="models/ldm-autoencoder-kl") |
|
parser.add_argument("--hf_checkpoint_dir", |
|
type=str, |
|
default="models/autoencoder-kl") |
|
parser.add_argument("-r", |
|
"--resume_from_checkpoint", |
|
type=str, |
|
default=None) |
|
parser.add_argument("-g", |
|
"--gradient_accumulation_steps", |
|
type=int, |
|
default=1) |
|
parser.add_argument("--hop_length", type=int, default=512) |
|
parser.add_argument("--save_images_batches", type=int, default=1000) |
|
parser.add_argument("--max_epochs", type=int, default=100) |
|
args = parser.parse_args() |
|
|
|
config = OmegaConf.load(args.ldm_config_file) |
|
model = instantiate_from_config(config.model) |
|
model.learning_rate = config.model.base_learning_rate |
|
data = AudioDiffusionDataModule( |
|
model_id=args.dataset_name, |
|
batch_size=args.batch_size, |
|
channels=config.model.params.ddconfig.in_channels) |
|
lightning_config = config.pop("lightning", OmegaConf.create()) |
|
trainer_config = lightning_config.get("trainer", OmegaConf.create()) |
|
trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps |
|
trainer_opt = argparse.Namespace(**trainer_config) |
|
trainer = Trainer.from_argparse_args( |
|
trainer_opt, |
|
max_epochs=args.max_epochs, |
|
resume_from_checkpoint=args.resume_from_checkpoint, |
|
callbacks=[ |
|
ImageLogger(every=args.save_images_batches, |
|
hop_length=args.hop_length), |
|
HFModelCheckpoint(ldm_config=config, |
|
hf_checkpoint=args.hf_checkpoint_dir, |
|
dirpath=args.ldm_checkpoint_dir, |
|
filename='{epoch:06}', |
|
verbose=True, |
|
save_last=True) |
|
]) |
|
trainer.fit(model, data) |
|
|