Spaces:
Running
on
Zero
Running
on
Zero
from diffusers import ( | |
DPMSolverMultistepScheduler, | |
DDPMScheduler, | |
DDIMScheduler, | |
PNDMScheduler, | |
DEISMultistepScheduler, | |
) | |
import torch | |
import yaml | |
import math | |
import tqdm | |
import time | |
class DiffusePipeline(object): | |
def __init__( | |
self, | |
opt, | |
model, | |
diffuser_name, | |
num_inference_steps, | |
device, | |
torch_dtype=torch.float16, | |
): | |
self.device = device | |
self.torch_dtype = torch_dtype | |
self.diffuser_name = diffuser_name | |
self.num_inference_steps = num_inference_steps | |
if self.torch_dtype == torch.float16: | |
model = model.half() | |
self.model = model.to(device) | |
self.opt = opt | |
# Load parameters from YAML file | |
with open("config/diffuser_params.yaml", "r") as yaml_file: | |
diffuser_params = yaml.safe_load(yaml_file) | |
# Select diffusion'parameters based on diffuser_name | |
if diffuser_name in diffuser_params: | |
params = diffuser_params[diffuser_name] | |
scheduler_class_name = params["scheduler_class"] | |
additional_params = params["additional_params"] | |
# align training parameters | |
additional_params["num_train_timesteps"] = opt.diffusion_steps | |
additional_params["beta_schedule"] = opt.beta_schedule | |
additional_params["prediction_type"] = opt.prediction_type | |
try: | |
scheduler_class = globals()[scheduler_class_name] | |
except KeyError: | |
raise ValueError(f"Class '{scheduler_class_name}' not found.") | |
self.scheduler = scheduler_class(**additional_params) | |
else: | |
raise ValueError(f"Unsupported diffuser_name: {diffuser_name}") | |
def generate_batch(self, caption, m_lens): | |
B = len(caption) | |
T = m_lens.max() | |
shape = (B, T, self.model.input_feats) | |
# random sampling noise x_T | |
sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype) | |
# set timesteps | |
self.scheduler.set_timesteps(self.num_inference_steps, self.device) | |
timesteps = [ | |
torch.tensor([t] * B, device=self.device).long() | |
for t in self.scheduler.timesteps | |
] | |
# cache text_embedded | |
enc_text = self.model.encode_text(caption, self.device) | |
for i, t in enumerate(timesteps): | |
# 1. model predict | |
with torch.no_grad(): | |
if getattr(self.model, "cond_mask_prob", 0) > 0: | |
predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text) | |
else: | |
predict = self.model(sample, t, enc_text=enc_text) | |
# 2. compute less noisy motion and set x_t -> x_t-1 | |
sample = self.scheduler.step(predict, t[0], sample).prev_sample | |
return sample | |
def generate(self, caption, m_lens, batch_size=32): | |
N = len(caption) | |
infer_mode = "" | |
if getattr(self.model, "cond_mask_prob", 0) > 0: | |
infer_mode = "classifier-free-guidance" | |
print( | |
f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps." | |
) | |
self.model.eval() | |
all_output = [] | |
t_sum = 0 | |
cur_idx = 0 | |
for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))): | |
if cur_idx + batch_size >= N: | |
batch_caption = caption[cur_idx:] | |
batch_m_lens = m_lens[cur_idx:] | |
else: | |
batch_caption = caption[cur_idx : cur_idx + batch_size] | |
batch_m_lens = m_lens[cur_idx : cur_idx + batch_size] | |
torch.cuda.synchronize() | |
start_time = time.time() | |
output = self.generate_batch(batch_caption, batch_m_lens) | |
torch.cuda.synchronize() | |
now_time = time.time() | |
# The average inference time is calculated after GPU warm-up in the first 50 steps. | |
if (bacth_idx + 1) * self.num_inference_steps >= 50: | |
t_sum += now_time - start_time | |
# Crop motion with gt/predicted motion length | |
B = output.shape[0] | |
for i in range(B): | |
all_output.append(output[i, : batch_m_lens[i]]) | |
cur_idx += batch_size | |
# calcalate average inference time | |
t_eval = t_sum / (bacth_idx - 1) | |
print( | |
"The average generation time of a batch motion (bs=%d) is %f seconds" | |
% (batch_size, t_eval) | |
) | |
return all_output, t_eval | |