Spaces:
Sleeping
Sleeping
""" | |
Alle transforms sind grundsätzlich auf batches bezogen! | |
Vae transforms sind invertierbar | |
""" | |
import pickle | |
from dataclasses import dataclass | |
from functools import partial, reduce, wraps | |
import numpy as np | |
import torch | |
# Allgemeine Funktionen ------------------------------------------------------------- | |
# Transformations in Pytorch sind am einfachsten. | |
def load(p): | |
with open(p, "rb") as stream: | |
return pickle.load(stream) | |
def save(obj, p): | |
with open(p, "wb") as stream: | |
pickle.dump(obj, stream) | |
def sequential_function(*functions): | |
return lambda x: reduce(lambda res, func: func(res), functions, x) | |
def np_sample(func): | |
rtn = sequential_function( | |
lambda x: torch.from_numpy(x).float(), | |
lambda x: torch.unsqueeze(x, 0), | |
func, | |
lambda x: x[0].numpy(), | |
) | |
return rtn | |
# Inverseabvle | |
class SequentialInversable(torch.nn.Sequential): | |
def __init__(self, *functions): | |
super().__init__(*functions) | |
self.inv_funcs = [f.inv for f in functions] | |
self.inv_funcs.reverse() | |
# def forward(self, x): | |
# return sequential_function(*self.functions)(x) | |
def inv(self, x): | |
return sequential_function(*self.inv_funcs)(x) | |
class LatentSelector(torch.nn.Module): | |
"""Verarbeitet Tensoren und numpy arrays""" | |
def __init__(self, ldim: int, selectdim: int): | |
super().__init__() | |
self.ldim = ldim | |
self.selectdim = selectdim | |
def forward(self, x: torch.Tensor): | |
return x[:, : self.selectdim] | |
def inv(self, x: torch.Tensor): | |
rtn = torch.cat( | |
[x, torch.zeros((x.shape[0], self.ldim - x.shape[1]), device=x.device)], | |
dim=1, | |
) | |
return rtn | |
class MinMaxScaler(torch.nn.Module): | |
#! Bei mehreren Signalen vorsicht mit dem Broadcasting. | |
def __init__( | |
self, | |
_min: torch.Tensor, | |
_max: torch.Tensor, | |
min_norm: float = 0.0, | |
max_norm: float = 1.0, | |
): | |
super().__init__() | |
self._min = _min | |
self._max = _max | |
self.min_norm = min_norm | |
self.max_norm = max_norm | |
def forward(self, ts): | |
"""None, no_signals""" | |
std = (ts - self._min) / (self._max - self._min) | |
rtn = std * (self.max_norm - self.min_norm) + self.min_norm | |
return rtn | |
def inv(self, ts): | |
std = (ts - self.min_norm) / (self.max_norm - self.min_norm) | |
rtn = std * (self._max - self._min) + self._min | |
return rtn | |
def from_array(cls, arr: torch.Tensor): | |
_min = torch.min(arr, axis=0).values | |
_max = torch.max(arr, axis=0).values | |
return cls(_min, _max) | |
class LatentSorter(torch.nn.Module): | |
def __init__(self, kl_dict: dict): | |
super().__init__() | |
self.kl_dict = kl_dict | |
def forward(self, latent): | |
""" | |
unsorted -> sorted | |
latent: (None, latent_dim) | |
""" | |
return latent[:, list(self.kl_dict.keys())] | |
def inv(self, latent): | |
keys = np.array(list(self.kl_dict.keys())) | |
return latent[:, torch.from_numpy(keys.argsort())] | |
def names(self): | |
rtn = ["{} KL{:.2f}".format(k, v) for k, v in self.kl_dict.items()] | |
return rtn | |
def apply_along_axis(function, x, axis: int = 0): | |
return torch.stack([function(x_i) for x_i in torch.unbind(x, dim=axis)], dim=axis) | |
# Eingangsshapes bleiben wie sie sind! | |
class SumField(torch.nn.Module): | |
""" | |
time series: [idx, time_step, signal] | |
image: [idx, signal, time_step, time_step] | |
""" | |
def forward(self, ts: torch.Tensor): | |
"""ts2img""" | |
samples = ts.shape[0] | |
time = ts.shape[1] | |
channels = ts.shape[2] | |
ts = torch.swapaxes(ts, 1, 2) # Zeitachse ans Ende | |
ts = torch.reshape( | |
ts, (samples * channels, time) | |
) # Zusammenfassen von Channel + idx | |
#! TODO: Schleife besser lösen | |
rtn = apply_along_axis(self._mtf_forward, ts, 0) | |
rtn = torch.reshape(rtn, (samples, channels, time, time)) | |
return rtn | |
def inv(self, img: torch.Tensor): | |
"""img2ts""" | |
rtn = torch.diagonal(img, dim1=2, dim2=3) | |
rtn = torch.swapaxes(rtn, 1, 2) # Channel und Zeitachse tauschen | |
return rtn | |
def _mtf_forward(ts): | |
"""For one dimensional time series ts""" | |
return torch.add(*torch.meshgrid(ts, ts, indexing="ij")) / 2 | |