Jonas Becker commited on
Commit
7f19394
·
1 Parent(s): 2478e2a
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ __pycache__
app.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ call venv\scripts\activate.bat
2
+ call streamlit run app.py
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import streamlit as st
3
+ import torch
4
+ import pandas as pd
5
+
6
+ import disvae
7
+ import transforms as trans
8
+
9
+ P_MODEL = "model/drilling_ds_btcvae"
10
+
11
+ # Decode Funktion --------------------------------------------------
12
+ sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
13
+ vae = disvae.load_model(P_MODEL)
14
+ scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
15
+ imaging = trans.SumField()
16
+
17
+ _dec = trans.sequential_function(
18
+ sorter.inv,
19
+ vae.decoder,
20
+ scaler.inv
21
+ )
22
+
23
+ def decode(latent):
24
+ with torch.no_grad():
25
+ return trans.np_sample(_dec)(latent)
26
+
27
+ img2ts = trans.np_sample(imaging.inv)
28
+
29
+ # GUI -----------------------------------------------------------
30
+
31
+ latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
32
+ latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
33
+
34
+ value = decode(latent_vector)
35
+
36
+ ts = img2ts(value)
37
+
38
+ df = pd.DataFrame({
39
+ "x":np.arange(len(ts)),
40
+ "y":ts.ravel()
41
+ }
42
+ )
43
+
44
+ st.line_chart(df,x="x",y="y")
45
+ st.write(ts)
46
+ # st.write(value)
47
+ # st.image(value, use_column_width="always")
48
+
49
+ # x = st.slider("Select a value")
50
+ # st.write(x, "squared is", x * x)
disvae/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from disvae.evaluate import Evaluator
2
+ from disvae.main import get_kl_dict
3
+ from disvae.training import Trainer
4
+ from disvae.utils.modelIO import load_model, save_model
5
+
6
+ # from disvae.models.vae import init_specific_model # notwendig?
disvae/evaluate.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import math
4
+ from functools import reduce
5
+ from collections import defaultdict
6
+ import json
7
+ from timeit import default_timer
8
+
9
+ from tqdm import trange, tqdm
10
+ import numpy as np
11
+ import torch
12
+
13
+ from disvae.models.losses import get_loss_f
14
+ from disvae.utils.math import log_density_gaussian
15
+ from disvae.utils.modelIO import save_metadata
16
+
17
+ TEST_LOSSES_FILE = "test_losses.log"
18
+ METRICS_FILENAME = "metrics.log"
19
+ METRIC_HELPERS_FILE = "metric_helpers.pth"
20
+
21
+
22
+ class Evaluator:
23
+ """
24
+ Class to handle training of model.
25
+
26
+ Parameters
27
+ ----------
28
+ model: disvae.vae.VAE
29
+
30
+ loss_f: disvae.models.BaseLoss
31
+ Loss function.
32
+
33
+ device: torch.device, optional
34
+ Device on which to run the code.
35
+
36
+ logger: logging.Logger, optional
37
+ Logger.
38
+
39
+ save_dir : str, optional
40
+ Directory for saving logs.
41
+
42
+ is_progress_bar: bool, optional
43
+ Whether to use a progress bar for training.
44
+ """
45
+
46
+ def __init__(self, model, loss_f,
47
+ device=torch.device("cpu"),
48
+ logger=logging.getLogger(__name__),
49
+ save_dir="results",
50
+ is_progress_bar=True):
51
+
52
+ self.device = device
53
+ self.loss_f = loss_f
54
+ self.model = model.to(self.device)
55
+ self.logger = logger
56
+ self.save_dir = save_dir
57
+ self.is_progress_bar = is_progress_bar
58
+ self.logger.info("Testing Device: {}".format(self.device))
59
+
60
+ def __call__(self, data_loader, is_metrics=False, is_losses=True):
61
+ """Compute all test losses.
62
+
63
+ Parameters
64
+ ----------
65
+ data_loader: torch.utils.data.DataLoader
66
+
67
+ is_metrics: bool, optional
68
+ Whether to compute and store the disentangling metrics.
69
+
70
+ is_losses: bool, optional
71
+ Whether to compute and store the test losses.
72
+ """
73
+ start = default_timer()
74
+ is_still_training = self.model.training
75
+ self.model.eval()
76
+
77
+ metric, losses = None, None
78
+ if is_metrics:
79
+ self.logger.info('Computing metrics...')
80
+ metrics = self.compute_metrics(data_loader)
81
+ self.logger.info('Losses: {}'.format(metrics))
82
+ save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)
83
+
84
+ if is_losses:
85
+ self.logger.info('Computing losses...')
86
+ losses = self.compute_losses(data_loader)
87
+ self.logger.info('Losses: {}'.format(losses))
88
+ save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)
89
+
90
+ if is_still_training:
91
+ self.model.train()
92
+
93
+ self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))
94
+
95
+ return metric, losses
96
+
97
+ def compute_losses(self, dataloader):
98
+ """Compute all test losses.
99
+
100
+ Parameters
101
+ ----------
102
+ data_loader: torch.utils.data.DataLoader
103
+ """
104
+ storer = defaultdict(list)
105
+ for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
106
+ data = data.to(self.device)
107
+
108
+ try:
109
+ recon_batch, latent_dist, latent_sample = self.model(data)
110
+ _ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
111
+ storer, latent_sample=latent_sample)
112
+ except ValueError:
113
+ # for losses that use multiple optimizers (e.g. Factor)
114
+ _ = self.loss_f.call_optimize(data, self.model, None, storer)
115
+
116
+ losses = {k: sum(v) / len(dataloader) for k, v in storer.items()}
117
+ return losses
118
+
119
+ def compute_metrics(self, dataloader):
120
+ """Compute all the metrics.
121
+
122
+ Parameters
123
+ ----------
124
+ data_loader: torch.utils.data.DataLoader
125
+ """
126
+ try:
127
+ lat_sizes = dataloader.dataset.lat_sizes
128
+ lat_names = dataloader.dataset.lat_names
129
+ except AttributeError:
130
+ raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__))
131
+
132
+ self.logger.info("Computing the empirical distribution q(z|x).")
133
+ samples_zCx, params_zCx = self._compute_q_zCx(dataloader)
134
+ len_dataset, latent_dim = samples_zCx.shape
135
+
136
+ self.logger.info("Estimating the marginal entropy.")
137
+ # marginal entropy H(z_j)
138
+ H_z = self._estimate_latent_entropies(samples_zCx, params_zCx)
139
+
140
+ # conditional entropy H(z|v)
141
+ samples_zCx = samples_zCx.view(*lat_sizes, latent_dim)
142
+ params_zCx = tuple(p.view(*lat_sizes, latent_dim) for p in params_zCx)
143
+ H_zCv = self._estimate_H_zCv(samples_zCx, params_zCx, lat_sizes, lat_names)
144
+
145
+ H_z = H_z.cpu()
146
+ H_zCv = H_zCv.cpu()
147
+
148
+ # I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j]
149
+ mut_info = - H_zCv + H_z
150
+ sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0)
151
+
152
+ metric_helpers = {'marginal_entropies': H_z, 'cond_entropies': H_zCv}
153
+ mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers)
154
+ aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers)
155
+
156
+ metrics = {'MIG': mig.item(), 'AAM': aam.item()}
157
+ torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
158
+
159
+ return metrics
160
+
161
+ def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None):
162
+ """Compute the mutual information gap as in [1].
163
+
164
+ References
165
+ ----------
166
+ [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
167
+ autoencoders." Advances in Neural Information Processing Systems. 2018.
168
+ """
169
+ # difference between the largest and second largest mutual info
170
+ delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1]
171
+ # NOTE: currently only works if balanced dataset for every factor of variation
172
+ # then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|)
173
+ H_v = torch.from_numpy(lat_sizes).float().log()
174
+ mig_k = delta_mut_info / H_v
175
+ mig = mig_k.mean() # mean over factor of variations
176
+
177
+ if storer is not None:
178
+ storer["mig_k"] = mig_k
179
+ storer["mig"] = mig
180
+
181
+ return mig
182
+
183
+ def _axis_aligned_metric(self, sorted_mut_info, storer=None):
184
+ """Compute the proposed axis aligned metrics."""
185
+ numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0)
186
+ aam_k = numerator / sorted_mut_info[:, 0]
187
+ aam_k[torch.isnan(aam_k)] = 0
188
+ aam = aam_k.mean() # mean over factor of variations
189
+
190
+ if storer is not None:
191
+ storer["aam_k"] = aam_k
192
+ storer["aam"] = aam
193
+
194
+ return aam
195
+
196
+ def _compute_q_zCx(self, dataloader):
197
+ """Compute the empiricall disitribution of q(z|x).
198
+
199
+ Parameter
200
+ ---------
201
+ dataloader: torch.utils.data.DataLoader
202
+ Batch data iterator.
203
+
204
+ Return
205
+ ------
206
+ samples_zCx: torch.tensor
207
+ Tensor of shape (len_dataset, latent_dim) containing a sample of
208
+ q(z|x) for every x in the dataset.
209
+
210
+ params_zCX: tuple of torch.Tensor
211
+ Sufficient statistics q(z|x) for each training example. E.g. for
212
+ gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
213
+ """
214
+ len_dataset = len(dataloader.dataset)
215
+ latent_dim = self.model.latent_dim
216
+ n_suff_stat = 2
217
+
218
+ q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
219
+
220
+ n = 0
221
+ with torch.no_grad():
222
+ for x, label in dataloader:
223
+ batch_size = x.size(0)
224
+ idcs = slice(n, n + batch_size)
225
+ q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
226
+ n += batch_size
227
+
228
+ params_zCX = q_zCx.unbind(-1)
229
+ samples_zCx = self.model.reparameterize(*params_zCX)
230
+
231
+ return samples_zCx, params_zCX
232
+
233
+ def _estimate_latent_entropies(self, samples_zCx, params_zCX,
234
+ n_samples=10000):
235
+ r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]`
236
+ using the emperical distribution of :math:`p(x)`.
237
+
238
+ Note
239
+ ----
240
+ - the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`.
241
+ - we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`.
242
+ - computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`.
243
+
244
+ Parameters
245
+ ----------
246
+ samples_zCx: torch.tensor
247
+ Tensor of shape (len_dataset, latent_dim) containing a sample of
248
+ q(z|x) for every x in the dataset.
249
+
250
+ params_zCX: tuple of torch.Tensor
251
+ Sufficient statistics q(z|x) for each training example. E.g. for
252
+ gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
253
+
254
+ n_samples: int, optional
255
+ Number of samples to use to estimate the entropies.
256
+
257
+ Return
258
+ ------
259
+ H_z: torch.Tensor
260
+ Tensor of shape (latent_dim) containing the marginal entropies H(z_j)
261
+ """
262
+ len_dataset, latent_dim = samples_zCx.shape
263
+ device = samples_zCx.device
264
+ H_z = torch.zeros(latent_dim, device=device)
265
+
266
+ # sample from p(x)
267
+ samples_x = torch.randperm(len_dataset, device=device)[:n_samples]
268
+ # sample from p(z|x)
269
+ samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples)
270
+
271
+ mini_batch_size = 10
272
+ samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
273
+ mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
274
+ log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
275
+ log_N = math.log(len_dataset)
276
+ with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
277
+ for k in range(0, n_samples, mini_batch_size):
278
+ # log q(z_j|x) for n_samples
279
+ idcs = slice(k, k + mini_batch_size)
280
+ log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
281
+ mean[..., idcs],
282
+ log_var[..., idcs])
283
+ # numerically stable log q(z_j) for n_samples:
284
+ # log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
285
+ # As we don't know q(z) we appoximate it with the monte carlo
286
+ # expectation of q(z_j|x_n) over x. => fix a single z and look at
287
+ # proba for every x to generate it. n_samples is not used here !
288
+ log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False)
289
+ # H(z_j) = E_{z_j}[- log q(z_j)]
290
+ # mean over n_samples (i.e. dimesnion 1 because already summed over 0).
291
+ H_z += (-log_q_z).sum(1)
292
+
293
+ t.update(mini_batch_size)
294
+
295
+ H_z /= n_samples
296
+
297
+ return H_z
298
+
299
+ def _estimate_H_zCv(self, samples_zCx, params_zCx, lat_sizes, lat_names):
300
+ """Estimate conditional entropies :math:`H[z|v]`."""
301
+ latent_dim = samples_zCx.size(-1)
302
+ len_dataset = reduce((lambda x, y: x * y), lat_sizes)
303
+ H_zCv = torch.zeros(len(lat_sizes), latent_dim, device=self.device)
304
+ for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)):
305
+ idcs = [slice(None)] * len(lat_sizes)
306
+ for i in range(lat_size):
307
+ self.logger.info("Estimating conditional entropies for the {}th value of {}.".format(i, lat_name))
308
+ idcs[i_fac_var] = i
309
+ # samples from q(z,x|v)
310
+ samples_zxCv = samples_zCx[idcs].contiguous().view(len_dataset // lat_size,
311
+ latent_dim)
312
+ params_zxCv = tuple(p[idcs].contiguous().view(len_dataset // lat_size, latent_dim)
313
+ for p in params_zCx)
314
+
315
+ H_zCv[i_fac_var] += self._estimate_latent_entropies(samples_zxCv, params_zxCv
316
+ ) / lat_size
317
+ return H_zCv
disvae/main.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pip-Packages -----------------------------------------------------
2
+ import importlib
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from torch import optim
12
+ from torch.utils.data import DataLoader
13
+
14
+ # From local package -----------------------------------------------
15
+ from disvae.models.losses import get_loss_f
16
+ from disvae.models.vae import init_specific_model
17
+ from disvae.training import Trainer
18
+ from disvae.utils.modelIO import save_model
19
+
20
+ # Loss stuff:
21
+
22
+
23
+ def parse_losses(p_model, filename="train_losses.log"):
24
+ df = pd.read_csv(Path(p_model) / filename)
25
+
26
+ losses = df["Loss"].unique()
27
+
28
+ rtn = [np.array(df[df["Loss"] == l]["Value"]) for l in losses]
29
+ rtn = pd.DataFrame(np.array(rtn).T, columns=losses)
30
+
31
+ return rtn
32
+
33
+
34
+ def get_kl_loss_latent(df):
35
+ """df muss bereits geparsed sein!"""
36
+ rtn = {int(c.split("_")[-1]): df[c].iloc[-1] for c in df if "kl_loss_" in c}
37
+ rtn = dict(sorted(rtn.items(), key=lambda item: item[1], reverse=True))
38
+ return rtn
39
+
40
+
41
+ def get_kl_dict(p_model):
42
+ df = parse_losses(p_model)
43
+ rtn = get_kl_loss_latent(df)
44
+ return rtn
45
+
46
+
47
+ # Datalaader convinience stuff
48
+
49
+
50
+ # def get_dataloader(dataset: torch.data.Dataset, batch_size, num_workers):
51
+ # # Funktion ist recht kompliziert. Das geht im Notebook schnell
52
+ # # Diese Dinge werden auch zur Visualisierung des Datasets benötigt
53
+
54
+ # # p_dataset_module, dataset_class, dataset_args
55
+ # # Import module
56
+ # # if p_dataset_module not in sys.path:
57
+ # # sys.path.append(str(Path(p_dataset_module).parent))
58
+ # # Dataset = getattr(
59
+ # # importlib.import_module(Path(p_dataset_module).stem), dataset_class
60
+ # # )
61
+
62
+ # # # Ab hier an, wenn das normal importiert würde
63
+ # # ds = Dataset(**dataset_args)
64
+
65
+ #
66
+
67
+ # return loader
68
+
69
+
70
+ def get_export_dir(base_dir: str, folder_name):
71
+ if folder_name is None:
72
+ folder_name = "Model_" + (
73
+ datetime.now().replace(microsecond=0).isoformat()
74
+ ).replace(" ", "_").replace(":", "-")
75
+
76
+ rtn = Path(base_dir) / folder_name
77
+
78
+ if not rtn.exists():
79
+ os.makedirs(rtn)
80
+ else:
81
+ raise ValueError("Output directory already exists.")
82
+
83
+ return rtn
84
+
85
+
86
+ def train_model(model, data_loader, loss_f, device, lr, epochs, export_dir):
87
+ trainer = Trainer(
88
+ model,
89
+ optim.Adam(model.parameters(), lr=lr),
90
+ loss_f,
91
+ device=device,
92
+ # logger=logger,
93
+ save_dir=export_dir,
94
+ is_progress_bar=True,
95
+ ) # ,
96
+ # gif_visualizer=gif_visualizer)
97
+ trainer(data_loader, epochs=epochs, checkpoint_every=10)
98
+
99
+ save_model(trainer.model, export_dir)
100
+ # , metadata=config) # Speichern passiert auch schon vorher
101
+
102
+ # gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
103
+
104
+
105
+ def train(dataset, config) -> str:
106
+ # Validate Config?
107
+
108
+ print("1) Set device")
109
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+ print(f"Device:\t\t {device}")
111
+
112
+ print("2) Get dataloader")
113
+ dataloader = DataLoader(
114
+ dataset,
115
+ batch_size=config["data_params"]["batch_size"],
116
+ shuffle=True,
117
+ pin_memory=torch.cuda.is_available,
118
+ num_workers=config["data_params"]["num_workers"],
119
+ )
120
+
121
+ print("3) Build model")
122
+ img_size = list(dataloader.dataset[0][0].shape)
123
+ print(f"Image size: \t {img_size}")
124
+ model = init_specific_model(img_size=img_size, **config["model_params"])
125
+ model = model.to(device) # make sure trainer and viz on same device
126
+
127
+ print("4) Build loss function")
128
+ loss_f = get_loss_f(
129
+ n_data=len(dataloader.dataset), device=device, **config["loss_params"]
130
+ )
131
+
132
+ print("5) Parse Export Params")
133
+ export_dir = get_export_dir(**config["export_params"])
134
+
135
+ print("6) Training model")
136
+ train_model(
137
+ model=model,
138
+ data_loader=dataloader,
139
+ loss_f=loss_f,
140
+ device=device,
141
+ export_dir=export_dir,
142
+ **config["trainer_params"],
143
+ )
144
+
145
+ return export_dir
disvae/models/__init__.py ADDED
File without changes
disvae/models/decoders.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing the decoders.
3
+ """
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ # ALL decoders should be called Decoder<Model>
11
+ def get_decoder(model_type):
12
+ model_type = model_type.lower().capitalize()
13
+ return eval("Decoder{}".format(model_type))
14
+
15
+
16
+ class DecoderBurgess(nn.Module):
17
+ def __init__(self, img_size,
18
+ latent_dim=10):
19
+ r"""Decoder of the model proposed in [1].
20
+
21
+ Parameters
22
+ ----------
23
+ img_size : tuple of ints
24
+ Size of images. E.g. (1, 32, 32) or (3, 64, 64).
25
+
26
+ latent_dim : int
27
+ Dimensionality of latent output.
28
+
29
+ Model Architecture (transposed for decoder)
30
+ ------------
31
+ - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
32
+ - 2 fully connected layers (each of 256 units)
33
+ - Latent distribution:
34
+ - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
35
+
36
+ References:
37
+ [1] Burgess, Christopher P., et al. "Understanding disentangling in
38
+ $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
39
+ """
40
+ super(DecoderBurgess, self).__init__()
41
+
42
+ # Layer parameters
43
+ hid_channels = 32
44
+ kernel_size = 4
45
+ hidden_dim = 256
46
+ self.img_size = img_size
47
+ # Shape required to start transpose convs
48
+ self.reshape = (hid_channels, kernel_size, kernel_size)
49
+ n_chan = self.img_size[0]
50
+ self.img_size = img_size
51
+
52
+ # Fully connected layers
53
+ self.lin1 = nn.Linear(latent_dim, hidden_dim)
54
+ self.lin2 = nn.Linear(hidden_dim, hidden_dim)
55
+ self.lin3 = nn.Linear(hidden_dim, np.product(self.reshape))
56
+
57
+ # Convolutional layers
58
+ cnn_kwargs = dict(stride=2, padding=1)
59
+ # If input image is 64x64 do fourth convolution
60
+ if self.img_size[1] == self.img_size[2] == 64:
61
+ self.convT_64 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
62
+
63
+ self.convT1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
64
+ self.convT2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
65
+ self.convT3 = nn.ConvTranspose2d(hid_channels, n_chan, kernel_size, **cnn_kwargs)
66
+
67
+ def forward(self, z):
68
+ batch_size = z.size(0)
69
+
70
+ # Fully connected layers with ReLu activations
71
+ x = torch.relu(self.lin1(z))
72
+ x = torch.relu(self.lin2(x))
73
+ x = torch.relu(self.lin3(x))
74
+ x = x.view(batch_size, *self.reshape)
75
+
76
+ # Convolutional layers with ReLu activations
77
+ if self.img_size[1] == self.img_size[2] == 64:
78
+ x = torch.relu(self.convT_64(x))
79
+ x = torch.relu(self.convT1(x))
80
+ x = torch.relu(self.convT2(x))
81
+ # Sigmoid activation for final conv layer
82
+ x = torch.sigmoid(self.convT3(x))
83
+
84
+ return x
disvae/models/discriminator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing discriminator for FactorVAE.
3
+ """
4
+ from torch import nn
5
+
6
+ from disvae.utils.initialization import weights_init
7
+
8
+
9
+ class Discriminator(nn.Module):
10
+ def __init__(self,
11
+ neg_slope=0.2,
12
+ latent_dim=10,
13
+ hidden_units=1000):
14
+ """Discriminator proposed in [1].
15
+
16
+ Parameters
17
+ ----------
18
+ neg_slope: float
19
+ Hyperparameter for the Leaky ReLu
20
+
21
+ latent_dim : int
22
+ Dimensionality of latent variables.
23
+
24
+ hidden_units: int
25
+ Number of hidden units in the MLP
26
+
27
+ Model Architecture
28
+ ------------
29
+ - 6 layer multi-layer perceptron, each with 1000 hidden units
30
+ - Leaky ReLu activations
31
+ - Output 2 logits
32
+
33
+ References:
34
+ [1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
35
+ arXiv preprint arXiv:1802.05983 (2018).
36
+
37
+ """
38
+ super(Discriminator, self).__init__()
39
+
40
+ # Activation parameters
41
+ self.neg_slope = neg_slope
42
+ self.leaky_relu = nn.LeakyReLU(self.neg_slope, True)
43
+
44
+ # Layer parameters
45
+ self.z_dim = latent_dim
46
+ self.hidden_units = hidden_units
47
+ # theoretically 1 with sigmoid but gives bad results => use 2 and softmax
48
+ out_units = 2
49
+
50
+ # Fully connected layers
51
+ self.lin1 = nn.Linear(self.z_dim, hidden_units)
52
+ self.lin2 = nn.Linear(hidden_units, hidden_units)
53
+ self.lin3 = nn.Linear(hidden_units, hidden_units)
54
+ self.lin4 = nn.Linear(hidden_units, hidden_units)
55
+ self.lin5 = nn.Linear(hidden_units, hidden_units)
56
+ self.lin6 = nn.Linear(hidden_units, out_units)
57
+
58
+ self.reset_parameters()
59
+
60
+ def forward(self, z):
61
+
62
+ # Fully connected layers with leaky ReLu activations
63
+ z = self.leaky_relu(self.lin1(z))
64
+ z = self.leaky_relu(self.lin2(z))
65
+ z = self.leaky_relu(self.lin3(z))
66
+ z = self.leaky_relu(self.lin4(z))
67
+ z = self.leaky_relu(self.lin5(z))
68
+ z = self.lin6(z)
69
+
70
+ return z
71
+
72
+ def reset_parameters(self):
73
+ self.apply(weights_init)
disvae/models/encoders.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing the encoders.
3
+ """
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ # ALL encoders should be called Enccoder<Model>
11
+ def get_encoder(model_type):
12
+ model_type = model_type.lower().capitalize()
13
+ return eval("Encoder{}".format(model_type))
14
+
15
+
16
+ class EncoderBurgess(nn.Module):
17
+ def __init__(self, img_size,
18
+ latent_dim=10):
19
+ r"""Encoder of the model proposed in [1].
20
+
21
+ Parameters
22
+ ----------
23
+ img_size : tuple of ints
24
+ Size of images. E.g. (1, 32, 32) or (3, 64, 64).
25
+
26
+ latent_dim : int
27
+ Dimensionality of latent output.
28
+
29
+ Model Architecture (transposed for decoder)
30
+ ------------
31
+ - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
32
+ - 2 fully connected layers (each of 256 units)
33
+ - Latent distribution:
34
+ - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
35
+
36
+ References:
37
+ [1] Burgess, Christopher P., et al. "Understanding disentangling in
38
+ $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
39
+ """
40
+ super(EncoderBurgess, self).__init__()
41
+
42
+ # Layer parameters
43
+ hid_channels = 32
44
+ kernel_size = 4
45
+ hidden_dim = 256
46
+ self.latent_dim = latent_dim
47
+ self.img_size = img_size
48
+ # Shape required to start transpose convs
49
+ self.reshape = (hid_channels, kernel_size, kernel_size)
50
+ n_chan = self.img_size[0]
51
+
52
+ # Convolutional layers
53
+ cnn_kwargs = dict(stride=2, padding=1)
54
+ self.conv1 = nn.Conv2d(n_chan, hid_channels, kernel_size, **cnn_kwargs)
55
+ self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
56
+ self.conv3 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
57
+
58
+ # If input image is 64x64 do fourth convolution
59
+ if self.img_size[1] == self.img_size[2] == 64:
60
+ self.conv_64 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
61
+
62
+ # Fully connected layers
63
+ self.lin1 = nn.Linear(np.product(self.reshape), hidden_dim)
64
+ self.lin2 = nn.Linear(hidden_dim, hidden_dim)
65
+
66
+ # Fully connected layers for mean and variance
67
+ self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)
68
+
69
+ def forward(self, x):
70
+ batch_size = x.size(0)
71
+
72
+ # Convolutional layers with ReLu activations
73
+ x = torch.relu(self.conv1(x))
74
+ x = torch.relu(self.conv2(x))
75
+ x = torch.relu(self.conv3(x))
76
+ if self.img_size[1] == self.img_size[2] == 64:
77
+ x = torch.relu(self.conv_64(x))
78
+
79
+ # Fully connected layers with ReLu activations
80
+ x = x.view((batch_size, -1))
81
+ x = torch.relu(self.lin1(x))
82
+ x = torch.relu(self.lin2(x))
83
+
84
+ # Fully connected layer for log variance and mean
85
+ # Log std-dev in paper (bear in mind)
86
+ mu_logvar = self.mu_logvar_gen(x)
87
+ mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)
88
+
89
+ return mu, logvar
disvae/models/losses.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing all vae losses.
3
+ """
4
+ import abc
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch import optim
11
+
12
+ from .discriminator import Discriminator
13
+ from disvae.utils.math import (log_density_gaussian, log_importance_weight_matrix,
14
+ matrix_log_density_gaussian)
15
+
16
+
17
+ LOSSES = ["VAE", "betaH", "betaB", "factor", "btcvae"]
18
+ RECON_DIST = ["bernoulli", "laplace", "gaussian"]
19
+
20
+
21
+ # TO-DO: clean n_data and device
22
+ def get_loss_f(loss_name, **kwargs_parse):
23
+ """Return the correct loss function given the argparse arguments."""
24
+ kwargs_all = dict(rec_dist=kwargs_parse["rec_dist"],
25
+ steps_anneal=kwargs_parse["reg_anneal"])
26
+ if loss_name == "betaH":
27
+ return BetaHLoss(beta=kwargs_parse["betaH_B"], **kwargs_all)
28
+ elif loss_name == "VAE":
29
+ return BetaHLoss(beta=1, **kwargs_all)
30
+ elif loss_name == "betaB":
31
+ return BetaBLoss(C_init=kwargs_parse["betaB_initC"],
32
+ C_fin=kwargs_parse["betaB_finC"],
33
+ gamma=kwargs_parse["betaB_G"],
34
+ **kwargs_all)
35
+ elif loss_name == "factor":
36
+ return FactorKLoss(kwargs_parse["device"],
37
+ gamma=kwargs_parse["factor_G"],
38
+ disc_kwargs=dict(latent_dim=kwargs_parse["latent_dim"]),
39
+ optim_kwargs=dict(lr=kwargs_parse["lr_disc"], betas=(0.5, 0.9)),
40
+ **kwargs_all)
41
+ elif loss_name == "btcvae":
42
+ return BtcvaeLoss(kwargs_parse["n_data"],
43
+ alpha=kwargs_parse["btcvae_A"],
44
+ beta=kwargs_parse["btcvae_B"],
45
+ gamma=kwargs_parse["btcvae_G"],
46
+ **kwargs_all)
47
+ else:
48
+ assert loss_name not in LOSSES
49
+ raise ValueError("Uknown loss : {}".format(loss_name))
50
+
51
+
52
+ class BaseLoss(abc.ABC):
53
+ """
54
+ Base class for losses.
55
+
56
+ Parameters
57
+ ----------
58
+ record_loss_every: int, optional
59
+ Every how many steps to recorsd the loss.
60
+
61
+ rec_dist: {"bernoulli", "gaussian", "laplace"}, optional
62
+ Reconstruction distribution istribution of the likelihood on the each pixel.
63
+ Implicitely defines the reconstruction loss. Bernoulli corresponds to a
64
+ binary cross entropy (bse), Gaussian corresponds to MSE, Laplace
65
+ corresponds to L1.
66
+
67
+ steps_anneal: nool, optional
68
+ Number of annealing steps where gradually adding the regularisation.
69
+ """
70
+
71
+ def __init__(self, record_loss_every=50, rec_dist="bernoulli", steps_anneal=0):
72
+ self.n_train_steps = 0
73
+ self.record_loss_every = record_loss_every
74
+ self.rec_dist = rec_dist
75
+ self.steps_anneal = steps_anneal
76
+
77
+ @abc.abstractmethod
78
+ def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
79
+ """
80
+ Calculates loss for a batch of data.
81
+
82
+ Parameters
83
+ ----------
84
+ data : torch.Tensor
85
+ Input data (e.g. batch of images). Shape : (batch_size, n_chan,
86
+ height, width).
87
+
88
+ recon_data : torch.Tensor
89
+ Reconstructed data. Shape : (batch_size, n_chan, height, width).
90
+
91
+ latent_dist : tuple of torch.tensor
92
+ sufficient statistics of the latent dimension. E.g. for gaussian
93
+ (mean, log_var) each of shape : (batch_size, latent_dim).
94
+
95
+ is_train : bool
96
+ Whether currently in train mode.
97
+
98
+ storer : dict
99
+ Dictionary in which to store important variables for vizualisation.
100
+
101
+ kwargs:
102
+ Loss specific arguments
103
+ """
104
+
105
+ def _pre_call(self, is_train, storer):
106
+ if is_train:
107
+ self.n_train_steps += 1
108
+
109
+ if not is_train or self.n_train_steps % self.record_loss_every == 1:
110
+ storer = storer
111
+ else:
112
+ storer = None
113
+
114
+ return storer
115
+
116
+
117
+ class BetaHLoss(BaseLoss):
118
+ """
119
+ Compute the Beta-VAE loss as in [1]
120
+
121
+ Parameters
122
+ ----------
123
+ beta : float, optional
124
+ Weight of the kl divergence.
125
+
126
+ kwargs:
127
+ Additional arguments for `BaseLoss`, e.g. rec_dist`.
128
+
129
+ References
130
+ ----------
131
+ [1] Higgins, Irina, et al. "beta-vae: Learning basic visual concepts with
132
+ a constrained variational framework." (2016).
133
+ """
134
+
135
+ def __init__(self, beta=4, **kwargs):
136
+ super().__init__(**kwargs)
137
+ self.beta = beta
138
+
139
+ def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
140
+ storer = self._pre_call(is_train, storer)
141
+
142
+ rec_loss = _reconstruction_loss(data, recon_data,
143
+ storer=storer,
144
+ distribution=self.rec_dist)
145
+ kl_loss = _kl_normal_loss(*latent_dist, storer)
146
+ anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
147
+ if is_train else 1)
148
+ loss = rec_loss + anneal_reg * (self.beta * kl_loss)
149
+
150
+ if storer is not None:
151
+ storer['loss'].append(loss.item())
152
+
153
+ return loss
154
+
155
+
156
+ class BetaBLoss(BaseLoss):
157
+ """
158
+ Compute the Beta-VAE loss as in [1]
159
+
160
+ Parameters
161
+ ----------
162
+ C_init : float, optional
163
+ Starting annealed capacity C.
164
+
165
+ C_fin : float, optional
166
+ Final annealed capacity C.
167
+
168
+ gamma : float, optional
169
+ Weight of the KL divergence term.
170
+
171
+ kwargs:
172
+ Additional arguments for `BaseLoss`, e.g. rec_dist`.
173
+
174
+ References
175
+ ----------
176
+ [1] Burgess, Christopher P., et al. "Understanding disentangling in
177
+ $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
178
+ """
179
+
180
+ def __init__(self, C_init=0., C_fin=20., gamma=100., **kwargs):
181
+ super().__init__(**kwargs)
182
+ self.gamma = gamma
183
+ self.C_init = C_init
184
+ self.C_fin = C_fin
185
+
186
+ def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
187
+ storer = self._pre_call(is_train, storer)
188
+
189
+ rec_loss = _reconstruction_loss(data, recon_data,
190
+ storer=storer,
191
+ distribution=self.rec_dist)
192
+ kl_loss = _kl_normal_loss(*latent_dist, storer)
193
+
194
+ C = (linear_annealing(self.C_init, self.C_fin, self.n_train_steps, self.steps_anneal)
195
+ if is_train else self.C_fin)
196
+
197
+ loss = rec_loss + self.gamma * (kl_loss - C).abs()
198
+
199
+ if storer is not None:
200
+ storer['loss'].append(loss.item())
201
+
202
+ return loss
203
+
204
+
205
+ class FactorKLoss(BaseLoss):
206
+ """
207
+ Compute the Factor-VAE loss as per Algorithm 2 of [1]
208
+
209
+ Parameters
210
+ ----------
211
+ device : torch.device
212
+
213
+ gamma : float, optional
214
+ Weight of the TC loss term. `gamma` in the paper.
215
+
216
+ discriminator : disvae.discriminator.Discriminator
217
+
218
+ optimizer_d : torch.optim
219
+
220
+ kwargs:
221
+ Additional arguments for `BaseLoss`, e.g. rec_dist`.
222
+
223
+ References
224
+ ----------
225
+ [1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
226
+ arXiv preprint arXiv:1802.05983 (2018).
227
+ """
228
+
229
+ def __init__(self, device,
230
+ gamma=10.,
231
+ disc_kwargs={},
232
+ optim_kwargs=dict(lr=5e-5, betas=(0.5, 0.9)),
233
+ **kwargs):
234
+ super().__init__(**kwargs)
235
+ self.gamma = gamma
236
+ self.device = device
237
+ self.discriminator = Discriminator(**disc_kwargs).to(self.device)
238
+ self.optimizer_d = optim.Adam(self.discriminator.parameters(), **optim_kwargs)
239
+
240
+ def __call__(self, *args, **kwargs):
241
+ raise ValueError("Use `call_optimize` to also train the discriminator")
242
+
243
+ def call_optimize(self, data, model, optimizer, storer):
244
+ storer = self._pre_call(model.training, storer)
245
+
246
+ # factor-vae split data into two batches. In the paper they sample 2 batches
247
+ batch_size = data.size(dim=0)
248
+ half_batch_size = batch_size // 2
249
+ data = data.split(half_batch_size)
250
+ data1 = data[0]
251
+ data2 = data[1]
252
+
253
+ # Factor VAE Loss
254
+ recon_batch, latent_dist, latent_sample1 = model(data1)
255
+ rec_loss = _reconstruction_loss(data1, recon_batch,
256
+ storer=storer,
257
+ distribution=self.rec_dist)
258
+
259
+ kl_loss = _kl_normal_loss(*latent_dist, storer)
260
+
261
+ d_z = self.discriminator(latent_sample1)
262
+ # We want log(p_true/p_false). If not using logisitc regression but softmax
263
+ # then p_true = exp(logit_true) / Z; p_false = exp(logit_false) / Z
264
+ # so log(p_true/p_false) = logit_true - logit_false
265
+ tc_loss = (d_z[:, 0] - d_z[:, 1]).mean()
266
+ # with sigmoid (not good results) should be `tc_loss = (2 * d_z.flatten()).mean()`
267
+
268
+ anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
269
+ if model.training else 1)
270
+ vae_loss = rec_loss + kl_loss + anneal_reg * self.gamma * tc_loss
271
+
272
+ if storer is not None:
273
+ storer['loss'].append(vae_loss.item())
274
+ storer['tc_loss'].append(tc_loss.item())
275
+
276
+ if not model.training:
277
+ # don't backprop if evaluating
278
+ return vae_loss
279
+
280
+ # Compute VAE gradients
281
+ optimizer.zero_grad()
282
+ vae_loss.backward(retain_graph=True)
283
+
284
+ # Discriminator Loss
285
+ # Get second sample of latent distribution
286
+ latent_sample2 = model.sample_latent(data2)
287
+ z_perm = _permute_dims(latent_sample2).detach()
288
+ d_z_perm = self.discriminator(z_perm)
289
+
290
+ # Calculate total correlation loss
291
+ # for cross entropy the target is the index => need to be long and says
292
+ # that it's first output for d_z and second for perm
293
+ ones = torch.ones(half_batch_size, dtype=torch.long, device=self.device)
294
+ zeros = torch.zeros_like(ones)
295
+ d_tc_loss = 0.5 * (F.cross_entropy(d_z, zeros) + F.cross_entropy(d_z_perm, ones))
296
+ # with sigmoid would be :
297
+ # d_tc_loss = 0.5 * (self.bce(d_z.flatten(), ones) + self.bce(d_z_perm.flatten(), 1 - ones))
298
+
299
+ # TO-DO: check ifshould also anneals discriminator if not becomes too good ???
300
+ #d_tc_loss = anneal_reg * d_tc_loss
301
+
302
+ # Compute discriminator gradients
303
+ self.optimizer_d.zero_grad()
304
+ d_tc_loss.backward()
305
+
306
+ # Update at the end (since pytorch 1.5. complains if update before)
307
+ optimizer.step()
308
+ self.optimizer_d.step()
309
+
310
+ if storer is not None:
311
+ storer['discrim_loss'].append(d_tc_loss.item())
312
+
313
+ return vae_loss
314
+
315
+
316
+ class BtcvaeLoss(BaseLoss):
317
+ """
318
+ Compute the decomposed KL loss with either minibatch weighted sampling or
319
+ minibatch stratified sampling according to [1]
320
+
321
+ Parameters
322
+ ----------
323
+ n_data: int
324
+ Number of data in the training set
325
+
326
+ alpha : float
327
+ Weight of the mutual information term.
328
+
329
+ beta : float
330
+ Weight of the total correlation term.
331
+
332
+ gamma : float
333
+ Weight of the dimension-wise KL term.
334
+
335
+ is_mss : bool
336
+ Whether to use minibatch stratified sampling instead of minibatch
337
+ weighted sampling.
338
+
339
+ kwargs:
340
+ Additional arguments for `BaseLoss`, e.g. rec_dist`.
341
+
342
+ References
343
+ ----------
344
+ [1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
345
+ autoencoders." Advances in Neural Information Processing Systems. 2018.
346
+ """
347
+
348
+ def __init__(self, n_data, alpha=1., beta=6., gamma=1., is_mss=True, **kwargs):
349
+ super().__init__(**kwargs)
350
+ self.n_data = n_data
351
+ self.beta = beta
352
+ self.alpha = alpha
353
+ self.gamma = gamma
354
+ self.is_mss = is_mss # minibatch stratified sampling
355
+
356
+ def __call__(self, data, recon_batch, latent_dist, is_train, storer,
357
+ latent_sample=None):
358
+ storer = self._pre_call(is_train, storer)
359
+ batch_size, latent_dim = latent_sample.shape
360
+
361
+ rec_loss = _reconstruction_loss(data, recon_batch,
362
+ storer=storer,
363
+ distribution=self.rec_dist)
364
+ log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
365
+ latent_dist,
366
+ self.n_data,
367
+ is_mss=self.is_mss)
368
+ # I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
369
+ mi_loss = (log_q_zCx - log_qz).mean()
370
+ # TC[z] = KL[q(z)||\prod_i z_i]
371
+ tc_loss = (log_qz - log_prod_qzi).mean()
372
+ # dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
373
+ dw_kl_loss = (log_prod_qzi - log_pz).mean()
374
+
375
+ anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
376
+ if is_train else 1)
377
+
378
+ # total loss
379
+ loss = rec_loss + (self.alpha * mi_loss +
380
+ self.beta * tc_loss +
381
+ anneal_reg * self.gamma * dw_kl_loss)
382
+
383
+ if storer is not None:
384
+ storer['loss'].append(loss.item())
385
+ storer['mi_loss'].append(mi_loss.item())
386
+ storer['tc_loss'].append(tc_loss.item())
387
+ storer['dw_kl_loss'].append(dw_kl_loss.item())
388
+ # computing this for storing and comparaison purposes
389
+ _ = _kl_normal_loss(*latent_dist, storer)
390
+
391
+ return loss
392
+
393
+
394
+ def _reconstruction_loss(data, recon_data, distribution="bernoulli", storer=None):
395
+ """
396
+ Calculates the per image reconstruction loss for a batch of data. I.e. negative
397
+ log likelihood.
398
+
399
+ Parameters
400
+ ----------
401
+ data : torch.Tensor
402
+ Input data (e.g. batch of images). Shape : (batch_size, n_chan,
403
+ height, width).
404
+
405
+ recon_data : torch.Tensor
406
+ Reconstructed data. Shape : (batch_size, n_chan, height, width).
407
+
408
+ distribution : {"bernoulli", "gaussian", "laplace"}
409
+ Distribution of the likelihood on the each pixel. Implicitely defines the
410
+ loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the
411
+ most commonly used. It has the issue that it doesn't penalize the same
412
+ way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian
413
+ distribution corresponds to MSE, and is sometimes used, but hard to train
414
+ ecause it ends up focusing only a few pixels that are very wrong. Laplace
415
+ distribution corresponds to L1 solves partially the issue of MSE.
416
+
417
+ storer : dict
418
+ Dictionary in which to store important variables for vizualisation.
419
+
420
+ Returns
421
+ -------
422
+ loss : torch.Tensor
423
+ Per image cross entropy (i.e. normalized per batch but not pixel and
424
+ channel)
425
+ """
426
+ batch_size, n_chan, height, width = recon_data.size()
427
+ is_colored = n_chan == 3
428
+
429
+ if distribution == "bernoulli":
430
+ loss = F.binary_cross_entropy(recon_data, data, reduction="sum")
431
+ elif distribution == "gaussian":
432
+ # loss in [0,255] space but normalized by 255 to not be too big
433
+ loss = F.mse_loss(recon_data * 255, data * 255, reduction="sum") / 255
434
+ elif distribution == "laplace":
435
+ # loss in [0,255] space but normalized by 255 to not be too big but
436
+ # multiply by 255 and divide 255, is the same as not doing anything for L1
437
+ loss = F.l1_loss(recon_data, data, reduction="sum")
438
+ loss = loss * 3 # emperical value to give similar values than bernoulli => use same hyperparam
439
+ loss = loss * (loss != 0) # masking to avoid nan
440
+ else:
441
+ assert distribution not in RECON_DIST
442
+ raise ValueError("Unkown distribution: {}".format(distribution))
443
+
444
+ loss = loss / batch_size
445
+
446
+ if storer is not None:
447
+ storer['recon_loss'].append(loss.item())
448
+
449
+ return loss
450
+
451
+
452
+ def _kl_normal_loss(mean, logvar, storer=None):
453
+ """
454
+ Calculates the KL divergence between a normal distribution
455
+ with diagonal covariance and a unit normal distribution.
456
+
457
+ Parameters
458
+ ----------
459
+ mean : torch.Tensor
460
+ Mean of the normal distribution. Shape (batch_size, latent_dim) where
461
+ D is dimension of distribution.
462
+
463
+ logvar : torch.Tensor
464
+ Diagonal log variance of the normal distribution. Shape (batch_size,
465
+ latent_dim)
466
+
467
+ storer : dict
468
+ Dictionary in which to store important variables for vizualisation.
469
+ """
470
+ latent_dim = mean.size(1)
471
+ # batch mean of kl for each latent dimension
472
+ latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
473
+ total_kl = latent_kl.sum()
474
+
475
+ if storer is not None:
476
+ storer['kl_loss'].append(total_kl.item())
477
+ for i in range(latent_dim):
478
+ storer['kl_loss_' + str(i)].append(latent_kl[i].item())
479
+
480
+ return total_kl
481
+
482
+
483
+ def _permute_dims(latent_sample):
484
+ """
485
+ Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from
486
+ q(z) (latent_dist) across the batch for each of the latent dimensions (mean
487
+ and log_var).
488
+
489
+ Parameters
490
+ ----------
491
+ latent_sample: torch.Tensor
492
+ sample from the latent dimension using the reparameterisation trick
493
+ shape : (batch_size, latent_dim).
494
+
495
+ References
496
+ ----------
497
+ [1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
498
+ arXiv preprint arXiv:1802.05983 (2018).
499
+
500
+ """
501
+ perm = torch.zeros_like(latent_sample)
502
+ batch_size, dim_z = perm.size()
503
+
504
+ for z in range(dim_z):
505
+ pi = torch.randperm(batch_size).to(latent_sample.device)
506
+ perm[:, z] = latent_sample[pi, z]
507
+
508
+ return perm
509
+
510
+
511
+ def linear_annealing(init, fin, step, annealing_steps):
512
+ """Linear annealing of a parameter."""
513
+ if annealing_steps == 0:
514
+ return fin
515
+ assert fin > init
516
+ delta = fin - init
517
+ annealed = min(init + delta * step / annealing_steps, fin)
518
+ return annealed
519
+
520
+
521
+ # Batch TC specific
522
+ # TO-DO: test if mss is better!
523
+ def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
524
+ batch_size, hidden_dim = latent_sample.shape
525
+
526
+ # calculate log q(z|x)
527
+ log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
528
+
529
+ # calculate log p(z)
530
+ # mean and log var is 0
531
+ zeros = torch.zeros_like(latent_sample)
532
+ log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
533
+
534
+ mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
535
+
536
+ if is_mss:
537
+ # use stratification
538
+ log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
539
+ mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
540
+
541
+ log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
542
+ log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
543
+
544
+ return log_pz, log_qz, log_prod_qzi, log_q_zCx
disvae/models/vae.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module containing the main VAE class.
3
+ """
4
+ import torch
5
+ from torch import nn, optim
6
+ from torch.nn import functional as F
7
+
8
+ from disvae.utils.initialization import weights_init
9
+ from .encoders import get_encoder
10
+ from .decoders import get_decoder
11
+
12
+ MODELS = ["Burgess"]
13
+
14
+
15
+ def init_specific_model(model_type, img_size, latent_dim):
16
+ """Return an instance of a VAE with encoder and decoder from `model_type`."""
17
+ model_type = model_type.lower().capitalize()
18
+ if model_type not in MODELS:
19
+ err = "Unkown model_type={}. Possible values: {}"
20
+ raise ValueError(err.format(model_type, MODELS))
21
+
22
+ encoder = get_encoder(model_type)
23
+ decoder = get_decoder(model_type)
24
+ model = VAE(img_size, encoder, decoder, latent_dim)
25
+ model.model_type = model_type # store to help reloading
26
+ return model
27
+
28
+
29
+ class VAE(nn.Module):
30
+ def __init__(self, img_size, encoder, decoder, latent_dim):
31
+ """
32
+ Class which defines model and forward pass.
33
+
34
+ Parameters
35
+ ----------
36
+ img_size : tuple of ints
37
+ Size of images. E.g. (1, 32, 32) or (3, 64, 64).
38
+ """
39
+ super(VAE, self).__init__()
40
+
41
+ if list(img_size[1:]) not in [[32, 32], [64, 64]]:
42
+ raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size))
43
+
44
+ self.latent_dim = latent_dim
45
+ self.img_size = img_size
46
+ self.num_pixels = self.img_size[1] * self.img_size[2]
47
+ self.encoder = encoder(img_size, self.latent_dim)
48
+ self.decoder = decoder(img_size, self.latent_dim)
49
+
50
+ self.reset_parameters()
51
+
52
+ def reparameterize(self, mean, logvar):
53
+ """
54
+ Samples from a normal distribution using the reparameterization trick.
55
+
56
+ Parameters
57
+ ----------
58
+ mean : torch.Tensor
59
+ Mean of the normal distribution. Shape (batch_size, latent_dim)
60
+
61
+ logvar : torch.Tensor
62
+ Diagonal log variance of the normal distribution. Shape (batch_size,
63
+ latent_dim)
64
+ """
65
+ if self.training:
66
+ std = torch.exp(0.5 * logvar)
67
+ eps = torch.randn_like(std)
68
+ return mean + std * eps
69
+ else:
70
+ # Reconstruction mode
71
+ return mean
72
+
73
+ def forward(self, x):
74
+ """
75
+ Forward pass of model.
76
+
77
+ Parameters
78
+ ----------
79
+ x : torch.Tensor
80
+ Batch of data. Shape (batch_size, n_chan, height, width)
81
+ """
82
+ latent_dist = self.encoder(x)
83
+ latent_sample = self.reparameterize(*latent_dist)
84
+ reconstruct = self.decoder(latent_sample)
85
+ return reconstruct, latent_dist, latent_sample
86
+
87
+ def reset_parameters(self):
88
+ self.apply(weights_init)
89
+
90
+ def sample_latent(self, x):
91
+ """
92
+ Returns a sample from the latent distribution.
93
+
94
+ Parameters
95
+ ----------
96
+ x : torch.Tensor
97
+ Batch of data. Shape (batch_size, n_chan, height, width)
98
+ """
99
+ latent_dist = self.encoder(x)
100
+ latent_sample = self.reparameterize(*latent_dist)
101
+ return latent_sample
disvae/training.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import imageio
2
+ import logging
3
+ import os
4
+ from collections import defaultdict
5
+ from timeit import default_timer
6
+
7
+ import torch
8
+ from torch.nn import functional as F
9
+ from tqdm import trange
10
+
11
+ from disvae.utils.modelIO import save_model
12
+
13
+ TRAIN_LOSSES_LOGFILE = "train_losses.log"
14
+
15
+
16
+ class Trainer:
17
+ """
18
+ Class to handle training of model.
19
+
20
+ Parameters
21
+ ----------
22
+ model: disvae.vae.VAE
23
+
24
+ optimizer: torch.optim.Optimizer
25
+
26
+ loss_f: disvae.models.BaseLoss
27
+ Loss function.
28
+
29
+ device: torch.device, optional
30
+ Device on which to run the code.
31
+
32
+ logger: logging.Logger, optional
33
+ Logger.
34
+
35
+ save_dir : str, optional
36
+ Directory for saving logs.
37
+
38
+ gif_visualizer : viz.Visualizer, optional
39
+ Gif Visualizer that should return samples at every epochs.
40
+
41
+ is_progress_bar: bool, optional
42
+ Whether to use a progress bar for training.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ model,
48
+ optimizer,
49
+ loss_f,
50
+ device=torch.device("cpu"),
51
+ logger=logging.getLogger(__name__),
52
+ save_dir="results",
53
+ gif_visualizer=None,
54
+ is_progress_bar=True,
55
+ ):
56
+ self.device = device
57
+ self.model = model.to(self.device)
58
+ self.loss_f = loss_f
59
+ self.optimizer = optimizer
60
+ self.save_dir = save_dir
61
+ self.is_progress_bar = is_progress_bar
62
+ self.logger = logger
63
+ self.losses_logger = LossesLogger(
64
+ os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE)
65
+ )
66
+ self.gif_visualizer = gif_visualizer
67
+ self.logger.info("Training Device: {}".format(self.device))
68
+
69
+ def __call__(self, data_loader, epochs=10, checkpoint_every=10):
70
+ """
71
+ Trains the model.
72
+
73
+ Parameters
74
+ ----------
75
+ data_loader: torch.utils.data.DataLoader
76
+
77
+ epochs: int, optional
78
+ Number of epochs to train the model for.
79
+
80
+ checkpoint_every: int, optional
81
+ Save a checkpoint of the trained model every n epoch.
82
+ """
83
+ start = default_timer()
84
+ self.model.train()
85
+ for epoch in range(epochs):
86
+ storer = defaultdict(list)
87
+ mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
88
+ self.logger.info(
89
+ "Epoch: {} Average loss per image: {:.2f}".format(
90
+ epoch + 1, mean_epoch_loss
91
+ )
92
+ )
93
+ self.losses_logger.log(epoch, storer)
94
+
95
+ if self.gif_visualizer is not None:
96
+ self.gif_visualizer()
97
+
98
+ if epoch % checkpoint_every == 0:
99
+ save_model(
100
+ self.model, self.save_dir, filename="model-{}.pt".format(epoch)
101
+ )
102
+
103
+ if self.gif_visualizer is not None:
104
+ self.gif_visualizer.save_reset()
105
+
106
+ self.model.eval()
107
+
108
+ delta_time = (default_timer() - start) / 60
109
+ self.logger.info("Finished training after {:.1f} min.".format(delta_time))
110
+
111
+ def _train_epoch(self, data_loader, storer, epoch):
112
+ """
113
+ Trains the model for one epoch.
114
+
115
+ Parameters
116
+ ----------
117
+ data_loader: torch.utils.data.DataLoader
118
+
119
+ storer: dict
120
+ Dictionary in which to store important variables for vizualisation.
121
+
122
+ epoch: int
123
+ Epoch number
124
+
125
+ Return
126
+ ------
127
+ mean_epoch_loss: float
128
+ Mean loss per image
129
+ """
130
+ epoch_loss = 0.0
131
+ kwargs = dict(
132
+ desc="Epoch {}".format(epoch + 1),
133
+ leave=False,
134
+ disable=not self.is_progress_bar,
135
+ )
136
+ with trange(len(data_loader), **kwargs) as t:
137
+ for _, (data, _) in enumerate(data_loader):
138
+ iter_loss = self._train_iteration(data, storer)
139
+ epoch_loss += iter_loss
140
+
141
+ t.set_postfix(loss=iter_loss)
142
+ t.update()
143
+
144
+ mean_epoch_loss = epoch_loss / len(data_loader)
145
+ return mean_epoch_loss
146
+
147
+ def _train_iteration(self, data, storer):
148
+ """
149
+ Trains the model for one iteration on a batch of data.
150
+
151
+ Parameters
152
+ ----------
153
+ data: torch.Tensor
154
+ A batch of data. Shape : (batch_size, channel, height, width).
155
+
156
+ storer: dict
157
+ Dictionary in which to store important variables for vizualisation.
158
+ """
159
+ batch_size, channel, height, width = data.size()
160
+ data = data.to(self.device)
161
+
162
+ try:
163
+ recon_batch, latent_dist, latent_sample = self.model(data)
164
+ loss = self.loss_f(
165
+ data,
166
+ recon_batch,
167
+ latent_dist,
168
+ self.model.training,
169
+ storer,
170
+ latent_sample=latent_sample,
171
+ )
172
+ self.optimizer.zero_grad()
173
+ loss.backward()
174
+ self.optimizer.step()
175
+
176
+ except ValueError:
177
+ # for losses that use multiple optimizers (e.g. Factor)
178
+ loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer)
179
+
180
+ return loss.item()
181
+
182
+
183
+ class LossesLogger(object):
184
+ """Class definition for objects to write data to log files in a
185
+ form which is then easy to be plotted.
186
+ """
187
+
188
+ def __init__(self, file_path_name):
189
+ """Create a logger to store information for plotting."""
190
+ if os.path.isfile(file_path_name):
191
+ os.remove(file_path_name)
192
+
193
+ self.logger = logging.getLogger("losses_logger")
194
+ self.logger.setLevel(1) # always store
195
+ file_handler = logging.FileHandler(file_path_name)
196
+ file_handler.setLevel(1)
197
+ self.logger.addHandler(file_handler)
198
+
199
+ header = ",".join(["Epoch", "Loss", "Value"])
200
+ self.logger.debug(header)
201
+
202
+ def log(self, epoch, losses_storer):
203
+ """Write to the log file"""
204
+ for k, v in losses_storer.items():
205
+ log_string = ",".join(str(item) for item in [epoch, k, mean(v)])
206
+ self.logger.debug(log_string)
207
+
208
+
209
+ # HELPERS
210
+ def mean(l):
211
+ """Compute the mean of a list"""
212
+ return sum(l) / len(l)
disvae/utils/__init__.py ADDED
File without changes
disvae/utils/initialization.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def get_activation_name(activation):
6
+ """Given a string or a `torch.nn.modules.activation` return the name of the activation."""
7
+ if isinstance(activation, str):
8
+ return activation
9
+
10
+ mapper = {nn.LeakyReLU: "leaky_relu", nn.ReLU: "relu", nn.Tanh: "tanh",
11
+ nn.Sigmoid: "sigmoid", nn.Softmax: "sigmoid"}
12
+ for k, v in mapper.items():
13
+ if isinstance(activation, k):
14
+ return k
15
+
16
+ raise ValueError("Unkown given activation type : {}".format(activation))
17
+
18
+
19
+ def get_gain(activation):
20
+ """Given an object of `torch.nn.modules.activation` or an activation name
21
+ return the correct gain."""
22
+ if activation is None:
23
+ return 1
24
+
25
+ activation_name = get_activation_name(activation)
26
+
27
+ param = None if activation_name != "leaky_relu" else activation.negative_slope
28
+ gain = nn.init.calculate_gain(activation_name, param)
29
+
30
+ return gain
31
+
32
+
33
+ def linear_init(layer, activation="relu"):
34
+ """Initialize a linear layer.
35
+ Args:
36
+ layer (nn.Linear): parameters to initialize.
37
+ activation (`torch.nn.modules.activation` or str, optional) activation that
38
+ will be used on the `layer`.
39
+ """
40
+ x = layer.weight
41
+
42
+ if activation is None:
43
+ return nn.init.xavier_uniform_(x)
44
+
45
+ activation_name = get_activation_name(activation)
46
+
47
+ if activation_name == "leaky_relu":
48
+ a = 0 if isinstance(activation, str) else activation.negative_slope
49
+ return nn.init.kaiming_uniform_(x, a=a, nonlinearity='leaky_relu')
50
+ elif activation_name == "relu":
51
+ return nn.init.kaiming_uniform_(x, nonlinearity='relu')
52
+ elif activation_name in ["sigmoid", "tanh"]:
53
+ return nn.init.xavier_uniform_(x, gain=get_gain(activation))
54
+
55
+
56
+ def weights_init(module):
57
+ if isinstance(module, torch.nn.modules.conv._ConvNd):
58
+ # TO-DO: check litterature
59
+ linear_init(module)
60
+ elif isinstance(module, nn.Linear):
61
+ linear_init(module)
disvae/utils/math.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+
4
+ from tqdm import trange, tqdm
5
+ import torch
6
+
7
+
8
+ def matrix_log_density_gaussian(x, mu, logvar):
9
+ """Calculates log density of a Gaussian for all combination of bacth pairs of
10
+ `x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
11
+ instead of (batch_size, dim) in the usual log density.
12
+
13
+ Parameters
14
+ ----------
15
+ x: torch.Tensor
16
+ Value at which to compute the density. Shape: (batch_size, dim).
17
+
18
+ mu: torch.Tensor
19
+ Mean. Shape: (batch_size, dim).
20
+
21
+ logvar: torch.Tensor
22
+ Log variance. Shape: (batch_size, dim).
23
+
24
+ batch_size: int
25
+ number of training images in the batch
26
+ """
27
+ batch_size, dim = x.shape
28
+ x = x.view(batch_size, 1, dim)
29
+ mu = mu.view(1, batch_size, dim)
30
+ logvar = logvar.view(1, batch_size, dim)
31
+ return log_density_gaussian(x, mu, logvar)
32
+
33
+
34
+ def log_density_gaussian(x, mu, logvar):
35
+ """Calculates log density of a Gaussian.
36
+
37
+ Parameters
38
+ ----------
39
+ x: torch.Tensor or np.ndarray or float
40
+ Value at which to compute the density.
41
+
42
+ mu: torch.Tensor or np.ndarray or float
43
+ Mean.
44
+
45
+ logvar: torch.Tensor or np.ndarray or float
46
+ Log variance.
47
+ """
48
+ normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
49
+ inv_var = torch.exp(-logvar)
50
+ log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
51
+ return log_density
52
+
53
+
54
+ def log_importance_weight_matrix(batch_size, dataset_size):
55
+ """
56
+ Calculates a log importance weight matrix
57
+
58
+ Parameters
59
+ ----------
60
+ batch_size: int
61
+ number of training images in the batch
62
+
63
+ dataset_size: int
64
+ number of training images in the dataset
65
+ """
66
+ N = dataset_size
67
+ M = batch_size - 1
68
+ strat_weight = (N - M) / (N * M)
69
+ W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
70
+ W.view(-1)[::M + 1] = 1 / N
71
+ W.view(-1)[1::M + 1] = strat_weight
72
+ W[M - 1, 0] = strat_weight
73
+ return W.log()
disvae/utils/modelIO.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from disvae.models.vae import init_specific_model
10
+
11
+ MODEL_FILENAME = "model.pt"
12
+ META_FILENAME = "specs.json"
13
+
14
+
15
+ def vae2onnx(vae, p_out: str) -> None:
16
+ if isinstance(vae, str):
17
+ p_out = Path(p_out)
18
+ if not p_out.exists():
19
+ p_out.mkdir()
20
+
21
+ device = next(vae.parameters()).device
22
+ vae.cpu()
23
+
24
+ # Encoder
25
+ vae.encoder.eval()
26
+ dummy_input_im = torch.zeros(tuple(np.concatenate([[1], vae.img_size])))
27
+ torch.onnx.export(vae.encoder, dummy_input_im, p_out / "encoder.onnx", verbose=True)
28
+
29
+ # Decoder
30
+ vae.decoder.eval()
31
+ dummy_input_latent = torch.zeros((1, vae.latent_dim))
32
+ torch.onnx.export(
33
+ vae.decoder, dummy_input_latent, p_out / "decoder.onnx", verbose=True
34
+ )
35
+
36
+ vae.to(device) # restore device
37
+
38
+
39
+ def save_model(model, directory, metadata=None, filename=MODEL_FILENAME):
40
+ """
41
+ Save a model and corresponding metadata.
42
+
43
+ Parameters
44
+ ----------
45
+ model : nn.Module
46
+ Model.
47
+
48
+ directory : str
49
+ Path to the directory where to save the data.
50
+
51
+ metadata : dict
52
+ Metadata to save.
53
+ """
54
+ device = next(model.parameters()).device
55
+ model.cpu()
56
+
57
+ if metadata is None:
58
+ # save the minimum required for loading
59
+ metadata = dict(
60
+ img_size=model.img_size,
61
+ latent_dim=model.latent_dim,
62
+ model_type=model.model_type,
63
+ )
64
+
65
+ save_metadata(metadata, directory)
66
+
67
+ path_to_model = os.path.join(directory, filename)
68
+ torch.save(model.state_dict(), path_to_model)
69
+
70
+ model.to(device) # restore device
71
+
72
+
73
+ def load_metadata(directory, filename=META_FILENAME):
74
+ """Load the metadata of a training directory.
75
+
76
+ Parameters
77
+ ----------
78
+ directory : string
79
+ Path to folder where model is saved. For example './experiments/mnist'.
80
+ """
81
+ path_to_metadata = os.path.join(directory, filename)
82
+
83
+ with open(path_to_metadata) as metadata_file:
84
+ metadata = json.load(metadata_file)
85
+
86
+ return metadata
87
+
88
+
89
+ def save_metadata(metadata, directory, filename=META_FILENAME, **kwargs):
90
+ """Load the metadata of a training directory.
91
+
92
+ Parameters
93
+ ----------
94
+ metadata:
95
+ Object to save
96
+
97
+ directory: string
98
+ Path to folder where to save model. For example './experiments/mnist'.
99
+
100
+ kwargs:
101
+ Additional arguments to `json.dump`
102
+ """
103
+ path_to_metadata = os.path.join(directory, filename)
104
+
105
+ with open(path_to_metadata, "w") as f:
106
+ json.dump(metadata, f, indent=4, sort_keys=True, **kwargs)
107
+
108
+
109
+ def load_model(directory, is_gpu=True, filename=MODEL_FILENAME):
110
+ """Load a trained model.
111
+
112
+ Parameters
113
+ ----------
114
+ directory : string
115
+ Path to folder where model is saved. For example './experiments/mnist'.
116
+
117
+ is_gpu : bool
118
+ Whether to load on GPU is available.
119
+ """
120
+ device = torch.device("cuda" if torch.cuda.is_available() and is_gpu else "cpu")
121
+
122
+ path_to_model = os.path.join(directory, MODEL_FILENAME)
123
+
124
+ metadata = load_metadata(directory)
125
+ img_size = metadata["img_size"]
126
+ latent_dim = metadata["latent_dim"]
127
+ model_type = metadata["model_type"]
128
+
129
+ path_to_model = os.path.join(directory, filename)
130
+ model = _get_model(model_type, img_size, latent_dim, device, path_to_model)
131
+ return model
132
+
133
+
134
+ def load_checkpoints(directory, is_gpu=True):
135
+ """Load all chechpointed models.
136
+
137
+ Parameters
138
+ ----------
139
+ directory : string
140
+ Path to folder where model is saved. For example './experiments/mnist'.
141
+
142
+ is_gpu : bool
143
+ Whether to load on GPU .
144
+ """
145
+ checkpoints = []
146
+ for root, _, filenames in os.walk(directory):
147
+ for filename in filenames:
148
+ results = re.search(r".*?-([0-9].*?).pt", filename)
149
+ if results is not None:
150
+ epoch_idx = int(results.group(1))
151
+ model = load_model(root, is_gpu=is_gpu, filename=filename)
152
+ checkpoints.append((epoch_idx, model))
153
+
154
+ return checkpoints
155
+
156
+
157
+ def _get_model(model_type, img_size, latent_dim, device, path_to_model):
158
+ """Load a single model.
159
+
160
+ Parameters
161
+ ----------
162
+ model_type : str
163
+ The name of the model to load. For example Burgess.
164
+ img_size : tuple
165
+ Tuple of the number of pixels in the image width and height.
166
+ For example (32, 32) or (64, 64).
167
+ latent_dim : int
168
+ The number of latent dimensions in the bottleneck.
169
+
170
+ device : str
171
+ Either 'cuda' or 'cpu'
172
+ path_to_device : str
173
+ Full path to the saved model on the device.
174
+ """
175
+ model = init_specific_model(model_type, img_size, latent_dim).to(device)
176
+ # works with state_dict to make it independent of the file structure
177
+ model.load_state_dict(torch.load(path_to_model), strict=False)
178
+ model.eval()
179
+
180
+ return model
181
+
182
+
183
+ def numpy_serialize(obj):
184
+ if type(obj).__module__ == np.__name__:
185
+ if isinstance(obj, np.ndarray):
186
+ return obj.tolist()
187
+ else:
188
+ return obj.item()
189
+ raise TypeError("Unknown type:", type(obj))
190
+
191
+
192
+ def save_np_arrays(arrays, directory, filename):
193
+ """Save dictionary of arrays in json file."""
194
+ save_metadata(arrays, directory, filename=filename, default=numpy_serialize)
195
+
196
+
197
+ def load_np_arrays(directory, filename):
198
+ """Load dictionary of arrays from json file."""
199
+ arrays = load_metadata(directory, filename=filename)
200
+ return {k: np.array(v) for k, v in arrays.items()}
model/drilling_ds_btcvae/model-0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bd1cfaf8d1124dc405fc753e77da693d74220b72e95971660d5a21ef2986081
3
+ size 2016145
model/drilling_ds_btcvae/model-10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:745b798b9252ebde56455f80caf95f948d1515cdd36fb28ee5d3738e619c1f89
3
+ size 2016687
model/drilling_ds_btcvae/model-20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a5d11e24ceea210dc9792497940cc84996f785f8f821fc4dd9b51bad6607f8f
3
+ size 2016687
model/drilling_ds_btcvae/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8288bc5d3e7d160054afe19afc6cfa4ea884106fe3377f6e00a6e4c384f4d5a
3
+ size 2014933
model/drilling_ds_btcvae/onnx/decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7212ac0302bbf4072e3d083b393c24355d1d88c671b9cb7f037e91ea5312d745
3
+ size 1046902
model/drilling_ds_btcvae/onnx/encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3299e3b39f5181946675ba3da2f17be960ccfb9aa3aaed56af173724720e657
3
+ size 1074967
model/drilling_ds_btcvae/specs.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "img_size": [
3
+ 1,
4
+ 64,
5
+ 64
6
+ ],
7
+ "latent_dim": 10,
8
+ "model_type": "Burgess"
9
+ }
model/drilling_ds_btcvae/train_losses.log ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Epoch,Loss,Value
2
+ 0,recon_loss,2249.039325420673
3
+ 0,loss,2050.8551119290864
4
+ 0,mi_loss,39.089650080754204
5
+ 0,tc_loss,-37.15008985079252
6
+ 0,dw_kl_loss,14.773665877488943
7
+ 0,kl_loss,16.61438555900867
8
+ 0,kl_loss_0,1.3352613651122038
9
+ 0,kl_loss_1,2.678740862183846
10
+ 0,kl_loss_2,1.5543635350007277
11
+ 0,kl_loss_3,2.3668181953521876
12
+ 0,kl_loss_4,1.1609044877382426
13
+ 0,kl_loss_5,1.2264407391731555
14
+ 0,kl_loss_6,1.9435587427937067
15
+ 0,kl_loss_7,1.346766499372629
16
+ 0,kl_loss_8,1.6203752496781259
17
+ 0,kl_loss_9,1.3811558622580309
18
+ 1,recon_loss,2217.2862955729165
19
+ 1,loss,2020.4783630371094
20
+ 1,mi_loss,39.90034135182699
21
+ 1,tc_loss,-37.161014238993324
22
+ 1,dw_kl_loss,12.359588861465454
23
+ 1,kl_loss,15.169561068216959
24
+ 1,kl_loss_0,1.0467030107975006
25
+ 1,kl_loss_1,2.6103671193122864
26
+ 1,kl_loss_2,1.0980331550041835
27
+ 1,kl_loss_3,2.823746303717295
28
+ 1,kl_loss_4,0.8646406829357147
29
+ 1,kl_loss_5,0.937546119093895
30
+ 1,kl_loss_6,1.9802785913149517
31
+ 1,kl_loss_7,1.117189645767212
32
+ 1,kl_loss_8,1.3413666437069576
33
+ 1,kl_loss_9,1.3496898810068767
34
+ 2,recon_loss,2214.081355168269
35
+ 2,loss,2017.7317645733174
36
+ 2,mi_loss,39.97995347243089
37
+ 2,tc_loss,-37.11946810208834
38
+ 2,dw_kl_loss,8.114381826840914
39
+ 2,kl_loss,10.967073367192196
40
+ 2,kl_loss_0,0.6580769190421472
41
+ 2,kl_loss_1,2.3679186747624326
42
+ 2,kl_loss_2,0.6067795891028184
43
+ 2,kl_loss_3,2.3443003251002383
44
+ 2,kl_loss_4,0.5523027869371268
45
+ 2,kl_loss_5,0.630363792181015
46
+ 2,kl_loss_6,1.140236680324261
47
+ 2,kl_loss_7,0.6584483155837426
48
+ 2,kl_loss_8,0.7375089365702409
49
+ 2,kl_loss_9,1.271137265058664
50
+ 3,recon_loss,2212.1290079752603
51
+ 3,loss,2015.6785278320312
52
+ 3,mi_loss,40.029717445373535
53
+ 3,tc_loss,-37.09218883514404
54
+ 3,dw_kl_loss,4.2465185324351
55
+ 3,kl_loss,7.1519254843393965
56
+ 3,kl_loss_0,0.3417855277657509
57
+ 3,kl_loss_1,2.362027664979299
58
+ 3,kl_loss_2,0.29086008047064144
59
+ 3,kl_loss_3,1.2119526664415996
60
+ 3,kl_loss_4,0.275665458291769
61
+ 3,kl_loss_5,0.34522855281829834
62
+ 3,kl_loss_6,0.5345357780655225
63
+ 3,kl_loss_7,0.33813271423180896
64
+ 3,kl_loss_8,0.3666527792811394
65
+ 3,kl_loss_9,1.0850842048724492
66
+ 4,recon_loss,2200.101787860577
67
+ 4,loss,2002.8201528695913
68
+ 4,mi_loss,39.938018798828125
69
+ 4,tc_loss,-37.15127005943885
70
+ 4,dw_kl_loss,1.9809786998308623
71
+ 4,kl_loss,4.73990275309636
72
+ 4,kl_loss_0,0.1199311871941273
73
+ 4,kl_loss_1,2.463518949655386
74
+ 4,kl_loss_2,0.10289078320448215
75
+ 4,kl_loss_3,0.4072182568219992
76
+ 4,kl_loss_4,0.10970824899581763
77
+ 4,kl_loss_5,0.17061764001846313
78
+ 4,kl_loss_6,0.20478195181259742
79
+ 4,kl_loss_7,0.12073192258293812
80
+ 4,kl_loss_8,0.1500035455593696
81
+ 4,kl_loss_9,0.8905002612334031
82
+ 5,recon_loss,2204.4312947591147
83
+ 5,loss,2007.2495625813801
84
+ 5,mi_loss,40.11853218078613
85
+ 5,tc_loss,-37.12883758544922
86
+ 5,dw_kl_loss,0.9564686765273412
87
+ 5,kl_loss,3.8624242544174194
88
+ 5,kl_loss_0,0.04401067225262523
89
+ 5,kl_loss_1,2.557211220264435
90
+ 5,kl_loss_2,0.04017933504655957
91
+ 5,kl_loss_3,0.1053344514220953
92
+ 5,kl_loss_4,0.04450849971423546
93
+ 5,kl_loss_5,0.06218112260103226
94
+ 5,kl_loss_6,0.0690454790989558
95
+ 5,kl_loss_7,0.04647901297236482
96
+ 5,kl_loss_8,0.05923667146513859
97
+ 5,kl_loss_9,0.8342377990484238
98
+ 6,recon_loss,2205.8294771634614
99
+ 6,loss,2008.6192157451924
100
+ 6,mi_loss,40.10805218036358
101
+ 6,tc_loss,-37.120982536902794
102
+ 6,dw_kl_loss,0.6391923519281241
103
+ 6,kl_loss,3.6411730509537916
104
+ 6,kl_loss_0,0.019102195349450294
105
+ 6,kl_loss_1,2.5998667020064135
106
+ 6,kl_loss_2,0.019974250418062393
107
+ 6,kl_loss_3,0.03957459602791529
108
+ 6,kl_loss_4,0.02344304831841817
109
+ 6,kl_loss_5,0.028779690225537006
110
+ 6,kl_loss_6,0.02773726975115446
111
+ 6,kl_loss_7,0.0207805114869888
112
+ 6,kl_loss_8,0.027905162853690293
113
+ 6,kl_loss_9,0.8340096381994394
114
+ 7,recon_loss,2207.6748657226562
115
+ 7,loss,2010.4278869628906
116
+ 7,mi_loss,40.222665786743164
117
+ 7,tc_loss,-37.143791834513344
118
+ 7,dw_kl_loss,0.5344960068662962
119
+ 7,kl_loss,3.6037667989730835
120
+ 7,kl_loss_0,0.012062065963012477
121
+ 7,kl_loss_1,2.7152191599210105
122
+ 7,kl_loss_2,0.011340752360410988
123
+ 7,kl_loss_3,0.020317877798030775
124
+ 7,kl_loss_4,0.010189585232486328
125
+ 7,kl_loss_5,0.01701245631556958
126
+ 7,kl_loss_6,0.013870434602722526
127
+ 7,kl_loss_7,0.012214566503340999
128
+ 7,kl_loss_8,0.014409278597061833
129
+ 7,kl_loss_9,0.7771305988232294
130
+ 8,recon_loss,2204.4822904146636
131
+ 8,loss,2007.0154465895432
132
+ 8,mi_loss,40.23030794583834
133
+ 8,tc_loss,-37.17468848595252
134
+ 8,dw_kl_loss,0.4160432998950665
135
+ 8,kl_loss,3.491810395167424
136
+ 8,kl_loss_0,0.007323983161208721
137
+ 8,kl_loss_1,2.6830281294309177
138
+ 8,kl_loss_2,0.006681273798816479
139
+ 8,kl_loss_3,0.010057092895014929
140
+ 8,kl_loss_4,0.007608905219687865
141
+ 8,kl_loss_5,0.008022861555218697
142
+ 8,kl_loss_6,0.0056697913588812715
143
+ 8,kl_loss_7,0.008376527505998429
144
+ 8,kl_loss_8,0.00837366422638297
145
+ 8,kl_loss_9,0.7466680911871103
146
+ 9,recon_loss,2206.011433919271
147
+ 9,loss,2008.6884663899739
148
+ 9,mi_loss,40.23820718129476
149
+ 9,tc_loss,-37.16220887502035
150
+ 9,dw_kl_loss,0.46936574081579846
151
+ 9,kl_loss,3.4854011138280234
152
+ 9,kl_loss_0,0.004957272865188618
153
+ 9,kl_loss_1,2.687060753504435
154
+ 9,kl_loss_2,0.004417084312687318
155
+ 9,kl_loss_3,0.007213231680604319
156
+ 9,kl_loss_4,0.006509214756079018
157
+ 9,kl_loss_5,0.00647372849440823
158
+ 9,kl_loss_6,0.005934650117220978
159
+ 9,kl_loss_7,0.006229850230738521
160
+ 9,kl_loss_8,0.006342786713503301
161
+ 9,kl_loss_9,0.7502625584602356
162
+ 10,recon_loss,2196.9954740084136
163
+ 10,loss,1999.4449462890625
164
+ 10,mi_loss,40.27904305091271
165
+ 10,tc_loss,-37.19911399254432
166
+ 10,dw_kl_loss,0.37395678689846623
167
+ 10,kl_loss,3.4625553901378927
168
+ 10,kl_loss_0,0.004101881040976598
169
+ 10,kl_loss_1,2.713090548148522
170
+ 10,kl_loss_2,0.0036793027359705707
171
+ 10,kl_loss_3,0.004992271200395548
172
+ 10,kl_loss_4,0.004465263444357193
173
+ 10,kl_loss_5,0.004884598884158409
174
+ 10,kl_loss_6,0.0037293383636726784
175
+ 10,kl_loss_7,0.004201724802931914
176
+ 10,kl_loss_8,0.00413606484205677
177
+ 10,kl_loss_9,0.7152743293688848
178
+ 11,recon_loss,2205.8016967773438
179
+ 11,loss,2008.5977172851562
180
+ 11,mi_loss,40.24677817026774
181
+ 11,tc_loss,-37.13152503967285
182
+ 11,dw_kl_loss,0.2668171264231205
183
+ 11,kl_loss,3.3814604083697
184
+ 11,kl_loss_0,0.002733308278645078
185
+ 11,kl_loss_1,2.5890026092529297
186
+ 11,kl_loss_2,0.003619925118982792
187
+ 11,kl_loss_3,0.0036296656665702662
188
+ 11,kl_loss_4,0.0030502101484065256
189
+ 11,kl_loss_5,0.003691234936316808
190
+ 11,kl_loss_6,0.0037010414137815437
191
+ 11,kl_loss_7,0.0035850899294018745
192
+ 11,kl_loss_8,0.0033201781722406545
193
+ 11,kl_loss_9,0.7651271522045135
194
+ 12,recon_loss,2204.1350911458335
195
+ 12,loss,2006.7399800618489
196
+ 12,mi_loss,40.24555047353109
197
+ 12,tc_loss,-37.16728210449219
198
+ 12,dw_kl_loss,0.29467178384462994
199
+ 12,kl_loss,3.4273709058761597
200
+ 12,kl_loss_0,0.00253635470289737
201
+ 12,kl_loss_1,2.64704296986262
202
+ 12,kl_loss_2,0.0027701943569506207
203
+ 12,kl_loss_3,0.0031717634992673993
204
+ 12,kl_loss_4,0.002853672835044563
205
+ 12,kl_loss_5,0.0033753060658151903
206
+ 12,kl_loss_6,0.002695254709882041
207
+ 12,kl_loss_7,0.0025047556264325976
208
+ 12,kl_loss_8,0.0035389424689734974
209
+ 12,kl_loss_9,0.7568817337354025
210
+ 13,recon_loss,2205.216759314904
211
+ 13,loss,2007.8895357572114
212
+ 13,mi_loss,40.38474479088416
213
+ 13,tc_loss,-37.186089735764725
214
+ 13,dw_kl_loss,0.33352065086364746
215
+ 13,kl_loss,3.4887316043560324
216
+ 13,kl_loss_0,0.0025462846343333903
217
+ 13,kl_loss_1,2.7032171212709866
218
+ 13,kl_loss_2,0.0023540596549327555
219
+ 13,kl_loss_3,0.0035233735106885433
220
+ 13,kl_loss_4,0.0030828057430111445
221
+ 13,kl_loss_5,0.002101871149184612
222
+ 13,kl_loss_6,0.0022684050580629935
223
+ 13,kl_loss_7,0.0018005223514942022
224
+ 13,kl_loss_8,0.002609655655060823
225
+ 13,kl_loss_9,0.7652274553592389
226
+ 14,recon_loss,2203.2179158528647
227
+ 14,loss,2006.1795145670574
228
+ 14,mi_loss,40.190184911092125
229
+ 14,tc_loss,-37.12630271911621
230
+ 14,dw_kl_loss,0.42094290008147556
231
+ 14,kl_loss,3.521983802318573
232
+ 14,kl_loss_0,0.0017787318599099915
233
+ 14,kl_loss_1,2.733785887559255
234
+ 14,kl_loss_2,0.0020149560490002236
235
+ 14,kl_loss_3,0.0026113978043819466
236
+ 14,kl_loss_4,0.0019087268350025017
237
+ 14,kl_loss_5,0.0026779077791919312
238
+ 14,kl_loss_6,0.002250582134972016
239
+ 14,kl_loss_7,0.0023648399704446397
240
+ 14,kl_loss_8,0.002240404215020438
241
+ 14,kl_loss_9,0.770350361863772
242
+ 15,recon_loss,2211.151329627404
243
+ 15,loss,2014.0557016225962
244
+ 15,mi_loss,40.30680495042067
245
+ 15,tc_loss,-37.14148154625526
246
+ 15,dw_kl_loss,0.3123072420175259
247
+ 15,kl_loss,3.4762388192690334
248
+ 15,kl_loss_0,0.002717837368926177
249
+ 15,kl_loss_1,2.6843277307657094
250
+ 15,kl_loss_2,0.002006038987579254
251
+ 15,kl_loss_3,0.0024746029207912777
252
+ 15,kl_loss_4,0.0028183436594330347
253
+ 15,kl_loss_5,0.0020857790413384256
254
+ 15,kl_loss_6,0.0016617546431147135
255
+ 15,kl_loss_7,0.0018722561474602956
256
+ 15,kl_loss_8,0.0026360321789979935
257
+ 15,kl_loss_9,0.7736383722378657
258
+ 16,recon_loss,2205.65625
259
+ 16,loss,2008.7182006835938
260
+ 16,mi_loss,40.34294160207113
261
+ 16,tc_loss,-37.122435569763184
262
+ 16,dw_kl_loss,0.30258239308993023
263
+ 16,kl_loss,3.5443766514460244
264
+ 16,kl_loss_0,0.0019371571640173595
265
+ 16,kl_loss_1,2.7254956364631653
266
+ 16,kl_loss_2,0.0018169079363966982
267
+ 16,kl_loss_3,0.0018087425269186497
268
+ 16,kl_loss_4,0.0017845627153292298
269
+ 16,kl_loss_5,0.002286287684304019
270
+ 16,kl_loss_6,0.002490955676573018
271
+ 16,kl_loss_7,0.0018719588794435065
272
+ 16,kl_loss_8,0.0028695606160908937
273
+ 16,kl_loss_9,0.802014946937561
274
+ 17,recon_loss,2198.288348858173
275
+ 17,loss,2001.154766376202
276
+ 17,mi_loss,40.29867348304162
277
+ 17,tc_loss,-37.145969977745644
278
+ 17,dw_kl_loss,0.3019469002118477
279
+ 17,kl_loss,3.4419017755068264
280
+ 17,kl_loss_0,0.0018457891777730905
281
+ 17,kl_loss_1,2.6779668514545145
282
+ 17,kl_loss_2,0.0019515727718289082
283
+ 17,kl_loss_3,0.002290374062095697
284
+ 17,kl_loss_4,0.0018018389550539164
285
+ 17,kl_loss_5,0.0016674650474809683
286
+ 17,kl_loss_6,0.001166574119661863
287
+ 17,kl_loss_7,0.0020146659002281153
288
+ 17,kl_loss_8,0.001355440105096652
289
+ 17,kl_loss_9,0.7498412315662091
290
+ 18,recon_loss,2198.6248575846353
291
+ 18,loss,2001.4028625488281
292
+ 18,mi_loss,40.322779973347984
293
+ 18,tc_loss,-37.165176709493004
294
+ 18,dw_kl_loss,0.31236464778582257
295
+ 18,kl_loss,3.5249797304471335
296
+ 18,kl_loss_0,0.0014744151073197524
297
+ 18,kl_loss_1,2.756214439868927
298
+ 18,kl_loss_2,0.0018141739613686998
299
+ 18,kl_loss_3,0.0015735261452694733
300
+ 18,kl_loss_4,0.0010972448314229648
301
+ 18,kl_loss_5,0.0018187712412327528
302
+ 18,kl_loss_6,0.0031091769536336264
303
+ 18,kl_loss_7,0.0022688430811588964
304
+ 18,kl_loss_8,0.0020122514882435403
305
+ 18,kl_loss_9,0.7535968770583471
306
+ 19,recon_loss,2212.218937800481
307
+ 19,loss,2015.4007662259614
308
+ 19,mi_loss,40.31675866933969
309
+ 19,tc_loss,-37.111326951246994
310
+ 19,dw_kl_loss,0.37757790088653564
311
+ 19,kl_loss,3.59769160930927
312
+ 19,kl_loss_0,0.0016564357882508864
313
+ 19,kl_loss_1,2.8344204976008487
314
+ 19,kl_loss_2,0.003070324229506346
315
+ 19,kl_loss_3,0.0013594922179786058
316
+ 19,kl_loss_4,0.001366000407590316
317
+ 19,kl_loss_5,0.002101327758282423
318
+ 19,kl_loss_6,0.0017111848036830241
319
+ 19,kl_loss_7,0.0018114753497334628
320
+ 19,kl_loss_8,0.002678456107297769
321
+ 19,kl_loss_9,0.7475164211713351
322
+ 20,recon_loss,2202.2599487304688
323
+ 20,loss,2004.8538614908855
324
+ 20,mi_loss,40.45914777119955
325
+ 20,tc_loss,-37.21652317047119
326
+ 20,dw_kl_loss,0.3205146963397662
327
+ 20,kl_loss,3.6001903414726257
328
+ 20,kl_loss_0,0.001861263260555764
329
+ 20,kl_loss_1,2.8560718297958374
330
+ 20,kl_loss_2,0.0016067677255099018
331
+ 20,kl_loss_3,0.0021211198375870786
332
+ 20,kl_loss_4,0.0015188050456345081
333
+ 20,kl_loss_5,0.0015412049445634086
334
+ 20,kl_loss_6,0.0016792030461753409
335
+ 20,kl_loss_7,0.0019624157963941493
336
+ 20,kl_loss_8,0.0014063233975321054
337
+ 20,kl_loss_9,0.7304214636484782
338
+ 21,recon_loss,2207.904334435096
339
+ 21,loss,2010.744882436899
340
+ 21,mi_loss,40.43276625413161
341
+ 21,tc_loss,-37.171888204721306
342
+ 21,dw_kl_loss,0.3078639725079903
343
+ 21,kl_loss,3.561537137398353
344
+ 21,kl_loss_0,0.0012945552858022542
345
+ 21,kl_loss_1,2.7875716319450965
346
+ 21,kl_loss_2,0.001239637235322824
347
+ 21,kl_loss_3,0.0011019167275382923
348
+ 21,kl_loss_4,0.0019378469755443244
349
+ 21,kl_loss_5,0.0017049070447683334
350
+ 21,kl_loss_6,0.0016337902141878237
351
+ 21,kl_loss_7,0.0016330647855423964
352
+ 21,kl_loss_8,0.001772170993857659
353
+ 21,kl_loss_9,0.7616475820541382
354
+ 22,recon_loss,2210.4578247070312
355
+ 22,loss,2013.5369771321614
356
+ 22,mi_loss,40.30475012461344
357
+ 22,tc_loss,-37.11816851298014
358
+ 22,dw_kl_loss,0.3306787001589934
359
+ 22,kl_loss,3.547394037246704
360
+ 22,kl_loss_0,0.0017359372383604448
361
+ 22,kl_loss_1,2.824240207672119
362
+ 22,kl_loss_2,0.0011897988927861054
363
+ 22,kl_loss_3,0.0013395212202643354
364
+ 22,kl_loss_4,0.001481865649111569
365
+ 22,kl_loss_5,0.001497445589241882
366
+ 22,kl_loss_6,0.001183633382121722
367
+ 22,kl_loss_7,0.0013544160562256973
368
+ 22,kl_loss_8,0.001382134932403763
369
+ 22,kl_loss_9,0.7119891196489334
370
+ 23,recon_loss,2210.9183443509614
371
+ 23,loss,2013.5601994441106
372
+ 23,mi_loss,40.38095914400541
373
+ 23,tc_loss,-37.20592850905199
374
+ 23,dw_kl_loss,0.3788297451459445
375
+ 23,kl_loss,3.5697782589839053
376
+ 23,kl_loss_0,0.001240100496663497
377
+ 23,kl_loss_1,2.8392400558178243
378
+ 23,kl_loss_2,0.0012557490442234736
379
+ 23,kl_loss_3,0.0016619229259399267
380
+ 23,kl_loss_4,0.0014921167435554357
381
+ 23,kl_loss_5,0.0014155316524780714
382
+ 23,kl_loss_6,0.0015437594041801416
383
+ 23,kl_loss_7,0.0013436333706172614
384
+ 23,kl_loss_8,0.0015257975946252162
385
+ 23,kl_loss_9,0.7190595681850727
386
+ 24,recon_loss,2205.2496948242188
387
+ 24,loss,2008.2829081217449
388
+ 24,mi_loss,40.33272743225098
389
+ 24,tc_loss,-37.13726011912028
390
+ 24,dw_kl_loss,0.3789539597928524
391
+ 24,kl_loss,3.5514416495958963
392
+ 24,kl_loss_0,0.0012495704383278887
393
+ 24,kl_loss_1,2.788417716821035
394
+ 24,kl_loss_2,0.0018070755759254098
395
+ 24,kl_loss_3,0.0012991854067270954
396
+ 24,kl_loss_4,0.0010839081757391493
397
+ 24,kl_loss_5,0.0019774677930399776
398
+ 24,kl_loss_6,0.0014922072102005284
399
+ 24,kl_loss_7,0.0012786747732510169
400
+ 24,kl_loss_8,0.001147995547701915
401
+ 24,kl_loss_9,0.7516879439353943
402
+ 25,recon_loss,2208.270467122396
403
+ 25,loss,2010.8847351074219
404
+ 25,mi_loss,40.34546057383219
405
+ 25,tc_loss,-37.20619010925293
406
+ 25,dw_kl_loss,0.38843652978539467
407
+ 25,kl_loss,3.6115296880404153
408
+ 25,kl_loss_0,0.0015405131659160058
409
+ 25,kl_loss_1,2.9013171394666037
410
+ 25,kl_loss_2,0.0011097188883771498
411
+ 25,kl_loss_3,0.0018861019440616171
412
+ 25,kl_loss_4,0.0015740771389876802
413
+ 25,kl_loss_5,0.0015496667086457212
414
+ 25,kl_loss_6,0.001531413911531369
415
+ 25,kl_loss_7,0.0019122938392683864
416
+ 25,kl_loss_8,0.0011389451489473383
417
+ 25,kl_loss_9,0.6979698638121287
418
+ 26,recon_loss,2203.443903996394
419
+ 26,loss,2006.6229717548076
420
+ 26,mi_loss,40.34916833730844
421
+ 26,tc_loss,-37.11746098445012
422
+ 26,dw_kl_loss,0.3816586262904681
423
+ 26,kl_loss,3.5773275265326867
424
+ 26,kl_loss_0,0.001173520998026316
425
+ 26,kl_loss_1,2.8474206374241757
426
+ 26,kl_loss_2,0.0016174610489262985
427
+ 26,kl_loss_3,0.0015147336257191806
428
+ 26,kl_loss_4,0.0010703427788729852
429
+ 26,kl_loss_5,0.0011983873107685493
430
+ 26,kl_loss_6,0.00126259820535779
431
+ 26,kl_loss_7,0.001042055252652902
432
+ 26,kl_loss_8,0.0008434637879522947
433
+ 26,kl_loss_9,0.7201842207175034
434
+ 27,recon_loss,2206.4358520507812
435
+ 27,loss,2009.3228352864583
436
+ 27,mi_loss,40.37657833099365
437
+ 27,tc_loss,-37.16412862141927
438
+ 27,dw_kl_loss,0.360832375784715
439
+ 27,kl_loss,3.5843074520428977
440
+ 27,kl_loss_0,0.0011982040014117956
441
+ 27,kl_loss_1,2.846610724925995
442
+ 27,kl_loss_2,0.0011326035019010305
443
+ 27,kl_loss_3,0.001651634695008397
444
+ 27,kl_loss_4,0.0012949489755555987
445
+ 27,kl_loss_5,0.0014977601046363513
446
+ 27,kl_loss_6,0.001088831612529854
447
+ 27,kl_loss_7,0.0012824649068837364
448
+ 27,kl_loss_8,0.001276915582517783
449
+ 27,kl_loss_9,0.7272733698288599
450
+ 28,recon_loss,2204.0506497896636
451
+ 28,loss,2007.2717942457932
452
+ 28,mi_loss,40.36799122737004
453
+ 28,tc_loss,-37.11823654174805
454
+ 28,dw_kl_loss,0.4098718739472903
455
+ 28,kl_loss,3.5976003866929274
456
+ 28,kl_loss_0,0.0012191503595274228
457
+ 28,kl_loss_1,2.8608819337991567
458
+ 28,kl_loss_2,0.0008682911642468893
459
+ 28,kl_loss_3,0.0012883888557553291
460
+ 28,kl_loss_4,0.001379269605072645
461
+ 28,kl_loss_5,0.0015075570688797878
462
+ 28,kl_loss_6,0.0012084581316090547
463
+ 28,kl_loss_7,0.0009197160028494322
464
+ 28,kl_loss_8,0.0013607487154121583
465
+ 28,kl_loss_9,0.726966903759883
466
+ 29,recon_loss,2199.2364501953125
467
+ 29,loss,2002.3850402832031
468
+ 29,mi_loss,40.270185470581055
469
+ 29,tc_loss,-37.105968157450356
470
+ 29,dw_kl_loss,0.3566128934423129
471
+ 29,kl_loss,3.578150769074758
472
+ 29,kl_loss_0,0.0011238552397117019
473
+ 29,kl_loss_1,2.8515902956326804
474
+ 29,kl_loss_2,0.0014782638754695654
475
+ 29,kl_loss_3,0.0011515693040564656
476
+ 29,kl_loss_4,0.0014677915023639798
477
+ 29,kl_loss_5,0.0013430318019042413
478
+ 29,kl_loss_6,0.0013744460884481668
479
+ 29,kl_loss_7,0.0012495889095589519
480
+ 29,kl_loss_8,0.00138730447118481
481
+ 29,kl_loss_9,0.715984657406807
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ #plotly
4
+ # scipy
5
+ tqdm
6
+ # pillow
7
+
8
+ # ipywidgets
9
+ # jupyterlab
10
+
11
+ torch
transforms.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alle transforms sind grundsätzlich auf batches bezogen!
3
+ Vae transforms sind invertierbar
4
+ """
5
+ import pickle
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce, wraps
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ # Allgemeine Funktionen -------------------------------------------------------------
13
+ # Transformations in Pytorch sind am einfachsten.
14
+
15
+
16
+ def load(p):
17
+ with open(p, "rb") as stream:
18
+ return pickle.load(stream)
19
+
20
+
21
+ def save(obj, p):
22
+ with open(p, "wb") as stream:
23
+ pickle.dump(obj, stream)
24
+
25
+
26
+ def sequential_function(*functions):
27
+ return lambda x: reduce(lambda res, func: func(res), functions, x)
28
+
29
+
30
+ def np_sample(func):
31
+ rtn = sequential_function(
32
+ lambda x: torch.from_numpy(x).float(),
33
+ lambda x: torch.unsqueeze(x, 0),
34
+ func,
35
+ lambda x: x[0].numpy(),
36
+ )
37
+ return rtn
38
+
39
+
40
+ # Inverseabvle
41
+ class SequentialInversable(torch.nn.Sequential):
42
+ def __init__(self, *functions):
43
+ super().__init__(*functions)
44
+
45
+ self.inv_funcs = [f.inv for f in functions]
46
+ self.inv_funcs.reverse()
47
+
48
+ # def forward(self, x):
49
+ # return sequential_function(*self.functions)(x)
50
+
51
+ def inv(self, x):
52
+ return sequential_function(*self.inv_funcs)(x)
53
+
54
+
55
+ class LatentSelector(torch.nn.Module):
56
+ """Verarbeitet Tensoren und numpy arrays"""
57
+
58
+ def __init__(self, ldim: int, selectdim: int):
59
+ super().__init__()
60
+ self.ldim = ldim
61
+ self.selectdim = selectdim
62
+
63
+ def forward(self, x: torch.Tensor):
64
+ return x[:, : self.selectdim]
65
+
66
+ def inv(self, x: torch.Tensor):
67
+ rtn = torch.cat(
68
+ [x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
69
+ dim=1,
70
+ )
71
+ return rtn
72
+
73
+
74
+ class MinMaxScaler(torch.nn.Module):
75
+ #! Bei mehreren Signalen vorsicht mit dem Broadcasting.
76
+ def __init__(
77
+ self,
78
+ _min: torch.Tensor,
79
+ _max: torch.Tensor,
80
+ min_norm: float = 0.0,
81
+ max_norm: float = 1.0,
82
+ ):
83
+ super().__init__()
84
+ self._min = _min
85
+ self._max = _max
86
+ self.min_norm = min_norm
87
+ self.max_norm = max_norm
88
+
89
+ def forward(self, ts):
90
+ """None, no_signals"""
91
+ std = (ts - self._min) / (self._max - self._min)
92
+ rtn = std * (self.max_norm - self.min_norm) + self.min_norm
93
+ return rtn
94
+
95
+ def inv(self, ts):
96
+ std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
97
+ rtn = std * (self._max - self._min) + self._min
98
+ return rtn
99
+
100
+ @classmethod
101
+ def from_array(cls, arr: torch.Tensor):
102
+ _min = torch.min(arr, axis=0).values
103
+ _max = torch.max(arr, axis=0).values
104
+
105
+ return cls(_min, _max)
106
+
107
+
108
+ class LatentSorter(torch.nn.Module):
109
+ def __init__(self, kl_dict: dict):
110
+ super().__init__()
111
+ self.kl_dict = kl_dict
112
+
113
+ def forward(self, latent):
114
+ """
115
+ unsorted -> sorted
116
+ latent: (None, latent_dim)
117
+ """
118
+ return latent[:, list(self.kl_dict.keys())]
119
+
120
+ def inv(self, latent):
121
+ keys = np.array(list(self.kl_dict.keys()))
122
+ return latent[:, torch.from_numpy(keys.argsort())]
123
+
124
+ @property
125
+ def names(self):
126
+ rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
127
+ return rtn
128
+
129
+
130
+ def apply_along_axis(function, x, axis: int = 0):
131
+ return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)
132
+
133
+
134
+ # Eingangsshapes bleiben wie sie sind!
135
+ class SumField(torch.nn.Module):
136
+ """
137
+ time series: [idx, time_step, signal]
138
+ image: [idx, signal, time_step, time_step]
139
+ """
140
+
141
+ def forward(self, ts: torch.Tensor):
142
+ """ts2img"""
143
+
144
+ samples = ts.shape[0]
145
+ time = ts.shape[1]
146
+ channels = ts.shape[2]
147
+
148
+ ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende
149
+ ts = torch.reshape(
150
+ ts, (samples * channels, time)
151
+ ) # Zusammenfassen von Channel + idx
152
+ #! TODO: Schleife besser lösen
153
+ rtn = apply_along_axis(self._mtf_forward, ts, 0)
154
+ rtn = torch.reshape(rtn, (samples, channels, time, time))
155
+
156
+ return rtn
157
+
158
+ def inv(self, img: torch.Tensor):
159
+ """img2ts"""
160
+ rtn = torch.diagonal(img, dim1=2, dim2=3)
161
+ rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen
162
+
163
+ return rtn
164
+
165
+ @staticmethod
166
+ def _mtf_forward(ts):
167
+ """For one dimensional time series ts"""
168
+ return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2