Spaces:
Runtime error
Runtime error
import math | |
import torch | |
from torch.utils.data._utils.collate import default_collate | |
DEFAULT_PAD_VALUES = { | |
'aa': 21, | |
'chain_id': ' ', | |
'icode': ' ', | |
} | |
DEFAULT_NO_PADDING = { | |
'origin', | |
} | |
class PaddingCollate(object): | |
def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True): | |
super().__init__() | |
self.length_ref_key = length_ref_key | |
self.pad_values = pad_values | |
self.no_padding = no_padding | |
self.eight = eight | |
def _pad_last(x, n, value=0): | |
if isinstance(x, torch.Tensor): | |
assert x.size(0) <= n | |
if x.size(0) == n: | |
return x | |
pad_size = [n - x.size(0)] + list(x.shape[1:]) | |
pad = torch.full(pad_size, fill_value=value).to(x) | |
return torch.cat([x, pad], dim=0) | |
elif isinstance(x, list): | |
pad = [value] * (n - len(x)) | |
return x + pad | |
else: | |
return x | |
def _get_pad_mask(l, n): | |
return torch.cat([ | |
torch.ones([l], dtype=torch.bool), | |
torch.zeros([n-l], dtype=torch.bool) | |
], dim=0) | |
def _get_common_keys(list_of_dict): | |
keys = set(list_of_dict[0].keys()) | |
for d in list_of_dict[1:]: | |
keys = keys.intersection(d.keys()) | |
return keys | |
def _get_pad_value(self, key): | |
if key not in self.pad_values: | |
return 0 | |
return self.pad_values[key] | |
def __call__(self, data_list): | |
max_length = max([data[self.length_ref_key].size(0) for data in data_list]) | |
keys = self._get_common_keys(data_list) | |
if self.eight: | |
max_length = math.ceil(max_length / 8) * 8 | |
data_list_padded = [] | |
for data in data_list: | |
data_padded = { | |
k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v | |
for k, v in data.items() | |
if k in keys | |
} | |
data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length) | |
data_list_padded.append(data_padded) | |
return default_collate(data_list_padded) | |
def apply_patch_to_tensor(x_full, x_patch, patch_idx): | |
""" | |
Args: | |
x_full: (N, ...) | |
x_patch: (M, ...) | |
patch_idx: (M, ) | |
Returns: | |
(N, ...) | |
""" | |
x_full = x_full.clone() | |
x_full[patch_idx] = x_patch | |
return x_full | |