File size: 6,784 Bytes
7bc83e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b1ccbd
7bc83e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import numpy as np

class GaussianMixtureModel:
    def __init__(self, mu, Sigma, pi):
        # Ensure all inputs are torch tensors
        assert isinstance(mu, torch.Tensor), "mu must be a torch.Tensor."
        assert isinstance(Sigma, torch.Tensor), "Sigma must be a torch.Tensor."
        assert isinstance(pi, torch.Tensor), "pi must be a torch.Tensor."
        
        # Ensure the dtype is torch.float32
        assert mu.dtype == torch.float32, "mu must have dtype torch.float32."
        assert Sigma.dtype == torch.float32, "Sigma must have dtype torch.float32."
        assert pi.dtype == torch.float32, "pi must have dtype torch.float32."
        
        self.K, self.d = mu.shape[:2]
        
        # Check the shape of mu
        if mu.shape == (self.K, self.d):
            mu = mu.unsqueeze(-1)  # Shape: (K, d, 1)
        assert mu.shape == (self.K, self.d, 1), "mu must have shape (K, d, 1)."
        
        # Check the shape of Sigma
        assert Sigma.shape == (self.K, self.d, self.d), "Sigma must have shape (K, d, d)."
        
        # Check the shape of pi and fix it if necessary
        if pi.shape == (self.K,):
            pi = pi.view(self.K, 1, 1)
        elif pi.shape == (self.K, 1):
            pi = pi.unsqueeze(-1)
        assert pi.shape == (self.K, 1, 1), "pi must have shape (K, 1, 1)."
        
        # Ensure pi sums to 1
        assert torch.isclose(torch.sum(pi), torch.tensor(1.0)), "Mixture weights must sum to 1."
        
        self.mu = mu
        self.Sigma = Sigma
        self.pi = pi
        
    def sample(self, n_samples):
        # Sample from the mixture model
        samples = []
        for _ in range(n_samples):
            # Choose a component based on mixture weights
            k = torch.multinomial(self.pi.reshape(self.pi.shape[0]), 1).item()
            sample = torch.distributions.MultivariateNormal(self.mu[k].squeeze(), self.Sigma[k]).sample()
            samples.append(sample)
        return torch.stack(samples)
    
    def log_prob(self, x):
        # Compute the log probability of a given sample x
        x = x.view(1, self.d, 1)  # Shape: (1, d, 1)
        diff = x - self.mu  # Shape: (K, d, 1)
        inv_Sigma = torch.inverse(self.Sigma)  # Shape: (K, d, d)
        
        exponent = -0.5 * torch.bmm(torch.bmm(diff.transpose(1, 2), inv_Sigma), diff).squeeze()  # Shape: (K,)
        normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)) / 2  # Shape: (K,)
        log_probs = torch.log(self.pi.squeeze()) + exponent - normalization  # Shape: (K,)
        return torch.logsumexp(log_probs, dim=0)
    
    def score(self, x):
        # Compute the score function (gradient of log probability) for a batch of samples x
        B = x.shape[0]
        x = x.view(B, 1, self.d, 1)  # Shape: (B, 1, d, 1)
        diff = x - self.mu.unsqueeze(0)  # Shape: (B, K, d, 1)
        inv_Sigma = torch.inverse(self.Sigma).unsqueeze(0)  # Shape: (1, K, d, d)
        
        diff_t = diff.transpose(-2, -1).contiguous().view(B * self.K, 1, self.d)  # Shape: (B*K, 1, d)
        inv_Sigma_flat = inv_Sigma.view(self.K, self.d, self.d).expand(B, self.K, self.d, self.d).contiguous().view(B * self.K, self.d, self.d)  # Shape: (B*K, d, d)
        exponent = -0.5 * torch.bmm(torch.bmm(diff_t, inv_Sigma_flat), diff.view(B * self.K, self.d, 1)).view(B, self.K)  # Shape: (B, K)
        
        normalization = torch.log(torch.det(2 * torch.pi * self.Sigma)).unsqueeze(0) / 2  # Shape: (1, K)
        probs = torch.exp(torch.log(self.pi.squeeze()).unsqueeze(0) + exponent - normalization)  # Shape: (B, K)
        norm_probs = probs / torch.sum(probs, dim=1, keepdim=True)  # Shape: (B, K)
        
        gradients = []
        for k in range(self.K):
            gradient = -torch.bmm(inv_Sigma[:, k].expand(B, self.d, self.d), diff[:, k])  # Shape: (B, d, 1)
            gradients.append(norm_probs[:, k].unsqueeze(1) * gradient.squeeze(-1))  # Shape: (B, d)
        return torch.sum(torch.stack(gradients, dim=1), dim=1)  # Shape: (B, d)
    
    def forward_diffusion(self, t):
        # Compute the evolved mean and covariance
        mu_t = self.mu * torch.tensor(np.exp(-0.5 * t), dtype=torch.float32)  # Shape: (K, d, 1)
        exp_neg_beta_t = torch.tensor(np.exp(-t), dtype=torch.float32).view(1, 1, 1)  # Shape: (1, 1, 1)
        Sigma_t = self.Sigma * exp_neg_beta_t + torch.eye(self.d, dtype=torch.float32) * (1 - exp_neg_beta_t)  # Shape: (K, d, d)
        return GaussianMixtureModel(mu_t, Sigma_t, self.pi)
    
    def flow(self, x_t, t, dt, num_steps):
        for _ in range(num_steps):
            x_t = self.probability_flow_ode(x_t, t, dt)
            t += dt
        return x_t
    
    def flow_gmm_to_normal(self, x_0, T=5, N=32):
        dt = T / N
        return self.flow(x_0, 0, dt, N)
    
    def flow_normal_to_gmm(self, x_T, T=5, N=32):
        dt = -T / N
        return self.flow(x_T, T, dt, N)
    
    def probability_flow_ode(self, x_t, t, dt):
        # Compute the evolved x_t based on the probability flow ODE using the Euler method
        # Forward diffusion to get the evolved GMM parameters
        gmm_t = self.forward_diffusion(t)
        
        # Compute the score function at x_t
        score = gmm_t.score(x_t)
        
        # Compute the drift term
        drift = -0.5 * x_t - 0.5 * score
        
        # Euler update
        x_t_plus_dt = x_t + drift * dt
        
        return x_t_plus_dt

