# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import argparse import os import numpy as np import torch import yaml from tqdm import tqdm from diffusion import create_diffusion from models import DiT_models def find_model(model_name): assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) if "ema" in checkpoint: # supports checkpoints from train.py print("Using EMA model") checkpoint = checkpoint["ema"] else: print("Using model") checkpoint = checkpoint["model"] return checkpoint def get_batch( step, batch_size, seq_len, DEVICE, data_file, data_dim, data_mean, data_std ): # Load dataset from memmap file arr = np.memmap(data_file, dtype=np.float16, mode="r") arr = np.memmap( data_file, dtype=np.float16, mode="r", shape=(arr.shape[0] // (data_dim + 3), data_dim + 3), ) # Create random number generator rng = np.random.Generator(np.random.PCG64(seed=step)) # Generate start indices and convert to integer array start_indices = rng.choice( arr.shape[0] - seq_len, size=batch_size, replace=False ).astype(np.int64) # Create batch data array batch_data = np.zeros((batch_size, seq_len, data_dim + 3), dtype=np.float16) # Fill batch data one sequence at a time for i, start_idx in enumerate(start_indices): batch_data[i] = arr[start_idx : start_idx + seq_len] # Extract features x = batch_data[:, :, :data_dim].astype(np.float16) x = np.moveaxis(x, 1, 2) phone = batch_data[:, :, data_dim].astype(np.int32) speaker_id = batch_data[:, :, data_dim + 1].astype(np.int32) phone_kind = batch_data[:, :, data_dim + 2].astype(np.int32) # convert to torch tensors x = torch.from_numpy(x).to(DEVICE) x = (x - data_mean) / data_std phone = torch.from_numpy(phone).to(DEVICE) speaker_id = torch.from_numpy(speaker_id).to(DEVICE) phone_kind = torch.from_numpy(phone_kind).to(DEVICE) return x, speaker_id, phone, phone_kind def get_data(config_path, seed=0): with open(config_path, "r") as f: config = yaml.safe_load(f) data_config = config["data"] model_config = config["model"] device = "cuda" # if torch.cuda.is_available() else "cpu" x, speaker_id, phone, phone_kind = get_batch( seed, 1, seq_len=model_config["input_size"], DEVICE=device, data_file=data_config["data_path"], data_dim=data_config["data_dim"], data_mean=data_config["data_mean"], data_std=data_config["data_std"], ) return x, speaker_id, phone, phone_kind def plot_samples(samples, x): # Create figure and axis fig, ax = plt.subplots(figsize=(20, 4)) plt.tight_layout() # Function to update frame def update(frame): ax.clear() ax.text( 0.02, 0.98, f"{frame+1} / 1000", transform=ax.transAxes, verticalalignment="top", color="black", ) if samples[frame].shape[1] > 1: im = ax.imshow( samples[frame].cpu().numpy()[0], origin="lower", aspect="auto", interpolation="none", vmin=-5, vmax=5, ) return [im] elif samples[frame].shape[1] == 1: line1 = ax.plot(samples[frame].cpu().numpy()[0, 0])[0] line2 = ax.plot(x.cpu().numpy()[0, 0])[0] plt.ylim(-10, 10) return [line1, line2] # Create animation with progress bar anim = animation.FuncAnimation( fig, update, frames=tqdm(range(len(samples)), desc="Generating animation"), interval=1000 / 60, blit=True, # 24 fps ) # Save as MP4 anim.save("animation.mp4", fps=60, extra_args=["-vcodec", "libx264"]) plt.close() model_cache = {} def sample( config_path, ckpt_path, cfg_scale=4.0, num_sampling_steps=1000, seed=0, speaker_id=None, phone=None, phone_kind=None, ): global model_cache torch.manual_seed(seed) torch.set_grad_enabled(False) device = "cuda" # if torch.cuda.is_available() else "cpu" with open(config_path, "r") as f: config = yaml.safe_load(f) data_config = config["data"] model_config = config["model"] if ckpt_path not in model_cache: # Load model: model = DiT_models[model_config["name"]]( input_size=model_config["input_size"], embedding_vocab_size=model_config["embedding_vocab_size"], learn_sigma=model_config["learn_sigma"], in_channels=data_config["data_dim"], ).to(device) state_dict = find_model(ckpt_path) model.load_state_dict(state_dict) model.eval() # important! model_cache[ckpt_path] = model else: model = model_cache[ckpt_path] diffusion = create_diffusion(str(num_sampling_steps)) n = 1 z = torch.randn(n, data_config["data_dim"], speaker_id.shape[1], device=device) attn_mask = speaker_id[:, None, :] == speaker_id[:, :, None] attn_mask = attn_mask.unsqueeze(1) attn_mask = torch.cat([attn_mask, attn_mask], 0) # Setup classifier-free guidance: z = torch.cat([z, z], 0) unconditional_value = model.y_embedder.unconditional_value phone_null = torch.full_like(phone, unconditional_value) speaker_id_null = torch.full_like(speaker_id, unconditional_value) phone = torch.cat([phone, phone_null], 0) speaker_id = torch.cat([speaker_id, speaker_id_null], 0) phone_kind_null = torch.full_like(phone_kind, unconditional_value) phone_kind = torch.cat([phone_kind, phone_kind_null], 0) model_kwargs = dict( phone=phone, speaker_id=speaker_id, phone_kind=phone_kind, cfg_scale=cfg_scale, attn_mask=attn_mask, ) with torch.no_grad(): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): samples = diffusion.p_sample_loop( model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device, ) samples = [s.chunk(2, dim=0)[0] for s in samples] # Remove null class samples return samples if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--ckpt", type=str, required=True) parser.add_argument("--cfg-scale", type=float, default=4.0) parser.add_argument("--num-sampling-steps", type=int, default=1000) parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() x, speaker_id, phone, phone_kind = get_data(args.config, args.seed) samples = sample( args.config, args.ckpt, args.cfg_scale, args.num_sampling_steps, args.seed, speaker_id, phone, phone_kind, ) plot_samples(samples, x)