Spaces:
Runtime error
Runtime error
File size: 2,748 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
from ..protein import constants
from ._base import register_transform
@register_transform('merge_chains')
class MergeChains(object):
def __init__(self):
super().__init__()
def assign_chain_number_(self, data_list):
chains = set()
for data in data_list:
chains.update(data['chain_id'])
chains = {c: i for i, c in enumerate(chains)}
for data in data_list:
data['chain_nb'] = torch.LongTensor([
chains[c] for c in data['chain_id']
])
def _data_attr(self, data, name):
if name in ('generate_flag', 'anchor_flag') and name not in data:
return torch.zeros(data['aa'].shape, dtype=torch.bool)
else:
return data[name]
def __call__(self, structure):
data_list = []
if structure['heavy'] is not None:
structure['heavy']['fragment_type'] = torch.full_like(
structure['heavy']['aa'],
fill_value = constants.Fragment.Heavy,
)
data_list.append(structure['heavy'])
if structure['light'] is not None:
structure['light']['fragment_type'] = torch.full_like(
structure['light']['aa'],
fill_value = constants.Fragment.Light,
)
data_list.append(structure['light'])
if structure['antigen'] is not None:
structure['antigen']['fragment_type'] = torch.full_like(
structure['antigen']['aa'],
fill_value = constants.Fragment.Antigen,
)
structure['antigen']['cdr_flag'] = torch.zeros_like(
structure['antigen']['aa'],
)
data_list.append(structure['antigen'])
self.assign_chain_number_(data_list)
list_props = {
'chain_id': [],
'icode': [],
}
tensor_props = {
'chain_nb': [],
'resseq': [],
'res_nb': [],
'aa': [],
'pos_heavyatom': [],
'mask_heavyatom': [],
'generate_flag': [],
'cdr_flag': [],
'anchor_flag': [],
'fragment_type': [],
}
for data in data_list:
for k in list_props.keys():
list_props[k].append(self._data_attr(data, k))
for k in tensor_props.keys():
tensor_props[k].append(self._data_attr(data, k))
list_props = {k: sum(v, start=[]) for k, v in list_props.items()}
tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()}
data_out = {
**list_props,
**tensor_props,
}
return data_out
|