Spaces:
Sleeping
Sleeping
Jonas Becker
commited on
Commit
·
7f19394
1
Parent(s):
2478e2a
1st try
Browse files- .gitattributes +1 -0
- .gitignore +2 -0
- app.bat +2 -0
- app.py +50 -0
- disvae/__init__.py +6 -0
- disvae/evaluate.py +317 -0
- disvae/main.py +145 -0
- disvae/models/__init__.py +0 -0
- disvae/models/decoders.py +84 -0
- disvae/models/discriminator.py +73 -0
- disvae/models/encoders.py +89 -0
- disvae/models/losses.py +544 -0
- disvae/models/vae.py +101 -0
- disvae/training.py +212 -0
- disvae/utils/__init__.py +0 -0
- disvae/utils/initialization.py +61 -0
- disvae/utils/math.py +73 -0
- disvae/utils/modelIO.py +200 -0
- model/drilling_ds_btcvae/model-0.pt +3 -0
- model/drilling_ds_btcvae/model-10.pt +3 -0
- model/drilling_ds_btcvae/model-20.pt +3 -0
- model/drilling_ds_btcvae/model.pt +3 -0
- model/drilling_ds_btcvae/onnx/decoder.onnx +3 -0
- model/drilling_ds_btcvae/onnx/encoder.onnx +3 -0
- model/drilling_ds_btcvae/specs.json +9 -0
- model/drilling_ds_btcvae/train_losses.log +481 -0
- requirements.txt +11 -0
- transforms.py +168 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__
|
app.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
call venv\scripts\activate.bat
|
2 |
+
call streamlit run app.py
|
app.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
import disvae
|
7 |
+
import transforms as trans
|
8 |
+
|
9 |
+
P_MODEL = "model/drilling_ds_btcvae"
|
10 |
+
|
11 |
+
# Decode Funktion --------------------------------------------------
|
12 |
+
sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
|
13 |
+
vae = disvae.load_model(P_MODEL)
|
14 |
+
scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
|
15 |
+
imaging = trans.SumField()
|
16 |
+
|
17 |
+
_dec = trans.sequential_function(
|
18 |
+
sorter.inv,
|
19 |
+
vae.decoder,
|
20 |
+
scaler.inv
|
21 |
+
)
|
22 |
+
|
23 |
+
def decode(latent):
|
24 |
+
with torch.no_grad():
|
25 |
+
return trans.np_sample(_dec)(latent)
|
26 |
+
|
27 |
+
img2ts = trans.np_sample(imaging.inv)
|
28 |
+
|
29 |
+
# GUI -----------------------------------------------------------
|
30 |
+
|
31 |
+
latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
|
32 |
+
latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
|
33 |
+
|
34 |
+
value = decode(latent_vector)
|
35 |
+
|
36 |
+
ts = img2ts(value)
|
37 |
+
|
38 |
+
df = pd.DataFrame({
|
39 |
+
"x":np.arange(len(ts)),
|
40 |
+
"y":ts.ravel()
|
41 |
+
}
|
42 |
+
)
|
43 |
+
|
44 |
+
st.line_chart(df,x="x",y="y")
|
45 |
+
st.write(ts)
|
46 |
+
# st.write(value)
|
47 |
+
# st.image(value, use_column_width="always")
|
48 |
+
|
49 |
+
# x = st.slider("Select a value")
|
50 |
+
# st.write(x, "squared is", x * x)
|
disvae/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from disvae.evaluate import Evaluator
|
2 |
+
from disvae.main import get_kl_dict
|
3 |
+
from disvae.training import Trainer
|
4 |
+
from disvae.utils.modelIO import load_model, save_model
|
5 |
+
|
6 |
+
# from disvae.models.vae import init_specific_model # notwendig?
|
disvae/evaluate.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from functools import reduce
|
5 |
+
from collections import defaultdict
|
6 |
+
import json
|
7 |
+
from timeit import default_timer
|
8 |
+
|
9 |
+
from tqdm import trange, tqdm
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from disvae.models.losses import get_loss_f
|
14 |
+
from disvae.utils.math import log_density_gaussian
|
15 |
+
from disvae.utils.modelIO import save_metadata
|
16 |
+
|
17 |
+
TEST_LOSSES_FILE = "test_losses.log"
|
18 |
+
METRICS_FILENAME = "metrics.log"
|
19 |
+
METRIC_HELPERS_FILE = "metric_helpers.pth"
|
20 |
+
|
21 |
+
|
22 |
+
class Evaluator:
|
23 |
+
"""
|
24 |
+
Class to handle training of model.
|
25 |
+
|
26 |
+
Parameters
|
27 |
+
----------
|
28 |
+
model: disvae.vae.VAE
|
29 |
+
|
30 |
+
loss_f: disvae.models.BaseLoss
|
31 |
+
Loss function.
|
32 |
+
|
33 |
+
device: torch.device, optional
|
34 |
+
Device on which to run the code.
|
35 |
+
|
36 |
+
logger: logging.Logger, optional
|
37 |
+
Logger.
|
38 |
+
|
39 |
+
save_dir : str, optional
|
40 |
+
Directory for saving logs.
|
41 |
+
|
42 |
+
is_progress_bar: bool, optional
|
43 |
+
Whether to use a progress bar for training.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, model, loss_f,
|
47 |
+
device=torch.device("cpu"),
|
48 |
+
logger=logging.getLogger(__name__),
|
49 |
+
save_dir="results",
|
50 |
+
is_progress_bar=True):
|
51 |
+
|
52 |
+
self.device = device
|
53 |
+
self.loss_f = loss_f
|
54 |
+
self.model = model.to(self.device)
|
55 |
+
self.logger = logger
|
56 |
+
self.save_dir = save_dir
|
57 |
+
self.is_progress_bar = is_progress_bar
|
58 |
+
self.logger.info("Testing Device: {}".format(self.device))
|
59 |
+
|
60 |
+
def __call__(self, data_loader, is_metrics=False, is_losses=True):
|
61 |
+
"""Compute all test losses.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
data_loader: torch.utils.data.DataLoader
|
66 |
+
|
67 |
+
is_metrics: bool, optional
|
68 |
+
Whether to compute and store the disentangling metrics.
|
69 |
+
|
70 |
+
is_losses: bool, optional
|
71 |
+
Whether to compute and store the test losses.
|
72 |
+
"""
|
73 |
+
start = default_timer()
|
74 |
+
is_still_training = self.model.training
|
75 |
+
self.model.eval()
|
76 |
+
|
77 |
+
metric, losses = None, None
|
78 |
+
if is_metrics:
|
79 |
+
self.logger.info('Computing metrics...')
|
80 |
+
metrics = self.compute_metrics(data_loader)
|
81 |
+
self.logger.info('Losses: {}'.format(metrics))
|
82 |
+
save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)
|
83 |
+
|
84 |
+
if is_losses:
|
85 |
+
self.logger.info('Computing losses...')
|
86 |
+
losses = self.compute_losses(data_loader)
|
87 |
+
self.logger.info('Losses: {}'.format(losses))
|
88 |
+
save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)
|
89 |
+
|
90 |
+
if is_still_training:
|
91 |
+
self.model.train()
|
92 |
+
|
93 |
+
self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))
|
94 |
+
|
95 |
+
return metric, losses
|
96 |
+
|
97 |
+
def compute_losses(self, dataloader):
|
98 |
+
"""Compute all test losses.
|
99 |
+
|
100 |
+
Parameters
|
101 |
+
----------
|
102 |
+
data_loader: torch.utils.data.DataLoader
|
103 |
+
"""
|
104 |
+
storer = defaultdict(list)
|
105 |
+
for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
|
106 |
+
data = data.to(self.device)
|
107 |
+
|
108 |
+
try:
|
109 |
+
recon_batch, latent_dist, latent_sample = self.model(data)
|
110 |
+
_ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
|
111 |
+
storer, latent_sample=latent_sample)
|
112 |
+
except ValueError:
|
113 |
+
# for losses that use multiple optimizers (e.g. Factor)
|
114 |
+
_ = self.loss_f.call_optimize(data, self.model, None, storer)
|
115 |
+
|
116 |
+
losses = {k: sum(v) / len(dataloader) for k, v in storer.items()}
|
117 |
+
return losses
|
118 |
+
|
119 |
+
def compute_metrics(self, dataloader):
|
120 |
+
"""Compute all the metrics.
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
data_loader: torch.utils.data.DataLoader
|
125 |
+
"""
|
126 |
+
try:
|
127 |
+
lat_sizes = dataloader.dataset.lat_sizes
|
128 |
+
lat_names = dataloader.dataset.lat_names
|
129 |
+
except AttributeError:
|
130 |
+
raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__))
|
131 |
+
|
132 |
+
self.logger.info("Computing the empirical distribution q(z|x).")
|
133 |
+
samples_zCx, params_zCx = self._compute_q_zCx(dataloader)
|
134 |
+
len_dataset, latent_dim = samples_zCx.shape
|
135 |
+
|
136 |
+
self.logger.info("Estimating the marginal entropy.")
|
137 |
+
# marginal entropy H(z_j)
|
138 |
+
H_z = self._estimate_latent_entropies(samples_zCx, params_zCx)
|
139 |
+
|
140 |
+
# conditional entropy H(z|v)
|
141 |
+
samples_zCx = samples_zCx.view(*lat_sizes, latent_dim)
|
142 |
+
params_zCx = tuple(p.view(*lat_sizes, latent_dim) for p in params_zCx)
|
143 |
+
H_zCv = self._estimate_H_zCv(samples_zCx, params_zCx, lat_sizes, lat_names)
|
144 |
+
|
145 |
+
H_z = H_z.cpu()
|
146 |
+
H_zCv = H_zCv.cpu()
|
147 |
+
|
148 |
+
# I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j]
|
149 |
+
mut_info = - H_zCv + H_z
|
150 |
+
sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0)
|
151 |
+
|
152 |
+
metric_helpers = {'marginal_entropies': H_z, 'cond_entropies': H_zCv}
|
153 |
+
mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers)
|
154 |
+
aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers)
|
155 |
+
|
156 |
+
metrics = {'MIG': mig.item(), 'AAM': aam.item()}
|
157 |
+
torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
|
158 |
+
|
159 |
+
return metrics
|
160 |
+
|
161 |
+
def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None):
|
162 |
+
"""Compute the mutual information gap as in [1].
|
163 |
+
|
164 |
+
References
|
165 |
+
----------
|
166 |
+
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
|
167 |
+
autoencoders." Advances in Neural Information Processing Systems. 2018.
|
168 |
+
"""
|
169 |
+
# difference between the largest and second largest mutual info
|
170 |
+
delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1]
|
171 |
+
# NOTE: currently only works if balanced dataset for every factor of variation
|
172 |
+
# then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|)
|
173 |
+
H_v = torch.from_numpy(lat_sizes).float().log()
|
174 |
+
mig_k = delta_mut_info / H_v
|
175 |
+
mig = mig_k.mean() # mean over factor of variations
|
176 |
+
|
177 |
+
if storer is not None:
|
178 |
+
storer["mig_k"] = mig_k
|
179 |
+
storer["mig"] = mig
|
180 |
+
|
181 |
+
return mig
|
182 |
+
|
183 |
+
def _axis_aligned_metric(self, sorted_mut_info, storer=None):
|
184 |
+
"""Compute the proposed axis aligned metrics."""
|
185 |
+
numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0)
|
186 |
+
aam_k = numerator / sorted_mut_info[:, 0]
|
187 |
+
aam_k[torch.isnan(aam_k)] = 0
|
188 |
+
aam = aam_k.mean() # mean over factor of variations
|
189 |
+
|
190 |
+
if storer is not None:
|
191 |
+
storer["aam_k"] = aam_k
|
192 |
+
storer["aam"] = aam
|
193 |
+
|
194 |
+
return aam
|
195 |
+
|
196 |
+
def _compute_q_zCx(self, dataloader):
|
197 |
+
"""Compute the empiricall disitribution of q(z|x).
|
198 |
+
|
199 |
+
Parameter
|
200 |
+
---------
|
201 |
+
dataloader: torch.utils.data.DataLoader
|
202 |
+
Batch data iterator.
|
203 |
+
|
204 |
+
Return
|
205 |
+
------
|
206 |
+
samples_zCx: torch.tensor
|
207 |
+
Tensor of shape (len_dataset, latent_dim) containing a sample of
|
208 |
+
q(z|x) for every x in the dataset.
|
209 |
+
|
210 |
+
params_zCX: tuple of torch.Tensor
|
211 |
+
Sufficient statistics q(z|x) for each training example. E.g. for
|
212 |
+
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
|
213 |
+
"""
|
214 |
+
len_dataset = len(dataloader.dataset)
|
215 |
+
latent_dim = self.model.latent_dim
|
216 |
+
n_suff_stat = 2
|
217 |
+
|
218 |
+
q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
|
219 |
+
|
220 |
+
n = 0
|
221 |
+
with torch.no_grad():
|
222 |
+
for x, label in dataloader:
|
223 |
+
batch_size = x.size(0)
|
224 |
+
idcs = slice(n, n + batch_size)
|
225 |
+
q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
|
226 |
+
n += batch_size
|
227 |
+
|
228 |
+
params_zCX = q_zCx.unbind(-1)
|
229 |
+
samples_zCx = self.model.reparameterize(*params_zCX)
|
230 |
+
|
231 |
+
return samples_zCx, params_zCX
|
232 |
+
|
233 |
+
def _estimate_latent_entropies(self, samples_zCx, params_zCX,
|
234 |
+
n_samples=10000):
|
235 |
+
r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]`
|
236 |
+
using the emperical distribution of :math:`p(x)`.
|
237 |
+
|
238 |
+
Note
|
239 |
+
----
|
240 |
+
- the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`.
|
241 |
+
- we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`.
|
242 |
+
- computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`.
|
243 |
+
|
244 |
+
Parameters
|
245 |
+
----------
|
246 |
+
samples_zCx: torch.tensor
|
247 |
+
Tensor of shape (len_dataset, latent_dim) containing a sample of
|
248 |
+
q(z|x) for every x in the dataset.
|
249 |
+
|
250 |
+
params_zCX: tuple of torch.Tensor
|
251 |
+
Sufficient statistics q(z|x) for each training example. E.g. for
|
252 |
+
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
|
253 |
+
|
254 |
+
n_samples: int, optional
|
255 |
+
Number of samples to use to estimate the entropies.
|
256 |
+
|
257 |
+
Return
|
258 |
+
------
|
259 |
+
H_z: torch.Tensor
|
260 |
+
Tensor of shape (latent_dim) containing the marginal entropies H(z_j)
|
261 |
+
"""
|
262 |
+
len_dataset, latent_dim = samples_zCx.shape
|
263 |
+
device = samples_zCx.device
|
264 |
+
H_z = torch.zeros(latent_dim, device=device)
|
265 |
+
|
266 |
+
# sample from p(x)
|
267 |
+
samples_x = torch.randperm(len_dataset, device=device)[:n_samples]
|
268 |
+
# sample from p(z|x)
|
269 |
+
samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples)
|
270 |
+
|
271 |
+
mini_batch_size = 10
|
272 |
+
samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
|
273 |
+
mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
|
274 |
+
log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
|
275 |
+
log_N = math.log(len_dataset)
|
276 |
+
with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
|
277 |
+
for k in range(0, n_samples, mini_batch_size):
|
278 |
+
# log q(z_j|x) for n_samples
|
279 |
+
idcs = slice(k, k + mini_batch_size)
|
280 |
+
log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
|
281 |
+
mean[..., idcs],
|
282 |
+
log_var[..., idcs])
|
283 |
+
# numerically stable log q(z_j) for n_samples:
|
284 |
+
# log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
|
285 |
+
# As we don't know q(z) we appoximate it with the monte carlo
|
286 |
+
# expectation of q(z_j|x_n) over x. => fix a single z and look at
|
287 |
+
# proba for every x to generate it. n_samples is not used here !
|
288 |
+
log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False)
|
289 |
+
# H(z_j) = E_{z_j}[- log q(z_j)]
|
290 |
+
# mean over n_samples (i.e. dimesnion 1 because already summed over 0).
|
291 |
+
H_z += (-log_q_z).sum(1)
|
292 |
+
|
293 |
+
t.update(mini_batch_size)
|
294 |
+
|
295 |
+
H_z /= n_samples
|
296 |
+
|
297 |
+
return H_z
|
298 |
+
|
299 |
+
def _estimate_H_zCv(self, samples_zCx, params_zCx, lat_sizes, lat_names):
|
300 |
+
"""Estimate conditional entropies :math:`H[z|v]`."""
|
301 |
+
latent_dim = samples_zCx.size(-1)
|
302 |
+
len_dataset = reduce((lambda x, y: x * y), lat_sizes)
|
303 |
+
H_zCv = torch.zeros(len(lat_sizes), latent_dim, device=self.device)
|
304 |
+
for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)):
|
305 |
+
idcs = [slice(None)] * len(lat_sizes)
|
306 |
+
for i in range(lat_size):
|
307 |
+
self.logger.info("Estimating conditional entropies for the {}th value of {}.".format(i, lat_name))
|
308 |
+
idcs[i_fac_var] = i
|
309 |
+
# samples from q(z,x|v)
|
310 |
+
samples_zxCv = samples_zCx[idcs].contiguous().view(len_dataset // lat_size,
|
311 |
+
latent_dim)
|
312 |
+
params_zxCv = tuple(p[idcs].contiguous().view(len_dataset // lat_size, latent_dim)
|
313 |
+
for p in params_zCx)
|
314 |
+
|
315 |
+
H_zCv[i_fac_var] += self._estimate_latent_entropies(samples_zxCv, params_zxCv
|
316 |
+
) / lat_size
|
317 |
+
return H_zCv
|
disvae/main.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pip-Packages -----------------------------------------------------
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from datetime import datetime
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
from torch import optim
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# From local package -----------------------------------------------
|
15 |
+
from disvae.models.losses import get_loss_f
|
16 |
+
from disvae.models.vae import init_specific_model
|
17 |
+
from disvae.training import Trainer
|
18 |
+
from disvae.utils.modelIO import save_model
|
19 |
+
|
20 |
+
# Loss stuff:
|
21 |
+
|
22 |
+
|
23 |
+
def parse_losses(p_model, filename="train_losses.log"):
|
24 |
+
df = pd.read_csv(Path(p_model) / filename)
|
25 |
+
|
26 |
+
losses = df["Loss"].unique()
|
27 |
+
|
28 |
+
rtn = [np.array(df[df["Loss"] == l]["Value"]) for l in losses]
|
29 |
+
rtn = pd.DataFrame(np.array(rtn).T, columns=losses)
|
30 |
+
|
31 |
+
return rtn
|
32 |
+
|
33 |
+
|
34 |
+
def get_kl_loss_latent(df):
|
35 |
+
"""df muss bereits geparsed sein!"""
|
36 |
+
rtn = {int(c.split("_")[-1]): df[c].iloc[-1] for c in df if "kl_loss_" in c}
|
37 |
+
rtn = dict(sorted(rtn.items(), key=lambda item: item[1], reverse=True))
|
38 |
+
return rtn
|
39 |
+
|
40 |
+
|
41 |
+
def get_kl_dict(p_model):
|
42 |
+
df = parse_losses(p_model)
|
43 |
+
rtn = get_kl_loss_latent(df)
|
44 |
+
return rtn
|
45 |
+
|
46 |
+
|
47 |
+
# Datalaader convinience stuff
|
48 |
+
|
49 |
+
|
50 |
+
# def get_dataloader(dataset: torch.data.Dataset, batch_size, num_workers):
|
51 |
+
# # Funktion ist recht kompliziert. Das geht im Notebook schnell
|
52 |
+
# # Diese Dinge werden auch zur Visualisierung des Datasets benötigt
|
53 |
+
|
54 |
+
# # p_dataset_module, dataset_class, dataset_args
|
55 |
+
# # Import module
|
56 |
+
# # if p_dataset_module not in sys.path:
|
57 |
+
# # sys.path.append(str(Path(p_dataset_module).parent))
|
58 |
+
# # Dataset = getattr(
|
59 |
+
# # importlib.import_module(Path(p_dataset_module).stem), dataset_class
|
60 |
+
# # )
|
61 |
+
|
62 |
+
# # # Ab hier an, wenn das normal importiert würde
|
63 |
+
# # ds = Dataset(**dataset_args)
|
64 |
+
|
65 |
+
#
|
66 |
+
|
67 |
+
# return loader
|
68 |
+
|
69 |
+
|
70 |
+
def get_export_dir(base_dir: str, folder_name):
|
71 |
+
if folder_name is None:
|
72 |
+
folder_name = "Model_" + (
|
73 |
+
datetime.now().replace(microsecond=0).isoformat()
|
74 |
+
).replace(" ", "_").replace(":", "-")
|
75 |
+
|
76 |
+
rtn = Path(base_dir) / folder_name
|
77 |
+
|
78 |
+
if not rtn.exists():
|
79 |
+
os.makedirs(rtn)
|
80 |
+
else:
|
81 |
+
raise ValueError("Output directory already exists.")
|
82 |
+
|
83 |
+
return rtn
|
84 |
+
|
85 |
+
|
86 |
+
def train_model(model, data_loader, loss_f, device, lr, epochs, export_dir):
|
87 |
+
trainer = Trainer(
|
88 |
+
model,
|
89 |
+
optim.Adam(model.parameters(), lr=lr),
|
90 |
+
loss_f,
|
91 |
+
device=device,
|
92 |
+
# logger=logger,
|
93 |
+
save_dir=export_dir,
|
94 |
+
is_progress_bar=True,
|
95 |
+
) # ,
|
96 |
+
# gif_visualizer=gif_visualizer)
|
97 |
+
trainer(data_loader, epochs=epochs, checkpoint_every=10)
|
98 |
+
|
99 |
+
save_model(trainer.model, export_dir)
|
100 |
+
# , metadata=config) # Speichern passiert auch schon vorher
|
101 |
+
|
102 |
+
# gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
|
103 |
+
|
104 |
+
|
105 |
+
def train(dataset, config) -> str:
|
106 |
+
# Validate Config?
|
107 |
+
|
108 |
+
print("1) Set device")
|
109 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
110 |
+
print(f"Device:\t\t {device}")
|
111 |
+
|
112 |
+
print("2) Get dataloader")
|
113 |
+
dataloader = DataLoader(
|
114 |
+
dataset,
|
115 |
+
batch_size=config["data_params"]["batch_size"],
|
116 |
+
shuffle=True,
|
117 |
+
pin_memory=torch.cuda.is_available,
|
118 |
+
num_workers=config["data_params"]["num_workers"],
|
119 |
+
)
|
120 |
+
|
121 |
+
print("3) Build model")
|
122 |
+
img_size = list(dataloader.dataset[0][0].shape)
|
123 |
+
print(f"Image size: \t {img_size}")
|
124 |
+
model = init_specific_model(img_size=img_size, **config["model_params"])
|
125 |
+
model = model.to(device) # make sure trainer and viz on same device
|
126 |
+
|
127 |
+
print("4) Build loss function")
|
128 |
+
loss_f = get_loss_f(
|
129 |
+
n_data=len(dataloader.dataset), device=device, **config["loss_params"]
|
130 |
+
)
|
131 |
+
|
132 |
+
print("5) Parse Export Params")
|
133 |
+
export_dir = get_export_dir(**config["export_params"])
|
134 |
+
|
135 |
+
print("6) Training model")
|
136 |
+
train_model(
|
137 |
+
model=model,
|
138 |
+
data_loader=dataloader,
|
139 |
+
loss_f=loss_f,
|
140 |
+
device=device,
|
141 |
+
export_dir=export_dir,
|
142 |
+
**config["trainer_params"],
|
143 |
+
)
|
144 |
+
|
145 |
+
return export_dir
|
disvae/models/__init__.py
ADDED
File without changes
|
disvae/models/decoders.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the decoders.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
# ALL decoders should be called Decoder<Model>
|
11 |
+
def get_decoder(model_type):
|
12 |
+
model_type = model_type.lower().capitalize()
|
13 |
+
return eval("Decoder{}".format(model_type))
|
14 |
+
|
15 |
+
|
16 |
+
class DecoderBurgess(nn.Module):
|
17 |
+
def __init__(self, img_size,
|
18 |
+
latent_dim=10):
|
19 |
+
r"""Decoder of the model proposed in [1].
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
img_size : tuple of ints
|
24 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
25 |
+
|
26 |
+
latent_dim : int
|
27 |
+
Dimensionality of latent output.
|
28 |
+
|
29 |
+
Model Architecture (transposed for decoder)
|
30 |
+
------------
|
31 |
+
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
|
32 |
+
- 2 fully connected layers (each of 256 units)
|
33 |
+
- Latent distribution:
|
34 |
+
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
38 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
39 |
+
"""
|
40 |
+
super(DecoderBurgess, self).__init__()
|
41 |
+
|
42 |
+
# Layer parameters
|
43 |
+
hid_channels = 32
|
44 |
+
kernel_size = 4
|
45 |
+
hidden_dim = 256
|
46 |
+
self.img_size = img_size
|
47 |
+
# Shape required to start transpose convs
|
48 |
+
self.reshape = (hid_channels, kernel_size, kernel_size)
|
49 |
+
n_chan = self.img_size[0]
|
50 |
+
self.img_size = img_size
|
51 |
+
|
52 |
+
# Fully connected layers
|
53 |
+
self.lin1 = nn.Linear(latent_dim, hidden_dim)
|
54 |
+
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
|
55 |
+
self.lin3 = nn.Linear(hidden_dim, np.product(self.reshape))
|
56 |
+
|
57 |
+
# Convolutional layers
|
58 |
+
cnn_kwargs = dict(stride=2, padding=1)
|
59 |
+
# If input image is 64x64 do fourth convolution
|
60 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
61 |
+
self.convT_64 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
62 |
+
|
63 |
+
self.convT1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
64 |
+
self.convT2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
65 |
+
self.convT3 = nn.ConvTranspose2d(hid_channels, n_chan, kernel_size, **cnn_kwargs)
|
66 |
+
|
67 |
+
def forward(self, z):
|
68 |
+
batch_size = z.size(0)
|
69 |
+
|
70 |
+
# Fully connected layers with ReLu activations
|
71 |
+
x = torch.relu(self.lin1(z))
|
72 |
+
x = torch.relu(self.lin2(x))
|
73 |
+
x = torch.relu(self.lin3(x))
|
74 |
+
x = x.view(batch_size, *self.reshape)
|
75 |
+
|
76 |
+
# Convolutional layers with ReLu activations
|
77 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
78 |
+
x = torch.relu(self.convT_64(x))
|
79 |
+
x = torch.relu(self.convT1(x))
|
80 |
+
x = torch.relu(self.convT2(x))
|
81 |
+
# Sigmoid activation for final conv layer
|
82 |
+
x = torch.sigmoid(self.convT3(x))
|
83 |
+
|
84 |
+
return x
|
disvae/models/discriminator.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing discriminator for FactorVAE.
|
3 |
+
"""
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from disvae.utils.initialization import weights_init
|
7 |
+
|
8 |
+
|
9 |
+
class Discriminator(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
neg_slope=0.2,
|
12 |
+
latent_dim=10,
|
13 |
+
hidden_units=1000):
|
14 |
+
"""Discriminator proposed in [1].
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
neg_slope: float
|
19 |
+
Hyperparameter for the Leaky ReLu
|
20 |
+
|
21 |
+
latent_dim : int
|
22 |
+
Dimensionality of latent variables.
|
23 |
+
|
24 |
+
hidden_units: int
|
25 |
+
Number of hidden units in the MLP
|
26 |
+
|
27 |
+
Model Architecture
|
28 |
+
------------
|
29 |
+
- 6 layer multi-layer perceptron, each with 1000 hidden units
|
30 |
+
- Leaky ReLu activations
|
31 |
+
- Output 2 logits
|
32 |
+
|
33 |
+
References:
|
34 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
35 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
36 |
+
|
37 |
+
"""
|
38 |
+
super(Discriminator, self).__init__()
|
39 |
+
|
40 |
+
# Activation parameters
|
41 |
+
self.neg_slope = neg_slope
|
42 |
+
self.leaky_relu = nn.LeakyReLU(self.neg_slope, True)
|
43 |
+
|
44 |
+
# Layer parameters
|
45 |
+
self.z_dim = latent_dim
|
46 |
+
self.hidden_units = hidden_units
|
47 |
+
# theoretically 1 with sigmoid but gives bad results => use 2 and softmax
|
48 |
+
out_units = 2
|
49 |
+
|
50 |
+
# Fully connected layers
|
51 |
+
self.lin1 = nn.Linear(self.z_dim, hidden_units)
|
52 |
+
self.lin2 = nn.Linear(hidden_units, hidden_units)
|
53 |
+
self.lin3 = nn.Linear(hidden_units, hidden_units)
|
54 |
+
self.lin4 = nn.Linear(hidden_units, hidden_units)
|
55 |
+
self.lin5 = nn.Linear(hidden_units, hidden_units)
|
56 |
+
self.lin6 = nn.Linear(hidden_units, out_units)
|
57 |
+
|
58 |
+
self.reset_parameters()
|
59 |
+
|
60 |
+
def forward(self, z):
|
61 |
+
|
62 |
+
# Fully connected layers with leaky ReLu activations
|
63 |
+
z = self.leaky_relu(self.lin1(z))
|
64 |
+
z = self.leaky_relu(self.lin2(z))
|
65 |
+
z = self.leaky_relu(self.lin3(z))
|
66 |
+
z = self.leaky_relu(self.lin4(z))
|
67 |
+
z = self.leaky_relu(self.lin5(z))
|
68 |
+
z = self.lin6(z)
|
69 |
+
|
70 |
+
return z
|
71 |
+
|
72 |
+
def reset_parameters(self):
|
73 |
+
self.apply(weights_init)
|
disvae/models/encoders.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the encoders.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
# ALL encoders should be called Enccoder<Model>
|
11 |
+
def get_encoder(model_type):
|
12 |
+
model_type = model_type.lower().capitalize()
|
13 |
+
return eval("Encoder{}".format(model_type))
|
14 |
+
|
15 |
+
|
16 |
+
class EncoderBurgess(nn.Module):
|
17 |
+
def __init__(self, img_size,
|
18 |
+
latent_dim=10):
|
19 |
+
r"""Encoder of the model proposed in [1].
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
img_size : tuple of ints
|
24 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
25 |
+
|
26 |
+
latent_dim : int
|
27 |
+
Dimensionality of latent output.
|
28 |
+
|
29 |
+
Model Architecture (transposed for decoder)
|
30 |
+
------------
|
31 |
+
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
|
32 |
+
- 2 fully connected layers (each of 256 units)
|
33 |
+
- Latent distribution:
|
34 |
+
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
38 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
39 |
+
"""
|
40 |
+
super(EncoderBurgess, self).__init__()
|
41 |
+
|
42 |
+
# Layer parameters
|
43 |
+
hid_channels = 32
|
44 |
+
kernel_size = 4
|
45 |
+
hidden_dim = 256
|
46 |
+
self.latent_dim = latent_dim
|
47 |
+
self.img_size = img_size
|
48 |
+
# Shape required to start transpose convs
|
49 |
+
self.reshape = (hid_channels, kernel_size, kernel_size)
|
50 |
+
n_chan = self.img_size[0]
|
51 |
+
|
52 |
+
# Convolutional layers
|
53 |
+
cnn_kwargs = dict(stride=2, padding=1)
|
54 |
+
self.conv1 = nn.Conv2d(n_chan, hid_channels, kernel_size, **cnn_kwargs)
|
55 |
+
self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
56 |
+
self.conv3 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
57 |
+
|
58 |
+
# If input image is 64x64 do fourth convolution
|
59 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
60 |
+
self.conv_64 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
61 |
+
|
62 |
+
# Fully connected layers
|
63 |
+
self.lin1 = nn.Linear(np.product(self.reshape), hidden_dim)
|
64 |
+
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
|
65 |
+
|
66 |
+
# Fully connected layers for mean and variance
|
67 |
+
self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
batch_size = x.size(0)
|
71 |
+
|
72 |
+
# Convolutional layers with ReLu activations
|
73 |
+
x = torch.relu(self.conv1(x))
|
74 |
+
x = torch.relu(self.conv2(x))
|
75 |
+
x = torch.relu(self.conv3(x))
|
76 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
77 |
+
x = torch.relu(self.conv_64(x))
|
78 |
+
|
79 |
+
# Fully connected layers with ReLu activations
|
80 |
+
x = x.view((batch_size, -1))
|
81 |
+
x = torch.relu(self.lin1(x))
|
82 |
+
x = torch.relu(self.lin2(x))
|
83 |
+
|
84 |
+
# Fully connected layer for log variance and mean
|
85 |
+
# Log std-dev in paper (bear in mind)
|
86 |
+
mu_logvar = self.mu_logvar_gen(x)
|
87 |
+
mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)
|
88 |
+
|
89 |
+
return mu, logvar
|
disvae/models/losses.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing all vae losses.
|
3 |
+
"""
|
4 |
+
import abc
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch import optim
|
11 |
+
|
12 |
+
from .discriminator import Discriminator
|
13 |
+
from disvae.utils.math import (log_density_gaussian, log_importance_weight_matrix,
|
14 |
+
matrix_log_density_gaussian)
|
15 |
+
|
16 |
+
|
17 |
+
LOSSES = ["VAE", "betaH", "betaB", "factor", "btcvae"]
|
18 |
+
RECON_DIST = ["bernoulli", "laplace", "gaussian"]
|
19 |
+
|
20 |
+
|
21 |
+
# TO-DO: clean n_data and device
|
22 |
+
def get_loss_f(loss_name, **kwargs_parse):
|
23 |
+
"""Return the correct loss function given the argparse arguments."""
|
24 |
+
kwargs_all = dict(rec_dist=kwargs_parse["rec_dist"],
|
25 |
+
steps_anneal=kwargs_parse["reg_anneal"])
|
26 |
+
if loss_name == "betaH":
|
27 |
+
return BetaHLoss(beta=kwargs_parse["betaH_B"], **kwargs_all)
|
28 |
+
elif loss_name == "VAE":
|
29 |
+
return BetaHLoss(beta=1, **kwargs_all)
|
30 |
+
elif loss_name == "betaB":
|
31 |
+
return BetaBLoss(C_init=kwargs_parse["betaB_initC"],
|
32 |
+
C_fin=kwargs_parse["betaB_finC"],
|
33 |
+
gamma=kwargs_parse["betaB_G"],
|
34 |
+
**kwargs_all)
|
35 |
+
elif loss_name == "factor":
|
36 |
+
return FactorKLoss(kwargs_parse["device"],
|
37 |
+
gamma=kwargs_parse["factor_G"],
|
38 |
+
disc_kwargs=dict(latent_dim=kwargs_parse["latent_dim"]),
|
39 |
+
optim_kwargs=dict(lr=kwargs_parse["lr_disc"], betas=(0.5, 0.9)),
|
40 |
+
**kwargs_all)
|
41 |
+
elif loss_name == "btcvae":
|
42 |
+
return BtcvaeLoss(kwargs_parse["n_data"],
|
43 |
+
alpha=kwargs_parse["btcvae_A"],
|
44 |
+
beta=kwargs_parse["btcvae_B"],
|
45 |
+
gamma=kwargs_parse["btcvae_G"],
|
46 |
+
**kwargs_all)
|
47 |
+
else:
|
48 |
+
assert loss_name not in LOSSES
|
49 |
+
raise ValueError("Uknown loss : {}".format(loss_name))
|
50 |
+
|
51 |
+
|
52 |
+
class BaseLoss(abc.ABC):
|
53 |
+
"""
|
54 |
+
Base class for losses.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
record_loss_every: int, optional
|
59 |
+
Every how many steps to recorsd the loss.
|
60 |
+
|
61 |
+
rec_dist: {"bernoulli", "gaussian", "laplace"}, optional
|
62 |
+
Reconstruction distribution istribution of the likelihood on the each pixel.
|
63 |
+
Implicitely defines the reconstruction loss. Bernoulli corresponds to a
|
64 |
+
binary cross entropy (bse), Gaussian corresponds to MSE, Laplace
|
65 |
+
corresponds to L1.
|
66 |
+
|
67 |
+
steps_anneal: nool, optional
|
68 |
+
Number of annealing steps where gradually adding the regularisation.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, record_loss_every=50, rec_dist="bernoulli", steps_anneal=0):
|
72 |
+
self.n_train_steps = 0
|
73 |
+
self.record_loss_every = record_loss_every
|
74 |
+
self.rec_dist = rec_dist
|
75 |
+
self.steps_anneal = steps_anneal
|
76 |
+
|
77 |
+
@abc.abstractmethod
|
78 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
79 |
+
"""
|
80 |
+
Calculates loss for a batch of data.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
data : torch.Tensor
|
85 |
+
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
|
86 |
+
height, width).
|
87 |
+
|
88 |
+
recon_data : torch.Tensor
|
89 |
+
Reconstructed data. Shape : (batch_size, n_chan, height, width).
|
90 |
+
|
91 |
+
latent_dist : tuple of torch.tensor
|
92 |
+
sufficient statistics of the latent dimension. E.g. for gaussian
|
93 |
+
(mean, log_var) each of shape : (batch_size, latent_dim).
|
94 |
+
|
95 |
+
is_train : bool
|
96 |
+
Whether currently in train mode.
|
97 |
+
|
98 |
+
storer : dict
|
99 |
+
Dictionary in which to store important variables for vizualisation.
|
100 |
+
|
101 |
+
kwargs:
|
102 |
+
Loss specific arguments
|
103 |
+
"""
|
104 |
+
|
105 |
+
def _pre_call(self, is_train, storer):
|
106 |
+
if is_train:
|
107 |
+
self.n_train_steps += 1
|
108 |
+
|
109 |
+
if not is_train or self.n_train_steps % self.record_loss_every == 1:
|
110 |
+
storer = storer
|
111 |
+
else:
|
112 |
+
storer = None
|
113 |
+
|
114 |
+
return storer
|
115 |
+
|
116 |
+
|
117 |
+
class BetaHLoss(BaseLoss):
|
118 |
+
"""
|
119 |
+
Compute the Beta-VAE loss as in [1]
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
beta : float, optional
|
124 |
+
Weight of the kl divergence.
|
125 |
+
|
126 |
+
kwargs:
|
127 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
128 |
+
|
129 |
+
References
|
130 |
+
----------
|
131 |
+
[1] Higgins, Irina, et al. "beta-vae: Learning basic visual concepts with
|
132 |
+
a constrained variational framework." (2016).
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, beta=4, **kwargs):
|
136 |
+
super().__init__(**kwargs)
|
137 |
+
self.beta = beta
|
138 |
+
|
139 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
140 |
+
storer = self._pre_call(is_train, storer)
|
141 |
+
|
142 |
+
rec_loss = _reconstruction_loss(data, recon_data,
|
143 |
+
storer=storer,
|
144 |
+
distribution=self.rec_dist)
|
145 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
146 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
147 |
+
if is_train else 1)
|
148 |
+
loss = rec_loss + anneal_reg * (self.beta * kl_loss)
|
149 |
+
|
150 |
+
if storer is not None:
|
151 |
+
storer['loss'].append(loss.item())
|
152 |
+
|
153 |
+
return loss
|
154 |
+
|
155 |
+
|
156 |
+
class BetaBLoss(BaseLoss):
|
157 |
+
"""
|
158 |
+
Compute the Beta-VAE loss as in [1]
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
C_init : float, optional
|
163 |
+
Starting annealed capacity C.
|
164 |
+
|
165 |
+
C_fin : float, optional
|
166 |
+
Final annealed capacity C.
|
167 |
+
|
168 |
+
gamma : float, optional
|
169 |
+
Weight of the KL divergence term.
|
170 |
+
|
171 |
+
kwargs:
|
172 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
173 |
+
|
174 |
+
References
|
175 |
+
----------
|
176 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
177 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, C_init=0., C_fin=20., gamma=100., **kwargs):
|
181 |
+
super().__init__(**kwargs)
|
182 |
+
self.gamma = gamma
|
183 |
+
self.C_init = C_init
|
184 |
+
self.C_fin = C_fin
|
185 |
+
|
186 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
187 |
+
storer = self._pre_call(is_train, storer)
|
188 |
+
|
189 |
+
rec_loss = _reconstruction_loss(data, recon_data,
|
190 |
+
storer=storer,
|
191 |
+
distribution=self.rec_dist)
|
192 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
193 |
+
|
194 |
+
C = (linear_annealing(self.C_init, self.C_fin, self.n_train_steps, self.steps_anneal)
|
195 |
+
if is_train else self.C_fin)
|
196 |
+
|
197 |
+
loss = rec_loss + self.gamma * (kl_loss - C).abs()
|
198 |
+
|
199 |
+
if storer is not None:
|
200 |
+
storer['loss'].append(loss.item())
|
201 |
+
|
202 |
+
return loss
|
203 |
+
|
204 |
+
|
205 |
+
class FactorKLoss(BaseLoss):
|
206 |
+
"""
|
207 |
+
Compute the Factor-VAE loss as per Algorithm 2 of [1]
|
208 |
+
|
209 |
+
Parameters
|
210 |
+
----------
|
211 |
+
device : torch.device
|
212 |
+
|
213 |
+
gamma : float, optional
|
214 |
+
Weight of the TC loss term. `gamma` in the paper.
|
215 |
+
|
216 |
+
discriminator : disvae.discriminator.Discriminator
|
217 |
+
|
218 |
+
optimizer_d : torch.optim
|
219 |
+
|
220 |
+
kwargs:
|
221 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
222 |
+
|
223 |
+
References
|
224 |
+
----------
|
225 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
226 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(self, device,
|
230 |
+
gamma=10.,
|
231 |
+
disc_kwargs={},
|
232 |
+
optim_kwargs=dict(lr=5e-5, betas=(0.5, 0.9)),
|
233 |
+
**kwargs):
|
234 |
+
super().__init__(**kwargs)
|
235 |
+
self.gamma = gamma
|
236 |
+
self.device = device
|
237 |
+
self.discriminator = Discriminator(**disc_kwargs).to(self.device)
|
238 |
+
self.optimizer_d = optim.Adam(self.discriminator.parameters(), **optim_kwargs)
|
239 |
+
|
240 |
+
def __call__(self, *args, **kwargs):
|
241 |
+
raise ValueError("Use `call_optimize` to also train the discriminator")
|
242 |
+
|
243 |
+
def call_optimize(self, data, model, optimizer, storer):
|
244 |
+
storer = self._pre_call(model.training, storer)
|
245 |
+
|
246 |
+
# factor-vae split data into two batches. In the paper they sample 2 batches
|
247 |
+
batch_size = data.size(dim=0)
|
248 |
+
half_batch_size = batch_size // 2
|
249 |
+
data = data.split(half_batch_size)
|
250 |
+
data1 = data[0]
|
251 |
+
data2 = data[1]
|
252 |
+
|
253 |
+
# Factor VAE Loss
|
254 |
+
recon_batch, latent_dist, latent_sample1 = model(data1)
|
255 |
+
rec_loss = _reconstruction_loss(data1, recon_batch,
|
256 |
+
storer=storer,
|
257 |
+
distribution=self.rec_dist)
|
258 |
+
|
259 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
260 |
+
|
261 |
+
d_z = self.discriminator(latent_sample1)
|
262 |
+
# We want log(p_true/p_false). If not using logisitc regression but softmax
|
263 |
+
# then p_true = exp(logit_true) / Z; p_false = exp(logit_false) / Z
|
264 |
+
# so log(p_true/p_false) = logit_true - logit_false
|
265 |
+
tc_loss = (d_z[:, 0] - d_z[:, 1]).mean()
|
266 |
+
# with sigmoid (not good results) should be `tc_loss = (2 * d_z.flatten()).mean()`
|
267 |
+
|
268 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
269 |
+
if model.training else 1)
|
270 |
+
vae_loss = rec_loss + kl_loss + anneal_reg * self.gamma * tc_loss
|
271 |
+
|
272 |
+
if storer is not None:
|
273 |
+
storer['loss'].append(vae_loss.item())
|
274 |
+
storer['tc_loss'].append(tc_loss.item())
|
275 |
+
|
276 |
+
if not model.training:
|
277 |
+
# don't backprop if evaluating
|
278 |
+
return vae_loss
|
279 |
+
|
280 |
+
# Compute VAE gradients
|
281 |
+
optimizer.zero_grad()
|
282 |
+
vae_loss.backward(retain_graph=True)
|
283 |
+
|
284 |
+
# Discriminator Loss
|
285 |
+
# Get second sample of latent distribution
|
286 |
+
latent_sample2 = model.sample_latent(data2)
|
287 |
+
z_perm = _permute_dims(latent_sample2).detach()
|
288 |
+
d_z_perm = self.discriminator(z_perm)
|
289 |
+
|
290 |
+
# Calculate total correlation loss
|
291 |
+
# for cross entropy the target is the index => need to be long and says
|
292 |
+
# that it's first output for d_z and second for perm
|
293 |
+
ones = torch.ones(half_batch_size, dtype=torch.long, device=self.device)
|
294 |
+
zeros = torch.zeros_like(ones)
|
295 |
+
d_tc_loss = 0.5 * (F.cross_entropy(d_z, zeros) + F.cross_entropy(d_z_perm, ones))
|
296 |
+
# with sigmoid would be :
|
297 |
+
# d_tc_loss = 0.5 * (self.bce(d_z.flatten(), ones) + self.bce(d_z_perm.flatten(), 1 - ones))
|
298 |
+
|
299 |
+
# TO-DO: check ifshould also anneals discriminator if not becomes too good ???
|
300 |
+
#d_tc_loss = anneal_reg * d_tc_loss
|
301 |
+
|
302 |
+
# Compute discriminator gradients
|
303 |
+
self.optimizer_d.zero_grad()
|
304 |
+
d_tc_loss.backward()
|
305 |
+
|
306 |
+
# Update at the end (since pytorch 1.5. complains if update before)
|
307 |
+
optimizer.step()
|
308 |
+
self.optimizer_d.step()
|
309 |
+
|
310 |
+
if storer is not None:
|
311 |
+
storer['discrim_loss'].append(d_tc_loss.item())
|
312 |
+
|
313 |
+
return vae_loss
|
314 |
+
|
315 |
+
|
316 |
+
class BtcvaeLoss(BaseLoss):
|
317 |
+
"""
|
318 |
+
Compute the decomposed KL loss with either minibatch weighted sampling or
|
319 |
+
minibatch stratified sampling according to [1]
|
320 |
+
|
321 |
+
Parameters
|
322 |
+
----------
|
323 |
+
n_data: int
|
324 |
+
Number of data in the training set
|
325 |
+
|
326 |
+
alpha : float
|
327 |
+
Weight of the mutual information term.
|
328 |
+
|
329 |
+
beta : float
|
330 |
+
Weight of the total correlation term.
|
331 |
+
|
332 |
+
gamma : float
|
333 |
+
Weight of the dimension-wise KL term.
|
334 |
+
|
335 |
+
is_mss : bool
|
336 |
+
Whether to use minibatch stratified sampling instead of minibatch
|
337 |
+
weighted sampling.
|
338 |
+
|
339 |
+
kwargs:
|
340 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
341 |
+
|
342 |
+
References
|
343 |
+
----------
|
344 |
+
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
|
345 |
+
autoencoders." Advances in Neural Information Processing Systems. 2018.
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, n_data, alpha=1., beta=6., gamma=1., is_mss=True, **kwargs):
|
349 |
+
super().__init__(**kwargs)
|
350 |
+
self.n_data = n_data
|
351 |
+
self.beta = beta
|
352 |
+
self.alpha = alpha
|
353 |
+
self.gamma = gamma
|
354 |
+
self.is_mss = is_mss # minibatch stratified sampling
|
355 |
+
|
356 |
+
def __call__(self, data, recon_batch, latent_dist, is_train, storer,
|
357 |
+
latent_sample=None):
|
358 |
+
storer = self._pre_call(is_train, storer)
|
359 |
+
batch_size, latent_dim = latent_sample.shape
|
360 |
+
|
361 |
+
rec_loss = _reconstruction_loss(data, recon_batch,
|
362 |
+
storer=storer,
|
363 |
+
distribution=self.rec_dist)
|
364 |
+
log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
|
365 |
+
latent_dist,
|
366 |
+
self.n_data,
|
367 |
+
is_mss=self.is_mss)
|
368 |
+
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
|
369 |
+
mi_loss = (log_q_zCx - log_qz).mean()
|
370 |
+
# TC[z] = KL[q(z)||\prod_i z_i]
|
371 |
+
tc_loss = (log_qz - log_prod_qzi).mean()
|
372 |
+
# dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
|
373 |
+
dw_kl_loss = (log_prod_qzi - log_pz).mean()
|
374 |
+
|
375 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
376 |
+
if is_train else 1)
|
377 |
+
|
378 |
+
# total loss
|
379 |
+
loss = rec_loss + (self.alpha * mi_loss +
|
380 |
+
self.beta * tc_loss +
|
381 |
+
anneal_reg * self.gamma * dw_kl_loss)
|
382 |
+
|
383 |
+
if storer is not None:
|
384 |
+
storer['loss'].append(loss.item())
|
385 |
+
storer['mi_loss'].append(mi_loss.item())
|
386 |
+
storer['tc_loss'].append(tc_loss.item())
|
387 |
+
storer['dw_kl_loss'].append(dw_kl_loss.item())
|
388 |
+
# computing this for storing and comparaison purposes
|
389 |
+
_ = _kl_normal_loss(*latent_dist, storer)
|
390 |
+
|
391 |
+
return loss
|
392 |
+
|
393 |
+
|
394 |
+
def _reconstruction_loss(data, recon_data, distribution="bernoulli", storer=None):
|
395 |
+
"""
|
396 |
+
Calculates the per image reconstruction loss for a batch of data. I.e. negative
|
397 |
+
log likelihood.
|
398 |
+
|
399 |
+
Parameters
|
400 |
+
----------
|
401 |
+
data : torch.Tensor
|
402 |
+
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
|
403 |
+
height, width).
|
404 |
+
|
405 |
+
recon_data : torch.Tensor
|
406 |
+
Reconstructed data. Shape : (batch_size, n_chan, height, width).
|
407 |
+
|
408 |
+
distribution : {"bernoulli", "gaussian", "laplace"}
|
409 |
+
Distribution of the likelihood on the each pixel. Implicitely defines the
|
410 |
+
loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the
|
411 |
+
most commonly used. It has the issue that it doesn't penalize the same
|
412 |
+
way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian
|
413 |
+
distribution corresponds to MSE, and is sometimes used, but hard to train
|
414 |
+
ecause it ends up focusing only a few pixels that are very wrong. Laplace
|
415 |
+
distribution corresponds to L1 solves partially the issue of MSE.
|
416 |
+
|
417 |
+
storer : dict
|
418 |
+
Dictionary in which to store important variables for vizualisation.
|
419 |
+
|
420 |
+
Returns
|
421 |
+
-------
|
422 |
+
loss : torch.Tensor
|
423 |
+
Per image cross entropy (i.e. normalized per batch but not pixel and
|
424 |
+
channel)
|
425 |
+
"""
|
426 |
+
batch_size, n_chan, height, width = recon_data.size()
|
427 |
+
is_colored = n_chan == 3
|
428 |
+
|
429 |
+
if distribution == "bernoulli":
|
430 |
+
loss = F.binary_cross_entropy(recon_data, data, reduction="sum")
|
431 |
+
elif distribution == "gaussian":
|
432 |
+
# loss in [0,255] space but normalized by 255 to not be too big
|
433 |
+
loss = F.mse_loss(recon_data * 255, data * 255, reduction="sum") / 255
|
434 |
+
elif distribution == "laplace":
|
435 |
+
# loss in [0,255] space but normalized by 255 to not be too big but
|
436 |
+
# multiply by 255 and divide 255, is the same as not doing anything for L1
|
437 |
+
loss = F.l1_loss(recon_data, data, reduction="sum")
|
438 |
+
loss = loss * 3 # emperical value to give similar values than bernoulli => use same hyperparam
|
439 |
+
loss = loss * (loss != 0) # masking to avoid nan
|
440 |
+
else:
|
441 |
+
assert distribution not in RECON_DIST
|
442 |
+
raise ValueError("Unkown distribution: {}".format(distribution))
|
443 |
+
|
444 |
+
loss = loss / batch_size
|
445 |
+
|
446 |
+
if storer is not None:
|
447 |
+
storer['recon_loss'].append(loss.item())
|
448 |
+
|
449 |
+
return loss
|
450 |
+
|
451 |
+
|
452 |
+
def _kl_normal_loss(mean, logvar, storer=None):
|
453 |
+
"""
|
454 |
+
Calculates the KL divergence between a normal distribution
|
455 |
+
with diagonal covariance and a unit normal distribution.
|
456 |
+
|
457 |
+
Parameters
|
458 |
+
----------
|
459 |
+
mean : torch.Tensor
|
460 |
+
Mean of the normal distribution. Shape (batch_size, latent_dim) where
|
461 |
+
D is dimension of distribution.
|
462 |
+
|
463 |
+
logvar : torch.Tensor
|
464 |
+
Diagonal log variance of the normal distribution. Shape (batch_size,
|
465 |
+
latent_dim)
|
466 |
+
|
467 |
+
storer : dict
|
468 |
+
Dictionary in which to store important variables for vizualisation.
|
469 |
+
"""
|
470 |
+
latent_dim = mean.size(1)
|
471 |
+
# batch mean of kl for each latent dimension
|
472 |
+
latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
|
473 |
+
total_kl = latent_kl.sum()
|
474 |
+
|
475 |
+
if storer is not None:
|
476 |
+
storer['kl_loss'].append(total_kl.item())
|
477 |
+
for i in range(latent_dim):
|
478 |
+
storer['kl_loss_' + str(i)].append(latent_kl[i].item())
|
479 |
+
|
480 |
+
return total_kl
|
481 |
+
|
482 |
+
|
483 |
+
def _permute_dims(latent_sample):
|
484 |
+
"""
|
485 |
+
Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from
|
486 |
+
q(z) (latent_dist) across the batch for each of the latent dimensions (mean
|
487 |
+
and log_var).
|
488 |
+
|
489 |
+
Parameters
|
490 |
+
----------
|
491 |
+
latent_sample: torch.Tensor
|
492 |
+
sample from the latent dimension using the reparameterisation trick
|
493 |
+
shape : (batch_size, latent_dim).
|
494 |
+
|
495 |
+
References
|
496 |
+
----------
|
497 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
498 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
499 |
+
|
500 |
+
"""
|
501 |
+
perm = torch.zeros_like(latent_sample)
|
502 |
+
batch_size, dim_z = perm.size()
|
503 |
+
|
504 |
+
for z in range(dim_z):
|
505 |
+
pi = torch.randperm(batch_size).to(latent_sample.device)
|
506 |
+
perm[:, z] = latent_sample[pi, z]
|
507 |
+
|
508 |
+
return perm
|
509 |
+
|
510 |
+
|
511 |
+
def linear_annealing(init, fin, step, annealing_steps):
|
512 |
+
"""Linear annealing of a parameter."""
|
513 |
+
if annealing_steps == 0:
|
514 |
+
return fin
|
515 |
+
assert fin > init
|
516 |
+
delta = fin - init
|
517 |
+
annealed = min(init + delta * step / annealing_steps, fin)
|
518 |
+
return annealed
|
519 |
+
|
520 |
+
|
521 |
+
# Batch TC specific
|
522 |
+
# TO-DO: test if mss is better!
|
523 |
+
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
|
524 |
+
batch_size, hidden_dim = latent_sample.shape
|
525 |
+
|
526 |
+
# calculate log q(z|x)
|
527 |
+
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
|
528 |
+
|
529 |
+
# calculate log p(z)
|
530 |
+
# mean and log var is 0
|
531 |
+
zeros = torch.zeros_like(latent_sample)
|
532 |
+
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
|
533 |
+
|
534 |
+
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
|
535 |
+
|
536 |
+
if is_mss:
|
537 |
+
# use stratification
|
538 |
+
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
|
539 |
+
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
|
540 |
+
|
541 |
+
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
|
542 |
+
log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
|
543 |
+
|
544 |
+
return log_pz, log_qz, log_prod_qzi, log_q_zCx
|
disvae/models/vae.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the main VAE class.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch import nn, optim
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from disvae.utils.initialization import weights_init
|
9 |
+
from .encoders import get_encoder
|
10 |
+
from .decoders import get_decoder
|
11 |
+
|
12 |
+
MODELS = ["Burgess"]
|
13 |
+
|
14 |
+
|
15 |
+
def init_specific_model(model_type, img_size, latent_dim):
|
16 |
+
"""Return an instance of a VAE with encoder and decoder from `model_type`."""
|
17 |
+
model_type = model_type.lower().capitalize()
|
18 |
+
if model_type not in MODELS:
|
19 |
+
err = "Unkown model_type={}. Possible values: {}"
|
20 |
+
raise ValueError(err.format(model_type, MODELS))
|
21 |
+
|
22 |
+
encoder = get_encoder(model_type)
|
23 |
+
decoder = get_decoder(model_type)
|
24 |
+
model = VAE(img_size, encoder, decoder, latent_dim)
|
25 |
+
model.model_type = model_type # store to help reloading
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
class VAE(nn.Module):
|
30 |
+
def __init__(self, img_size, encoder, decoder, latent_dim):
|
31 |
+
"""
|
32 |
+
Class which defines model and forward pass.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
img_size : tuple of ints
|
37 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
38 |
+
"""
|
39 |
+
super(VAE, self).__init__()
|
40 |
+
|
41 |
+
if list(img_size[1:]) not in [[32, 32], [64, 64]]:
|
42 |
+
raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size))
|
43 |
+
|
44 |
+
self.latent_dim = latent_dim
|
45 |
+
self.img_size = img_size
|
46 |
+
self.num_pixels = self.img_size[1] * self.img_size[2]
|
47 |
+
self.encoder = encoder(img_size, self.latent_dim)
|
48 |
+
self.decoder = decoder(img_size, self.latent_dim)
|
49 |
+
|
50 |
+
self.reset_parameters()
|
51 |
+
|
52 |
+
def reparameterize(self, mean, logvar):
|
53 |
+
"""
|
54 |
+
Samples from a normal distribution using the reparameterization trick.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
mean : torch.Tensor
|
59 |
+
Mean of the normal distribution. Shape (batch_size, latent_dim)
|
60 |
+
|
61 |
+
logvar : torch.Tensor
|
62 |
+
Diagonal log variance of the normal distribution. Shape (batch_size,
|
63 |
+
latent_dim)
|
64 |
+
"""
|
65 |
+
if self.training:
|
66 |
+
std = torch.exp(0.5 * logvar)
|
67 |
+
eps = torch.randn_like(std)
|
68 |
+
return mean + std * eps
|
69 |
+
else:
|
70 |
+
# Reconstruction mode
|
71 |
+
return mean
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""
|
75 |
+
Forward pass of model.
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
x : torch.Tensor
|
80 |
+
Batch of data. Shape (batch_size, n_chan, height, width)
|
81 |
+
"""
|
82 |
+
latent_dist = self.encoder(x)
|
83 |
+
latent_sample = self.reparameterize(*latent_dist)
|
84 |
+
reconstruct = self.decoder(latent_sample)
|
85 |
+
return reconstruct, latent_dist, latent_sample
|
86 |
+
|
87 |
+
def reset_parameters(self):
|
88 |
+
self.apply(weights_init)
|
89 |
+
|
90 |
+
def sample_latent(self, x):
|
91 |
+
"""
|
92 |
+
Returns a sample from the latent distribution.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
x : torch.Tensor
|
97 |
+
Batch of data. Shape (batch_size, n_chan, height, width)
|
98 |
+
"""
|
99 |
+
latent_dist = self.encoder(x)
|
100 |
+
latent_sample = self.reparameterize(*latent_dist)
|
101 |
+
return latent_sample
|
disvae/training.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import imageio
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
from timeit import default_timer
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
from disvae.utils.modelIO import save_model
|
12 |
+
|
13 |
+
TRAIN_LOSSES_LOGFILE = "train_losses.log"
|
14 |
+
|
15 |
+
|
16 |
+
class Trainer:
|
17 |
+
"""
|
18 |
+
Class to handle training of model.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
model: disvae.vae.VAE
|
23 |
+
|
24 |
+
optimizer: torch.optim.Optimizer
|
25 |
+
|
26 |
+
loss_f: disvae.models.BaseLoss
|
27 |
+
Loss function.
|
28 |
+
|
29 |
+
device: torch.device, optional
|
30 |
+
Device on which to run the code.
|
31 |
+
|
32 |
+
logger: logging.Logger, optional
|
33 |
+
Logger.
|
34 |
+
|
35 |
+
save_dir : str, optional
|
36 |
+
Directory for saving logs.
|
37 |
+
|
38 |
+
gif_visualizer : viz.Visualizer, optional
|
39 |
+
Gif Visualizer that should return samples at every epochs.
|
40 |
+
|
41 |
+
is_progress_bar: bool, optional
|
42 |
+
Whether to use a progress bar for training.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
model,
|
48 |
+
optimizer,
|
49 |
+
loss_f,
|
50 |
+
device=torch.device("cpu"),
|
51 |
+
logger=logging.getLogger(__name__),
|
52 |
+
save_dir="results",
|
53 |
+
gif_visualizer=None,
|
54 |
+
is_progress_bar=True,
|
55 |
+
):
|
56 |
+
self.device = device
|
57 |
+
self.model = model.to(self.device)
|
58 |
+
self.loss_f = loss_f
|
59 |
+
self.optimizer = optimizer
|
60 |
+
self.save_dir = save_dir
|
61 |
+
self.is_progress_bar = is_progress_bar
|
62 |
+
self.logger = logger
|
63 |
+
self.losses_logger = LossesLogger(
|
64 |
+
os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE)
|
65 |
+
)
|
66 |
+
self.gif_visualizer = gif_visualizer
|
67 |
+
self.logger.info("Training Device: {}".format(self.device))
|
68 |
+
|
69 |
+
def __call__(self, data_loader, epochs=10, checkpoint_every=10):
|
70 |
+
"""
|
71 |
+
Trains the model.
|
72 |
+
|
73 |
+
Parameters
|
74 |
+
----------
|
75 |
+
data_loader: torch.utils.data.DataLoader
|
76 |
+
|
77 |
+
epochs: int, optional
|
78 |
+
Number of epochs to train the model for.
|
79 |
+
|
80 |
+
checkpoint_every: int, optional
|
81 |
+
Save a checkpoint of the trained model every n epoch.
|
82 |
+
"""
|
83 |
+
start = default_timer()
|
84 |
+
self.model.train()
|
85 |
+
for epoch in range(epochs):
|
86 |
+
storer = defaultdict(list)
|
87 |
+
mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
|
88 |
+
self.logger.info(
|
89 |
+
"Epoch: {} Average loss per image: {:.2f}".format(
|
90 |
+
epoch + 1, mean_epoch_loss
|
91 |
+
)
|
92 |
+
)
|
93 |
+
self.losses_logger.log(epoch, storer)
|
94 |
+
|
95 |
+
if self.gif_visualizer is not None:
|
96 |
+
self.gif_visualizer()
|
97 |
+
|
98 |
+
if epoch % checkpoint_every == 0:
|
99 |
+
save_model(
|
100 |
+
self.model, self.save_dir, filename="model-{}.pt".format(epoch)
|
101 |
+
)
|
102 |
+
|
103 |
+
if self.gif_visualizer is not None:
|
104 |
+
self.gif_visualizer.save_reset()
|
105 |
+
|
106 |
+
self.model.eval()
|
107 |
+
|
108 |
+
delta_time = (default_timer() - start) / 60
|
109 |
+
self.logger.info("Finished training after {:.1f} min.".format(delta_time))
|
110 |
+
|
111 |
+
def _train_epoch(self, data_loader, storer, epoch):
|
112 |
+
"""
|
113 |
+
Trains the model for one epoch.
|
114 |
+
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
data_loader: torch.utils.data.DataLoader
|
118 |
+
|
119 |
+
storer: dict
|
120 |
+
Dictionary in which to store important variables for vizualisation.
|
121 |
+
|
122 |
+
epoch: int
|
123 |
+
Epoch number
|
124 |
+
|
125 |
+
Return
|
126 |
+
------
|
127 |
+
mean_epoch_loss: float
|
128 |
+
Mean loss per image
|
129 |
+
"""
|
130 |
+
epoch_loss = 0.0
|
131 |
+
kwargs = dict(
|
132 |
+
desc="Epoch {}".format(epoch + 1),
|
133 |
+
leave=False,
|
134 |
+
disable=not self.is_progress_bar,
|
135 |
+
)
|
136 |
+
with trange(len(data_loader), **kwargs) as t:
|
137 |
+
for _, (data, _) in enumerate(data_loader):
|
138 |
+
iter_loss = self._train_iteration(data, storer)
|
139 |
+
epoch_loss += iter_loss
|
140 |
+
|
141 |
+
t.set_postfix(loss=iter_loss)
|
142 |
+
t.update()
|
143 |
+
|
144 |
+
mean_epoch_loss = epoch_loss / len(data_loader)
|
145 |
+
return mean_epoch_loss
|
146 |
+
|
147 |
+
def _train_iteration(self, data, storer):
|
148 |
+
"""
|
149 |
+
Trains the model for one iteration on a batch of data.
|
150 |
+
|
151 |
+
Parameters
|
152 |
+
----------
|
153 |
+
data: torch.Tensor
|
154 |
+
A batch of data. Shape : (batch_size, channel, height, width).
|
155 |
+
|
156 |
+
storer: dict
|
157 |
+
Dictionary in which to store important variables for vizualisation.
|
158 |
+
"""
|
159 |
+
batch_size, channel, height, width = data.size()
|
160 |
+
data = data.to(self.device)
|
161 |
+
|
162 |
+
try:
|
163 |
+
recon_batch, latent_dist, latent_sample = self.model(data)
|
164 |
+
loss = self.loss_f(
|
165 |
+
data,
|
166 |
+
recon_batch,
|
167 |
+
latent_dist,
|
168 |
+
self.model.training,
|
169 |
+
storer,
|
170 |
+
latent_sample=latent_sample,
|
171 |
+
)
|
172 |
+
self.optimizer.zero_grad()
|
173 |
+
loss.backward()
|
174 |
+
self.optimizer.step()
|
175 |
+
|
176 |
+
except ValueError:
|
177 |
+
# for losses that use multiple optimizers (e.g. Factor)
|
178 |
+
loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer)
|
179 |
+
|
180 |
+
return loss.item()
|
181 |
+
|
182 |
+
|
183 |
+
class LossesLogger(object):
|
184 |
+
"""Class definition for objects to write data to log files in a
|
185 |
+
form which is then easy to be plotted.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, file_path_name):
|
189 |
+
"""Create a logger to store information for plotting."""
|
190 |
+
if os.path.isfile(file_path_name):
|
191 |
+
os.remove(file_path_name)
|
192 |
+
|
193 |
+
self.logger = logging.getLogger("losses_logger")
|
194 |
+
self.logger.setLevel(1) # always store
|
195 |
+
file_handler = logging.FileHandler(file_path_name)
|
196 |
+
file_handler.setLevel(1)
|
197 |
+
self.logger.addHandler(file_handler)
|
198 |
+
|
199 |
+
header = ",".join(["Epoch", "Loss", "Value"])
|
200 |
+
self.logger.debug(header)
|
201 |
+
|
202 |
+
def log(self, epoch, losses_storer):
|
203 |
+
"""Write to the log file"""
|
204 |
+
for k, v in losses_storer.items():
|
205 |
+
log_string = ",".join(str(item) for item in [epoch, k, mean(v)])
|
206 |
+
self.logger.debug(log_string)
|
207 |
+
|
208 |
+
|
209 |
+
# HELPERS
|
210 |
+
def mean(l):
|
211 |
+
"""Compute the mean of a list"""
|
212 |
+
return sum(l) / len(l)
|
disvae/utils/__init__.py
ADDED
File without changes
|
disvae/utils/initialization.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def get_activation_name(activation):
|
6 |
+
"""Given a string or a `torch.nn.modules.activation` return the name of the activation."""
|
7 |
+
if isinstance(activation, str):
|
8 |
+
return activation
|
9 |
+
|
10 |
+
mapper = {nn.LeakyReLU: "leaky_relu", nn.ReLU: "relu", nn.Tanh: "tanh",
|
11 |
+
nn.Sigmoid: "sigmoid", nn.Softmax: "sigmoid"}
|
12 |
+
for k, v in mapper.items():
|
13 |
+
if isinstance(activation, k):
|
14 |
+
return k
|
15 |
+
|
16 |
+
raise ValueError("Unkown given activation type : {}".format(activation))
|
17 |
+
|
18 |
+
|
19 |
+
def get_gain(activation):
|
20 |
+
"""Given an object of `torch.nn.modules.activation` or an activation name
|
21 |
+
return the correct gain."""
|
22 |
+
if activation is None:
|
23 |
+
return 1
|
24 |
+
|
25 |
+
activation_name = get_activation_name(activation)
|
26 |
+
|
27 |
+
param = None if activation_name != "leaky_relu" else activation.negative_slope
|
28 |
+
gain = nn.init.calculate_gain(activation_name, param)
|
29 |
+
|
30 |
+
return gain
|
31 |
+
|
32 |
+
|
33 |
+
def linear_init(layer, activation="relu"):
|
34 |
+
"""Initialize a linear layer.
|
35 |
+
Args:
|
36 |
+
layer (nn.Linear): parameters to initialize.
|
37 |
+
activation (`torch.nn.modules.activation` or str, optional) activation that
|
38 |
+
will be used on the `layer`.
|
39 |
+
"""
|
40 |
+
x = layer.weight
|
41 |
+
|
42 |
+
if activation is None:
|
43 |
+
return nn.init.xavier_uniform_(x)
|
44 |
+
|
45 |
+
activation_name = get_activation_name(activation)
|
46 |
+
|
47 |
+
if activation_name == "leaky_relu":
|
48 |
+
a = 0 if isinstance(activation, str) else activation.negative_slope
|
49 |
+
return nn.init.kaiming_uniform_(x, a=a, nonlinearity='leaky_relu')
|
50 |
+
elif activation_name == "relu":
|
51 |
+
return nn.init.kaiming_uniform_(x, nonlinearity='relu')
|
52 |
+
elif activation_name in ["sigmoid", "tanh"]:
|
53 |
+
return nn.init.xavier_uniform_(x, gain=get_gain(activation))
|
54 |
+
|
55 |
+
|
56 |
+
def weights_init(module):
|
57 |
+
if isinstance(module, torch.nn.modules.conv._ConvNd):
|
58 |
+
# TO-DO: check litterature
|
59 |
+
linear_init(module)
|
60 |
+
elif isinstance(module, nn.Linear):
|
61 |
+
linear_init(module)
|
disvae/utils/math.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
|
4 |
+
from tqdm import trange, tqdm
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def matrix_log_density_gaussian(x, mu, logvar):
|
9 |
+
"""Calculates log density of a Gaussian for all combination of bacth pairs of
|
10 |
+
`x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
|
11 |
+
instead of (batch_size, dim) in the usual log density.
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
----------
|
15 |
+
x: torch.Tensor
|
16 |
+
Value at which to compute the density. Shape: (batch_size, dim).
|
17 |
+
|
18 |
+
mu: torch.Tensor
|
19 |
+
Mean. Shape: (batch_size, dim).
|
20 |
+
|
21 |
+
logvar: torch.Tensor
|
22 |
+
Log variance. Shape: (batch_size, dim).
|
23 |
+
|
24 |
+
batch_size: int
|
25 |
+
number of training images in the batch
|
26 |
+
"""
|
27 |
+
batch_size, dim = x.shape
|
28 |
+
x = x.view(batch_size, 1, dim)
|
29 |
+
mu = mu.view(1, batch_size, dim)
|
30 |
+
logvar = logvar.view(1, batch_size, dim)
|
31 |
+
return log_density_gaussian(x, mu, logvar)
|
32 |
+
|
33 |
+
|
34 |
+
def log_density_gaussian(x, mu, logvar):
|
35 |
+
"""Calculates log density of a Gaussian.
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
x: torch.Tensor or np.ndarray or float
|
40 |
+
Value at which to compute the density.
|
41 |
+
|
42 |
+
mu: torch.Tensor or np.ndarray or float
|
43 |
+
Mean.
|
44 |
+
|
45 |
+
logvar: torch.Tensor or np.ndarray or float
|
46 |
+
Log variance.
|
47 |
+
"""
|
48 |
+
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
|
49 |
+
inv_var = torch.exp(-logvar)
|
50 |
+
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
|
51 |
+
return log_density
|
52 |
+
|
53 |
+
|
54 |
+
def log_importance_weight_matrix(batch_size, dataset_size):
|
55 |
+
"""
|
56 |
+
Calculates a log importance weight matrix
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
batch_size: int
|
61 |
+
number of training images in the batch
|
62 |
+
|
63 |
+
dataset_size: int
|
64 |
+
number of training images in the dataset
|
65 |
+
"""
|
66 |
+
N = dataset_size
|
67 |
+
M = batch_size - 1
|
68 |
+
strat_weight = (N - M) / (N * M)
|
69 |
+
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
|
70 |
+
W.view(-1)[::M + 1] = 1 / N
|
71 |
+
W.view(-1)[1::M + 1] = strat_weight
|
72 |
+
W[M - 1, 0] = strat_weight
|
73 |
+
return W.log()
|
disvae/utils/modelIO.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from disvae.models.vae import init_specific_model
|
10 |
+
|
11 |
+
MODEL_FILENAME = "model.pt"
|
12 |
+
META_FILENAME = "specs.json"
|
13 |
+
|
14 |
+
|
15 |
+
def vae2onnx(vae, p_out: str) -> None:
|
16 |
+
if isinstance(vae, str):
|
17 |
+
p_out = Path(p_out)
|
18 |
+
if not p_out.exists():
|
19 |
+
p_out.mkdir()
|
20 |
+
|
21 |
+
device = next(vae.parameters()).device
|
22 |
+
vae.cpu()
|
23 |
+
|
24 |
+
# Encoder
|
25 |
+
vae.encoder.eval()
|
26 |
+
dummy_input_im = torch.zeros(tuple(np.concatenate([[1], vae.img_size])))
|
27 |
+
torch.onnx.export(vae.encoder, dummy_input_im, p_out / "encoder.onnx", verbose=True)
|
28 |
+
|
29 |
+
# Decoder
|
30 |
+
vae.decoder.eval()
|
31 |
+
dummy_input_latent = torch.zeros((1, vae.latent_dim))
|
32 |
+
torch.onnx.export(
|
33 |
+
vae.decoder, dummy_input_latent, p_out / "decoder.onnx", verbose=True
|
34 |
+
)
|
35 |
+
|
36 |
+
vae.to(device) # restore device
|
37 |
+
|
38 |
+
|
39 |
+
def save_model(model, directory, metadata=None, filename=MODEL_FILENAME):
|
40 |
+
"""
|
41 |
+
Save a model and corresponding metadata.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
model : nn.Module
|
46 |
+
Model.
|
47 |
+
|
48 |
+
directory : str
|
49 |
+
Path to the directory where to save the data.
|
50 |
+
|
51 |
+
metadata : dict
|
52 |
+
Metadata to save.
|
53 |
+
"""
|
54 |
+
device = next(model.parameters()).device
|
55 |
+
model.cpu()
|
56 |
+
|
57 |
+
if metadata is None:
|
58 |
+
# save the minimum required for loading
|
59 |
+
metadata = dict(
|
60 |
+
img_size=model.img_size,
|
61 |
+
latent_dim=model.latent_dim,
|
62 |
+
model_type=model.model_type,
|
63 |
+
)
|
64 |
+
|
65 |
+
save_metadata(metadata, directory)
|
66 |
+
|
67 |
+
path_to_model = os.path.join(directory, filename)
|
68 |
+
torch.save(model.state_dict(), path_to_model)
|
69 |
+
|
70 |
+
model.to(device) # restore device
|
71 |
+
|
72 |
+
|
73 |
+
def load_metadata(directory, filename=META_FILENAME):
|
74 |
+
"""Load the metadata of a training directory.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
directory : string
|
79 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
80 |
+
"""
|
81 |
+
path_to_metadata = os.path.join(directory, filename)
|
82 |
+
|
83 |
+
with open(path_to_metadata) as metadata_file:
|
84 |
+
metadata = json.load(metadata_file)
|
85 |
+
|
86 |
+
return metadata
|
87 |
+
|
88 |
+
|
89 |
+
def save_metadata(metadata, directory, filename=META_FILENAME, **kwargs):
|
90 |
+
"""Load the metadata of a training directory.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
metadata:
|
95 |
+
Object to save
|
96 |
+
|
97 |
+
directory: string
|
98 |
+
Path to folder where to save model. For example './experiments/mnist'.
|
99 |
+
|
100 |
+
kwargs:
|
101 |
+
Additional arguments to `json.dump`
|
102 |
+
"""
|
103 |
+
path_to_metadata = os.path.join(directory, filename)
|
104 |
+
|
105 |
+
with open(path_to_metadata, "w") as f:
|
106 |
+
json.dump(metadata, f, indent=4, sort_keys=True, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def load_model(directory, is_gpu=True, filename=MODEL_FILENAME):
|
110 |
+
"""Load a trained model.
|
111 |
+
|
112 |
+
Parameters
|
113 |
+
----------
|
114 |
+
directory : string
|
115 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
116 |
+
|
117 |
+
is_gpu : bool
|
118 |
+
Whether to load on GPU is available.
|
119 |
+
"""
|
120 |
+
device = torch.device("cuda" if torch.cuda.is_available() and is_gpu else "cpu")
|
121 |
+
|
122 |
+
path_to_model = os.path.join(directory, MODEL_FILENAME)
|
123 |
+
|
124 |
+
metadata = load_metadata(directory)
|
125 |
+
img_size = metadata["img_size"]
|
126 |
+
latent_dim = metadata["latent_dim"]
|
127 |
+
model_type = metadata["model_type"]
|
128 |
+
|
129 |
+
path_to_model = os.path.join(directory, filename)
|
130 |
+
model = _get_model(model_type, img_size, latent_dim, device, path_to_model)
|
131 |
+
return model
|
132 |
+
|
133 |
+
|
134 |
+
def load_checkpoints(directory, is_gpu=True):
|
135 |
+
"""Load all chechpointed models.
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
directory : string
|
140 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
141 |
+
|
142 |
+
is_gpu : bool
|
143 |
+
Whether to load on GPU .
|
144 |
+
"""
|
145 |
+
checkpoints = []
|
146 |
+
for root, _, filenames in os.walk(directory):
|
147 |
+
for filename in filenames:
|
148 |
+
results = re.search(r".*?-([0-9].*?).pt", filename)
|
149 |
+
if results is not None:
|
150 |
+
epoch_idx = int(results.group(1))
|
151 |
+
model = load_model(root, is_gpu=is_gpu, filename=filename)
|
152 |
+
checkpoints.append((epoch_idx, model))
|
153 |
+
|
154 |
+
return checkpoints
|
155 |
+
|
156 |
+
|
157 |
+
def _get_model(model_type, img_size, latent_dim, device, path_to_model):
|
158 |
+
"""Load a single model.
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
model_type : str
|
163 |
+
The name of the model to load. For example Burgess.
|
164 |
+
img_size : tuple
|
165 |
+
Tuple of the number of pixels in the image width and height.
|
166 |
+
For example (32, 32) or (64, 64).
|
167 |
+
latent_dim : int
|
168 |
+
The number of latent dimensions in the bottleneck.
|
169 |
+
|
170 |
+
device : str
|
171 |
+
Either 'cuda' or 'cpu'
|
172 |
+
path_to_device : str
|
173 |
+
Full path to the saved model on the device.
|
174 |
+
"""
|
175 |
+
model = init_specific_model(model_type, img_size, latent_dim).to(device)
|
176 |
+
# works with state_dict to make it independent of the file structure
|
177 |
+
model.load_state_dict(torch.load(path_to_model), strict=False)
|
178 |
+
model.eval()
|
179 |
+
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def numpy_serialize(obj):
|
184 |
+
if type(obj).__module__ == np.__name__:
|
185 |
+
if isinstance(obj, np.ndarray):
|
186 |
+
return obj.tolist()
|
187 |
+
else:
|
188 |
+
return obj.item()
|
189 |
+
raise TypeError("Unknown type:", type(obj))
|
190 |
+
|
191 |
+
|
192 |
+
def save_np_arrays(arrays, directory, filename):
|
193 |
+
"""Save dictionary of arrays in json file."""
|
194 |
+
save_metadata(arrays, directory, filename=filename, default=numpy_serialize)
|
195 |
+
|
196 |
+
|
197 |
+
def load_np_arrays(directory, filename):
|
198 |
+
"""Load dictionary of arrays from json file."""
|
199 |
+
arrays = load_metadata(directory, filename=filename)
|
200 |
+
return {k: np.array(v) for k, v in arrays.items()}
|
model/drilling_ds_btcvae/model-0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bd1cfaf8d1124dc405fc753e77da693d74220b72e95971660d5a21ef2986081
|
3 |
+
size 2016145
|
model/drilling_ds_btcvae/model-10.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:745b798b9252ebde56455f80caf95f948d1515cdd36fb28ee5d3738e619c1f89
|
3 |
+
size 2016687
|
model/drilling_ds_btcvae/model-20.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a5d11e24ceea210dc9792497940cc84996f785f8f821fc4dd9b51bad6607f8f
|
3 |
+
size 2016687
|
model/drilling_ds_btcvae/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8288bc5d3e7d160054afe19afc6cfa4ea884106fe3377f6e00a6e4c384f4d5a
|
3 |
+
size 2014933
|
model/drilling_ds_btcvae/onnx/decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7212ac0302bbf4072e3d083b393c24355d1d88c671b9cb7f037e91ea5312d745
|
3 |
+
size 1046902
|
model/drilling_ds_btcvae/onnx/encoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3299e3b39f5181946675ba3da2f17be960ccfb9aa3aaed56af173724720e657
|
3 |
+
size 1074967
|
model/drilling_ds_btcvae/specs.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"img_size": [
|
3 |
+
1,
|
4 |
+
64,
|
5 |
+
64
|
6 |
+
],
|
7 |
+
"latent_dim": 10,
|
8 |
+
"model_type": "Burgess"
|
9 |
+
}
|
model/drilling_ds_btcvae/train_losses.log
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Epoch,Loss,Value
|
2 |
+
0,recon_loss,2249.039325420673
|
3 |
+
0,loss,2050.8551119290864
|
4 |
+
0,mi_loss,39.089650080754204
|
5 |
+
0,tc_loss,-37.15008985079252
|
6 |
+
0,dw_kl_loss,14.773665877488943
|
7 |
+
0,kl_loss,16.61438555900867
|
8 |
+
0,kl_loss_0,1.3352613651122038
|
9 |
+
0,kl_loss_1,2.678740862183846
|
10 |
+
0,kl_loss_2,1.5543635350007277
|
11 |
+
0,kl_loss_3,2.3668181953521876
|
12 |
+
0,kl_loss_4,1.1609044877382426
|
13 |
+
0,kl_loss_5,1.2264407391731555
|
14 |
+
0,kl_loss_6,1.9435587427937067
|
15 |
+
0,kl_loss_7,1.346766499372629
|
16 |
+
0,kl_loss_8,1.6203752496781259
|
17 |
+
0,kl_loss_9,1.3811558622580309
|
18 |
+
1,recon_loss,2217.2862955729165
|
19 |
+
1,loss,2020.4783630371094
|
20 |
+
1,mi_loss,39.90034135182699
|
21 |
+
1,tc_loss,-37.161014238993324
|
22 |
+
1,dw_kl_loss,12.359588861465454
|
23 |
+
1,kl_loss,15.169561068216959
|
24 |
+
1,kl_loss_0,1.0467030107975006
|
25 |
+
1,kl_loss_1,2.6103671193122864
|
26 |
+
1,kl_loss_2,1.0980331550041835
|
27 |
+
1,kl_loss_3,2.823746303717295
|
28 |
+
1,kl_loss_4,0.8646406829357147
|
29 |
+
1,kl_loss_5,0.937546119093895
|
30 |
+
1,kl_loss_6,1.9802785913149517
|
31 |
+
1,kl_loss_7,1.117189645767212
|
32 |
+
1,kl_loss_8,1.3413666437069576
|
33 |
+
1,kl_loss_9,1.3496898810068767
|
34 |
+
2,recon_loss,2214.081355168269
|
35 |
+
2,loss,2017.7317645733174
|
36 |
+
2,mi_loss,39.97995347243089
|
37 |
+
2,tc_loss,-37.11946810208834
|
38 |
+
2,dw_kl_loss,8.114381826840914
|
39 |
+
2,kl_loss,10.967073367192196
|
40 |
+
2,kl_loss_0,0.6580769190421472
|
41 |
+
2,kl_loss_1,2.3679186747624326
|
42 |
+
2,kl_loss_2,0.6067795891028184
|
43 |
+
2,kl_loss_3,2.3443003251002383
|
44 |
+
2,kl_loss_4,0.5523027869371268
|
45 |
+
2,kl_loss_5,0.630363792181015
|
46 |
+
2,kl_loss_6,1.140236680324261
|
47 |
+
2,kl_loss_7,0.6584483155837426
|
48 |
+
2,kl_loss_8,0.7375089365702409
|
49 |
+
2,kl_loss_9,1.271137265058664
|
50 |
+
3,recon_loss,2212.1290079752603
|
51 |
+
3,loss,2015.6785278320312
|
52 |
+
3,mi_loss,40.029717445373535
|
53 |
+
3,tc_loss,-37.09218883514404
|
54 |
+
3,dw_kl_loss,4.2465185324351
|
55 |
+
3,kl_loss,7.1519254843393965
|
56 |
+
3,kl_loss_0,0.3417855277657509
|
57 |
+
3,kl_loss_1,2.362027664979299
|
58 |
+
3,kl_loss_2,0.29086008047064144
|
59 |
+
3,kl_loss_3,1.2119526664415996
|
60 |
+
3,kl_loss_4,0.275665458291769
|
61 |
+
3,kl_loss_5,0.34522855281829834
|
62 |
+
3,kl_loss_6,0.5345357780655225
|
63 |
+
3,kl_loss_7,0.33813271423180896
|
64 |
+
3,kl_loss_8,0.3666527792811394
|
65 |
+
3,kl_loss_9,1.0850842048724492
|
66 |
+
4,recon_loss,2200.101787860577
|
67 |
+
4,loss,2002.8201528695913
|
68 |
+
4,mi_loss,39.938018798828125
|
69 |
+
4,tc_loss,-37.15127005943885
|
70 |
+
4,dw_kl_loss,1.9809786998308623
|
71 |
+
4,kl_loss,4.73990275309636
|
72 |
+
4,kl_loss_0,0.1199311871941273
|
73 |
+
4,kl_loss_1,2.463518949655386
|
74 |
+
4,kl_loss_2,0.10289078320448215
|
75 |
+
4,kl_loss_3,0.4072182568219992
|
76 |
+
4,kl_loss_4,0.10970824899581763
|
77 |
+
4,kl_loss_5,0.17061764001846313
|
78 |
+
4,kl_loss_6,0.20478195181259742
|
79 |
+
4,kl_loss_7,0.12073192258293812
|
80 |
+
4,kl_loss_8,0.1500035455593696
|
81 |
+
4,kl_loss_9,0.8905002612334031
|
82 |
+
5,recon_loss,2204.4312947591147
|
83 |
+
5,loss,2007.2495625813801
|
84 |
+
5,mi_loss,40.11853218078613
|
85 |
+
5,tc_loss,-37.12883758544922
|
86 |
+
5,dw_kl_loss,0.9564686765273412
|
87 |
+
5,kl_loss,3.8624242544174194
|
88 |
+
5,kl_loss_0,0.04401067225262523
|
89 |
+
5,kl_loss_1,2.557211220264435
|
90 |
+
5,kl_loss_2,0.04017933504655957
|
91 |
+
5,kl_loss_3,0.1053344514220953
|
92 |
+
5,kl_loss_4,0.04450849971423546
|
93 |
+
5,kl_loss_5,0.06218112260103226
|
94 |
+
5,kl_loss_6,0.0690454790989558
|
95 |
+
5,kl_loss_7,0.04647901297236482
|
96 |
+
5,kl_loss_8,0.05923667146513859
|
97 |
+
5,kl_loss_9,0.8342377990484238
|
98 |
+
6,recon_loss,2205.8294771634614
|
99 |
+
6,loss,2008.6192157451924
|
100 |
+
6,mi_loss,40.10805218036358
|
101 |
+
6,tc_loss,-37.120982536902794
|
102 |
+
6,dw_kl_loss,0.6391923519281241
|
103 |
+
6,kl_loss,3.6411730509537916
|
104 |
+
6,kl_loss_0,0.019102195349450294
|
105 |
+
6,kl_loss_1,2.5998667020064135
|
106 |
+
6,kl_loss_2,0.019974250418062393
|
107 |
+
6,kl_loss_3,0.03957459602791529
|
108 |
+
6,kl_loss_4,0.02344304831841817
|
109 |
+
6,kl_loss_5,0.028779690225537006
|
110 |
+
6,kl_loss_6,0.02773726975115446
|
111 |
+
6,kl_loss_7,0.0207805114869888
|
112 |
+
6,kl_loss_8,0.027905162853690293
|
113 |
+
6,kl_loss_9,0.8340096381994394
|
114 |
+
7,recon_loss,2207.6748657226562
|
115 |
+
7,loss,2010.4278869628906
|
116 |
+
7,mi_loss,40.222665786743164
|
117 |
+
7,tc_loss,-37.143791834513344
|
118 |
+
7,dw_kl_loss,0.5344960068662962
|
119 |
+
7,kl_loss,3.6037667989730835
|
120 |
+
7,kl_loss_0,0.012062065963012477
|
121 |
+
7,kl_loss_1,2.7152191599210105
|
122 |
+
7,kl_loss_2,0.011340752360410988
|
123 |
+
7,kl_loss_3,0.020317877798030775
|
124 |
+
7,kl_loss_4,0.010189585232486328
|
125 |
+
7,kl_loss_5,0.01701245631556958
|
126 |
+
7,kl_loss_6,0.013870434602722526
|
127 |
+
7,kl_loss_7,0.012214566503340999
|
128 |
+
7,kl_loss_8,0.014409278597061833
|
129 |
+
7,kl_loss_9,0.7771305988232294
|
130 |
+
8,recon_loss,2204.4822904146636
|
131 |
+
8,loss,2007.0154465895432
|
132 |
+
8,mi_loss,40.23030794583834
|
133 |
+
8,tc_loss,-37.17468848595252
|
134 |
+
8,dw_kl_loss,0.4160432998950665
|
135 |
+
8,kl_loss,3.491810395167424
|
136 |
+
8,kl_loss_0,0.007323983161208721
|
137 |
+
8,kl_loss_1,2.6830281294309177
|
138 |
+
8,kl_loss_2,0.006681273798816479
|
139 |
+
8,kl_loss_3,0.010057092895014929
|
140 |
+
8,kl_loss_4,0.007608905219687865
|
141 |
+
8,kl_loss_5,0.008022861555218697
|
142 |
+
8,kl_loss_6,0.0056697913588812715
|
143 |
+
8,kl_loss_7,0.008376527505998429
|
144 |
+
8,kl_loss_8,0.00837366422638297
|
145 |
+
8,kl_loss_9,0.7466680911871103
|
146 |
+
9,recon_loss,2206.011433919271
|
147 |
+
9,loss,2008.6884663899739
|
148 |
+
9,mi_loss,40.23820718129476
|
149 |
+
9,tc_loss,-37.16220887502035
|
150 |
+
9,dw_kl_loss,0.46936574081579846
|
151 |
+
9,kl_loss,3.4854011138280234
|
152 |
+
9,kl_loss_0,0.004957272865188618
|
153 |
+
9,kl_loss_1,2.687060753504435
|
154 |
+
9,kl_loss_2,0.004417084312687318
|
155 |
+
9,kl_loss_3,0.007213231680604319
|
156 |
+
9,kl_loss_4,0.006509214756079018
|
157 |
+
9,kl_loss_5,0.00647372849440823
|
158 |
+
9,kl_loss_6,0.005934650117220978
|
159 |
+
9,kl_loss_7,0.006229850230738521
|
160 |
+
9,kl_loss_8,0.006342786713503301
|
161 |
+
9,kl_loss_9,0.7502625584602356
|
162 |
+
10,recon_loss,2196.9954740084136
|
163 |
+
10,loss,1999.4449462890625
|
164 |
+
10,mi_loss,40.27904305091271
|
165 |
+
10,tc_loss,-37.19911399254432
|
166 |
+
10,dw_kl_loss,0.37395678689846623
|
167 |
+
10,kl_loss,3.4625553901378927
|
168 |
+
10,kl_loss_0,0.004101881040976598
|
169 |
+
10,kl_loss_1,2.713090548148522
|
170 |
+
10,kl_loss_2,0.0036793027359705707
|
171 |
+
10,kl_loss_3,0.004992271200395548
|
172 |
+
10,kl_loss_4,0.004465263444357193
|
173 |
+
10,kl_loss_5,0.004884598884158409
|
174 |
+
10,kl_loss_6,0.0037293383636726784
|
175 |
+
10,kl_loss_7,0.004201724802931914
|
176 |
+
10,kl_loss_8,0.00413606484205677
|
177 |
+
10,kl_loss_9,0.7152743293688848
|
178 |
+
11,recon_loss,2205.8016967773438
|
179 |
+
11,loss,2008.5977172851562
|
180 |
+
11,mi_loss,40.24677817026774
|
181 |
+
11,tc_loss,-37.13152503967285
|
182 |
+
11,dw_kl_loss,0.2668171264231205
|
183 |
+
11,kl_loss,3.3814604083697
|
184 |
+
11,kl_loss_0,0.002733308278645078
|
185 |
+
11,kl_loss_1,2.5890026092529297
|
186 |
+
11,kl_loss_2,0.003619925118982792
|
187 |
+
11,kl_loss_3,0.0036296656665702662
|
188 |
+
11,kl_loss_4,0.0030502101484065256
|
189 |
+
11,kl_loss_5,0.003691234936316808
|
190 |
+
11,kl_loss_6,0.0037010414137815437
|
191 |
+
11,kl_loss_7,0.0035850899294018745
|
192 |
+
11,kl_loss_8,0.0033201781722406545
|
193 |
+
11,kl_loss_9,0.7651271522045135
|
194 |
+
12,recon_loss,2204.1350911458335
|
195 |
+
12,loss,2006.7399800618489
|
196 |
+
12,mi_loss,40.24555047353109
|
197 |
+
12,tc_loss,-37.16728210449219
|
198 |
+
12,dw_kl_loss,0.29467178384462994
|
199 |
+
12,kl_loss,3.4273709058761597
|
200 |
+
12,kl_loss_0,0.00253635470289737
|
201 |
+
12,kl_loss_1,2.64704296986262
|
202 |
+
12,kl_loss_2,0.0027701943569506207
|
203 |
+
12,kl_loss_3,0.0031717634992673993
|
204 |
+
12,kl_loss_4,0.002853672835044563
|
205 |
+
12,kl_loss_5,0.0033753060658151903
|
206 |
+
12,kl_loss_6,0.002695254709882041
|
207 |
+
12,kl_loss_7,0.0025047556264325976
|
208 |
+
12,kl_loss_8,0.0035389424689734974
|
209 |
+
12,kl_loss_9,0.7568817337354025
|
210 |
+
13,recon_loss,2205.216759314904
|
211 |
+
13,loss,2007.8895357572114
|
212 |
+
13,mi_loss,40.38474479088416
|
213 |
+
13,tc_loss,-37.186089735764725
|
214 |
+
13,dw_kl_loss,0.33352065086364746
|
215 |
+
13,kl_loss,3.4887316043560324
|
216 |
+
13,kl_loss_0,0.0025462846343333903
|
217 |
+
13,kl_loss_1,2.7032171212709866
|
218 |
+
13,kl_loss_2,0.0023540596549327555
|
219 |
+
13,kl_loss_3,0.0035233735106885433
|
220 |
+
13,kl_loss_4,0.0030828057430111445
|
221 |
+
13,kl_loss_5,0.002101871149184612
|
222 |
+
13,kl_loss_6,0.0022684050580629935
|
223 |
+
13,kl_loss_7,0.0018005223514942022
|
224 |
+
13,kl_loss_8,0.002609655655060823
|
225 |
+
13,kl_loss_9,0.7652274553592389
|
226 |
+
14,recon_loss,2203.2179158528647
|
227 |
+
14,loss,2006.1795145670574
|
228 |
+
14,mi_loss,40.190184911092125
|
229 |
+
14,tc_loss,-37.12630271911621
|
230 |
+
14,dw_kl_loss,0.42094290008147556
|
231 |
+
14,kl_loss,3.521983802318573
|
232 |
+
14,kl_loss_0,0.0017787318599099915
|
233 |
+
14,kl_loss_1,2.733785887559255
|
234 |
+
14,kl_loss_2,0.0020149560490002236
|
235 |
+
14,kl_loss_3,0.0026113978043819466
|
236 |
+
14,kl_loss_4,0.0019087268350025017
|
237 |
+
14,kl_loss_5,0.0026779077791919312
|
238 |
+
14,kl_loss_6,0.002250582134972016
|
239 |
+
14,kl_loss_7,0.0023648399704446397
|
240 |
+
14,kl_loss_8,0.002240404215020438
|
241 |
+
14,kl_loss_9,0.770350361863772
|
242 |
+
15,recon_loss,2211.151329627404
|
243 |
+
15,loss,2014.0557016225962
|
244 |
+
15,mi_loss,40.30680495042067
|
245 |
+
15,tc_loss,-37.14148154625526
|
246 |
+
15,dw_kl_loss,0.3123072420175259
|
247 |
+
15,kl_loss,3.4762388192690334
|
248 |
+
15,kl_loss_0,0.002717837368926177
|
249 |
+
15,kl_loss_1,2.6843277307657094
|
250 |
+
15,kl_loss_2,0.002006038987579254
|
251 |
+
15,kl_loss_3,0.0024746029207912777
|
252 |
+
15,kl_loss_4,0.0028183436594330347
|
253 |
+
15,kl_loss_5,0.0020857790413384256
|
254 |
+
15,kl_loss_6,0.0016617546431147135
|
255 |
+
15,kl_loss_7,0.0018722561474602956
|
256 |
+
15,kl_loss_8,0.0026360321789979935
|
257 |
+
15,kl_loss_9,0.7736383722378657
|
258 |
+
16,recon_loss,2205.65625
|
259 |
+
16,loss,2008.7182006835938
|
260 |
+
16,mi_loss,40.34294160207113
|
261 |
+
16,tc_loss,-37.122435569763184
|
262 |
+
16,dw_kl_loss,0.30258239308993023
|
263 |
+
16,kl_loss,3.5443766514460244
|
264 |
+
16,kl_loss_0,0.0019371571640173595
|
265 |
+
16,kl_loss_1,2.7254956364631653
|
266 |
+
16,kl_loss_2,0.0018169079363966982
|
267 |
+
16,kl_loss_3,0.0018087425269186497
|
268 |
+
16,kl_loss_4,0.0017845627153292298
|
269 |
+
16,kl_loss_5,0.002286287684304019
|
270 |
+
16,kl_loss_6,0.002490955676573018
|
271 |
+
16,kl_loss_7,0.0018719588794435065
|
272 |
+
16,kl_loss_8,0.0028695606160908937
|
273 |
+
16,kl_loss_9,0.802014946937561
|
274 |
+
17,recon_loss,2198.288348858173
|
275 |
+
17,loss,2001.154766376202
|
276 |
+
17,mi_loss,40.29867348304162
|
277 |
+
17,tc_loss,-37.145969977745644
|
278 |
+
17,dw_kl_loss,0.3019469002118477
|
279 |
+
17,kl_loss,3.4419017755068264
|
280 |
+
17,kl_loss_0,0.0018457891777730905
|
281 |
+
17,kl_loss_1,2.6779668514545145
|
282 |
+
17,kl_loss_2,0.0019515727718289082
|
283 |
+
17,kl_loss_3,0.002290374062095697
|
284 |
+
17,kl_loss_4,0.0018018389550539164
|
285 |
+
17,kl_loss_5,0.0016674650474809683
|
286 |
+
17,kl_loss_6,0.001166574119661863
|
287 |
+
17,kl_loss_7,0.0020146659002281153
|
288 |
+
17,kl_loss_8,0.001355440105096652
|
289 |
+
17,kl_loss_9,0.7498412315662091
|
290 |
+
18,recon_loss,2198.6248575846353
|
291 |
+
18,loss,2001.4028625488281
|
292 |
+
18,mi_loss,40.322779973347984
|
293 |
+
18,tc_loss,-37.165176709493004
|
294 |
+
18,dw_kl_loss,0.31236464778582257
|
295 |
+
18,kl_loss,3.5249797304471335
|
296 |
+
18,kl_loss_0,0.0014744151073197524
|
297 |
+
18,kl_loss_1,2.756214439868927
|
298 |
+
18,kl_loss_2,0.0018141739613686998
|
299 |
+
18,kl_loss_3,0.0015735261452694733
|
300 |
+
18,kl_loss_4,0.0010972448314229648
|
301 |
+
18,kl_loss_5,0.0018187712412327528
|
302 |
+
18,kl_loss_6,0.0031091769536336264
|
303 |
+
18,kl_loss_7,0.0022688430811588964
|
304 |
+
18,kl_loss_8,0.0020122514882435403
|
305 |
+
18,kl_loss_9,0.7535968770583471
|
306 |
+
19,recon_loss,2212.218937800481
|
307 |
+
19,loss,2015.4007662259614
|
308 |
+
19,mi_loss,40.31675866933969
|
309 |
+
19,tc_loss,-37.111326951246994
|
310 |
+
19,dw_kl_loss,0.37757790088653564
|
311 |
+
19,kl_loss,3.59769160930927
|
312 |
+
19,kl_loss_0,0.0016564357882508864
|
313 |
+
19,kl_loss_1,2.8344204976008487
|
314 |
+
19,kl_loss_2,0.003070324229506346
|
315 |
+
19,kl_loss_3,0.0013594922179786058
|
316 |
+
19,kl_loss_4,0.001366000407590316
|
317 |
+
19,kl_loss_5,0.002101327758282423
|
318 |
+
19,kl_loss_6,0.0017111848036830241
|
319 |
+
19,kl_loss_7,0.0018114753497334628
|
320 |
+
19,kl_loss_8,0.002678456107297769
|
321 |
+
19,kl_loss_9,0.7475164211713351
|
322 |
+
20,recon_loss,2202.2599487304688
|
323 |
+
20,loss,2004.8538614908855
|
324 |
+
20,mi_loss,40.45914777119955
|
325 |
+
20,tc_loss,-37.21652317047119
|
326 |
+
20,dw_kl_loss,0.3205146963397662
|
327 |
+
20,kl_loss,3.6001903414726257
|
328 |
+
20,kl_loss_0,0.001861263260555764
|
329 |
+
20,kl_loss_1,2.8560718297958374
|
330 |
+
20,kl_loss_2,0.0016067677255099018
|
331 |
+
20,kl_loss_3,0.0021211198375870786
|
332 |
+
20,kl_loss_4,0.0015188050456345081
|
333 |
+
20,kl_loss_5,0.0015412049445634086
|
334 |
+
20,kl_loss_6,0.0016792030461753409
|
335 |
+
20,kl_loss_7,0.0019624157963941493
|
336 |
+
20,kl_loss_8,0.0014063233975321054
|
337 |
+
20,kl_loss_9,0.7304214636484782
|
338 |
+
21,recon_loss,2207.904334435096
|
339 |
+
21,loss,2010.744882436899
|
340 |
+
21,mi_loss,40.43276625413161
|
341 |
+
21,tc_loss,-37.171888204721306
|
342 |
+
21,dw_kl_loss,0.3078639725079903
|
343 |
+
21,kl_loss,3.561537137398353
|
344 |
+
21,kl_loss_0,0.0012945552858022542
|
345 |
+
21,kl_loss_1,2.7875716319450965
|
346 |
+
21,kl_loss_2,0.001239637235322824
|
347 |
+
21,kl_loss_3,0.0011019167275382923
|
348 |
+
21,kl_loss_4,0.0019378469755443244
|
349 |
+
21,kl_loss_5,0.0017049070447683334
|
350 |
+
21,kl_loss_6,0.0016337902141878237
|
351 |
+
21,kl_loss_7,0.0016330647855423964
|
352 |
+
21,kl_loss_8,0.001772170993857659
|
353 |
+
21,kl_loss_9,0.7616475820541382
|
354 |
+
22,recon_loss,2210.4578247070312
|
355 |
+
22,loss,2013.5369771321614
|
356 |
+
22,mi_loss,40.30475012461344
|
357 |
+
22,tc_loss,-37.11816851298014
|
358 |
+
22,dw_kl_loss,0.3306787001589934
|
359 |
+
22,kl_loss,3.547394037246704
|
360 |
+
22,kl_loss_0,0.0017359372383604448
|
361 |
+
22,kl_loss_1,2.824240207672119
|
362 |
+
22,kl_loss_2,0.0011897988927861054
|
363 |
+
22,kl_loss_3,0.0013395212202643354
|
364 |
+
22,kl_loss_4,0.001481865649111569
|
365 |
+
22,kl_loss_5,0.001497445589241882
|
366 |
+
22,kl_loss_6,0.001183633382121722
|
367 |
+
22,kl_loss_7,0.0013544160562256973
|
368 |
+
22,kl_loss_8,0.001382134932403763
|
369 |
+
22,kl_loss_9,0.7119891196489334
|
370 |
+
23,recon_loss,2210.9183443509614
|
371 |
+
23,loss,2013.5601994441106
|
372 |
+
23,mi_loss,40.38095914400541
|
373 |
+
23,tc_loss,-37.20592850905199
|
374 |
+
23,dw_kl_loss,0.3788297451459445
|
375 |
+
23,kl_loss,3.5697782589839053
|
376 |
+
23,kl_loss_0,0.001240100496663497
|
377 |
+
23,kl_loss_1,2.8392400558178243
|
378 |
+
23,kl_loss_2,0.0012557490442234736
|
379 |
+
23,kl_loss_3,0.0016619229259399267
|
380 |
+
23,kl_loss_4,0.0014921167435554357
|
381 |
+
23,kl_loss_5,0.0014155316524780714
|
382 |
+
23,kl_loss_6,0.0015437594041801416
|
383 |
+
23,kl_loss_7,0.0013436333706172614
|
384 |
+
23,kl_loss_8,0.0015257975946252162
|
385 |
+
23,kl_loss_9,0.7190595681850727
|
386 |
+
24,recon_loss,2205.2496948242188
|
387 |
+
24,loss,2008.2829081217449
|
388 |
+
24,mi_loss,40.33272743225098
|
389 |
+
24,tc_loss,-37.13726011912028
|
390 |
+
24,dw_kl_loss,0.3789539597928524
|
391 |
+
24,kl_loss,3.5514416495958963
|
392 |
+
24,kl_loss_0,0.0012495704383278887
|
393 |
+
24,kl_loss_1,2.788417716821035
|
394 |
+
24,kl_loss_2,0.0018070755759254098
|
395 |
+
24,kl_loss_3,0.0012991854067270954
|
396 |
+
24,kl_loss_4,0.0010839081757391493
|
397 |
+
24,kl_loss_5,0.0019774677930399776
|
398 |
+
24,kl_loss_6,0.0014922072102005284
|
399 |
+
24,kl_loss_7,0.0012786747732510169
|
400 |
+
24,kl_loss_8,0.001147995547701915
|
401 |
+
24,kl_loss_9,0.7516879439353943
|
402 |
+
25,recon_loss,2208.270467122396
|
403 |
+
25,loss,2010.8847351074219
|
404 |
+
25,mi_loss,40.34546057383219
|
405 |
+
25,tc_loss,-37.20619010925293
|
406 |
+
25,dw_kl_loss,0.38843652978539467
|
407 |
+
25,kl_loss,3.6115296880404153
|
408 |
+
25,kl_loss_0,0.0015405131659160058
|
409 |
+
25,kl_loss_1,2.9013171394666037
|
410 |
+
25,kl_loss_2,0.0011097188883771498
|
411 |
+
25,kl_loss_3,0.0018861019440616171
|
412 |
+
25,kl_loss_4,0.0015740771389876802
|
413 |
+
25,kl_loss_5,0.0015496667086457212
|
414 |
+
25,kl_loss_6,0.001531413911531369
|
415 |
+
25,kl_loss_7,0.0019122938392683864
|
416 |
+
25,kl_loss_8,0.0011389451489473383
|
417 |
+
25,kl_loss_9,0.6979698638121287
|
418 |
+
26,recon_loss,2203.443903996394
|
419 |
+
26,loss,2006.6229717548076
|
420 |
+
26,mi_loss,40.34916833730844
|
421 |
+
26,tc_loss,-37.11746098445012
|
422 |
+
26,dw_kl_loss,0.3816586262904681
|
423 |
+
26,kl_loss,3.5773275265326867
|
424 |
+
26,kl_loss_0,0.001173520998026316
|
425 |
+
26,kl_loss_1,2.8474206374241757
|
426 |
+
26,kl_loss_2,0.0016174610489262985
|
427 |
+
26,kl_loss_3,0.0015147336257191806
|
428 |
+
26,kl_loss_4,0.0010703427788729852
|
429 |
+
26,kl_loss_5,0.0011983873107685493
|
430 |
+
26,kl_loss_6,0.00126259820535779
|
431 |
+
26,kl_loss_7,0.001042055252652902
|
432 |
+
26,kl_loss_8,0.0008434637879522947
|
433 |
+
26,kl_loss_9,0.7201842207175034
|
434 |
+
27,recon_loss,2206.4358520507812
|
435 |
+
27,loss,2009.3228352864583
|
436 |
+
27,mi_loss,40.37657833099365
|
437 |
+
27,tc_loss,-37.16412862141927
|
438 |
+
27,dw_kl_loss,0.360832375784715
|
439 |
+
27,kl_loss,3.5843074520428977
|
440 |
+
27,kl_loss_0,0.0011982040014117956
|
441 |
+
27,kl_loss_1,2.846610724925995
|
442 |
+
27,kl_loss_2,0.0011326035019010305
|
443 |
+
27,kl_loss_3,0.001651634695008397
|
444 |
+
27,kl_loss_4,0.0012949489755555987
|
445 |
+
27,kl_loss_5,0.0014977601046363513
|
446 |
+
27,kl_loss_6,0.001088831612529854
|
447 |
+
27,kl_loss_7,0.0012824649068837364
|
448 |
+
27,kl_loss_8,0.001276915582517783
|
449 |
+
27,kl_loss_9,0.7272733698288599
|
450 |
+
28,recon_loss,2204.0506497896636
|
451 |
+
28,loss,2007.2717942457932
|
452 |
+
28,mi_loss,40.36799122737004
|
453 |
+
28,tc_loss,-37.11823654174805
|
454 |
+
28,dw_kl_loss,0.4098718739472903
|
455 |
+
28,kl_loss,3.5976003866929274
|
456 |
+
28,kl_loss_0,0.0012191503595274228
|
457 |
+
28,kl_loss_1,2.8608819337991567
|
458 |
+
28,kl_loss_2,0.0008682911642468893
|
459 |
+
28,kl_loss_3,0.0012883888557553291
|
460 |
+
28,kl_loss_4,0.001379269605072645
|
461 |
+
28,kl_loss_5,0.0015075570688797878
|
462 |
+
28,kl_loss_6,0.0012084581316090547
|
463 |
+
28,kl_loss_7,0.0009197160028494322
|
464 |
+
28,kl_loss_8,0.0013607487154121583
|
465 |
+
28,kl_loss_9,0.726966903759883
|
466 |
+
29,recon_loss,2199.2364501953125
|
467 |
+
29,loss,2002.3850402832031
|
468 |
+
29,mi_loss,40.270185470581055
|
469 |
+
29,tc_loss,-37.105968157450356
|
470 |
+
29,dw_kl_loss,0.3566128934423129
|
471 |
+
29,kl_loss,3.578150769074758
|
472 |
+
29,kl_loss_0,0.0011238552397117019
|
473 |
+
29,kl_loss_1,2.8515902956326804
|
474 |
+
29,kl_loss_2,0.0014782638754695654
|
475 |
+
29,kl_loss_3,0.0011515693040564656
|
476 |
+
29,kl_loss_4,0.0014677915023639798
|
477 |
+
29,kl_loss_5,0.0013430318019042413
|
478 |
+
29,kl_loss_6,0.0013744460884481668
|
479 |
+
29,kl_loss_7,0.0012495889095589519
|
480 |
+
29,kl_loss_8,0.00138730447118481
|
481 |
+
29,kl_loss_9,0.715984657406807
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
#plotly
|
4 |
+
# scipy
|
5 |
+
tqdm
|
6 |
+
# pillow
|
7 |
+
|
8 |
+
# ipywidgets
|
9 |
+
# jupyterlab
|
10 |
+
|
11 |
+
torch
|
transforms.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alle transforms sind grundsätzlich auf batches bezogen!
|
3 |
+
Vae transforms sind invertierbar
|
4 |
+
"""
|
5 |
+
import pickle
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from functools import partial, reduce, wraps
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Allgemeine Funktionen -------------------------------------------------------------
|
13 |
+
# Transformations in Pytorch sind am einfachsten.
|
14 |
+
|
15 |
+
|
16 |
+
def load(p):
|
17 |
+
with open(p, "rb") as stream:
|
18 |
+
return pickle.load(stream)
|
19 |
+
|
20 |
+
|
21 |
+
def save(obj, p):
|
22 |
+
with open(p, "wb") as stream:
|
23 |
+
pickle.dump(obj, stream)
|
24 |
+
|
25 |
+
|
26 |
+
def sequential_function(*functions):
|
27 |
+
return lambda x: reduce(lambda res, func: func(res), functions, x)
|
28 |
+
|
29 |
+
|
30 |
+
def np_sample(func):
|
31 |
+
rtn = sequential_function(
|
32 |
+
lambda x: torch.from_numpy(x).float(),
|
33 |
+
lambda x: torch.unsqueeze(x, 0),
|
34 |
+
func,
|
35 |
+
lambda x: x[0].numpy(),
|
36 |
+
)
|
37 |
+
return rtn
|
38 |
+
|
39 |
+
|
40 |
+
# Inverseabvle
|
41 |
+
class SequentialInversable(torch.nn.Sequential):
|
42 |
+
def __init__(self, *functions):
|
43 |
+
super().__init__(*functions)
|
44 |
+
|
45 |
+
self.inv_funcs = [f.inv for f in functions]
|
46 |
+
self.inv_funcs.reverse()
|
47 |
+
|
48 |
+
# def forward(self, x):
|
49 |
+
# return sequential_function(*self.functions)(x)
|
50 |
+
|
51 |
+
def inv(self, x):
|
52 |
+
return sequential_function(*self.inv_funcs)(x)
|
53 |
+
|
54 |
+
|
55 |
+
class LatentSelector(torch.nn.Module):
|
56 |
+
"""Verarbeitet Tensoren und numpy arrays"""
|
57 |
+
|
58 |
+
def __init__(self, ldim: int, selectdim: int):
|
59 |
+
super().__init__()
|
60 |
+
self.ldim = ldim
|
61 |
+
self.selectdim = selectdim
|
62 |
+
|
63 |
+
def forward(self, x: torch.Tensor):
|
64 |
+
return x[:, : self.selectdim]
|
65 |
+
|
66 |
+
def inv(self, x: torch.Tensor):
|
67 |
+
rtn = torch.cat(
|
68 |
+
[x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
|
69 |
+
dim=1,
|
70 |
+
)
|
71 |
+
return rtn
|
72 |
+
|
73 |
+
|
74 |
+
class MinMaxScaler(torch.nn.Module):
|
75 |
+
#! Bei mehreren Signalen vorsicht mit dem Broadcasting.
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
_min: torch.Tensor,
|
79 |
+
_max: torch.Tensor,
|
80 |
+
min_norm: float = 0.0,
|
81 |
+
max_norm: float = 1.0,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
self._min = _min
|
85 |
+
self._max = _max
|
86 |
+
self.min_norm = min_norm
|
87 |
+
self.max_norm = max_norm
|
88 |
+
|
89 |
+
def forward(self, ts):
|
90 |
+
"""None, no_signals"""
|
91 |
+
std = (ts - self._min) / (self._max - self._min)
|
92 |
+
rtn = std * (self.max_norm - self.min_norm) + self.min_norm
|
93 |
+
return rtn
|
94 |
+
|
95 |
+
def inv(self, ts):
|
96 |
+
std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
|
97 |
+
rtn = std * (self._max - self._min) + self._min
|
98 |
+
return rtn
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def from_array(cls, arr: torch.Tensor):
|
102 |
+
_min = torch.min(arr, axis=0).values
|
103 |
+
_max = torch.max(arr, axis=0).values
|
104 |
+
|
105 |
+
return cls(_min, _max)
|
106 |
+
|
107 |
+
|
108 |
+
class LatentSorter(torch.nn.Module):
|
109 |
+
def __init__(self, kl_dict: dict):
|
110 |
+
super().__init__()
|
111 |
+
self.kl_dict = kl_dict
|
112 |
+
|
113 |
+
def forward(self, latent):
|
114 |
+
"""
|
115 |
+
unsorted -> sorted
|
116 |
+
latent: (None, latent_dim)
|
117 |
+
"""
|
118 |
+
return latent[:, list(self.kl_dict.keys())]
|
119 |
+
|
120 |
+
def inv(self, latent):
|
121 |
+
keys = np.array(list(self.kl_dict.keys()))
|
122 |
+
return latent[:, torch.from_numpy(keys.argsort())]
|
123 |
+
|
124 |
+
@property
|
125 |
+
def names(self):
|
126 |
+
rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
|
127 |
+
return rtn
|
128 |
+
|
129 |
+
|
130 |
+
def apply_along_axis(function, x, axis: int = 0):
|
131 |
+
return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)
|
132 |
+
|
133 |
+
|
134 |
+
# Eingangsshapes bleiben wie sie sind!
|
135 |
+
class SumField(torch.nn.Module):
|
136 |
+
"""
|
137 |
+
time series: [idx, time_step, signal]
|
138 |
+
image: [idx, signal, time_step, time_step]
|
139 |
+
"""
|
140 |
+
|
141 |
+
def forward(self, ts: torch.Tensor):
|
142 |
+
"""ts2img"""
|
143 |
+
|
144 |
+
samples = ts.shape[0]
|
145 |
+
time = ts.shape[1]
|
146 |
+
channels = ts.shape[2]
|
147 |
+
|
148 |
+
ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende
|
149 |
+
ts = torch.reshape(
|
150 |
+
ts, (samples * channels, time)
|
151 |
+
) # Zusammenfassen von Channel + idx
|
152 |
+
#! TODO: Schleife besser lösen
|
153 |
+
rtn = apply_along_axis(self._mtf_forward, ts, 0)
|
154 |
+
rtn = torch.reshape(rtn, (samples, channels, time, time))
|
155 |
+
|
156 |
+
return rtn
|
157 |
+
|
158 |
+
def inv(self, img: torch.Tensor):
|
159 |
+
"""img2ts"""
|
160 |
+
rtn = torch.diagonal(img, dim1=2, dim2=3)
|
161 |
+
rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen
|
162 |
+
|
163 |
+
return rtn
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def _mtf_forward(ts):
|
167 |
+
"""For one dimensional time series ts"""
|
168 |
+
return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2
|