vae_celeba / transforms.py
Jonas Becker
1st try
c53ddec
raw
history blame
4.47 kB
"""
Alle transforms sind grundsätzlich auf batches bezogen!
Vae transforms sind invertierbar
"""
import pickle
from dataclasses import dataclass
from functools import partial, reduce, wraps
import numpy as np
import torch
# Allgemeine Funktionen -------------------------------------------------------------
# Transformations in Pytorch sind am einfachsten.
def load(p):
with open(p, "rb") as stream:
return pickle.load(stream)
def save(obj, p):
with open(p, "wb") as stream:
pickle.dump(obj, stream)
def sequential_function(*functions):
return lambda x: reduce(lambda res, func: func(res), functions, x)
def np_sample(func):
rtn = sequential_function(
lambda x: torch.from_numpy(x).float(),
lambda x: torch.unsqueeze(x, 0),
func,
lambda x: x[0].numpy(),
)
return rtn
# Inverseabvle
class SequentialInversable(torch.nn.Sequential):
def __init__(self, *functions):
super().__init__(*functions)
self.inv_funcs = [f.inv for f in functions]
self.inv_funcs.reverse()
# def forward(self, x):
# return sequential_function(*self.functions)(x)
def inv(self, x):
return sequential_function(*self.inv_funcs)(x)
class LatentSelector(torch.nn.Module):
"""Verarbeitet Tensoren und numpy arrays"""
def __init__(self, ldim: int, selectdim: int):
super().__init__()
self.ldim = ldim
self.selectdim = selectdim
def forward(self, x: torch.Tensor):
return x[:, : self.selectdim]
def inv(self, x: torch.Tensor):
rtn = torch.cat(
[x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)],
dim=1,
)
return rtn
class MinMaxScaler(torch.nn.Module):
#! Bei mehreren Signalen vorsicht mit dem Broadcasting.
def __init__(
self,
_min: torch.Tensor,
_max: torch.Tensor,
min_norm: float = 0.0,
max_norm: float = 1.0,
):
super().__init__()
self._min = _min
self._max = _max
self.min_norm = min_norm
self.max_norm = max_norm
def forward(self, ts):
"""None, no_signals"""
std = (ts - self._min) / (self._max - self._min)
rtn = std * (self.max_norm - self.min_norm) + self.min_norm
return rtn
def inv(self, ts):
std = (ts - self.min_norm) / (self.max_norm - self.min_norm)
rtn = std * (self._max - self._min) + self._min
return rtn
@classmethod
def from_array(cls, arr: torch.Tensor):
_min = torch.min(arr, axis=0).values
_max = torch.max(arr, axis=0).values
return cls(_min, _max)
class LatentSorter(torch.nn.Module):
def __init__(self, kl_dict: dict):
super().__init__()
self.kl_dict = kl_dict
def forward(self, latent):
"""
unsorted -> sorted
latent: (None, latent_dim)
"""
return latent[:, list(self.kl_dict.keys())]
def inv(self, latent):
keys = np.array(list(self.kl_dict.keys()))
return latent[:, torch.from_numpy(keys.argsort())]
@property
def names(self):
rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()]
return rtn
def apply_along_axis(function, x, axis: int = 0):
return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis)
# Eingangsshapes bleiben wie sie sind!
class SumField(torch.nn.Module):
"""
time series: [idx, time_step, signal]
image: [idx, signal, time_step, time_step]
"""
def forward(self, ts: torch.Tensor):
"""ts2img"""
samples = ts.shape[0]
time = ts.shape[1]
channels = ts.shape[2]
ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende
ts = torch.reshape(
ts, (samples * channels, time)
) # Zusammenfassen von Channel + idx
#! TODO: Schleife besser lösen
rtn = apply_along_axis(self._mtf_forward, ts, 0)
rtn = torch.reshape(rtn, (samples, channels, time, time))
return rtn
def inv(self, img: torch.Tensor):
"""img2ts"""
rtn = torch.diagonal(img, dim1=2, dim2=3)
rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen
return rtn
@staticmethod
def _mtf_forward(ts):
"""For one dimensional time series ts"""
return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2