Spaces:
Runtime error
Runtime error
Update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +134 -0
- app.py +1 -3
- design_dock.py +67 -0
- design_pdb.py +4 -0
- design_testset.py +4 -0
- diffab/datasets/__init__.py +4 -0
- diffab/datasets/_base.py +40 -0
- diffab/datasets/custom.py +200 -0
- diffab/datasets/sabdab.py +470 -0
- diffab/models/__init__.py +3 -0
- diffab/models/_base.py +13 -0
- diffab/models/diffab.py +142 -0
- diffab/modules/common/geometry.py +481 -0
- diffab/modules/common/layers.py +160 -0
- diffab/modules/common/so3.py +146 -0
- diffab/modules/common/structure.py +77 -0
- diffab/modules/common/topology.py +24 -0
- diffab/modules/diffusion/dpm_full.py +319 -0
- diffab/modules/diffusion/transition.py +223 -0
- diffab/modules/encoders/ga.py +193 -0
- diffab/modules/encoders/pair.py +102 -0
- diffab/modules/encoders/residue.py +92 -0
- diffab/tools/dock/base.py +28 -0
- diffab/tools/dock/hdock.py +164 -0
- diffab/tools/eval/__main__.py +4 -0
- diffab/tools/eval/base.py +125 -0
- diffab/tools/eval/energy.py +43 -0
- diffab/tools/eval/run.py +66 -0
- diffab/tools/eval/similarity.py +125 -0
- diffab/tools/relax/__main__.py +4 -0
- diffab/tools/relax/base.py +117 -0
- diffab/tools/relax/openmm_relaxer.py +144 -0
- diffab/tools/relax/pyrosetta_relaxer.py +189 -0
- diffab/tools/relax/run.py +85 -0
- diffab/tools/renumber/__init__.py +1 -0
- diffab/tools/renumber/__main__.py +5 -0
- diffab/tools/renumber/run.py +85 -0
- diffab/tools/runner/design_for_pdb.py +291 -0
- diffab/tools/runner/design_for_testset.py +243 -0
- diffab/utils/data.py +89 -0
- diffab/utils/inference.py +60 -0
- diffab/utils/misc.py +126 -0
- diffab/utils/protein/constants.py +319 -0
- diffab/utils/protein/parsers.py +109 -0
- diffab/utils/protein/writers.py +75 -0
- diffab/utils/train.py +151 -0
- diffab/utils/transforms/__init__.py +7 -0
- diffab/utils/transforms/_base.py +56 -0
- diffab/utils/transforms/mask.py +213 -0
- diffab/utils/transforms/merge.py +88 -0
.gitignore
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
/playgrounds*
|
132 |
+
/logs*
|
133 |
+
/results*
|
134 |
+
/*.csv
|
app.py
CHANGED
@@ -24,8 +24,6 @@ from diffab.tools.renumber.run import (
|
|
24 |
assign_number_to_sequence,
|
25 |
)
|
26 |
|
27 |
-
DIFFAB_DIR = os.path.realpath('./diffab-repo')
|
28 |
-
|
29 |
CDR_OPTIONS = OrderedDict()
|
30 |
CDR_OPTIONS['H_CDR1'] = 'H1'
|
31 |
CDR_OPTIONS['H_CDR2'] = 'H2'
|
@@ -96,7 +94,7 @@ def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_d
|
|
96 |
bufsize=1,
|
97 |
stdout=subprocess.PIPE,
|
98 |
stderr=subprocess.STDOUT,
|
99 |
-
cwd=
|
100 |
)
|
101 |
for line in iter(proc.stdout.readline, b''):
|
102 |
output_buffer += line.decode()
|
|
|
24 |
assign_number_to_sequence,
|
25 |
)
|
26 |
|
|
|
|
|
27 |
CDR_OPTIONS = OrderedDict()
|
28 |
CDR_OPTIONS['H_CDR1'] = 'H1'
|
29 |
CDR_OPTIONS['H_CDR2'] = 'H2'
|
|
|
94 |
bufsize=1,
|
95 |
stdout=subprocess.PIPE,
|
96 |
stderr=subprocess.STDOUT,
|
97 |
+
cwd=os.getcwd(),
|
98 |
)
|
99 |
for line in iter(proc.stdout.readline, b''):
|
100 |
output_buffer += line.decode()
|
design_dock.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import argparse
|
4 |
+
from diffab.tools.dock.hdock import HDockAntibody
|
5 |
+
from diffab.tools.runner.design_for_pdb import args_factory, design_for_pdb
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument('--antigen', type=str, required=True)
|
11 |
+
parser.add_argument('--antibody', type=str, default='./data/examples/3QHF_Fv.pdb')
|
12 |
+
parser.add_argument('--heavy', type=str, default='H', help='Chain id of the heavy chain.')
|
13 |
+
parser.add_argument('--light', type=str, default='L', help='Chain id of the light chain.')
|
14 |
+
parser.add_argument('--hdock_bin', type=str, default='./bin/hdock')
|
15 |
+
parser.add_argument('--createpl_bin', type=str, default='./bin/createpl')
|
16 |
+
parser.add_argument('-n', '--num_docks', type=int, default=10)
|
17 |
+
parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
|
18 |
+
parser.add_argument('-o', '--out_root', type=str, default='./results')
|
19 |
+
parser.add_argument('-t', '--tag', type=str, default='')
|
20 |
+
parser.add_argument('-s', '--seed', type=int, default=None)
|
21 |
+
parser.add_argument('-d', '--device', type=str, default='cuda')
|
22 |
+
parser.add_argument('-b', '--batch_size', type=int, default=16)
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
hdock_missing = []
|
26 |
+
if not os.path.exists(args.hdock_bin):
|
27 |
+
hdock_missing.append(args.hdock_bin)
|
28 |
+
if not os.path.exists(args.createpl_bin):
|
29 |
+
hdock_missing.append(args.createpl_bin)
|
30 |
+
if len(hdock_missing) > 0:
|
31 |
+
print("[WARNING] The following HDOCK applications are missing:")
|
32 |
+
for f in hdock_missing:
|
33 |
+
print(f" > {f}")
|
34 |
+
print("Please download HDOCK from http://huanglab.phys.hust.edu.cn/software/hdocklite/ "
|
35 |
+
"and put `hdock` and `createpl` to the above path.")
|
36 |
+
exit()
|
37 |
+
|
38 |
+
antigen_name = os.path.basename(os.path.splitext(args.antigen)[0])
|
39 |
+
docked_pdb_dir = os.path.join(os.path.splitext(args.antigen)[0] + '_dock')
|
40 |
+
os.makedirs(docked_pdb_dir, exist_ok=True)
|
41 |
+
docked_pdb_paths = []
|
42 |
+
for fname in os.listdir(docked_pdb_dir):
|
43 |
+
if fname.endswith('.pdb'):
|
44 |
+
docked_pdb_paths.append(os.path.join(docked_pdb_dir, fname))
|
45 |
+
if len(docked_pdb_paths) < args.num_docks:
|
46 |
+
with HDockAntibody() as dock_session:
|
47 |
+
dock_session.set_antigen(args.antigen)
|
48 |
+
dock_session.set_antibody(args.antibody)
|
49 |
+
docked_tmp_paths = dock_session.dock()
|
50 |
+
for i, tmp_path in enumerate(docked_tmp_paths[:args.num_docks]):
|
51 |
+
dest_path = os.path.join(docked_pdb_dir, f"{antigen_name}_Ab_{i:04d}.pdb")
|
52 |
+
shutil.copyfile(tmp_path, dest_path)
|
53 |
+
print(f'[INFO] Copy {tmp_path} -> {dest_path}')
|
54 |
+
docked_pdb_paths.append(dest_path)
|
55 |
+
|
56 |
+
for pdb_path in docked_pdb_paths:
|
57 |
+
current_args = vars(args)
|
58 |
+
current_args['tag'] += antigen_name
|
59 |
+
design_args = args_factory(
|
60 |
+
pdb_path = pdb_path,
|
61 |
+
**current_args,
|
62 |
+
)
|
63 |
+
design_for_pdb(design_args)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
main()
|
design_pdb.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffab.tools.runner.design_for_pdb import args_from_cmdline, design_for_pdb
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
design_for_pdb(args_from_cmdline())
|
design_testset.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffab.tools.runner.design_for_testset import main
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
main()
|
diffab/datasets/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sabdab import SAbDabDataset
|
2 |
+
from .custom import CustomDataset
|
3 |
+
|
4 |
+
from ._base import get_dataset
|
diffab/datasets/_base.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, ConcatDataset
|
2 |
+
from diffab.utils.transforms import get_transform
|
3 |
+
|
4 |
+
|
5 |
+
_DATASET_DICT = {}
|
6 |
+
|
7 |
+
|
8 |
+
def register_dataset(name):
|
9 |
+
def decorator(cls):
|
10 |
+
_DATASET_DICT[name] = cls
|
11 |
+
return cls
|
12 |
+
return decorator
|
13 |
+
|
14 |
+
|
15 |
+
def get_dataset(cfg):
|
16 |
+
transform = get_transform(cfg.transform) if 'transform' in cfg else None
|
17 |
+
return _DATASET_DICT[cfg.type](cfg, transform=transform)
|
18 |
+
|
19 |
+
|
20 |
+
@register_dataset('concat')
|
21 |
+
def get_concat_dataset(cfg):
|
22 |
+
datasets = [get_dataset(d) for d in cfg.datasets]
|
23 |
+
return ConcatDataset(datasets)
|
24 |
+
|
25 |
+
|
26 |
+
@register_dataset('balanced_concat')
|
27 |
+
class BalancedConcatDataset(Dataset):
|
28 |
+
|
29 |
+
def __init__(self, cfg, transform=None):
|
30 |
+
super().__init__()
|
31 |
+
assert transform is None, 'transform is not supported.'
|
32 |
+
self.datasets = [get_dataset(d) for d in cfg.datasets]
|
33 |
+
self.max_size = max([len(d) for d in self.datasets])
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return self.max_size * len(self.datasets)
|
37 |
+
|
38 |
+
def __getitem__(self, idx):
|
39 |
+
dataset_idx = idx // self.max_size
|
40 |
+
return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])]
|
diffab/datasets/custom.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import joblib
|
4 |
+
import pickle
|
5 |
+
import lmdb
|
6 |
+
from Bio import PDB
|
7 |
+
from Bio.PDB import PDBExceptions
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
from ..utils.protein import parsers
|
12 |
+
from .sabdab import _label_heavy_chain_cdr, _label_light_chain_cdr
|
13 |
+
from ._base import register_dataset
|
14 |
+
|
15 |
+
|
16 |
+
def preprocess_antibody_structure(task):
|
17 |
+
pdb_path = task['pdb_path']
|
18 |
+
H_id = task.get('heavy_id', 'H')
|
19 |
+
L_id = task.get('light_id', 'L')
|
20 |
+
|
21 |
+
parser = PDB.PDBParser(QUIET=True)
|
22 |
+
model = parser.get_structure(id, pdb_path)[0]
|
23 |
+
|
24 |
+
all_chain_ids = [c.id for c in model]
|
25 |
+
|
26 |
+
parsed = {
|
27 |
+
'id': task['id'],
|
28 |
+
'heavy': None,
|
29 |
+
'heavy_seqmap': None,
|
30 |
+
'light': None,
|
31 |
+
'light_seqmap': None,
|
32 |
+
'antigen': None,
|
33 |
+
'antigen_seqmap': None,
|
34 |
+
}
|
35 |
+
try:
|
36 |
+
if H_id in all_chain_ids:
|
37 |
+
(
|
38 |
+
parsed['heavy'],
|
39 |
+
parsed['heavy_seqmap']
|
40 |
+
) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure(
|
41 |
+
model[H_id],
|
42 |
+
max_resseq = 113 # Chothia, end of Heavy chain Fv
|
43 |
+
))
|
44 |
+
|
45 |
+
if L_id in all_chain_ids:
|
46 |
+
(
|
47 |
+
parsed['light'],
|
48 |
+
parsed['light_seqmap']
|
49 |
+
) = _label_light_chain_cdr(*parsers.parse_biopython_structure(
|
50 |
+
model[L_id],
|
51 |
+
max_resseq = 106 # Chothia, end of Light chain Fv
|
52 |
+
))
|
53 |
+
|
54 |
+
if parsed['heavy'] is None and parsed['light'] is None:
|
55 |
+
raise ValueError(
|
56 |
+
f'Neither valid antibody H-chain or L-chain is found. '
|
57 |
+
f'Please ensure that the chain id of heavy chain is "{H_id}" '
|
58 |
+
f'and the id of the light chain is "{L_id}".'
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
ag_chain_ids = [cid for cid in all_chain_ids if cid not in (H_id, L_id)]
|
63 |
+
if len(ag_chain_ids) > 0:
|
64 |
+
chains = [model[c] for c in ag_chain_ids]
|
65 |
+
(
|
66 |
+
parsed['antigen'],
|
67 |
+
parsed['antigen_seqmap']
|
68 |
+
) = parsers.parse_biopython_structure(chains)
|
69 |
+
|
70 |
+
except (
|
71 |
+
PDBExceptions.PDBConstructionException,
|
72 |
+
parsers.ParsingException,
|
73 |
+
KeyError,
|
74 |
+
ValueError,
|
75 |
+
) as e:
|
76 |
+
logging.warning('[{}] {}: {}'.format(
|
77 |
+
task['id'],
|
78 |
+
e.__class__.__name__,
|
79 |
+
str(e)
|
80 |
+
))
|
81 |
+
return None
|
82 |
+
|
83 |
+
return parsed
|
84 |
+
|
85 |
+
|
86 |
+
@register_dataset('custom')
|
87 |
+
class CustomDataset(Dataset):
|
88 |
+
|
89 |
+
MAP_SIZE = 32*(1024*1024*1024) # 32GB
|
90 |
+
|
91 |
+
def __init__(self, structure_dir, transform=None, reset=False):
|
92 |
+
super().__init__()
|
93 |
+
self.structure_dir = structure_dir
|
94 |
+
self.transform = transform
|
95 |
+
|
96 |
+
self.db_conn = None
|
97 |
+
self.db_ids = None
|
98 |
+
self._load_structures(reset)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def _cache_db_path(self):
|
102 |
+
return os.path.join(self.structure_dir, 'structure_cache.lmdb')
|
103 |
+
|
104 |
+
def _connect_db(self):
|
105 |
+
self._close_db()
|
106 |
+
self.db_conn = lmdb.open(
|
107 |
+
self._cache_db_path,
|
108 |
+
map_size=self.MAP_SIZE,
|
109 |
+
create=False,
|
110 |
+
subdir=False,
|
111 |
+
readonly=True,
|
112 |
+
lock=False,
|
113 |
+
readahead=False,
|
114 |
+
meminit=False,
|
115 |
+
)
|
116 |
+
with self.db_conn.begin() as txn:
|
117 |
+
keys = [k.decode() for k in txn.cursor().iternext(values=False)]
|
118 |
+
self.db_ids = keys
|
119 |
+
|
120 |
+
def _close_db(self):
|
121 |
+
if self.db_conn is not None:
|
122 |
+
self.db_conn.close()
|
123 |
+
self.db_conn = None
|
124 |
+
self.db_ids = None
|
125 |
+
|
126 |
+
def _load_structures(self, reset):
|
127 |
+
all_pdbs = []
|
128 |
+
for fname in os.listdir(self.structure_dir):
|
129 |
+
if not fname.endswith('.pdb'): continue
|
130 |
+
all_pdbs.append(fname)
|
131 |
+
|
132 |
+
if reset or not os.path.exists(self._cache_db_path):
|
133 |
+
todo_pdbs = all_pdbs
|
134 |
+
else:
|
135 |
+
self._connect_db()
|
136 |
+
processed_pdbs = self.db_ids
|
137 |
+
self._close_db()
|
138 |
+
todo_pdbs = list(set(all_pdbs) - set(processed_pdbs))
|
139 |
+
|
140 |
+
if len(todo_pdbs) > 0:
|
141 |
+
self._preprocess_structures(todo_pdbs)
|
142 |
+
|
143 |
+
def _preprocess_structures(self, pdb_list):
|
144 |
+
tasks = []
|
145 |
+
for pdb_fname in pdb_list:
|
146 |
+
pdb_path = os.path.join(self.structure_dir, pdb_fname)
|
147 |
+
tasks.append({
|
148 |
+
'id': pdb_fname,
|
149 |
+
'pdb_path': pdb_path,
|
150 |
+
})
|
151 |
+
|
152 |
+
data_list = joblib.Parallel(
|
153 |
+
n_jobs = max(joblib.cpu_count() // 2, 1),
|
154 |
+
)(
|
155 |
+
joblib.delayed(preprocess_antibody_structure)(task)
|
156 |
+
for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
|
157 |
+
)
|
158 |
+
|
159 |
+
db_conn = lmdb.open(
|
160 |
+
self._cache_db_path,
|
161 |
+
map_size = self.MAP_SIZE,
|
162 |
+
create=True,
|
163 |
+
subdir=False,
|
164 |
+
readonly=False,
|
165 |
+
)
|
166 |
+
ids = []
|
167 |
+
with db_conn.begin(write=True, buffers=True) as txn:
|
168 |
+
for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
|
169 |
+
if data is None:
|
170 |
+
continue
|
171 |
+
ids.append(data['id'])
|
172 |
+
txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
|
173 |
+
|
174 |
+
def __len__(self):
|
175 |
+
return len(self.db_ids)
|
176 |
+
|
177 |
+
def __getitem__(self, index):
|
178 |
+
self._connect_db()
|
179 |
+
id = self.db_ids[index]
|
180 |
+
with self.db_conn.begin() as txn:
|
181 |
+
data = pickle.loads(txn.get(id.encode()))
|
182 |
+
if self.transform is not None:
|
183 |
+
data = self.transform(data)
|
184 |
+
return data
|
185 |
+
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
import argparse
|
189 |
+
parser = argparse.ArgumentParser()
|
190 |
+
parser.add_argument('--dir', type=str, default='./data/custom')
|
191 |
+
parser.add_argument('--reset', action='store_true', default=False)
|
192 |
+
args = parser.parse_args()
|
193 |
+
|
194 |
+
dataset = CustomDataset(
|
195 |
+
structure_dir = args.dir,
|
196 |
+
reset = args.reset,
|
197 |
+
)
|
198 |
+
print(dataset[0])
|
199 |
+
print(len(dataset))
|
200 |
+
|
diffab/datasets/sabdab.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import logging
|
4 |
+
import datetime
|
5 |
+
import pandas as pd
|
6 |
+
import joblib
|
7 |
+
import pickle
|
8 |
+
import lmdb
|
9 |
+
import subprocess
|
10 |
+
import torch
|
11 |
+
from Bio import PDB, SeqRecord, SeqIO, Seq
|
12 |
+
from Bio.PDB import PDBExceptions
|
13 |
+
from Bio.PDB import Polypeptide
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
|
17 |
+
from ..utils.protein import parsers, constants
|
18 |
+
from ._base import register_dataset
|
19 |
+
|
20 |
+
|
21 |
+
ALLOWED_AG_TYPES = {
|
22 |
+
'protein',
|
23 |
+
'protein | protein',
|
24 |
+
'protein | protein | protein',
|
25 |
+
'protein | protein | protein | protein | protein',
|
26 |
+
'protein | protein | protein | protein',
|
27 |
+
}
|
28 |
+
|
29 |
+
RESOLUTION_THRESHOLD = 4.0
|
30 |
+
|
31 |
+
TEST_ANTIGENS = [
|
32 |
+
'sars-cov-2 receptor binding domain',
|
33 |
+
'hiv-1 envelope glycoprotein gp160',
|
34 |
+
'mers s',
|
35 |
+
'influenza a virus',
|
36 |
+
'cd27 antigen',
|
37 |
+
]
|
38 |
+
|
39 |
+
|
40 |
+
def nan_to_empty_string(val):
|
41 |
+
if val != val or not val:
|
42 |
+
return ''
|
43 |
+
else:
|
44 |
+
return val
|
45 |
+
|
46 |
+
|
47 |
+
def nan_to_none(val):
|
48 |
+
if val != val or not val:
|
49 |
+
return None
|
50 |
+
else:
|
51 |
+
return val
|
52 |
+
|
53 |
+
|
54 |
+
def split_sabdab_delimited_str(val):
|
55 |
+
if not val:
|
56 |
+
return []
|
57 |
+
else:
|
58 |
+
return [s.strip() for s in val.split('|')]
|
59 |
+
|
60 |
+
|
61 |
+
def parse_sabdab_resolution(val):
|
62 |
+
if val == 'NOT' or not val or val != val:
|
63 |
+
return None
|
64 |
+
elif isinstance(val, str) and ',' in val:
|
65 |
+
return float(val.split(',')[0].strip())
|
66 |
+
else:
|
67 |
+
return float(val)
|
68 |
+
|
69 |
+
|
70 |
+
def _aa_tensor_to_sequence(aa):
|
71 |
+
return ''.join([Polypeptide.index_to_one(a.item()) for a in aa.flatten()])
|
72 |
+
|
73 |
+
|
74 |
+
def _label_heavy_chain_cdr(data, seq_map, max_cdr3_length=30):
|
75 |
+
if data is None or seq_map is None:
|
76 |
+
return data, seq_map
|
77 |
+
|
78 |
+
# Add CDR labels
|
79 |
+
cdr_flag = torch.zeros_like(data['aa'])
|
80 |
+
for position, idx in seq_map.items():
|
81 |
+
resseq = position[1]
|
82 |
+
cdr_type = constants.ChothiaCDRRange.to_cdr('H', resseq)
|
83 |
+
if cdr_type is not None:
|
84 |
+
cdr_flag[idx] = cdr_type
|
85 |
+
data['cdr_flag'] = cdr_flag
|
86 |
+
|
87 |
+
# Add CDR sequence annotations
|
88 |
+
data['H1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H1] )
|
89 |
+
data['H2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H2] )
|
90 |
+
data['H3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H3] )
|
91 |
+
|
92 |
+
cdr3_length = (cdr_flag == constants.CDR.H3).sum().item()
|
93 |
+
# Remove too long CDR3
|
94 |
+
if cdr3_length > max_cdr3_length:
|
95 |
+
cdr_flag[cdr_flag == constants.CDR.H3] = 0
|
96 |
+
logging.warning(f'CDR-H3 too long {cdr3_length}. Removed.')
|
97 |
+
return None, None
|
98 |
+
|
99 |
+
# Filter: ensure CDR3 exists
|
100 |
+
if cdr3_length == 0:
|
101 |
+
logging.warning('No CDR-H3 found in the heavy chain.')
|
102 |
+
return None, None
|
103 |
+
|
104 |
+
return data, seq_map
|
105 |
+
|
106 |
+
|
107 |
+
def _label_light_chain_cdr(data, seq_map, max_cdr3_length=30):
|
108 |
+
if data is None or seq_map is None:
|
109 |
+
return data, seq_map
|
110 |
+
cdr_flag = torch.zeros_like(data['aa'])
|
111 |
+
for position, idx in seq_map.items():
|
112 |
+
resseq = position[1]
|
113 |
+
cdr_type = constants.ChothiaCDRRange.to_cdr('L', resseq)
|
114 |
+
if cdr_type is not None:
|
115 |
+
cdr_flag[idx] = cdr_type
|
116 |
+
data['cdr_flag'] = cdr_flag
|
117 |
+
|
118 |
+
data['L1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L1] )
|
119 |
+
data['L2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L2] )
|
120 |
+
data['L3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L3] )
|
121 |
+
|
122 |
+
cdr3_length = (cdr_flag == constants.CDR.L3).sum().item()
|
123 |
+
# Remove too long CDR3
|
124 |
+
if cdr3_length > max_cdr3_length:
|
125 |
+
cdr_flag[cdr_flag == constants.CDR.L3] = 0
|
126 |
+
logging.warning(f'CDR-L3 too long {cdr3_length}. Removed.')
|
127 |
+
return None, None
|
128 |
+
|
129 |
+
# Ensure CDR3 exists
|
130 |
+
if cdr3_length == 0:
|
131 |
+
logging.warning('No CDRs found in the light chain.')
|
132 |
+
return None, None
|
133 |
+
|
134 |
+
return data, seq_map
|
135 |
+
|
136 |
+
|
137 |
+
def preprocess_sabdab_structure(task):
|
138 |
+
entry = task['entry']
|
139 |
+
pdb_path = task['pdb_path']
|
140 |
+
|
141 |
+
parser = PDB.PDBParser(QUIET=True)
|
142 |
+
model = parser.get_structure(id, pdb_path)[0]
|
143 |
+
|
144 |
+
parsed = {
|
145 |
+
'id': entry['id'],
|
146 |
+
'heavy': None,
|
147 |
+
'heavy_seqmap': None,
|
148 |
+
'light': None,
|
149 |
+
'light_seqmap': None,
|
150 |
+
'antigen': None,
|
151 |
+
'antigen_seqmap': None,
|
152 |
+
}
|
153 |
+
try:
|
154 |
+
if entry['H_chain'] is not None:
|
155 |
+
(
|
156 |
+
parsed['heavy'],
|
157 |
+
parsed['heavy_seqmap']
|
158 |
+
) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure(
|
159 |
+
model[entry['H_chain']],
|
160 |
+
max_resseq = 113 # Chothia, end of Heavy chain Fv
|
161 |
+
))
|
162 |
+
|
163 |
+
if entry['L_chain'] is not None:
|
164 |
+
(
|
165 |
+
parsed['light'],
|
166 |
+
parsed['light_seqmap']
|
167 |
+
) = _label_light_chain_cdr(*parsers.parse_biopython_structure(
|
168 |
+
model[entry['L_chain']],
|
169 |
+
max_resseq = 106 # Chothia, end of Light chain Fv
|
170 |
+
))
|
171 |
+
|
172 |
+
if parsed['heavy'] is None and parsed['light'] is None:
|
173 |
+
raise ValueError('Neither valid H-chain or L-chain is found.')
|
174 |
+
|
175 |
+
if len(entry['ag_chains']) > 0:
|
176 |
+
chains = [model[c] for c in entry['ag_chains']]
|
177 |
+
(
|
178 |
+
parsed['antigen'],
|
179 |
+
parsed['antigen_seqmap']
|
180 |
+
) = parsers.parse_biopython_structure(chains)
|
181 |
+
|
182 |
+
except (
|
183 |
+
PDBExceptions.PDBConstructionException,
|
184 |
+
parsers.ParsingException,
|
185 |
+
KeyError,
|
186 |
+
ValueError,
|
187 |
+
) as e:
|
188 |
+
logging.warning('[{}] {}: {}'.format(
|
189 |
+
task['id'],
|
190 |
+
e.__class__.__name__,
|
191 |
+
str(e)
|
192 |
+
))
|
193 |
+
return None
|
194 |
+
|
195 |
+
return parsed
|
196 |
+
|
197 |
+
|
198 |
+
class SAbDabDataset(Dataset):
|
199 |
+
|
200 |
+
MAP_SIZE = 32*(1024*1024*1024) # 32GB
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
summary_path = './data/sabdab_summary_all.tsv',
|
205 |
+
chothia_dir = './data/all_structures/chothia',
|
206 |
+
processed_dir = './data/processed',
|
207 |
+
split = 'train',
|
208 |
+
split_seed = 2022,
|
209 |
+
transform = None,
|
210 |
+
reset = False,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
self.summary_path = summary_path
|
214 |
+
self.chothia_dir = chothia_dir
|
215 |
+
if not os.path.exists(chothia_dir):
|
216 |
+
raise FileNotFoundError(
|
217 |
+
f"SAbDab structures not found in {chothia_dir}. "
|
218 |
+
"Please download them from http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/"
|
219 |
+
)
|
220 |
+
self.processed_dir = processed_dir
|
221 |
+
os.makedirs(processed_dir, exist_ok=True)
|
222 |
+
|
223 |
+
self.sabdab_entries = None
|
224 |
+
self._load_sabdab_entries()
|
225 |
+
|
226 |
+
self.db_conn = None
|
227 |
+
self.db_ids = None
|
228 |
+
self._load_structures(reset)
|
229 |
+
|
230 |
+
self.clusters = None
|
231 |
+
self.id_to_cluster = None
|
232 |
+
self._load_clusters(reset)
|
233 |
+
|
234 |
+
self.ids_in_split = None
|
235 |
+
self._load_split(split, split_seed)
|
236 |
+
|
237 |
+
self.transform = transform
|
238 |
+
|
239 |
+
def _load_sabdab_entries(self):
|
240 |
+
df = pd.read_csv(self.summary_path, sep='\t')
|
241 |
+
entries_all = []
|
242 |
+
for i, row in tqdm(
|
243 |
+
df.iterrows(),
|
244 |
+
dynamic_ncols=True,
|
245 |
+
desc='Loading entries',
|
246 |
+
total=len(df),
|
247 |
+
):
|
248 |
+
entry_id = "{pdbcode}_{H}_{L}_{Ag}".format(
|
249 |
+
pdbcode = row['pdb'],
|
250 |
+
H = nan_to_empty_string(row['Hchain']),
|
251 |
+
L = nan_to_empty_string(row['Lchain']),
|
252 |
+
Ag = ''.join(split_sabdab_delimited_str(
|
253 |
+
nan_to_empty_string(row['antigen_chain'])
|
254 |
+
))
|
255 |
+
)
|
256 |
+
ag_chains = split_sabdab_delimited_str(
|
257 |
+
nan_to_empty_string(row['antigen_chain'])
|
258 |
+
)
|
259 |
+
resolution = parse_sabdab_resolution(row['resolution'])
|
260 |
+
entry = {
|
261 |
+
'id': entry_id,
|
262 |
+
'pdbcode': row['pdb'],
|
263 |
+
'H_chain': nan_to_none(row['Hchain']),
|
264 |
+
'L_chain': nan_to_none(row['Lchain']),
|
265 |
+
'ag_chains': ag_chains,
|
266 |
+
'ag_type': nan_to_none(row['antigen_type']),
|
267 |
+
'ag_name': nan_to_none(row['antigen_name']),
|
268 |
+
'date': datetime.datetime.strptime(row['date'], '%m/%d/%y'),
|
269 |
+
'resolution': resolution,
|
270 |
+
'method': row['method'],
|
271 |
+
'scfv': row['scfv'],
|
272 |
+
}
|
273 |
+
|
274 |
+
# Filtering
|
275 |
+
if (
|
276 |
+
(entry['ag_type'] in ALLOWED_AG_TYPES or entry['ag_type'] is None)
|
277 |
+
and (entry['resolution'] is not None and entry['resolution'] <= RESOLUTION_THRESHOLD)
|
278 |
+
):
|
279 |
+
entries_all.append(entry)
|
280 |
+
self.sabdab_entries = entries_all
|
281 |
+
|
282 |
+
def _load_structures(self, reset):
|
283 |
+
if not os.path.exists(self._structure_cache_path) or reset:
|
284 |
+
if os.path.exists(self._structure_cache_path):
|
285 |
+
os.unlink(self._structure_cache_path)
|
286 |
+
self._preprocess_structures()
|
287 |
+
|
288 |
+
with open(self._structure_cache_path + '-ids', 'rb') as f:
|
289 |
+
self.db_ids = pickle.load(f)
|
290 |
+
self.sabdab_entries = list(
|
291 |
+
filter(
|
292 |
+
lambda e: e['id'] in self.db_ids,
|
293 |
+
self.sabdab_entries
|
294 |
+
)
|
295 |
+
)
|
296 |
+
|
297 |
+
@property
|
298 |
+
def _structure_cache_path(self):
|
299 |
+
return os.path.join(self.processed_dir, 'structures.lmdb')
|
300 |
+
|
301 |
+
def _preprocess_structures(self):
|
302 |
+
tasks = []
|
303 |
+
for entry in self.sabdab_entries:
|
304 |
+
pdb_path = os.path.join(self.chothia_dir, '{}.pdb'.format(entry['pdbcode']))
|
305 |
+
if not os.path.exists(pdb_path):
|
306 |
+
logging.warning(f"PDB not found: {pdb_path}")
|
307 |
+
continue
|
308 |
+
tasks.append({
|
309 |
+
'id': entry['id'],
|
310 |
+
'entry': entry,
|
311 |
+
'pdb_path': pdb_path,
|
312 |
+
})
|
313 |
+
|
314 |
+
data_list = joblib.Parallel(
|
315 |
+
n_jobs = max(joblib.cpu_count() // 2, 1),
|
316 |
+
)(
|
317 |
+
joblib.delayed(preprocess_sabdab_structure)(task)
|
318 |
+
for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
|
319 |
+
)
|
320 |
+
|
321 |
+
db_conn = lmdb.open(
|
322 |
+
self._structure_cache_path,
|
323 |
+
map_size = self.MAP_SIZE,
|
324 |
+
create=True,
|
325 |
+
subdir=False,
|
326 |
+
readonly=False,
|
327 |
+
)
|
328 |
+
ids = []
|
329 |
+
with db_conn.begin(write=True, buffers=True) as txn:
|
330 |
+
for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
|
331 |
+
if data is None:
|
332 |
+
continue
|
333 |
+
ids.append(data['id'])
|
334 |
+
txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
|
335 |
+
|
336 |
+
with open(self._structure_cache_path + '-ids', 'wb') as f:
|
337 |
+
pickle.dump(ids, f)
|
338 |
+
|
339 |
+
@property
|
340 |
+
def _cluster_path(self):
|
341 |
+
return os.path.join(self.processed_dir, 'cluster_result_cluster.tsv')
|
342 |
+
|
343 |
+
def _load_clusters(self, reset):
|
344 |
+
if not os.path.exists(self._cluster_path) or reset:
|
345 |
+
self._create_clusters()
|
346 |
+
|
347 |
+
clusters, id_to_cluster = {}, {}
|
348 |
+
with open(self._cluster_path, 'r') as f:
|
349 |
+
for line in f.readlines():
|
350 |
+
cluster_name, data_id = line.split()
|
351 |
+
if cluster_name not in clusters:
|
352 |
+
clusters[cluster_name] = []
|
353 |
+
clusters[cluster_name].append(data_id)
|
354 |
+
id_to_cluster[data_id] = cluster_name
|
355 |
+
self.clusters = clusters
|
356 |
+
self.id_to_cluster = id_to_cluster
|
357 |
+
|
358 |
+
def _create_clusters(self):
|
359 |
+
cdr_records = []
|
360 |
+
for id in self.db_ids:
|
361 |
+
structure = self.get_structure(id)
|
362 |
+
if structure['heavy'] is not None:
|
363 |
+
cdr_records.append(SeqRecord.SeqRecord(
|
364 |
+
Seq.Seq(structure['heavy']['H3_seq']),
|
365 |
+
id = structure['id'],
|
366 |
+
name = '',
|
367 |
+
description = '',
|
368 |
+
))
|
369 |
+
elif structure['light'] is not None:
|
370 |
+
cdr_records.append(SeqRecord.SeqRecord(
|
371 |
+
Seq.Seq(structure['light']['L3_seq']),
|
372 |
+
id = structure['id'],
|
373 |
+
name = '',
|
374 |
+
description = '',
|
375 |
+
))
|
376 |
+
fasta_path = os.path.join(self.processed_dir, 'cdr_sequences.fasta')
|
377 |
+
SeqIO.write(cdr_records, fasta_path, 'fasta')
|
378 |
+
|
379 |
+
cmd = ' '.join([
|
380 |
+
'mmseqs', 'easy-cluster',
|
381 |
+
os.path.realpath(fasta_path),
|
382 |
+
'cluster_result', 'cluster_tmp',
|
383 |
+
'--min-seq-id', '0.5',
|
384 |
+
'-c', '0.8',
|
385 |
+
'--cov-mode', '1',
|
386 |
+
])
|
387 |
+
subprocess.run(cmd, cwd=self.processed_dir, shell=True, check=True)
|
388 |
+
|
389 |
+
def _load_split(self, split, split_seed):
|
390 |
+
assert split in ('train', 'val', 'test')
|
391 |
+
ids_test = [
|
392 |
+
entry['id']
|
393 |
+
for entry in self.sabdab_entries
|
394 |
+
if entry['ag_name'] in TEST_ANTIGENS
|
395 |
+
]
|
396 |
+
test_relevant_clusters = set([self.id_to_cluster[id] for id in ids_test])
|
397 |
+
|
398 |
+
ids_train_val = [
|
399 |
+
entry['id']
|
400 |
+
for entry in self.sabdab_entries
|
401 |
+
if self.id_to_cluster[entry['id']] not in test_relevant_clusters
|
402 |
+
]
|
403 |
+
random.Random(split_seed).shuffle(ids_train_val)
|
404 |
+
if split == 'test':
|
405 |
+
self.ids_in_split = ids_test
|
406 |
+
elif split == 'val':
|
407 |
+
self.ids_in_split = ids_train_val[:20]
|
408 |
+
else:
|
409 |
+
self.ids_in_split = ids_train_val[20:]
|
410 |
+
|
411 |
+
def _connect_db(self):
|
412 |
+
if self.db_conn is not None:
|
413 |
+
return
|
414 |
+
self.db_conn = lmdb.open(
|
415 |
+
self._structure_cache_path,
|
416 |
+
map_size=self.MAP_SIZE,
|
417 |
+
create=False,
|
418 |
+
subdir=False,
|
419 |
+
readonly=True,
|
420 |
+
lock=False,
|
421 |
+
readahead=False,
|
422 |
+
meminit=False,
|
423 |
+
)
|
424 |
+
|
425 |
+
def get_structure(self, id):
|
426 |
+
self._connect_db()
|
427 |
+
with self.db_conn.begin() as txn:
|
428 |
+
return pickle.loads(txn.get(id.encode()))
|
429 |
+
|
430 |
+
def __len__(self):
|
431 |
+
return len(self.ids_in_split)
|
432 |
+
|
433 |
+
def __getitem__(self, index):
|
434 |
+
id = self.ids_in_split[index]
|
435 |
+
data = self.get_structure(id)
|
436 |
+
if self.transform is not None:
|
437 |
+
data = self.transform(data)
|
438 |
+
return data
|
439 |
+
|
440 |
+
|
441 |
+
@register_dataset('sabdab')
|
442 |
+
def get_sabdab_dataset(cfg, transform):
|
443 |
+
return SAbDabDataset(
|
444 |
+
summary_path = cfg.summary_path,
|
445 |
+
chothia_dir = cfg.chothia_dir,
|
446 |
+
processed_dir = cfg.processed_dir,
|
447 |
+
split = cfg.split,
|
448 |
+
split_seed = cfg.get('split_seed', 2022),
|
449 |
+
transform = transform,
|
450 |
+
)
|
451 |
+
|
452 |
+
|
453 |
+
if __name__ == '__main__':
|
454 |
+
import argparse
|
455 |
+
parser = argparse.ArgumentParser()
|
456 |
+
parser.add_argument('--split', type=str, default='train')
|
457 |
+
parser.add_argument('--processed_dir', type=str, default='./data/processed')
|
458 |
+
parser.add_argument('--reset', action='store_true', default=False)
|
459 |
+
args = parser.parse_args()
|
460 |
+
if args.reset:
|
461 |
+
sure = input('Sure to reset? (y/n): ')
|
462 |
+
if sure != 'y':
|
463 |
+
exit()
|
464 |
+
dataset = SAbDabDataset(
|
465 |
+
processed_dir=args.processed_dir,
|
466 |
+
split=args.split,
|
467 |
+
reset=args.reset
|
468 |
+
)
|
469 |
+
print(dataset[0])
|
470 |
+
print(len(dataset), len(dataset.clusters))
|
diffab/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .diffab import DiffusionAntibodyDesign
|
2 |
+
|
3 |
+
from ._base import get_model
|
diffab/models/_base.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
_MODEL_DICT = {}
|
3 |
+
|
4 |
+
|
5 |
+
def register_model(name):
|
6 |
+
def decorator(cls):
|
7 |
+
_MODEL_DICT[name] = cls
|
8 |
+
return cls
|
9 |
+
return decorator
|
10 |
+
|
11 |
+
|
12 |
+
def get_model(cfg):
|
13 |
+
return _MODEL_DICT[cfg.type](cfg)
|
diffab/models/diffab.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from diffab.modules.common.geometry import construct_3d_basis
|
5 |
+
from diffab.modules.common.so3 import rotation_to_so3vec
|
6 |
+
from diffab.modules.encoders.residue import ResidueEmbedding
|
7 |
+
from diffab.modules.encoders.pair import PairEmbedding
|
8 |
+
from diffab.modules.diffusion.dpm_full import FullDPM
|
9 |
+
from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom
|
10 |
+
from ._base import register_model
|
11 |
+
|
12 |
+
|
13 |
+
resolution_to_num_atoms = {
|
14 |
+
'backbone+CB': 5,
|
15 |
+
'full': max_num_heavyatoms
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
@register_model('diffab')
|
20 |
+
class DiffusionAntibodyDesign(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, cfg):
|
23 |
+
super().__init__()
|
24 |
+
self.cfg = cfg
|
25 |
+
|
26 |
+
num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')]
|
27 |
+
self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms)
|
28 |
+
self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms)
|
29 |
+
|
30 |
+
self.diffusion = FullDPM(
|
31 |
+
cfg.res_feat_dim,
|
32 |
+
cfg.pair_feat_dim,
|
33 |
+
**cfg.diffusion,
|
34 |
+
)
|
35 |
+
|
36 |
+
def encode(self, batch, remove_structure, remove_sequence):
|
37 |
+
"""
|
38 |
+
Returns:
|
39 |
+
res_feat: (N, L, res_feat_dim)
|
40 |
+
pair_feat: (N, L, L, pair_feat_dim)
|
41 |
+
"""
|
42 |
+
# This is used throughout embedding and encoding layers
|
43 |
+
# to avoid data leakage.
|
44 |
+
context_mask = torch.logical_and(
|
45 |
+
batch['mask_heavyatom'][:, :, BBHeavyAtom.CA],
|
46 |
+
~batch['generate_flag'] # Context means ``not generated''
|
47 |
+
)
|
48 |
+
|
49 |
+
structure_mask = context_mask if remove_structure else None
|
50 |
+
sequence_mask = context_mask if remove_sequence else None
|
51 |
+
|
52 |
+
res_feat = self.residue_embed(
|
53 |
+
aa = batch['aa'],
|
54 |
+
res_nb = batch['res_nb'],
|
55 |
+
chain_nb = batch['chain_nb'],
|
56 |
+
pos_atoms = batch['pos_heavyatom'],
|
57 |
+
mask_atoms = batch['mask_heavyatom'],
|
58 |
+
fragment_type = batch['fragment_type'],
|
59 |
+
structure_mask = structure_mask,
|
60 |
+
sequence_mask = sequence_mask,
|
61 |
+
)
|
62 |
+
|
63 |
+
pair_feat = self.pair_embed(
|
64 |
+
aa = batch['aa'],
|
65 |
+
res_nb = batch['res_nb'],
|
66 |
+
chain_nb = batch['chain_nb'],
|
67 |
+
pos_atoms = batch['pos_heavyatom'],
|
68 |
+
mask_atoms = batch['mask_heavyatom'],
|
69 |
+
structure_mask = structure_mask,
|
70 |
+
sequence_mask = sequence_mask,
|
71 |
+
)
|
72 |
+
|
73 |
+
R = construct_3d_basis(
|
74 |
+
batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],
|
75 |
+
batch['pos_heavyatom'][:, :, BBHeavyAtom.C],
|
76 |
+
batch['pos_heavyatom'][:, :, BBHeavyAtom.N],
|
77 |
+
)
|
78 |
+
p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
|
79 |
+
|
80 |
+
return res_feat, pair_feat, R, p
|
81 |
+
|
82 |
+
def forward(self, batch):
|
83 |
+
mask_generate = batch['generate_flag']
|
84 |
+
mask_res = batch['mask']
|
85 |
+
res_feat, pair_feat, R_0, p_0 = self.encode(
|
86 |
+
batch,
|
87 |
+
remove_structure = self.cfg.get('train_structure', True),
|
88 |
+
remove_sequence = self.cfg.get('train_sequence', True)
|
89 |
+
)
|
90 |
+
v_0 = rotation_to_so3vec(R_0)
|
91 |
+
s_0 = batch['aa']
|
92 |
+
|
93 |
+
loss_dict = self.diffusion(
|
94 |
+
v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res,
|
95 |
+
denoise_structure = self.cfg.get('train_structure', True),
|
96 |
+
denoise_sequence = self.cfg.get('train_sequence', True),
|
97 |
+
)
|
98 |
+
return loss_dict
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def sample(
|
102 |
+
self,
|
103 |
+
batch,
|
104 |
+
sample_opt={
|
105 |
+
'sample_structure': True,
|
106 |
+
'sample_sequence': True,
|
107 |
+
}
|
108 |
+
):
|
109 |
+
mask_generate = batch['generate_flag']
|
110 |
+
mask_res = batch['mask']
|
111 |
+
res_feat, pair_feat, R_0, p_0 = self.encode(
|
112 |
+
batch,
|
113 |
+
remove_structure = sample_opt.get('sample_structure', True),
|
114 |
+
remove_sequence = sample_opt.get('sample_sequence', True)
|
115 |
+
)
|
116 |
+
v_0 = rotation_to_so3vec(R_0)
|
117 |
+
s_0 = batch['aa']
|
118 |
+
traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt)
|
119 |
+
return traj
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def optimize(
|
123 |
+
self,
|
124 |
+
batch,
|
125 |
+
opt_step,
|
126 |
+
optimize_opt={
|
127 |
+
'sample_structure': True,
|
128 |
+
'sample_sequence': True,
|
129 |
+
}
|
130 |
+
):
|
131 |
+
mask_generate = batch['generate_flag']
|
132 |
+
mask_res = batch['mask']
|
133 |
+
res_feat, pair_feat, R_0, p_0 = self.encode(
|
134 |
+
batch,
|
135 |
+
remove_structure = optimize_opt.get('sample_structure', True),
|
136 |
+
remove_sequence = optimize_opt.get('sample_sequence', True)
|
137 |
+
)
|
138 |
+
v_0 = rotation_to_so3vec(R_0)
|
139 |
+
s_0 = batch['aa']
|
140 |
+
|
141 |
+
traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt)
|
142 |
+
return traj
|
diffab/modules/common/geometry.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from diffab.utils.protein.constants import (
|
5 |
+
BBHeavyAtom,
|
6 |
+
backbone_atom_coordinates_tensor,
|
7 |
+
bb_oxygen_coordinate_tensor,
|
8 |
+
)
|
9 |
+
from .topology import get_terminus_flag
|
10 |
+
|
11 |
+
|
12 |
+
def safe_norm(x, dim=-1, keepdim=False, eps=1e-8, sqrt=True):
|
13 |
+
out = torch.clamp(torch.sum(torch.square(x), dim=dim, keepdim=keepdim), min=eps)
|
14 |
+
return torch.sqrt(out) if sqrt else out
|
15 |
+
|
16 |
+
|
17 |
+
def pairwise_distances(x, y=None, return_v=False):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
x: (B, N, d)
|
21 |
+
y: (B, M, d)
|
22 |
+
"""
|
23 |
+
if y is None: y = x
|
24 |
+
v = x.unsqueeze(2) - y.unsqueeze(1) # (B, N, M, d)
|
25 |
+
d = safe_norm(v, dim=-1)
|
26 |
+
if return_v:
|
27 |
+
return d, v
|
28 |
+
else:
|
29 |
+
return d
|
30 |
+
|
31 |
+
|
32 |
+
def normalize_vector(v, dim, eps=1e-6):
|
33 |
+
return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps)
|
34 |
+
|
35 |
+
|
36 |
+
def project_v2v(v, e, dim):
|
37 |
+
"""
|
38 |
+
Description:
|
39 |
+
Project vector `v` onto vector `e`.
|
40 |
+
Args:
|
41 |
+
v: (N, L, 3).
|
42 |
+
e: (N, L, 3).
|
43 |
+
"""
|
44 |
+
return (e * v).sum(dim=dim, keepdim=True) * e
|
45 |
+
|
46 |
+
|
47 |
+
def construct_3d_basis(center, p1, p2):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
center: (N, L, 3), usually the position of C_alpha.
|
51 |
+
p1: (N, L, 3), usually the position of C.
|
52 |
+
p2: (N, L, 3), usually the position of N.
|
53 |
+
Returns
|
54 |
+
A batch of orthogonal basis matrix, (N, L, 3, 3cols_index).
|
55 |
+
The matrix is composed of 3 column vectors: [e1, e2, e3].
|
56 |
+
"""
|
57 |
+
v1 = p1 - center # (N, L, 3)
|
58 |
+
e1 = normalize_vector(v1, dim=-1)
|
59 |
+
|
60 |
+
v2 = p2 - center # (N, L, 3)
|
61 |
+
u2 = v2 - project_v2v(v2, e1, dim=-1)
|
62 |
+
e2 = normalize_vector(u2, dim=-1)
|
63 |
+
|
64 |
+
e3 = torch.cross(e1, e2, dim=-1) # (N, L, 3)
|
65 |
+
|
66 |
+
mat = torch.cat([
|
67 |
+
e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1)
|
68 |
+
], dim=-1) # (N, L, 3, 3_index)
|
69 |
+
return mat
|
70 |
+
|
71 |
+
|
72 |
+
def local_to_global(R, t, p):
|
73 |
+
"""
|
74 |
+
Description:
|
75 |
+
Convert local (internal) coordinates to global (external) coordinates q.
|
76 |
+
q <- Rp + t
|
77 |
+
Args:
|
78 |
+
R: (N, L, 3, 3).
|
79 |
+
t: (N, L, 3).
|
80 |
+
p: Local coordinates, (N, L, ..., 3).
|
81 |
+
Returns:
|
82 |
+
q: Global coordinates, (N, L, ..., 3).
|
83 |
+
"""
|
84 |
+
assert p.size(-1) == 3
|
85 |
+
p_size = p.size()
|
86 |
+
N, L = p_size[0], p_size[1]
|
87 |
+
|
88 |
+
p = p.view(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *)
|
89 |
+
q = torch.matmul(R, p) + t.unsqueeze(-1) # (N, L, 3, *)
|
90 |
+
q = q.transpose(-1, -2).reshape(p_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3)
|
91 |
+
return q
|
92 |
+
|
93 |
+
|
94 |
+
def global_to_local(R, t, q):
|
95 |
+
"""
|
96 |
+
Description:
|
97 |
+
Convert global (external) coordinates q to local (internal) coordinates p.
|
98 |
+
p <- R^{T}(q - t)
|
99 |
+
Args:
|
100 |
+
R: (N, L, 3, 3).
|
101 |
+
t: (N, L, 3).
|
102 |
+
q: Global coordinates, (N, L, ..., 3).
|
103 |
+
Returns:
|
104 |
+
p: Local coordinates, (N, L, ..., 3).
|
105 |
+
"""
|
106 |
+
assert q.size(-1) == 3
|
107 |
+
q_size = q.size()
|
108 |
+
N, L = q_size[0], q_size[1]
|
109 |
+
|
110 |
+
q = q.reshape(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *)
|
111 |
+
p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (N, L, 3, *)
|
112 |
+
p = p.transpose(-1, -2).reshape(q_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3)
|
113 |
+
return p
|
114 |
+
|
115 |
+
|
116 |
+
def apply_rotation_to_vector(R, p):
|
117 |
+
return local_to_global(R, torch.zeros_like(p), p)
|
118 |
+
|
119 |
+
|
120 |
+
def compose_rotation_and_translation(R1, t1, R2, t2):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
R1,t1: Frame basis and coordinate, (N, L, 3, 3), (N, L, 3).
|
124 |
+
R2,t2: Rotation and translation to be applied to (R1, t1), (N, L, 3, 3), (N, L, 3).
|
125 |
+
Returns
|
126 |
+
R_new <- R1R2
|
127 |
+
t_new <- R1t2 + t1
|
128 |
+
"""
|
129 |
+
R_new = torch.matmul(R1, R2) # (N, L, 3, 3)
|
130 |
+
t_new = torch.matmul(R1, t2.unsqueeze(-1)).squeeze(-1) + t1
|
131 |
+
return R_new, t_new
|
132 |
+
|
133 |
+
|
134 |
+
def compose_chain(Ts):
|
135 |
+
while len(Ts) >= 2:
|
136 |
+
R1, t1 = Ts[-2]
|
137 |
+
R2, t2 = Ts[-1]
|
138 |
+
T_next = compose_rotation_and_translation(R1, t1, R2, t2)
|
139 |
+
Ts = Ts[:-2] + [T_next]
|
140 |
+
return Ts[0]
|
141 |
+
|
142 |
+
|
143 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
144 |
+
# All rights reserved.
|
145 |
+
#
|
146 |
+
# This source code is licensed under the BSD-style license found in the
|
147 |
+
# LICENSE file in the root directory of this source tree.
|
148 |
+
def quaternion_to_rotation_matrix(quaternions):
|
149 |
+
"""
|
150 |
+
Convert rotations given as quaternions to rotation matrices.
|
151 |
+
Args:
|
152 |
+
quaternions: quaternions with real part first,
|
153 |
+
as tensor of shape (..., 4).
|
154 |
+
Returns:
|
155 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
156 |
+
"""
|
157 |
+
quaternions = F.normalize(quaternions, dim=-1)
|
158 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
159 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
160 |
+
|
161 |
+
o = torch.stack(
|
162 |
+
(
|
163 |
+
1 - two_s * (j * j + k * k),
|
164 |
+
two_s * (i * j - k * r),
|
165 |
+
two_s * (i * k + j * r),
|
166 |
+
two_s * (i * j + k * r),
|
167 |
+
1 - two_s * (i * i + k * k),
|
168 |
+
two_s * (j * k - i * r),
|
169 |
+
two_s * (i * k - j * r),
|
170 |
+
two_s * (j * k + i * r),
|
171 |
+
1 - two_s * (i * i + j * j),
|
172 |
+
),
|
173 |
+
-1,
|
174 |
+
)
|
175 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
176 |
+
|
177 |
+
|
178 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
179 |
+
# All rights reserved.
|
180 |
+
#
|
181 |
+
# This source code is licensed under the BSD-style license found in the
|
182 |
+
# LICENSE file in the root directory of this source tree.
|
183 |
+
"""
|
184 |
+
BSD License
|
185 |
+
|
186 |
+
For PyTorch3D software
|
187 |
+
|
188 |
+
Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
|
189 |
+
|
190 |
+
Redistribution and use in source and binary forms, with or without modification,
|
191 |
+
are permitted provided that the following conditions are met:
|
192 |
+
|
193 |
+
* Redistributions of source code must retain the above copyright notice, this
|
194 |
+
list of conditions and the following disclaimer.
|
195 |
+
|
196 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
197 |
+
this list of conditions and the following disclaimer in the documentation
|
198 |
+
and/or other materials provided with the distribution.
|
199 |
+
|
200 |
+
* Neither the name Meta nor the names of its contributors may be used to
|
201 |
+
endorse or promote products derived from this software without specific
|
202 |
+
prior written permission.
|
203 |
+
|
204 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
205 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
206 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
207 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
208 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
209 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
210 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
211 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
212 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
213 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
214 |
+
"""
|
215 |
+
def quaternion_1ijk_to_rotation_matrix(q):
|
216 |
+
"""
|
217 |
+
(1 + ai + bj + ck) -> R
|
218 |
+
Args:
|
219 |
+
q: (..., 3)
|
220 |
+
"""
|
221 |
+
b, c, d = torch.unbind(q, dim=-1)
|
222 |
+
s = torch.sqrt(1 + b**2 + c**2 + d**2)
|
223 |
+
a, b, c, d = 1/s, b/s, c/s, d/s
|
224 |
+
|
225 |
+
o = torch.stack(
|
226 |
+
(
|
227 |
+
a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c,
|
228 |
+
2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b,
|
229 |
+
2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2,
|
230 |
+
),
|
231 |
+
-1,
|
232 |
+
)
|
233 |
+
return o.reshape(q.shape[:-1] + (3, 3))
|
234 |
+
|
235 |
+
|
236 |
+
def repr_6d_to_rotation_matrix(x):
|
237 |
+
"""
|
238 |
+
Args:
|
239 |
+
x: 6D representations, (..., 6).
|
240 |
+
Returns:
|
241 |
+
Rotation matrices, (..., 3, 3_index).
|
242 |
+
"""
|
243 |
+
a1, a2 = x[..., 0:3], x[..., 3:6]
|
244 |
+
b1 = normalize_vector(a1, dim=-1)
|
245 |
+
b2 = normalize_vector(a2 - project_v2v(a2, b1, dim=-1), dim=-1)
|
246 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
247 |
+
|
248 |
+
mat = torch.cat([
|
249 |
+
b1.unsqueeze(-1), b2.unsqueeze(-1), b3.unsqueeze(-1)
|
250 |
+
], dim=-1) # (N, L, 3, 3_index)
|
251 |
+
return mat
|
252 |
+
|
253 |
+
|
254 |
+
def dihedral_from_four_points(p0, p1, p2, p3):
|
255 |
+
"""
|
256 |
+
Args:
|
257 |
+
p0-3: (*, 3).
|
258 |
+
Returns:
|
259 |
+
Dihedral angles in radian, (*, ).
|
260 |
+
"""
|
261 |
+
v0 = p2 - p1
|
262 |
+
v1 = p0 - p1
|
263 |
+
v2 = p3 - p2
|
264 |
+
u1 = torch.cross(v0, v1, dim=-1)
|
265 |
+
n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True)
|
266 |
+
u2 = torch.cross(v0, v2, dim=-1)
|
267 |
+
n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True)
|
268 |
+
sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) )
|
269 |
+
dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999) )
|
270 |
+
dihed = torch.nan_to_num(dihed)
|
271 |
+
return dihed
|
272 |
+
|
273 |
+
|
274 |
+
def knn_gather(idx, value):
|
275 |
+
"""
|
276 |
+
Args:
|
277 |
+
idx: (B, N, K)
|
278 |
+
value: (B, M, d)
|
279 |
+
Returns:
|
280 |
+
(B, N, K, d)
|
281 |
+
"""
|
282 |
+
N, d = idx.size(1), value.size(-1)
|
283 |
+
idx = idx.unsqueeze(-1).repeat(1, 1, 1, d) # (B, N, K, d)
|
284 |
+
value = value.unsqueeze(1).repeat(1, N, 1, 1) # (B, N, M, d)
|
285 |
+
return torch.gather(value, dim=2, index=idx)
|
286 |
+
|
287 |
+
|
288 |
+
def knn_points(q, p, K):
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
q: (B, M, d)
|
292 |
+
p: (B, N, d)
|
293 |
+
Returns:
|
294 |
+
(B, M, K), (B, M, K), (B, M, K, d)
|
295 |
+
"""
|
296 |
+
_, L, _ = p.size()
|
297 |
+
d = pairwise_distances(q, p) # (B, N, M)
|
298 |
+
dist, idx = d.topk(min(L, K), dim=-1, largest=False) # (B, M, K), (B, M, K)
|
299 |
+
return dist, idx, knn_gather(idx, p)
|
300 |
+
|
301 |
+
|
302 |
+
def angstrom_to_nm(x):
|
303 |
+
return x / 10
|
304 |
+
|
305 |
+
|
306 |
+
def nm_to_angstrom(x):
|
307 |
+
return x * 10
|
308 |
+
|
309 |
+
|
310 |
+
def get_backbone_dihedral_angles(pos_atoms, chain_nb, res_nb, mask):
|
311 |
+
"""
|
312 |
+
Args:
|
313 |
+
pos_atoms: (N, L, A, 3).
|
314 |
+
chain_nb: (N, L).
|
315 |
+
res_nb: (N, L).
|
316 |
+
mask: (N, L).
|
317 |
+
Returns:
|
318 |
+
bb_dihedral: Omega, Phi, and Psi angles in radian, (N, L, 3).
|
319 |
+
mask_bb_dihed: Masks of dihedral angles, (N, L, 3).
|
320 |
+
"""
|
321 |
+
pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3)
|
322 |
+
pos_CA = pos_atoms[:, :, BBHeavyAtom.CA]
|
323 |
+
pos_C = pos_atoms[:, :, BBHeavyAtom.C]
|
324 |
+
|
325 |
+
N_term_flag, C_term_flag = get_terminus_flag(chain_nb, res_nb, mask) # (N, L)
|
326 |
+
omega_mask = torch.logical_not(N_term_flag)
|
327 |
+
phi_mask = torch.logical_not(N_term_flag)
|
328 |
+
psi_mask = torch.logical_not(C_term_flag)
|
329 |
+
|
330 |
+
# N-termini don't have omega and phi
|
331 |
+
omega = F.pad(
|
332 |
+
dihedral_from_four_points(pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:]),
|
333 |
+
pad=(1, 0), value=0,
|
334 |
+
)
|
335 |
+
phi = F.pad(
|
336 |
+
dihedral_from_four_points(pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:], pos_C[:, 1:]),
|
337 |
+
pad=(1, 0), value=0,
|
338 |
+
)
|
339 |
+
|
340 |
+
# C-termini don't have psi
|
341 |
+
psi = F.pad(
|
342 |
+
dihedral_from_four_points(pos_N[:, :-1], pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:]),
|
343 |
+
pad=(0, 1), value=0,
|
344 |
+
)
|
345 |
+
|
346 |
+
mask_bb_dihed = torch.stack([omega_mask, phi_mask, psi_mask], dim=-1)
|
347 |
+
bb_dihedral = torch.stack([omega, phi, psi], dim=-1) * mask_bb_dihed
|
348 |
+
return bb_dihedral, mask_bb_dihed
|
349 |
+
|
350 |
+
|
351 |
+
def pairwise_dihedrals(pos_atoms):
|
352 |
+
"""
|
353 |
+
Args:
|
354 |
+
pos_atoms: (N, L, A, 3).
|
355 |
+
Returns:
|
356 |
+
Inter-residue Phi and Psi angles, (N, L, L, 2).
|
357 |
+
"""
|
358 |
+
N, L = pos_atoms.shape[:2]
|
359 |
+
pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3)
|
360 |
+
pos_CA = pos_atoms[:, :, BBHeavyAtom.CA]
|
361 |
+
pos_C = pos_atoms[:, :, BBHeavyAtom.C]
|
362 |
+
|
363 |
+
ir_phi = dihedral_from_four_points(
|
364 |
+
pos_C[:,:,None].expand(N, L, L, 3),
|
365 |
+
pos_N[:,None,:].expand(N, L, L, 3),
|
366 |
+
pos_CA[:,None,:].expand(N, L, L, 3),
|
367 |
+
pos_C[:,None,:].expand(N, L, L, 3)
|
368 |
+
)
|
369 |
+
ir_psi = dihedral_from_four_points(
|
370 |
+
pos_N[:,:,None].expand(N, L, L, 3),
|
371 |
+
pos_CA[:,:,None].expand(N, L, L, 3),
|
372 |
+
pos_C[:,:,None].expand(N, L, L, 3),
|
373 |
+
pos_N[:,None,:].expand(N, L, L, 3)
|
374 |
+
)
|
375 |
+
ir_dihed = torch.stack([ir_phi, ir_psi], dim=-1)
|
376 |
+
return ir_dihed
|
377 |
+
|
378 |
+
|
379 |
+
def apply_rotation_matrix_to_rot6d(R, O):
|
380 |
+
"""
|
381 |
+
Args:
|
382 |
+
R: (..., 3, 3)
|
383 |
+
O: (..., 6)
|
384 |
+
Returns:
|
385 |
+
Rotated 6D representation, (..., 6).
|
386 |
+
"""
|
387 |
+
u1, u2 = O[..., :3, None], O[..., 3:, None] # (..., 3, 1)
|
388 |
+
v1 = torch.matmul(R, u1).squeeze(-1) # (..., 3)
|
389 |
+
v2 = torch.matmul(R, u2).squeeze(-1)
|
390 |
+
return torch.cat([v1, v2], dim=-1)
|
391 |
+
|
392 |
+
|
393 |
+
def normalize_rot6d(O):
|
394 |
+
"""
|
395 |
+
Args:
|
396 |
+
O: (..., 6)
|
397 |
+
"""
|
398 |
+
u1, u2 = O[..., :3], O[..., 3:] # (..., 3)
|
399 |
+
v1 = F.normalize(u1, p=2, dim=-1) # (..., 3)
|
400 |
+
v2 = F.normalize(u2 - project_v2v(u2, v1), p=2, dim=-1)
|
401 |
+
return torch.cat([v1, v2], dim=-1)
|
402 |
+
|
403 |
+
|
404 |
+
def reconstruct_backbone(R, t, aa, chain_nb, res_nb, mask):
|
405 |
+
"""
|
406 |
+
Args:
|
407 |
+
R: (N, L, 3, 3)
|
408 |
+
t: (N, L, 3)
|
409 |
+
aa: (N, L)
|
410 |
+
chain_nb: (N, L)
|
411 |
+
res_nb: (N, L)
|
412 |
+
mask: (N, L)
|
413 |
+
Returns:
|
414 |
+
Reconstructed backbone atoms, (N, L, 4, 3).
|
415 |
+
"""
|
416 |
+
N, L = aa.size()
|
417 |
+
# atom_coords = restype_heavyatom_rigid_group_positions.clone().to(t) # (21, 14, 3)
|
418 |
+
bb_coords = backbone_atom_coordinates_tensor.clone().to(t) # (21, 3, 3)
|
419 |
+
oxygen_coord = bb_oxygen_coordinate_tensor.clone().to(t) # (21, 3)
|
420 |
+
aa = aa.clamp(min=0, max=20) # 20 for UNK
|
421 |
+
|
422 |
+
bb_coords = bb_coords[aa.flatten()].reshape(N, L, -1, 3) # (N, L, 3, 3)
|
423 |
+
oxygen_coord = oxygen_coord[aa.flatten()].reshape(N, L, -1) # (N, L, 3)
|
424 |
+
bb_pos = local_to_global(R, t, bb_coords) # Global coordinates of N, CA, C. (N, L, 3, 3).
|
425 |
+
|
426 |
+
# Compute PSI angle
|
427 |
+
bb_dihedral, _ = get_backbone_dihedral_angles(bb_pos, chain_nb, res_nb, mask)
|
428 |
+
psi = bb_dihedral[..., 2] # (N, L)
|
429 |
+
# Make rotation matrix for PSI
|
430 |
+
sin_psi = torch.sin(psi).reshape(N, L, 1, 1)
|
431 |
+
cos_psi = torch.cos(psi).reshape(N, L, 1, 1)
|
432 |
+
zero = torch.zeros_like(sin_psi)
|
433 |
+
one = torch.ones_like(sin_psi)
|
434 |
+
row1 = torch.cat([one, zero, zero], dim=-1) # (N, L, 1, 3)
|
435 |
+
row2 = torch.cat([zero, cos_psi, -sin_psi], dim=-1) # (N, L, 1, 3)
|
436 |
+
row3 = torch.cat([zero, sin_psi, cos_psi], dim=-1) # (N, L, 1, 3)
|
437 |
+
R_psi = torch.cat([row1, row2, row3], dim=-2) # (N, L, 3, 3)
|
438 |
+
|
439 |
+
# Compute rotoation and translation of PSI frame, and position of O.
|
440 |
+
R_psi, t_psi = compose_chain([
|
441 |
+
(R, t), # Backbone
|
442 |
+
(R_psi, torch.zeros_like(t)), # PSI angle
|
443 |
+
])
|
444 |
+
O_pos = local_to_global(R_psi, t_psi, oxygen_coord.reshape(N, L, 1, 3))
|
445 |
+
|
446 |
+
bb_pos = torch.cat([bb_pos, O_pos], dim=2) # (N, L, 4, 3)
|
447 |
+
return bb_pos
|
448 |
+
|
449 |
+
|
450 |
+
def reconstruct_backbone_partially(pos_ctx, R_new, t_new, aa, chain_nb, res_nb, mask_atoms, mask_recons):
|
451 |
+
"""
|
452 |
+
Args:
|
453 |
+
pos: (N, L, A, 3).
|
454 |
+
R_new: (N, L, 3, 3).
|
455 |
+
t_new: (N, L, 3).
|
456 |
+
mask_atoms: (N, L, A).
|
457 |
+
mask_recons:(N, L).
|
458 |
+
Returns:
|
459 |
+
pos_new: (N, L, A, 3).
|
460 |
+
mask_new: (N, L, A).
|
461 |
+
"""
|
462 |
+
N, L, A = mask_atoms.size()
|
463 |
+
|
464 |
+
mask_res = mask_atoms[:, :, BBHeavyAtom.CA]
|
465 |
+
pos_recons = reconstruct_backbone(R_new, t_new, aa, chain_nb, res_nb, mask_res) # (N, L, 4, 3)
|
466 |
+
pos_recons = F.pad(pos_recons, pad=(0, 0, 0, A-4), value=0) # (N, L, A, 3)
|
467 |
+
|
468 |
+
pos_new = torch.where(
|
469 |
+
mask_recons[:, :, None, None].expand_as(pos_ctx),
|
470 |
+
pos_recons, pos_ctx
|
471 |
+
) # (N, L, A, 3)
|
472 |
+
|
473 |
+
mask_bb_atoms = torch.zeros_like(mask_atoms)
|
474 |
+
mask_bb_atoms[:, :, :4] = True
|
475 |
+
mask_new = torch.where(
|
476 |
+
mask_recons[:, :, None].expand_as(mask_atoms),
|
477 |
+
mask_bb_atoms, mask_atoms
|
478 |
+
)
|
479 |
+
|
480 |
+
return pos_new, mask_new
|
481 |
+
|
diffab/modules/common/layers.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def mask_zero(mask, value):
|
7 |
+
return torch.where(mask, value, torch.zeros_like(value))
|
8 |
+
|
9 |
+
|
10 |
+
def clampped_one_hot(x, num_classes):
|
11 |
+
mask = (x >= 0) & (x < num_classes) # (N, L)
|
12 |
+
x = x.clamp(min=0, max=num_classes-1)
|
13 |
+
y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C)
|
14 |
+
return y
|
15 |
+
|
16 |
+
|
17 |
+
class DistanceToBins(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False):
|
20 |
+
super().__init__()
|
21 |
+
self.dist_min = dist_min
|
22 |
+
self.dist_max = dist_max
|
23 |
+
self.num_bins = num_bins
|
24 |
+
self.use_onehot = use_onehot
|
25 |
+
|
26 |
+
if use_onehot:
|
27 |
+
offset = torch.linspace(dist_min, dist_max, self.num_bins)
|
28 |
+
else:
|
29 |
+
offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag
|
30 |
+
self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred
|
31 |
+
self.register_buffer('offset', offset)
|
32 |
+
|
33 |
+
@property
|
34 |
+
def out_channels(self):
|
35 |
+
return self.num_bins
|
36 |
+
|
37 |
+
def forward(self, dist, dim, normalize=True):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
dist: (N, *, 1, *)
|
41 |
+
Returns:
|
42 |
+
(N, *, num_bins, *)
|
43 |
+
"""
|
44 |
+
assert dist.size()[dim] == 1
|
45 |
+
offset_shape = [1] * len(dist.size())
|
46 |
+
offset_shape[dim] = -1
|
47 |
+
|
48 |
+
if self.use_onehot:
|
49 |
+
diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *)
|
50 |
+
bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *)
|
51 |
+
y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0)
|
52 |
+
else:
|
53 |
+
overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *)
|
54 |
+
y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *)
|
55 |
+
y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *)
|
56 |
+
y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *)
|
57 |
+
if normalize:
|
58 |
+
y = y / y.sum(dim=dim, keepdim=True)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
|
63 |
+
class PositionalEncoding(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, num_funcs=6):
|
66 |
+
super().__init__()
|
67 |
+
self.num_funcs = num_funcs
|
68 |
+
self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs))
|
69 |
+
|
70 |
+
def get_out_dim(self, in_dim):
|
71 |
+
return in_dim * (2 * self.num_funcs + 1)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
x: (..., d).
|
77 |
+
"""
|
78 |
+
shape = list(x.shape[:-1]) + [-1]
|
79 |
+
x = x.unsqueeze(-1) # (..., d, 1)
|
80 |
+
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
|
81 |
+
code = code.reshape(shape)
|
82 |
+
return code
|
83 |
+
|
84 |
+
|
85 |
+
class AngularEncoding(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, num_funcs=3):
|
88 |
+
super().__init__()
|
89 |
+
self.num_funcs = num_funcs
|
90 |
+
self.register_buffer('freq_bands', torch.FloatTensor(
|
91 |
+
[i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)]
|
92 |
+
))
|
93 |
+
|
94 |
+
def get_out_dim(self, in_dim):
|
95 |
+
return in_dim * (1 + 2*2*self.num_funcs)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
x: (..., d).
|
101 |
+
"""
|
102 |
+
shape = list(x.shape[:-1]) + [-1]
|
103 |
+
x = x.unsqueeze(-1) # (..., d, 1)
|
104 |
+
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
|
105 |
+
code = code.reshape(shape)
|
106 |
+
return code
|
107 |
+
|
108 |
+
|
109 |
+
class LayerNorm(nn.Module):
|
110 |
+
|
111 |
+
def __init__(self,
|
112 |
+
normal_shape,
|
113 |
+
gamma=True,
|
114 |
+
beta=True,
|
115 |
+
epsilon=1e-10):
|
116 |
+
"""Layer normalization layer
|
117 |
+
See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
|
118 |
+
:param normal_shape: The shape of the input tensor or the last dimension of the input tensor.
|
119 |
+
:param gamma: Add a scale parameter if it is True.
|
120 |
+
:param beta: Add an offset parameter if it is True.
|
121 |
+
:param epsilon: Epsilon for calculating variance.
|
122 |
+
"""
|
123 |
+
super().__init__()
|
124 |
+
if isinstance(normal_shape, int):
|
125 |
+
normal_shape = (normal_shape,)
|
126 |
+
else:
|
127 |
+
normal_shape = (normal_shape[-1],)
|
128 |
+
self.normal_shape = torch.Size(normal_shape)
|
129 |
+
self.epsilon = epsilon
|
130 |
+
if gamma:
|
131 |
+
self.gamma = nn.Parameter(torch.Tensor(*normal_shape))
|
132 |
+
else:
|
133 |
+
self.register_parameter('gamma', None)
|
134 |
+
if beta:
|
135 |
+
self.beta = nn.Parameter(torch.Tensor(*normal_shape))
|
136 |
+
else:
|
137 |
+
self.register_parameter('beta', None)
|
138 |
+
self.reset_parameters()
|
139 |
+
|
140 |
+
def reset_parameters(self):
|
141 |
+
if self.gamma is not None:
|
142 |
+
self.gamma.data.fill_(1)
|
143 |
+
if self.beta is not None:
|
144 |
+
self.beta.data.zero_()
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
mean = x.mean(dim=-1, keepdim=True)
|
148 |
+
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
|
149 |
+
std = (var + self.epsilon).sqrt()
|
150 |
+
y = (x - mean) / std
|
151 |
+
if self.gamma is not None:
|
152 |
+
y *= self.gamma
|
153 |
+
if self.beta is not None:
|
154 |
+
y += self.beta
|
155 |
+
return y
|
156 |
+
|
157 |
+
def extra_repr(self):
|
158 |
+
return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format(
|
159 |
+
self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon,
|
160 |
+
)
|
diffab/modules/common/so3.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .geometry import quaternion_to_rotation_matrix
|
8 |
+
|
9 |
+
|
10 |
+
def log_rotation(R):
|
11 |
+
trace = R[..., range(3), range(3)].sum(-1)
|
12 |
+
if torch.is_grad_enabled():
|
13 |
+
# The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999
|
14 |
+
min_cos = -0.999
|
15 |
+
else:
|
16 |
+
min_cos = -1.0
|
17 |
+
cos_theta = ( (trace-1) / 2 ).clamp_min(min=min_cos)
|
18 |
+
sin_theta = torch.sqrt(1 - cos_theta**2)
|
19 |
+
theta = torch.acos(cos_theta)
|
20 |
+
coef = ((theta+1e-8)/(2*sin_theta+2e-8))[..., None, None]
|
21 |
+
logR = coef * (R - R.transpose(-1, -2))
|
22 |
+
return logR
|
23 |
+
|
24 |
+
|
25 |
+
def skewsym_to_so3vec(S):
|
26 |
+
x = S[..., 1, 2]
|
27 |
+
y = S[..., 2, 0]
|
28 |
+
z = S[..., 0, 1]
|
29 |
+
w = torch.stack([x,y,z], dim=-1)
|
30 |
+
return w
|
31 |
+
|
32 |
+
|
33 |
+
def so3vec_to_skewsym(w):
|
34 |
+
x, y, z = torch.unbind(w, dim=-1)
|
35 |
+
o = torch.zeros_like(x)
|
36 |
+
S = torch.stack([
|
37 |
+
o, z, -y,
|
38 |
+
-z, o, x,
|
39 |
+
y, -x, o,
|
40 |
+
], dim=-1).reshape(w.shape[:-1] + (3, 3))
|
41 |
+
return S
|
42 |
+
|
43 |
+
|
44 |
+
def exp_skewsym(S):
|
45 |
+
x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1)
|
46 |
+
I = torch.eye(3).to(S).view([1 for _ in range(S.dim()-2)] + [3, 3])
|
47 |
+
|
48 |
+
sinx, cosx = torch.sin(x), torch.cos(x)
|
49 |
+
b = (sinx + 1e-8) / (x + 1e-8)
|
50 |
+
c = (1-cosx + 1e-8) / (x**2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5
|
51 |
+
|
52 |
+
S2 = S @ S
|
53 |
+
return I + b[..., None, None]*S + c[..., None, None]*S2
|
54 |
+
|
55 |
+
|
56 |
+
def so3vec_to_rotation(w):
|
57 |
+
return exp_skewsym(so3vec_to_skewsym(w))
|
58 |
+
|
59 |
+
|
60 |
+
def rotation_to_so3vec(R):
|
61 |
+
logR = log_rotation(R)
|
62 |
+
w = skewsym_to_so3vec(logR)
|
63 |
+
return w
|
64 |
+
|
65 |
+
|
66 |
+
def random_uniform_so3(size, device='cpu'):
|
67 |
+
q = F.normalize(torch.randn(list(size)+[4,], device=device), dim=-1) # (..., 4)
|
68 |
+
return rotation_to_so3vec(quaternion_to_rotation_matrix(q))
|
69 |
+
|
70 |
+
|
71 |
+
class ApproxAngularDistribution(nn.Module):
|
72 |
+
|
73 |
+
def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024):
|
74 |
+
super().__init__()
|
75 |
+
self.std_threshold = std_threshold
|
76 |
+
self.num_bins = num_bins
|
77 |
+
self.num_iters = num_iters
|
78 |
+
self.register_buffer('stddevs', torch.FloatTensor(stddevs))
|
79 |
+
self.register_buffer('approx_flag', self.stddevs <= std_threshold)
|
80 |
+
self._precompute_histograms()
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def _pdf(x, e, L):
|
84 |
+
"""
|
85 |
+
Args:
|
86 |
+
x: (N, )
|
87 |
+
e: Float
|
88 |
+
L: Integer
|
89 |
+
"""
|
90 |
+
x = x[:, None] # (N, *)
|
91 |
+
c = ((1 - torch.cos(x)) / math.pi) # (N, *)
|
92 |
+
l = torch.arange(0, L)[None, :] # (*, L)
|
93 |
+
a = (2*l+1) * torch.exp(-l*(l+1)*(e**2)) # (*, L)
|
94 |
+
b = (torch.sin( (l+0.5)* x ) + 1e-6) / (torch.sin( x / 2 ) + 1e-6) # (N, L)
|
95 |
+
|
96 |
+
f = (c * a * b).sum(dim=1)
|
97 |
+
return f
|
98 |
+
|
99 |
+
def _precompute_histograms(self):
|
100 |
+
X, Y = [], []
|
101 |
+
for std in self.stddevs:
|
102 |
+
std = std.item()
|
103 |
+
x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,)
|
104 |
+
y = self._pdf(x, std, self.num_iters) # (n_bins,)
|
105 |
+
y = torch.nan_to_num(y).clamp_min(0)
|
106 |
+
X.append(x)
|
107 |
+
Y.append(y)
|
108 |
+
self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins)
|
109 |
+
self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins)
|
110 |
+
|
111 |
+
def sample(self, std_idx):
|
112 |
+
"""
|
113 |
+
Args:
|
114 |
+
std_idx: Indices of standard deviation.
|
115 |
+
Returns:
|
116 |
+
samples: Angular samples [0, PI), same size as std.
|
117 |
+
"""
|
118 |
+
size = std_idx.size()
|
119 |
+
std_idx = std_idx.flatten() # (N,)
|
120 |
+
|
121 |
+
# Samples from histogram
|
122 |
+
prob = self.Y[std_idx] # (N, n_bins)
|
123 |
+
bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,)
|
124 |
+
bin_start = self.X[std_idx, bin_idx] # (N,)
|
125 |
+
bin_width = self.X[std_idx, bin_idx+1] - self.X[std_idx, bin_idx]
|
126 |
+
samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,)
|
127 |
+
|
128 |
+
# Samples from Gaussian approximation
|
129 |
+
mean_gaussian = self.stddevs[std_idx]*2
|
130 |
+
std_gaussian = self.stddevs[std_idx]
|
131 |
+
samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian
|
132 |
+
samples_gaussian = samples_gaussian.abs() % math.pi
|
133 |
+
|
134 |
+
# Choose from histogram or Gaussian
|
135 |
+
gaussian_flag = self.approx_flag[std_idx]
|
136 |
+
samples = torch.where(gaussian_flag, samples_gaussian, samples_hist)
|
137 |
+
|
138 |
+
return samples.reshape(size)
|
139 |
+
|
140 |
+
|
141 |
+
def random_normal_so3(std_idx, angular_distrib, device='cpu'):
|
142 |
+
size = std_idx.size()
|
143 |
+
u = F.normalize(torch.randn(list(size)+[3,], device=device), dim=-1)
|
144 |
+
theta = angular_distrib.sample(std_idx)
|
145 |
+
w = u * theta[..., None]
|
146 |
+
return w
|
diffab/modules/common/structure.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU
|
3 |
+
|
4 |
+
from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix
|
5 |
+
|
6 |
+
|
7 |
+
class FrameRotationTranslationPrediction(Module):
|
8 |
+
|
9 |
+
def __init__(self, feat_dim, rot_repr, nn_type='mlp'):
|
10 |
+
super().__init__()
|
11 |
+
assert rot_repr in ('quaternion', '6d')
|
12 |
+
self.rot_repr = rot_repr
|
13 |
+
if rot_repr == 'quaternion':
|
14 |
+
out_dim = 3 + 3
|
15 |
+
elif rot_repr == '6d':
|
16 |
+
out_dim = 6 + 3
|
17 |
+
|
18 |
+
if nn_type == 'linear':
|
19 |
+
self.nn = Linear(feat_dim, out_dim)
|
20 |
+
elif nn_type == 'mlp':
|
21 |
+
self.nn = Sequential(
|
22 |
+
Linear(feat_dim, feat_dim), ReLU(),
|
23 |
+
Linear(feat_dim, feat_dim), ReLU(),
|
24 |
+
Linear(feat_dim, out_dim)
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
raise ValueError('Unknown nn_type: %s' % nn_type)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.nn(x) # (..., d+3)
|
31 |
+
if self.rot_repr == 'quaternion':
|
32 |
+
quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1)
|
33 |
+
R_delta = quaternion_to_rotation_matrix(quaternion)
|
34 |
+
t_delta = y[..., 3:6]
|
35 |
+
return R_delta, t_delta
|
36 |
+
elif self.rot_repr == '6d':
|
37 |
+
R_delta = repr_6d_to_rotation_matrix(y[..., 0:6])
|
38 |
+
t_delta = y[..., 6:9]
|
39 |
+
return R_delta, t_delta
|
40 |
+
|
41 |
+
|
42 |
+
class FrameUpdate(Module):
|
43 |
+
|
44 |
+
def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'):
|
45 |
+
super().__init__()
|
46 |
+
self.transition_mlp = Sequential(
|
47 |
+
Linear(node_feat_dim, node_feat_dim), ReLU(),
|
48 |
+
Linear(node_feat_dim, node_feat_dim), ReLU(),
|
49 |
+
Linear(node_feat_dim, node_feat_dim),
|
50 |
+
)
|
51 |
+
self.transition_layer_norm = LayerNorm(node_feat_dim)
|
52 |
+
|
53 |
+
self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type)
|
54 |
+
|
55 |
+
def forward(self, R, t, x, mask_generate):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
R: Frame basis matrices, (N, L, 3, 3_index).
|
59 |
+
t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom.
|
60 |
+
x: Node-wise features, (N, L, F).
|
61 |
+
mask_generate: Masks, (N, L).
|
62 |
+
Returns:
|
63 |
+
R': Updated basis matrices, (N, L, 3, 3_index).
|
64 |
+
t': Updated coordinates, (N, L, 3).
|
65 |
+
"""
|
66 |
+
x = self.transition_layer_norm(x + self.transition_mlp(x))
|
67 |
+
|
68 |
+
R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3)
|
69 |
+
R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta)
|
70 |
+
|
71 |
+
mask_R = mask_generate[:, :, None, None].expand_as(R)
|
72 |
+
mask_t = mask_generate[:, :, None].expand_as(t)
|
73 |
+
|
74 |
+
R_new = torch.where(mask_R, R_new, R)
|
75 |
+
t_new = torch.where(mask_t, t_new, t)
|
76 |
+
|
77 |
+
return R_new, t_new
|
diffab/modules/common/topology.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def get_consecutive_flag(chain_nb, res_nb, mask):
|
6 |
+
"""
|
7 |
+
Args:
|
8 |
+
chain_nb, res_nb
|
9 |
+
Returns:
|
10 |
+
consec: A flag tensor indicating whether residue-i is connected to residue-(i+1),
|
11 |
+
BoolTensor, (B, L-1)[b, i].
|
12 |
+
"""
|
13 |
+
d_res_nb = (res_nb[:, 1:] - res_nb[:, :-1]).abs() # (B, L-1)
|
14 |
+
same_chain = (chain_nb[:, 1:] == chain_nb[:, :-1])
|
15 |
+
consec = torch.logical_and(d_res_nb == 1, same_chain)
|
16 |
+
consec = torch.logical_and(consec, mask[:, :-1])
|
17 |
+
return consec
|
18 |
+
|
19 |
+
|
20 |
+
def get_terminus_flag(chain_nb, res_nb, mask):
|
21 |
+
consec = get_consecutive_flag(chain_nb, res_nb, mask)
|
22 |
+
N_term_flag = F.pad(torch.logical_not(consec), pad=(1, 0), value=1)
|
23 |
+
C_term_flag = F.pad(torch.logical_not(consec), pad=(0, 1), value=1)
|
24 |
+
return N_term_flag, C_term_flag
|
diffab/modules/diffusion/dpm_full.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
from diffab.modules.common.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix
|
8 |
+
from diffab.modules.common.so3 import so3vec_to_rotation, rotation_to_so3vec, random_uniform_so3
|
9 |
+
from diffab.modules.encoders.ga import GAEncoder
|
10 |
+
from .transition import RotationTransition, PositionTransition, AminoacidCategoricalTransition
|
11 |
+
|
12 |
+
|
13 |
+
def rotation_matrix_cosine_loss(R_pred, R_true):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
R_pred: (*, 3, 3).
|
17 |
+
R_true: (*, 3, 3).
|
18 |
+
Returns:
|
19 |
+
Per-matrix losses, (*, ).
|
20 |
+
"""
|
21 |
+
size = list(R_pred.shape[:-2])
|
22 |
+
ncol = R_pred.numel() // 3
|
23 |
+
|
24 |
+
RT_pred = R_pred.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
|
25 |
+
RT_true = R_true.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
|
26 |
+
|
27 |
+
ones = torch.ones([ncol, ], dtype=torch.long, device=R_pred.device)
|
28 |
+
loss = F.cosine_embedding_loss(RT_pred, RT_true, ones, reduction='none') # (ncol*3, )
|
29 |
+
loss = loss.reshape(size + [3]).sum(dim=-1) # (*, )
|
30 |
+
return loss
|
31 |
+
|
32 |
+
|
33 |
+
class EpsilonNet(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, res_feat_dim, pair_feat_dim, num_layers, encoder_opt={}):
|
36 |
+
super().__init__()
|
37 |
+
self.current_sequence_embedding = nn.Embedding(25, res_feat_dim) # 22 is padding
|
38 |
+
self.res_feat_mixer = nn.Sequential(
|
39 |
+
nn.Linear(res_feat_dim * 2, res_feat_dim), nn.ReLU(),
|
40 |
+
nn.Linear(res_feat_dim, res_feat_dim),
|
41 |
+
)
|
42 |
+
self.encoder = GAEncoder(res_feat_dim, pair_feat_dim, num_layers, **encoder_opt)
|
43 |
+
|
44 |
+
self.eps_crd_net = nn.Sequential(
|
45 |
+
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
|
46 |
+
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
|
47 |
+
nn.Linear(res_feat_dim, 3)
|
48 |
+
)
|
49 |
+
|
50 |
+
self.eps_rot_net = nn.Sequential(
|
51 |
+
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
|
52 |
+
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
|
53 |
+
nn.Linear(res_feat_dim, 3)
|
54 |
+
)
|
55 |
+
|
56 |
+
self.eps_seq_net = nn.Sequential(
|
57 |
+
nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
|
58 |
+
nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
|
59 |
+
nn.Linear(res_feat_dim, 20), nn.Softmax(dim=-1)
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
v_t: (N, L, 3).
|
66 |
+
p_t: (N, L, 3).
|
67 |
+
s_t: (N, L).
|
68 |
+
res_feat: (N, L, res_dim).
|
69 |
+
pair_feat: (N, L, L, pair_dim).
|
70 |
+
beta: (N,).
|
71 |
+
mask_generate: (N, L).
|
72 |
+
mask_res: (N, L).
|
73 |
+
Returns:
|
74 |
+
v_next: UPDATED (not epsilon) SO3-vector of orietnations, (N, L, 3).
|
75 |
+
eps_pos: (N, L, 3).
|
76 |
+
"""
|
77 |
+
N, L = mask_res.size()
|
78 |
+
R = so3vec_to_rotation(v_t) # (N, L, 3, 3)
|
79 |
+
|
80 |
+
# s_t = s_t.clamp(min=0, max=19) # TODO: clamping is good but ugly.
|
81 |
+
res_feat = self.res_feat_mixer(torch.cat([res_feat, self.current_sequence_embedding(s_t)], dim=-1)) # [Important] Incorporate sequence at the current step.
|
82 |
+
res_feat = self.encoder(R, p_t, res_feat, pair_feat, mask_res)
|
83 |
+
|
84 |
+
t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :].expand(N, L, 3)
|
85 |
+
in_feat = torch.cat([res_feat, t_embed], dim=-1)
|
86 |
+
|
87 |
+
# Position changes
|
88 |
+
eps_crd = self.eps_crd_net(in_feat) # (N, L, 3)
|
89 |
+
eps_pos = apply_rotation_to_vector(R, eps_crd) # (N, L, 3)
|
90 |
+
eps_pos = torch.where(mask_generate[:, :, None].expand_as(eps_pos), eps_pos, torch.zeros_like(eps_pos))
|
91 |
+
|
92 |
+
# New orientation
|
93 |
+
eps_rot = self.eps_rot_net(in_feat) # (N, L, 3)
|
94 |
+
U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (N, L, 3, 3)
|
95 |
+
R_next = R @ U
|
96 |
+
v_next = rotation_to_so3vec(R_next) # (N, L, 3)
|
97 |
+
v_next = torch.where(mask_generate[:, :, None].expand_as(v_next), v_next, v_t)
|
98 |
+
|
99 |
+
# New sequence categorical distributions
|
100 |
+
c_denoised = self.eps_seq_net(in_feat) # Already softmax-ed, (N, L, 20)
|
101 |
+
|
102 |
+
return v_next, R_next, eps_pos, c_denoised
|
103 |
+
|
104 |
+
|
105 |
+
class FullDPM(nn.Module):
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
res_feat_dim,
|
110 |
+
pair_feat_dim,
|
111 |
+
num_steps,
|
112 |
+
eps_net_opt={},
|
113 |
+
trans_rot_opt={},
|
114 |
+
trans_pos_opt={},
|
115 |
+
trans_seq_opt={},
|
116 |
+
position_mean=[0.0, 0.0, 0.0],
|
117 |
+
position_scale=[10.0],
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.eps_net = EpsilonNet(res_feat_dim, pair_feat_dim, **eps_net_opt)
|
121 |
+
self.num_steps = num_steps
|
122 |
+
self.trans_rot = RotationTransition(num_steps, **trans_rot_opt)
|
123 |
+
self.trans_pos = PositionTransition(num_steps, **trans_pos_opt)
|
124 |
+
self.trans_seq = AminoacidCategoricalTransition(num_steps, **trans_seq_opt)
|
125 |
+
|
126 |
+
self.register_buffer('position_mean', torch.FloatTensor(position_mean).view(1, 1, -1))
|
127 |
+
self.register_buffer('position_scale', torch.FloatTensor(position_scale).view(1, 1, -1))
|
128 |
+
self.register_buffer('_dummy', torch.empty([0, ]))
|
129 |
+
|
130 |
+
def _normalize_position(self, p):
|
131 |
+
p_norm = (p - self.position_mean) / self.position_scale
|
132 |
+
return p_norm
|
133 |
+
|
134 |
+
def _unnormalize_position(self, p_norm):
|
135 |
+
p = p_norm * self.position_scale + self.position_mean
|
136 |
+
return p
|
137 |
+
|
138 |
+
def forward(self, v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, denoise_structure, denoise_sequence, t=None):
|
139 |
+
N, L = res_feat.shape[:2]
|
140 |
+
if t == None:
|
141 |
+
t = torch.randint(0, self.num_steps, (N,), dtype=torch.long, device=self._dummy.device)
|
142 |
+
p_0 = self._normalize_position(p_0)
|
143 |
+
|
144 |
+
if denoise_structure:
|
145 |
+
# Add noise to rotation
|
146 |
+
R_0 = so3vec_to_rotation(v_0)
|
147 |
+
v_noisy, _ = self.trans_rot.add_noise(v_0, mask_generate, t)
|
148 |
+
# Add noise to positions
|
149 |
+
p_noisy, eps_p = self.trans_pos.add_noise(p_0, mask_generate, t)
|
150 |
+
else:
|
151 |
+
R_0 = so3vec_to_rotation(v_0)
|
152 |
+
v_noisy = v_0.clone()
|
153 |
+
p_noisy = p_0.clone()
|
154 |
+
eps_p = torch.zeros_like(p_noisy)
|
155 |
+
|
156 |
+
if denoise_sequence:
|
157 |
+
# Add noise to sequence
|
158 |
+
_, s_noisy = self.trans_seq.add_noise(s_0, mask_generate, t)
|
159 |
+
else:
|
160 |
+
s_noisy = s_0.clone()
|
161 |
+
|
162 |
+
beta = self.trans_pos.var_sched.betas[t]
|
163 |
+
v_pred, R_pred, eps_p_pred, c_denoised = self.eps_net(
|
164 |
+
v_noisy, p_noisy, s_noisy, res_feat, pair_feat, beta, mask_generate, mask_res
|
165 |
+
) # (N, L, 3), (N, L, 3, 3), (N, L, 3), (N, L, 20), (N, L)
|
166 |
+
|
167 |
+
loss_dict = {}
|
168 |
+
|
169 |
+
# Rotation loss
|
170 |
+
loss_rot = rotation_matrix_cosine_loss(R_pred, R_0) # (N, L)
|
171 |
+
loss_rot = (loss_rot * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
|
172 |
+
loss_dict['rot'] = loss_rot
|
173 |
+
|
174 |
+
# Position loss
|
175 |
+
loss_pos = F.mse_loss(eps_p_pred, eps_p, reduction='none').sum(dim=-1) # (N, L)
|
176 |
+
loss_pos = (loss_pos * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
|
177 |
+
loss_dict['pos'] = loss_pos
|
178 |
+
|
179 |
+
# Sequence categorical loss
|
180 |
+
post_true = self.trans_seq.posterior(s_noisy, s_0, t)
|
181 |
+
log_post_pred = torch.log(self.trans_seq.posterior(s_noisy, c_denoised, t) + 1e-8)
|
182 |
+
kldiv = F.kl_div(
|
183 |
+
input=log_post_pred,
|
184 |
+
target=post_true,
|
185 |
+
reduction='none',
|
186 |
+
log_target=False
|
187 |
+
).sum(dim=-1) # (N, L)
|
188 |
+
loss_seq = (kldiv * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
|
189 |
+
loss_dict['seq'] = loss_seq
|
190 |
+
|
191 |
+
return loss_dict
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def sample(
|
195 |
+
self,
|
196 |
+
v, p, s,
|
197 |
+
res_feat, pair_feat,
|
198 |
+
mask_generate, mask_res,
|
199 |
+
sample_structure=True, sample_sequence=True,
|
200 |
+
pbar=False,
|
201 |
+
):
|
202 |
+
"""
|
203 |
+
Args:
|
204 |
+
v: Orientations of contextual residues, (N, L, 3).
|
205 |
+
p: Positions of contextual residues, (N, L, 3).
|
206 |
+
s: Sequence of contextual residues, (N, L).
|
207 |
+
"""
|
208 |
+
N, L = v.shape[:2]
|
209 |
+
p = self._normalize_position(p)
|
210 |
+
|
211 |
+
# Set the orientation and position of residues to be predicted to random values
|
212 |
+
if sample_structure:
|
213 |
+
v_rand = random_uniform_so3([N, L], device=self._dummy.device)
|
214 |
+
p_rand = torch.randn_like(p)
|
215 |
+
v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_rand, v)
|
216 |
+
p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_rand, p)
|
217 |
+
else:
|
218 |
+
v_init, p_init = v, p
|
219 |
+
|
220 |
+
if sample_sequence:
|
221 |
+
s_rand = torch.randint_like(s, low=0, high=19)
|
222 |
+
s_init = torch.where(mask_generate, s_rand, s)
|
223 |
+
else:
|
224 |
+
s_init = s
|
225 |
+
|
226 |
+
traj = {self.num_steps: (v_init, self._unnormalize_position(p_init), s_init)}
|
227 |
+
if pbar:
|
228 |
+
pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
|
229 |
+
else:
|
230 |
+
pbar = lambda x: x
|
231 |
+
for t in pbar(range(self.num_steps, 0, -1)):
|
232 |
+
v_t, p_t, s_t = traj[t]
|
233 |
+
p_t = self._normalize_position(p_t)
|
234 |
+
|
235 |
+
beta = self.trans_pos.var_sched.betas[t].expand([N, ])
|
236 |
+
t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
|
237 |
+
|
238 |
+
v_next, R_next, eps_p, c_denoised = self.eps_net(
|
239 |
+
v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
|
240 |
+
) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
|
241 |
+
|
242 |
+
v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
|
243 |
+
p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
|
244 |
+
_, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
|
245 |
+
|
246 |
+
if not sample_structure:
|
247 |
+
v_next, p_next = v_t, p_t
|
248 |
+
if not sample_sequence:
|
249 |
+
s_next = s_t
|
250 |
+
|
251 |
+
traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
|
252 |
+
traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
|
253 |
+
|
254 |
+
return traj
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def optimize(
|
258 |
+
self,
|
259 |
+
v, p, s,
|
260 |
+
opt_step: int,
|
261 |
+
res_feat, pair_feat,
|
262 |
+
mask_generate, mask_res,
|
263 |
+
sample_structure=True, sample_sequence=True,
|
264 |
+
pbar=False,
|
265 |
+
):
|
266 |
+
"""
|
267 |
+
Description:
|
268 |
+
First adds noise to the given structure, then denoises it.
|
269 |
+
"""
|
270 |
+
N, L = v.shape[:2]
|
271 |
+
p = self._normalize_position(p)
|
272 |
+
t = torch.full([N, ], fill_value=opt_step, dtype=torch.long, device=self._dummy.device)
|
273 |
+
|
274 |
+
# Set the orientation and position of residues to be predicted to random values
|
275 |
+
if sample_structure:
|
276 |
+
# Add noise to rotation
|
277 |
+
v_noisy, _ = self.trans_rot.add_noise(v, mask_generate, t)
|
278 |
+
# Add noise to positions
|
279 |
+
p_noisy, _ = self.trans_pos.add_noise(p, mask_generate, t)
|
280 |
+
v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_noisy, v)
|
281 |
+
p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_noisy, p)
|
282 |
+
else:
|
283 |
+
v_init, p_init = v, p
|
284 |
+
|
285 |
+
if sample_sequence:
|
286 |
+
_, s_noisy = self.trans_seq.add_noise(s, mask_generate, t)
|
287 |
+
s_init = torch.where(mask_generate, s_noisy, s)
|
288 |
+
else:
|
289 |
+
s_init = s
|
290 |
+
|
291 |
+
traj = {opt_step: (v_init, self._unnormalize_position(p_init), s_init)}
|
292 |
+
if pbar:
|
293 |
+
pbar = functools.partial(tqdm, total=opt_step, desc='Optimizing')
|
294 |
+
else:
|
295 |
+
pbar = lambda x: x
|
296 |
+
for t in pbar(range(opt_step, 0, -1)):
|
297 |
+
v_t, p_t, s_t = traj[t]
|
298 |
+
p_t = self._normalize_position(p_t)
|
299 |
+
|
300 |
+
beta = self.trans_pos.var_sched.betas[t].expand([N, ])
|
301 |
+
t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
|
302 |
+
|
303 |
+
v_next, R_next, eps_p, c_denoised = self.eps_net(
|
304 |
+
v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
|
305 |
+
) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
|
306 |
+
|
307 |
+
v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
|
308 |
+
p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
|
309 |
+
_, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
|
310 |
+
|
311 |
+
if not sample_structure:
|
312 |
+
v_next, p_next = v_t, p_t
|
313 |
+
if not sample_sequence:
|
314 |
+
s_next = s_t
|
315 |
+
|
316 |
+
traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
|
317 |
+
traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
|
318 |
+
|
319 |
+
return traj
|
diffab/modules/diffusion/transition.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from diffab.modules.common.layers import clampped_one_hot
|
7 |
+
from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec
|
8 |
+
|
9 |
+
|
10 |
+
class VarianceSchedule(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, num_steps=100, s=0.01):
|
13 |
+
super().__init__()
|
14 |
+
T = num_steps
|
15 |
+
t = torch.arange(0, num_steps+1, dtype=torch.float)
|
16 |
+
f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2
|
17 |
+
alpha_bars = f_t / f_t[0]
|
18 |
+
|
19 |
+
betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
|
20 |
+
betas = torch.cat([torch.zeros([1]), betas], dim=0)
|
21 |
+
betas = betas.clamp_max(0.999)
|
22 |
+
|
23 |
+
sigmas = torch.zeros_like(betas)
|
24 |
+
for i in range(1, betas.size(0)):
|
25 |
+
sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
|
26 |
+
sigmas = torch.sqrt(sigmas)
|
27 |
+
|
28 |
+
self.register_buffer('betas', betas)
|
29 |
+
self.register_buffer('alpha_bars', alpha_bars)
|
30 |
+
self.register_buffer('alphas', 1 - betas)
|
31 |
+
self.register_buffer('sigmas', sigmas)
|
32 |
+
|
33 |
+
|
34 |
+
class PositionTransition(nn.Module):
|
35 |
+
|
36 |
+
def __init__(self, num_steps, var_sched_opt={}):
|
37 |
+
super().__init__()
|
38 |
+
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
|
39 |
+
|
40 |
+
def add_noise(self, p_0, mask_generate, t):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
p_0: (N, L, 3).
|
44 |
+
mask_generate: (N, L).
|
45 |
+
t: (N,).
|
46 |
+
"""
|
47 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
48 |
+
|
49 |
+
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1)
|
50 |
+
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)
|
51 |
+
|
52 |
+
e_rand = torch.randn_like(p_0)
|
53 |
+
p_noisy = c0*p_0 + c1*e_rand
|
54 |
+
p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0)
|
55 |
+
|
56 |
+
return p_noisy, e_rand
|
57 |
+
|
58 |
+
def denoise(self, p_t, eps_p, mask_generate, t):
|
59 |
+
# IMPORTANT:
|
60 |
+
# clampping alpha is to fix the instability issue at the first step (t=T)
|
61 |
+
# it seems like a problem with the ``improved ddpm''.
|
62 |
+
alpha = self.var_sched.alphas[t].clamp_min(
|
63 |
+
self.var_sched.alphas[-2]
|
64 |
+
)
|
65 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
66 |
+
sigma = self.var_sched.sigmas[t].view(-1, 1, 1)
|
67 |
+
|
68 |
+
c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1)
|
69 |
+
c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1)
|
70 |
+
|
71 |
+
z = torch.where(
|
72 |
+
(t > 1)[:, None, None].expand_as(p_t),
|
73 |
+
torch.randn_like(p_t),
|
74 |
+
torch.zeros_like(p_t),
|
75 |
+
)
|
76 |
+
|
77 |
+
p_next = c0 * (p_t - c1 * eps_p) + sigma * z
|
78 |
+
p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t)
|
79 |
+
return p_next
|
80 |
+
|
81 |
+
|
82 |
+
class RotationTransition(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}):
|
85 |
+
super().__init__()
|
86 |
+
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
|
87 |
+
|
88 |
+
# Forward (perturb)
|
89 |
+
c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,).
|
90 |
+
self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt)
|
91 |
+
|
92 |
+
# Inverse (generate)
|
93 |
+
sigma = self.var_sched.sigmas
|
94 |
+
self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt)
|
95 |
+
|
96 |
+
self.register_buffer('_dummy', torch.empty([0, ]))
|
97 |
+
|
98 |
+
def add_noise(self, v_0, mask_generate, t):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
v_0: (N, L, 3).
|
102 |
+
mask_generate: (N, L).
|
103 |
+
t: (N,).
|
104 |
+
"""
|
105 |
+
N, L = mask_generate.size()
|
106 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
107 |
+
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1)
|
108 |
+
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)
|
109 |
+
|
110 |
+
# Noise rotation
|
111 |
+
e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3)
|
112 |
+
e_normal = e_scaled / (c1 + 1e-8)
|
113 |
+
E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3)
|
114 |
+
|
115 |
+
# Scaled true rotation
|
116 |
+
R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3)
|
117 |
+
|
118 |
+
R_noisy = E_scaled @ R0_scaled
|
119 |
+
v_noisy = rotation_to_so3vec(R_noisy)
|
120 |
+
v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0)
|
121 |
+
|
122 |
+
return v_noisy, e_scaled
|
123 |
+
|
124 |
+
def denoise(self, v_t, v_next, mask_generate, t):
|
125 |
+
N, L = mask_generate.size()
|
126 |
+
e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3)
|
127 |
+
e = torch.where(
|
128 |
+
(t > 1)[:, None, None].expand(N, L, 3),
|
129 |
+
e,
|
130 |
+
torch.zeros_like(e) # Simply denoise and don't add noise at the last step
|
131 |
+
)
|
132 |
+
E = so3vec_to_rotation(e)
|
133 |
+
|
134 |
+
R_next = E @ so3vec_to_rotation(v_next)
|
135 |
+
v_next = rotation_to_so3vec(R_next)
|
136 |
+
v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t)
|
137 |
+
|
138 |
+
return v_next
|
139 |
+
|
140 |
+
|
141 |
+
class AminoacidCategoricalTransition(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, num_steps, num_classes=20, var_sched_opt={}):
|
144 |
+
super().__init__()
|
145 |
+
self.num_classes = num_classes
|
146 |
+
self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def _sample(c):
|
150 |
+
"""
|
151 |
+
Args:
|
152 |
+
c: (N, L, K).
|
153 |
+
Returns:
|
154 |
+
x: (N, L).
|
155 |
+
"""
|
156 |
+
N, L, K = c.size()
|
157 |
+
c = c.view(N*L, K) + 1e-8
|
158 |
+
x = torch.multinomial(c, 1).view(N, L)
|
159 |
+
return x
|
160 |
+
|
161 |
+
def add_noise(self, x_0, mask_generate, t):
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
x_0: (N, L)
|
165 |
+
mask_generate: (N, L).
|
166 |
+
t: (N,).
|
167 |
+
Returns:
|
168 |
+
c_t: Probability, (N, L, K).
|
169 |
+
x_t: Sample, LongTensor, (N, L).
|
170 |
+
"""
|
171 |
+
N, L = x_0.size()
|
172 |
+
K = self.num_classes
|
173 |
+
c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K).
|
174 |
+
alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
|
175 |
+
c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K )
|
176 |
+
c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0)
|
177 |
+
x_t = self._sample(c_t)
|
178 |
+
return c_t, x_t
|
179 |
+
|
180 |
+
def posterior(self, x_t, x_0, t):
|
181 |
+
"""
|
182 |
+
Args:
|
183 |
+
x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K).
|
184 |
+
x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K).
|
185 |
+
t: (N,).
|
186 |
+
Returns:
|
187 |
+
theta: Posterior probability at (t-1)-th step, (N, L, K).
|
188 |
+
"""
|
189 |
+
K = self.num_classes
|
190 |
+
|
191 |
+
if x_t.dim() == 3:
|
192 |
+
c_t = x_t # When x_t is probability distribution.
|
193 |
+
else:
|
194 |
+
c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K)
|
195 |
+
|
196 |
+
if x_0.dim() == 3:
|
197 |
+
c_0 = x_0 # When x_0 is probability distribution.
|
198 |
+
else:
|
199 |
+
c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K)
|
200 |
+
|
201 |
+
alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
|
202 |
+
alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
|
203 |
+
|
204 |
+
theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K)
|
205 |
+
theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8)
|
206 |
+
return theta
|
207 |
+
|
208 |
+
def denoise(self, x_t, c_0_pred, mask_generate, t):
|
209 |
+
"""
|
210 |
+
Args:
|
211 |
+
x_t: (N, L).
|
212 |
+
c_0_pred: Normalized probability predicted by networks, (N, L, K).
|
213 |
+
mask_generate: (N, L).
|
214 |
+
t: (N,).
|
215 |
+
Returns:
|
216 |
+
post: Posterior probability at (t-1)-th step, (N, L, K).
|
217 |
+
x_next: Sample at (t-1)-th step, LongTensor, (N, L).
|
218 |
+
"""
|
219 |
+
c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K)
|
220 |
+
post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K)
|
221 |
+
post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t)
|
222 |
+
x_next = self._sample(post)
|
223 |
+
return post, x_next
|
diffab/modules/encoders/ga.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm
|
7 |
+
from diffab.modules.common.layers import mask_zero, LayerNorm
|
8 |
+
from diffab.utils.protein.constants import BBHeavyAtom
|
9 |
+
|
10 |
+
|
11 |
+
def _alpha_from_logits(logits, mask, inf=1e5):
|
12 |
+
"""
|
13 |
+
Args:
|
14 |
+
logits: Logit matrices, (N, L_i, L_j, num_heads).
|
15 |
+
mask: Masks, (N, L).
|
16 |
+
Returns:
|
17 |
+
alpha: Attention weights.
|
18 |
+
"""
|
19 |
+
N, L, _, _ = logits.size()
|
20 |
+
mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *)
|
21 |
+
mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *)
|
22 |
+
|
23 |
+
logits = torch.where(mask_pair, logits, logits - inf)
|
24 |
+
alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads)
|
25 |
+
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
|
26 |
+
return alpha
|
27 |
+
|
28 |
+
|
29 |
+
def _heads(x, n_heads, n_ch):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
x: (..., num_heads * num_channels)
|
33 |
+
Returns:
|
34 |
+
(..., num_heads, num_channels)
|
35 |
+
"""
|
36 |
+
s = list(x.size())[:-1] + [n_heads, n_ch]
|
37 |
+
return x.view(*s)
|
38 |
+
|
39 |
+
|
40 |
+
class GABlock(nn.Module):
|
41 |
+
|
42 |
+
def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8,
|
43 |
+
num_value_points=8, num_heads=12, bias=False):
|
44 |
+
super().__init__()
|
45 |
+
self.node_feat_dim = node_feat_dim
|
46 |
+
self.pair_feat_dim = pair_feat_dim
|
47 |
+
self.value_dim = value_dim
|
48 |
+
self.query_key_dim = query_key_dim
|
49 |
+
self.num_query_points = num_query_points
|
50 |
+
self.num_value_points = num_value_points
|
51 |
+
self.num_heads = num_heads
|
52 |
+
|
53 |
+
# Node
|
54 |
+
self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
|
55 |
+
self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
|
56 |
+
self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)
|
57 |
+
|
58 |
+
# Pair
|
59 |
+
self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)
|
60 |
+
|
61 |
+
# Spatial
|
62 |
+
self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)),
|
63 |
+
requires_grad=True)
|
64 |
+
self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
|
65 |
+
self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
|
66 |
+
self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)
|
67 |
+
|
68 |
+
# Output
|
69 |
+
self.out_transform = nn.Linear(
|
70 |
+
in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + (
|
71 |
+
num_heads * num_value_points * (3 + 3 + 1)),
|
72 |
+
out_features=node_feat_dim,
|
73 |
+
)
|
74 |
+
|
75 |
+
self.layer_norm_1 = LayerNorm(node_feat_dim)
|
76 |
+
self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
|
77 |
+
nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
|
78 |
+
nn.Linear(node_feat_dim, node_feat_dim))
|
79 |
+
self.layer_norm_2 = LayerNorm(node_feat_dim)
|
80 |
+
|
81 |
+
def _node_logits(self, x):
|
82 |
+
query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
|
83 |
+
key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
|
84 |
+
logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) *
|
85 |
+
(1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads)
|
86 |
+
return logits_node
|
87 |
+
|
88 |
+
def _pair_logits(self, z):
|
89 |
+
logits_pair = self.proj_pair_bias(z)
|
90 |
+
return logits_pair
|
91 |
+
|
92 |
+
def _spatial_logits(self, R, t, x):
|
93 |
+
N, L, _ = t.size()
|
94 |
+
|
95 |
+
# Query
|
96 |
+
query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points,
|
97 |
+
3) # (N, L, n_heads * n_pnts, 3)
|
98 |
+
query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3)
|
99 |
+
query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
|
100 |
+
|
101 |
+
# Key
|
102 |
+
key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points,
|
103 |
+
3) # (N, L, 3, n_heads * n_pnts)
|
104 |
+
key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3)
|
105 |
+
key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
|
106 |
+
|
107 |
+
# Q-K Product
|
108 |
+
sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads)
|
109 |
+
gamma = F.softplus(self.spatial_coef)
|
110 |
+
logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points)))
|
111 |
+
/ 2) # (N, L, L, n_heads)
|
112 |
+
return logits_spatial
|
113 |
+
|
114 |
+
def _pair_aggregation(self, alpha, z):
|
115 |
+
N, L = z.shape[:2]
|
116 |
+
feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C)
|
117 |
+
feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C)
|
118 |
+
return feat_p2n.reshape(N, L, -1)
|
119 |
+
|
120 |
+
def _node_aggregation(self, alpha, x):
|
121 |
+
N, L = x.shape[:2]
|
122 |
+
value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch)
|
123 |
+
feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch)
|
124 |
+
feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch)
|
125 |
+
return feat_node.reshape(N, L, -1)
|
126 |
+
|
127 |
+
def _spatial_aggregation(self, alpha, R, t, x):
|
128 |
+
N, L, _ = t.size()
|
129 |
+
value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points,
|
130 |
+
3) # (N, L, n_heads * n_v_pnts, 3)
|
131 |
+
value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points,
|
132 |
+
3)) # (N, L, n_heads, n_v_pnts, 3)
|
133 |
+
aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \
|
134 |
+
value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3)
|
135 |
+
aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3)
|
136 |
+
|
137 |
+
feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3)
|
138 |
+
feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts)
|
139 |
+
feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3)
|
140 |
+
|
141 |
+
feat_spatial = torch.cat([
|
142 |
+
feat_points.reshape(N, L, -1),
|
143 |
+
feat_distance.reshape(N, L, -1),
|
144 |
+
feat_direction.reshape(N, L, -1),
|
145 |
+
], dim=-1)
|
146 |
+
|
147 |
+
return feat_spatial
|
148 |
+
|
149 |
+
def forward(self, R, t, x, z, mask):
|
150 |
+
"""
|
151 |
+
Args:
|
152 |
+
R: Frame basis matrices, (N, L, 3, 3_index).
|
153 |
+
t: Frame external (absolute) coordinates, (N, L, 3).
|
154 |
+
x: Node-wise features, (N, L, F).
|
155 |
+
z: Pair-wise features, (N, L, L, C).
|
156 |
+
mask: Masks, (N, L).
|
157 |
+
Returns:
|
158 |
+
x': Updated node-wise features, (N, L, F).
|
159 |
+
"""
|
160 |
+
# Attention logits
|
161 |
+
logits_node = self._node_logits(x)
|
162 |
+
logits_pair = self._pair_logits(z)
|
163 |
+
logits_spatial = self._spatial_logits(R, t, x)
|
164 |
+
# Summing logits up and apply `softmax`.
|
165 |
+
logits_sum = logits_node + logits_pair + logits_spatial
|
166 |
+
alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads)
|
167 |
+
|
168 |
+
# Aggregate features
|
169 |
+
feat_p2n = self._pair_aggregation(alpha, z)
|
170 |
+
feat_node = self._node_aggregation(alpha, x)
|
171 |
+
feat_spatial = self._spatial_aggregation(alpha, R, t, x)
|
172 |
+
|
173 |
+
# Finally
|
174 |
+
feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F)
|
175 |
+
feat_all = mask_zero(mask.unsqueeze(-1), feat_all)
|
176 |
+
x_updated = self.layer_norm_1(x + feat_all)
|
177 |
+
x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
|
178 |
+
return x_updated
|
179 |
+
|
180 |
+
|
181 |
+
class GAEncoder(nn.Module):
|
182 |
+
|
183 |
+
def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
|
184 |
+
super(GAEncoder, self).__init__()
|
185 |
+
self.blocks = nn.ModuleList([
|
186 |
+
GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt)
|
187 |
+
for _ in range(num_layers)
|
188 |
+
])
|
189 |
+
|
190 |
+
def forward(self, R, t, res_feat, pair_feat, mask):
|
191 |
+
for i, block in enumerate(self.blocks):
|
192 |
+
res_feat = block(R, t, res_feat, pair_feat, mask)
|
193 |
+
return res_feat
|
diffab/modules/encoders/pair.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from diffab.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals
|
6 |
+
from diffab.modules.common.layers import AngularEncoding
|
7 |
+
from diffab.utils.protein.constants import BBHeavyAtom, AA
|
8 |
+
|
9 |
+
|
10 |
+
class PairEmbedding(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32):
|
13 |
+
super().__init__()
|
14 |
+
self.max_num_atoms = max_num_atoms
|
15 |
+
self.max_aa_types = max_aa_types
|
16 |
+
self.max_relpos = max_relpos
|
17 |
+
self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim)
|
18 |
+
self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim)
|
19 |
+
|
20 |
+
self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms)
|
21 |
+
nn.init.zeros_(self.aapair_to_distcoef.weight)
|
22 |
+
self.distance_embed = nn.Sequential(
|
23 |
+
nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(),
|
24 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
25 |
+
)
|
26 |
+
|
27 |
+
self.dihedral_embed = AngularEncoding()
|
28 |
+
feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi
|
29 |
+
|
30 |
+
infeat_dim = feat_dim+feat_dim+feat_dim+feat_dihed_dim
|
31 |
+
self.out_mlp = nn.Sequential(
|
32 |
+
nn.Linear(infeat_dim, feat_dim), nn.ReLU(),
|
33 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
34 |
+
nn.Linear(feat_dim, feat_dim),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
aa: (N, L).
|
41 |
+
res_nb: (N, L).
|
42 |
+
chain_nb: (N, L).
|
43 |
+
pos_atoms: (N, L, A, 3)
|
44 |
+
mask_atoms: (N, L, A)
|
45 |
+
structure_mask: (N, L)
|
46 |
+
sequence_mask: (N, L), mask out unknown amino acids to generate.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
(N, L, L, feat_dim)
|
50 |
+
"""
|
51 |
+
N, L = aa.size()
|
52 |
+
|
53 |
+
# Remove other atoms
|
54 |
+
pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
|
55 |
+
mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
|
56 |
+
|
57 |
+
mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
|
58 |
+
mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :]
|
59 |
+
pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None
|
60 |
+
|
61 |
+
# Pair identities
|
62 |
+
if sequence_mask is not None:
|
63 |
+
# Avoid data leakage at training time
|
64 |
+
aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
|
65 |
+
aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L)
|
66 |
+
feat_aapair = self.aa_pair_embed(aa_pair)
|
67 |
+
|
68 |
+
# Relative sequential positions
|
69 |
+
same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :])
|
70 |
+
relpos = torch.clamp(
|
71 |
+
res_nb[:,:,None] - res_nb[:,None,:],
|
72 |
+
min=-self.max_relpos, max=self.max_relpos,
|
73 |
+
) # (N, L, L)
|
74 |
+
feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None]
|
75 |
+
|
76 |
+
# Distances
|
77 |
+
d = angstrom_to_nm(torch.linalg.norm(
|
78 |
+
pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:],
|
79 |
+
dim = -1, ord = 2,
|
80 |
+
)).reshape(N, L, L, -1) # (N, L, L, A*A)
|
81 |
+
c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A)
|
82 |
+
d_gauss = torch.exp(-1 * c * d**2)
|
83 |
+
mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1)
|
84 |
+
feat_dist = self.distance_embed(d_gauss * mask_atom_pair)
|
85 |
+
if pair_structure_mask is not None:
|
86 |
+
# Avoid data leakage at training time
|
87 |
+
feat_dist = feat_dist * pair_structure_mask[:, :, :, None]
|
88 |
+
|
89 |
+
# Orientations
|
90 |
+
dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2)
|
91 |
+
feat_dihed = self.dihedral_embed(dihed)
|
92 |
+
if pair_structure_mask is not None:
|
93 |
+
# Avoid data leakage at training time
|
94 |
+
feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None]
|
95 |
+
|
96 |
+
# All
|
97 |
+
feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1)
|
98 |
+
feat_all = self.out_mlp(feat_all) # (N, L, L, F)
|
99 |
+
feat_all = feat_all * mask_pair[:, :, :, None]
|
100 |
+
|
101 |
+
return feat_all
|
102 |
+
|
diffab/modules/encoders/residue.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from diffab.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles
|
5 |
+
from diffab.modules.common.layers import AngularEncoding
|
6 |
+
from diffab.utils.protein.constants import BBHeavyAtom, AA
|
7 |
+
|
8 |
+
|
9 |
+
class ResidueEmbedding(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, feat_dim, max_num_atoms, max_aa_types=22):
|
12 |
+
super().__init__()
|
13 |
+
self.max_num_atoms = max_num_atoms
|
14 |
+
self.max_aa_types = max_aa_types
|
15 |
+
self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim)
|
16 |
+
self.dihed_embed = AngularEncoding()
|
17 |
+
self.type_embed = nn.Embedding(10, feat_dim, padding_idx=0) # 1: Heavy, 2: Light, 3: Ag
|
18 |
+
infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3) + feat_dim
|
19 |
+
self.mlp = nn.Sequential(
|
20 |
+
nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(),
|
21 |
+
nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(),
|
22 |
+
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
|
23 |
+
nn.Linear(feat_dim, feat_dim)
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, fragment_type, structure_mask=None, sequence_mask=None):
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
aa: (N, L).
|
30 |
+
res_nb: (N, L).
|
31 |
+
chain_nb: (N, L).
|
32 |
+
pos_atoms: (N, L, A, 3).
|
33 |
+
mask_atoms: (N, L, A).
|
34 |
+
fragment_type: (N, L).
|
35 |
+
structure_mask: (N, L), mask out unknown structures to generate.
|
36 |
+
sequence_mask: (N, L), mask out unknown amino acids to generate.
|
37 |
+
"""
|
38 |
+
N, L = aa.size()
|
39 |
+
mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
|
40 |
+
|
41 |
+
# Remove other atoms
|
42 |
+
pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
|
43 |
+
mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
|
44 |
+
|
45 |
+
# Amino acid identity features
|
46 |
+
if sequence_mask is not None:
|
47 |
+
# Avoid data leakage at training time
|
48 |
+
aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
|
49 |
+
aa_feat = self.aatype_embed(aa) # (N, L, feat)
|
50 |
+
|
51 |
+
# Coordinate features
|
52 |
+
R = construct_3d_basis(
|
53 |
+
pos_atoms[:, :, BBHeavyAtom.CA],
|
54 |
+
pos_atoms[:, :, BBHeavyAtom.C],
|
55 |
+
pos_atoms[:, :, BBHeavyAtom.N]
|
56 |
+
)
|
57 |
+
t = pos_atoms[:, :, BBHeavyAtom.CA]
|
58 |
+
crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3)
|
59 |
+
crd_mask = mask_atoms[:, :, :, None].expand_as(crd)
|
60 |
+
crd = torch.where(crd_mask, crd, torch.zeros_like(crd))
|
61 |
+
|
62 |
+
aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
|
63 |
+
rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand)
|
64 |
+
place_mask = (aa_expand == rng_expand)
|
65 |
+
crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
|
66 |
+
crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand))
|
67 |
+
crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3)
|
68 |
+
if structure_mask is not None:
|
69 |
+
# Avoid data leakage at training time
|
70 |
+
crd_feat = crd_feat * structure_mask[:, :, None]
|
71 |
+
|
72 |
+
# Backbone dihedral features
|
73 |
+
bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue)
|
74 |
+
dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3)
|
75 |
+
dihed_feat = dihed_feat.reshape(N, L, -1)
|
76 |
+
if structure_mask is not None:
|
77 |
+
# Avoid data leakage at training time
|
78 |
+
dihed_mask = torch.logical_and(
|
79 |
+
structure_mask,
|
80 |
+
torch.logical_and(
|
81 |
+
torch.roll(structure_mask, shifts=+1, dims=1),
|
82 |
+
torch.roll(structure_mask, shifts=-1, dims=1)
|
83 |
+
),
|
84 |
+
) # Avoid slight data leakage via dihedral angles of anchor residues
|
85 |
+
dihed_feat = dihed_feat * dihed_mask[:, :, None]
|
86 |
+
|
87 |
+
# Type feature
|
88 |
+
type_feat = self.type_embed(fragment_type) # (N, L, feat)
|
89 |
+
|
90 |
+
out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat, type_feat], dim=-1)) # (N, L, F)
|
91 |
+
out_feat = out_feat * mask_residue[:, :, None]
|
92 |
+
return out_feat
|
diffab/tools/dock/base.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
FilePath = str
|
6 |
+
|
7 |
+
|
8 |
+
class DockingEngine(abc.ABC):
|
9 |
+
|
10 |
+
@abc.abstractmethod
|
11 |
+
def __enter__(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
@abc.abstractmethod
|
15 |
+
def __exit__(self, typ, value, traceback):
|
16 |
+
pass
|
17 |
+
|
18 |
+
@abc.abstractmethod
|
19 |
+
def set_receptor(self, pdb_path: FilePath):
|
20 |
+
pass
|
21 |
+
|
22 |
+
@abc.abstractmethod
|
23 |
+
def set_ligand(self, pdb_path: FilePath):
|
24 |
+
pass
|
25 |
+
|
26 |
+
@abc.abstractmethod
|
27 |
+
def dock(self) -> List[FilePath]:
|
28 |
+
pass
|
diffab/tools/dock/hdock.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import tempfile
|
4 |
+
import subprocess
|
5 |
+
import dataclasses as dc
|
6 |
+
from typing import List, Optional
|
7 |
+
from Bio import PDB
|
8 |
+
from Bio.PDB import Model as PDBModel
|
9 |
+
|
10 |
+
from diffab.tools.renumber import renumber as renumber_chothia
|
11 |
+
from .base import DockingEngine
|
12 |
+
|
13 |
+
|
14 |
+
def fix_docked_pdb(pdb_path):
|
15 |
+
fixed = []
|
16 |
+
with open(pdb_path, 'r') as f:
|
17 |
+
for ln in f.readlines():
|
18 |
+
if (ln.startswith('ATOM') or ln.startswith('HETATM')) and len(ln) == 56:
|
19 |
+
fixed.append( ln[:-1] + ' 1.00 0.00 \n' )
|
20 |
+
else:
|
21 |
+
fixed.append(ln)
|
22 |
+
with open(pdb_path, 'w') as f:
|
23 |
+
f.write(''.join(fixed))
|
24 |
+
|
25 |
+
|
26 |
+
class HDock(DockingEngine):
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hdock_bin='./bin/hdock',
|
31 |
+
createpl_bin='./bin/createpl',
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.hdock_bin = os.path.realpath(hdock_bin)
|
35 |
+
self.createpl_bin = os.path.realpath(createpl_bin)
|
36 |
+
self.tmpdir = tempfile.TemporaryDirectory()
|
37 |
+
|
38 |
+
self._has_receptor = False
|
39 |
+
self._has_ligand = False
|
40 |
+
|
41 |
+
self._receptor_chains = []
|
42 |
+
self._ligand_chains = []
|
43 |
+
|
44 |
+
def __enter__(self):
|
45 |
+
return self
|
46 |
+
|
47 |
+
def __exit__(self, typ, value, traceback):
|
48 |
+
self.tmpdir.cleanup()
|
49 |
+
|
50 |
+
def set_receptor(self, pdb_path):
|
51 |
+
shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'receptor.pdb'))
|
52 |
+
self._has_receptor = True
|
53 |
+
|
54 |
+
def set_ligand(self, pdb_path):
|
55 |
+
shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb'))
|
56 |
+
self._has_ligand = True
|
57 |
+
|
58 |
+
def _dump_complex_pdb(self):
|
59 |
+
parser = PDB.PDBParser(QUIET=True)
|
60 |
+
model_receptor = parser.get_structure(None, os.path.join(self.tmpdir.name, 'receptor.pdb'))[0]
|
61 |
+
docked_pdb_path = os.path.join(self.tmpdir.name, 'ligand_docked.pdb')
|
62 |
+
fix_docked_pdb(docked_pdb_path)
|
63 |
+
structure_ligdocked = parser.get_structure(None, docked_pdb_path)
|
64 |
+
|
65 |
+
pdb_io = PDB.PDBIO()
|
66 |
+
paths = []
|
67 |
+
for i, model_ligdocked in enumerate(structure_ligdocked):
|
68 |
+
model_complex = PDBModel.Model(0)
|
69 |
+
for chain in model_receptor:
|
70 |
+
model_complex.add(chain.copy())
|
71 |
+
for chain in model_ligdocked:
|
72 |
+
model_complex.add(chain.copy())
|
73 |
+
pdb_io.set_structure(model_complex)
|
74 |
+
save_path = os.path.join(self.tmpdir.name, f"complex_{i}.pdb")
|
75 |
+
pdb_io.save(save_path)
|
76 |
+
paths.append(save_path)
|
77 |
+
return paths
|
78 |
+
|
79 |
+
def dock(self):
|
80 |
+
if not (self._has_receptor and self._has_ligand):
|
81 |
+
raise ValueError('Missing receptor or ligand.')
|
82 |
+
subprocess.run(
|
83 |
+
[self.hdock_bin, "receptor.pdb", "ligand.pdb"],
|
84 |
+
cwd=self.tmpdir.name, check=True
|
85 |
+
)
|
86 |
+
subprocess.run(
|
87 |
+
[self.createpl_bin, "Hdock.out", "ligand_docked.pdb"],
|
88 |
+
cwd=self.tmpdir.name, check=True
|
89 |
+
)
|
90 |
+
return self._dump_complex_pdb()
|
91 |
+
|
92 |
+
|
93 |
+
@dc.dataclass
|
94 |
+
class DockSite:
|
95 |
+
chain: str
|
96 |
+
resseq: int
|
97 |
+
|
98 |
+
|
99 |
+
class HDockAntibody(HDock):
|
100 |
+
|
101 |
+
def __init__(self, *args, **kwargs):
|
102 |
+
super().__init__(*args, **kwargs)
|
103 |
+
self._heavy_chain_id = None
|
104 |
+
self._epitope_sites: Optional[List[DockSite]] = None
|
105 |
+
|
106 |
+
def set_ligand(self, pdb_path):
|
107 |
+
raise NotImplementedError('Please use set_antibody')
|
108 |
+
|
109 |
+
def set_receptor(self, pdb_path):
|
110 |
+
raise NotImplementedError('Please use set_antigen')
|
111 |
+
|
112 |
+
def set_antigen(self, pdb_path, epitope_sites: Optional[List[DockSite]]=None):
|
113 |
+
super().set_receptor(pdb_path)
|
114 |
+
self._epitope_sites = epitope_sites
|
115 |
+
|
116 |
+
def set_antibody(self, pdb_path):
|
117 |
+
heavy_chains, _ = renumber_chothia(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb'))
|
118 |
+
self._has_ligand = True
|
119 |
+
self._heavy_chain_id = heavy_chains[0]
|
120 |
+
|
121 |
+
def _prepare_lsite(self):
|
122 |
+
lsite_content = f"95-102:{self._heavy_chain_id}\n" # Chothia CDR H3
|
123 |
+
with open(os.path.join(self.tmpdir.name, 'lsite.txt'), 'w') as f:
|
124 |
+
f.write(lsite_content)
|
125 |
+
print(f"[INFO] lsite content: {lsite_content}")
|
126 |
+
|
127 |
+
def _prepare_rsite(self):
|
128 |
+
rsite_content = ""
|
129 |
+
for site in self._epitope_sites:
|
130 |
+
rsite_content += f"{site.resseq}:{site.chain}\n"
|
131 |
+
with open(os.path.join(self.tmpdir.name, 'rsite.txt'), 'w') as f:
|
132 |
+
f.write(rsite_content)
|
133 |
+
print(f"[INFO] rsite content: {rsite_content}")
|
134 |
+
|
135 |
+
def dock(self):
|
136 |
+
if not (self._has_receptor and self._has_ligand):
|
137 |
+
raise ValueError('Missing receptor or ligand.')
|
138 |
+
self._prepare_lsite()
|
139 |
+
|
140 |
+
cmd_hdock = [self.hdock_bin, "receptor.pdb", "ligand.pdb", "-lsite", "lsite.txt"]
|
141 |
+
if self._epitope_sites is not None:
|
142 |
+
self._prepare_rsite()
|
143 |
+
cmd_hdock += ["-rsite", "rsite.txt"]
|
144 |
+
subprocess.run(
|
145 |
+
cmd_hdock,
|
146 |
+
cwd=self.tmpdir.name, check=True
|
147 |
+
)
|
148 |
+
|
149 |
+
cmd_pl = [self.createpl_bin, "Hdock.out", "ligand_docked.pdb", "-lsite", "lsite.txt"]
|
150 |
+
if self._epitope_sites is not None:
|
151 |
+
self._prepare_rsite()
|
152 |
+
cmd_pl += ["-rsite", "rsite.txt"]
|
153 |
+
subprocess.run(
|
154 |
+
cmd_pl,
|
155 |
+
cwd=self.tmpdir.name, check=True
|
156 |
+
)
|
157 |
+
return self._dump_complex_pdb()
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == '__main__':
|
161 |
+
with HDockAntibody('hdock', 'createpl') as dock:
|
162 |
+
dock.set_antigen('./data/dock/receptor.pdb', [DockSite('A', 991)])
|
163 |
+
dock.set_antibody('./data/example_dock/3qhf_fv.pdb')
|
164 |
+
print(dock.dock())
|
diffab/tools/eval/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .run import main
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
main()
|
diffab/tools/eval/base.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import shelve
|
5 |
+
from Bio import PDB
|
6 |
+
from typing import Optional, Tuple, List
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class EvalTask:
|
12 |
+
in_path: str
|
13 |
+
ref_path: str
|
14 |
+
info: dict
|
15 |
+
structure: str
|
16 |
+
name: str
|
17 |
+
method: str
|
18 |
+
cdr: str
|
19 |
+
ab_chains: List
|
20 |
+
|
21 |
+
residue_first: Optional[Tuple] = None
|
22 |
+
residue_last: Optional[Tuple] = None
|
23 |
+
|
24 |
+
scores: dict = field(default_factory=dict)
|
25 |
+
|
26 |
+
def get_gen_biopython_model(self):
|
27 |
+
parser = PDB.PDBParser(QUIET=True)
|
28 |
+
return parser.get_structure(self.in_path, self.in_path)[0]
|
29 |
+
|
30 |
+
def get_ref_biopython_model(self):
|
31 |
+
parser = PDB.PDBParser(QUIET=True)
|
32 |
+
return parser.get_structure(self.ref_path, self.ref_path)[0]
|
33 |
+
|
34 |
+
def save_to_db(self, db: shelve.Shelf):
|
35 |
+
db[self.in_path] = self
|
36 |
+
|
37 |
+
def to_report_dict(self):
|
38 |
+
return {
|
39 |
+
'method': self.method,
|
40 |
+
'structure': self.structure,
|
41 |
+
'cdr': self.cdr,
|
42 |
+
'filename': os.path.basename(self.in_path),
|
43 |
+
**self.scores
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
class TaskScanner:
|
48 |
+
|
49 |
+
def __init__(self, root, postfix=None, db: Optional[shelve.Shelf]=None):
|
50 |
+
super().__init__()
|
51 |
+
self.root = root
|
52 |
+
self.postfix = postfix
|
53 |
+
self.visited = set()
|
54 |
+
self.db = db
|
55 |
+
if db is not None:
|
56 |
+
for k in db.keys():
|
57 |
+
self.visited.add(k)
|
58 |
+
|
59 |
+
def _get_metadata(self, fpath):
|
60 |
+
json_path = os.path.join(
|
61 |
+
os.path.dirname(os.path.dirname(fpath)),
|
62 |
+
'metadata.json'
|
63 |
+
)
|
64 |
+
tag_name = os.path.basename(os.path.dirname(fpath))
|
65 |
+
method_name = os.path.basename(
|
66 |
+
os.path.dirname(os.path.dirname(os.path.dirname(fpath)))
|
67 |
+
)
|
68 |
+
try:
|
69 |
+
antibody_chains = set()
|
70 |
+
info = None
|
71 |
+
with open(json_path, 'r') as f:
|
72 |
+
metadata = json.load(f)
|
73 |
+
for item in metadata['items']:
|
74 |
+
if item['tag'] == tag_name:
|
75 |
+
info = item
|
76 |
+
antibody_chains.add(item['residue_first'][0])
|
77 |
+
if info is not None:
|
78 |
+
info['antibody_chains'] = list(antibody_chains)
|
79 |
+
info['structure'] = metadata['identifier']
|
80 |
+
info['method'] = method_name
|
81 |
+
return info
|
82 |
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
83 |
+
return None
|
84 |
+
|
85 |
+
def scan(self) -> List[EvalTask]:
|
86 |
+
tasks = []
|
87 |
+
if self.postfix is None or not self.postfix:
|
88 |
+
input_fname_pattern = '^\d+\.pdb$'
|
89 |
+
ref_fname = 'REF1.pdb'
|
90 |
+
else:
|
91 |
+
input_fname_pattern = f'^\d+\_{self.postfix}\.pdb$'
|
92 |
+
ref_fname = f'REF1_{self.postfix}.pdb'
|
93 |
+
for parent, _, files in os.walk(self.root):
|
94 |
+
for fname in files:
|
95 |
+
fpath = os.path.join(parent, fname)
|
96 |
+
if not re.match(input_fname_pattern, fname):
|
97 |
+
continue
|
98 |
+
if os.path.getsize(fpath) == 0:
|
99 |
+
continue
|
100 |
+
if fpath in self.visited:
|
101 |
+
continue
|
102 |
+
|
103 |
+
# Path to the reference structure
|
104 |
+
ref_path = os.path.join(parent, ref_fname)
|
105 |
+
if not os.path.exists(ref_path):
|
106 |
+
continue
|
107 |
+
|
108 |
+
# CDR information
|
109 |
+
info = self._get_metadata(fpath)
|
110 |
+
if info is None:
|
111 |
+
continue
|
112 |
+
tasks.append(EvalTask(
|
113 |
+
in_path = fpath,
|
114 |
+
ref_path = ref_path,
|
115 |
+
info = info,
|
116 |
+
structure = info['structure'],
|
117 |
+
name = info['name'],
|
118 |
+
method = info['method'],
|
119 |
+
cdr = info['tag'],
|
120 |
+
ab_chains = info['antibody_chains'],
|
121 |
+
residue_first = info.get('residue_first', None),
|
122 |
+
residue_last = info.get('residue_last', None),
|
123 |
+
))
|
124 |
+
self.visited.add(fpath)
|
125 |
+
return tasks
|
diffab/tools/eval/energy.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyright: reportMissingImports=false
|
2 |
+
import pyrosetta
|
3 |
+
from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
|
4 |
+
pyrosetta.init(' '.join([
|
5 |
+
'-mute', 'all',
|
6 |
+
'-use_input_sc',
|
7 |
+
'-ignore_unrecognized_res',
|
8 |
+
'-ignore_zero_occupancy', 'false',
|
9 |
+
'-load_PDB_components', 'false',
|
10 |
+
'-relax:default_repeats', '2',
|
11 |
+
'-no_fconfig',
|
12 |
+
]))
|
13 |
+
|
14 |
+
from tools.eval.base import EvalTask
|
15 |
+
|
16 |
+
|
17 |
+
def pyrosetta_interface_energy(pdb_path, interface):
|
18 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
19 |
+
mover = InterfaceAnalyzerMover(interface)
|
20 |
+
mover.set_pack_separated(True)
|
21 |
+
mover.apply(pose)
|
22 |
+
return pose.scores['dG_separated']
|
23 |
+
|
24 |
+
|
25 |
+
def eval_interface_energy(task: EvalTask):
|
26 |
+
model_gen = task.get_gen_biopython_model()
|
27 |
+
antigen_chains = set()
|
28 |
+
for chain in model_gen:
|
29 |
+
if chain.id not in task.ab_chains:
|
30 |
+
antigen_chains.add(chain.id)
|
31 |
+
antigen_chains = ''.join(list(antigen_chains))
|
32 |
+
antibody_chains = ''.join(task.ab_chains)
|
33 |
+
interface = f"{antibody_chains}_{antigen_chains}"
|
34 |
+
|
35 |
+
dG_gen = pyrosetta_interface_energy(task.in_path, interface)
|
36 |
+
dG_ref = pyrosetta_interface_energy(task.ref_path, interface)
|
37 |
+
|
38 |
+
task.scores.update({
|
39 |
+
'dG_gen': dG_gen,
|
40 |
+
'dG_ref': dG_ref,
|
41 |
+
'ddG': dG_gen - dG_ref
|
42 |
+
})
|
43 |
+
return task
|
diffab/tools/eval/run.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import ray
|
4 |
+
import shelve
|
5 |
+
import time
|
6 |
+
import pandas as pd
|
7 |
+
from typing import Mapping
|
8 |
+
|
9 |
+
from tools.eval.base import EvalTask, TaskScanner
|
10 |
+
from tools.eval.similarity import eval_similarity
|
11 |
+
from tools.eval.energy import eval_interface_energy
|
12 |
+
|
13 |
+
|
14 |
+
@ray.remote(num_cpus=1)
|
15 |
+
def evaluate(task, args):
|
16 |
+
funcs = []
|
17 |
+
funcs.append(eval_similarity)
|
18 |
+
if not args.no_energy:
|
19 |
+
funcs.append(eval_interface_energy)
|
20 |
+
for f in funcs:
|
21 |
+
task = f(task)
|
22 |
+
return task
|
23 |
+
|
24 |
+
|
25 |
+
def dump_db(db: Mapping[str, EvalTask], path):
|
26 |
+
table = []
|
27 |
+
for task in db.values():
|
28 |
+
if 'abopt' in path and task.scores['seqid'] >= 100.0:
|
29 |
+
# In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type
|
30 |
+
continue
|
31 |
+
table.append(task.to_report_dict())
|
32 |
+
table = pd.DataFrame(table)
|
33 |
+
table.to_csv(path, index=False, float_format='%.6f')
|
34 |
+
return table
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument('--root', type=str, default='./results')
|
40 |
+
parser.add_argument('--pfx', type=str, default='rosetta')
|
41 |
+
parser.add_argument('--no_energy', action='store_true', default=False)
|
42 |
+
args = parser.parse_args()
|
43 |
+
ray.init()
|
44 |
+
|
45 |
+
db_path = os.path.join(args.root, 'evaluation_db')
|
46 |
+
with shelve.open(db_path) as db:
|
47 |
+
scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db)
|
48 |
+
|
49 |
+
while True:
|
50 |
+
tasks = scanner.scan()
|
51 |
+
futures = [evaluate.remote(t, args) for t in tasks]
|
52 |
+
if len(futures) > 0:
|
53 |
+
print(f'Submitted {len(futures)} tasks.')
|
54 |
+
while len(futures) > 0:
|
55 |
+
done_ids, futures = ray.wait(futures, num_returns=1)
|
56 |
+
for done_id in done_ids:
|
57 |
+
done_task = ray.get(done_id)
|
58 |
+
done_task.save_to_db(db)
|
59 |
+
print(f'Remaining {len(futures)}. Finished {done_task.in_path}')
|
60 |
+
db.sync()
|
61 |
+
|
62 |
+
dump_db(db, os.path.join(args.root, 'summary.csv'))
|
63 |
+
time.sleep(1.0)
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
main()
|
diffab/tools/eval/similarity.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from Bio.PDB import PDBParser, Selection
|
3 |
+
from Bio.PDB.Polypeptide import three_to_one
|
4 |
+
from Bio import pairwise2
|
5 |
+
from Bio.Align import substitution_matrices
|
6 |
+
|
7 |
+
from diffab.tools.eval.base import EvalTask
|
8 |
+
|
9 |
+
|
10 |
+
def reslist_rmsd(res_list1, res_list2):
|
11 |
+
res_short, res_long = (res_list1, res_list2) if len(res_list1) < len(res_list2) else (res_list2, res_list1)
|
12 |
+
M, N = len(res_short), len(res_long)
|
13 |
+
|
14 |
+
def d(i, j):
|
15 |
+
coord_i = np.array(res_short[i]['CA'].get_coord())
|
16 |
+
coord_j = np.array(res_long[j]['CA'].get_coord())
|
17 |
+
return ((coord_i - coord_j) ** 2).sum()
|
18 |
+
|
19 |
+
SD = np.full([M, N], np.inf)
|
20 |
+
for i in range(M):
|
21 |
+
j = N - (M - i)
|
22 |
+
SD[i, j] = sum([ d(i+k, j+k) for k in range(N-j) ])
|
23 |
+
|
24 |
+
for j in range(N):
|
25 |
+
SD[M-1, j] = d(M-1, j)
|
26 |
+
|
27 |
+
for i in range(M-2, -1, -1):
|
28 |
+
for j in range((N-(M-i))-1, -1, -1):
|
29 |
+
SD[i, j] = min(
|
30 |
+
d(i, j) + SD[i+1, j+1],
|
31 |
+
SD[i, j+1]
|
32 |
+
)
|
33 |
+
|
34 |
+
min_SD = SD[0, :N-M+1].min()
|
35 |
+
best_RMSD = np.sqrt(min_SD / M)
|
36 |
+
return best_RMSD
|
37 |
+
|
38 |
+
|
39 |
+
def entity_to_seq(entity):
|
40 |
+
seq = ''
|
41 |
+
mapping = []
|
42 |
+
for res in Selection.unfold_entities(entity, 'R'):
|
43 |
+
try:
|
44 |
+
seq += three_to_one(res.get_resname())
|
45 |
+
mapping.append(res.get_id())
|
46 |
+
except KeyError:
|
47 |
+
pass
|
48 |
+
assert len(seq) == len(mapping)
|
49 |
+
return seq, mapping
|
50 |
+
|
51 |
+
|
52 |
+
def reslist_seqid(res_list1, res_list2):
|
53 |
+
seq1, _ = entity_to_seq(res_list1)
|
54 |
+
seq2, _ = entity_to_seq(res_list2)
|
55 |
+
_, seq_id = align_sequences(seq1, seq2)
|
56 |
+
return seq_id
|
57 |
+
|
58 |
+
|
59 |
+
def align_sequences(sequence_A, sequence_B, **kwargs):
|
60 |
+
"""
|
61 |
+
Performs a global pairwise alignment between two sequences
|
62 |
+
using the BLOSUM62 matrix and the Needleman-Wunsch algorithm
|
63 |
+
as implemented in Biopython. Returns the alignment, the sequence
|
64 |
+
identity and the residue mapping between both original sequences.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def _calculate_identity(sequenceA, sequenceB):
|
68 |
+
"""
|
69 |
+
Returns the percentage of identical characters between two sequences.
|
70 |
+
Assumes the sequences are aligned.
|
71 |
+
"""
|
72 |
+
|
73 |
+
sa, sb, sl = sequenceA, sequenceB, len(sequenceA)
|
74 |
+
matches = [sa[i] == sb[i] for i in range(sl)]
|
75 |
+
seq_id = (100 * sum(matches)) / sl
|
76 |
+
return seq_id
|
77 |
+
|
78 |
+
# gapless_sl = sum([1 for i in range(sl) if (sa[i] != '-' and sb[i] != '-')])
|
79 |
+
# gap_id = (100 * sum(matches)) / gapless_sl
|
80 |
+
# return (seq_id, gap_id)
|
81 |
+
|
82 |
+
#
|
83 |
+
matrix = kwargs.get('matrix', substitution_matrices.load("BLOSUM62"))
|
84 |
+
gap_open = kwargs.get('gap_open', -10.0)
|
85 |
+
gap_extend = kwargs.get('gap_extend', -0.5)
|
86 |
+
|
87 |
+
alns = pairwise2.align.globalds(sequence_A, sequence_B,
|
88 |
+
matrix, gap_open, gap_extend,
|
89 |
+
penalize_end_gaps=(False, False) )
|
90 |
+
|
91 |
+
best_aln = alns[0]
|
92 |
+
aligned_A, aligned_B, score, begin, end = best_aln
|
93 |
+
|
94 |
+
# Calculate sequence identity
|
95 |
+
seq_id = _calculate_identity(aligned_A, aligned_B)
|
96 |
+
return (aligned_A, aligned_B), seq_id
|
97 |
+
|
98 |
+
|
99 |
+
def extract_reslist(model, residue_first, residue_last):
|
100 |
+
assert residue_first[0] == residue_last[0]
|
101 |
+
residue_first, residue_last = tuple(residue_first), tuple(residue_last)
|
102 |
+
|
103 |
+
chain_id = residue_first[0]
|
104 |
+
pos_first, pos_last = residue_first[1:], residue_last[1:]
|
105 |
+
chain = model[chain_id]
|
106 |
+
reslist = []
|
107 |
+
for res in Selection.unfold_entities(chain, 'R'):
|
108 |
+
pos_current = (res.id[1], res.id[2])
|
109 |
+
if pos_first <= pos_current <= pos_last:
|
110 |
+
reslist.append(res)
|
111 |
+
return reslist
|
112 |
+
|
113 |
+
|
114 |
+
def eval_similarity(task: EvalTask):
|
115 |
+
model_gen = task.get_gen_biopython_model()
|
116 |
+
model_ref = task.get_ref_biopython_model()
|
117 |
+
|
118 |
+
reslist_gen = extract_reslist(model_gen, task.residue_first, task.residue_last)
|
119 |
+
reslist_ref = extract_reslist(model_ref, task.residue_first, task.residue_last)
|
120 |
+
|
121 |
+
task.scores.update({
|
122 |
+
'rmsd': reslist_rmsd(reslist_gen, reslist_ref),
|
123 |
+
'seqid': reslist_seqid(reslist_gen, reslist_ref),
|
124 |
+
})
|
125 |
+
return task
|
diffab/tools/relax/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .run import main
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
main()
|
diffab/tools/relax/base.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
from typing import Optional, Tuple, List
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class RelaxTask:
|
10 |
+
in_path: str
|
11 |
+
current_path: str
|
12 |
+
info: dict
|
13 |
+
status: str
|
14 |
+
|
15 |
+
flexible_residue_first: Optional[Tuple] = None
|
16 |
+
flexible_residue_last: Optional[Tuple] = None
|
17 |
+
|
18 |
+
def get_in_path_with_tag(self, tag):
|
19 |
+
name, ext = os.path.splitext(self.in_path)
|
20 |
+
new_path = f'{name}_{tag}{ext}'
|
21 |
+
return new_path
|
22 |
+
|
23 |
+
def set_current_path_tag(self, tag):
|
24 |
+
new_path = self.get_in_path_with_tag(tag)
|
25 |
+
self.current_path = new_path
|
26 |
+
return new_path
|
27 |
+
|
28 |
+
def check_current_path_exists(self):
|
29 |
+
ok = os.path.exists(self.current_path)
|
30 |
+
if not ok:
|
31 |
+
self.mark_failure()
|
32 |
+
if os.path.getsize(self.current_path) == 0:
|
33 |
+
ok = False
|
34 |
+
self.mark_failure()
|
35 |
+
os.unlink(self.current_path)
|
36 |
+
return ok
|
37 |
+
|
38 |
+
def update_if_finished(self, tag):
|
39 |
+
out_path = self.get_in_path_with_tag(tag)
|
40 |
+
if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
|
41 |
+
# print('Already finished', out_path)
|
42 |
+
self.set_current_path_tag(tag)
|
43 |
+
self.mark_success()
|
44 |
+
return True
|
45 |
+
return False
|
46 |
+
|
47 |
+
def can_proceed(self):
|
48 |
+
self.check_current_path_exists()
|
49 |
+
return self.status != 'failed'
|
50 |
+
|
51 |
+
def mark_success(self):
|
52 |
+
self.status = 'success'
|
53 |
+
|
54 |
+
def mark_failure(self):
|
55 |
+
self.status = 'failed'
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
class TaskScanner:
|
60 |
+
|
61 |
+
def __init__(self, root, final_postfix=None):
|
62 |
+
super().__init__()
|
63 |
+
self.root = root
|
64 |
+
self.visited = set()
|
65 |
+
self.final_postfix = final_postfix
|
66 |
+
|
67 |
+
def _get_metadata(self, fpath):
|
68 |
+
json_path = os.path.join(
|
69 |
+
os.path.dirname(os.path.dirname(fpath)),
|
70 |
+
'metadata.json'
|
71 |
+
)
|
72 |
+
tag_name = os.path.basename(os.path.dirname(fpath))
|
73 |
+
try:
|
74 |
+
with open(json_path, 'r') as f:
|
75 |
+
metadata = json.load(f)
|
76 |
+
for item in metadata['items']:
|
77 |
+
if item['tag'] == tag_name:
|
78 |
+
return item
|
79 |
+
except (json.JSONDecodeError, FileNotFoundError) as e:
|
80 |
+
return None
|
81 |
+
return None
|
82 |
+
|
83 |
+
def scan(self) -> List[RelaxTask]:
|
84 |
+
tasks = []
|
85 |
+
input_fname_pattern = '(^\d+\.pdb$|^REF\d\.pdb$)'
|
86 |
+
for parent, _, files in os.walk(self.root):
|
87 |
+
for fname in files:
|
88 |
+
fpath = os.path.join(parent, fname)
|
89 |
+
if not re.match(input_fname_pattern, fname):
|
90 |
+
continue
|
91 |
+
if os.path.getsize(fpath) == 0:
|
92 |
+
continue
|
93 |
+
if fpath in self.visited:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# If finished
|
97 |
+
if self.final_postfix is not None:
|
98 |
+
fpath_name, fpath_ext = os.path.splitext(fpath)
|
99 |
+
fpath_final = f"{fpath_name}_{self.final_postfix}{fpath_ext}"
|
100 |
+
if os.path.exists(fpath_final):
|
101 |
+
continue
|
102 |
+
|
103 |
+
# Get metadata
|
104 |
+
info = self._get_metadata(fpath)
|
105 |
+
if info is None:
|
106 |
+
continue
|
107 |
+
|
108 |
+
tasks.append(RelaxTask(
|
109 |
+
in_path = fpath,
|
110 |
+
current_path = fpath,
|
111 |
+
info = info,
|
112 |
+
status = 'created',
|
113 |
+
flexible_residue_first = info.get('residue_first', None),
|
114 |
+
flexible_residue_last = info.get('residue_last', None),
|
115 |
+
))
|
116 |
+
self.visited.add(fpath)
|
117 |
+
return tasks
|
diffab/tools/relax/openmm_relaxer.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import io
|
4 |
+
import logging
|
5 |
+
import pdbfixer
|
6 |
+
import openmm
|
7 |
+
from openmm import app as openmm_app
|
8 |
+
from openmm import unit
|
9 |
+
ENERGY = unit.kilocalories_per_mole
|
10 |
+
LENGTH = unit.angstroms
|
11 |
+
|
12 |
+
from diffab.tools.relax.base import RelaxTask
|
13 |
+
|
14 |
+
|
15 |
+
def current_milli_time():
|
16 |
+
return round(time.time() * 1000)
|
17 |
+
|
18 |
+
|
19 |
+
def _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last):
|
20 |
+
if ch_rs_ic[0] != flexible_residue_first[0]: return False
|
21 |
+
r_first, r_last = tuple(flexible_residue_first[1:]), tuple(flexible_residue_last[1:])
|
22 |
+
rs_ic = ch_rs_ic[1:]
|
23 |
+
return r_first <= rs_ic <= r_last
|
24 |
+
|
25 |
+
|
26 |
+
class ForceFieldMinimizer(object):
|
27 |
+
|
28 |
+
def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'):
|
29 |
+
super().__init__()
|
30 |
+
self.stiffness = stiffness
|
31 |
+
self.max_iterations = max_iterations
|
32 |
+
self.tolerance = tolerance
|
33 |
+
assert platform in ('CUDA', 'CPU')
|
34 |
+
self.platform = platform
|
35 |
+
|
36 |
+
def _fix(self, pdb_str):
|
37 |
+
fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str))
|
38 |
+
fixer.findNonstandardResidues()
|
39 |
+
fixer.replaceNonstandardResidues()
|
40 |
+
|
41 |
+
fixer.findMissingResidues()
|
42 |
+
fixer.findMissingAtoms()
|
43 |
+
fixer.addMissingAtoms(seed=0)
|
44 |
+
fixer.addMissingHydrogens()
|
45 |
+
|
46 |
+
out_handle = io.StringIO()
|
47 |
+
openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True)
|
48 |
+
return out_handle.getvalue()
|
49 |
+
|
50 |
+
def _get_pdb_string(self, topology, positions):
|
51 |
+
with io.StringIO() as f:
|
52 |
+
openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True)
|
53 |
+
return f.getvalue()
|
54 |
+
|
55 |
+
def _minimize(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None):
|
56 |
+
pdb = openmm_app.PDBFile(io.StringIO(pdb_str))
|
57 |
+
|
58 |
+
force_field = openmm_app.ForceField("amber99sb.xml")
|
59 |
+
constraints = openmm_app.HBonds
|
60 |
+
system = force_field.createSystem(pdb.topology, constraints=constraints)
|
61 |
+
|
62 |
+
# Add constraints to non-generated regions
|
63 |
+
force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
|
64 |
+
force.addGlobalParameter("k", self.stiffness)
|
65 |
+
for p in ["x0", "y0", "z0"]:
|
66 |
+
force.addPerParticleParameter(p)
|
67 |
+
|
68 |
+
if flexible_residue_first is not None and flexible_residue_last is not None:
|
69 |
+
for i, a in enumerate(pdb.topology.atoms()):
|
70 |
+
ch_rs_ic = (a.residue.chain.id, int(a.residue.id), a.residue.insertionCode)
|
71 |
+
if not _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last) and a.element.name != "hydrogen":
|
72 |
+
force.addParticle(i, pdb.positions[i])
|
73 |
+
|
74 |
+
system.addForce(force)
|
75 |
+
|
76 |
+
# Set up the integrator and simulation
|
77 |
+
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
|
78 |
+
platform = openmm.Platform.getPlatformByName("CUDA")
|
79 |
+
simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
|
80 |
+
simulation.context.setPositions(pdb.positions)
|
81 |
+
|
82 |
+
# Perform minimization
|
83 |
+
ret = {}
|
84 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
85 |
+
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
86 |
+
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
87 |
+
|
88 |
+
simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance)
|
89 |
+
|
90 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
91 |
+
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
92 |
+
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
93 |
+
ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions())
|
94 |
+
|
95 |
+
return ret['min_pdb'], ret
|
96 |
+
|
97 |
+
def _add_energy_remarks(self, pdb_str, ret):
|
98 |
+
pdb_lines = pdb_str.splitlines()
|
99 |
+
pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal']))
|
100 |
+
pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit']))
|
101 |
+
return "\n".join(pdb_lines)
|
102 |
+
|
103 |
+
def __call__(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None, return_info=True):
|
104 |
+
if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"):
|
105 |
+
with open(pdb_str) as f:
|
106 |
+
pdb_str = f.read()
|
107 |
+
|
108 |
+
pdb_fixed = self._fix(pdb_str)
|
109 |
+
pdb_min, ret = self._minimize(pdb_fixed, flexible_residue_first, flexible_residue_last)
|
110 |
+
pdb_min = self._add_energy_remarks(pdb_min, ret)
|
111 |
+
if return_info:
|
112 |
+
return pdb_min, ret
|
113 |
+
else:
|
114 |
+
return pdb_min
|
115 |
+
|
116 |
+
|
117 |
+
def run_openmm(task: RelaxTask):
|
118 |
+
if not task.can_proceed() :
|
119 |
+
return task
|
120 |
+
if task.update_if_finished('openmm'):
|
121 |
+
return task
|
122 |
+
|
123 |
+
try:
|
124 |
+
minimizer = ForceFieldMinimizer()
|
125 |
+
with open(task.current_path, 'r') as f:
|
126 |
+
pdb_str = f.read()
|
127 |
+
|
128 |
+
pdb_min = minimizer(
|
129 |
+
pdb_str = pdb_str,
|
130 |
+
flexible_residue_first = task.flexible_residue_first,
|
131 |
+
flexible_residue_last = task.flexible_residue_last,
|
132 |
+
return_info = False,
|
133 |
+
)
|
134 |
+
out_path = task.set_current_path_tag('openmm')
|
135 |
+
with open(out_path, 'w') as f:
|
136 |
+
f.write(pdb_min)
|
137 |
+
task.mark_success()
|
138 |
+
except ValueError as e:
|
139 |
+
logging.warning(
|
140 |
+
f'{e.__class__.__name__}: {str(e)} ({task.current_path})'
|
141 |
+
)
|
142 |
+
task.mark_failure()
|
143 |
+
return task
|
144 |
+
|
diffab/tools/relax/pyrosetta_relaxer.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyright: reportMissingImports=false
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import pyrosetta
|
5 |
+
from pyrosetta.rosetta.protocols.relax import FastRelax
|
6 |
+
from pyrosetta.rosetta.core.pack.task import TaskFactory
|
7 |
+
from pyrosetta.rosetta.core.pack.task import operation
|
8 |
+
from pyrosetta.rosetta.core.select import residue_selector as selections
|
9 |
+
from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
|
10 |
+
pyrosetta.init(' '.join([
|
11 |
+
'-mute', 'all',
|
12 |
+
'-use_input_sc',
|
13 |
+
'-ignore_unrecognized_res',
|
14 |
+
'-ignore_zero_occupancy', 'false',
|
15 |
+
'-load_PDB_components', 'false',
|
16 |
+
'-relax:default_repeats', '2',
|
17 |
+
'-no_fconfig',
|
18 |
+
]))
|
19 |
+
|
20 |
+
from diffab.tools.relax.base import RelaxTask
|
21 |
+
|
22 |
+
|
23 |
+
def current_milli_time():
|
24 |
+
return round(time.time() * 1000)
|
25 |
+
|
26 |
+
|
27 |
+
def parse_residue_position(p):
|
28 |
+
icode = None
|
29 |
+
if not p[-1].isnumeric(): # Has ICODE
|
30 |
+
icode = p[-1]
|
31 |
+
|
32 |
+
for i, c in enumerate(p):
|
33 |
+
if c.isnumeric():
|
34 |
+
break
|
35 |
+
chain = p[:i]
|
36 |
+
resseq = int(p[i:])
|
37 |
+
|
38 |
+
if icode is not None:
|
39 |
+
return chain, resseq, icode
|
40 |
+
else:
|
41 |
+
return chain, resseq
|
42 |
+
|
43 |
+
|
44 |
+
def get_scorefxn(scorefxn_name:str):
|
45 |
+
"""
|
46 |
+
Gets the scorefxn with appropriate corrections.
|
47 |
+
Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550
|
48 |
+
"""
|
49 |
+
import pyrosetta
|
50 |
+
|
51 |
+
corrections = {
|
52 |
+
'beta_july15': False,
|
53 |
+
'beta_nov16': False,
|
54 |
+
'gen_potential': False,
|
55 |
+
'restore_talaris_behavior': False,
|
56 |
+
}
|
57 |
+
if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
|
58 |
+
# beta_july15 is ref2015
|
59 |
+
corrections['beta_july15'] = True
|
60 |
+
elif 'beta_nov16' in scorefxn_name:
|
61 |
+
corrections['beta_nov16'] = True
|
62 |
+
elif 'genpot' in scorefxn_name:
|
63 |
+
corrections['gen_potential'] = True
|
64 |
+
pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
|
65 |
+
elif 'talaris' in scorefxn_name: #2013 and 2014
|
66 |
+
corrections['restore_talaris_behavior'] = True
|
67 |
+
else:
|
68 |
+
pass
|
69 |
+
for corr, value in corrections.items():
|
70 |
+
pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
|
71 |
+
return pyrosetta.create_score_function(scorefxn_name)
|
72 |
+
|
73 |
+
|
74 |
+
class RelaxRegion(object):
|
75 |
+
|
76 |
+
def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True):
|
77 |
+
super().__init__()
|
78 |
+
self.scorefxn = get_scorefxn(scorefxn)
|
79 |
+
self.fast_relax = FastRelax()
|
80 |
+
self.fast_relax.set_scorefxn(self.scorefxn)
|
81 |
+
self.fast_relax.max_iter(max_iter)
|
82 |
+
assert subset in ('all', 'target', 'nbrs')
|
83 |
+
self.subset = subset
|
84 |
+
self.move_bb = move_bb
|
85 |
+
|
86 |
+
def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last):
|
87 |
+
pose = pyrosetta.pose_from_pdb(pdb_path)
|
88 |
+
start_t = current_milli_time()
|
89 |
+
original_pose = pose.clone()
|
90 |
+
|
91 |
+
tf = TaskFactory()
|
92 |
+
tf.push_back(operation.InitializeFromCommandline())
|
93 |
+
tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position.
|
94 |
+
|
95 |
+
# Create selector for the region to be relaxed
|
96 |
+
# Turn off design and repacking on irrelevant positions
|
97 |
+
if flexible_residue_first[-1] == ' ':
|
98 |
+
flexible_residue_first = flexible_residue_first[:-1]
|
99 |
+
if flexible_residue_last[-1] == ' ':
|
100 |
+
flexible_residue_last = flexible_residue_last[:-1]
|
101 |
+
if self.subset != 'all':
|
102 |
+
gen_selector = selections.ResidueIndexSelector()
|
103 |
+
gen_selector.set_index_range(
|
104 |
+
pose.pdb_info().pdb2pose(*flexible_residue_first),
|
105 |
+
pose.pdb_info().pdb2pose(*flexible_residue_last),
|
106 |
+
)
|
107 |
+
nbr_selector = selections.NeighborhoodResidueSelector()
|
108 |
+
nbr_selector.set_focus_selector(gen_selector)
|
109 |
+
nbr_selector.set_include_focus_in_subset(True)
|
110 |
+
|
111 |
+
if self.subset == 'nbrs':
|
112 |
+
subset_selector = nbr_selector
|
113 |
+
elif self.subset == 'target':
|
114 |
+
subset_selector = gen_selector
|
115 |
+
|
116 |
+
prevent_repacking_rlt = operation.PreventRepackingRLT()
|
117 |
+
prevent_subset_repacking = operation.OperateOnResidueSubset(
|
118 |
+
prevent_repacking_rlt,
|
119 |
+
subset_selector,
|
120 |
+
flip_subset=True,
|
121 |
+
)
|
122 |
+
tf.push_back(prevent_subset_repacking)
|
123 |
+
|
124 |
+
scorefxn = self.scorefxn
|
125 |
+
fr = self.fast_relax
|
126 |
+
|
127 |
+
pose = original_pose.clone()
|
128 |
+
pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
|
129 |
+
for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
|
130 |
+
pos_list.append(pos)
|
131 |
+
# basic_idealize(pose, pos_list, scorefxn, fast=True)
|
132 |
+
|
133 |
+
mmf = MoveMapFactory()
|
134 |
+
if self.move_bb:
|
135 |
+
mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
|
136 |
+
mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
|
137 |
+
mm = mmf.create_movemap_from_pose(pose)
|
138 |
+
|
139 |
+
fr.set_movemap(mm)
|
140 |
+
fr.set_task_factory(tf)
|
141 |
+
fr.apply(pose)
|
142 |
+
|
143 |
+
e_before = scorefxn(original_pose)
|
144 |
+
e_relax = scorefxn(pose)
|
145 |
+
# print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000))
|
146 |
+
# print(' > Energy (before): %.4f' % scorefxn(original_pose))
|
147 |
+
# print(' > Energy (optimized): %.4f' % scorefxn(pose))
|
148 |
+
return pose, e_before, e_relax
|
149 |
+
|
150 |
+
|
151 |
+
def run_pyrosetta(task: RelaxTask):
|
152 |
+
if not task.can_proceed() :
|
153 |
+
return task
|
154 |
+
if task.update_if_finished('rosetta'):
|
155 |
+
return task
|
156 |
+
|
157 |
+
minimizer = RelaxRegion()
|
158 |
+
pose_min, _, _ = minimizer(
|
159 |
+
pdb_path = task.current_path,
|
160 |
+
flexible_residue_first = task.flexible_residue_first,
|
161 |
+
flexible_residue_last = task.flexible_residue_last,
|
162 |
+
)
|
163 |
+
|
164 |
+
out_path = task.set_current_path_tag('rosetta')
|
165 |
+
pose_min.dump_pdb(out_path)
|
166 |
+
task.mark_success()
|
167 |
+
return task
|
168 |
+
|
169 |
+
|
170 |
+
def run_pyrosetta_fixbb(task: RelaxTask):
|
171 |
+
if not task.can_proceed() :
|
172 |
+
return task
|
173 |
+
if task.update_if_finished('fixbb'):
|
174 |
+
return task
|
175 |
+
|
176 |
+
minimizer = RelaxRegion(move_bb=False)
|
177 |
+
pose_min, _, _ = minimizer(
|
178 |
+
pdb_path = task.current_path,
|
179 |
+
flexible_residue_first = task.flexible_residue_first,
|
180 |
+
flexible_residue_last = task.flexible_residue_last,
|
181 |
+
)
|
182 |
+
|
183 |
+
out_path = task.set_current_path_tag('fixbb')
|
184 |
+
pose_min.dump_pdb(out_path)
|
185 |
+
task.mark_success()
|
186 |
+
return task
|
187 |
+
|
188 |
+
|
189 |
+
|
diffab/tools/relax/run.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ray
|
3 |
+
import time
|
4 |
+
|
5 |
+
from diffab.tools.relax.openmm_relaxer import run_openmm
|
6 |
+
from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb
|
7 |
+
from diffab.tools.relax.base import TaskScanner
|
8 |
+
|
9 |
+
|
10 |
+
@ray.remote(num_gpus=1/8, num_cpus=1)
|
11 |
+
def run_openmm_remote(task):
|
12 |
+
return run_openmm(task)
|
13 |
+
|
14 |
+
|
15 |
+
@ray.remote(num_cpus=1)
|
16 |
+
def run_pyrosetta_remote(task):
|
17 |
+
return run_pyrosetta(task)
|
18 |
+
|
19 |
+
|
20 |
+
@ray.remote(num_cpus=1)
|
21 |
+
def run_pyrosetta_fixbb_remote(task):
|
22 |
+
return run_pyrosetta_fixbb(task)
|
23 |
+
|
24 |
+
|
25 |
+
@ray.remote
|
26 |
+
def pipeline_openmm_pyrosetta(task):
|
27 |
+
funcs = [
|
28 |
+
run_openmm_remote,
|
29 |
+
run_pyrosetta_remote,
|
30 |
+
]
|
31 |
+
for fn in funcs:
|
32 |
+
task = fn.remote(task)
|
33 |
+
return ray.get(task)
|
34 |
+
|
35 |
+
|
36 |
+
@ray.remote
|
37 |
+
def pipeline_pyrosetta(task):
|
38 |
+
funcs = [
|
39 |
+
run_pyrosetta_remote,
|
40 |
+
]
|
41 |
+
for fn in funcs:
|
42 |
+
task = fn.remote(task)
|
43 |
+
return ray.get(task)
|
44 |
+
|
45 |
+
|
46 |
+
@ray.remote
|
47 |
+
def pipeline_pyrosetta_fixbb(task):
|
48 |
+
funcs = [
|
49 |
+
run_pyrosetta_fixbb_remote,
|
50 |
+
]
|
51 |
+
for fn in funcs:
|
52 |
+
task = fn.remote(task)
|
53 |
+
return ray.get(task)
|
54 |
+
|
55 |
+
|
56 |
+
pipeline_dict = {
|
57 |
+
'openmm_pyrosetta': pipeline_openmm_pyrosetta,
|
58 |
+
'pyrosetta': pipeline_pyrosetta,
|
59 |
+
'pyrosetta_fixbb': pipeline_pyrosetta_fixbb,
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
ray.init()
|
65 |
+
parser = argparse.ArgumentParser()
|
66 |
+
parser.add_argument('--root', type=str, default='./results')
|
67 |
+
parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta)
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta'
|
71 |
+
scanner = TaskScanner(args.root, final_postfix=final_pfx)
|
72 |
+
while True:
|
73 |
+
tasks = scanner.scan()
|
74 |
+
futures = [args.pipeline.remote(t) for t in tasks]
|
75 |
+
if len(futures) > 0:
|
76 |
+
print(f'Submitted {len(futures)} tasks.')
|
77 |
+
while len(futures) > 0:
|
78 |
+
done_ids, futures = ray.wait(futures, num_returns=1)
|
79 |
+
for done_id in done_ids:
|
80 |
+
done_task = ray.get(done_id)
|
81 |
+
print(f'Remaining {len(futures)}. Finished {done_task.current_path}')
|
82 |
+
time.sleep(1.0)
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
main()
|
diffab/tools/renumber/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .run import renumber
|
diffab/tools/renumber/__main__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .run import main
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
main()
|
5 |
+
|
diffab/tools/renumber/run.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import abnumber
|
3 |
+
from Bio import PDB
|
4 |
+
from Bio.PDB import Model, Chain, Residue, Selection
|
5 |
+
from Bio.Data import SCOPData
|
6 |
+
from typing import List, Tuple
|
7 |
+
|
8 |
+
|
9 |
+
def biopython_chain_to_sequence(chain: Chain.Chain):
|
10 |
+
residue_list = Selection.unfold_entities(chain, 'R')
|
11 |
+
seq = ''.join([SCOPData.protein_letters_3to1.get(r.resname, 'X') for r in residue_list])
|
12 |
+
return seq, residue_list
|
13 |
+
|
14 |
+
|
15 |
+
def assign_number_to_sequence(seq):
|
16 |
+
abchain = abnumber.Chain(seq, scheme='chothia')
|
17 |
+
offset = seq.index(abchain.seq)
|
18 |
+
if not (offset >= 0):
|
19 |
+
raise ValueError(
|
20 |
+
'The identified Fv sequence is not a subsequence of the original sequence.'
|
21 |
+
)
|
22 |
+
|
23 |
+
numbers = [None for _ in range(len(seq))]
|
24 |
+
for i, (pos, aa) in enumerate(abchain):
|
25 |
+
resseq = pos.number
|
26 |
+
icode = pos.letter if pos.letter else ' '
|
27 |
+
numbers[i+offset] = (resseq, icode)
|
28 |
+
return numbers, abchain
|
29 |
+
|
30 |
+
|
31 |
+
def renumber_biopython_chain(chain_id, residue_list: List[Residue.Residue], numbers: List[Tuple[int, str]]):
|
32 |
+
chain = Chain.Chain(chain_id)
|
33 |
+
for residue, number in zip(residue_list, numbers):
|
34 |
+
if number is None:
|
35 |
+
continue
|
36 |
+
residue = residue.copy()
|
37 |
+
new_id = (residue.id[0], number[0], number[1])
|
38 |
+
residue.id = new_id
|
39 |
+
chain.add(residue)
|
40 |
+
return chain
|
41 |
+
|
42 |
+
|
43 |
+
def renumber(in_pdb, out_pdb, return_other_chains=False):
|
44 |
+
parser = PDB.PDBParser(QUIET=True)
|
45 |
+
structure = parser.get_structure(None, in_pdb)
|
46 |
+
model = structure[0]
|
47 |
+
model_new = Model.Model(0)
|
48 |
+
|
49 |
+
heavy_chains, light_chains, other_chains = [], [], []
|
50 |
+
|
51 |
+
for chain in model:
|
52 |
+
try:
|
53 |
+
seq, reslist = biopython_chain_to_sequence(chain)
|
54 |
+
numbers, abchain = assign_number_to_sequence(seq)
|
55 |
+
chain_new = renumber_biopython_chain(chain.id, reslist, numbers)
|
56 |
+
print(f'[INFO] Renumbered chain {chain_new.id} ({abchain.chain_type})')
|
57 |
+
if abchain.chain_type == 'H':
|
58 |
+
heavy_chains.append(chain_new.id)
|
59 |
+
elif abchain.chain_type in ('K', 'L'):
|
60 |
+
light_chains.append(chain_new.id)
|
61 |
+
except abnumber.ChainParseError as e:
|
62 |
+
print(f'[INFO] Chain {chain.id} does not contain valid Fv: {str(e)}')
|
63 |
+
chain_new = chain.copy()
|
64 |
+
other_chains.append(chain_new.id)
|
65 |
+
model_new.add(chain_new)
|
66 |
+
|
67 |
+
pdb_io = PDB.PDBIO()
|
68 |
+
pdb_io.set_structure(model_new)
|
69 |
+
pdb_io.save(out_pdb)
|
70 |
+
if return_other_chains:
|
71 |
+
return heavy_chains, light_chains, other_chains
|
72 |
+
else:
|
73 |
+
return heavy_chains, light_chains
|
74 |
+
|
75 |
+
|
76 |
+
def main():
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
parser.add_argument('in_pdb', type=str)
|
79 |
+
parser.add_argument('out_pdb', type=str)
|
80 |
+
args = parser.parse_args()
|
81 |
+
|
82 |
+
renumber(args.in_pdb, args.out_pdb)
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
main()
|
diffab/tools/runner/design_for_pdb.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
from diffab.datasets.custom import preprocess_antibody_structure
|
9 |
+
from diffab.models import get_model
|
10 |
+
from diffab.modules.common.geometry import reconstruct_backbone_partially
|
11 |
+
from diffab.modules.common.so3 import so3vec_to_rotation
|
12 |
+
from diffab.utils.inference import RemoveNative
|
13 |
+
from diffab.utils.protein.writers import save_pdb
|
14 |
+
from diffab.utils.train import recursive_to
|
15 |
+
from diffab.utils.misc import *
|
16 |
+
from diffab.utils.data import *
|
17 |
+
from diffab.utils.transforms import *
|
18 |
+
from diffab.utils.inference import *
|
19 |
+
from diffab.tools.renumber import renumber as renumber_antibody
|
20 |
+
|
21 |
+
|
22 |
+
def create_data_variants(config, structure_factory):
|
23 |
+
structure = structure_factory()
|
24 |
+
structure_id = structure['id']
|
25 |
+
|
26 |
+
data_variants = []
|
27 |
+
if config.mode == 'single_cdr':
|
28 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
29 |
+
for cdr_name in cdrs:
|
30 |
+
transform = Compose([
|
31 |
+
MaskSingleCDR(cdr_name, augmentation=False),
|
32 |
+
MergeChains(),
|
33 |
+
])
|
34 |
+
data_var = transform(structure_factory())
|
35 |
+
residue_first, residue_last = get_residue_first_last(data_var)
|
36 |
+
data_variants.append({
|
37 |
+
'data': data_var,
|
38 |
+
'name': f'{structure_id}-{cdr_name}',
|
39 |
+
'tag': f'{cdr_name}',
|
40 |
+
'cdr': cdr_name,
|
41 |
+
'residue_first': residue_first,
|
42 |
+
'residue_last': residue_last,
|
43 |
+
})
|
44 |
+
elif config.mode == 'multiple_cdrs':
|
45 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
46 |
+
transform = Compose([
|
47 |
+
MaskMultipleCDRs(selection=cdrs, augmentation=False),
|
48 |
+
MergeChains(),
|
49 |
+
])
|
50 |
+
data_var = transform(structure_factory())
|
51 |
+
data_variants.append({
|
52 |
+
'data': data_var,
|
53 |
+
'name': f'{structure_id}-MultipleCDRs',
|
54 |
+
'tag': 'MultipleCDRs',
|
55 |
+
'cdrs': cdrs,
|
56 |
+
'residue_first': None,
|
57 |
+
'residue_last': None,
|
58 |
+
})
|
59 |
+
elif config.mode == 'full':
|
60 |
+
transform = Compose([
|
61 |
+
MaskAntibody(),
|
62 |
+
MergeChains(),
|
63 |
+
])
|
64 |
+
data_var = transform(structure_factory())
|
65 |
+
data_variants.append({
|
66 |
+
'data': data_var,
|
67 |
+
'name': f'{structure_id}-Full',
|
68 |
+
'tag': 'Full',
|
69 |
+
'residue_first': None,
|
70 |
+
'residue_last': None,
|
71 |
+
})
|
72 |
+
elif config.mode == 'abopt':
|
73 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
74 |
+
for cdr_name in cdrs:
|
75 |
+
transform = Compose([
|
76 |
+
MaskSingleCDR(cdr_name, augmentation=False),
|
77 |
+
MergeChains(),
|
78 |
+
])
|
79 |
+
data_var = transform(structure_factory())
|
80 |
+
residue_first, residue_last = get_residue_first_last(data_var)
|
81 |
+
for opt_step in config.sampling.optimize_steps:
|
82 |
+
data_variants.append({
|
83 |
+
'data': data_var,
|
84 |
+
'name': f'{structure_id}-{cdr_name}-O{opt_step}',
|
85 |
+
'tag': f'{cdr_name}-O{opt_step}',
|
86 |
+
'cdr': cdr_name,
|
87 |
+
'opt_step': opt_step,
|
88 |
+
'residue_first': residue_first,
|
89 |
+
'residue_last': residue_last,
|
90 |
+
})
|
91 |
+
else:
|
92 |
+
raise ValueError(f'Unknown mode: {config.mode}.')
|
93 |
+
return data_variants
|
94 |
+
|
95 |
+
|
96 |
+
def design_for_pdb(args):
|
97 |
+
# Load configs
|
98 |
+
config, config_name = load_config(args.config)
|
99 |
+
seed_all(args.seed if args.seed is not None else config.sampling.seed)
|
100 |
+
|
101 |
+
# Structure loading
|
102 |
+
data_id = os.path.basename(args.pdb_path)
|
103 |
+
if args.no_renumber:
|
104 |
+
pdb_path = args.pdb_path
|
105 |
+
else:
|
106 |
+
in_pdb_path = args.pdb_path
|
107 |
+
out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb'
|
108 |
+
heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path)
|
109 |
+
pdb_path = out_pdb_path
|
110 |
+
|
111 |
+
if args.heavy is None and len(heavy_chains) > 0:
|
112 |
+
args.heavy = heavy_chains[0]
|
113 |
+
if args.light is None and len(light_chains) > 0:
|
114 |
+
args.light = light_chains[0]
|
115 |
+
if args.heavy is None and args.light is None:
|
116 |
+
raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.")
|
117 |
+
get_structure = lambda: preprocess_antibody_structure({
|
118 |
+
'id': data_id,
|
119 |
+
'pdb_path': pdb_path,
|
120 |
+
'heavy_id': args.heavy,
|
121 |
+
# If the input is a nanobody, the light chain will be ignores
|
122 |
+
'light_id': args.light,
|
123 |
+
})
|
124 |
+
|
125 |
+
# Logging
|
126 |
+
structure_ = get_structure()
|
127 |
+
structure_id = structure_['id']
|
128 |
+
tag_postfix = '_%s' % args.tag if args.tag else ''
|
129 |
+
log_dir = get_new_log_dir(
|
130 |
+
os.path.join(args.out_root, config_name + tag_postfix),
|
131 |
+
prefix=data_id
|
132 |
+
)
|
133 |
+
logger = get_logger('sample', log_dir)
|
134 |
+
logger.info(f'Data ID: {structure_["id"]}')
|
135 |
+
logger.info(f'Results will be saved to {log_dir}')
|
136 |
+
data_native = MergeChains()(structure_)
|
137 |
+
save_pdb(data_native, os.path.join(log_dir, 'reference.pdb'))
|
138 |
+
|
139 |
+
# Load checkpoint and model
|
140 |
+
logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint))
|
141 |
+
ckpt = torch.load(config.model.checkpoint, map_location='cpu')
|
142 |
+
cfg_ckpt = ckpt['config']
|
143 |
+
model = get_model(cfg_ckpt.model).to(args.device)
|
144 |
+
lsd = model.load_state_dict(ckpt['model'])
|
145 |
+
logger.info(str(lsd))
|
146 |
+
|
147 |
+
# Make data variants
|
148 |
+
data_variants = create_data_variants(
|
149 |
+
config = config,
|
150 |
+
structure_factory = get_structure,
|
151 |
+
)
|
152 |
+
|
153 |
+
# Save metadata
|
154 |
+
metadata = {
|
155 |
+
'identifier': structure_id,
|
156 |
+
'index': data_id,
|
157 |
+
'config': args.config,
|
158 |
+
'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants],
|
159 |
+
}
|
160 |
+
with open(os.path.join(log_dir, 'metadata.json'), 'w') as f:
|
161 |
+
json.dump(metadata, f, indent=2)
|
162 |
+
|
163 |
+
# Start sampling
|
164 |
+
collate_fn = PaddingCollate(eight=False)
|
165 |
+
inference_tfm = [ PatchAroundAnchor(), ]
|
166 |
+
if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode
|
167 |
+
inference_tfm.append(RemoveNative(
|
168 |
+
remove_structure = config.sampling.sample_structure,
|
169 |
+
remove_sequence = config.sampling.sample_sequence,
|
170 |
+
))
|
171 |
+
inference_tfm = Compose(inference_tfm)
|
172 |
+
|
173 |
+
for variant in data_variants:
|
174 |
+
os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True)
|
175 |
+
logger.info(f"Start sampling for: {variant['tag']}")
|
176 |
+
|
177 |
+
save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization
|
178 |
+
|
179 |
+
data_cropped = inference_tfm(
|
180 |
+
copy.deepcopy(variant['data'])
|
181 |
+
)
|
182 |
+
data_list_repeat = [ data_cropped ] * config.sampling.num_samples
|
183 |
+
loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
184 |
+
|
185 |
+
count = 0
|
186 |
+
for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True):
|
187 |
+
torch.set_grad_enabled(False)
|
188 |
+
model.eval()
|
189 |
+
batch = recursive_to(batch, args.device)
|
190 |
+
if 'abopt' in config.mode:
|
191 |
+
# Antibody optimization starting from native
|
192 |
+
traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={
|
193 |
+
'pbar': True,
|
194 |
+
'sample_structure': config.sampling.sample_structure,
|
195 |
+
'sample_sequence': config.sampling.sample_sequence,
|
196 |
+
})
|
197 |
+
else:
|
198 |
+
# De novo design
|
199 |
+
traj_batch = model.sample(batch, sample_opt={
|
200 |
+
'pbar': True,
|
201 |
+
'sample_structure': config.sampling.sample_structure,
|
202 |
+
'sample_sequence': config.sampling.sample_sequence,
|
203 |
+
})
|
204 |
+
|
205 |
+
aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid.
|
206 |
+
pos_atom_new, mask_atom_new = reconstruct_backbone_partially(
|
207 |
+
pos_ctx = batch['pos_heavyatom'],
|
208 |
+
R_new = so3vec_to_rotation(traj_batch[0][0]),
|
209 |
+
t_new = traj_batch[0][1],
|
210 |
+
aa = aa_new,
|
211 |
+
chain_nb = batch['chain_nb'],
|
212 |
+
res_nb = batch['res_nb'],
|
213 |
+
mask_atoms = batch['mask_heavyatom'],
|
214 |
+
mask_recons = batch['generate_flag'],
|
215 |
+
)
|
216 |
+
aa_new = aa_new.cpu()
|
217 |
+
pos_atom_new = pos_atom_new.cpu()
|
218 |
+
mask_atom_new = mask_atom_new.cpu()
|
219 |
+
|
220 |
+
for i in range(aa_new.size(0)):
|
221 |
+
data_tmpl = variant['data']
|
222 |
+
aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx'])
|
223 |
+
mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx'])
|
224 |
+
pos_ha = (
|
225 |
+
apply_patch_to_tensor(
|
226 |
+
data_tmpl['pos_heavyatom'],
|
227 |
+
pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
|
228 |
+
data_cropped['patch_idx']
|
229 |
+
)
|
230 |
+
)
|
231 |
+
|
232 |
+
save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, ))
|
233 |
+
save_pdb({
|
234 |
+
'chain_nb': data_tmpl['chain_nb'],
|
235 |
+
'chain_id': data_tmpl['chain_id'],
|
236 |
+
'resseq': data_tmpl['resseq'],
|
237 |
+
'icode': data_tmpl['icode'],
|
238 |
+
# Generated
|
239 |
+
'aa': aa,
|
240 |
+
'mask_heavyatom': mask_ha,
|
241 |
+
'pos_heavyatom': pos_ha,
|
242 |
+
}, path=save_path)
|
243 |
+
# save_pdb({
|
244 |
+
# 'chain_nb': data_cropped['chain_nb'],
|
245 |
+
# 'chain_id': data_cropped['chain_id'],
|
246 |
+
# 'resseq': data_cropped['resseq'],
|
247 |
+
# 'icode': data_cropped['icode'],
|
248 |
+
# # Generated
|
249 |
+
# 'aa': aa_new[i],
|
250 |
+
# 'mask_heavyatom': mask_atom_new[i],
|
251 |
+
# 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
|
252 |
+
# }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, )))
|
253 |
+
count += 1
|
254 |
+
|
255 |
+
logger.info('Finished.\n')
|
256 |
+
|
257 |
+
|
258 |
+
def args_from_cmdline():
|
259 |
+
parser = argparse.ArgumentParser()
|
260 |
+
parser.add_argument('pdb_path', type=str)
|
261 |
+
parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.')
|
262 |
+
parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.')
|
263 |
+
parser.add_argument('--no_renumber', action='store_true', default=False)
|
264 |
+
parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
|
265 |
+
parser.add_argument('-o', '--out_root', type=str, default='./results')
|
266 |
+
parser.add_argument('-t', '--tag', type=str, default='')
|
267 |
+
parser.add_argument('-s', '--seed', type=int, default=None)
|
268 |
+
parser.add_argument('-d', '--device', type=str, default='cuda')
|
269 |
+
parser.add_argument('-b', '--batch_size', type=int, default=16)
|
270 |
+
args = parser.parse_args()
|
271 |
+
return args
|
272 |
+
|
273 |
+
|
274 |
+
def args_factory(**kwargs):
|
275 |
+
default_args = EasyDict(
|
276 |
+
heavy = 'H',
|
277 |
+
light = 'L',
|
278 |
+
no_renumber = False,
|
279 |
+
config = './configs/test/codesign_single.yml',
|
280 |
+
out_root = './results',
|
281 |
+
tag = '',
|
282 |
+
seed = None,
|
283 |
+
device = 'cuda',
|
284 |
+
batch_size = 16
|
285 |
+
)
|
286 |
+
default_args.update(kwargs)
|
287 |
+
return default_args
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == '__main__':
|
291 |
+
design_for_pdb(args_from_cmdline())
|
diffab/tools/runner/design_for_testset.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import json
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
from diffab.datasets import get_dataset
|
9 |
+
from diffab.models import get_model
|
10 |
+
from diffab.modules.common.geometry import reconstruct_backbone_partially
|
11 |
+
from diffab.modules.common.so3 import so3vec_to_rotation
|
12 |
+
from diffab.utils.inference import RemoveNative
|
13 |
+
from diffab.utils.protein.writers import save_pdb
|
14 |
+
from diffab.utils.train import recursive_to
|
15 |
+
from diffab.utils.misc import *
|
16 |
+
from diffab.utils.data import *
|
17 |
+
from diffab.utils.transforms import *
|
18 |
+
from diffab.utils.inference import *
|
19 |
+
|
20 |
+
|
21 |
+
def create_data_variants(config, structure_factory):
|
22 |
+
structure = structure_factory()
|
23 |
+
structure_id = structure['id']
|
24 |
+
|
25 |
+
data_variants = []
|
26 |
+
if config.mode == 'single_cdr':
|
27 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
28 |
+
for cdr_name in cdrs:
|
29 |
+
transform = Compose([
|
30 |
+
MaskSingleCDR(cdr_name, augmentation=False),
|
31 |
+
MergeChains(),
|
32 |
+
])
|
33 |
+
data_var = transform(structure_factory())
|
34 |
+
residue_first, residue_last = get_residue_first_last(data_var)
|
35 |
+
data_variants.append({
|
36 |
+
'data': data_var,
|
37 |
+
'name': f'{structure_id}-{cdr_name}',
|
38 |
+
'tag': f'{cdr_name}',
|
39 |
+
'cdr': cdr_name,
|
40 |
+
'residue_first': residue_first,
|
41 |
+
'residue_last': residue_last,
|
42 |
+
})
|
43 |
+
elif config.mode == 'multiple_cdrs':
|
44 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
45 |
+
transform = Compose([
|
46 |
+
MaskMultipleCDRs(selection=cdrs, augmentation=False),
|
47 |
+
MergeChains(),
|
48 |
+
])
|
49 |
+
data_var = transform(structure_factory())
|
50 |
+
data_variants.append({
|
51 |
+
'data': data_var,
|
52 |
+
'name': f'{structure_id}-MultipleCDRs',
|
53 |
+
'tag': 'MultipleCDRs',
|
54 |
+
'cdrs': cdrs,
|
55 |
+
'residue_first': None,
|
56 |
+
'residue_last': None,
|
57 |
+
})
|
58 |
+
elif config.mode == 'full':
|
59 |
+
transform = Compose([
|
60 |
+
MaskAntibody(),
|
61 |
+
MergeChains(),
|
62 |
+
])
|
63 |
+
data_var = transform(structure_factory())
|
64 |
+
data_variants.append({
|
65 |
+
'data': data_var,
|
66 |
+
'name': f'{structure_id}-Full',
|
67 |
+
'tag': 'Full',
|
68 |
+
'residue_first': None,
|
69 |
+
'residue_last': None,
|
70 |
+
})
|
71 |
+
elif config.mode == 'abopt':
|
72 |
+
cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
|
73 |
+
for cdr_name in cdrs:
|
74 |
+
transform = Compose([
|
75 |
+
MaskSingleCDR(cdr_name, augmentation=False),
|
76 |
+
MergeChains(),
|
77 |
+
])
|
78 |
+
data_var = transform(structure_factory())
|
79 |
+
residue_first, residue_last = get_residue_first_last(data_var)
|
80 |
+
for opt_step in config.sampling.optimize_steps:
|
81 |
+
data_variants.append({
|
82 |
+
'data': data_var,
|
83 |
+
'name': f'{structure_id}-{cdr_name}-O{opt_step}',
|
84 |
+
'tag': f'{cdr_name}-O{opt_step}',
|
85 |
+
'cdr': cdr_name,
|
86 |
+
'opt_step': opt_step,
|
87 |
+
'residue_first': residue_first,
|
88 |
+
'residue_last': residue_last,
|
89 |
+
})
|
90 |
+
else:
|
91 |
+
raise ValueError(f'Unknown mode: {config.mode}.')
|
92 |
+
return data_variants
|
93 |
+
|
94 |
+
def main():
|
95 |
+
parser = argparse.ArgumentParser()
|
96 |
+
parser.add_argument('index', type=int)
|
97 |
+
parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
|
98 |
+
parser.add_argument('-o', '--out_root', type=str, default='./results')
|
99 |
+
parser.add_argument('-t', '--tag', type=str, default='')
|
100 |
+
parser.add_argument('-s', '--seed', type=int, default=None)
|
101 |
+
parser.add_argument('-d', '--device', type=str, default='cuda')
|
102 |
+
parser.add_argument('-b', '--batch_size', type=int, default=16)
|
103 |
+
args = parser.parse_args()
|
104 |
+
|
105 |
+
# Load configs
|
106 |
+
config, config_name = load_config(args.config)
|
107 |
+
seed_all(args.seed if args.seed is not None else config.sampling.seed)
|
108 |
+
|
109 |
+
# Testset
|
110 |
+
dataset = get_dataset(config.dataset.test)
|
111 |
+
get_structure = lambda: dataset[args.index]
|
112 |
+
|
113 |
+
# Logging
|
114 |
+
structure_ = get_structure()
|
115 |
+
structure_id = structure_['id']
|
116 |
+
tag_postfix = '_%s' % args.tag if args.tag else ''
|
117 |
+
log_dir = get_new_log_dir(os.path.join(args.out_root, config_name + tag_postfix), prefix='%04d_%s' % (args.index, structure_['id']))
|
118 |
+
logger = get_logger('sample', log_dir)
|
119 |
+
logger.info('Data ID: %s' % structure_['id'])
|
120 |
+
data_native = MergeChains()(structure_)
|
121 |
+
save_pdb(data_native, os.path.join(log_dir, 'reference.pdb'))
|
122 |
+
|
123 |
+
# Load checkpoint and model
|
124 |
+
logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint))
|
125 |
+
ckpt = torch.load(config.model.checkpoint, map_location='cpu')
|
126 |
+
cfg_ckpt = ckpt['config']
|
127 |
+
model = get_model(cfg_ckpt.model).to(args.device)
|
128 |
+
lsd = model.load_state_dict(ckpt['model'])
|
129 |
+
logger.info(str(lsd))
|
130 |
+
|
131 |
+
# Make data variants
|
132 |
+
data_variants = create_data_variants(
|
133 |
+
config = config,
|
134 |
+
structure_factory = get_structure,
|
135 |
+
)
|
136 |
+
|
137 |
+
# Save metadata
|
138 |
+
metadata = {
|
139 |
+
'identifier': structure_id,
|
140 |
+
'index': args.index,
|
141 |
+
'config': args.config,
|
142 |
+
'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants],
|
143 |
+
}
|
144 |
+
with open(os.path.join(log_dir, 'metadata.json'), 'w') as f:
|
145 |
+
json.dump(metadata, f, indent=2)
|
146 |
+
|
147 |
+
# Start sampling
|
148 |
+
collate_fn = PaddingCollate(eight=False)
|
149 |
+
inference_tfm = [ PatchAroundAnchor(), ]
|
150 |
+
if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode
|
151 |
+
inference_tfm.append(RemoveNative(
|
152 |
+
remove_structure = config.sampling.sample_structure,
|
153 |
+
remove_sequence = config.sampling.sample_sequence,
|
154 |
+
))
|
155 |
+
inference_tfm = Compose(inference_tfm)
|
156 |
+
|
157 |
+
for variant in data_variants:
|
158 |
+
os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True)
|
159 |
+
logger.info(f"Start sampling for: {variant['tag']}")
|
160 |
+
|
161 |
+
save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization
|
162 |
+
|
163 |
+
data_cropped = inference_tfm(
|
164 |
+
copy.deepcopy(variant['data'])
|
165 |
+
)
|
166 |
+
data_list_repeat = [ data_cropped ] * config.sampling.num_samples
|
167 |
+
loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
168 |
+
|
169 |
+
count = 0
|
170 |
+
for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True):
|
171 |
+
torch.set_grad_enabled(False)
|
172 |
+
model.eval()
|
173 |
+
batch = recursive_to(batch, args.device)
|
174 |
+
if 'abopt' in config.mode:
|
175 |
+
# Antibody optimization starting from native
|
176 |
+
traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={
|
177 |
+
'pbar': True,
|
178 |
+
'sample_structure': config.sampling.sample_structure,
|
179 |
+
'sample_sequence': config.sampling.sample_sequence,
|
180 |
+
})
|
181 |
+
else:
|
182 |
+
# De novo design
|
183 |
+
traj_batch = model.sample(batch, sample_opt={
|
184 |
+
'pbar': True,
|
185 |
+
'sample_structure': config.sampling.sample_structure,
|
186 |
+
'sample_sequence': config.sampling.sample_sequence,
|
187 |
+
})
|
188 |
+
|
189 |
+
aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid.
|
190 |
+
pos_atom_new, mask_atom_new = reconstruct_backbone_partially(
|
191 |
+
pos_ctx = batch['pos_heavyatom'],
|
192 |
+
R_new = so3vec_to_rotation(traj_batch[0][0]),
|
193 |
+
t_new = traj_batch[0][1],
|
194 |
+
aa = aa_new,
|
195 |
+
chain_nb = batch['chain_nb'],
|
196 |
+
res_nb = batch['res_nb'],
|
197 |
+
mask_atoms = batch['mask_heavyatom'],
|
198 |
+
mask_recons = batch['generate_flag'],
|
199 |
+
)
|
200 |
+
aa_new = aa_new.cpu()
|
201 |
+
pos_atom_new = pos_atom_new.cpu()
|
202 |
+
mask_atom_new = mask_atom_new.cpu()
|
203 |
+
|
204 |
+
for i in range(aa_new.size(0)):
|
205 |
+
data_tmpl = variant['data']
|
206 |
+
aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx'])
|
207 |
+
mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx'])
|
208 |
+
pos_ha = (
|
209 |
+
apply_patch_to_tensor(
|
210 |
+
data_tmpl['pos_heavyatom'],
|
211 |
+
pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
|
212 |
+
data_cropped['patch_idx']
|
213 |
+
)
|
214 |
+
)
|
215 |
+
|
216 |
+
save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, ))
|
217 |
+
save_pdb({
|
218 |
+
'chain_nb': data_tmpl['chain_nb'],
|
219 |
+
'chain_id': data_tmpl['chain_id'],
|
220 |
+
'resseq': data_tmpl['resseq'],
|
221 |
+
'icode': data_tmpl['icode'],
|
222 |
+
# Generated
|
223 |
+
'aa': aa,
|
224 |
+
'mask_heavyatom': mask_ha,
|
225 |
+
'pos_heavyatom': pos_ha,
|
226 |
+
}, path=save_path)
|
227 |
+
# save_pdb({
|
228 |
+
# 'chain_nb': data_cropped['chain_nb'],
|
229 |
+
# 'chain_id': data_cropped['chain_id'],
|
230 |
+
# 'resseq': data_cropped['resseq'],
|
231 |
+
# 'icode': data_cropped['icode'],
|
232 |
+
# # Generated
|
233 |
+
# 'aa': aa_new[i],
|
234 |
+
# 'mask_heavyatom': mask_atom_new[i],
|
235 |
+
# 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
|
236 |
+
# }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, )))
|
237 |
+
count += 1
|
238 |
+
|
239 |
+
logger.info('Finished.\n')
|
240 |
+
|
241 |
+
|
242 |
+
if __name__ == '__main__':
|
243 |
+
main()
|
diffab/utils/data.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data._utils.collate import default_collate
|
4 |
+
|
5 |
+
|
6 |
+
DEFAULT_PAD_VALUES = {
|
7 |
+
'aa': 21,
|
8 |
+
'chain_id': ' ',
|
9 |
+
'icode': ' ',
|
10 |
+
}
|
11 |
+
|
12 |
+
DEFAULT_NO_PADDING = {
|
13 |
+
'origin',
|
14 |
+
}
|
15 |
+
|
16 |
+
class PaddingCollate(object):
|
17 |
+
|
18 |
+
def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True):
|
19 |
+
super().__init__()
|
20 |
+
self.length_ref_key = length_ref_key
|
21 |
+
self.pad_values = pad_values
|
22 |
+
self.no_padding = no_padding
|
23 |
+
self.eight = eight
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def _pad_last(x, n, value=0):
|
27 |
+
if isinstance(x, torch.Tensor):
|
28 |
+
assert x.size(0) <= n
|
29 |
+
if x.size(0) == n:
|
30 |
+
return x
|
31 |
+
pad_size = [n - x.size(0)] + list(x.shape[1:])
|
32 |
+
pad = torch.full(pad_size, fill_value=value).to(x)
|
33 |
+
return torch.cat([x, pad], dim=0)
|
34 |
+
elif isinstance(x, list):
|
35 |
+
pad = [value] * (n - len(x))
|
36 |
+
return x + pad
|
37 |
+
else:
|
38 |
+
return x
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def _get_pad_mask(l, n):
|
42 |
+
return torch.cat([
|
43 |
+
torch.ones([l], dtype=torch.bool),
|
44 |
+
torch.zeros([n-l], dtype=torch.bool)
|
45 |
+
], dim=0)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def _get_common_keys(list_of_dict):
|
49 |
+
keys = set(list_of_dict[0].keys())
|
50 |
+
for d in list_of_dict[1:]:
|
51 |
+
keys = keys.intersection(d.keys())
|
52 |
+
return keys
|
53 |
+
|
54 |
+
|
55 |
+
def _get_pad_value(self, key):
|
56 |
+
if key not in self.pad_values:
|
57 |
+
return 0
|
58 |
+
return self.pad_values[key]
|
59 |
+
|
60 |
+
def __call__(self, data_list):
|
61 |
+
max_length = max([data[self.length_ref_key].size(0) for data in data_list])
|
62 |
+
keys = self._get_common_keys(data_list)
|
63 |
+
|
64 |
+
if self.eight:
|
65 |
+
max_length = math.ceil(max_length / 8) * 8
|
66 |
+
data_list_padded = []
|
67 |
+
for data in data_list:
|
68 |
+
data_padded = {
|
69 |
+
k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v
|
70 |
+
for k, v in data.items()
|
71 |
+
if k in keys
|
72 |
+
}
|
73 |
+
data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length)
|
74 |
+
data_list_padded.append(data_padded)
|
75 |
+
return default_collate(data_list_padded)
|
76 |
+
|
77 |
+
|
78 |
+
def apply_patch_to_tensor(x_full, x_patch, patch_idx):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
x_full: (N, ...)
|
82 |
+
x_patch: (M, ...)
|
83 |
+
patch_idx: (M, )
|
84 |
+
Returns:
|
85 |
+
(N, ...)
|
86 |
+
"""
|
87 |
+
x_full = x_full.clone()
|
88 |
+
x_full[patch_idx] = x_patch
|
89 |
+
return x_full
|
diffab/utils/inference.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .protein import constants
|
3 |
+
|
4 |
+
|
5 |
+
def find_cdrs(structure):
|
6 |
+
cdrs = []
|
7 |
+
if structure['heavy'] is not None:
|
8 |
+
flag = structure['heavy']['cdr_flag']
|
9 |
+
if int(constants.CDR.H1) in flag:
|
10 |
+
cdrs.append('H_CDR1')
|
11 |
+
if int(constants.CDR.H2) in flag:
|
12 |
+
cdrs.append('H_CDR2')
|
13 |
+
if int(constants.CDR.H3) in flag:
|
14 |
+
cdrs.append('H_CDR3')
|
15 |
+
|
16 |
+
if structure['light'] is not None:
|
17 |
+
flag = structure['light']['cdr_flag']
|
18 |
+
if int(constants.CDR.L1) in flag:
|
19 |
+
cdrs.append('L_CDR1')
|
20 |
+
if int(constants.CDR.L2) in flag:
|
21 |
+
cdrs.append('L_CDR2')
|
22 |
+
if int(constants.CDR.L3) in flag:
|
23 |
+
cdrs.append('L_CDR3')
|
24 |
+
|
25 |
+
return cdrs
|
26 |
+
|
27 |
+
|
28 |
+
def get_residue_first_last(data):
|
29 |
+
loop_flag = data['generate_flag']
|
30 |
+
loop_idx = torch.arange(loop_flag.size(0))[loop_flag]
|
31 |
+
idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item()
|
32 |
+
residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first])
|
33 |
+
residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last])
|
34 |
+
return residue_first, residue_last
|
35 |
+
|
36 |
+
|
37 |
+
class RemoveNative(object):
|
38 |
+
|
39 |
+
def __init__(self, remove_structure, remove_sequence):
|
40 |
+
super().__init__()
|
41 |
+
self.remove_structure = remove_structure
|
42 |
+
self.remove_sequence = remove_sequence
|
43 |
+
|
44 |
+
def __call__(self, data):
|
45 |
+
generate_flag = data['generate_flag'].clone()
|
46 |
+
if self.remove_sequence:
|
47 |
+
data['aa'] = torch.where(
|
48 |
+
generate_flag,
|
49 |
+
torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop
|
50 |
+
data['aa']
|
51 |
+
)
|
52 |
+
|
53 |
+
if self.remove_structure:
|
54 |
+
data['pos_heavyatom'] = torch.where(
|
55 |
+
generate_flag[:, None, None].expand(data['pos_heavyatom'].shape),
|
56 |
+
torch.randn_like(data['pos_heavyatom']) * 10,
|
57 |
+
data['pos_heavyatom']
|
58 |
+
)
|
59 |
+
|
60 |
+
return data
|
diffab/utils/misc.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import logging
|
5 |
+
from typing import OrderedDict
|
6 |
+
import torch
|
7 |
+
import torch.linalg
|
8 |
+
import numpy as np
|
9 |
+
import yaml
|
10 |
+
from easydict import EasyDict
|
11 |
+
from glob import glob
|
12 |
+
|
13 |
+
|
14 |
+
class BlackHole(object):
|
15 |
+
def __setattr__(self, name, value):
|
16 |
+
pass
|
17 |
+
|
18 |
+
def __call__(self, *args, **kwargs):
|
19 |
+
return self
|
20 |
+
|
21 |
+
def __getattr__(self, name):
|
22 |
+
return self
|
23 |
+
|
24 |
+
|
25 |
+
class Counter(object):
|
26 |
+
def __init__(self, start=0):
|
27 |
+
super().__init__()
|
28 |
+
self.now = start
|
29 |
+
|
30 |
+
def step(self, delta=1):
|
31 |
+
prev = self.now
|
32 |
+
self.now += delta
|
33 |
+
return prev
|
34 |
+
|
35 |
+
|
36 |
+
def get_logger(name, log_dir=None):
|
37 |
+
logger = logging.getLogger(name)
|
38 |
+
logger.setLevel(logging.DEBUG)
|
39 |
+
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
|
40 |
+
|
41 |
+
stream_handler = logging.StreamHandler()
|
42 |
+
stream_handler.setLevel(logging.DEBUG)
|
43 |
+
stream_handler.setFormatter(formatter)
|
44 |
+
logger.addHandler(stream_handler)
|
45 |
+
|
46 |
+
if log_dir is not None:
|
47 |
+
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
|
48 |
+
file_handler.setLevel(logging.DEBUG)
|
49 |
+
file_handler.setFormatter(formatter)
|
50 |
+
logger.addHandler(file_handler)
|
51 |
+
|
52 |
+
return logger
|
53 |
+
|
54 |
+
|
55 |
+
def get_new_log_dir(root='./logs', prefix='', tag=''):
|
56 |
+
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
|
57 |
+
if prefix != '':
|
58 |
+
fn = prefix + '_' + fn
|
59 |
+
if tag != '':
|
60 |
+
fn = fn + '_' + tag
|
61 |
+
log_dir = os.path.join(root, fn)
|
62 |
+
os.makedirs(log_dir)
|
63 |
+
return log_dir
|
64 |
+
|
65 |
+
|
66 |
+
def seed_all(seed):
|
67 |
+
torch.backends.cudnn.deterministic = True
|
68 |
+
torch.manual_seed(seed)
|
69 |
+
torch.cuda.manual_seed_all(seed)
|
70 |
+
np.random.seed(seed)
|
71 |
+
random.seed(seed)
|
72 |
+
|
73 |
+
|
74 |
+
def inf_iterator(iterable):
|
75 |
+
iterator = iterable.__iter__()
|
76 |
+
while True:
|
77 |
+
try:
|
78 |
+
yield iterator.__next__()
|
79 |
+
except StopIteration:
|
80 |
+
iterator = iterable.__iter__()
|
81 |
+
|
82 |
+
|
83 |
+
def log_hyperparams(writer, args):
|
84 |
+
from torch.utils.tensorboard.summary import hparams
|
85 |
+
vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
|
86 |
+
exp, ssi, sei = hparams(vars_args, {})
|
87 |
+
writer.file_writer.add_summary(exp)
|
88 |
+
writer.file_writer.add_summary(ssi)
|
89 |
+
writer.file_writer.add_summary(sei)
|
90 |
+
|
91 |
+
|
92 |
+
def int_tuple(argstr):
|
93 |
+
return tuple(map(int, argstr.split(',')))
|
94 |
+
|
95 |
+
|
96 |
+
def str_tuple(argstr):
|
97 |
+
return tuple(argstr.split(','))
|
98 |
+
|
99 |
+
|
100 |
+
def get_checkpoint_path(folder, it=None):
|
101 |
+
if it is not None:
|
102 |
+
return os.path.join(folder, '%d.pt' % it), it
|
103 |
+
all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt'))))
|
104 |
+
all_iters.sort()
|
105 |
+
return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1]
|
106 |
+
|
107 |
+
|
108 |
+
def load_config(config_path):
|
109 |
+
with open(config_path, 'r') as f:
|
110 |
+
config = EasyDict(yaml.safe_load(f))
|
111 |
+
config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')]
|
112 |
+
return config, config_name
|
113 |
+
|
114 |
+
|
115 |
+
def extract_weights(weights: OrderedDict, prefix):
|
116 |
+
extracted = OrderedDict()
|
117 |
+
for k, v in weights.items():
|
118 |
+
if k.startswith(prefix):
|
119 |
+
extracted.update({
|
120 |
+
k[len(prefix):]: v
|
121 |
+
})
|
122 |
+
return extracted
|
123 |
+
|
124 |
+
|
125 |
+
def current_milli_time():
|
126 |
+
return round(time.time() * 1000)
|
diffab/utils/protein/constants.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import enum
|
3 |
+
|
4 |
+
class CDR(enum.IntEnum):
|
5 |
+
H1 = 1
|
6 |
+
H2 = 2
|
7 |
+
H3 = 3
|
8 |
+
L1 = 4
|
9 |
+
L2 = 5
|
10 |
+
L3 = 6
|
11 |
+
|
12 |
+
|
13 |
+
class ChothiaCDRRange:
|
14 |
+
H1 = (26, 32)
|
15 |
+
H2 = (52, 56)
|
16 |
+
H3 = (95, 102)
|
17 |
+
|
18 |
+
L1 = (24, 34)
|
19 |
+
L2 = (50, 56)
|
20 |
+
L3 = (89, 97)
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
def to_cdr(cls, chain_type, resseq):
|
24 |
+
assert chain_type in ('H', 'L')
|
25 |
+
if chain_type == 'H':
|
26 |
+
if cls.H1[0] <= resseq <= cls.H1[1]:
|
27 |
+
return CDR.H1
|
28 |
+
elif cls.H2[0] <= resseq <= cls.H2[1]:
|
29 |
+
return CDR.H2
|
30 |
+
elif cls.H3[0] <= resseq <= cls.H3[1]:
|
31 |
+
return CDR.H3
|
32 |
+
elif chain_type == 'L':
|
33 |
+
if cls.L1[0] <= resseq <= cls.L1[1]: # Chothia VH-CDR1
|
34 |
+
return CDR.L1
|
35 |
+
elif cls.L2[0] <= resseq <= cls.L2[1]:
|
36 |
+
return CDR.L2
|
37 |
+
elif cls.L3[0] <= resseq <= cls.L3[1]:
|
38 |
+
return CDR.L3
|
39 |
+
|
40 |
+
|
41 |
+
class Fragment(enum.IntEnum):
|
42 |
+
Heavy = 1
|
43 |
+
Light = 2
|
44 |
+
Antigen = 3
|
45 |
+
|
46 |
+
##
|
47 |
+
# Residue identities
|
48 |
+
"""
|
49 |
+
This is part of the OpenMM molecular simulation toolkit originating from
|
50 |
+
Simbios, the NIH National Center for Physics-Based Simulation of
|
51 |
+
Biological Structures at Stanford, funded under the NIH Roadmap for
|
52 |
+
Medical Research, grant U54 GM072970. See https://simtk.org.
|
53 |
+
|
54 |
+
Portions copyright (c) 2013 Stanford University and the Authors.
|
55 |
+
Authors: Peter Eastman
|
56 |
+
Contributors:
|
57 |
+
|
58 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
59 |
+
copy of this software and associated documentation files (the "Software"),
|
60 |
+
to deal in the Software without restriction, including without limitation
|
61 |
+
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
62 |
+
and/or sell copies of the Software, and to permit persons to whom the
|
63 |
+
Software is furnished to do so, subject to the following conditions:
|
64 |
+
|
65 |
+
The above copyright notice and this permission notice shall be included in
|
66 |
+
all copies or substantial portions of the Software.
|
67 |
+
|
68 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
69 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
70 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
71 |
+
THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
72 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
73 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
|
74 |
+
USE OR OTHER DEALINGS IN THE SOFTWARE.
|
75 |
+
"""
|
76 |
+
non_standard_residue_substitutions = {
|
77 |
+
'2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG',
|
78 |
+
'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA',
|
79 |
+
'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS',
|
80 |
+
'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS',
|
81 |
+
'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA',
|
82 |
+
'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP',
|
83 |
+
'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY',
|
84 |
+
'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO',
|
85 |
+
'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN',
|
86 |
+
'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU',
|
87 |
+
'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE',
|
88 |
+
'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS',
|
89 |
+
'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA',
|
90 |
+
'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR'
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
ressymb_to_resindex = {
|
95 |
+
'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
|
96 |
+
'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
|
97 |
+
'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
|
98 |
+
'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,
|
99 |
+
'X': 20,
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
class AA(enum.IntEnum):
|
104 |
+
ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4
|
105 |
+
GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9
|
106 |
+
MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14
|
107 |
+
SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19
|
108 |
+
UNK = 20
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def _missing_(cls, value):
|
112 |
+
if isinstance(value, str) and len(value) == 3: # three representation
|
113 |
+
if value in non_standard_residue_substitutions:
|
114 |
+
value = non_standard_residue_substitutions[value]
|
115 |
+
if value in cls._member_names_:
|
116 |
+
return getattr(cls, value)
|
117 |
+
elif isinstance(value, str) and len(value) == 1: # one representation
|
118 |
+
if value in ressymb_to_resindex:
|
119 |
+
return cls(ressymb_to_resindex[value])
|
120 |
+
|
121 |
+
return super()._missing_(value)
|
122 |
+
|
123 |
+
def __str__(self):
|
124 |
+
return self.name
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def is_aa(cls, value):
|
128 |
+
return (value in ressymb_to_resindex) or \
|
129 |
+
(value in non_standard_residue_substitutions) or \
|
130 |
+
(value in cls._member_names_) or \
|
131 |
+
(value in cls._member_map_.values())
|
132 |
+
|
133 |
+
|
134 |
+
num_aa_types = len(AA)
|
135 |
+
|
136 |
+
##
|
137 |
+
# Atom identities
|
138 |
+
|
139 |
+
class BBHeavyAtom(enum.IntEnum):
|
140 |
+
N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14;
|
141 |
+
|
142 |
+
max_num_heavyatoms = 15
|
143 |
+
|
144 |
+
# Copyright 2021 DeepMind Technologies Limited
|
145 |
+
#
|
146 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
147 |
+
# you may not use this file except in compliance with the License.
|
148 |
+
# You may obtain a copy of the License at
|
149 |
+
#
|
150 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
151 |
+
#
|
152 |
+
# Unless required by applicable law or agreed to in writing, software
|
153 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
154 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
155 |
+
# See the License for the specific language governing permissions and
|
156 |
+
# limitations under the License.
|
157 |
+
restype_to_heavyatom_names = {
|
158 |
+
AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', '', 'OXT'],
|
159 |
+
AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', '', 'OXT'],
|
160 |
+
AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', '', 'OXT'],
|
161 |
+
AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', '', 'OXT'],
|
162 |
+
AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', '', 'OXT'],
|
163 |
+
AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', '', 'OXT'],
|
164 |
+
AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', '', 'OXT'],
|
165 |
+
AA.GLY: ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', '', 'OXT'],
|
166 |
+
AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', '', 'OXT'],
|
167 |
+
AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', '', 'OXT'],
|
168 |
+
AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', '', 'OXT'],
|
169 |
+
AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', '', 'OXT'],
|
170 |
+
AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', '', 'OXT'],
|
171 |
+
AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', '', 'OXT'],
|
172 |
+
AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', '', 'OXT'],
|
173 |
+
AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', '', 'OXT'],
|
174 |
+
AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', '', 'OXT'],
|
175 |
+
AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],
|
176 |
+
AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', '', 'OXT'],
|
177 |
+
AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', '', 'OXT'],
|
178 |
+
AA.UNK: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', ''],
|
179 |
+
}
|
180 |
+
for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms
|
181 |
+
|
182 |
+
|
183 |
+
backbone_atom_coordinates = {
|
184 |
+
AA.ALA: [
|
185 |
+
(-0.525, 1.363, 0.0), # N
|
186 |
+
(0.0, 0.0, 0.0), # CA
|
187 |
+
(1.526, -0.0, -0.0), # C
|
188 |
+
],
|
189 |
+
AA.ARG: [
|
190 |
+
(-0.524, 1.362, -0.0), # N
|
191 |
+
(0.0, 0.0, 0.0), # CA
|
192 |
+
(1.525, -0.0, -0.0), # C
|
193 |
+
],
|
194 |
+
AA.ASN: [
|
195 |
+
(-0.536, 1.357, 0.0), # N
|
196 |
+
(0.0, 0.0, 0.0), # CA
|
197 |
+
(1.526, -0.0, -0.0), # C
|
198 |
+
],
|
199 |
+
AA.ASP: [
|
200 |
+
(-0.525, 1.362, -0.0), # N
|
201 |
+
(0.0, 0.0, 0.0), # CA
|
202 |
+
(1.527, 0.0, -0.0), # C
|
203 |
+
],
|
204 |
+
AA.CYS: [
|
205 |
+
(-0.522, 1.362, -0.0), # N
|
206 |
+
(0.0, 0.0, 0.0), # CA
|
207 |
+
(1.524, 0.0, 0.0), # C
|
208 |
+
],
|
209 |
+
AA.GLN: [
|
210 |
+
(-0.526, 1.361, -0.0), # N
|
211 |
+
(0.0, 0.0, 0.0), # CA
|
212 |
+
(1.526, 0.0, 0.0), # C
|
213 |
+
],
|
214 |
+
AA.GLU: [
|
215 |
+
(-0.528, 1.361, 0.0), # N
|
216 |
+
(0.0, 0.0, 0.0), # CA
|
217 |
+
(1.526, -0.0, -0.0), # C
|
218 |
+
],
|
219 |
+
AA.GLY: [
|
220 |
+
(-0.572, 1.337, 0.0), # N
|
221 |
+
(0.0, 0.0, 0.0), # CA
|
222 |
+
(1.517, -0.0, -0.0), # C
|
223 |
+
],
|
224 |
+
AA.HIS: [
|
225 |
+
(-0.527, 1.36, 0.0), # N
|
226 |
+
(0.0, 0.0, 0.0), # CA
|
227 |
+
(1.525, 0.0, 0.0), # C
|
228 |
+
],
|
229 |
+
AA.ILE: [
|
230 |
+
(-0.493, 1.373, -0.0), # N
|
231 |
+
(0.0, 0.0, 0.0), # CA
|
232 |
+
(1.527, -0.0, -0.0), # C
|
233 |
+
],
|
234 |
+
AA.LEU: [
|
235 |
+
(-0.52, 1.363, 0.0), # N
|
236 |
+
(0.0, 0.0, 0.0), # CA
|
237 |
+
(1.525, -0.0, -0.0), # C
|
238 |
+
],
|
239 |
+
AA.LYS: [
|
240 |
+
(-0.526, 1.362, -0.0), # N
|
241 |
+
(0.0, 0.0, 0.0), # CA
|
242 |
+
(1.526, 0.0, 0.0), # C
|
243 |
+
],
|
244 |
+
AA.MET: [
|
245 |
+
(-0.521, 1.364, -0.0), # N
|
246 |
+
(0.0, 0.0, 0.0), # CA
|
247 |
+
(1.525, 0.0, 0.0), # C
|
248 |
+
],
|
249 |
+
AA.PHE: [
|
250 |
+
(-0.518, 1.363, 0.0), # N
|
251 |
+
(0.0, 0.0, 0.0), # CA
|
252 |
+
(1.524, 0.0, -0.0), # C
|
253 |
+
],
|
254 |
+
AA.PRO: [
|
255 |
+
(-0.566, 1.351, -0.0), # N
|
256 |
+
(0.0, 0.0, 0.0), # CA
|
257 |
+
(1.527, -0.0, 0.0), # C
|
258 |
+
],
|
259 |
+
AA.SER: [
|
260 |
+
(-0.529, 1.36, -0.0), # N
|
261 |
+
(0.0, 0.0, 0.0), # CA
|
262 |
+
(1.525, -0.0, -0.0), # C
|
263 |
+
],
|
264 |
+
AA.THR: [
|
265 |
+
(-0.517, 1.364, 0.0), # N
|
266 |
+
(0.0, 0.0, 0.0), # CA
|
267 |
+
(1.526, 0.0, -0.0), # C
|
268 |
+
],
|
269 |
+
AA.TRP: [
|
270 |
+
(-0.521, 1.363, 0.0), # N
|
271 |
+
(0.0, 0.0, 0.0), # CA
|
272 |
+
(1.525, -0.0, 0.0), # C
|
273 |
+
],
|
274 |
+
AA.TYR: [
|
275 |
+
(-0.522, 1.362, 0.0), # N
|
276 |
+
(0.0, 0.0, 0.0), # CA
|
277 |
+
(1.524, -0.0, -0.0), # C
|
278 |
+
],
|
279 |
+
AA.VAL: [
|
280 |
+
(-0.494, 1.373, -0.0), # N
|
281 |
+
(0.0, 0.0, 0.0), # CA
|
282 |
+
(1.527, -0.0, -0.0), # C
|
283 |
+
],
|
284 |
+
}
|
285 |
+
|
286 |
+
bb_oxygen_coordinate = {
|
287 |
+
AA.ALA: (2.153, -1.062, 0.0),
|
288 |
+
AA.ARG: (2.151, -1.062, 0.0),
|
289 |
+
AA.ASN: (2.151, -1.062, 0.0),
|
290 |
+
AA.ASP: (2.153, -1.062, 0.0),
|
291 |
+
AA.CYS: (2.149, -1.062, 0.0),
|
292 |
+
AA.GLN: (2.152, -1.062, 0.0),
|
293 |
+
AA.GLU: (2.152, -1.062, 0.0),
|
294 |
+
AA.GLY: (2.143, -1.062, 0.0),
|
295 |
+
AA.HIS: (2.15, -1.063, 0.0),
|
296 |
+
AA.ILE: (2.154, -1.062, 0.0),
|
297 |
+
AA.LEU: (2.15, -1.063, 0.0),
|
298 |
+
AA.LYS: (2.152, -1.062, 0.0),
|
299 |
+
AA.MET: (2.15, -1.062, 0.0),
|
300 |
+
AA.PHE: (2.15, -1.062, 0.0),
|
301 |
+
AA.PRO: (2.148, -1.066, 0.0),
|
302 |
+
AA.SER: (2.151, -1.062, 0.0),
|
303 |
+
AA.THR: (2.152, -1.062, 0.0),
|
304 |
+
AA.TRP: (2.152, -1.062, 0.0),
|
305 |
+
AA.TYR: (2.151, -1.062, 0.0),
|
306 |
+
AA.VAL: (2.154, -1.062, 0.0),
|
307 |
+
}
|
308 |
+
|
309 |
+
backbone_atom_coordinates_tensor = torch.zeros([21, 3, 3])
|
310 |
+
bb_oxygen_coordinate_tensor = torch.zeros([21, 3])
|
311 |
+
|
312 |
+
def make_coordinate_tensors():
|
313 |
+
for restype, atom_coords in backbone_atom_coordinates.items():
|
314 |
+
for atom_id, atom_coord in enumerate(atom_coords):
|
315 |
+
backbone_atom_coordinates_tensor[restype][atom_id] = torch.FloatTensor(atom_coord)
|
316 |
+
|
317 |
+
for restype, bb_oxy_coord in bb_oxygen_coordinate.items():
|
318 |
+
bb_oxygen_coordinate_tensor[restype] = torch.FloatTensor(bb_oxy_coord)
|
319 |
+
make_coordinate_tensors()
|
diffab/utils/protein/parsers.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from Bio.PDB import Selection
|
3 |
+
from Bio.PDB.Residue import Residue
|
4 |
+
from easydict import EasyDict
|
5 |
+
|
6 |
+
from .constants import (
|
7 |
+
AA, max_num_heavyatoms,
|
8 |
+
restype_to_heavyatom_names,
|
9 |
+
BBHeavyAtom
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class ParsingException(Exception):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
def _get_residue_heavyatom_info(res: Residue):
|
18 |
+
pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)
|
19 |
+
mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)
|
20 |
+
restype = AA(res.get_resname())
|
21 |
+
for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):
|
22 |
+
if atom_name == '': continue
|
23 |
+
if atom_name in res:
|
24 |
+
pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)
|
25 |
+
mask_heavyatom[idx] = True
|
26 |
+
return pos_heavyatom, mask_heavyatom
|
27 |
+
|
28 |
+
|
29 |
+
def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None):
|
30 |
+
chains = Selection.unfold_entities(entity, 'C')
|
31 |
+
chains.sort(key=lambda c: c.get_id())
|
32 |
+
data = EasyDict({
|
33 |
+
'chain_id': [],
|
34 |
+
'resseq': [], 'icode': [], 'res_nb': [],
|
35 |
+
'aa': [],
|
36 |
+
'pos_heavyatom': [], 'mask_heavyatom': [],
|
37 |
+
})
|
38 |
+
tensor_types = {
|
39 |
+
'resseq': torch.LongTensor,
|
40 |
+
'res_nb': torch.LongTensor,
|
41 |
+
'aa': torch.LongTensor,
|
42 |
+
'pos_heavyatom': torch.stack,
|
43 |
+
'mask_heavyatom': torch.stack,
|
44 |
+
}
|
45 |
+
|
46 |
+
count_aa, count_unk = 0, 0
|
47 |
+
|
48 |
+
for i, chain in enumerate(chains):
|
49 |
+
seq_this = 0 # Renumbering residues
|
50 |
+
residues = Selection.unfold_entities(chain, 'R')
|
51 |
+
residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode
|
52 |
+
for _, res in enumerate(residues):
|
53 |
+
resseq_this = int(res.get_id()[1])
|
54 |
+
if max_resseq is not None and resseq_this > max_resseq:
|
55 |
+
continue
|
56 |
+
|
57 |
+
resname = res.get_resname()
|
58 |
+
if not AA.is_aa(resname): continue
|
59 |
+
if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue
|
60 |
+
restype = AA(resname)
|
61 |
+
count_aa += 1
|
62 |
+
if restype == AA.UNK:
|
63 |
+
count_unk += 1
|
64 |
+
continue
|
65 |
+
|
66 |
+
# Chain info
|
67 |
+
data.chain_id.append(chain.get_id())
|
68 |
+
|
69 |
+
# Residue types
|
70 |
+
data.aa.append(restype) # Will be automatically cast to torch.long
|
71 |
+
|
72 |
+
# Heavy atoms
|
73 |
+
pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res)
|
74 |
+
data.pos_heavyatom.append(pos_heavyatom)
|
75 |
+
data.mask_heavyatom.append(mask_heavyatom)
|
76 |
+
|
77 |
+
# Sequential number
|
78 |
+
resseq_this = int(res.get_id()[1])
|
79 |
+
icode_this = res.get_id()[2]
|
80 |
+
if seq_this == 0:
|
81 |
+
seq_this = 1
|
82 |
+
else:
|
83 |
+
d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item()
|
84 |
+
if d_CA_CA <= 4.0:
|
85 |
+
seq_this += 1
|
86 |
+
else:
|
87 |
+
d_resseq = resseq_this - data.resseq[-1]
|
88 |
+
seq_this += max(2, d_resseq)
|
89 |
+
|
90 |
+
data.resseq.append(resseq_this)
|
91 |
+
data.icode.append(icode_this)
|
92 |
+
data.res_nb.append(seq_this)
|
93 |
+
|
94 |
+
if len(data.aa) == 0:
|
95 |
+
raise ParsingException('No parsed residues.')
|
96 |
+
|
97 |
+
if (count_unk / count_aa) >= unknown_threshold:
|
98 |
+
raise ParsingException(
|
99 |
+
f'Too many unknown residues, threshold {unknown_threshold:.2f}.'
|
100 |
+
)
|
101 |
+
|
102 |
+
seq_map = {}
|
103 |
+
for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)):
|
104 |
+
seq_map[(chain_id, resseq, icode)] = i
|
105 |
+
|
106 |
+
for key, convert_fn in tensor_types.items():
|
107 |
+
data[key] = convert_fn(data[key])
|
108 |
+
|
109 |
+
return data, seq_map
|
diffab/utils/protein/writers.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
from Bio import BiopythonWarning
|
4 |
+
from Bio.PDB import PDBIO
|
5 |
+
from Bio.PDB.StructureBuilder import StructureBuilder
|
6 |
+
|
7 |
+
from .constants import AA, restype_to_heavyatom_names
|
8 |
+
|
9 |
+
|
10 |
+
def save_pdb(data, path=None):
|
11 |
+
"""
|
12 |
+
Args:
|
13 |
+
data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`,
|
14 |
+
`pos_heavyatom`, `mask_heavyatom`.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def _mask_select(v, mask):
|
18 |
+
if isinstance(v, str):
|
19 |
+
return ''.join([s for i, s in enumerate(v) if mask[i]])
|
20 |
+
elif isinstance(v, list):
|
21 |
+
return [s for i, s in enumerate(v) if mask[i]]
|
22 |
+
elif isinstance(v, torch.Tensor):
|
23 |
+
return v[mask]
|
24 |
+
else:
|
25 |
+
return v
|
26 |
+
|
27 |
+
def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch):
|
28 |
+
builder.init_chain(chain_id_ch[0])
|
29 |
+
builder.init_seg(' ')
|
30 |
+
|
31 |
+
for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \
|
32 |
+
zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch):
|
33 |
+
if not AA.is_aa(aa_res.item()):
|
34 |
+
print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item()))
|
35 |
+
continue
|
36 |
+
restype = AA(aa_res.item())
|
37 |
+
builder.init_residue(
|
38 |
+
resname = str(restype),
|
39 |
+
field = ' ',
|
40 |
+
resseq = resseq_res.item(),
|
41 |
+
icode = icode_res,
|
42 |
+
)
|
43 |
+
|
44 |
+
for i, atom_name in enumerate(restype_to_heavyatom_names[restype]):
|
45 |
+
if atom_name == '': continue # No expected atom
|
46 |
+
if (~mask_allatom_res[i]).any(): continue # Atom is missing
|
47 |
+
if len(atom_name) == 1: fullname = ' %s ' % atom_name
|
48 |
+
elif len(atom_name) == 2: fullname = ' %s ' % atom_name
|
49 |
+
elif len(atom_name) == 3: fullname = ' %s' % atom_name
|
50 |
+
else: fullname = atom_name # len == 4
|
51 |
+
builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,)
|
52 |
+
|
53 |
+
warnings.simplefilter('ignore', BiopythonWarning)
|
54 |
+
builder = StructureBuilder()
|
55 |
+
builder.init_structure(0)
|
56 |
+
builder.init_model(0)
|
57 |
+
|
58 |
+
unique_chain_nb = data['chain_nb'].unique().tolist()
|
59 |
+
for ch_nb in unique_chain_nb:
|
60 |
+
mask = (data['chain_nb'] == ch_nb)
|
61 |
+
aa = _mask_select(data['aa'], mask)
|
62 |
+
pos_heavyatom = _mask_select(data['pos_heavyatom'], mask)
|
63 |
+
mask_heavyatom = _mask_select(data['mask_heavyatom'], mask)
|
64 |
+
chain_id = _mask_select(data['chain_id'], mask)
|
65 |
+
resseq = _mask_select(data['resseq'], mask)
|
66 |
+
icode = _mask_select(data['icode'], mask)
|
67 |
+
|
68 |
+
_build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode)
|
69 |
+
|
70 |
+
structure = builder.get_structure()
|
71 |
+
if path is not None:
|
72 |
+
io = PDBIO()
|
73 |
+
io.set_structure(structure)
|
74 |
+
io.save(path)
|
75 |
+
return structure
|
diffab/utils/train.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from easydict import EasyDict
|
4 |
+
|
5 |
+
from .misc import BlackHole
|
6 |
+
|
7 |
+
|
8 |
+
def get_optimizer(cfg, model):
|
9 |
+
if cfg.type == 'adam':
|
10 |
+
return torch.optim.Adam(
|
11 |
+
model.parameters(),
|
12 |
+
lr=cfg.lr,
|
13 |
+
weight_decay=cfg.weight_decay,
|
14 |
+
betas=(cfg.beta1, cfg.beta2, )
|
15 |
+
)
|
16 |
+
else:
|
17 |
+
raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
|
18 |
+
|
19 |
+
|
20 |
+
def get_scheduler(cfg, optimizer):
|
21 |
+
if cfg.type is None:
|
22 |
+
return BlackHole()
|
23 |
+
elif cfg.type == 'plateau':
|
24 |
+
return torch.optim.lr_scheduler.ReduceLROnPlateau(
|
25 |
+
optimizer,
|
26 |
+
factor=cfg.factor,
|
27 |
+
patience=cfg.patience,
|
28 |
+
min_lr=cfg.min_lr,
|
29 |
+
)
|
30 |
+
elif cfg.type == 'multistep':
|
31 |
+
return torch.optim.lr_scheduler.MultiStepLR(
|
32 |
+
optimizer,
|
33 |
+
milestones=cfg.milestones,
|
34 |
+
gamma=cfg.gamma,
|
35 |
+
)
|
36 |
+
elif cfg.type == 'exp':
|
37 |
+
return torch.optim.lr_scheduler.ExponentialLR(
|
38 |
+
optimizer,
|
39 |
+
gamma=cfg.gamma,
|
40 |
+
)
|
41 |
+
elif cfg.type is None:
|
42 |
+
return BlackHole()
|
43 |
+
else:
|
44 |
+
raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
|
45 |
+
|
46 |
+
|
47 |
+
def get_warmup_sched(cfg, optimizer):
|
48 |
+
if cfg is None: return BlackHole()
|
49 |
+
lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups]
|
50 |
+
warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas)
|
51 |
+
return warmup_sched
|
52 |
+
|
53 |
+
|
54 |
+
def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}):
|
55 |
+
logstr = '[%s] Iter %05d' % (tag, it)
|
56 |
+
logstr += ' | loss %.4f' % out['overall'].item()
|
57 |
+
for k, v in out.items():
|
58 |
+
if k == 'overall': continue
|
59 |
+
logstr += ' | loss(%s) %.4f' % (k, v.item())
|
60 |
+
for k, v in others.items():
|
61 |
+
logstr += ' | %s %2.4f' % (k, v)
|
62 |
+
logger.info(logstr)
|
63 |
+
|
64 |
+
for k, v in out.items():
|
65 |
+
if k == 'overall':
|
66 |
+
writer.add_scalar('%s/loss' % tag, v, it)
|
67 |
+
else:
|
68 |
+
writer.add_scalar('%s/loss_%s' % (tag, k), v, it)
|
69 |
+
for k, v in others.items():
|
70 |
+
writer.add_scalar('%s/%s' % (tag, k), v, it)
|
71 |
+
writer.flush()
|
72 |
+
|
73 |
+
|
74 |
+
class ValidationLossTape(object):
|
75 |
+
|
76 |
+
def __init__(self):
|
77 |
+
super().__init__()
|
78 |
+
self.accumulate = {}
|
79 |
+
self.others = {}
|
80 |
+
self.total = 0
|
81 |
+
|
82 |
+
def update(self, out, n, others={}):
|
83 |
+
self.total += n
|
84 |
+
for k, v in out.items():
|
85 |
+
if k not in self.accumulate:
|
86 |
+
self.accumulate[k] = v.clone().detach()
|
87 |
+
else:
|
88 |
+
self.accumulate[k] += v.clone().detach()
|
89 |
+
|
90 |
+
for k, v in others.items():
|
91 |
+
if k not in self.others:
|
92 |
+
self.others[k] = v.clone().detach()
|
93 |
+
else:
|
94 |
+
self.others[k] += v.clone().detach()
|
95 |
+
|
96 |
+
|
97 |
+
def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'):
|
98 |
+
avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()})
|
99 |
+
avg_others = EasyDict({k:v / self.total for k, v in self.others.items()})
|
100 |
+
log_losses(avg, it, tag, logger, writer, others=avg_others)
|
101 |
+
return avg['overall']
|
102 |
+
|
103 |
+
|
104 |
+
def recursive_to(obj, device):
|
105 |
+
if isinstance(obj, torch.Tensor):
|
106 |
+
if device == 'cpu':
|
107 |
+
return obj.cpu()
|
108 |
+
try:
|
109 |
+
return obj.cuda(device=device, non_blocking=True)
|
110 |
+
except RuntimeError:
|
111 |
+
return obj.to(device)
|
112 |
+
elif isinstance(obj, list):
|
113 |
+
return [recursive_to(o, device=device) for o in obj]
|
114 |
+
elif isinstance(obj, tuple):
|
115 |
+
return tuple(recursive_to(o, device=device) for o in obj)
|
116 |
+
elif isinstance(obj, dict):
|
117 |
+
return {k: recursive_to(v, device=device) for k, v in obj.items()}
|
118 |
+
|
119 |
+
else:
|
120 |
+
return obj
|
121 |
+
|
122 |
+
|
123 |
+
def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'):
|
124 |
+
if mode == 'sqrt':
|
125 |
+
w = np.sqrt(length / max_length)
|
126 |
+
elif mode == 'linear':
|
127 |
+
w = length / max_length
|
128 |
+
elif mode is None:
|
129 |
+
w = 1.0
|
130 |
+
else:
|
131 |
+
raise ValueError('Unknown reweighting mode: %s' % mode)
|
132 |
+
return w
|
133 |
+
|
134 |
+
|
135 |
+
def sum_weighted_losses(losses, weights):
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
losses: Dict of scalar tensors.
|
139 |
+
weights: Dict of weights.
|
140 |
+
"""
|
141 |
+
loss = 0
|
142 |
+
for k in losses.keys():
|
143 |
+
if weights is None:
|
144 |
+
loss = loss + losses[k]
|
145 |
+
else:
|
146 |
+
loss = loss + weights[k] * losses[k]
|
147 |
+
return loss
|
148 |
+
|
149 |
+
|
150 |
+
def count_parameters(model):
|
151 |
+
return sum(p.numel() for p in model.parameters())
|
diffab/utils/transforms/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Transforms
|
2 |
+
from .mask import MaskSingleCDR, MaskMultipleCDRs, MaskAntibody
|
3 |
+
from .merge import MergeChains
|
4 |
+
from .patch import PatchAroundAnchor
|
5 |
+
|
6 |
+
# Factory
|
7 |
+
from ._base import get_transform, Compose
|
diffab/utils/transforms/_base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
from torchvision.transforms import Compose
|
4 |
+
|
5 |
+
|
6 |
+
_TRANSFORM_DICT = {}
|
7 |
+
|
8 |
+
|
9 |
+
def register_transform(name):
|
10 |
+
def decorator(cls):
|
11 |
+
_TRANSFORM_DICT[name] = cls
|
12 |
+
return cls
|
13 |
+
return decorator
|
14 |
+
|
15 |
+
|
16 |
+
def get_transform(cfg):
|
17 |
+
if cfg is None or len(cfg) == 0:
|
18 |
+
return None
|
19 |
+
tfms = []
|
20 |
+
for t_dict in cfg:
|
21 |
+
t_dict = copy.deepcopy(t_dict)
|
22 |
+
cls = _TRANSFORM_DICT[t_dict.pop('type')]
|
23 |
+
tfms.append(cls(**t_dict))
|
24 |
+
return Compose(tfms)
|
25 |
+
|
26 |
+
|
27 |
+
def _index_select(v, index, n):
|
28 |
+
if isinstance(v, torch.Tensor) and v.size(0) == n:
|
29 |
+
return v[index]
|
30 |
+
elif isinstance(v, list) and len(v) == n:
|
31 |
+
return [v[i] for i in index]
|
32 |
+
else:
|
33 |
+
return v
|
34 |
+
|
35 |
+
|
36 |
+
def _index_select_data(data, index):
|
37 |
+
return {
|
38 |
+
k: _index_select(v, index, data['aa'].size(0))
|
39 |
+
for k, v in data.items()
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def _mask_select(v, mask):
|
44 |
+
if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0):
|
45 |
+
return v[mask]
|
46 |
+
elif isinstance(v, list) and len(v) == mask.size(0):
|
47 |
+
return [v[i] for i, b in enumerate(mask) if b]
|
48 |
+
else:
|
49 |
+
return v
|
50 |
+
|
51 |
+
|
52 |
+
def _mask_select_data(data, mask):
|
53 |
+
return {
|
54 |
+
k: _mask_select(v, mask)
|
55 |
+
for k, v in data.items()
|
56 |
+
}
|
diffab/utils/transforms/mask.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
from ..protein import constants
|
6 |
+
from ._base import register_transform
|
7 |
+
|
8 |
+
|
9 |
+
def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
|
10 |
+
first, last = continuous_flag_to_range(flag)
|
11 |
+
length = flag.sum().item()
|
12 |
+
if (length - 2*shrink_limit) < min_length:
|
13 |
+
shrink_limit = 0
|
14 |
+
first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
|
15 |
+
last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
|
16 |
+
flag_ext = flag.clone()
|
17 |
+
flag_ext[first_ext : last_ext+1] = True
|
18 |
+
return flag_ext
|
19 |
+
|
20 |
+
|
21 |
+
def continuous_flag_to_range(flag):
|
22 |
+
first = (torch.arange(0, flag.size(0))[flag]).min().item()
|
23 |
+
last = (torch.arange(0, flag.size(0))[flag]).max().item()
|
24 |
+
return first, last
|
25 |
+
|
26 |
+
|
27 |
+
@register_transform('mask_single_cdr')
|
28 |
+
class MaskSingleCDR(object):
|
29 |
+
|
30 |
+
def __init__(self, selection=None, augmentation=True):
|
31 |
+
super().__init__()
|
32 |
+
cdr_str_to_enum = {
|
33 |
+
'H1': constants.CDR.H1,
|
34 |
+
'H2': constants.CDR.H2,
|
35 |
+
'H3': constants.CDR.H3,
|
36 |
+
'L1': constants.CDR.L1,
|
37 |
+
'L2': constants.CDR.L2,
|
38 |
+
'L3': constants.CDR.L3,
|
39 |
+
'H_CDR1': constants.CDR.H1,
|
40 |
+
'H_CDR2': constants.CDR.H2,
|
41 |
+
'H_CDR3': constants.CDR.H3,
|
42 |
+
'L_CDR1': constants.CDR.L1,
|
43 |
+
'L_CDR2': constants.CDR.L2,
|
44 |
+
'L_CDR3': constants.CDR.L3,
|
45 |
+
'CDR3': 'CDR3', # H3 first, then fallback to L3
|
46 |
+
}
|
47 |
+
assert selection is None or selection in cdr_str_to_enum
|
48 |
+
self.selection = cdr_str_to_enum.get(selection, None)
|
49 |
+
self.augmentation = augmentation
|
50 |
+
|
51 |
+
def perform_masking_(self, data, selection=None):
|
52 |
+
cdr_flag = data['cdr_flag']
|
53 |
+
|
54 |
+
if selection is None:
|
55 |
+
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
|
56 |
+
cdr_to_mask = random.choice(cdr_all)
|
57 |
+
else:
|
58 |
+
cdr_to_mask = selection
|
59 |
+
|
60 |
+
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
|
61 |
+
if self.augmentation:
|
62 |
+
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
|
63 |
+
|
64 |
+
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
|
65 |
+
left_idx = max(0, cdr_first-1)
|
66 |
+
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
|
67 |
+
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
|
68 |
+
anchor_flag[left_idx] = True
|
69 |
+
anchor_flag[right_idx] = True
|
70 |
+
|
71 |
+
data['generate_flag'] = cdr_to_mask_flag
|
72 |
+
data['anchor_flag'] = anchor_flag
|
73 |
+
|
74 |
+
def __call__(self, structure):
|
75 |
+
if self.selection is None:
|
76 |
+
ab_data = []
|
77 |
+
if structure['heavy'] is not None:
|
78 |
+
ab_data.append(structure['heavy'])
|
79 |
+
if structure['light'] is not None:
|
80 |
+
ab_data.append(structure['light'])
|
81 |
+
data_to_mask = random.choice(ab_data)
|
82 |
+
sel = None
|
83 |
+
elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
|
84 |
+
data_to_mask = structure['heavy']
|
85 |
+
sel = int(self.selection)
|
86 |
+
elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
|
87 |
+
data_to_mask = structure['light']
|
88 |
+
sel = int(self.selection)
|
89 |
+
elif self.selection == 'CDR3':
|
90 |
+
if structure['heavy'] is not None:
|
91 |
+
data_to_mask = structure['heavy']
|
92 |
+
sel = constants.CDR.H3
|
93 |
+
else:
|
94 |
+
data_to_mask = structure['light']
|
95 |
+
sel = constants.CDR.L3
|
96 |
+
|
97 |
+
self.perform_masking_(data_to_mask, selection=sel)
|
98 |
+
return structure
|
99 |
+
|
100 |
+
|
101 |
+
@register_transform('mask_multiple_cdrs')
|
102 |
+
class MaskMultipleCDRs(object):
|
103 |
+
|
104 |
+
def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
|
105 |
+
super().__init__()
|
106 |
+
cdr_str_to_enum = {
|
107 |
+
'H1': constants.CDR.H1,
|
108 |
+
'H2': constants.CDR.H2,
|
109 |
+
'H3': constants.CDR.H3,
|
110 |
+
'L1': constants.CDR.L1,
|
111 |
+
'L2': constants.CDR.L2,
|
112 |
+
'L3': constants.CDR.L3,
|
113 |
+
'H_CDR1': constants.CDR.H1,
|
114 |
+
'H_CDR2': constants.CDR.H2,
|
115 |
+
'H_CDR3': constants.CDR.H3,
|
116 |
+
'L_CDR1': constants.CDR.L1,
|
117 |
+
'L_CDR2': constants.CDR.L2,
|
118 |
+
'L_CDR3': constants.CDR.L3,
|
119 |
+
}
|
120 |
+
if selection is not None:
|
121 |
+
self.selection = [cdr_str_to_enum[s] for s in selection]
|
122 |
+
else:
|
123 |
+
self.selection = None
|
124 |
+
self.augmentation = augmentation
|
125 |
+
|
126 |
+
def mask_one_cdr_(self, data, cdr_to_mask):
|
127 |
+
cdr_flag = data['cdr_flag']
|
128 |
+
|
129 |
+
cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
|
130 |
+
if self.augmentation:
|
131 |
+
cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
|
132 |
+
|
133 |
+
cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
|
134 |
+
left_idx = max(0, cdr_first-1)
|
135 |
+
right_idx = min(data['aa'].size(0)-1, cdr_last+1)
|
136 |
+
anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
|
137 |
+
anchor_flag[left_idx] = True
|
138 |
+
anchor_flag[right_idx] = True
|
139 |
+
|
140 |
+
if 'generate_flag' not in data:
|
141 |
+
data['generate_flag'] = cdr_to_mask_flag
|
142 |
+
data['anchor_flag'] = anchor_flag
|
143 |
+
else:
|
144 |
+
data['generate_flag'] |= cdr_to_mask_flag
|
145 |
+
data['anchor_flag'] |= anchor_flag
|
146 |
+
|
147 |
+
def mask_for_one_chain_(self, data):
|
148 |
+
cdr_flag = data['cdr_flag']
|
149 |
+
cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
|
150 |
+
|
151 |
+
num_cdrs_to_mask = random.randint(1, len(cdr_all))
|
152 |
+
|
153 |
+
if self.selection is not None:
|
154 |
+
cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
|
155 |
+
else:
|
156 |
+
random.shuffle(cdr_all)
|
157 |
+
cdrs_to_mask = cdr_all[:num_cdrs_to_mask]
|
158 |
+
|
159 |
+
for cdr_to_mask in cdrs_to_mask:
|
160 |
+
self.mask_one_cdr_(data, cdr_to_mask)
|
161 |
+
|
162 |
+
def __call__(self, structure):
|
163 |
+
if structure['heavy'] is not None:
|
164 |
+
self.mask_for_one_chain_(structure['heavy'])
|
165 |
+
if structure['light'] is not None:
|
166 |
+
self.mask_for_one_chain_(structure['light'])
|
167 |
+
return structure
|
168 |
+
|
169 |
+
|
170 |
+
@register_transform('mask_antibody')
|
171 |
+
class MaskAntibody(object):
|
172 |
+
|
173 |
+
def mask_ab_chain_(self, data):
|
174 |
+
data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)
|
175 |
+
|
176 |
+
def __call__(self, structure):
|
177 |
+
pos_ab_alpha = []
|
178 |
+
if structure['heavy'] is not None:
|
179 |
+
self.mask_ab_chain_(structure['heavy'])
|
180 |
+
pos_ab_alpha.append(
|
181 |
+
structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
|
182 |
+
)
|
183 |
+
if structure['light'] is not None:
|
184 |
+
self.mask_ab_chain_(structure['light'])
|
185 |
+
pos_ab_alpha.append(
|
186 |
+
structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
|
187 |
+
)
|
188 |
+
pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3)
|
189 |
+
|
190 |
+
if structure['antigen'] is not None:
|
191 |
+
pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
|
192 |
+
ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab)
|
193 |
+
nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag)
|
194 |
+
contact_flag = (nn_ab_dist <= 6.0) # (L_Ag)
|
195 |
+
if contact_flag.sum().item() == 0:
|
196 |
+
contact_flag[nn_ab_dist.argmin()] = True
|
197 |
+
|
198 |
+
anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
|
199 |
+
anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
|
200 |
+
anchor_flag[anchor_idx] = True
|
201 |
+
structure['antigen']['anchor_flag'] = anchor_flag
|
202 |
+
structure['antigen']['contact_flag'] = contact_flag
|
203 |
+
|
204 |
+
return structure
|
205 |
+
|
206 |
+
|
207 |
+
@register_transform('remove_antigen')
|
208 |
+
class RemoveAntigen:
|
209 |
+
|
210 |
+
def __call__(self, structure):
|
211 |
+
structure['antigen'] = None
|
212 |
+
structure['antigen_seqmap'] = None
|
213 |
+
return structure
|
diffab/utils/transforms/merge.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..protein import constants
|
4 |
+
from ._base import register_transform
|
5 |
+
|
6 |
+
|
7 |
+
@register_transform('merge_chains')
|
8 |
+
class MergeChains(object):
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def assign_chain_number_(self, data_list):
|
14 |
+
chains = set()
|
15 |
+
for data in data_list:
|
16 |
+
chains.update(data['chain_id'])
|
17 |
+
chains = {c: i for i, c in enumerate(chains)}
|
18 |
+
|
19 |
+
for data in data_list:
|
20 |
+
data['chain_nb'] = torch.LongTensor([
|
21 |
+
chains[c] for c in data['chain_id']
|
22 |
+
])
|
23 |
+
|
24 |
+
def _data_attr(self, data, name):
|
25 |
+
if name in ('generate_flag', 'anchor_flag') and name not in data:
|
26 |
+
return torch.zeros(data['aa'].shape, dtype=torch.bool)
|
27 |
+
else:
|
28 |
+
return data[name]
|
29 |
+
|
30 |
+
def __call__(self, structure):
|
31 |
+
data_list = []
|
32 |
+
if structure['heavy'] is not None:
|
33 |
+
structure['heavy']['fragment_type'] = torch.full_like(
|
34 |
+
structure['heavy']['aa'],
|
35 |
+
fill_value = constants.Fragment.Heavy,
|
36 |
+
)
|
37 |
+
data_list.append(structure['heavy'])
|
38 |
+
|
39 |
+
if structure['light'] is not None:
|
40 |
+
structure['light']['fragment_type'] = torch.full_like(
|
41 |
+
structure['light']['aa'],
|
42 |
+
fill_value = constants.Fragment.Light,
|
43 |
+
)
|
44 |
+
data_list.append(structure['light'])
|
45 |
+
|
46 |
+
if structure['antigen'] is not None:
|
47 |
+
structure['antigen']['fragment_type'] = torch.full_like(
|
48 |
+
structure['antigen']['aa'],
|
49 |
+
fill_value = constants.Fragment.Antigen,
|
50 |
+
)
|
51 |
+
structure['antigen']['cdr_flag'] = torch.zeros_like(
|
52 |
+
structure['antigen']['aa'],
|
53 |
+
)
|
54 |
+
data_list.append(structure['antigen'])
|
55 |
+
|
56 |
+
self.assign_chain_number_(data_list)
|
57 |
+
|
58 |
+
list_props = {
|
59 |
+
'chain_id': [],
|
60 |
+
'icode': [],
|
61 |
+
}
|
62 |
+
tensor_props = {
|
63 |
+
'chain_nb': [],
|
64 |
+
'resseq': [],
|
65 |
+
'res_nb': [],
|
66 |
+
'aa': [],
|
67 |
+
'pos_heavyatom': [],
|
68 |
+
'mask_heavyatom': [],
|
69 |
+
'generate_flag': [],
|
70 |
+
'cdr_flag': [],
|
71 |
+
'anchor_flag': [],
|
72 |
+
'fragment_type': [],
|
73 |
+
}
|
74 |
+
|
75 |
+
for data in data_list:
|
76 |
+
for k in list_props.keys():
|
77 |
+
list_props[k].append(self._data_attr(data, k))
|
78 |
+
for k in tensor_props.keys():
|
79 |
+
tensor_props[k].append(self._data_attr(data, k))
|
80 |
+
|
81 |
+
list_props = {k: sum(v, start=[]) for k, v in list_props.items()}
|
82 |
+
tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()}
|
83 |
+
data_out = {
|
84 |
+
**list_props,
|
85 |
+
**tensor_props,
|
86 |
+
}
|
87 |
+
return data_out
|
88 |
+
|