Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import argparse | |
import os | |
import numpy as np | |
import torch | |
import yaml | |
from tqdm import tqdm | |
from diffusion import create_diffusion | |
from models import DiT_models | |
def find_model(model_name): | |
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" | |
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) | |
if "ema" in checkpoint: # supports checkpoints from train.py | |
print("Using EMA model") | |
checkpoint = checkpoint["ema"] | |
else: | |
print("Using model") | |
checkpoint = checkpoint["model"] | |
return checkpoint | |
def get_batch( | |
step, batch_size, seq_len, DEVICE, data_file, data_dim, data_mean, data_std | |
): | |
# Load dataset from memmap file | |
arr = np.memmap(data_file, dtype=np.float16, mode="r") | |
arr = np.memmap( | |
data_file, | |
dtype=np.float16, | |
mode="r", | |
shape=(arr.shape[0] // (data_dim + 3), data_dim + 3), | |
) | |
# Create random number generator | |
rng = np.random.Generator(np.random.PCG64(seed=step)) | |
# Generate start indices and convert to integer array | |
start_indices = rng.choice( | |
arr.shape[0] - seq_len, size=batch_size, replace=False | |
).astype(np.int64) | |
# Create batch data array | |
batch_data = np.zeros((batch_size, seq_len, data_dim + 3), dtype=np.float16) | |
# Fill batch data one sequence at a time | |
for i, start_idx in enumerate(start_indices): | |
batch_data[i] = arr[start_idx : start_idx + seq_len] | |
# Extract features | |
x = batch_data[:, :, :data_dim].astype(np.float16) | |
x = np.moveaxis(x, 1, 2) | |
phone = batch_data[:, :, data_dim].astype(np.int32) | |
speaker_id = batch_data[:, :, data_dim + 1].astype(np.int32) | |
phone_kind = batch_data[:, :, data_dim + 2].astype(np.int32) | |
# convert to torch tensors | |
x = torch.from_numpy(x).to(DEVICE) | |
x = (x - data_mean) / data_std | |
phone = torch.from_numpy(phone).to(DEVICE) | |
speaker_id = torch.from_numpy(speaker_id).to(DEVICE) | |
phone_kind = torch.from_numpy(phone_kind).to(DEVICE) | |
return x, speaker_id, phone, phone_kind | |
def get_data(config_path, seed=0): | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
data_config = config["data"] | |
model_config = config["model"] | |
device = "cuda" # if torch.cuda.is_available() else "cpu" | |
x, speaker_id, phone, phone_kind = get_batch( | |
seed, | |
1, | |
seq_len=model_config["input_size"], | |
DEVICE=device, | |
data_file=data_config["data_path"], | |
data_dim=data_config["data_dim"], | |
data_mean=data_config["data_mean"], | |
data_std=data_config["data_std"], | |
) | |
return x, speaker_id, phone, phone_kind | |
def plot_samples(samples, x): | |
# Create figure and axis | |
fig, ax = plt.subplots(figsize=(20, 4)) | |
plt.tight_layout() | |
# Function to update frame | |
def update(frame): | |
ax.clear() | |
ax.text( | |
0.02, | |
0.98, | |
f"{frame+1} / 1000", | |
transform=ax.transAxes, | |
verticalalignment="top", | |
color="black", | |
) | |
if samples[frame].shape[1] > 1: | |
im = ax.imshow( | |
samples[frame].cpu().numpy()[0], | |
origin="lower", | |
aspect="auto", | |
interpolation="none", | |
vmin=-5, | |
vmax=5, | |
) | |
return [im] | |
elif samples[frame].shape[1] == 1: | |
line1 = ax.plot(samples[frame].cpu().numpy()[0, 0])[0] | |
line2 = ax.plot(x.cpu().numpy()[0, 0])[0] | |
plt.ylim(-10, 10) | |
return [line1, line2] | |
# Create animation with progress bar | |
anim = animation.FuncAnimation( | |
fig, | |
update, | |
frames=tqdm(range(len(samples)), desc="Generating animation"), | |
interval=1000 / 60, | |
blit=True, # 24 fps | |
) | |
# Save as MP4 | |
anim.save("animation.mp4", fps=60, extra_args=["-vcodec", "libx264"]) | |
plt.close() | |
model_cache = {} | |
def sample( | |
config_path, | |
ckpt_path, | |
cfg_scale=4.0, | |
num_sampling_steps=1000, | |
seed=0, | |
speaker_id=None, | |
phone=None, | |
phone_kind=None, | |
): | |
global model_cache | |
torch.manual_seed(seed) | |
torch.set_grad_enabled(False) | |
device = "cuda" # if torch.cuda.is_available() else "cpu" | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
data_config = config["data"] | |
model_config = config["model"] | |
if ckpt_path not in model_cache: | |
# Load model: | |
model = DiT_models[model_config["name"]]( | |
input_size=model_config["input_size"], | |
embedding_vocab_size=model_config["embedding_vocab_size"], | |
learn_sigma=model_config["learn_sigma"], | |
in_channels=data_config["data_dim"], | |
).to(device) | |
state_dict = find_model(ckpt_path) | |
model.load_state_dict(state_dict) | |
model.eval() # important! | |
model_cache[ckpt_path] = model | |
else: | |
model = model_cache[ckpt_path] | |
diffusion = create_diffusion(str(num_sampling_steps)) | |
n = 1 | |
z = torch.randn(n, data_config["data_dim"], speaker_id.shape[1], device=device) | |
attn_mask = speaker_id[:, None, :] == speaker_id[:, :, None] | |
attn_mask = attn_mask.unsqueeze(1) | |
attn_mask = torch.cat([attn_mask, attn_mask], 0) | |
# Setup classifier-free guidance: | |
z = torch.cat([z, z], 0) | |
unconditional_value = model.y_embedder.unconditional_value | |
phone_null = torch.full_like(phone, unconditional_value) | |
speaker_id_null = torch.full_like(speaker_id, unconditional_value) | |
phone = torch.cat([phone, phone_null], 0) | |
speaker_id = torch.cat([speaker_id, speaker_id_null], 0) | |
phone_kind_null = torch.full_like(phone_kind, unconditional_value) | |
phone_kind = torch.cat([phone_kind, phone_kind_null], 0) | |
model_kwargs = dict( | |
phone=phone, | |
speaker_id=speaker_id, | |
phone_kind=phone_kind, | |
cfg_scale=cfg_scale, | |
attn_mask=attn_mask, | |
) | |
samples = diffusion.p_sample_loop( | |
model.forward_with_cfg, | |
z.shape, | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=device, | |
) | |
samples = [s.chunk(2, dim=0)[0] for s in samples] # Remove null class samples | |
return samples | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--ckpt", type=str, required=True) | |
parser.add_argument("--cfg-scale", type=float, default=4.0) | |
parser.add_argument("--num-sampling-steps", type=int, default=1000) | |
parser.add_argument("--seed", type=int, default=0) | |
args = parser.parse_args() | |
x, speaker_id, phone, phone_kind = get_data(args.config, args.seed) | |
samples = sample( | |
args.config, | |
args.ckpt, | |
args.cfg_scale, | |
args.num_sampling_steps, | |
args.seed, | |
speaker_id, | |
phone, | |
phone_kind, | |
) | |
plot_samples(samples, x) | |