Spaces:
Sleeping
Sleeping
Jonas Becker
commited on
Commit
·
c53ddec
1
Parent(s):
e8305d9
1st try
Browse files- .gitignore +2 -0
- app.bat +2 -0
- app.py +38 -0
- disvae/__init__.py +6 -0
- disvae/evaluate.py +317 -0
- disvae/main.py +145 -0
- disvae/models/__init__.py +0 -0
- disvae/models/decoders.py +84 -0
- disvae/models/discriminator.py +73 -0
- disvae/models/encoders.py +89 -0
- disvae/models/losses.py +544 -0
- disvae/models/vae.py +101 -0
- disvae/training.py +212 -0
- disvae/utils/__init__.py +0 -0
- disvae/utils/initialization.py +61 -0
- disvae/utils/math.py +73 -0
- disvae/utils/modelIO.py +200 -0
- requirements.txt +11 -0
- transforms.py +168 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
__pycache__
|
app.bat
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
call venv\scripts\activate.bat
|
2 |
+
call streamlit run app.py
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import disvae
|
6 |
+
import transforms as trans
|
7 |
+
|
8 |
+
P_MODEL = "models/btcvae_celeba"
|
9 |
+
|
10 |
+
# Decode Funktion --------------------------------------------------
|
11 |
+
sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL))
|
12 |
+
vae = disvae.load_model(P_MODEL)
|
13 |
+
scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6)
|
14 |
+
imaging = trans.SumField()
|
15 |
+
|
16 |
+
_dec = trans.sequential_function(
|
17 |
+
sorter.inv,
|
18 |
+
vae.decoder
|
19 |
+
)
|
20 |
+
|
21 |
+
def decode(latent):
|
22 |
+
with torch.no_grad():
|
23 |
+
return trans.np_sample(_dec)(latent)
|
24 |
+
|
25 |
+
# GUI -----------------------------------------------------------
|
26 |
+
|
27 |
+
latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)])
|
28 |
+
latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0)
|
29 |
+
|
30 |
+
value = decode(latent_vector)
|
31 |
+
|
32 |
+
value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1)# * 255
|
33 |
+
|
34 |
+
# st.write(value)
|
35 |
+
st.image(value, use_column_width="always")
|
36 |
+
|
37 |
+
# x = st.slider("Select a value")
|
38 |
+
# st.write(x, "squared is", x * x)
|
disvae/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from disvae.evaluate import Evaluator
|
2 |
+
from disvae.main import get_kl_dict
|
3 |
+
from disvae.training import Trainer
|
4 |
+
from disvae.utils.modelIO import load_model, save_model
|
5 |
+
|
6 |
+
# from disvae.models.vae import init_specific_model # notwendig?
|
disvae/evaluate.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from functools import reduce
|
5 |
+
from collections import defaultdict
|
6 |
+
import json
|
7 |
+
from timeit import default_timer
|
8 |
+
|
9 |
+
from tqdm import trange, tqdm
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from disvae.models.losses import get_loss_f
|
14 |
+
from disvae.utils.math import log_density_gaussian
|
15 |
+
from disvae.utils.modelIO import save_metadata
|
16 |
+
|
17 |
+
TEST_LOSSES_FILE = "test_losses.log"
|
18 |
+
METRICS_FILENAME = "metrics.log"
|
19 |
+
METRIC_HELPERS_FILE = "metric_helpers.pth"
|
20 |
+
|
21 |
+
|
22 |
+
class Evaluator:
|
23 |
+
"""
|
24 |
+
Class to handle training of model.
|
25 |
+
|
26 |
+
Parameters
|
27 |
+
----------
|
28 |
+
model: disvae.vae.VAE
|
29 |
+
|
30 |
+
loss_f: disvae.models.BaseLoss
|
31 |
+
Loss function.
|
32 |
+
|
33 |
+
device: torch.device, optional
|
34 |
+
Device on which to run the code.
|
35 |
+
|
36 |
+
logger: logging.Logger, optional
|
37 |
+
Logger.
|
38 |
+
|
39 |
+
save_dir : str, optional
|
40 |
+
Directory for saving logs.
|
41 |
+
|
42 |
+
is_progress_bar: bool, optional
|
43 |
+
Whether to use a progress bar for training.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, model, loss_f,
|
47 |
+
device=torch.device("cpu"),
|
48 |
+
logger=logging.getLogger(__name__),
|
49 |
+
save_dir="results",
|
50 |
+
is_progress_bar=True):
|
51 |
+
|
52 |
+
self.device = device
|
53 |
+
self.loss_f = loss_f
|
54 |
+
self.model = model.to(self.device)
|
55 |
+
self.logger = logger
|
56 |
+
self.save_dir = save_dir
|
57 |
+
self.is_progress_bar = is_progress_bar
|
58 |
+
self.logger.info("Testing Device: {}".format(self.device))
|
59 |
+
|
60 |
+
def __call__(self, data_loader, is_metrics=False, is_losses=True):
|
61 |
+
"""Compute all test losses.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
data_loader: torch.utils.data.DataLoader
|
66 |
+
|
67 |
+
is_metrics: bool, optional
|
68 |
+
Whether to compute and store the disentangling metrics.
|
69 |
+
|
70 |
+
is_losses: bool, optional
|
71 |
+
Whether to compute and store the test losses.
|
72 |
+
"""
|
73 |
+
start = default_timer()
|
74 |
+
is_still_training = self.model.training
|
75 |
+
self.model.eval()
|
76 |
+
|
77 |
+
metric, losses = None, None
|
78 |
+
if is_metrics:
|
79 |
+
self.logger.info('Computing metrics...')
|
80 |
+
metrics = self.compute_metrics(data_loader)
|
81 |
+
self.logger.info('Losses: {}'.format(metrics))
|
82 |
+
save_metadata(metrics, self.save_dir, filename=METRICS_FILENAME)
|
83 |
+
|
84 |
+
if is_losses:
|
85 |
+
self.logger.info('Computing losses...')
|
86 |
+
losses = self.compute_losses(data_loader)
|
87 |
+
self.logger.info('Losses: {}'.format(losses))
|
88 |
+
save_metadata(losses, self.save_dir, filename=TEST_LOSSES_FILE)
|
89 |
+
|
90 |
+
if is_still_training:
|
91 |
+
self.model.train()
|
92 |
+
|
93 |
+
self.logger.info('Finished evaluating after {:.1f} min.'.format((default_timer() - start) / 60))
|
94 |
+
|
95 |
+
return metric, losses
|
96 |
+
|
97 |
+
def compute_losses(self, dataloader):
|
98 |
+
"""Compute all test losses.
|
99 |
+
|
100 |
+
Parameters
|
101 |
+
----------
|
102 |
+
data_loader: torch.utils.data.DataLoader
|
103 |
+
"""
|
104 |
+
storer = defaultdict(list)
|
105 |
+
for data, _ in tqdm(dataloader, leave=False, disable=not self.is_progress_bar):
|
106 |
+
data = data.to(self.device)
|
107 |
+
|
108 |
+
try:
|
109 |
+
recon_batch, latent_dist, latent_sample = self.model(data)
|
110 |
+
_ = self.loss_f(data, recon_batch, latent_dist, self.model.training,
|
111 |
+
storer, latent_sample=latent_sample)
|
112 |
+
except ValueError:
|
113 |
+
# for losses that use multiple optimizers (e.g. Factor)
|
114 |
+
_ = self.loss_f.call_optimize(data, self.model, None, storer)
|
115 |
+
|
116 |
+
losses = {k: sum(v) / len(dataloader) for k, v in storer.items()}
|
117 |
+
return losses
|
118 |
+
|
119 |
+
def compute_metrics(self, dataloader):
|
120 |
+
"""Compute all the metrics.
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
data_loader: torch.utils.data.DataLoader
|
125 |
+
"""
|
126 |
+
try:
|
127 |
+
lat_sizes = dataloader.dataset.lat_sizes
|
128 |
+
lat_names = dataloader.dataset.lat_names
|
129 |
+
except AttributeError:
|
130 |
+
raise ValueError("Dataset needs to have known true factors of variations to compute the metric. This does not seem to be the case for {}".format(type(dataloader.__dict__["dataset"]).__name__))
|
131 |
+
|
132 |
+
self.logger.info("Computing the empirical distribution q(z|x).")
|
133 |
+
samples_zCx, params_zCx = self._compute_q_zCx(dataloader)
|
134 |
+
len_dataset, latent_dim = samples_zCx.shape
|
135 |
+
|
136 |
+
self.logger.info("Estimating the marginal entropy.")
|
137 |
+
# marginal entropy H(z_j)
|
138 |
+
H_z = self._estimate_latent_entropies(samples_zCx, params_zCx)
|
139 |
+
|
140 |
+
# conditional entropy H(z|v)
|
141 |
+
samples_zCx = samples_zCx.view(*lat_sizes, latent_dim)
|
142 |
+
params_zCx = tuple(p.view(*lat_sizes, latent_dim) for p in params_zCx)
|
143 |
+
H_zCv = self._estimate_H_zCv(samples_zCx, params_zCx, lat_sizes, lat_names)
|
144 |
+
|
145 |
+
H_z = H_z.cpu()
|
146 |
+
H_zCv = H_zCv.cpu()
|
147 |
+
|
148 |
+
# I[z_j;v_k] = E[log \sum_x q(z_j|x)p(x|v_k)] + H[z_j] = - H[z_j|v_k] + H[z_j]
|
149 |
+
mut_info = - H_zCv + H_z
|
150 |
+
sorted_mut_info = torch.sort(mut_info, dim=1, descending=True)[0].clamp(min=0)
|
151 |
+
|
152 |
+
metric_helpers = {'marginal_entropies': H_z, 'cond_entropies': H_zCv}
|
153 |
+
mig = self._mutual_information_gap(sorted_mut_info, lat_sizes, storer=metric_helpers)
|
154 |
+
aam = self._axis_aligned_metric(sorted_mut_info, storer=metric_helpers)
|
155 |
+
|
156 |
+
metrics = {'MIG': mig.item(), 'AAM': aam.item()}
|
157 |
+
torch.save(metric_helpers, os.path.join(self.save_dir, METRIC_HELPERS_FILE))
|
158 |
+
|
159 |
+
return metrics
|
160 |
+
|
161 |
+
def _mutual_information_gap(self, sorted_mut_info, lat_sizes, storer=None):
|
162 |
+
"""Compute the mutual information gap as in [1].
|
163 |
+
|
164 |
+
References
|
165 |
+
----------
|
166 |
+
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
|
167 |
+
autoencoders." Advances in Neural Information Processing Systems. 2018.
|
168 |
+
"""
|
169 |
+
# difference between the largest and second largest mutual info
|
170 |
+
delta_mut_info = sorted_mut_info[:, 0] - sorted_mut_info[:, 1]
|
171 |
+
# NOTE: currently only works if balanced dataset for every factor of variation
|
172 |
+
# then H(v_k) = - |V_k|/|V_k| log(1/|V_k|) = log(|V_k|)
|
173 |
+
H_v = torch.from_numpy(lat_sizes).float().log()
|
174 |
+
mig_k = delta_mut_info / H_v
|
175 |
+
mig = mig_k.mean() # mean over factor of variations
|
176 |
+
|
177 |
+
if storer is not None:
|
178 |
+
storer["mig_k"] = mig_k
|
179 |
+
storer["mig"] = mig
|
180 |
+
|
181 |
+
return mig
|
182 |
+
|
183 |
+
def _axis_aligned_metric(self, sorted_mut_info, storer=None):
|
184 |
+
"""Compute the proposed axis aligned metrics."""
|
185 |
+
numerator = (sorted_mut_info[:, 0] - sorted_mut_info[:, 1:].sum(dim=1)).clamp(min=0)
|
186 |
+
aam_k = numerator / sorted_mut_info[:, 0]
|
187 |
+
aam_k[torch.isnan(aam_k)] = 0
|
188 |
+
aam = aam_k.mean() # mean over factor of variations
|
189 |
+
|
190 |
+
if storer is not None:
|
191 |
+
storer["aam_k"] = aam_k
|
192 |
+
storer["aam"] = aam
|
193 |
+
|
194 |
+
return aam
|
195 |
+
|
196 |
+
def _compute_q_zCx(self, dataloader):
|
197 |
+
"""Compute the empiricall disitribution of q(z|x).
|
198 |
+
|
199 |
+
Parameter
|
200 |
+
---------
|
201 |
+
dataloader: torch.utils.data.DataLoader
|
202 |
+
Batch data iterator.
|
203 |
+
|
204 |
+
Return
|
205 |
+
------
|
206 |
+
samples_zCx: torch.tensor
|
207 |
+
Tensor of shape (len_dataset, latent_dim) containing a sample of
|
208 |
+
q(z|x) for every x in the dataset.
|
209 |
+
|
210 |
+
params_zCX: tuple of torch.Tensor
|
211 |
+
Sufficient statistics q(z|x) for each training example. E.g. for
|
212 |
+
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
|
213 |
+
"""
|
214 |
+
len_dataset = len(dataloader.dataset)
|
215 |
+
latent_dim = self.model.latent_dim
|
216 |
+
n_suff_stat = 2
|
217 |
+
|
218 |
+
q_zCx = torch.zeros(len_dataset, latent_dim, n_suff_stat, device=self.device)
|
219 |
+
|
220 |
+
n = 0
|
221 |
+
with torch.no_grad():
|
222 |
+
for x, label in dataloader:
|
223 |
+
batch_size = x.size(0)
|
224 |
+
idcs = slice(n, n + batch_size)
|
225 |
+
q_zCx[idcs, :, 0], q_zCx[idcs, :, 1] = self.model.encoder(x.to(self.device))
|
226 |
+
n += batch_size
|
227 |
+
|
228 |
+
params_zCX = q_zCx.unbind(-1)
|
229 |
+
samples_zCx = self.model.reparameterize(*params_zCX)
|
230 |
+
|
231 |
+
return samples_zCx, params_zCX
|
232 |
+
|
233 |
+
def _estimate_latent_entropies(self, samples_zCx, params_zCX,
|
234 |
+
n_samples=10000):
|
235 |
+
r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]`
|
236 |
+
using the emperical distribution of :math:`p(x)`.
|
237 |
+
|
238 |
+
Note
|
239 |
+
----
|
240 |
+
- the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`.
|
241 |
+
- we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`.
|
242 |
+
- computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`.
|
243 |
+
|
244 |
+
Parameters
|
245 |
+
----------
|
246 |
+
samples_zCx: torch.tensor
|
247 |
+
Tensor of shape (len_dataset, latent_dim) containing a sample of
|
248 |
+
q(z|x) for every x in the dataset.
|
249 |
+
|
250 |
+
params_zCX: tuple of torch.Tensor
|
251 |
+
Sufficient statistics q(z|x) for each training example. E.g. for
|
252 |
+
gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).
|
253 |
+
|
254 |
+
n_samples: int, optional
|
255 |
+
Number of samples to use to estimate the entropies.
|
256 |
+
|
257 |
+
Return
|
258 |
+
------
|
259 |
+
H_z: torch.Tensor
|
260 |
+
Tensor of shape (latent_dim) containing the marginal entropies H(z_j)
|
261 |
+
"""
|
262 |
+
len_dataset, latent_dim = samples_zCx.shape
|
263 |
+
device = samples_zCx.device
|
264 |
+
H_z = torch.zeros(latent_dim, device=device)
|
265 |
+
|
266 |
+
# sample from p(x)
|
267 |
+
samples_x = torch.randperm(len_dataset, device=device)[:n_samples]
|
268 |
+
# sample from p(z|x)
|
269 |
+
samples_zCx = samples_zCx.index_select(0, samples_x).view(latent_dim, n_samples)
|
270 |
+
|
271 |
+
mini_batch_size = 10
|
272 |
+
samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
|
273 |
+
mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
|
274 |
+
log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim, n_samples)
|
275 |
+
log_N = math.log(len_dataset)
|
276 |
+
with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
|
277 |
+
for k in range(0, n_samples, mini_batch_size):
|
278 |
+
# log q(z_j|x) for n_samples
|
279 |
+
idcs = slice(k, k + mini_batch_size)
|
280 |
+
log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
|
281 |
+
mean[..., idcs],
|
282 |
+
log_var[..., idcs])
|
283 |
+
# numerically stable log q(z_j) for n_samples:
|
284 |
+
# log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
|
285 |
+
# As we don't know q(z) we appoximate it with the monte carlo
|
286 |
+
# expectation of q(z_j|x_n) over x. => fix a single z and look at
|
287 |
+
# proba for every x to generate it. n_samples is not used here !
|
288 |
+
log_q_z = -log_N + torch.logsumexp(log_q_zCx, dim=0, keepdim=False)
|
289 |
+
# H(z_j) = E_{z_j}[- log q(z_j)]
|
290 |
+
# mean over n_samples (i.e. dimesnion 1 because already summed over 0).
|
291 |
+
H_z += (-log_q_z).sum(1)
|
292 |
+
|
293 |
+
t.update(mini_batch_size)
|
294 |
+
|
295 |
+
H_z /= n_samples
|
296 |
+
|
297 |
+
return H_z
|
298 |
+
|
299 |
+
def _estimate_H_zCv(self, samples_zCx, params_zCx, lat_sizes, lat_names):
|
300 |
+
"""Estimate conditional entropies :math:`H[z|v]`."""
|
301 |
+
latent_dim = samples_zCx.size(-1)
|
302 |
+
len_dataset = reduce((lambda x, y: x * y), lat_sizes)
|
303 |
+
H_zCv = torch.zeros(len(lat_sizes), latent_dim, device=self.device)
|
304 |
+
for i_fac_var, (lat_size, lat_name) in enumerate(zip(lat_sizes, lat_names)):
|
305 |
+
idcs = [slice(None)] * len(lat_sizes)
|
306 |
+
for i in range(lat_size):
|
307 |
+
self.logger.info("Estimating conditional entropies for the {}th value of {}.".format(i, lat_name))
|
308 |
+
idcs[i_fac_var] = i
|
309 |
+
# samples from q(z,x|v)
|
310 |
+
samples_zxCv = samples_zCx[idcs].contiguous().view(len_dataset // lat_size,
|
311 |
+
latent_dim)
|
312 |
+
params_zxCv = tuple(p[idcs].contiguous().view(len_dataset // lat_size, latent_dim)
|
313 |
+
for p in params_zCx)
|
314 |
+
|
315 |
+
H_zCv[i_fac_var] += self._estimate_latent_entropies(samples_zxCv, params_zxCv
|
316 |
+
) / lat_size
|
317 |
+
return H_zCv
|
disvae/main.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pip-Packages -----------------------------------------------------
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from datetime import datetime
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
from torch import optim
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# From local package -----------------------------------------------
|
15 |
+
from disvae.models.losses import get_loss_f
|
16 |
+
from disvae.models.vae import init_specific_model
|
17 |
+
from disvae.training import Trainer
|
18 |
+
from disvae.utils.modelIO import save_model
|
19 |
+
|
20 |
+
# Loss stuff:
|
21 |
+
|
22 |
+
|
23 |
+
def parse_losses(p_model, filename="train_losses.log"):
|
24 |
+
df = pd.read_csv(Path(p_model) / filename)
|
25 |
+
|
26 |
+
losses = df["Loss"].unique()
|
27 |
+
|
28 |
+
rtn = [np.array(df[df["Loss"] == l]["Value"]) for l in losses]
|
29 |
+
rtn = pd.DataFrame(np.array(rtn).T, columns=losses)
|
30 |
+
|
31 |
+
return rtn
|
32 |
+
|
33 |
+
|
34 |
+
def get_kl_loss_latent(df):
|
35 |
+
"""df muss bereits geparsed sein!"""
|
36 |
+
rtn = {int(c.split("_")[-1]): df[c].iloc[-1] for c in df if "kl_loss_" in c}
|
37 |
+
rtn = dict(sorted(rtn.items(), key=lambda item: item[1], reverse=True))
|
38 |
+
return rtn
|
39 |
+
|
40 |
+
|
41 |
+
def get_kl_dict(p_model):
|
42 |
+
df = parse_losses(p_model)
|
43 |
+
rtn = get_kl_loss_latent(df)
|
44 |
+
return rtn
|
45 |
+
|
46 |
+
|
47 |
+
# Datalaader convinience stuff
|
48 |
+
|
49 |
+
|
50 |
+
# def get_dataloader(dataset: torch.data.Dataset, batch_size, num_workers):
|
51 |
+
# # Funktion ist recht kompliziert. Das geht im Notebook schnell
|
52 |
+
# # Diese Dinge werden auch zur Visualisierung des Datasets benötigt
|
53 |
+
|
54 |
+
# # p_dataset_module, dataset_class, dataset_args
|
55 |
+
# # Import module
|
56 |
+
# # if p_dataset_module not in sys.path:
|
57 |
+
# # sys.path.append(str(Path(p_dataset_module).parent))
|
58 |
+
# # Dataset = getattr(
|
59 |
+
# # importlib.import_module(Path(p_dataset_module).stem), dataset_class
|
60 |
+
# # )
|
61 |
+
|
62 |
+
# # # Ab hier an, wenn das normal importiert würde
|
63 |
+
# # ds = Dataset(**dataset_args)
|
64 |
+
|
65 |
+
#
|
66 |
+
|
67 |
+
# return loader
|
68 |
+
|
69 |
+
|
70 |
+
def get_export_dir(base_dir: str, folder_name):
|
71 |
+
if folder_name is None:
|
72 |
+
folder_name = "Model_" + (
|
73 |
+
datetime.now().replace(microsecond=0).isoformat()
|
74 |
+
).replace(" ", "_").replace(":", "-")
|
75 |
+
|
76 |
+
rtn = Path(base_dir) / folder_name
|
77 |
+
|
78 |
+
if not rtn.exists():
|
79 |
+
os.makedirs(rtn)
|
80 |
+
else:
|
81 |
+
raise ValueError("Output directory already exists.")
|
82 |
+
|
83 |
+
return rtn
|
84 |
+
|
85 |
+
|
86 |
+
def train_model(model, data_loader, loss_f, device, lr, epochs, export_dir):
|
87 |
+
trainer = Trainer(
|
88 |
+
model,
|
89 |
+
optim.Adam(model.parameters(), lr=lr),
|
90 |
+
loss_f,
|
91 |
+
device=device,
|
92 |
+
# logger=logger,
|
93 |
+
save_dir=export_dir,
|
94 |
+
is_progress_bar=True,
|
95 |
+
) # ,
|
96 |
+
# gif_visualizer=gif_visualizer)
|
97 |
+
trainer(data_loader, epochs=epochs, checkpoint_every=10)
|
98 |
+
|
99 |
+
save_model(trainer.model, export_dir)
|
100 |
+
# , metadata=config) # Speichern passiert auch schon vorher
|
101 |
+
|
102 |
+
# gif_visualizer = GifTraversalsTraining(model, args.dataset, exp_dir)
|
103 |
+
|
104 |
+
|
105 |
+
def train(dataset, config) -> str:
|
106 |
+
# Validate Config?
|
107 |
+
|
108 |
+
print("1) Set device")
|
109 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
110 |
+
print(f"Device:\t\t {device}")
|
111 |
+
|
112 |
+
print("2) Get dataloader")
|
113 |
+
dataloader = DataLoader(
|
114 |
+
dataset,
|
115 |
+
batch_size=config["data_params"]["batch_size"],
|
116 |
+
shuffle=True,
|
117 |
+
pin_memory=torch.cuda.is_available,
|
118 |
+
num_workers=config["data_params"]["num_workers"],
|
119 |
+
)
|
120 |
+
|
121 |
+
print("3) Build model")
|
122 |
+
img_size = list(dataloader.dataset[0][0].shape)
|
123 |
+
print(f"Image size: \t {img_size}")
|
124 |
+
model = init_specific_model(img_size=img_size, **config["model_params"])
|
125 |
+
model = model.to(device) # make sure trainer and viz on same device
|
126 |
+
|
127 |
+
print("4) Build loss function")
|
128 |
+
loss_f = get_loss_f(
|
129 |
+
n_data=len(dataloader.dataset), device=device, **config["loss_params"]
|
130 |
+
)
|
131 |
+
|
132 |
+
print("5) Parse Export Params")
|
133 |
+
export_dir = get_export_dir(**config["export_params"])
|
134 |
+
|
135 |
+
print("6) Training model")
|
136 |
+
train_model(
|
137 |
+
model=model,
|
138 |
+
data_loader=dataloader,
|
139 |
+
loss_f=loss_f,
|
140 |
+
device=device,
|
141 |
+
export_dir=export_dir,
|
142 |
+
**config["trainer_params"],
|
143 |
+
)
|
144 |
+
|
145 |
+
return export_dir
|
disvae/models/__init__.py
ADDED
File without changes
|
disvae/models/decoders.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the decoders.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
# ALL decoders should be called Decoder<Model>
|
11 |
+
def get_decoder(model_type):
|
12 |
+
model_type = model_type.lower().capitalize()
|
13 |
+
return eval("Decoder{}".format(model_type))
|
14 |
+
|
15 |
+
|
16 |
+
class DecoderBurgess(nn.Module):
|
17 |
+
def __init__(self, img_size,
|
18 |
+
latent_dim=10):
|
19 |
+
r"""Decoder of the model proposed in [1].
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
img_size : tuple of ints
|
24 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
25 |
+
|
26 |
+
latent_dim : int
|
27 |
+
Dimensionality of latent output.
|
28 |
+
|
29 |
+
Model Architecture (transposed for decoder)
|
30 |
+
------------
|
31 |
+
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
|
32 |
+
- 2 fully connected layers (each of 256 units)
|
33 |
+
- Latent distribution:
|
34 |
+
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
38 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
39 |
+
"""
|
40 |
+
super(DecoderBurgess, self).__init__()
|
41 |
+
|
42 |
+
# Layer parameters
|
43 |
+
hid_channels = 32
|
44 |
+
kernel_size = 4
|
45 |
+
hidden_dim = 256
|
46 |
+
self.img_size = img_size
|
47 |
+
# Shape required to start transpose convs
|
48 |
+
self.reshape = (hid_channels, kernel_size, kernel_size)
|
49 |
+
n_chan = self.img_size[0]
|
50 |
+
self.img_size = img_size
|
51 |
+
|
52 |
+
# Fully connected layers
|
53 |
+
self.lin1 = nn.Linear(latent_dim, hidden_dim)
|
54 |
+
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
|
55 |
+
self.lin3 = nn.Linear(hidden_dim, np.product(self.reshape))
|
56 |
+
|
57 |
+
# Convolutional layers
|
58 |
+
cnn_kwargs = dict(stride=2, padding=1)
|
59 |
+
# If input image is 64x64 do fourth convolution
|
60 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
61 |
+
self.convT_64 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
62 |
+
|
63 |
+
self.convT1 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
64 |
+
self.convT2 = nn.ConvTranspose2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
65 |
+
self.convT3 = nn.ConvTranspose2d(hid_channels, n_chan, kernel_size, **cnn_kwargs)
|
66 |
+
|
67 |
+
def forward(self, z):
|
68 |
+
batch_size = z.size(0)
|
69 |
+
|
70 |
+
# Fully connected layers with ReLu activations
|
71 |
+
x = torch.relu(self.lin1(z))
|
72 |
+
x = torch.relu(self.lin2(x))
|
73 |
+
x = torch.relu(self.lin3(x))
|
74 |
+
x = x.view(batch_size, *self.reshape)
|
75 |
+
|
76 |
+
# Convolutional layers with ReLu activations
|
77 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
78 |
+
x = torch.relu(self.convT_64(x))
|
79 |
+
x = torch.relu(self.convT1(x))
|
80 |
+
x = torch.relu(self.convT2(x))
|
81 |
+
# Sigmoid activation for final conv layer
|
82 |
+
x = torch.sigmoid(self.convT3(x))
|
83 |
+
|
84 |
+
return x
|
disvae/models/discriminator.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing discriminator for FactorVAE.
|
3 |
+
"""
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from disvae.utils.initialization import weights_init
|
7 |
+
|
8 |
+
|
9 |
+
class Discriminator(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
neg_slope=0.2,
|
12 |
+
latent_dim=10,
|
13 |
+
hidden_units=1000):
|
14 |
+
"""Discriminator proposed in [1].
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
neg_slope: float
|
19 |
+
Hyperparameter for the Leaky ReLu
|
20 |
+
|
21 |
+
latent_dim : int
|
22 |
+
Dimensionality of latent variables.
|
23 |
+
|
24 |
+
hidden_units: int
|
25 |
+
Number of hidden units in the MLP
|
26 |
+
|
27 |
+
Model Architecture
|
28 |
+
------------
|
29 |
+
- 6 layer multi-layer perceptron, each with 1000 hidden units
|
30 |
+
- Leaky ReLu activations
|
31 |
+
- Output 2 logits
|
32 |
+
|
33 |
+
References:
|
34 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
35 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
36 |
+
|
37 |
+
"""
|
38 |
+
super(Discriminator, self).__init__()
|
39 |
+
|
40 |
+
# Activation parameters
|
41 |
+
self.neg_slope = neg_slope
|
42 |
+
self.leaky_relu = nn.LeakyReLU(self.neg_slope, True)
|
43 |
+
|
44 |
+
# Layer parameters
|
45 |
+
self.z_dim = latent_dim
|
46 |
+
self.hidden_units = hidden_units
|
47 |
+
# theoretically 1 with sigmoid but gives bad results => use 2 and softmax
|
48 |
+
out_units = 2
|
49 |
+
|
50 |
+
# Fully connected layers
|
51 |
+
self.lin1 = nn.Linear(self.z_dim, hidden_units)
|
52 |
+
self.lin2 = nn.Linear(hidden_units, hidden_units)
|
53 |
+
self.lin3 = nn.Linear(hidden_units, hidden_units)
|
54 |
+
self.lin4 = nn.Linear(hidden_units, hidden_units)
|
55 |
+
self.lin5 = nn.Linear(hidden_units, hidden_units)
|
56 |
+
self.lin6 = nn.Linear(hidden_units, out_units)
|
57 |
+
|
58 |
+
self.reset_parameters()
|
59 |
+
|
60 |
+
def forward(self, z):
|
61 |
+
|
62 |
+
# Fully connected layers with leaky ReLu activations
|
63 |
+
z = self.leaky_relu(self.lin1(z))
|
64 |
+
z = self.leaky_relu(self.lin2(z))
|
65 |
+
z = self.leaky_relu(self.lin3(z))
|
66 |
+
z = self.leaky_relu(self.lin4(z))
|
67 |
+
z = self.leaky_relu(self.lin5(z))
|
68 |
+
z = self.lin6(z)
|
69 |
+
|
70 |
+
return z
|
71 |
+
|
72 |
+
def reset_parameters(self):
|
73 |
+
self.apply(weights_init)
|
disvae/models/encoders.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the encoders.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
# ALL encoders should be called Enccoder<Model>
|
11 |
+
def get_encoder(model_type):
|
12 |
+
model_type = model_type.lower().capitalize()
|
13 |
+
return eval("Encoder{}".format(model_type))
|
14 |
+
|
15 |
+
|
16 |
+
class EncoderBurgess(nn.Module):
|
17 |
+
def __init__(self, img_size,
|
18 |
+
latent_dim=10):
|
19 |
+
r"""Encoder of the model proposed in [1].
|
20 |
+
|
21 |
+
Parameters
|
22 |
+
----------
|
23 |
+
img_size : tuple of ints
|
24 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
25 |
+
|
26 |
+
latent_dim : int
|
27 |
+
Dimensionality of latent output.
|
28 |
+
|
29 |
+
Model Architecture (transposed for decoder)
|
30 |
+
------------
|
31 |
+
- 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
|
32 |
+
- 2 fully connected layers (each of 256 units)
|
33 |
+
- Latent distribution:
|
34 |
+
- 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
38 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
39 |
+
"""
|
40 |
+
super(EncoderBurgess, self).__init__()
|
41 |
+
|
42 |
+
# Layer parameters
|
43 |
+
hid_channels = 32
|
44 |
+
kernel_size = 4
|
45 |
+
hidden_dim = 256
|
46 |
+
self.latent_dim = latent_dim
|
47 |
+
self.img_size = img_size
|
48 |
+
# Shape required to start transpose convs
|
49 |
+
self.reshape = (hid_channels, kernel_size, kernel_size)
|
50 |
+
n_chan = self.img_size[0]
|
51 |
+
|
52 |
+
# Convolutional layers
|
53 |
+
cnn_kwargs = dict(stride=2, padding=1)
|
54 |
+
self.conv1 = nn.Conv2d(n_chan, hid_channels, kernel_size, **cnn_kwargs)
|
55 |
+
self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
56 |
+
self.conv3 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
57 |
+
|
58 |
+
# If input image is 64x64 do fourth convolution
|
59 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
60 |
+
self.conv_64 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
|
61 |
+
|
62 |
+
# Fully connected layers
|
63 |
+
self.lin1 = nn.Linear(np.product(self.reshape), hidden_dim)
|
64 |
+
self.lin2 = nn.Linear(hidden_dim, hidden_dim)
|
65 |
+
|
66 |
+
# Fully connected layers for mean and variance
|
67 |
+
self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
batch_size = x.size(0)
|
71 |
+
|
72 |
+
# Convolutional layers with ReLu activations
|
73 |
+
x = torch.relu(self.conv1(x))
|
74 |
+
x = torch.relu(self.conv2(x))
|
75 |
+
x = torch.relu(self.conv3(x))
|
76 |
+
if self.img_size[1] == self.img_size[2] == 64:
|
77 |
+
x = torch.relu(self.conv_64(x))
|
78 |
+
|
79 |
+
# Fully connected layers with ReLu activations
|
80 |
+
x = x.view((batch_size, -1))
|
81 |
+
x = torch.relu(self.lin1(x))
|
82 |
+
x = torch.relu(self.lin2(x))
|
83 |
+
|
84 |
+
# Fully connected layer for log variance and mean
|
85 |
+
# Log std-dev in paper (bear in mind)
|
86 |
+
mu_logvar = self.mu_logvar_gen(x)
|
87 |
+
mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)
|
88 |
+
|
89 |
+
return mu, logvar
|
disvae/models/losses.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing all vae losses.
|
3 |
+
"""
|
4 |
+
import abc
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch import optim
|
11 |
+
|
12 |
+
from .discriminator import Discriminator
|
13 |
+
from disvae.utils.math import (log_density_gaussian, log_importance_weight_matrix,
|
14 |
+
matrix_log_density_gaussian)
|
15 |
+
|
16 |
+
|
17 |
+
LOSSES = ["VAE", "betaH", "betaB", "factor", "btcvae"]
|
18 |
+
RECON_DIST = ["bernoulli", "laplace", "gaussian"]
|
19 |
+
|
20 |
+
|
21 |
+
# TO-DO: clean n_data and device
|
22 |
+
def get_loss_f(loss_name, **kwargs_parse):
|
23 |
+
"""Return the correct loss function given the argparse arguments."""
|
24 |
+
kwargs_all = dict(rec_dist=kwargs_parse["rec_dist"],
|
25 |
+
steps_anneal=kwargs_parse["reg_anneal"])
|
26 |
+
if loss_name == "betaH":
|
27 |
+
return BetaHLoss(beta=kwargs_parse["betaH_B"], **kwargs_all)
|
28 |
+
elif loss_name == "VAE":
|
29 |
+
return BetaHLoss(beta=1, **kwargs_all)
|
30 |
+
elif loss_name == "betaB":
|
31 |
+
return BetaBLoss(C_init=kwargs_parse["betaB_initC"],
|
32 |
+
C_fin=kwargs_parse["betaB_finC"],
|
33 |
+
gamma=kwargs_parse["betaB_G"],
|
34 |
+
**kwargs_all)
|
35 |
+
elif loss_name == "factor":
|
36 |
+
return FactorKLoss(kwargs_parse["device"],
|
37 |
+
gamma=kwargs_parse["factor_G"],
|
38 |
+
disc_kwargs=dict(latent_dim=kwargs_parse["latent_dim"]),
|
39 |
+
optim_kwargs=dict(lr=kwargs_parse["lr_disc"], betas=(0.5, 0.9)),
|
40 |
+
**kwargs_all)
|
41 |
+
elif loss_name == "btcvae":
|
42 |
+
return BtcvaeLoss(kwargs_parse["n_data"],
|
43 |
+
alpha=kwargs_parse["btcvae_A"],
|
44 |
+
beta=kwargs_parse["btcvae_B"],
|
45 |
+
gamma=kwargs_parse["btcvae_G"],
|
46 |
+
**kwargs_all)
|
47 |
+
else:
|
48 |
+
assert loss_name not in LOSSES
|
49 |
+
raise ValueError("Uknown loss : {}".format(loss_name))
|
50 |
+
|
51 |
+
|
52 |
+
class BaseLoss(abc.ABC):
|
53 |
+
"""
|
54 |
+
Base class for losses.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
record_loss_every: int, optional
|
59 |
+
Every how many steps to recorsd the loss.
|
60 |
+
|
61 |
+
rec_dist: {"bernoulli", "gaussian", "laplace"}, optional
|
62 |
+
Reconstruction distribution istribution of the likelihood on the each pixel.
|
63 |
+
Implicitely defines the reconstruction loss. Bernoulli corresponds to a
|
64 |
+
binary cross entropy (bse), Gaussian corresponds to MSE, Laplace
|
65 |
+
corresponds to L1.
|
66 |
+
|
67 |
+
steps_anneal: nool, optional
|
68 |
+
Number of annealing steps where gradually adding the regularisation.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, record_loss_every=50, rec_dist="bernoulli", steps_anneal=0):
|
72 |
+
self.n_train_steps = 0
|
73 |
+
self.record_loss_every = record_loss_every
|
74 |
+
self.rec_dist = rec_dist
|
75 |
+
self.steps_anneal = steps_anneal
|
76 |
+
|
77 |
+
@abc.abstractmethod
|
78 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
79 |
+
"""
|
80 |
+
Calculates loss for a batch of data.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
data : torch.Tensor
|
85 |
+
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
|
86 |
+
height, width).
|
87 |
+
|
88 |
+
recon_data : torch.Tensor
|
89 |
+
Reconstructed data. Shape : (batch_size, n_chan, height, width).
|
90 |
+
|
91 |
+
latent_dist : tuple of torch.tensor
|
92 |
+
sufficient statistics of the latent dimension. E.g. for gaussian
|
93 |
+
(mean, log_var) each of shape : (batch_size, latent_dim).
|
94 |
+
|
95 |
+
is_train : bool
|
96 |
+
Whether currently in train mode.
|
97 |
+
|
98 |
+
storer : dict
|
99 |
+
Dictionary in which to store important variables for vizualisation.
|
100 |
+
|
101 |
+
kwargs:
|
102 |
+
Loss specific arguments
|
103 |
+
"""
|
104 |
+
|
105 |
+
def _pre_call(self, is_train, storer):
|
106 |
+
if is_train:
|
107 |
+
self.n_train_steps += 1
|
108 |
+
|
109 |
+
if not is_train or self.n_train_steps % self.record_loss_every == 1:
|
110 |
+
storer = storer
|
111 |
+
else:
|
112 |
+
storer = None
|
113 |
+
|
114 |
+
return storer
|
115 |
+
|
116 |
+
|
117 |
+
class BetaHLoss(BaseLoss):
|
118 |
+
"""
|
119 |
+
Compute the Beta-VAE loss as in [1]
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
beta : float, optional
|
124 |
+
Weight of the kl divergence.
|
125 |
+
|
126 |
+
kwargs:
|
127 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
128 |
+
|
129 |
+
References
|
130 |
+
----------
|
131 |
+
[1] Higgins, Irina, et al. "beta-vae: Learning basic visual concepts with
|
132 |
+
a constrained variational framework." (2016).
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(self, beta=4, **kwargs):
|
136 |
+
super().__init__(**kwargs)
|
137 |
+
self.beta = beta
|
138 |
+
|
139 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
140 |
+
storer = self._pre_call(is_train, storer)
|
141 |
+
|
142 |
+
rec_loss = _reconstruction_loss(data, recon_data,
|
143 |
+
storer=storer,
|
144 |
+
distribution=self.rec_dist)
|
145 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
146 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
147 |
+
if is_train else 1)
|
148 |
+
loss = rec_loss + anneal_reg * (self.beta * kl_loss)
|
149 |
+
|
150 |
+
if storer is not None:
|
151 |
+
storer['loss'].append(loss.item())
|
152 |
+
|
153 |
+
return loss
|
154 |
+
|
155 |
+
|
156 |
+
class BetaBLoss(BaseLoss):
|
157 |
+
"""
|
158 |
+
Compute the Beta-VAE loss as in [1]
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
C_init : float, optional
|
163 |
+
Starting annealed capacity C.
|
164 |
+
|
165 |
+
C_fin : float, optional
|
166 |
+
Final annealed capacity C.
|
167 |
+
|
168 |
+
gamma : float, optional
|
169 |
+
Weight of the KL divergence term.
|
170 |
+
|
171 |
+
kwargs:
|
172 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
173 |
+
|
174 |
+
References
|
175 |
+
----------
|
176 |
+
[1] Burgess, Christopher P., et al. "Understanding disentangling in
|
177 |
+
$\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, C_init=0., C_fin=20., gamma=100., **kwargs):
|
181 |
+
super().__init__(**kwargs)
|
182 |
+
self.gamma = gamma
|
183 |
+
self.C_init = C_init
|
184 |
+
self.C_fin = C_fin
|
185 |
+
|
186 |
+
def __call__(self, data, recon_data, latent_dist, is_train, storer, **kwargs):
|
187 |
+
storer = self._pre_call(is_train, storer)
|
188 |
+
|
189 |
+
rec_loss = _reconstruction_loss(data, recon_data,
|
190 |
+
storer=storer,
|
191 |
+
distribution=self.rec_dist)
|
192 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
193 |
+
|
194 |
+
C = (linear_annealing(self.C_init, self.C_fin, self.n_train_steps, self.steps_anneal)
|
195 |
+
if is_train else self.C_fin)
|
196 |
+
|
197 |
+
loss = rec_loss + self.gamma * (kl_loss - C).abs()
|
198 |
+
|
199 |
+
if storer is not None:
|
200 |
+
storer['loss'].append(loss.item())
|
201 |
+
|
202 |
+
return loss
|
203 |
+
|
204 |
+
|
205 |
+
class FactorKLoss(BaseLoss):
|
206 |
+
"""
|
207 |
+
Compute the Factor-VAE loss as per Algorithm 2 of [1]
|
208 |
+
|
209 |
+
Parameters
|
210 |
+
----------
|
211 |
+
device : torch.device
|
212 |
+
|
213 |
+
gamma : float, optional
|
214 |
+
Weight of the TC loss term. `gamma` in the paper.
|
215 |
+
|
216 |
+
discriminator : disvae.discriminator.Discriminator
|
217 |
+
|
218 |
+
optimizer_d : torch.optim
|
219 |
+
|
220 |
+
kwargs:
|
221 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
222 |
+
|
223 |
+
References
|
224 |
+
----------
|
225 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
226 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(self, device,
|
230 |
+
gamma=10.,
|
231 |
+
disc_kwargs={},
|
232 |
+
optim_kwargs=dict(lr=5e-5, betas=(0.5, 0.9)),
|
233 |
+
**kwargs):
|
234 |
+
super().__init__(**kwargs)
|
235 |
+
self.gamma = gamma
|
236 |
+
self.device = device
|
237 |
+
self.discriminator = Discriminator(**disc_kwargs).to(self.device)
|
238 |
+
self.optimizer_d = optim.Adam(self.discriminator.parameters(), **optim_kwargs)
|
239 |
+
|
240 |
+
def __call__(self, *args, **kwargs):
|
241 |
+
raise ValueError("Use `call_optimize` to also train the discriminator")
|
242 |
+
|
243 |
+
def call_optimize(self, data, model, optimizer, storer):
|
244 |
+
storer = self._pre_call(model.training, storer)
|
245 |
+
|
246 |
+
# factor-vae split data into two batches. In the paper they sample 2 batches
|
247 |
+
batch_size = data.size(dim=0)
|
248 |
+
half_batch_size = batch_size // 2
|
249 |
+
data = data.split(half_batch_size)
|
250 |
+
data1 = data[0]
|
251 |
+
data2 = data[1]
|
252 |
+
|
253 |
+
# Factor VAE Loss
|
254 |
+
recon_batch, latent_dist, latent_sample1 = model(data1)
|
255 |
+
rec_loss = _reconstruction_loss(data1, recon_batch,
|
256 |
+
storer=storer,
|
257 |
+
distribution=self.rec_dist)
|
258 |
+
|
259 |
+
kl_loss = _kl_normal_loss(*latent_dist, storer)
|
260 |
+
|
261 |
+
d_z = self.discriminator(latent_sample1)
|
262 |
+
# We want log(p_true/p_false). If not using logisitc regression but softmax
|
263 |
+
# then p_true = exp(logit_true) / Z; p_false = exp(logit_false) / Z
|
264 |
+
# so log(p_true/p_false) = logit_true - logit_false
|
265 |
+
tc_loss = (d_z[:, 0] - d_z[:, 1]).mean()
|
266 |
+
# with sigmoid (not good results) should be `tc_loss = (2 * d_z.flatten()).mean()`
|
267 |
+
|
268 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
269 |
+
if model.training else 1)
|
270 |
+
vae_loss = rec_loss + kl_loss + anneal_reg * self.gamma * tc_loss
|
271 |
+
|
272 |
+
if storer is not None:
|
273 |
+
storer['loss'].append(vae_loss.item())
|
274 |
+
storer['tc_loss'].append(tc_loss.item())
|
275 |
+
|
276 |
+
if not model.training:
|
277 |
+
# don't backprop if evaluating
|
278 |
+
return vae_loss
|
279 |
+
|
280 |
+
# Compute VAE gradients
|
281 |
+
optimizer.zero_grad()
|
282 |
+
vae_loss.backward(retain_graph=True)
|
283 |
+
|
284 |
+
# Discriminator Loss
|
285 |
+
# Get second sample of latent distribution
|
286 |
+
latent_sample2 = model.sample_latent(data2)
|
287 |
+
z_perm = _permute_dims(latent_sample2).detach()
|
288 |
+
d_z_perm = self.discriminator(z_perm)
|
289 |
+
|
290 |
+
# Calculate total correlation loss
|
291 |
+
# for cross entropy the target is the index => need to be long and says
|
292 |
+
# that it's first output for d_z and second for perm
|
293 |
+
ones = torch.ones(half_batch_size, dtype=torch.long, device=self.device)
|
294 |
+
zeros = torch.zeros_like(ones)
|
295 |
+
d_tc_loss = 0.5 * (F.cross_entropy(d_z, zeros) + F.cross_entropy(d_z_perm, ones))
|
296 |
+
# with sigmoid would be :
|
297 |
+
# d_tc_loss = 0.5 * (self.bce(d_z.flatten(), ones) + self.bce(d_z_perm.flatten(), 1 - ones))
|
298 |
+
|
299 |
+
# TO-DO: check ifshould also anneals discriminator if not becomes too good ???
|
300 |
+
#d_tc_loss = anneal_reg * d_tc_loss
|
301 |
+
|
302 |
+
# Compute discriminator gradients
|
303 |
+
self.optimizer_d.zero_grad()
|
304 |
+
d_tc_loss.backward()
|
305 |
+
|
306 |
+
# Update at the end (since pytorch 1.5. complains if update before)
|
307 |
+
optimizer.step()
|
308 |
+
self.optimizer_d.step()
|
309 |
+
|
310 |
+
if storer is not None:
|
311 |
+
storer['discrim_loss'].append(d_tc_loss.item())
|
312 |
+
|
313 |
+
return vae_loss
|
314 |
+
|
315 |
+
|
316 |
+
class BtcvaeLoss(BaseLoss):
|
317 |
+
"""
|
318 |
+
Compute the decomposed KL loss with either minibatch weighted sampling or
|
319 |
+
minibatch stratified sampling according to [1]
|
320 |
+
|
321 |
+
Parameters
|
322 |
+
----------
|
323 |
+
n_data: int
|
324 |
+
Number of data in the training set
|
325 |
+
|
326 |
+
alpha : float
|
327 |
+
Weight of the mutual information term.
|
328 |
+
|
329 |
+
beta : float
|
330 |
+
Weight of the total correlation term.
|
331 |
+
|
332 |
+
gamma : float
|
333 |
+
Weight of the dimension-wise KL term.
|
334 |
+
|
335 |
+
is_mss : bool
|
336 |
+
Whether to use minibatch stratified sampling instead of minibatch
|
337 |
+
weighted sampling.
|
338 |
+
|
339 |
+
kwargs:
|
340 |
+
Additional arguments for `BaseLoss`, e.g. rec_dist`.
|
341 |
+
|
342 |
+
References
|
343 |
+
----------
|
344 |
+
[1] Chen, Tian Qi, et al. "Isolating sources of disentanglement in variational
|
345 |
+
autoencoders." Advances in Neural Information Processing Systems. 2018.
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, n_data, alpha=1., beta=6., gamma=1., is_mss=True, **kwargs):
|
349 |
+
super().__init__(**kwargs)
|
350 |
+
self.n_data = n_data
|
351 |
+
self.beta = beta
|
352 |
+
self.alpha = alpha
|
353 |
+
self.gamma = gamma
|
354 |
+
self.is_mss = is_mss # minibatch stratified sampling
|
355 |
+
|
356 |
+
def __call__(self, data, recon_batch, latent_dist, is_train, storer,
|
357 |
+
latent_sample=None):
|
358 |
+
storer = self._pre_call(is_train, storer)
|
359 |
+
batch_size, latent_dim = latent_sample.shape
|
360 |
+
|
361 |
+
rec_loss = _reconstruction_loss(data, recon_batch,
|
362 |
+
storer=storer,
|
363 |
+
distribution=self.rec_dist)
|
364 |
+
log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
|
365 |
+
latent_dist,
|
366 |
+
self.n_data,
|
367 |
+
is_mss=self.is_mss)
|
368 |
+
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
|
369 |
+
mi_loss = (log_q_zCx - log_qz).mean()
|
370 |
+
# TC[z] = KL[q(z)||\prod_i z_i]
|
371 |
+
tc_loss = (log_qz - log_prod_qzi).mean()
|
372 |
+
# dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
|
373 |
+
dw_kl_loss = (log_prod_qzi - log_pz).mean()
|
374 |
+
|
375 |
+
anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
|
376 |
+
if is_train else 1)
|
377 |
+
|
378 |
+
# total loss
|
379 |
+
loss = rec_loss + (self.alpha * mi_loss +
|
380 |
+
self.beta * tc_loss +
|
381 |
+
anneal_reg * self.gamma * dw_kl_loss)
|
382 |
+
|
383 |
+
if storer is not None:
|
384 |
+
storer['loss'].append(loss.item())
|
385 |
+
storer['mi_loss'].append(mi_loss.item())
|
386 |
+
storer['tc_loss'].append(tc_loss.item())
|
387 |
+
storer['dw_kl_loss'].append(dw_kl_loss.item())
|
388 |
+
# computing this for storing and comparaison purposes
|
389 |
+
_ = _kl_normal_loss(*latent_dist, storer)
|
390 |
+
|
391 |
+
return loss
|
392 |
+
|
393 |
+
|
394 |
+
def _reconstruction_loss(data, recon_data, distribution="bernoulli", storer=None):
|
395 |
+
"""
|
396 |
+
Calculates the per image reconstruction loss for a batch of data. I.e. negative
|
397 |
+
log likelihood.
|
398 |
+
|
399 |
+
Parameters
|
400 |
+
----------
|
401 |
+
data : torch.Tensor
|
402 |
+
Input data (e.g. batch of images). Shape : (batch_size, n_chan,
|
403 |
+
height, width).
|
404 |
+
|
405 |
+
recon_data : torch.Tensor
|
406 |
+
Reconstructed data. Shape : (batch_size, n_chan, height, width).
|
407 |
+
|
408 |
+
distribution : {"bernoulli", "gaussian", "laplace"}
|
409 |
+
Distribution of the likelihood on the each pixel. Implicitely defines the
|
410 |
+
loss Bernoulli corresponds to a binary cross entropy (bse) loss and is the
|
411 |
+
most commonly used. It has the issue that it doesn't penalize the same
|
412 |
+
way (0.1,0.2) and (0.4,0.5), which might not be optimal. Gaussian
|
413 |
+
distribution corresponds to MSE, and is sometimes used, but hard to train
|
414 |
+
ecause it ends up focusing only a few pixels that are very wrong. Laplace
|
415 |
+
distribution corresponds to L1 solves partially the issue of MSE.
|
416 |
+
|
417 |
+
storer : dict
|
418 |
+
Dictionary in which to store important variables for vizualisation.
|
419 |
+
|
420 |
+
Returns
|
421 |
+
-------
|
422 |
+
loss : torch.Tensor
|
423 |
+
Per image cross entropy (i.e. normalized per batch but not pixel and
|
424 |
+
channel)
|
425 |
+
"""
|
426 |
+
batch_size, n_chan, height, width = recon_data.size()
|
427 |
+
is_colored = n_chan == 3
|
428 |
+
|
429 |
+
if distribution == "bernoulli":
|
430 |
+
loss = F.binary_cross_entropy(recon_data, data, reduction="sum")
|
431 |
+
elif distribution == "gaussian":
|
432 |
+
# loss in [0,255] space but normalized by 255 to not be too big
|
433 |
+
loss = F.mse_loss(recon_data * 255, data * 255, reduction="sum") / 255
|
434 |
+
elif distribution == "laplace":
|
435 |
+
# loss in [0,255] space but normalized by 255 to not be too big but
|
436 |
+
# multiply by 255 and divide 255, is the same as not doing anything for L1
|
437 |
+
loss = F.l1_loss(recon_data, data, reduction="sum")
|
438 |
+
loss = loss * 3 # emperical value to give similar values than bernoulli => use same hyperparam
|
439 |
+
loss = loss * (loss != 0) # masking to avoid nan
|
440 |
+
else:
|
441 |
+
assert distribution not in RECON_DIST
|
442 |
+
raise ValueError("Unkown distribution: {}".format(distribution))
|
443 |
+
|
444 |
+
loss = loss / batch_size
|
445 |
+
|
446 |
+
if storer is not None:
|
447 |
+
storer['recon_loss'].append(loss.item())
|
448 |
+
|
449 |
+
return loss
|
450 |
+
|
451 |
+
|
452 |
+
def _kl_normal_loss(mean, logvar, storer=None):
|
453 |
+
"""
|
454 |
+
Calculates the KL divergence between a normal distribution
|
455 |
+
with diagonal covariance and a unit normal distribution.
|
456 |
+
|
457 |
+
Parameters
|
458 |
+
----------
|
459 |
+
mean : torch.Tensor
|
460 |
+
Mean of the normal distribution. Shape (batch_size, latent_dim) where
|
461 |
+
D is dimension of distribution.
|
462 |
+
|
463 |
+
logvar : torch.Tensor
|
464 |
+
Diagonal log variance of the normal distribution. Shape (batch_size,
|
465 |
+
latent_dim)
|
466 |
+
|
467 |
+
storer : dict
|
468 |
+
Dictionary in which to store important variables for vizualisation.
|
469 |
+
"""
|
470 |
+
latent_dim = mean.size(1)
|
471 |
+
# batch mean of kl for each latent dimension
|
472 |
+
latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
|
473 |
+
total_kl = latent_kl.sum()
|
474 |
+
|
475 |
+
if storer is not None:
|
476 |
+
storer['kl_loss'].append(total_kl.item())
|
477 |
+
for i in range(latent_dim):
|
478 |
+
storer['kl_loss_' + str(i)].append(latent_kl[i].item())
|
479 |
+
|
480 |
+
return total_kl
|
481 |
+
|
482 |
+
|
483 |
+
def _permute_dims(latent_sample):
|
484 |
+
"""
|
485 |
+
Implementation of Algorithm 1 in ref [1]. Randomly permutes the sample from
|
486 |
+
q(z) (latent_dist) across the batch for each of the latent dimensions (mean
|
487 |
+
and log_var).
|
488 |
+
|
489 |
+
Parameters
|
490 |
+
----------
|
491 |
+
latent_sample: torch.Tensor
|
492 |
+
sample from the latent dimension using the reparameterisation trick
|
493 |
+
shape : (batch_size, latent_dim).
|
494 |
+
|
495 |
+
References
|
496 |
+
----------
|
497 |
+
[1] Kim, Hyunjik, and Andriy Mnih. "Disentangling by factorising."
|
498 |
+
arXiv preprint arXiv:1802.05983 (2018).
|
499 |
+
|
500 |
+
"""
|
501 |
+
perm = torch.zeros_like(latent_sample)
|
502 |
+
batch_size, dim_z = perm.size()
|
503 |
+
|
504 |
+
for z in range(dim_z):
|
505 |
+
pi = torch.randperm(batch_size).to(latent_sample.device)
|
506 |
+
perm[:, z] = latent_sample[pi, z]
|
507 |
+
|
508 |
+
return perm
|
509 |
+
|
510 |
+
|
511 |
+
def linear_annealing(init, fin, step, annealing_steps):
|
512 |
+
"""Linear annealing of a parameter."""
|
513 |
+
if annealing_steps == 0:
|
514 |
+
return fin
|
515 |
+
assert fin > init
|
516 |
+
delta = fin - init
|
517 |
+
annealed = min(init + delta * step / annealing_steps, fin)
|
518 |
+
return annealed
|
519 |
+
|
520 |
+
|
521 |
+
# Batch TC specific
|
522 |
+
# TO-DO: test if mss is better!
|
523 |
+
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
|
524 |
+
batch_size, hidden_dim = latent_sample.shape
|
525 |
+
|
526 |
+
# calculate log q(z|x)
|
527 |
+
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
|
528 |
+
|
529 |
+
# calculate log p(z)
|
530 |
+
# mean and log var is 0
|
531 |
+
zeros = torch.zeros_like(latent_sample)
|
532 |
+
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
|
533 |
+
|
534 |
+
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
|
535 |
+
|
536 |
+
if is_mss:
|
537 |
+
# use stratification
|
538 |
+
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
|
539 |
+
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
|
540 |
+
|
541 |
+
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
|
542 |
+
log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
|
543 |
+
|
544 |
+
return log_pz, log_qz, log_prod_qzi, log_q_zCx
|
disvae/models/vae.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module containing the main VAE class.
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch import nn, optim
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from disvae.utils.initialization import weights_init
|
9 |
+
from .encoders import get_encoder
|
10 |
+
from .decoders import get_decoder
|
11 |
+
|
12 |
+
MODELS = ["Burgess"]
|
13 |
+
|
14 |
+
|
15 |
+
def init_specific_model(model_type, img_size, latent_dim):
|
16 |
+
"""Return an instance of a VAE with encoder and decoder from `model_type`."""
|
17 |
+
model_type = model_type.lower().capitalize()
|
18 |
+
if model_type not in MODELS:
|
19 |
+
err = "Unkown model_type={}. Possible values: {}"
|
20 |
+
raise ValueError(err.format(model_type, MODELS))
|
21 |
+
|
22 |
+
encoder = get_encoder(model_type)
|
23 |
+
decoder = get_decoder(model_type)
|
24 |
+
model = VAE(img_size, encoder, decoder, latent_dim)
|
25 |
+
model.model_type = model_type # store to help reloading
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
class VAE(nn.Module):
|
30 |
+
def __init__(self, img_size, encoder, decoder, latent_dim):
|
31 |
+
"""
|
32 |
+
Class which defines model and forward pass.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
img_size : tuple of ints
|
37 |
+
Size of images. E.g. (1, 32, 32) or (3, 64, 64).
|
38 |
+
"""
|
39 |
+
super(VAE, self).__init__()
|
40 |
+
|
41 |
+
if list(img_size[1:]) not in [[32, 32], [64, 64]]:
|
42 |
+
raise RuntimeError("{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(img_size))
|
43 |
+
|
44 |
+
self.latent_dim = latent_dim
|
45 |
+
self.img_size = img_size
|
46 |
+
self.num_pixels = self.img_size[1] * self.img_size[2]
|
47 |
+
self.encoder = encoder(img_size, self.latent_dim)
|
48 |
+
self.decoder = decoder(img_size, self.latent_dim)
|
49 |
+
|
50 |
+
self.reset_parameters()
|
51 |
+
|
52 |
+
def reparameterize(self, mean, logvar):
|
53 |
+
"""
|
54 |
+
Samples from a normal distribution using the reparameterization trick.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
mean : torch.Tensor
|
59 |
+
Mean of the normal distribution. Shape (batch_size, latent_dim)
|
60 |
+
|
61 |
+
logvar : torch.Tensor
|
62 |
+
Diagonal log variance of the normal distribution. Shape (batch_size,
|
63 |
+
latent_dim)
|
64 |
+
"""
|
65 |
+
if self.training:
|
66 |
+
std = torch.exp(0.5 * logvar)
|
67 |
+
eps = torch.randn_like(std)
|
68 |
+
return mean + std * eps
|
69 |
+
else:
|
70 |
+
# Reconstruction mode
|
71 |
+
return mean
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""
|
75 |
+
Forward pass of model.
|
76 |
+
|
77 |
+
Parameters
|
78 |
+
----------
|
79 |
+
x : torch.Tensor
|
80 |
+
Batch of data. Shape (batch_size, n_chan, height, width)
|
81 |
+
"""
|
82 |
+
latent_dist = self.encoder(x)
|
83 |
+
latent_sample = self.reparameterize(*latent_dist)
|
84 |
+
reconstruct = self.decoder(latent_sample)
|
85 |
+
return reconstruct, latent_dist, latent_sample
|
86 |
+
|
87 |
+
def reset_parameters(self):
|
88 |
+
self.apply(weights_init)
|
89 |
+
|
90 |
+
def sample_latent(self, x):
|
91 |
+
"""
|
92 |
+
Returns a sample from the latent distribution.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
x : torch.Tensor
|
97 |
+
Batch of data. Shape (batch_size, n_chan, height, width)
|
98 |
+
"""
|
99 |
+
latent_dist = self.encoder(x)
|
100 |
+
latent_sample = self.reparameterize(*latent_dist)
|
101 |
+
return latent_sample
|
disvae/training.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import imageio
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
from timeit import default_timer
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
from disvae.utils.modelIO import save_model
|
12 |
+
|
13 |
+
TRAIN_LOSSES_LOGFILE = "train_losses.log"
|
14 |
+
|
15 |
+
|
16 |
+
class Trainer:
|
17 |
+
"""
|
18 |
+
Class to handle training of model.
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
model: disvae.vae.VAE
|
23 |
+
|
24 |
+
optimizer: torch.optim.Optimizer
|
25 |
+
|
26 |
+
loss_f: disvae.models.BaseLoss
|
27 |
+
Loss function.
|
28 |
+
|
29 |
+
device: torch.device, optional
|
30 |
+
Device on which to run the code.
|
31 |
+
|
32 |
+
logger: logging.Logger, optional
|
33 |
+
Logger.
|
34 |
+
|
35 |
+
save_dir : str, optional
|
36 |
+
Directory for saving logs.
|
37 |
+
|
38 |
+
gif_visualizer : viz.Visualizer, optional
|
39 |
+
Gif Visualizer that should return samples at every epochs.
|
40 |
+
|
41 |
+
is_progress_bar: bool, optional
|
42 |
+
Whether to use a progress bar for training.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
model,
|
48 |
+
optimizer,
|
49 |
+
loss_f,
|
50 |
+
device=torch.device("cpu"),
|
51 |
+
logger=logging.getLogger(__name__),
|
52 |
+
save_dir="results",
|
53 |
+
gif_visualizer=None,
|
54 |
+
is_progress_bar=True,
|
55 |
+
):
|
56 |
+
self.device = device
|
57 |
+
self.model = model.to(self.device)
|
58 |
+
self.loss_f = loss_f
|
59 |
+
self.optimizer = optimizer
|
60 |
+
self.save_dir = save_dir
|
61 |
+
self.is_progress_bar = is_progress_bar
|
62 |
+
self.logger = logger
|
63 |
+
self.losses_logger = LossesLogger(
|
64 |
+
os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE)
|
65 |
+
)
|
66 |
+
self.gif_visualizer = gif_visualizer
|
67 |
+
self.logger.info("Training Device: {}".format(self.device))
|
68 |
+
|
69 |
+
def __call__(self, data_loader, epochs=10, checkpoint_every=10):
|
70 |
+
"""
|
71 |
+
Trains the model.
|
72 |
+
|
73 |
+
Parameters
|
74 |
+
----------
|
75 |
+
data_loader: torch.utils.data.DataLoader
|
76 |
+
|
77 |
+
epochs: int, optional
|
78 |
+
Number of epochs to train the model for.
|
79 |
+
|
80 |
+
checkpoint_every: int, optional
|
81 |
+
Save a checkpoint of the trained model every n epoch.
|
82 |
+
"""
|
83 |
+
start = default_timer()
|
84 |
+
self.model.train()
|
85 |
+
for epoch in range(epochs):
|
86 |
+
storer = defaultdict(list)
|
87 |
+
mean_epoch_loss = self._train_epoch(data_loader, storer, epoch)
|
88 |
+
self.logger.info(
|
89 |
+
"Epoch: {} Average loss per image: {:.2f}".format(
|
90 |
+
epoch + 1, mean_epoch_loss
|
91 |
+
)
|
92 |
+
)
|
93 |
+
self.losses_logger.log(epoch, storer)
|
94 |
+
|
95 |
+
if self.gif_visualizer is not None:
|
96 |
+
self.gif_visualizer()
|
97 |
+
|
98 |
+
if epoch % checkpoint_every == 0:
|
99 |
+
save_model(
|
100 |
+
self.model, self.save_dir, filename="model-{}.pt".format(epoch)
|
101 |
+
)
|
102 |
+
|
103 |
+
if self.gif_visualizer is not None:
|
104 |
+
self.gif_visualizer.save_reset()
|
105 |
+
|
106 |
+
self.model.eval()
|
107 |
+
|
108 |
+
delta_time = (default_timer() - start) / 60
|
109 |
+
self.logger.info("Finished training after {:.1f} min.".format(delta_time))
|
110 |
+
|
111 |
+
def _train_epoch(self, data_loader, storer, epoch):
|
112 |
+
"""
|
113 |
+
Trains the model for one epoch.
|
114 |
+
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
data_loader: torch.utils.data.DataLoader
|
118 |
+
|
119 |
+
storer: dict
|
120 |
+
Dictionary in which to store important variables for vizualisation.
|
121 |
+
|
122 |
+
epoch: int
|
123 |
+
Epoch number
|
124 |
+
|
125 |
+
Return
|
126 |
+
------
|
127 |
+
mean_epoch_loss: float
|
128 |
+
Mean loss per image
|
129 |
+
"""
|
130 |
+
epoch_loss = 0.0
|
131 |
+
kwargs = dict(
|
132 |
+
desc="Epoch {}".format(epoch + 1),
|
133 |
+
leave=False,
|
134 |
+
disable=not self.is_progress_bar,
|
135 |
+
)
|
136 |
+
with trange(len(data_loader), **kwargs) as t:
|
137 |
+
for _, (data, _) in enumerate(data_loader):
|
138 |
+
iter_loss = self._train_iteration(data, storer)
|
139 |
+
epoch_loss += iter_loss
|
140 |
+
|
141 |
+
t.set_postfix(loss=iter_loss)
|
142 |
+
t.update()
|
143 |
+
|
144 |
+
mean_epoch_loss = epoch_loss / len(data_loader)
|
145 |
+
return mean_epoch_loss
|
146 |
+
|
147 |
+
def _train_iteration(self, data, storer):
|
148 |
+
"""
|
149 |
+
Trains the model for one iteration on a batch of data.
|
150 |
+
|
151 |
+
Parameters
|
152 |
+
----------
|
153 |
+
data: torch.Tensor
|
154 |
+
A batch of data. Shape : (batch_size, channel, height, width).
|
155 |
+
|
156 |
+
storer: dict
|
157 |
+
Dictionary in which to store important variables for vizualisation.
|
158 |
+
"""
|
159 |
+
batch_size, channel, height, width = data.size()
|
160 |
+
data = data.to(self.device)
|
161 |
+
|
162 |
+
try:
|
163 |
+
recon_batch, latent_dist, latent_sample = self.model(data)
|
164 |
+
loss = self.loss_f(
|
165 |
+
data,
|
166 |
+
recon_batch,
|
167 |
+
latent_dist,
|
168 |
+
self.model.training,
|
169 |
+
storer,
|
170 |
+
latent_sample=latent_sample,
|
171 |
+
)
|
172 |
+
self.optimizer.zero_grad()
|
173 |
+
loss.backward()
|
174 |
+
self.optimizer.step()
|
175 |
+
|
176 |
+
except ValueError:
|
177 |
+
# for losses that use multiple optimizers (e.g. Factor)
|
178 |
+
loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer)
|
179 |
+
|
180 |
+
return loss.item()
|
181 |
+
|
182 |
+
|
183 |
+
class LossesLogger(object):
|
184 |
+
"""Class definition for objects to write data to log files in a
|
185 |
+
form which is then easy to be plotted.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, file_path_name):
|
189 |
+
"""Create a logger to store information for plotting."""
|
190 |
+
if os.path.isfile(file_path_name):
|
191 |
+
os.remove(file_path_name)
|
192 |
+
|
193 |
+
self.logger = logging.getLogger("losses_logger")
|
194 |
+
self.logger.setLevel(1) # always store
|
195 |
+
file_handler = logging.FileHandler(file_path_name)
|
196 |
+
file_handler.setLevel(1)
|
197 |
+
self.logger.addHandler(file_handler)
|
198 |
+
|
199 |
+
header = ",".join(["Epoch", "Loss", "Value"])
|
200 |
+
self.logger.debug(header)
|
201 |
+
|
202 |
+
def log(self, epoch, losses_storer):
|
203 |
+
"""Write to the log file"""
|
204 |
+
for k, v in losses_storer.items():
|
205 |
+
log_string = ",".join(str(item) for item in [epoch, k, mean(v)])
|
206 |
+
self.logger.debug(log_string)
|
207 |
+
|
208 |
+
|
209 |
+
# HELPERS
|
210 |
+
def mean(l):
|
211 |
+
"""Compute the mean of a list"""
|
212 |
+
return sum(l) / len(l)
|
disvae/utils/__init__.py
ADDED
File without changes
|
disvae/utils/initialization.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
def get_activation_name(activation):
|
6 |
+
"""Given a string or a `torch.nn.modules.activation` return the name of the activation."""
|
7 |
+
if isinstance(activation, str):
|
8 |
+
return activation
|
9 |
+
|
10 |
+
mapper = {nn.LeakyReLU: "leaky_relu", nn.ReLU: "relu", nn.Tanh: "tanh",
|
11 |
+
nn.Sigmoid: "sigmoid", nn.Softmax: "sigmoid"}
|
12 |
+
for k, v in mapper.items():
|
13 |
+
if isinstance(activation, k):
|
14 |
+
return k
|
15 |
+
|
16 |
+
raise ValueError("Unkown given activation type : {}".format(activation))
|
17 |
+
|
18 |
+
|
19 |
+
def get_gain(activation):
|
20 |
+
"""Given an object of `torch.nn.modules.activation` or an activation name
|
21 |
+
return the correct gain."""
|
22 |
+
if activation is None:
|
23 |
+
return 1
|
24 |
+
|
25 |
+
activation_name = get_activation_name(activation)
|
26 |
+
|
27 |
+
param = None if activation_name != "leaky_relu" else activation.negative_slope
|
28 |
+
gain = nn.init.calculate_gain(activation_name, param)
|
29 |
+
|
30 |
+
return gain
|
31 |
+
|
32 |
+
|
33 |
+
def linear_init(layer, activation="relu"):
|
34 |
+
"""Initialize a linear layer.
|
35 |
+
Args:
|
36 |
+
layer (nn.Linear): parameters to initialize.
|
37 |
+
activation (`torch.nn.modules.activation` or str, optional) activation that
|
38 |
+
will be used on the `layer`.
|
39 |
+
"""
|
40 |
+
x = layer.weight
|
41 |
+
|
42 |
+
if activation is None:
|
43 |
+
return nn.init.xavier_uniform_(x)
|
44 |
+
|
45 |
+
activation_name = get_activation_name(activation)
|
46 |
+
|
47 |
+
if activation_name == "leaky_relu":
|
48 |
+
a = 0 if isinstance(activation, str) else activation.negative_slope
|
49 |
+
return nn.init.kaiming_uniform_(x, a=a, nonlinearity='leaky_relu')
|
50 |
+
elif activation_name == "relu":
|
51 |
+
return nn.init.kaiming_uniform_(x, nonlinearity='relu')
|
52 |
+
elif activation_name in ["sigmoid", "tanh"]:
|
53 |
+
return nn.init.xavier_uniform_(x, gain=get_gain(activation))
|
54 |
+
|
55 |
+
|
56 |
+
def weights_init(module):
|
57 |
+
if isinstance(module, torch.nn.modules.conv._ConvNd):
|
58 |
+
# TO-DO: check litterature
|
59 |
+
linear_init(module)
|
60 |
+
elif isinstance(module, nn.Linear):
|
61 |
+
linear_init(module)
|
disvae/utils/math.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
|
4 |
+
from tqdm import trange, tqdm
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def matrix_log_density_gaussian(x, mu, logvar):
|
9 |
+
"""Calculates log density of a Gaussian for all combination of bacth pairs of
|
10 |
+
`x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
|
11 |
+
instead of (batch_size, dim) in the usual log density.
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
----------
|
15 |
+
x: torch.Tensor
|
16 |
+
Value at which to compute the density. Shape: (batch_size, dim).
|
17 |
+
|
18 |
+
mu: torch.Tensor
|
19 |
+
Mean. Shape: (batch_size, dim).
|
20 |
+
|
21 |
+
logvar: torch.Tensor
|
22 |
+
Log variance. Shape: (batch_size, dim).
|
23 |
+
|
24 |
+
batch_size: int
|
25 |
+
number of training images in the batch
|
26 |
+
"""
|
27 |
+
batch_size, dim = x.shape
|
28 |
+
x = x.view(batch_size, 1, dim)
|
29 |
+
mu = mu.view(1, batch_size, dim)
|
30 |
+
logvar = logvar.view(1, batch_size, dim)
|
31 |
+
return log_density_gaussian(x, mu, logvar)
|
32 |
+
|
33 |
+
|
34 |
+
def log_density_gaussian(x, mu, logvar):
|
35 |
+
"""Calculates log density of a Gaussian.
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
x: torch.Tensor or np.ndarray or float
|
40 |
+
Value at which to compute the density.
|
41 |
+
|
42 |
+
mu: torch.Tensor or np.ndarray or float
|
43 |
+
Mean.
|
44 |
+
|
45 |
+
logvar: torch.Tensor or np.ndarray or float
|
46 |
+
Log variance.
|
47 |
+
"""
|
48 |
+
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
|
49 |
+
inv_var = torch.exp(-logvar)
|
50 |
+
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
|
51 |
+
return log_density
|
52 |
+
|
53 |
+
|
54 |
+
def log_importance_weight_matrix(batch_size, dataset_size):
|
55 |
+
"""
|
56 |
+
Calculates a log importance weight matrix
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
batch_size: int
|
61 |
+
number of training images in the batch
|
62 |
+
|
63 |
+
dataset_size: int
|
64 |
+
number of training images in the dataset
|
65 |
+
"""
|
66 |
+
N = dataset_size
|
67 |
+
M = batch_size - 1
|
68 |
+
strat_weight = (N - M) / (N * M)
|
69 |
+
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
|
70 |
+
W.view(-1)[::M + 1] = 1 / N
|
71 |
+
W.view(-1)[1::M + 1] = strat_weight
|
72 |
+
W[M - 1, 0] = strat_weight
|
73 |
+
return W.log()
|
disvae/utils/modelIO.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from disvae.models.vae import init_specific_model
|
10 |
+
|
11 |
+
MODEL_FILENAME = "model.pt"
|
12 |
+
META_FILENAME = "specs.json"
|
13 |
+
|
14 |
+
|
15 |
+
def vae2onnx(vae, p_out: str) -> None:
|
16 |
+
if isinstance(vae, str):
|
17 |
+
p_out = Path(p_out)
|
18 |
+
if not p_out.exists():
|
19 |
+
p_out.mkdir()
|
20 |
+
|
21 |
+
device = next(vae.parameters()).device
|
22 |
+
vae.cpu()
|
23 |
+
|
24 |
+
# Encoder
|
25 |
+
vae.encoder.eval()
|
26 |
+
dummy_input_im = torch.zeros(tuple(np.concatenate([[1], vae.img_size])))
|
27 |
+
torch.onnx.export(vae.encoder, dummy_input_im, p_out / "encoder.onnx", verbose=True)
|
28 |
+
|
29 |
+
# Decoder
|
30 |
+
vae.decoder.eval()
|
31 |
+
dummy_input_latent = torch.zeros((1, vae.latent_dim))
|
32 |
+
torch.onnx.export(
|
33 |
+
vae.decoder, dummy_input_latent, p_out / "decoder.onnx", verbose=True
|
34 |
+
)
|
35 |
+
|
36 |
+
vae.to(device) # restore device
|
37 |
+
|
38 |
+
|
39 |
+
def save_model(model, directory, metadata=None, filename=MODEL_FILENAME):
|
40 |
+
"""
|
41 |
+
Save a model and corresponding metadata.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
model : nn.Module
|
46 |
+
Model.
|
47 |
+
|
48 |
+
directory : str
|
49 |
+
Path to the directory where to save the data.
|
50 |
+
|
51 |
+
metadata : dict
|
52 |
+
Metadata to save.
|
53 |
+
"""
|
54 |
+
device = next(model.parameters()).device
|
55 |
+
model.cpu()
|
56 |
+
|
57 |
+
if metadata is None:
|
58 |
+
# save the minimum required for loading
|
59 |
+
metadata = dict(
|
60 |
+
img_size=model.img_size,
|
61 |
+
latent_dim=model.latent_dim,
|
62 |
+
model_type=model.model_type,
|
63 |
+
)
|
64 |
+
|
65 |
+
save_metadata(metadata, directory)
|
66 |
+
|
67 |
+
path_to_model = os.path.join(directory, filename)
|
68 |
+
torch.save(model.state_dict(), path_to_model)
|
69 |
+
|
70 |
+
model.to(device) # restore device
|
71 |
+
|
72 |
+
|
73 |
+
def load_metadata(directory, filename=META_FILENAME):
|
74 |
+
"""Load the metadata of a training directory.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
directory : string
|
79 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
80 |
+
"""
|
81 |
+
path_to_metadata = os.path.join(directory, filename)
|
82 |
+
|
83 |
+
with open(path_to_metadata) as metadata_file:
|
84 |
+
metadata = json.load(metadata_file)
|
85 |
+
|
86 |
+
return metadata
|
87 |
+
|
88 |
+
|
89 |
+
def save_metadata(metadata, directory, filename=META_FILENAME, **kwargs):
|
90 |
+
"""Load the metadata of a training directory.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
metadata:
|
95 |
+
Object to save
|
96 |
+
|
97 |
+
directory: string
|
98 |
+
Path to folder where to save model. For example './experiments/mnist'.
|
99 |
+
|
100 |
+
kwargs:
|
101 |
+
Additional arguments to `json.dump`
|
102 |
+
"""
|
103 |
+
path_to_metadata = os.path.join(directory, filename)
|
104 |
+
|
105 |
+
with open(path_to_metadata, "w") as f:
|
106 |
+
json.dump(metadata, f, indent=4, sort_keys=True, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def load_model(directory, is_gpu=True, filename=MODEL_FILENAME):
|
110 |
+
"""Load a trained model.
|
111 |
+
|
112 |
+
Parameters
|
113 |
+
----------
|
114 |
+
directory : string
|
115 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
116 |
+
|
117 |
+
is_gpu : bool
|
118 |
+
Whether to load on GPU is available.
|
119 |
+
"""
|
120 |
+
device = torch.device("cuda" if torch.cuda.is_available() and is_gpu else "cpu")
|
121 |
+
|
122 |
+
path_to_model = os.path.join(directory, MODEL_FILENAME)
|
123 |
+
|
124 |
+
metadata = load_metadata(directory)
|
125 |
+
img_size = metadata["img_size"]
|
126 |
+
latent_dim = metadata["latent_dim"]
|
127 |
+
model_type = metadata["model_type"]
|
128 |
+
|
129 |
+
path_to_model = os.path.join(directory, filename)
|
130 |
+
model = _get_model(model_type, img_size, latent_dim, device, path_to_model)
|
131 |
+
return model
|
132 |
+
|
133 |
+
|
134 |
+
def load_checkpoints(directory, is_gpu=True):
|
135 |
+
"""Load all chechpointed models.
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
directory : string
|
140 |
+
Path to folder where model is saved. For example './experiments/mnist'.
|
141 |
+
|
142 |
+
is_gpu : bool
|
143 |
+
Whether to load on GPU .
|
144 |
+
"""
|
145 |
+
checkpoints = []
|
146 |
+
for root, _, filenames in os.walk(directory):
|
147 |
+
for filename in filenames:
|
148 |
+
results = re.search(r".*?-([0-9].*?).pt", filename)
|
149 |
+
if results is not None:
|
150 |
+
epoch_idx = int(results.group(1))
|
151 |
+
model = load_model(root, is_gpu=is_gpu, filename=filename)
|
152 |
+
checkpoints.append((epoch_idx, model))
|
153 |
+
|
154 |
+
return checkpoints
|
155 |
+
|
156 |
+
|
157 |
+
def _get_model(model_type, img_size, latent_dim, device, path_to_model):
|
158 |
+
"""Load a single model.
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
model_type : str
|
163 |
+
The name of the model to load. For example Burgess.
|
164 |
+
img_size : tuple
|
165 |
+
Tuple of the number of pixels in the image width and height.
|
166 |
+
For example (32, 32) or (64, 64).
|
167 |
+
latent_dim : int
|
168 |
+
The number of latent dimensions in the bottleneck.
|
169 |
+
|
170 |
+
device : str
|
171 |
+
Either 'cuda' or 'cpu'
|
172 |
+
path_to_device : str
|
173 |
+
Full path to the saved model on the device.
|
174 |
+
"""
|
175 |
+
model = init_specific_model(model_type, img_size, latent_dim).to(device)
|
176 |
+
# works with state_dict to make it independent of the file structure
|
177 |
+
model.load_state_dict(torch.load(path_to_model), strict=False)
|
178 |
+
model.eval()
|
179 |
+
|
180 |
+
return model
|
181 |
+
|
182 |
+
|
183 |
+
def numpy_serialize(obj):
|
184 |
+
if type(obj).__module__ == np.__name__:
|
185 |
+
if isinstance(obj, np.ndarray):
|
186 |
+
return obj.tolist()
|
187 |
+
else:
|
188 |
+
return obj.item()
|
189 |
+
raise TypeError("Unknown type:", type(obj))
|
190 |
+
|
191 |
+
|
192 |
+
def save_np_arrays(arrays, directory, filename):
|
193 |
+
"""Save dictionary of arrays in json file."""
|
194 |
+
save_metadata(arrays, directory, filename=filename, default=numpy_serialize)
|
195 |
+
|
196 |
+
|
197 |
+
def load_np_arrays(directory, filename):
|
198 |
+
"""Load dictionary of arrays from json file."""
|
199 |
+
arrays = load_metadata(directory, filename=filename)
|
200 |
+
return {k: np.array(v) for k, v in arrays.items()}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
#pandas
|
3 |
+
#plotly
|
4 |
+
# scipy
|
5 |
+
# tqdm
|
6 |
+
# pillow
|
7 |
+
|
8 |
+
# ipywidgets
|
9 |
+
# jupyterlab
|
10 |
+
|
11 |
+
torch
|
transforms.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alle transforms sind grundsätzlich auf batches bezogen!
|
3 |
+
Vae transforms sind invertierbar
|
4 |
+
"""
|
5 |
+
import pickle
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from functools import partial, reduce, wraps
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Allgemeine Funktionen -------------------------------------------------------------
|
13 |
+
# Transformations in Pytorch sind am einfachsten.
|
14 |
+
|
15 |
+
|
16 |
+
def load(p):
|
17 |
+
with open(p, "rb") as stream:
|
18 |
+
return pickle.load(stream)
|
19 |
+
|
20 |
+
|
21 |
+
def save(obj, p):
|
22 |
+
with open(p, "wb") as stream:
|
23 |
+
pickle.dump(obj, stream)
|
24 |
+
|
25 |
+
|
26 |
+
def sequential_function(*functions):
|
27 |
+
return lambda x: reduce(lambda res, func: func(res), functions, x)
|
28 |
+
|
29 |
+
|
30 |
+
def np_sample(func):
|
31 |
+
rtn = sequential_function(
|
32 |
+
lambda x: torch.from_numpy(x).float(),
|
33 |
+
lambda x: torch.unsqueeze(x, 0),
|
34 |
+
func,
|
35 |
+
lambda x: x[0].numpy(),
|
36 |
+
)
|
37 |
+
return rtn
|
38 |
+
|
39 |
+
|
40 |
+
# Inverseabvle
|
41 |
+
class SequentialInversable(torch.nn.Sequential):
|
42 |
+
def __init__(self, *functions):
|
43 |
+
super().__init__(*functions)
|
44 |
+
|
45 |
+
self.inv_funcs = [f.inv for f in functions]
|
46 |
+
self.inv_funcs.reverse()
|
47 |
+
|
48 |
+
# def forward(self, x):
|
49 |
+
# return sequential_function(*self.functions)(x)
|
50 |
+
|
51 |
+
def inv(self, x):
|
52 |
+
return sequential_function(*self.inv_funcs)(x)
|
53 |
+
|
54 |
+
|
55 |
+
class LatentSelector(torch.nn.Module):
|
56 |
+
"""Verarbeitet Tensoren und numpy arrays"""
|
57 |
+
|
58 |
+
def __init__(self, ldim: int, selectdim: int):
|
59 |
+
super().__init__()
|
60 |
+
self.ldim = ldim
|
61 |
+
self.selectdim = selectdim
|
62 |
+
|
63 |
+
def forward(self, x: torch.Tensor):
|
64 |
+
return x[:, : self.selectdim]
|
65 |
+
|
66 |
+
def inv(self, x: torch.Tensor):
|
67 |
+
rtn = torch.cat(
|
68 |
+
[x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
|
69 |
+
dim=1,
|
70 |
+
)
|
71 |
+
return rtn
|
72 |
+
|
73 |
+
|
74 |
+
class MinMaxScaler(torch.nn.Module):
|
75 |
+
#! Bei mehreren Signalen vorsicht mit dem Broadcasting.
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
_min: torch.Tensor,
|
79 |
+
_max: torch.Tensor,
|
80 |
+
min_norm: float = 0.0,
|
81 |
+
max_norm: float = 1.0,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
self._min = _min
|
85 |
+
self._max = _max
|
86 |
+
self.min_norm = min_norm
|
87 |
+
self.max_norm = max_norm
|
88 |
+
|
89 |
+
def forward(self, ts):
|
90 |
+
"""None, no_signals"""
|
91 |
+
std = (ts - self._min) / (self._max - self._min)
|
92 |
+
rtn = std * (self.max_norm - self.min_norm) + self.min_norm
|
93 |
+
return rtn
|
94 |
+
|
95 |
+
def inv(self, ts):
|
96 |
+
std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
|
97 |
+
rtn = std * (self._max - self._min) + self._min
|
98 |
+
return rtn
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def from_array(cls, arr: torch.Tensor):
|
102 |
+
_min = torch.min(arr, axis=0).values
|
103 |
+
_max = torch.max(arr, axis=0).values
|
104 |
+
|
105 |
+
return cls(_min, _max)
|
106 |
+
|
107 |
+
|
108 |
+
class LatentSorter(torch.nn.Module):
|
109 |
+
def __init__(self, kl_dict: dict):
|
110 |
+
super().__init__()
|
111 |
+
self.kl_dict = kl_dict
|
112 |
+
|
113 |
+
def forward(self, latent):
|
114 |
+
"""
|
115 |
+
unsorted -> sorted
|
116 |
+
latent: (None, latent_dim)
|
117 |
+
"""
|
118 |
+
return latent[:, list(self.kl_dict.keys())]
|
119 |
+
|
120 |
+
def inv(self, latent):
|
121 |
+
keys = np.array(list(self.kl_dict.keys()))
|
122 |
+
return latent[:, torch.from_numpy(keys.argsort())]
|
123 |
+
|
124 |
+
@property
|
125 |
+
def names(self):
|
126 |
+
rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
|
127 |
+
return rtn
|
128 |
+
|
129 |
+
|
130 |
+
def apply_along_axis(function, x, axis: int = 0):
|
131 |
+
return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)
|
132 |
+
|
133 |
+
|
134 |
+
# Eingangsshapes bleiben wie sie sind!
|
135 |
+
class SumField(torch.nn.Module):
|
136 |
+
"""
|
137 |
+
time series: [idx, time_step, signal]
|
138 |
+
image: [idx, signal, time_step, time_step]
|
139 |
+
"""
|
140 |
+
|
141 |
+
def forward(self, ts: torch.Tensor):
|
142 |
+
"""ts2img"""
|
143 |
+
|
144 |
+
samples = ts.shape[0]
|
145 |
+
time = ts.shape[1]
|
146 |
+
channels = ts.shape[2]
|
147 |
+
|
148 |
+
ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende
|
149 |
+
ts = torch.reshape(
|
150 |
+
ts, (samples * channels, time)
|
151 |
+
) # Zusammenfassen von Channel + idx
|
152 |
+
#! TODO: Schleife besser lösen
|
153 |
+
rtn = apply_along_axis(self._mtf_forward, ts, 0)
|
154 |
+
rtn = torch.reshape(rtn, (samples, channels, time, time))
|
155 |
+
|
156 |
+
return rtn
|
157 |
+
|
158 |
+
def inv(self, img: torch.Tensor):
|
159 |
+
"""img2ts"""
|
160 |
+
rtn = torch.diagonal(img, dim1=2, dim2=3)
|
161 |
+
rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen
|
162 |
+
|
163 |
+
return rtn
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def _mtf_forward(ts):
|
167 |
+
"""For one dimensional time series ts"""
|
168 |
+
return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2
|