Spaces:
Sleeping
Sleeping
import h5py | |
import numpy as np | |
import pandas as pd | |
import torch | |
from dotmap import DotMap | |
from salad.utils.paths import DATA_DIR | |
from salad.utils import thutil | |
class SALADDataset(torch.utils.data.Dataset): | |
def __init__(self, data_path, repeat=None, **kwargs): | |
super().__init__() | |
self.data_path = str(DATA_DIR / data_path) | |
self.repeat = repeat | |
self.__dict__.update(kwargs) | |
self.hparams = DotMap(self.__dict__) | |
""" | |
Global Data statistics. | |
""" | |
if self.hparams.get("global_normalization"): | |
with h5py.File(self.data_path.replace(".hdf5", "_mean_std.hdf5")) as f: | |
self.global_mean = f["mean"][:].astype(np.float32) | |
self.global_std = f["std"][:].astype(np.float32) | |
self.data = dict() | |
with h5py.File(self.data_path) as f: | |
for k in self.hparams.data_keys: | |
self.data[k] = f[k][:].astype(np.float32) | |
""" | |
global_normalization arg is for gaussians only. | |
""" | |
if k == "g_js_affine": | |
if self.hparams.get("global_normalization") == "partial": | |
assert k == "g_js_affine" | |
if self.hparams.get("verbose"): | |
print("[*] Normalize data only for pi and eigenvalues.") | |
# 3: mu, 9: eigvec, 1: pi, 3: eigval | |
self.data[k] = self.normalize_global_static( | |
self.data[k], slice(12, None) | |
) | |
elif self.hparams.get("global_normalization") == "all": | |
assert k == "g_js_affine" | |
if self.hparams.get("verbose"): | |
print("[*] Normalize data for all elements.") | |
self.data[k] = self.normalize_global_static( | |
self.data[k], slice(None) | |
) | |
def __getitem__(self, idx): | |
if self.repeat is not None and self.repeat > 1: | |
idx = int(idx / self.repeat) | |
items = [] | |
for k in self.hparams.data_keys: | |
data = torch.from_numpy(self.data[k][idx]) | |
items.append(data) | |
if self.hparams.get("concat_data"): | |
return torch.cat(items, -1) # [16,528] | |
if len(items) == 1: | |
return items[0] | |
return items | |
def __len__(self): | |
k = self.hparams.data_keys[0] | |
if self.repeat is not None and self.repeat > 1: | |
return len(self.data[k]) * self.repeat | |
return len(self.data[k]) | |
def get_other_latents(self, key): | |
with h5py.File(self.data_path) as f: | |
return f[key][:].astype(np.float32) | |
def normalize_global_static(self, data: np.ndarray, normalize_indices=slice(None)): | |
""" | |
Input: | |
np.ndarray or torch.Tensor. [16,16] or [B,16,16] | |
slice(None) -> full | |
slice(12, None) -> partial | |
Output: | |
[16,16] or [B,16,16] | |
""" | |
assert normalize_indices == slice(None) or normalize_indices == slice( | |
12, None | |
), print(f"{normalize_indices} is wrong.") | |
data = thutil.th2np(data).copy() | |
data[..., normalize_indices] = ( | |
data[..., normalize_indices] - self.global_mean[normalize_indices] | |
) / self.global_std[normalize_indices] | |
return data | |
def unnormalize_global_static( | |
self, data: np.ndarray, unnormalize_indices=slice(None) | |
): | |
""" | |
Input: | |
np.ndarray or torch.Tensor. [16,16] or [B,16,16] | |
slice(None) -> full | |
slice(12, None) -> partial | |
Output: | |
[16,16] or [B,16,16] | |
""" | |
assert unnormalize_indices == slice(None) or unnormalize_indices == slice( | |
12, None | |
), print(f"{unnormalize_indices} is wrong.") | |
data = thutil.th2np(data).copy() | |
data[..., unnormalize_indices] = ( | |
data[..., unnormalize_indices] | |
) * self.global_std[unnormalize_indices] + self.global_mean[unnormalize_indices] | |
return data | |
class LangSALADDataset(SALADDataset): | |
def __init__(self, data_path, repeat=None, **kwargs): | |
super().__init__(data_path, repeat, **kwargs) | |
# self.game_data = pd.read_csv(self.hparams.lang_data_path) | |
self.game_data = pd.read_csv(DATA_DIR / "autosdf_spaghetti_intersec_game_data.csv") | |
self.shapenet_ids = np.array(self.game_data["sn"]) | |
self.spaghetti_indices = np.array(self.game_data["spaghetti_idx"]) # for 5401 | |
self.texts = np.array(self.game_data["text"]) | |
assert len(self.shapenet_ids) == len(self.spaghetti_indices) == len(self.texts) | |
def __getitem__(self, idx): | |
if self.repeat is not None and self.repeat > 1: | |
idx = int(idx / self.repeat) | |
spa_idx = self.spaghetti_indices[idx] | |
text = self.texts[idx] | |
latents = [] | |
for k in self.hparams.data_keys: | |
data = torch.from_numpy(self.data[k][spa_idx]) | |
latents.append(data) | |
item = latents + [text] | |
if self.hparams.get("concat_data"): | |
latents = torch.cat(latents, -1) | |
return latents, text | |
return item | |
def __len__(self): | |
if self.repeat is not None and self.repeat > 1: | |
return len(self.texts) * self.repeat | |
return len(self.texts) | |