File size: 4,631 Bytes
b887ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from diffusers import (
    DPMSolverMultistepScheduler,
    DDPMScheduler,
    DDIMScheduler,
    PNDMScheduler,
    DEISMultistepScheduler,
)
import torch
import yaml
import math
import tqdm
import time


class DiffusePipeline(object):

    def __init__(
        self,
        opt,
        model,
        diffuser_name,
        num_inference_steps,
        device,
        torch_dtype=torch.float16,
    ):
        self.device = device
        self.torch_dtype = torch_dtype
        self.diffuser_name = diffuser_name
        self.num_inference_steps = num_inference_steps
        if self.torch_dtype == torch.float16:
            model = model.half()
        self.model = model.to(device)
        self.opt = opt

        # Load parameters from YAML file
        with open("config/diffuser_params.yaml", "r") as yaml_file:
            diffuser_params = yaml.safe_load(yaml_file)

        # Select diffusion'parameters based on diffuser_name
        if diffuser_name in diffuser_params:
            params = diffuser_params[diffuser_name]
            scheduler_class_name = params["scheduler_class"]
            additional_params = params["additional_params"]

            # align training parameters
            additional_params["num_train_timesteps"] = opt.diffusion_steps
            additional_params["beta_schedule"] = opt.beta_schedule
            additional_params["prediction_type"] = opt.prediction_type

            try:
                scheduler_class = globals()[scheduler_class_name]
            except KeyError:
                raise ValueError(f"Class '{scheduler_class_name}' not found.")

            self.scheduler = scheduler_class(**additional_params)
        else:
            raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")

    def generate_batch(self, caption, m_lens):
        B = len(caption)
        T = m_lens.max()
        shape = (B, T, self.model.input_feats)

        # random sampling noise x_T
        sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)

        # set timesteps
        self.scheduler.set_timesteps(self.num_inference_steps, self.device)
        timesteps = [
            torch.tensor([t] * B, device=self.device).long()
            for t in self.scheduler.timesteps
        ]

        # cache text_embedded
        enc_text = self.model.encode_text(caption, self.device)

        for i, t in enumerate(timesteps):
            # 1. model predict
            with torch.no_grad():
                if getattr(self.model, "cond_mask_prob", 0) > 0:
                    predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
                else:

                    predict = self.model(sample, t, enc_text=enc_text)

            # 2. compute less noisy motion and set x_t -> x_t-1
            sample = self.scheduler.step(predict, t[0], sample).prev_sample

        return sample

    def generate(self, caption, m_lens, batch_size=32):
        N = len(caption)
        infer_mode = ""
        if getattr(self.model, "cond_mask_prob", 0) > 0:
            infer_mode = "classifier-free-guidance"
        print(
            f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
        )
        self.model.eval()

        all_output = []
        t_sum = 0
        cur_idx = 0
        for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
            if cur_idx + batch_size >= N:
                batch_caption = caption[cur_idx:]
                batch_m_lens = m_lens[cur_idx:]
            else:
                batch_caption = caption[cur_idx : cur_idx + batch_size]
                batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
            torch.cuda.synchronize()
            start_time = time.time()
            output = self.generate_batch(batch_caption, batch_m_lens)
            torch.cuda.synchronize()
            now_time = time.time()

            # The average inference time is calculated after GPU warm-up in the first 50 steps.
            if (bacth_idx + 1) * self.num_inference_steps >= 50:
                t_sum += now_time - start_time

            # Crop motion with gt/predicted motion length
            B = output.shape[0]
            for i in range(B):
                all_output.append(output[i, : batch_m_lens[i]])

            cur_idx += batch_size

        # calcalate average inference time
        t_eval = t_sum / (bacth_idx - 1)
        print(
            "The average generation time of a batch motion (bs=%d) is %f seconds"
            % (batch_size, t_eval)
        )
        return all_output, t_eval