Spaces:
Runtime error
Runtime error
import torch | |
from ..protein import constants | |
from ._base import register_transform | |
class MergeChains(object): | |
def __init__(self): | |
super().__init__() | |
def assign_chain_number_(self, data_list): | |
chains = set() | |
for data in data_list: | |
chains.update(data['chain_id']) | |
chains = {c: i for i, c in enumerate(chains)} | |
for data in data_list: | |
data['chain_nb'] = torch.LongTensor([ | |
chains[c] for c in data['chain_id'] | |
]) | |
def _data_attr(self, data, name): | |
if name in ('generate_flag', 'anchor_flag') and name not in data: | |
return torch.zeros(data['aa'].shape, dtype=torch.bool) | |
else: | |
return data[name] | |
def __call__(self, structure): | |
data_list = [] | |
if structure['heavy'] is not None: | |
structure['heavy']['fragment_type'] = torch.full_like( | |
structure['heavy']['aa'], | |
fill_value = constants.Fragment.Heavy, | |
) | |
data_list.append(structure['heavy']) | |
if structure['light'] is not None: | |
structure['light']['fragment_type'] = torch.full_like( | |
structure['light']['aa'], | |
fill_value = constants.Fragment.Light, | |
) | |
data_list.append(structure['light']) | |
if structure['antigen'] is not None: | |
structure['antigen']['fragment_type'] = torch.full_like( | |
structure['antigen']['aa'], | |
fill_value = constants.Fragment.Antigen, | |
) | |
structure['antigen']['cdr_flag'] = torch.zeros_like( | |
structure['antigen']['aa'], | |
) | |
data_list.append(structure['antigen']) | |
self.assign_chain_number_(data_list) | |
list_props = { | |
'chain_id': [], | |
'icode': [], | |
} | |
tensor_props = { | |
'chain_nb': [], | |
'resseq': [], | |
'res_nb': [], | |
'aa': [], | |
'pos_heavyatom': [], | |
'mask_heavyatom': [], | |
'generate_flag': [], | |
'cdr_flag': [], | |
'anchor_flag': [], | |
'fragment_type': [], | |
} | |
for data in data_list: | |
for k in list_props.keys(): | |
list_props[k].append(self._data_attr(data, k)) | |
for k in tensor_props.keys(): | |
tensor_props[k].append(self._data_attr(data, k)) | |
list_props = {k: sum(v, start=[]) for k, v in list_props.items()} | |
tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()} | |
data_out = { | |
**list_props, | |
**tensor_props, | |
} | |
return data_out | |