# Example usage
if __name__ == "__main__":
    mu = torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)  # Shape: (K, d)
    Sigma = torch.stack([torch.eye(2), torch.eye(2)], dim=0).float()  # Shape: (K, d, d)
    pi = torch.tensor([0.5, 0.5], dtype=torch.float32)  # Shape: (K,)

    gmm = GaussianMixtureModel(mu, Sigma, pi)

    # Perform forward diffusion
    t = 1.0
    gmm_t = gmm.forward_diffusion(t)

    # Sampling from the diffused GMM
    samples = gmm_t.sample(10)
    print("Samples after forward diffusion:\n", samples)

    # Log probability of a sample after forward diffusion
    log_prob = gmm_t.log_prob(torch.tensor([0.5, 0.5], dtype=torch.float32))
    print("Log Probability after forward diffusion:", log_prob)

    # Score of a sample after forward diffusion
    score = gmm_t.score(torch.tensor([[0.5, 0.5]], dtype=torch.float32))
    print("Score after forward diffusion:", score)

    # Using the flow_gmm_to_normal
    x_0 = torch.tensor([[0.5, 0.5]], dtype=torch.float32)
    x_T_normal = gmm.flow_gmm_to_normal(x_0)
    print("x_T_normal after flowing from GMM to normal:", x_T_normal)

    # Using the flow_normal_to_gmm
    x_T = torch.tensor([[0.0, 0.0]], dtype=torch.float32)
    x_0_gmm = gmm.flow_normal_to_gmm(x_T)
    print("x_0_gmm after flowing from normal to GMM:", x_0_gmm)