ntt123's picture
use latest torch
efb6748
raw
history blame
7.55 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import argparse
import os
import numpy as np
import torch
import yaml
from tqdm import tqdm
from diffusion import create_diffusion
from models import DiT_models
def find_model(model_name):
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
if "ema" in checkpoint: # supports checkpoints from train.py
print("Using EMA model")
checkpoint = checkpoint["ema"]
else:
print("Using model")
checkpoint = checkpoint["model"]
return checkpoint
def get_batch(
step, batch_size, seq_len, DEVICE, data_file, data_dim, data_mean, data_std
):
# Load dataset from memmap file
arr = np.memmap(data_file, dtype=np.float16, mode="r")
arr = np.memmap(
data_file,
dtype=np.float16,
mode="r",
shape=(arr.shape[0] // (data_dim + 3), data_dim + 3),
)
# Create random number generator
rng = np.random.Generator(np.random.PCG64(seed=step))
# Generate start indices and convert to integer array
start_indices = rng.choice(
arr.shape[0] - seq_len, size=batch_size, replace=False
).astype(np.int64)
# Create batch data array
batch_data = np.zeros((batch_size, seq_len, data_dim + 3), dtype=np.float16)
# Fill batch data one sequence at a time
for i, start_idx in enumerate(start_indices):
batch_data[i] = arr[start_idx : start_idx + seq_len]
# Extract features
x = batch_data[:, :, :data_dim].astype(np.float16)
x = np.moveaxis(x, 1, 2)
phone = batch_data[:, :, data_dim].astype(np.int32)
speaker_id = batch_data[:, :, data_dim + 1].astype(np.int32)
phone_kind = batch_data[:, :, data_dim + 2].astype(np.int32)
# convert to torch tensors
x = torch.from_numpy(x).to(DEVICE)
x = (x - data_mean) / data_std
phone = torch.from_numpy(phone).to(DEVICE)
speaker_id = torch.from_numpy(speaker_id).to(DEVICE)
phone_kind = torch.from_numpy(phone_kind).to(DEVICE)
return x, speaker_id, phone, phone_kind
def get_data(config_path, seed=0):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
data_config = config["data"]
model_config = config["model"]
device = "cuda" # if torch.cuda.is_available() else "cpu"
x, speaker_id, phone, phone_kind = get_batch(
seed,
1,
seq_len=model_config["input_size"],
DEVICE=device,
data_file=data_config["data_path"],
data_dim=data_config["data_dim"],
data_mean=data_config["data_mean"],
data_std=data_config["data_std"],
)
return x, speaker_id, phone, phone_kind
def plot_samples(samples, x):
# Create figure and axis
fig, ax = plt.subplots(figsize=(20, 4))
plt.tight_layout()
# Function to update frame
def update(frame):
ax.clear()
ax.text(
0.02,
0.98,
f"{frame+1} / 1000",
transform=ax.transAxes,
verticalalignment="top",
color="black",
)
if samples[frame].shape[1] > 1:
im = ax.imshow(
samples[frame].cpu().numpy()[0],
origin="lower",
aspect="auto",
interpolation="none",
vmin=-5,
vmax=5,
)
return [im]
elif samples[frame].shape[1] == 1:
line1 = ax.plot(samples[frame].cpu().numpy()[0, 0])[0]
line2 = ax.plot(x.cpu().numpy()[0, 0])[0]
plt.ylim(-10, 10)
return [line1, line2]
# Create animation with progress bar
anim = animation.FuncAnimation(
fig,
update,
frames=tqdm(range(len(samples)), desc="Generating animation"),
interval=1000 / 60,
blit=True, # 24 fps
)
# Save as MP4
anim.save("animation.mp4", fps=60, extra_args=["-vcodec", "libx264"])
plt.close()
model_cache = {}
def sample(
config_path,
ckpt_path,
cfg_scale=4.0,
num_sampling_steps=1000,
seed=0,
speaker_id=None,
phone=None,
phone_kind=None,
):
global model_cache
torch.manual_seed(seed)
torch.set_grad_enabled(False)
device = "cuda" # if torch.cuda.is_available() else "cpu"
with open(config_path, "r") as f:
config = yaml.safe_load(f)
data_config = config["data"]
model_config = config["model"]
if ckpt_path not in model_cache:
# Load model:
model = DiT_models[model_config["name"]](
input_size=model_config["input_size"],
embedding_vocab_size=model_config["embedding_vocab_size"],
learn_sigma=model_config["learn_sigma"],
in_channels=data_config["data_dim"],
).to(device).bfloat16
state_dict = find_model(ckpt_path)
model.load_state_dict(state_dict)
model.eval() # important!
model = model.bfloat16()
model_cache[ckpt_path] = model
else:
model = model_cache[ckpt_path]
diffusion = create_diffusion(str(num_sampling_steps))
n = 1
z = torch.randn(n, data_config["data_dim"], speaker_id.shape[1], device=device)
attn_mask = speaker_id[:, None, :] == speaker_id[:, :, None]
attn_mask = attn_mask.unsqueeze(1)
attn_mask = torch.cat([attn_mask, attn_mask], 0)
# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
unconditional_value = model.y_embedder.unconditional_value
phone_null = torch.full_like(phone, unconditional_value)
speaker_id_null = torch.full_like(speaker_id, unconditional_value)
phone = torch.cat([phone, phone_null], 0)
speaker_id = torch.cat([speaker_id, speaker_id_null], 0)
phone_kind_null = torch.full_like(phone_kind, unconditional_value)
phone_kind = torch.cat([phone_kind, phone_kind_null], 0)
model_kwargs = dict(
phone=phone,
speaker_id=speaker_id,
phone_kind=phone_kind,
cfg_scale=cfg_scale,
attn_mask=attn_mask,
)
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
samples = diffusion.p_sample_loop(
model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device=device,
)
samples = [s.chunk(2, dim=0)[0] for s in samples] # Remove null class samples
return samples
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--cfg-scale", type=float, default=4.0)
parser.add_argument("--num-sampling-steps", type=int, default=1000)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
x, speaker_id, phone, phone_kind = get_data(args.config, args.seed)
samples = sample(
args.config,
args.ckpt,
args.cfg_scale,
args.num_sampling_steps,
args.seed,
speaker_id,
phone,
phone_kind,
)
plot_samples(samples, x)