luost26 commited on
Commit
753e275
·
1 Parent(s): 5d58d1a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +134 -0
  2. app.py +1 -3
  3. design_dock.py +67 -0
  4. design_pdb.py +4 -0
  5. design_testset.py +4 -0
  6. diffab/datasets/__init__.py +4 -0
  7. diffab/datasets/_base.py +40 -0
  8. diffab/datasets/custom.py +200 -0
  9. diffab/datasets/sabdab.py +470 -0
  10. diffab/models/__init__.py +3 -0
  11. diffab/models/_base.py +13 -0
  12. diffab/models/diffab.py +142 -0
  13. diffab/modules/common/geometry.py +481 -0
  14. diffab/modules/common/layers.py +160 -0
  15. diffab/modules/common/so3.py +146 -0
  16. diffab/modules/common/structure.py +77 -0
  17. diffab/modules/common/topology.py +24 -0
  18. diffab/modules/diffusion/dpm_full.py +319 -0
  19. diffab/modules/diffusion/transition.py +223 -0
  20. diffab/modules/encoders/ga.py +193 -0
  21. diffab/modules/encoders/pair.py +102 -0
  22. diffab/modules/encoders/residue.py +92 -0
  23. diffab/tools/dock/base.py +28 -0
  24. diffab/tools/dock/hdock.py +164 -0
  25. diffab/tools/eval/__main__.py +4 -0
  26. diffab/tools/eval/base.py +125 -0
  27. diffab/tools/eval/energy.py +43 -0
  28. diffab/tools/eval/run.py +66 -0
  29. diffab/tools/eval/similarity.py +125 -0
  30. diffab/tools/relax/__main__.py +4 -0
  31. diffab/tools/relax/base.py +117 -0
  32. diffab/tools/relax/openmm_relaxer.py +144 -0
  33. diffab/tools/relax/pyrosetta_relaxer.py +189 -0
  34. diffab/tools/relax/run.py +85 -0
  35. diffab/tools/renumber/__init__.py +1 -0
  36. diffab/tools/renumber/__main__.py +5 -0
  37. diffab/tools/renumber/run.py +85 -0
  38. diffab/tools/runner/design_for_pdb.py +291 -0
  39. diffab/tools/runner/design_for_testset.py +243 -0
  40. diffab/utils/data.py +89 -0
  41. diffab/utils/inference.py +60 -0
  42. diffab/utils/misc.py +126 -0
  43. diffab/utils/protein/constants.py +319 -0
  44. diffab/utils/protein/parsers.py +109 -0
  45. diffab/utils/protein/writers.py +75 -0
  46. diffab/utils/train.py +151 -0
  47. diffab/utils/transforms/__init__.py +7 -0
  48. diffab/utils/transforms/_base.py +56 -0
  49. diffab/utils/transforms/mask.py +213 -0
  50. diffab/utils/transforms/merge.py +88 -0
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ /playgrounds*
132
+ /logs*
133
+ /results*
134
+ /*.csv
app.py CHANGED
@@ -24,8 +24,6 @@ from diffab.tools.renumber.run import (
24
  assign_number_to_sequence,
25
  )
26
 
27
- DIFFAB_DIR = os.path.realpath('./diffab-repo')
28
-
29
  CDR_OPTIONS = OrderedDict()
30
  CDR_OPTIONS['H_CDR1'] = 'H1'
31
  CDR_OPTIONS['H_CDR2'] = 'H2'
@@ -96,7 +94,7 @@ def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_d
96
  bufsize=1,
97
  stdout=subprocess.PIPE,
98
  stderr=subprocess.STDOUT,
99
- cwd=DIFFAB_DIR,
100
  )
101
  for line in iter(proc.stdout.readline, b''):
102
  output_buffer += line.decode()
 
24
  assign_number_to_sequence,
25
  )
26
 
 
 
27
  CDR_OPTIONS = OrderedDict()
28
  CDR_OPTIONS['H_CDR1'] = 'H1'
29
  CDR_OPTIONS['H_CDR2'] = 'H2'
 
94
  bufsize=1,
95
  stdout=subprocess.PIPE,
96
  stderr=subprocess.STDOUT,
97
+ cwd=os.getcwd(),
98
  )
99
  for line in iter(proc.stdout.readline, b''):
100
  output_buffer += line.decode()
design_dock.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import argparse
4
+ from diffab.tools.dock.hdock import HDockAntibody
5
+ from diffab.tools.runner.design_for_pdb import args_factory, design_for_pdb
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--antigen', type=str, required=True)
11
+ parser.add_argument('--antibody', type=str, default='./data/examples/3QHF_Fv.pdb')
12
+ parser.add_argument('--heavy', type=str, default='H', help='Chain id of the heavy chain.')
13
+ parser.add_argument('--light', type=str, default='L', help='Chain id of the light chain.')
14
+ parser.add_argument('--hdock_bin', type=str, default='./bin/hdock')
15
+ parser.add_argument('--createpl_bin', type=str, default='./bin/createpl')
16
+ parser.add_argument('-n', '--num_docks', type=int, default=10)
17
+ parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
18
+ parser.add_argument('-o', '--out_root', type=str, default='./results')
19
+ parser.add_argument('-t', '--tag', type=str, default='')
20
+ parser.add_argument('-s', '--seed', type=int, default=None)
21
+ parser.add_argument('-d', '--device', type=str, default='cuda')
22
+ parser.add_argument('-b', '--batch_size', type=int, default=16)
23
+ args = parser.parse_args()
24
+
25
+ hdock_missing = []
26
+ if not os.path.exists(args.hdock_bin):
27
+ hdock_missing.append(args.hdock_bin)
28
+ if not os.path.exists(args.createpl_bin):
29
+ hdock_missing.append(args.createpl_bin)
30
+ if len(hdock_missing) > 0:
31
+ print("[WARNING] The following HDOCK applications are missing:")
32
+ for f in hdock_missing:
33
+ print(f" > {f}")
34
+ print("Please download HDOCK from http://huanglab.phys.hust.edu.cn/software/hdocklite/ "
35
+ "and put `hdock` and `createpl` to the above path.")
36
+ exit()
37
+
38
+ antigen_name = os.path.basename(os.path.splitext(args.antigen)[0])
39
+ docked_pdb_dir = os.path.join(os.path.splitext(args.antigen)[0] + '_dock')
40
+ os.makedirs(docked_pdb_dir, exist_ok=True)
41
+ docked_pdb_paths = []
42
+ for fname in os.listdir(docked_pdb_dir):
43
+ if fname.endswith('.pdb'):
44
+ docked_pdb_paths.append(os.path.join(docked_pdb_dir, fname))
45
+ if len(docked_pdb_paths) < args.num_docks:
46
+ with HDockAntibody() as dock_session:
47
+ dock_session.set_antigen(args.antigen)
48
+ dock_session.set_antibody(args.antibody)
49
+ docked_tmp_paths = dock_session.dock()
50
+ for i, tmp_path in enumerate(docked_tmp_paths[:args.num_docks]):
51
+ dest_path = os.path.join(docked_pdb_dir, f"{antigen_name}_Ab_{i:04d}.pdb")
52
+ shutil.copyfile(tmp_path, dest_path)
53
+ print(f'[INFO] Copy {tmp_path} -> {dest_path}')
54
+ docked_pdb_paths.append(dest_path)
55
+
56
+ for pdb_path in docked_pdb_paths:
57
+ current_args = vars(args)
58
+ current_args['tag'] += antigen_name
59
+ design_args = args_factory(
60
+ pdb_path = pdb_path,
61
+ **current_args,
62
+ )
63
+ design_for_pdb(design_args)
64
+
65
+
66
+ if __name__ == '__main__':
67
+ main()
design_pdb.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from diffab.tools.runner.design_for_pdb import args_from_cmdline, design_for_pdb
2
+
3
+ if __name__ == '__main__':
4
+ design_for_pdb(args_from_cmdline())
design_testset.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from diffab.tools.runner.design_for_testset import main
2
+
3
+ if __name__ == '__main__':
4
+ main()
diffab/datasets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .sabdab import SAbDabDataset
2
+ from .custom import CustomDataset
3
+
4
+ from ._base import get_dataset
diffab/datasets/_base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, ConcatDataset
2
+ from diffab.utils.transforms import get_transform
3
+
4
+
5
+ _DATASET_DICT = {}
6
+
7
+
8
+ def register_dataset(name):
9
+ def decorator(cls):
10
+ _DATASET_DICT[name] = cls
11
+ return cls
12
+ return decorator
13
+
14
+
15
+ def get_dataset(cfg):
16
+ transform = get_transform(cfg.transform) if 'transform' in cfg else None
17
+ return _DATASET_DICT[cfg.type](cfg, transform=transform)
18
+
19
+
20
+ @register_dataset('concat')
21
+ def get_concat_dataset(cfg):
22
+ datasets = [get_dataset(d) for d in cfg.datasets]
23
+ return ConcatDataset(datasets)
24
+
25
+
26
+ @register_dataset('balanced_concat')
27
+ class BalancedConcatDataset(Dataset):
28
+
29
+ def __init__(self, cfg, transform=None):
30
+ super().__init__()
31
+ assert transform is None, 'transform is not supported.'
32
+ self.datasets = [get_dataset(d) for d in cfg.datasets]
33
+ self.max_size = max([len(d) for d in self.datasets])
34
+
35
+ def __len__(self):
36
+ return self.max_size * len(self.datasets)
37
+
38
+ def __getitem__(self, idx):
39
+ dataset_idx = idx // self.max_size
40
+ return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])]
diffab/datasets/custom.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import joblib
4
+ import pickle
5
+ import lmdb
6
+ from Bio import PDB
7
+ from Bio.PDB import PDBExceptions
8
+ from torch.utils.data import Dataset
9
+ from tqdm.auto import tqdm
10
+
11
+ from ..utils.protein import parsers
12
+ from .sabdab import _label_heavy_chain_cdr, _label_light_chain_cdr
13
+ from ._base import register_dataset
14
+
15
+
16
+ def preprocess_antibody_structure(task):
17
+ pdb_path = task['pdb_path']
18
+ H_id = task.get('heavy_id', 'H')
19
+ L_id = task.get('light_id', 'L')
20
+
21
+ parser = PDB.PDBParser(QUIET=True)
22
+ model = parser.get_structure(id, pdb_path)[0]
23
+
24
+ all_chain_ids = [c.id for c in model]
25
+
26
+ parsed = {
27
+ 'id': task['id'],
28
+ 'heavy': None,
29
+ 'heavy_seqmap': None,
30
+ 'light': None,
31
+ 'light_seqmap': None,
32
+ 'antigen': None,
33
+ 'antigen_seqmap': None,
34
+ }
35
+ try:
36
+ if H_id in all_chain_ids:
37
+ (
38
+ parsed['heavy'],
39
+ parsed['heavy_seqmap']
40
+ ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure(
41
+ model[H_id],
42
+ max_resseq = 113 # Chothia, end of Heavy chain Fv
43
+ ))
44
+
45
+ if L_id in all_chain_ids:
46
+ (
47
+ parsed['light'],
48
+ parsed['light_seqmap']
49
+ ) = _label_light_chain_cdr(*parsers.parse_biopython_structure(
50
+ model[L_id],
51
+ max_resseq = 106 # Chothia, end of Light chain Fv
52
+ ))
53
+
54
+ if parsed['heavy'] is None and parsed['light'] is None:
55
+ raise ValueError(
56
+ f'Neither valid antibody H-chain or L-chain is found. '
57
+ f'Please ensure that the chain id of heavy chain is "{H_id}" '
58
+ f'and the id of the light chain is "{L_id}".'
59
+ )
60
+
61
+
62
+ ag_chain_ids = [cid for cid in all_chain_ids if cid not in (H_id, L_id)]
63
+ if len(ag_chain_ids) > 0:
64
+ chains = [model[c] for c in ag_chain_ids]
65
+ (
66
+ parsed['antigen'],
67
+ parsed['antigen_seqmap']
68
+ ) = parsers.parse_biopython_structure(chains)
69
+
70
+ except (
71
+ PDBExceptions.PDBConstructionException,
72
+ parsers.ParsingException,
73
+ KeyError,
74
+ ValueError,
75
+ ) as e:
76
+ logging.warning('[{}] {}: {}'.format(
77
+ task['id'],
78
+ e.__class__.__name__,
79
+ str(e)
80
+ ))
81
+ return None
82
+
83
+ return parsed
84
+
85
+
86
+ @register_dataset('custom')
87
+ class CustomDataset(Dataset):
88
+
89
+ MAP_SIZE = 32*(1024*1024*1024) # 32GB
90
+
91
+ def __init__(self, structure_dir, transform=None, reset=False):
92
+ super().__init__()
93
+ self.structure_dir = structure_dir
94
+ self.transform = transform
95
+
96
+ self.db_conn = None
97
+ self.db_ids = None
98
+ self._load_structures(reset)
99
+
100
+ @property
101
+ def _cache_db_path(self):
102
+ return os.path.join(self.structure_dir, 'structure_cache.lmdb')
103
+
104
+ def _connect_db(self):
105
+ self._close_db()
106
+ self.db_conn = lmdb.open(
107
+ self._cache_db_path,
108
+ map_size=self.MAP_SIZE,
109
+ create=False,
110
+ subdir=False,
111
+ readonly=True,
112
+ lock=False,
113
+ readahead=False,
114
+ meminit=False,
115
+ )
116
+ with self.db_conn.begin() as txn:
117
+ keys = [k.decode() for k in txn.cursor().iternext(values=False)]
118
+ self.db_ids = keys
119
+
120
+ def _close_db(self):
121
+ if self.db_conn is not None:
122
+ self.db_conn.close()
123
+ self.db_conn = None
124
+ self.db_ids = None
125
+
126
+ def _load_structures(self, reset):
127
+ all_pdbs = []
128
+ for fname in os.listdir(self.structure_dir):
129
+ if not fname.endswith('.pdb'): continue
130
+ all_pdbs.append(fname)
131
+
132
+ if reset or not os.path.exists(self._cache_db_path):
133
+ todo_pdbs = all_pdbs
134
+ else:
135
+ self._connect_db()
136
+ processed_pdbs = self.db_ids
137
+ self._close_db()
138
+ todo_pdbs = list(set(all_pdbs) - set(processed_pdbs))
139
+
140
+ if len(todo_pdbs) > 0:
141
+ self._preprocess_structures(todo_pdbs)
142
+
143
+ def _preprocess_structures(self, pdb_list):
144
+ tasks = []
145
+ for pdb_fname in pdb_list:
146
+ pdb_path = os.path.join(self.structure_dir, pdb_fname)
147
+ tasks.append({
148
+ 'id': pdb_fname,
149
+ 'pdb_path': pdb_path,
150
+ })
151
+
152
+ data_list = joblib.Parallel(
153
+ n_jobs = max(joblib.cpu_count() // 2, 1),
154
+ )(
155
+ joblib.delayed(preprocess_antibody_structure)(task)
156
+ for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
157
+ )
158
+
159
+ db_conn = lmdb.open(
160
+ self._cache_db_path,
161
+ map_size = self.MAP_SIZE,
162
+ create=True,
163
+ subdir=False,
164
+ readonly=False,
165
+ )
166
+ ids = []
167
+ with db_conn.begin(write=True, buffers=True) as txn:
168
+ for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
169
+ if data is None:
170
+ continue
171
+ ids.append(data['id'])
172
+ txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
173
+
174
+ def __len__(self):
175
+ return len(self.db_ids)
176
+
177
+ def __getitem__(self, index):
178
+ self._connect_db()
179
+ id = self.db_ids[index]
180
+ with self.db_conn.begin() as txn:
181
+ data = pickle.loads(txn.get(id.encode()))
182
+ if self.transform is not None:
183
+ data = self.transform(data)
184
+ return data
185
+
186
+
187
+ if __name__ == '__main__':
188
+ import argparse
189
+ parser = argparse.ArgumentParser()
190
+ parser.add_argument('--dir', type=str, default='./data/custom')
191
+ parser.add_argument('--reset', action='store_true', default=False)
192
+ args = parser.parse_args()
193
+
194
+ dataset = CustomDataset(
195
+ structure_dir = args.dir,
196
+ reset = args.reset,
197
+ )
198
+ print(dataset[0])
199
+ print(len(dataset))
200
+
diffab/datasets/sabdab.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ import datetime
5
+ import pandas as pd
6
+ import joblib
7
+ import pickle
8
+ import lmdb
9
+ import subprocess
10
+ import torch
11
+ from Bio import PDB, SeqRecord, SeqIO, Seq
12
+ from Bio.PDB import PDBExceptions
13
+ from Bio.PDB import Polypeptide
14
+ from torch.utils.data import Dataset
15
+ from tqdm.auto import tqdm
16
+
17
+ from ..utils.protein import parsers, constants
18
+ from ._base import register_dataset
19
+
20
+
21
+ ALLOWED_AG_TYPES = {
22
+ 'protein',
23
+ 'protein | protein',
24
+ 'protein | protein | protein',
25
+ 'protein | protein | protein | protein | protein',
26
+ 'protein | protein | protein | protein',
27
+ }
28
+
29
+ RESOLUTION_THRESHOLD = 4.0
30
+
31
+ TEST_ANTIGENS = [
32
+ 'sars-cov-2 receptor binding domain',
33
+ 'hiv-1 envelope glycoprotein gp160',
34
+ 'mers s',
35
+ 'influenza a virus',
36
+ 'cd27 antigen',
37
+ ]
38
+
39
+
40
+ def nan_to_empty_string(val):
41
+ if val != val or not val:
42
+ return ''
43
+ else:
44
+ return val
45
+
46
+
47
+ def nan_to_none(val):
48
+ if val != val or not val:
49
+ return None
50
+ else:
51
+ return val
52
+
53
+
54
+ def split_sabdab_delimited_str(val):
55
+ if not val:
56
+ return []
57
+ else:
58
+ return [s.strip() for s in val.split('|')]
59
+
60
+
61
+ def parse_sabdab_resolution(val):
62
+ if val == 'NOT' or not val or val != val:
63
+ return None
64
+ elif isinstance(val, str) and ',' in val:
65
+ return float(val.split(',')[0].strip())
66
+ else:
67
+ return float(val)
68
+
69
+
70
+ def _aa_tensor_to_sequence(aa):
71
+ return ''.join([Polypeptide.index_to_one(a.item()) for a in aa.flatten()])
72
+
73
+
74
+ def _label_heavy_chain_cdr(data, seq_map, max_cdr3_length=30):
75
+ if data is None or seq_map is None:
76
+ return data, seq_map
77
+
78
+ # Add CDR labels
79
+ cdr_flag = torch.zeros_like(data['aa'])
80
+ for position, idx in seq_map.items():
81
+ resseq = position[1]
82
+ cdr_type = constants.ChothiaCDRRange.to_cdr('H', resseq)
83
+ if cdr_type is not None:
84
+ cdr_flag[idx] = cdr_type
85
+ data['cdr_flag'] = cdr_flag
86
+
87
+ # Add CDR sequence annotations
88
+ data['H1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H1] )
89
+ data['H2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H2] )
90
+ data['H3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H3] )
91
+
92
+ cdr3_length = (cdr_flag == constants.CDR.H3).sum().item()
93
+ # Remove too long CDR3
94
+ if cdr3_length > max_cdr3_length:
95
+ cdr_flag[cdr_flag == constants.CDR.H3] = 0
96
+ logging.warning(f'CDR-H3 too long {cdr3_length}. Removed.')
97
+ return None, None
98
+
99
+ # Filter: ensure CDR3 exists
100
+ if cdr3_length == 0:
101
+ logging.warning('No CDR-H3 found in the heavy chain.')
102
+ return None, None
103
+
104
+ return data, seq_map
105
+
106
+
107
+ def _label_light_chain_cdr(data, seq_map, max_cdr3_length=30):
108
+ if data is None or seq_map is None:
109
+ return data, seq_map
110
+ cdr_flag = torch.zeros_like(data['aa'])
111
+ for position, idx in seq_map.items():
112
+ resseq = position[1]
113
+ cdr_type = constants.ChothiaCDRRange.to_cdr('L', resseq)
114
+ if cdr_type is not None:
115
+ cdr_flag[idx] = cdr_type
116
+ data['cdr_flag'] = cdr_flag
117
+
118
+ data['L1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L1] )
119
+ data['L2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L2] )
120
+ data['L3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L3] )
121
+
122
+ cdr3_length = (cdr_flag == constants.CDR.L3).sum().item()
123
+ # Remove too long CDR3
124
+ if cdr3_length > max_cdr3_length:
125
+ cdr_flag[cdr_flag == constants.CDR.L3] = 0
126
+ logging.warning(f'CDR-L3 too long {cdr3_length}. Removed.')
127
+ return None, None
128
+
129
+ # Ensure CDR3 exists
130
+ if cdr3_length == 0:
131
+ logging.warning('No CDRs found in the light chain.')
132
+ return None, None
133
+
134
+ return data, seq_map
135
+
136
+
137
+ def preprocess_sabdab_structure(task):
138
+ entry = task['entry']
139
+ pdb_path = task['pdb_path']
140
+
141
+ parser = PDB.PDBParser(QUIET=True)
142
+ model = parser.get_structure(id, pdb_path)[0]
143
+
144
+ parsed = {
145
+ 'id': entry['id'],
146
+ 'heavy': None,
147
+ 'heavy_seqmap': None,
148
+ 'light': None,
149
+ 'light_seqmap': None,
150
+ 'antigen': None,
151
+ 'antigen_seqmap': None,
152
+ }
153
+ try:
154
+ if entry['H_chain'] is not None:
155
+ (
156
+ parsed['heavy'],
157
+ parsed['heavy_seqmap']
158
+ ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure(
159
+ model[entry['H_chain']],
160
+ max_resseq = 113 # Chothia, end of Heavy chain Fv
161
+ ))
162
+
163
+ if entry['L_chain'] is not None:
164
+ (
165
+ parsed['light'],
166
+ parsed['light_seqmap']
167
+ ) = _label_light_chain_cdr(*parsers.parse_biopython_structure(
168
+ model[entry['L_chain']],
169
+ max_resseq = 106 # Chothia, end of Light chain Fv
170
+ ))
171
+
172
+ if parsed['heavy'] is None and parsed['light'] is None:
173
+ raise ValueError('Neither valid H-chain or L-chain is found.')
174
+
175
+ if len(entry['ag_chains']) > 0:
176
+ chains = [model[c] for c in entry['ag_chains']]
177
+ (
178
+ parsed['antigen'],
179
+ parsed['antigen_seqmap']
180
+ ) = parsers.parse_biopython_structure(chains)
181
+
182
+ except (
183
+ PDBExceptions.PDBConstructionException,
184
+ parsers.ParsingException,
185
+ KeyError,
186
+ ValueError,
187
+ ) as e:
188
+ logging.warning('[{}] {}: {}'.format(
189
+ task['id'],
190
+ e.__class__.__name__,
191
+ str(e)
192
+ ))
193
+ return None
194
+
195
+ return parsed
196
+
197
+
198
+ class SAbDabDataset(Dataset):
199
+
200
+ MAP_SIZE = 32*(1024*1024*1024) # 32GB
201
+
202
+ def __init__(
203
+ self,
204
+ summary_path = './data/sabdab_summary_all.tsv',
205
+ chothia_dir = './data/all_structures/chothia',
206
+ processed_dir = './data/processed',
207
+ split = 'train',
208
+ split_seed = 2022,
209
+ transform = None,
210
+ reset = False,
211
+ ):
212
+ super().__init__()
213
+ self.summary_path = summary_path
214
+ self.chothia_dir = chothia_dir
215
+ if not os.path.exists(chothia_dir):
216
+ raise FileNotFoundError(
217
+ f"SAbDab structures not found in {chothia_dir}. "
218
+ "Please download them from http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/"
219
+ )
220
+ self.processed_dir = processed_dir
221
+ os.makedirs(processed_dir, exist_ok=True)
222
+
223
+ self.sabdab_entries = None
224
+ self._load_sabdab_entries()
225
+
226
+ self.db_conn = None
227
+ self.db_ids = None
228
+ self._load_structures(reset)
229
+
230
+ self.clusters = None
231
+ self.id_to_cluster = None
232
+ self._load_clusters(reset)
233
+
234
+ self.ids_in_split = None
235
+ self._load_split(split, split_seed)
236
+
237
+ self.transform = transform
238
+
239
+ def _load_sabdab_entries(self):
240
+ df = pd.read_csv(self.summary_path, sep='\t')
241
+ entries_all = []
242
+ for i, row in tqdm(
243
+ df.iterrows(),
244
+ dynamic_ncols=True,
245
+ desc='Loading entries',
246
+ total=len(df),
247
+ ):
248
+ entry_id = "{pdbcode}_{H}_{L}_{Ag}".format(
249
+ pdbcode = row['pdb'],
250
+ H = nan_to_empty_string(row['Hchain']),
251
+ L = nan_to_empty_string(row['Lchain']),
252
+ Ag = ''.join(split_sabdab_delimited_str(
253
+ nan_to_empty_string(row['antigen_chain'])
254
+ ))
255
+ )
256
+ ag_chains = split_sabdab_delimited_str(
257
+ nan_to_empty_string(row['antigen_chain'])
258
+ )
259
+ resolution = parse_sabdab_resolution(row['resolution'])
260
+ entry = {
261
+ 'id': entry_id,
262
+ 'pdbcode': row['pdb'],
263
+ 'H_chain': nan_to_none(row['Hchain']),
264
+ 'L_chain': nan_to_none(row['Lchain']),
265
+ 'ag_chains': ag_chains,
266
+ 'ag_type': nan_to_none(row['antigen_type']),
267
+ 'ag_name': nan_to_none(row['antigen_name']),
268
+ 'date': datetime.datetime.strptime(row['date'], '%m/%d/%y'),
269
+ 'resolution': resolution,
270
+ 'method': row['method'],
271
+ 'scfv': row['scfv'],
272
+ }
273
+
274
+ # Filtering
275
+ if (
276
+ (entry['ag_type'] in ALLOWED_AG_TYPES or entry['ag_type'] is None)
277
+ and (entry['resolution'] is not None and entry['resolution'] <= RESOLUTION_THRESHOLD)
278
+ ):
279
+ entries_all.append(entry)
280
+ self.sabdab_entries = entries_all
281
+
282
+ def _load_structures(self, reset):
283
+ if not os.path.exists(self._structure_cache_path) or reset:
284
+ if os.path.exists(self._structure_cache_path):
285
+ os.unlink(self._structure_cache_path)
286
+ self._preprocess_structures()
287
+
288
+ with open(self._structure_cache_path + '-ids', 'rb') as f:
289
+ self.db_ids = pickle.load(f)
290
+ self.sabdab_entries = list(
291
+ filter(
292
+ lambda e: e['id'] in self.db_ids,
293
+ self.sabdab_entries
294
+ )
295
+ )
296
+
297
+ @property
298
+ def _structure_cache_path(self):
299
+ return os.path.join(self.processed_dir, 'structures.lmdb')
300
+
301
+ def _preprocess_structures(self):
302
+ tasks = []
303
+ for entry in self.sabdab_entries:
304
+ pdb_path = os.path.join(self.chothia_dir, '{}.pdb'.format(entry['pdbcode']))
305
+ if not os.path.exists(pdb_path):
306
+ logging.warning(f"PDB not found: {pdb_path}")
307
+ continue
308
+ tasks.append({
309
+ 'id': entry['id'],
310
+ 'entry': entry,
311
+ 'pdb_path': pdb_path,
312
+ })
313
+
314
+ data_list = joblib.Parallel(
315
+ n_jobs = max(joblib.cpu_count() // 2, 1),
316
+ )(
317
+ joblib.delayed(preprocess_sabdab_structure)(task)
318
+ for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
319
+ )
320
+
321
+ db_conn = lmdb.open(
322
+ self._structure_cache_path,
323
+ map_size = self.MAP_SIZE,
324
+ create=True,
325
+ subdir=False,
326
+ readonly=False,
327
+ )
328
+ ids = []
329
+ with db_conn.begin(write=True, buffers=True) as txn:
330
+ for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
331
+ if data is None:
332
+ continue
333
+ ids.append(data['id'])
334
+ txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
335
+
336
+ with open(self._structure_cache_path + '-ids', 'wb') as f:
337
+ pickle.dump(ids, f)
338
+
339
+ @property
340
+ def _cluster_path(self):
341
+ return os.path.join(self.processed_dir, 'cluster_result_cluster.tsv')
342
+
343
+ def _load_clusters(self, reset):
344
+ if not os.path.exists(self._cluster_path) or reset:
345
+ self._create_clusters()
346
+
347
+ clusters, id_to_cluster = {}, {}
348
+ with open(self._cluster_path, 'r') as f:
349
+ for line in f.readlines():
350
+ cluster_name, data_id = line.split()
351
+ if cluster_name not in clusters:
352
+ clusters[cluster_name] = []
353
+ clusters[cluster_name].append(data_id)
354
+ id_to_cluster[data_id] = cluster_name
355
+ self.clusters = clusters
356
+ self.id_to_cluster = id_to_cluster
357
+
358
+ def _create_clusters(self):
359
+ cdr_records = []
360
+ for id in self.db_ids:
361
+ structure = self.get_structure(id)
362
+ if structure['heavy'] is not None:
363
+ cdr_records.append(SeqRecord.SeqRecord(
364
+ Seq.Seq(structure['heavy']['H3_seq']),
365
+ id = structure['id'],
366
+ name = '',
367
+ description = '',
368
+ ))
369
+ elif structure['light'] is not None:
370
+ cdr_records.append(SeqRecord.SeqRecord(
371
+ Seq.Seq(structure['light']['L3_seq']),
372
+ id = structure['id'],
373
+ name = '',
374
+ description = '',
375
+ ))
376
+ fasta_path = os.path.join(self.processed_dir, 'cdr_sequences.fasta')
377
+ SeqIO.write(cdr_records, fasta_path, 'fasta')
378
+
379
+ cmd = ' '.join([
380
+ 'mmseqs', 'easy-cluster',
381
+ os.path.realpath(fasta_path),
382
+ 'cluster_result', 'cluster_tmp',
383
+ '--min-seq-id', '0.5',
384
+ '-c', '0.8',
385
+ '--cov-mode', '1',
386
+ ])
387
+ subprocess.run(cmd, cwd=self.processed_dir, shell=True, check=True)
388
+
389
+ def _load_split(self, split, split_seed):
390
+ assert split in ('train', 'val', 'test')
391
+ ids_test = [
392
+ entry['id']
393
+ for entry in self.sabdab_entries
394
+ if entry['ag_name'] in TEST_ANTIGENS
395
+ ]
396
+ test_relevant_clusters = set([self.id_to_cluster[id] for id in ids_test])
397
+
398
+ ids_train_val = [
399
+ entry['id']
400
+ for entry in self.sabdab_entries
401
+ if self.id_to_cluster[entry['id']] not in test_relevant_clusters
402
+ ]
403
+ random.Random(split_seed).shuffle(ids_train_val)
404
+ if split == 'test':
405
+ self.ids_in_split = ids_test
406
+ elif split == 'val':
407
+ self.ids_in_split = ids_train_val[:20]
408
+ else:
409
+ self.ids_in_split = ids_train_val[20:]
410
+
411
+ def _connect_db(self):
412
+ if self.db_conn is not None:
413
+ return
414
+ self.db_conn = lmdb.open(
415
+ self._structure_cache_path,
416
+ map_size=self.MAP_SIZE,
417
+ create=False,
418
+ subdir=False,
419
+ readonly=True,
420
+ lock=False,
421
+ readahead=False,
422
+ meminit=False,
423
+ )
424
+
425
+ def get_structure(self, id):
426
+ self._connect_db()
427
+ with self.db_conn.begin() as txn:
428
+ return pickle.loads(txn.get(id.encode()))
429
+
430
+ def __len__(self):
431
+ return len(self.ids_in_split)
432
+
433
+ def __getitem__(self, index):
434
+ id = self.ids_in_split[index]
435
+ data = self.get_structure(id)
436
+ if self.transform is not None:
437
+ data = self.transform(data)
438
+ return data
439
+
440
+
441
+ @register_dataset('sabdab')
442
+ def get_sabdab_dataset(cfg, transform):
443
+ return SAbDabDataset(
444
+ summary_path = cfg.summary_path,
445
+ chothia_dir = cfg.chothia_dir,
446
+ processed_dir = cfg.processed_dir,
447
+ split = cfg.split,
448
+ split_seed = cfg.get('split_seed', 2022),
449
+ transform = transform,
450
+ )
451
+
452
+
453
+ if __name__ == '__main__':
454
+ import argparse
455
+ parser = argparse.ArgumentParser()
456
+ parser.add_argument('--split', type=str, default='train')
457
+ parser.add_argument('--processed_dir', type=str, default='./data/processed')
458
+ parser.add_argument('--reset', action='store_true', default=False)
459
+ args = parser.parse_args()
460
+ if args.reset:
461
+ sure = input('Sure to reset? (y/n): ')
462
+ if sure != 'y':
463
+ exit()
464
+ dataset = SAbDabDataset(
465
+ processed_dir=args.processed_dir,
466
+ split=args.split,
467
+ reset=args.reset
468
+ )
469
+ print(dataset[0])
470
+ print(len(dataset), len(dataset.clusters))
diffab/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .diffab import DiffusionAntibodyDesign
2
+
3
+ from ._base import get_model
diffab/models/_base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ _MODEL_DICT = {}
3
+
4
+
5
+ def register_model(name):
6
+ def decorator(cls):
7
+ _MODEL_DICT[name] = cls
8
+ return cls
9
+ return decorator
10
+
11
+
12
+ def get_model(cfg):
13
+ return _MODEL_DICT[cfg.type](cfg)
diffab/models/diffab.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from diffab.modules.common.geometry import construct_3d_basis
5
+ from diffab.modules.common.so3 import rotation_to_so3vec
6
+ from diffab.modules.encoders.residue import ResidueEmbedding
7
+ from diffab.modules.encoders.pair import PairEmbedding
8
+ from diffab.modules.diffusion.dpm_full import FullDPM
9
+ from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom
10
+ from ._base import register_model
11
+
12
+
13
+ resolution_to_num_atoms = {
14
+ 'backbone+CB': 5,
15
+ 'full': max_num_heavyatoms
16
+ }
17
+
18
+
19
+ @register_model('diffab')
20
+ class DiffusionAntibodyDesign(nn.Module):
21
+
22
+ def __init__(self, cfg):
23
+ super().__init__()
24
+ self.cfg = cfg
25
+
26
+ num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')]
27
+ self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms)
28
+ self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms)
29
+
30
+ self.diffusion = FullDPM(
31
+ cfg.res_feat_dim,
32
+ cfg.pair_feat_dim,
33
+ **cfg.diffusion,
34
+ )
35
+
36
+ def encode(self, batch, remove_structure, remove_sequence):
37
+ """
38
+ Returns:
39
+ res_feat: (N, L, res_feat_dim)
40
+ pair_feat: (N, L, L, pair_feat_dim)
41
+ """
42
+ # This is used throughout embedding and encoding layers
43
+ # to avoid data leakage.
44
+ context_mask = torch.logical_and(
45
+ batch['mask_heavyatom'][:, :, BBHeavyAtom.CA],
46
+ ~batch['generate_flag'] # Context means ``not generated''
47
+ )
48
+
49
+ structure_mask = context_mask if remove_structure else None
50
+ sequence_mask = context_mask if remove_sequence else None
51
+
52
+ res_feat = self.residue_embed(
53
+ aa = batch['aa'],
54
+ res_nb = batch['res_nb'],
55
+ chain_nb = batch['chain_nb'],
56
+ pos_atoms = batch['pos_heavyatom'],
57
+ mask_atoms = batch['mask_heavyatom'],
58
+ fragment_type = batch['fragment_type'],
59
+ structure_mask = structure_mask,
60
+ sequence_mask = sequence_mask,
61
+ )
62
+
63
+ pair_feat = self.pair_embed(
64
+ aa = batch['aa'],
65
+ res_nb = batch['res_nb'],
66
+ chain_nb = batch['chain_nb'],
67
+ pos_atoms = batch['pos_heavyatom'],
68
+ mask_atoms = batch['mask_heavyatom'],
69
+ structure_mask = structure_mask,
70
+ sequence_mask = sequence_mask,
71
+ )
72
+
73
+ R = construct_3d_basis(
74
+ batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],
75
+ batch['pos_heavyatom'][:, :, BBHeavyAtom.C],
76
+ batch['pos_heavyatom'][:, :, BBHeavyAtom.N],
77
+ )
78
+ p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
79
+
80
+ return res_feat, pair_feat, R, p
81
+
82
+ def forward(self, batch):
83
+ mask_generate = batch['generate_flag']
84
+ mask_res = batch['mask']
85
+ res_feat, pair_feat, R_0, p_0 = self.encode(
86
+ batch,
87
+ remove_structure = self.cfg.get('train_structure', True),
88
+ remove_sequence = self.cfg.get('train_sequence', True)
89
+ )
90
+ v_0 = rotation_to_so3vec(R_0)
91
+ s_0 = batch['aa']
92
+
93
+ loss_dict = self.diffusion(
94
+ v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res,
95
+ denoise_structure = self.cfg.get('train_structure', True),
96
+ denoise_sequence = self.cfg.get('train_sequence', True),
97
+ )
98
+ return loss_dict
99
+
100
+ @torch.no_grad()
101
+ def sample(
102
+ self,
103
+ batch,
104
+ sample_opt={
105
+ 'sample_structure': True,
106
+ 'sample_sequence': True,
107
+ }
108
+ ):
109
+ mask_generate = batch['generate_flag']
110
+ mask_res = batch['mask']
111
+ res_feat, pair_feat, R_0, p_0 = self.encode(
112
+ batch,
113
+ remove_structure = sample_opt.get('sample_structure', True),
114
+ remove_sequence = sample_opt.get('sample_sequence', True)
115
+ )
116
+ v_0 = rotation_to_so3vec(R_0)
117
+ s_0 = batch['aa']
118
+ traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt)
119
+ return traj
120
+
121
+ @torch.no_grad()
122
+ def optimize(
123
+ self,
124
+ batch,
125
+ opt_step,
126
+ optimize_opt={
127
+ 'sample_structure': True,
128
+ 'sample_sequence': True,
129
+ }
130
+ ):
131
+ mask_generate = batch['generate_flag']
132
+ mask_res = batch['mask']
133
+ res_feat, pair_feat, R_0, p_0 = self.encode(
134
+ batch,
135
+ remove_structure = optimize_opt.get('sample_structure', True),
136
+ remove_sequence = optimize_opt.get('sample_sequence', True)
137
+ )
138
+ v_0 = rotation_to_so3vec(R_0)
139
+ s_0 = batch['aa']
140
+
141
+ traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt)
142
+ return traj
diffab/modules/common/geometry.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from diffab.utils.protein.constants import (
5
+ BBHeavyAtom,
6
+ backbone_atom_coordinates_tensor,
7
+ bb_oxygen_coordinate_tensor,
8
+ )
9
+ from .topology import get_terminus_flag
10
+
11
+
12
+ def safe_norm(x, dim=-1, keepdim=False, eps=1e-8, sqrt=True):
13
+ out = torch.clamp(torch.sum(torch.square(x), dim=dim, keepdim=keepdim), min=eps)
14
+ return torch.sqrt(out) if sqrt else out
15
+
16
+
17
+ def pairwise_distances(x, y=None, return_v=False):
18
+ """
19
+ Args:
20
+ x: (B, N, d)
21
+ y: (B, M, d)
22
+ """
23
+ if y is None: y = x
24
+ v = x.unsqueeze(2) - y.unsqueeze(1) # (B, N, M, d)
25
+ d = safe_norm(v, dim=-1)
26
+ if return_v:
27
+ return d, v
28
+ else:
29
+ return d
30
+
31
+
32
+ def normalize_vector(v, dim, eps=1e-6):
33
+ return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps)
34
+
35
+
36
+ def project_v2v(v, e, dim):
37
+ """
38
+ Description:
39
+ Project vector `v` onto vector `e`.
40
+ Args:
41
+ v: (N, L, 3).
42
+ e: (N, L, 3).
43
+ """
44
+ return (e * v).sum(dim=dim, keepdim=True) * e
45
+
46
+
47
+ def construct_3d_basis(center, p1, p2):
48
+ """
49
+ Args:
50
+ center: (N, L, 3), usually the position of C_alpha.
51
+ p1: (N, L, 3), usually the position of C.
52
+ p2: (N, L, 3), usually the position of N.
53
+ Returns
54
+ A batch of orthogonal basis matrix, (N, L, 3, 3cols_index).
55
+ The matrix is composed of 3 column vectors: [e1, e2, e3].
56
+ """
57
+ v1 = p1 - center # (N, L, 3)
58
+ e1 = normalize_vector(v1, dim=-1)
59
+
60
+ v2 = p2 - center # (N, L, 3)
61
+ u2 = v2 - project_v2v(v2, e1, dim=-1)
62
+ e2 = normalize_vector(u2, dim=-1)
63
+
64
+ e3 = torch.cross(e1, e2, dim=-1) # (N, L, 3)
65
+
66
+ mat = torch.cat([
67
+ e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1)
68
+ ], dim=-1) # (N, L, 3, 3_index)
69
+ return mat
70
+
71
+
72
+ def local_to_global(R, t, p):
73
+ """
74
+ Description:
75
+ Convert local (internal) coordinates to global (external) coordinates q.
76
+ q <- Rp + t
77
+ Args:
78
+ R: (N, L, 3, 3).
79
+ t: (N, L, 3).
80
+ p: Local coordinates, (N, L, ..., 3).
81
+ Returns:
82
+ q: Global coordinates, (N, L, ..., 3).
83
+ """
84
+ assert p.size(-1) == 3
85
+ p_size = p.size()
86
+ N, L = p_size[0], p_size[1]
87
+
88
+ p = p.view(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *)
89
+ q = torch.matmul(R, p) + t.unsqueeze(-1) # (N, L, 3, *)
90
+ q = q.transpose(-1, -2).reshape(p_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3)
91
+ return q
92
+
93
+
94
+ def global_to_local(R, t, q):
95
+ """
96
+ Description:
97
+ Convert global (external) coordinates q to local (internal) coordinates p.
98
+ p <- R^{T}(q - t)
99
+ Args:
100
+ R: (N, L, 3, 3).
101
+ t: (N, L, 3).
102
+ q: Global coordinates, (N, L, ..., 3).
103
+ Returns:
104
+ p: Local coordinates, (N, L, ..., 3).
105
+ """
106
+ assert q.size(-1) == 3
107
+ q_size = q.size()
108
+ N, L = q_size[0], q_size[1]
109
+
110
+ q = q.reshape(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *)
111
+ p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (N, L, 3, *)
112
+ p = p.transpose(-1, -2).reshape(q_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3)
113
+ return p
114
+
115
+
116
+ def apply_rotation_to_vector(R, p):
117
+ return local_to_global(R, torch.zeros_like(p), p)
118
+
119
+
120
+ def compose_rotation_and_translation(R1, t1, R2, t2):
121
+ """
122
+ Args:
123
+ R1,t1: Frame basis and coordinate, (N, L, 3, 3), (N, L, 3).
124
+ R2,t2: Rotation and translation to be applied to (R1, t1), (N, L, 3, 3), (N, L, 3).
125
+ Returns
126
+ R_new <- R1R2
127
+ t_new <- R1t2 + t1
128
+ """
129
+ R_new = torch.matmul(R1, R2) # (N, L, 3, 3)
130
+ t_new = torch.matmul(R1, t2.unsqueeze(-1)).squeeze(-1) + t1
131
+ return R_new, t_new
132
+
133
+
134
+ def compose_chain(Ts):
135
+ while len(Ts) >= 2:
136
+ R1, t1 = Ts[-2]
137
+ R2, t2 = Ts[-1]
138
+ T_next = compose_rotation_and_translation(R1, t1, R2, t2)
139
+ Ts = Ts[:-2] + [T_next]
140
+ return Ts[0]
141
+
142
+
143
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
144
+ # All rights reserved.
145
+ #
146
+ # This source code is licensed under the BSD-style license found in the
147
+ # LICENSE file in the root directory of this source tree.
148
+ def quaternion_to_rotation_matrix(quaternions):
149
+ """
150
+ Convert rotations given as quaternions to rotation matrices.
151
+ Args:
152
+ quaternions: quaternions with real part first,
153
+ as tensor of shape (..., 4).
154
+ Returns:
155
+ Rotation matrices as tensor of shape (..., 3, 3).
156
+ """
157
+ quaternions = F.normalize(quaternions, dim=-1)
158
+ r, i, j, k = torch.unbind(quaternions, -1)
159
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
160
+
161
+ o = torch.stack(
162
+ (
163
+ 1 - two_s * (j * j + k * k),
164
+ two_s * (i * j - k * r),
165
+ two_s * (i * k + j * r),
166
+ two_s * (i * j + k * r),
167
+ 1 - two_s * (i * i + k * k),
168
+ two_s * (j * k - i * r),
169
+ two_s * (i * k - j * r),
170
+ two_s * (j * k + i * r),
171
+ 1 - two_s * (i * i + j * j),
172
+ ),
173
+ -1,
174
+ )
175
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
176
+
177
+
178
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
179
+ # All rights reserved.
180
+ #
181
+ # This source code is licensed under the BSD-style license found in the
182
+ # LICENSE file in the root directory of this source tree.
183
+ """
184
+ BSD License
185
+
186
+ For PyTorch3D software
187
+
188
+ Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
189
+
190
+ Redistribution and use in source and binary forms, with or without modification,
191
+ are permitted provided that the following conditions are met:
192
+
193
+ * Redistributions of source code must retain the above copyright notice, this
194
+ list of conditions and the following disclaimer.
195
+
196
+ * Redistributions in binary form must reproduce the above copyright notice,
197
+ this list of conditions and the following disclaimer in the documentation
198
+ and/or other materials provided with the distribution.
199
+
200
+ * Neither the name Meta nor the names of its contributors may be used to
201
+ endorse or promote products derived from this software without specific
202
+ prior written permission.
203
+
204
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
205
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
206
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
207
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
208
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
209
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
210
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
211
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
212
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
213
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
214
+ """
215
+ def quaternion_1ijk_to_rotation_matrix(q):
216
+ """
217
+ (1 + ai + bj + ck) -> R
218
+ Args:
219
+ q: (..., 3)
220
+ """
221
+ b, c, d = torch.unbind(q, dim=-1)
222
+ s = torch.sqrt(1 + b**2 + c**2 + d**2)
223
+ a, b, c, d = 1/s, b/s, c/s, d/s
224
+
225
+ o = torch.stack(
226
+ (
227
+ a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c,
228
+ 2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b,
229
+ 2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2,
230
+ ),
231
+ -1,
232
+ )
233
+ return o.reshape(q.shape[:-1] + (3, 3))
234
+
235
+
236
+ def repr_6d_to_rotation_matrix(x):
237
+ """
238
+ Args:
239
+ x: 6D representations, (..., 6).
240
+ Returns:
241
+ Rotation matrices, (..., 3, 3_index).
242
+ """
243
+ a1, a2 = x[..., 0:3], x[..., 3:6]
244
+ b1 = normalize_vector(a1, dim=-1)
245
+ b2 = normalize_vector(a2 - project_v2v(a2, b1, dim=-1), dim=-1)
246
+ b3 = torch.cross(b1, b2, dim=-1)
247
+
248
+ mat = torch.cat([
249
+ b1.unsqueeze(-1), b2.unsqueeze(-1), b3.unsqueeze(-1)
250
+ ], dim=-1) # (N, L, 3, 3_index)
251
+ return mat
252
+
253
+
254
+ def dihedral_from_four_points(p0, p1, p2, p3):
255
+ """
256
+ Args:
257
+ p0-3: (*, 3).
258
+ Returns:
259
+ Dihedral angles in radian, (*, ).
260
+ """
261
+ v0 = p2 - p1
262
+ v1 = p0 - p1
263
+ v2 = p3 - p2
264
+ u1 = torch.cross(v0, v1, dim=-1)
265
+ n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True)
266
+ u2 = torch.cross(v0, v2, dim=-1)
267
+ n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True)
268
+ sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) )
269
+ dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999) )
270
+ dihed = torch.nan_to_num(dihed)
271
+ return dihed
272
+
273
+
274
+ def knn_gather(idx, value):
275
+ """
276
+ Args:
277
+ idx: (B, N, K)
278
+ value: (B, M, d)
279
+ Returns:
280
+ (B, N, K, d)
281
+ """
282
+ N, d = idx.size(1), value.size(-1)
283
+ idx = idx.unsqueeze(-1).repeat(1, 1, 1, d) # (B, N, K, d)
284
+ value = value.unsqueeze(1).repeat(1, N, 1, 1) # (B, N, M, d)
285
+ return torch.gather(value, dim=2, index=idx)
286
+
287
+
288
+ def knn_points(q, p, K):
289
+ """
290
+ Args:
291
+ q: (B, M, d)
292
+ p: (B, N, d)
293
+ Returns:
294
+ (B, M, K), (B, M, K), (B, M, K, d)
295
+ """
296
+ _, L, _ = p.size()
297
+ d = pairwise_distances(q, p) # (B, N, M)
298
+ dist, idx = d.topk(min(L, K), dim=-1, largest=False) # (B, M, K), (B, M, K)
299
+ return dist, idx, knn_gather(idx, p)
300
+
301
+
302
+ def angstrom_to_nm(x):
303
+ return x / 10
304
+
305
+
306
+ def nm_to_angstrom(x):
307
+ return x * 10
308
+
309
+
310
+ def get_backbone_dihedral_angles(pos_atoms, chain_nb, res_nb, mask):
311
+ """
312
+ Args:
313
+ pos_atoms: (N, L, A, 3).
314
+ chain_nb: (N, L).
315
+ res_nb: (N, L).
316
+ mask: (N, L).
317
+ Returns:
318
+ bb_dihedral: Omega, Phi, and Psi angles in radian, (N, L, 3).
319
+ mask_bb_dihed: Masks of dihedral angles, (N, L, 3).
320
+ """
321
+ pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3)
322
+ pos_CA = pos_atoms[:, :, BBHeavyAtom.CA]
323
+ pos_C = pos_atoms[:, :, BBHeavyAtom.C]
324
+
325
+ N_term_flag, C_term_flag = get_terminus_flag(chain_nb, res_nb, mask) # (N, L)
326
+ omega_mask = torch.logical_not(N_term_flag)
327
+ phi_mask = torch.logical_not(N_term_flag)
328
+ psi_mask = torch.logical_not(C_term_flag)
329
+
330
+ # N-termini don't have omega and phi
331
+ omega = F.pad(
332
+ dihedral_from_four_points(pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:]),
333
+ pad=(1, 0), value=0,
334
+ )
335
+ phi = F.pad(
336
+ dihedral_from_four_points(pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:], pos_C[:, 1:]),
337
+ pad=(1, 0), value=0,
338
+ )
339
+
340
+ # C-termini don't have psi
341
+ psi = F.pad(
342
+ dihedral_from_four_points(pos_N[:, :-1], pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:]),
343
+ pad=(0, 1), value=0,
344
+ )
345
+
346
+ mask_bb_dihed = torch.stack([omega_mask, phi_mask, psi_mask], dim=-1)
347
+ bb_dihedral = torch.stack([omega, phi, psi], dim=-1) * mask_bb_dihed
348
+ return bb_dihedral, mask_bb_dihed
349
+
350
+
351
+ def pairwise_dihedrals(pos_atoms):
352
+ """
353
+ Args:
354
+ pos_atoms: (N, L, A, 3).
355
+ Returns:
356
+ Inter-residue Phi and Psi angles, (N, L, L, 2).
357
+ """
358
+ N, L = pos_atoms.shape[:2]
359
+ pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3)
360
+ pos_CA = pos_atoms[:, :, BBHeavyAtom.CA]
361
+ pos_C = pos_atoms[:, :, BBHeavyAtom.C]
362
+
363
+ ir_phi = dihedral_from_four_points(
364
+ pos_C[:,:,None].expand(N, L, L, 3),
365
+ pos_N[:,None,:].expand(N, L, L, 3),
366
+ pos_CA[:,None,:].expand(N, L, L, 3),
367
+ pos_C[:,None,:].expand(N, L, L, 3)
368
+ )
369
+ ir_psi = dihedral_from_four_points(
370
+ pos_N[:,:,None].expand(N, L, L, 3),
371
+ pos_CA[:,:,None].expand(N, L, L, 3),
372
+ pos_C[:,:,None].expand(N, L, L, 3),
373
+ pos_N[:,None,:].expand(N, L, L, 3)
374
+ )
375
+ ir_dihed = torch.stack([ir_phi, ir_psi], dim=-1)
376
+ return ir_dihed
377
+
378
+
379
+ def apply_rotation_matrix_to_rot6d(R, O):
380
+ """
381
+ Args:
382
+ R: (..., 3, 3)
383
+ O: (..., 6)
384
+ Returns:
385
+ Rotated 6D representation, (..., 6).
386
+ """
387
+ u1, u2 = O[..., :3, None], O[..., 3:, None] # (..., 3, 1)
388
+ v1 = torch.matmul(R, u1).squeeze(-1) # (..., 3)
389
+ v2 = torch.matmul(R, u2).squeeze(-1)
390
+ return torch.cat([v1, v2], dim=-1)
391
+
392
+
393
+ def normalize_rot6d(O):
394
+ """
395
+ Args:
396
+ O: (..., 6)
397
+ """
398
+ u1, u2 = O[..., :3], O[..., 3:] # (..., 3)
399
+ v1 = F.normalize(u1, p=2, dim=-1) # (..., 3)
400
+ v2 = F.normalize(u2 - project_v2v(u2, v1), p=2, dim=-1)
401
+ return torch.cat([v1, v2], dim=-1)
402
+
403
+
404
+ def reconstruct_backbone(R, t, aa, chain_nb, res_nb, mask):
405
+ """
406
+ Args:
407
+ R: (N, L, 3, 3)
408
+ t: (N, L, 3)
409
+ aa: (N, L)
410
+ chain_nb: (N, L)
411
+ res_nb: (N, L)
412
+ mask: (N, L)
413
+ Returns:
414
+ Reconstructed backbone atoms, (N, L, 4, 3).
415
+ """
416
+ N, L = aa.size()
417
+ # atom_coords = restype_heavyatom_rigid_group_positions.clone().to(t) # (21, 14, 3)
418
+ bb_coords = backbone_atom_coordinates_tensor.clone().to(t) # (21, 3, 3)
419
+ oxygen_coord = bb_oxygen_coordinate_tensor.clone().to(t) # (21, 3)
420
+ aa = aa.clamp(min=0, max=20) # 20 for UNK
421
+
422
+ bb_coords = bb_coords[aa.flatten()].reshape(N, L, -1, 3) # (N, L, 3, 3)
423
+ oxygen_coord = oxygen_coord[aa.flatten()].reshape(N, L, -1) # (N, L, 3)
424
+ bb_pos = local_to_global(R, t, bb_coords) # Global coordinates of N, CA, C. (N, L, 3, 3).
425
+
426
+ # Compute PSI angle
427
+ bb_dihedral, _ = get_backbone_dihedral_angles(bb_pos, chain_nb, res_nb, mask)
428
+ psi = bb_dihedral[..., 2] # (N, L)
429
+ # Make rotation matrix for PSI
430
+ sin_psi = torch.sin(psi).reshape(N, L, 1, 1)
431
+ cos_psi = torch.cos(psi).reshape(N, L, 1, 1)
432
+ zero = torch.zeros_like(sin_psi)
433
+ one = torch.ones_like(sin_psi)
434
+ row1 = torch.cat([one, zero, zero], dim=-1) # (N, L, 1, 3)
435
+ row2 = torch.cat([zero, cos_psi, -sin_psi], dim=-1) # (N, L, 1, 3)
436
+ row3 = torch.cat([zero, sin_psi, cos_psi], dim=-1) # (N, L, 1, 3)
437
+ R_psi = torch.cat([row1, row2, row3], dim=-2) # (N, L, 3, 3)
438
+
439
+ # Compute rotoation and translation of PSI frame, and position of O.
440
+ R_psi, t_psi = compose_chain([
441
+ (R, t), # Backbone
442
+ (R_psi, torch.zeros_like(t)), # PSI angle
443
+ ])
444
+ O_pos = local_to_global(R_psi, t_psi, oxygen_coord.reshape(N, L, 1, 3))
445
+
446
+ bb_pos = torch.cat([bb_pos, O_pos], dim=2) # (N, L, 4, 3)
447
+ return bb_pos
448
+
449
+
450
+ def reconstruct_backbone_partially(pos_ctx, R_new, t_new, aa, chain_nb, res_nb, mask_atoms, mask_recons):
451
+ """
452
+ Args:
453
+ pos: (N, L, A, 3).
454
+ R_new: (N, L, 3, 3).
455
+ t_new: (N, L, 3).
456
+ mask_atoms: (N, L, A).
457
+ mask_recons:(N, L).
458
+ Returns:
459
+ pos_new: (N, L, A, 3).
460
+ mask_new: (N, L, A).
461
+ """
462
+ N, L, A = mask_atoms.size()
463
+
464
+ mask_res = mask_atoms[:, :, BBHeavyAtom.CA]
465
+ pos_recons = reconstruct_backbone(R_new, t_new, aa, chain_nb, res_nb, mask_res) # (N, L, 4, 3)
466
+ pos_recons = F.pad(pos_recons, pad=(0, 0, 0, A-4), value=0) # (N, L, A, 3)
467
+
468
+ pos_new = torch.where(
469
+ mask_recons[:, :, None, None].expand_as(pos_ctx),
470
+ pos_recons, pos_ctx
471
+ ) # (N, L, A, 3)
472
+
473
+ mask_bb_atoms = torch.zeros_like(mask_atoms)
474
+ mask_bb_atoms[:, :, :4] = True
475
+ mask_new = torch.where(
476
+ mask_recons[:, :, None].expand_as(mask_atoms),
477
+ mask_bb_atoms, mask_atoms
478
+ )
479
+
480
+ return pos_new, mask_new
481
+
diffab/modules/common/layers.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def mask_zero(mask, value):
7
+ return torch.where(mask, value, torch.zeros_like(value))
8
+
9
+
10
+ def clampped_one_hot(x, num_classes):
11
+ mask = (x >= 0) & (x < num_classes) # (N, L)
12
+ x = x.clamp(min=0, max=num_classes-1)
13
+ y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C)
14
+ return y
15
+
16
+
17
+ class DistanceToBins(nn.Module):
18
+
19
+ def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False):
20
+ super().__init__()
21
+ self.dist_min = dist_min
22
+ self.dist_max = dist_max
23
+ self.num_bins = num_bins
24
+ self.use_onehot = use_onehot
25
+
26
+ if use_onehot:
27
+ offset = torch.linspace(dist_min, dist_max, self.num_bins)
28
+ else:
29
+ offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag
30
+ self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred
31
+ self.register_buffer('offset', offset)
32
+
33
+ @property
34
+ def out_channels(self):
35
+ return self.num_bins
36
+
37
+ def forward(self, dist, dim, normalize=True):
38
+ """
39
+ Args:
40
+ dist: (N, *, 1, *)
41
+ Returns:
42
+ (N, *, num_bins, *)
43
+ """
44
+ assert dist.size()[dim] == 1
45
+ offset_shape = [1] * len(dist.size())
46
+ offset_shape[dim] = -1
47
+
48
+ if self.use_onehot:
49
+ diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *)
50
+ bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *)
51
+ y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0)
52
+ else:
53
+ overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *)
54
+ y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *)
55
+ y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *)
56
+ y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *)
57
+ if normalize:
58
+ y = y / y.sum(dim=dim, keepdim=True)
59
+
60
+ return y
61
+
62
+
63
+ class PositionalEncoding(nn.Module):
64
+
65
+ def __init__(self, num_funcs=6):
66
+ super().__init__()
67
+ self.num_funcs = num_funcs
68
+ self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs))
69
+
70
+ def get_out_dim(self, in_dim):
71
+ return in_dim * (2 * self.num_funcs + 1)
72
+
73
+ def forward(self, x):
74
+ """
75
+ Args:
76
+ x: (..., d).
77
+ """
78
+ shape = list(x.shape[:-1]) + [-1]
79
+ x = x.unsqueeze(-1) # (..., d, 1)
80
+ code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
81
+ code = code.reshape(shape)
82
+ return code
83
+
84
+
85
+ class AngularEncoding(nn.Module):
86
+
87
+ def __init__(self, num_funcs=3):
88
+ super().__init__()
89
+ self.num_funcs = num_funcs
90
+ self.register_buffer('freq_bands', torch.FloatTensor(
91
+ [i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)]
92
+ ))
93
+
94
+ def get_out_dim(self, in_dim):
95
+ return in_dim * (1 + 2*2*self.num_funcs)
96
+
97
+ def forward(self, x):
98
+ """
99
+ Args:
100
+ x: (..., d).
101
+ """
102
+ shape = list(x.shape[:-1]) + [-1]
103
+ x = x.unsqueeze(-1) # (..., d, 1)
104
+ code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
105
+ code = code.reshape(shape)
106
+ return code
107
+
108
+
109
+ class LayerNorm(nn.Module):
110
+
111
+ def __init__(self,
112
+ normal_shape,
113
+ gamma=True,
114
+ beta=True,
115
+ epsilon=1e-10):
116
+ """Layer normalization layer
117
+ See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
118
+ :param normal_shape: The shape of the input tensor or the last dimension of the input tensor.
119
+ :param gamma: Add a scale parameter if it is True.
120
+ :param beta: Add an offset parameter if it is True.
121
+ :param epsilon: Epsilon for calculating variance.
122
+ """
123
+ super().__init__()
124
+ if isinstance(normal_shape, int):
125
+ normal_shape = (normal_shape,)
126
+ else:
127
+ normal_shape = (normal_shape[-1],)
128
+ self.normal_shape = torch.Size(normal_shape)
129
+ self.epsilon = epsilon
130
+ if gamma:
131
+ self.gamma = nn.Parameter(torch.Tensor(*normal_shape))
132
+ else:
133
+ self.register_parameter('gamma', None)
134
+ if beta:
135
+ self.beta = nn.Parameter(torch.Tensor(*normal_shape))
136
+ else:
137
+ self.register_parameter('beta', None)
138
+ self.reset_parameters()
139
+
140
+ def reset_parameters(self):
141
+ if self.gamma is not None:
142
+ self.gamma.data.fill_(1)
143
+ if self.beta is not None:
144
+ self.beta.data.zero_()
145
+
146
+ def forward(self, x):
147
+ mean = x.mean(dim=-1, keepdim=True)
148
+ var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
149
+ std = (var + self.epsilon).sqrt()
150
+ y = (x - mean) / std
151
+ if self.gamma is not None:
152
+ y *= self.gamma
153
+ if self.beta is not None:
154
+ y += self.beta
155
+ return y
156
+
157
+ def extra_repr(self):
158
+ return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format(
159
+ self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon,
160
+ )
diffab/modules/common/so3.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .geometry import quaternion_to_rotation_matrix
8
+
9
+
10
+ def log_rotation(R):
11
+ trace = R[..., range(3), range(3)].sum(-1)
12
+ if torch.is_grad_enabled():
13
+ # The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999
14
+ min_cos = -0.999
15
+ else:
16
+ min_cos = -1.0
17
+ cos_theta = ( (trace-1) / 2 ).clamp_min(min=min_cos)
18
+ sin_theta = torch.sqrt(1 - cos_theta**2)
19
+ theta = torch.acos(cos_theta)
20
+ coef = ((theta+1e-8)/(2*sin_theta+2e-8))[..., None, None]
21
+ logR = coef * (R - R.transpose(-1, -2))
22
+ return logR
23
+
24
+
25
+ def skewsym_to_so3vec(S):
26
+ x = S[..., 1, 2]
27
+ y = S[..., 2, 0]
28
+ z = S[..., 0, 1]
29
+ w = torch.stack([x,y,z], dim=-1)
30
+ return w
31
+
32
+
33
+ def so3vec_to_skewsym(w):
34
+ x, y, z = torch.unbind(w, dim=-1)
35
+ o = torch.zeros_like(x)
36
+ S = torch.stack([
37
+ o, z, -y,
38
+ -z, o, x,
39
+ y, -x, o,
40
+ ], dim=-1).reshape(w.shape[:-1] + (3, 3))
41
+ return S
42
+
43
+
44
+ def exp_skewsym(S):
45
+ x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1)
46
+ I = torch.eye(3).to(S).view([1 for _ in range(S.dim()-2)] + [3, 3])
47
+
48
+ sinx, cosx = torch.sin(x), torch.cos(x)
49
+ b = (sinx + 1e-8) / (x + 1e-8)
50
+ c = (1-cosx + 1e-8) / (x**2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5
51
+
52
+ S2 = S @ S
53
+ return I + b[..., None, None]*S + c[..., None, None]*S2
54
+
55
+
56
+ def so3vec_to_rotation(w):
57
+ return exp_skewsym(so3vec_to_skewsym(w))
58
+
59
+
60
+ def rotation_to_so3vec(R):
61
+ logR = log_rotation(R)
62
+ w = skewsym_to_so3vec(logR)
63
+ return w
64
+
65
+
66
+ def random_uniform_so3(size, device='cpu'):
67
+ q = F.normalize(torch.randn(list(size)+[4,], device=device), dim=-1) # (..., 4)
68
+ return rotation_to_so3vec(quaternion_to_rotation_matrix(q))
69
+
70
+
71
+ class ApproxAngularDistribution(nn.Module):
72
+
73
+ def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024):
74
+ super().__init__()
75
+ self.std_threshold = std_threshold
76
+ self.num_bins = num_bins
77
+ self.num_iters = num_iters
78
+ self.register_buffer('stddevs', torch.FloatTensor(stddevs))
79
+ self.register_buffer('approx_flag', self.stddevs <= std_threshold)
80
+ self._precompute_histograms()
81
+
82
+ @staticmethod
83
+ def _pdf(x, e, L):
84
+ """
85
+ Args:
86
+ x: (N, )
87
+ e: Float
88
+ L: Integer
89
+ """
90
+ x = x[:, None] # (N, *)
91
+ c = ((1 - torch.cos(x)) / math.pi) # (N, *)
92
+ l = torch.arange(0, L)[None, :] # (*, L)
93
+ a = (2*l+1) * torch.exp(-l*(l+1)*(e**2)) # (*, L)
94
+ b = (torch.sin( (l+0.5)* x ) + 1e-6) / (torch.sin( x / 2 ) + 1e-6) # (N, L)
95
+
96
+ f = (c * a * b).sum(dim=1)
97
+ return f
98
+
99
+ def _precompute_histograms(self):
100
+ X, Y = [], []
101
+ for std in self.stddevs:
102
+ std = std.item()
103
+ x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,)
104
+ y = self._pdf(x, std, self.num_iters) # (n_bins,)
105
+ y = torch.nan_to_num(y).clamp_min(0)
106
+ X.append(x)
107
+ Y.append(y)
108
+ self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins)
109
+ self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins)
110
+
111
+ def sample(self, std_idx):
112
+ """
113
+ Args:
114
+ std_idx: Indices of standard deviation.
115
+ Returns:
116
+ samples: Angular samples [0, PI), same size as std.
117
+ """
118
+ size = std_idx.size()
119
+ std_idx = std_idx.flatten() # (N,)
120
+
121
+ # Samples from histogram
122
+ prob = self.Y[std_idx] # (N, n_bins)
123
+ bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,)
124
+ bin_start = self.X[std_idx, bin_idx] # (N,)
125
+ bin_width = self.X[std_idx, bin_idx+1] - self.X[std_idx, bin_idx]
126
+ samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,)
127
+
128
+ # Samples from Gaussian approximation
129
+ mean_gaussian = self.stddevs[std_idx]*2
130
+ std_gaussian = self.stddevs[std_idx]
131
+ samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian
132
+ samples_gaussian = samples_gaussian.abs() % math.pi
133
+
134
+ # Choose from histogram or Gaussian
135
+ gaussian_flag = self.approx_flag[std_idx]
136
+ samples = torch.where(gaussian_flag, samples_gaussian, samples_hist)
137
+
138
+ return samples.reshape(size)
139
+
140
+
141
+ def random_normal_so3(std_idx, angular_distrib, device='cpu'):
142
+ size = std_idx.size()
143
+ u = F.normalize(torch.randn(list(size)+[3,], device=device), dim=-1)
144
+ theta = angular_distrib.sample(std_idx)
145
+ w = u * theta[..., None]
146
+ return w
diffab/modules/common/structure.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU
3
+
4
+ from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix
5
+
6
+
7
+ class FrameRotationTranslationPrediction(Module):
8
+
9
+ def __init__(self, feat_dim, rot_repr, nn_type='mlp'):
10
+ super().__init__()
11
+ assert rot_repr in ('quaternion', '6d')
12
+ self.rot_repr = rot_repr
13
+ if rot_repr == 'quaternion':
14
+ out_dim = 3 + 3
15
+ elif rot_repr == '6d':
16
+ out_dim = 6 + 3
17
+
18
+ if nn_type == 'linear':
19
+ self.nn = Linear(feat_dim, out_dim)
20
+ elif nn_type == 'mlp':
21
+ self.nn = Sequential(
22
+ Linear(feat_dim, feat_dim), ReLU(),
23
+ Linear(feat_dim, feat_dim), ReLU(),
24
+ Linear(feat_dim, out_dim)
25
+ )
26
+ else:
27
+ raise ValueError('Unknown nn_type: %s' % nn_type)
28
+
29
+ def forward(self, x):
30
+ y = self.nn(x) # (..., d+3)
31
+ if self.rot_repr == 'quaternion':
32
+ quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1)
33
+ R_delta = quaternion_to_rotation_matrix(quaternion)
34
+ t_delta = y[..., 3:6]
35
+ return R_delta, t_delta
36
+ elif self.rot_repr == '6d':
37
+ R_delta = repr_6d_to_rotation_matrix(y[..., 0:6])
38
+ t_delta = y[..., 6:9]
39
+ return R_delta, t_delta
40
+
41
+
42
+ class FrameUpdate(Module):
43
+
44
+ def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'):
45
+ super().__init__()
46
+ self.transition_mlp = Sequential(
47
+ Linear(node_feat_dim, node_feat_dim), ReLU(),
48
+ Linear(node_feat_dim, node_feat_dim), ReLU(),
49
+ Linear(node_feat_dim, node_feat_dim),
50
+ )
51
+ self.transition_layer_norm = LayerNorm(node_feat_dim)
52
+
53
+ self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type)
54
+
55
+ def forward(self, R, t, x, mask_generate):
56
+ """
57
+ Args:
58
+ R: Frame basis matrices, (N, L, 3, 3_index).
59
+ t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom.
60
+ x: Node-wise features, (N, L, F).
61
+ mask_generate: Masks, (N, L).
62
+ Returns:
63
+ R': Updated basis matrices, (N, L, 3, 3_index).
64
+ t': Updated coordinates, (N, L, 3).
65
+ """
66
+ x = self.transition_layer_norm(x + self.transition_mlp(x))
67
+
68
+ R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3)
69
+ R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta)
70
+
71
+ mask_R = mask_generate[:, :, None, None].expand_as(R)
72
+ mask_t = mask_generate[:, :, None].expand_as(t)
73
+
74
+ R_new = torch.where(mask_R, R_new, R)
75
+ t_new = torch.where(mask_t, t_new, t)
76
+
77
+ return R_new, t_new
diffab/modules/common/topology.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def get_consecutive_flag(chain_nb, res_nb, mask):
6
+ """
7
+ Args:
8
+ chain_nb, res_nb
9
+ Returns:
10
+ consec: A flag tensor indicating whether residue-i is connected to residue-(i+1),
11
+ BoolTensor, (B, L-1)[b, i].
12
+ """
13
+ d_res_nb = (res_nb[:, 1:] - res_nb[:, :-1]).abs() # (B, L-1)
14
+ same_chain = (chain_nb[:, 1:] == chain_nb[:, :-1])
15
+ consec = torch.logical_and(d_res_nb == 1, same_chain)
16
+ consec = torch.logical_and(consec, mask[:, :-1])
17
+ return consec
18
+
19
+
20
+ def get_terminus_flag(chain_nb, res_nb, mask):
21
+ consec = get_consecutive_flag(chain_nb, res_nb, mask)
22
+ N_term_flag = F.pad(torch.logical_not(consec), pad=(1, 0), value=1)
23
+ C_term_flag = F.pad(torch.logical_not(consec), pad=(0, 1), value=1)
24
+ return N_term_flag, C_term_flag
diffab/modules/diffusion/dpm_full.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+ from tqdm.auto import tqdm
6
+
7
+ from diffab.modules.common.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix
8
+ from diffab.modules.common.so3 import so3vec_to_rotation, rotation_to_so3vec, random_uniform_so3
9
+ from diffab.modules.encoders.ga import GAEncoder
10
+ from .transition import RotationTransition, PositionTransition, AminoacidCategoricalTransition
11
+
12
+
13
+ def rotation_matrix_cosine_loss(R_pred, R_true):
14
+ """
15
+ Args:
16
+ R_pred: (*, 3, 3).
17
+ R_true: (*, 3, 3).
18
+ Returns:
19
+ Per-matrix losses, (*, ).
20
+ """
21
+ size = list(R_pred.shape[:-2])
22
+ ncol = R_pred.numel() // 3
23
+
24
+ RT_pred = R_pred.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
25
+ RT_true = R_true.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3)
26
+
27
+ ones = torch.ones([ncol, ], dtype=torch.long, device=R_pred.device)
28
+ loss = F.cosine_embedding_loss(RT_pred, RT_true, ones, reduction='none') # (ncol*3, )
29
+ loss = loss.reshape(size + [3]).sum(dim=-1) # (*, )
30
+ return loss
31
+
32
+
33
+ class EpsilonNet(nn.Module):
34
+
35
+ def __init__(self, res_feat_dim, pair_feat_dim, num_layers, encoder_opt={}):
36
+ super().__init__()
37
+ self.current_sequence_embedding = nn.Embedding(25, res_feat_dim) # 22 is padding
38
+ self.res_feat_mixer = nn.Sequential(
39
+ nn.Linear(res_feat_dim * 2, res_feat_dim), nn.ReLU(),
40
+ nn.Linear(res_feat_dim, res_feat_dim),
41
+ )
42
+ self.encoder = GAEncoder(res_feat_dim, pair_feat_dim, num_layers, **encoder_opt)
43
+
44
+ self.eps_crd_net = nn.Sequential(
45
+ nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
46
+ nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
47
+ nn.Linear(res_feat_dim, 3)
48
+ )
49
+
50
+ self.eps_rot_net = nn.Sequential(
51
+ nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
52
+ nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
53
+ nn.Linear(res_feat_dim, 3)
54
+ )
55
+
56
+ self.eps_seq_net = nn.Sequential(
57
+ nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(),
58
+ nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(),
59
+ nn.Linear(res_feat_dim, 20), nn.Softmax(dim=-1)
60
+ )
61
+
62
+ def forward(self, v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res):
63
+ """
64
+ Args:
65
+ v_t: (N, L, 3).
66
+ p_t: (N, L, 3).
67
+ s_t: (N, L).
68
+ res_feat: (N, L, res_dim).
69
+ pair_feat: (N, L, L, pair_dim).
70
+ beta: (N,).
71
+ mask_generate: (N, L).
72
+ mask_res: (N, L).
73
+ Returns:
74
+ v_next: UPDATED (not epsilon) SO3-vector of orietnations, (N, L, 3).
75
+ eps_pos: (N, L, 3).
76
+ """
77
+ N, L = mask_res.size()
78
+ R = so3vec_to_rotation(v_t) # (N, L, 3, 3)
79
+
80
+ # s_t = s_t.clamp(min=0, max=19) # TODO: clamping is good but ugly.
81
+ res_feat = self.res_feat_mixer(torch.cat([res_feat, self.current_sequence_embedding(s_t)], dim=-1)) # [Important] Incorporate sequence at the current step.
82
+ res_feat = self.encoder(R, p_t, res_feat, pair_feat, mask_res)
83
+
84
+ t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :].expand(N, L, 3)
85
+ in_feat = torch.cat([res_feat, t_embed], dim=-1)
86
+
87
+ # Position changes
88
+ eps_crd = self.eps_crd_net(in_feat) # (N, L, 3)
89
+ eps_pos = apply_rotation_to_vector(R, eps_crd) # (N, L, 3)
90
+ eps_pos = torch.where(mask_generate[:, :, None].expand_as(eps_pos), eps_pos, torch.zeros_like(eps_pos))
91
+
92
+ # New orientation
93
+ eps_rot = self.eps_rot_net(in_feat) # (N, L, 3)
94
+ U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (N, L, 3, 3)
95
+ R_next = R @ U
96
+ v_next = rotation_to_so3vec(R_next) # (N, L, 3)
97
+ v_next = torch.where(mask_generate[:, :, None].expand_as(v_next), v_next, v_t)
98
+
99
+ # New sequence categorical distributions
100
+ c_denoised = self.eps_seq_net(in_feat) # Already softmax-ed, (N, L, 20)
101
+
102
+ return v_next, R_next, eps_pos, c_denoised
103
+
104
+
105
+ class FullDPM(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ res_feat_dim,
110
+ pair_feat_dim,
111
+ num_steps,
112
+ eps_net_opt={},
113
+ trans_rot_opt={},
114
+ trans_pos_opt={},
115
+ trans_seq_opt={},
116
+ position_mean=[0.0, 0.0, 0.0],
117
+ position_scale=[10.0],
118
+ ):
119
+ super().__init__()
120
+ self.eps_net = EpsilonNet(res_feat_dim, pair_feat_dim, **eps_net_opt)
121
+ self.num_steps = num_steps
122
+ self.trans_rot = RotationTransition(num_steps, **trans_rot_opt)
123
+ self.trans_pos = PositionTransition(num_steps, **trans_pos_opt)
124
+ self.trans_seq = AminoacidCategoricalTransition(num_steps, **trans_seq_opt)
125
+
126
+ self.register_buffer('position_mean', torch.FloatTensor(position_mean).view(1, 1, -1))
127
+ self.register_buffer('position_scale', torch.FloatTensor(position_scale).view(1, 1, -1))
128
+ self.register_buffer('_dummy', torch.empty([0, ]))
129
+
130
+ def _normalize_position(self, p):
131
+ p_norm = (p - self.position_mean) / self.position_scale
132
+ return p_norm
133
+
134
+ def _unnormalize_position(self, p_norm):
135
+ p = p_norm * self.position_scale + self.position_mean
136
+ return p
137
+
138
+ def forward(self, v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, denoise_structure, denoise_sequence, t=None):
139
+ N, L = res_feat.shape[:2]
140
+ if t == None:
141
+ t = torch.randint(0, self.num_steps, (N,), dtype=torch.long, device=self._dummy.device)
142
+ p_0 = self._normalize_position(p_0)
143
+
144
+ if denoise_structure:
145
+ # Add noise to rotation
146
+ R_0 = so3vec_to_rotation(v_0)
147
+ v_noisy, _ = self.trans_rot.add_noise(v_0, mask_generate, t)
148
+ # Add noise to positions
149
+ p_noisy, eps_p = self.trans_pos.add_noise(p_0, mask_generate, t)
150
+ else:
151
+ R_0 = so3vec_to_rotation(v_0)
152
+ v_noisy = v_0.clone()
153
+ p_noisy = p_0.clone()
154
+ eps_p = torch.zeros_like(p_noisy)
155
+
156
+ if denoise_sequence:
157
+ # Add noise to sequence
158
+ _, s_noisy = self.trans_seq.add_noise(s_0, mask_generate, t)
159
+ else:
160
+ s_noisy = s_0.clone()
161
+
162
+ beta = self.trans_pos.var_sched.betas[t]
163
+ v_pred, R_pred, eps_p_pred, c_denoised = self.eps_net(
164
+ v_noisy, p_noisy, s_noisy, res_feat, pair_feat, beta, mask_generate, mask_res
165
+ ) # (N, L, 3), (N, L, 3, 3), (N, L, 3), (N, L, 20), (N, L)
166
+
167
+ loss_dict = {}
168
+
169
+ # Rotation loss
170
+ loss_rot = rotation_matrix_cosine_loss(R_pred, R_0) # (N, L)
171
+ loss_rot = (loss_rot * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
172
+ loss_dict['rot'] = loss_rot
173
+
174
+ # Position loss
175
+ loss_pos = F.mse_loss(eps_p_pred, eps_p, reduction='none').sum(dim=-1) # (N, L)
176
+ loss_pos = (loss_pos * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
177
+ loss_dict['pos'] = loss_pos
178
+
179
+ # Sequence categorical loss
180
+ post_true = self.trans_seq.posterior(s_noisy, s_0, t)
181
+ log_post_pred = torch.log(self.trans_seq.posterior(s_noisy, c_denoised, t) + 1e-8)
182
+ kldiv = F.kl_div(
183
+ input=log_post_pred,
184
+ target=post_true,
185
+ reduction='none',
186
+ log_target=False
187
+ ).sum(dim=-1) # (N, L)
188
+ loss_seq = (kldiv * mask_generate).sum() / (mask_generate.sum().float() + 1e-8)
189
+ loss_dict['seq'] = loss_seq
190
+
191
+ return loss_dict
192
+
193
+ @torch.no_grad()
194
+ def sample(
195
+ self,
196
+ v, p, s,
197
+ res_feat, pair_feat,
198
+ mask_generate, mask_res,
199
+ sample_structure=True, sample_sequence=True,
200
+ pbar=False,
201
+ ):
202
+ """
203
+ Args:
204
+ v: Orientations of contextual residues, (N, L, 3).
205
+ p: Positions of contextual residues, (N, L, 3).
206
+ s: Sequence of contextual residues, (N, L).
207
+ """
208
+ N, L = v.shape[:2]
209
+ p = self._normalize_position(p)
210
+
211
+ # Set the orientation and position of residues to be predicted to random values
212
+ if sample_structure:
213
+ v_rand = random_uniform_so3([N, L], device=self._dummy.device)
214
+ p_rand = torch.randn_like(p)
215
+ v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_rand, v)
216
+ p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_rand, p)
217
+ else:
218
+ v_init, p_init = v, p
219
+
220
+ if sample_sequence:
221
+ s_rand = torch.randint_like(s, low=0, high=19)
222
+ s_init = torch.where(mask_generate, s_rand, s)
223
+ else:
224
+ s_init = s
225
+
226
+ traj = {self.num_steps: (v_init, self._unnormalize_position(p_init), s_init)}
227
+ if pbar:
228
+ pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling')
229
+ else:
230
+ pbar = lambda x: x
231
+ for t in pbar(range(self.num_steps, 0, -1)):
232
+ v_t, p_t, s_t = traj[t]
233
+ p_t = self._normalize_position(p_t)
234
+
235
+ beta = self.trans_pos.var_sched.betas[t].expand([N, ])
236
+ t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
237
+
238
+ v_next, R_next, eps_p, c_denoised = self.eps_net(
239
+ v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
240
+ ) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
241
+
242
+ v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
243
+ p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
244
+ _, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
245
+
246
+ if not sample_structure:
247
+ v_next, p_next = v_t, p_t
248
+ if not sample_sequence:
249
+ s_next = s_t
250
+
251
+ traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
252
+ traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
253
+
254
+ return traj
255
+
256
+ @torch.no_grad()
257
+ def optimize(
258
+ self,
259
+ v, p, s,
260
+ opt_step: int,
261
+ res_feat, pair_feat,
262
+ mask_generate, mask_res,
263
+ sample_structure=True, sample_sequence=True,
264
+ pbar=False,
265
+ ):
266
+ """
267
+ Description:
268
+ First adds noise to the given structure, then denoises it.
269
+ """
270
+ N, L = v.shape[:2]
271
+ p = self._normalize_position(p)
272
+ t = torch.full([N, ], fill_value=opt_step, dtype=torch.long, device=self._dummy.device)
273
+
274
+ # Set the orientation and position of residues to be predicted to random values
275
+ if sample_structure:
276
+ # Add noise to rotation
277
+ v_noisy, _ = self.trans_rot.add_noise(v, mask_generate, t)
278
+ # Add noise to positions
279
+ p_noisy, _ = self.trans_pos.add_noise(p, mask_generate, t)
280
+ v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_noisy, v)
281
+ p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_noisy, p)
282
+ else:
283
+ v_init, p_init = v, p
284
+
285
+ if sample_sequence:
286
+ _, s_noisy = self.trans_seq.add_noise(s, mask_generate, t)
287
+ s_init = torch.where(mask_generate, s_noisy, s)
288
+ else:
289
+ s_init = s
290
+
291
+ traj = {opt_step: (v_init, self._unnormalize_position(p_init), s_init)}
292
+ if pbar:
293
+ pbar = functools.partial(tqdm, total=opt_step, desc='Optimizing')
294
+ else:
295
+ pbar = lambda x: x
296
+ for t in pbar(range(opt_step, 0, -1)):
297
+ v_t, p_t, s_t = traj[t]
298
+ p_t = self._normalize_position(p_t)
299
+
300
+ beta = self.trans_pos.var_sched.betas[t].expand([N, ])
301
+ t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device)
302
+
303
+ v_next, R_next, eps_p, c_denoised = self.eps_net(
304
+ v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res
305
+ ) # (N, L, 3), (N, L, 3, 3), (N, L, 3)
306
+
307
+ v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor)
308
+ p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor)
309
+ _, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor)
310
+
311
+ if not sample_structure:
312
+ v_next, p_next = v_t, p_t
313
+ if not sample_sequence:
314
+ s_next = s_t
315
+
316
+ traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next)
317
+ traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory.
318
+
319
+ return traj
diffab/modules/diffusion/transition.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffab.modules.common.layers import clampped_one_hot
7
+ from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec
8
+
9
+
10
+ class VarianceSchedule(nn.Module):
11
+
12
+ def __init__(self, num_steps=100, s=0.01):
13
+ super().__init__()
14
+ T = num_steps
15
+ t = torch.arange(0, num_steps+1, dtype=torch.float)
16
+ f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2
17
+ alpha_bars = f_t / f_t[0]
18
+
19
+ betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
20
+ betas = torch.cat([torch.zeros([1]), betas], dim=0)
21
+ betas = betas.clamp_max(0.999)
22
+
23
+ sigmas = torch.zeros_like(betas)
24
+ for i in range(1, betas.size(0)):
25
+ sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
26
+ sigmas = torch.sqrt(sigmas)
27
+
28
+ self.register_buffer('betas', betas)
29
+ self.register_buffer('alpha_bars', alpha_bars)
30
+ self.register_buffer('alphas', 1 - betas)
31
+ self.register_buffer('sigmas', sigmas)
32
+
33
+
34
+ class PositionTransition(nn.Module):
35
+
36
+ def __init__(self, num_steps, var_sched_opt={}):
37
+ super().__init__()
38
+ self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
39
+
40
+ def add_noise(self, p_0, mask_generate, t):
41
+ """
42
+ Args:
43
+ p_0: (N, L, 3).
44
+ mask_generate: (N, L).
45
+ t: (N,).
46
+ """
47
+ alpha_bar = self.var_sched.alpha_bars[t]
48
+
49
+ c0 = torch.sqrt(alpha_bar).view(-1, 1, 1)
50
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)
51
+
52
+ e_rand = torch.randn_like(p_0)
53
+ p_noisy = c0*p_0 + c1*e_rand
54
+ p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0)
55
+
56
+ return p_noisy, e_rand
57
+
58
+ def denoise(self, p_t, eps_p, mask_generate, t):
59
+ # IMPORTANT:
60
+ # clampping alpha is to fix the instability issue at the first step (t=T)
61
+ # it seems like a problem with the ``improved ddpm''.
62
+ alpha = self.var_sched.alphas[t].clamp_min(
63
+ self.var_sched.alphas[-2]
64
+ )
65
+ alpha_bar = self.var_sched.alpha_bars[t]
66
+ sigma = self.var_sched.sigmas[t].view(-1, 1, 1)
67
+
68
+ c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1)
69
+ c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1)
70
+
71
+ z = torch.where(
72
+ (t > 1)[:, None, None].expand_as(p_t),
73
+ torch.randn_like(p_t),
74
+ torch.zeros_like(p_t),
75
+ )
76
+
77
+ p_next = c0 * (p_t - c1 * eps_p) + sigma * z
78
+ p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t)
79
+ return p_next
80
+
81
+
82
+ class RotationTransition(nn.Module):
83
+
84
+ def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}):
85
+ super().__init__()
86
+ self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
87
+
88
+ # Forward (perturb)
89
+ c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,).
90
+ self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt)
91
+
92
+ # Inverse (generate)
93
+ sigma = self.var_sched.sigmas
94
+ self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt)
95
+
96
+ self.register_buffer('_dummy', torch.empty([0, ]))
97
+
98
+ def add_noise(self, v_0, mask_generate, t):
99
+ """
100
+ Args:
101
+ v_0: (N, L, 3).
102
+ mask_generate: (N, L).
103
+ t: (N,).
104
+ """
105
+ N, L = mask_generate.size()
106
+ alpha_bar = self.var_sched.alpha_bars[t]
107
+ c0 = torch.sqrt(alpha_bar).view(-1, 1, 1)
108
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1)
109
+
110
+ # Noise rotation
111
+ e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3)
112
+ e_normal = e_scaled / (c1 + 1e-8)
113
+ E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3)
114
+
115
+ # Scaled true rotation
116
+ R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3)
117
+
118
+ R_noisy = E_scaled @ R0_scaled
119
+ v_noisy = rotation_to_so3vec(R_noisy)
120
+ v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0)
121
+
122
+ return v_noisy, e_scaled
123
+
124
+ def denoise(self, v_t, v_next, mask_generate, t):
125
+ N, L = mask_generate.size()
126
+ e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3)
127
+ e = torch.where(
128
+ (t > 1)[:, None, None].expand(N, L, 3),
129
+ e,
130
+ torch.zeros_like(e) # Simply denoise and don't add noise at the last step
131
+ )
132
+ E = so3vec_to_rotation(e)
133
+
134
+ R_next = E @ so3vec_to_rotation(v_next)
135
+ v_next = rotation_to_so3vec(R_next)
136
+ v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t)
137
+
138
+ return v_next
139
+
140
+
141
+ class AminoacidCategoricalTransition(nn.Module):
142
+
143
+ def __init__(self, num_steps, num_classes=20, var_sched_opt={}):
144
+ super().__init__()
145
+ self.num_classes = num_classes
146
+ self.var_sched = VarianceSchedule(num_steps, **var_sched_opt)
147
+
148
+ @staticmethod
149
+ def _sample(c):
150
+ """
151
+ Args:
152
+ c: (N, L, K).
153
+ Returns:
154
+ x: (N, L).
155
+ """
156
+ N, L, K = c.size()
157
+ c = c.view(N*L, K) + 1e-8
158
+ x = torch.multinomial(c, 1).view(N, L)
159
+ return x
160
+
161
+ def add_noise(self, x_0, mask_generate, t):
162
+ """
163
+ Args:
164
+ x_0: (N, L)
165
+ mask_generate: (N, L).
166
+ t: (N,).
167
+ Returns:
168
+ c_t: Probability, (N, L, K).
169
+ x_t: Sample, LongTensor, (N, L).
170
+ """
171
+ N, L = x_0.size()
172
+ K = self.num_classes
173
+ c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K).
174
+ alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
175
+ c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K )
176
+ c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0)
177
+ x_t = self._sample(c_t)
178
+ return c_t, x_t
179
+
180
+ def posterior(self, x_t, x_0, t):
181
+ """
182
+ Args:
183
+ x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K).
184
+ x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K).
185
+ t: (N,).
186
+ Returns:
187
+ theta: Posterior probability at (t-1)-th step, (N, L, K).
188
+ """
189
+ K = self.num_classes
190
+
191
+ if x_t.dim() == 3:
192
+ c_t = x_t # When x_t is probability distribution.
193
+ else:
194
+ c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K)
195
+
196
+ if x_0.dim() == 3:
197
+ c_0 = x_0 # When x_0 is probability distribution.
198
+ else:
199
+ c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K)
200
+
201
+ alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
202
+ alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1)
203
+
204
+ theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K)
205
+ theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8)
206
+ return theta
207
+
208
+ def denoise(self, x_t, c_0_pred, mask_generate, t):
209
+ """
210
+ Args:
211
+ x_t: (N, L).
212
+ c_0_pred: Normalized probability predicted by networks, (N, L, K).
213
+ mask_generate: (N, L).
214
+ t: (N,).
215
+ Returns:
216
+ post: Posterior probability at (t-1)-th step, (N, L, K).
217
+ x_next: Sample at (t-1)-th step, LongTensor, (N, L).
218
+ """
219
+ c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K)
220
+ post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K)
221
+ post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t)
222
+ x_next = self._sample(post)
223
+ return post, x_next
diffab/modules/encoders/ga.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm
7
+ from diffab.modules.common.layers import mask_zero, LayerNorm
8
+ from diffab.utils.protein.constants import BBHeavyAtom
9
+
10
+
11
+ def _alpha_from_logits(logits, mask, inf=1e5):
12
+ """
13
+ Args:
14
+ logits: Logit matrices, (N, L_i, L_j, num_heads).
15
+ mask: Masks, (N, L).
16
+ Returns:
17
+ alpha: Attention weights.
18
+ """
19
+ N, L, _, _ = logits.size()
20
+ mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *)
21
+ mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *)
22
+
23
+ logits = torch.where(mask_pair, logits, logits - inf)
24
+ alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads)
25
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
26
+ return alpha
27
+
28
+
29
+ def _heads(x, n_heads, n_ch):
30
+ """
31
+ Args:
32
+ x: (..., num_heads * num_channels)
33
+ Returns:
34
+ (..., num_heads, num_channels)
35
+ """
36
+ s = list(x.size())[:-1] + [n_heads, n_ch]
37
+ return x.view(*s)
38
+
39
+
40
+ class GABlock(nn.Module):
41
+
42
+ def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8,
43
+ num_value_points=8, num_heads=12, bias=False):
44
+ super().__init__()
45
+ self.node_feat_dim = node_feat_dim
46
+ self.pair_feat_dim = pair_feat_dim
47
+ self.value_dim = value_dim
48
+ self.query_key_dim = query_key_dim
49
+ self.num_query_points = num_query_points
50
+ self.num_value_points = num_value_points
51
+ self.num_heads = num_heads
52
+
53
+ # Node
54
+ self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
55
+ self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
56
+ self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)
57
+
58
+ # Pair
59
+ self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)
60
+
61
+ # Spatial
62
+ self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)),
63
+ requires_grad=True)
64
+ self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
65
+ self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
66
+ self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)
67
+
68
+ # Output
69
+ self.out_transform = nn.Linear(
70
+ in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + (
71
+ num_heads * num_value_points * (3 + 3 + 1)),
72
+ out_features=node_feat_dim,
73
+ )
74
+
75
+ self.layer_norm_1 = LayerNorm(node_feat_dim)
76
+ self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
77
+ nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
78
+ nn.Linear(node_feat_dim, node_feat_dim))
79
+ self.layer_norm_2 = LayerNorm(node_feat_dim)
80
+
81
+ def _node_logits(self, x):
82
+ query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
83
+ key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
84
+ logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) *
85
+ (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads)
86
+ return logits_node
87
+
88
+ def _pair_logits(self, z):
89
+ logits_pair = self.proj_pair_bias(z)
90
+ return logits_pair
91
+
92
+ def _spatial_logits(self, R, t, x):
93
+ N, L, _ = t.size()
94
+
95
+ # Query
96
+ query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points,
97
+ 3) # (N, L, n_heads * n_pnts, 3)
98
+ query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3)
99
+ query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
100
+
101
+ # Key
102
+ key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points,
103
+ 3) # (N, L, 3, n_heads * n_pnts)
104
+ key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3)
105
+ key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
106
+
107
+ # Q-K Product
108
+ sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads)
109
+ gamma = F.softplus(self.spatial_coef)
110
+ logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points)))
111
+ / 2) # (N, L, L, n_heads)
112
+ return logits_spatial
113
+
114
+ def _pair_aggregation(self, alpha, z):
115
+ N, L = z.shape[:2]
116
+ feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C)
117
+ feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C)
118
+ return feat_p2n.reshape(N, L, -1)
119
+
120
+ def _node_aggregation(self, alpha, x):
121
+ N, L = x.shape[:2]
122
+ value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch)
123
+ feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch)
124
+ feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch)
125
+ return feat_node.reshape(N, L, -1)
126
+
127
+ def _spatial_aggregation(self, alpha, R, t, x):
128
+ N, L, _ = t.size()
129
+ value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points,
130
+ 3) # (N, L, n_heads * n_v_pnts, 3)
131
+ value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points,
132
+ 3)) # (N, L, n_heads, n_v_pnts, 3)
133
+ aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \
134
+ value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3)
135
+ aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3)
136
+
137
+ feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3)
138
+ feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts)
139
+ feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3)
140
+
141
+ feat_spatial = torch.cat([
142
+ feat_points.reshape(N, L, -1),
143
+ feat_distance.reshape(N, L, -1),
144
+ feat_direction.reshape(N, L, -1),
145
+ ], dim=-1)
146
+
147
+ return feat_spatial
148
+
149
+ def forward(self, R, t, x, z, mask):
150
+ """
151
+ Args:
152
+ R: Frame basis matrices, (N, L, 3, 3_index).
153
+ t: Frame external (absolute) coordinates, (N, L, 3).
154
+ x: Node-wise features, (N, L, F).
155
+ z: Pair-wise features, (N, L, L, C).
156
+ mask: Masks, (N, L).
157
+ Returns:
158
+ x': Updated node-wise features, (N, L, F).
159
+ """
160
+ # Attention logits
161
+ logits_node = self._node_logits(x)
162
+ logits_pair = self._pair_logits(z)
163
+ logits_spatial = self._spatial_logits(R, t, x)
164
+ # Summing logits up and apply `softmax`.
165
+ logits_sum = logits_node + logits_pair + logits_spatial
166
+ alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads)
167
+
168
+ # Aggregate features
169
+ feat_p2n = self._pair_aggregation(alpha, z)
170
+ feat_node = self._node_aggregation(alpha, x)
171
+ feat_spatial = self._spatial_aggregation(alpha, R, t, x)
172
+
173
+ # Finally
174
+ feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F)
175
+ feat_all = mask_zero(mask.unsqueeze(-1), feat_all)
176
+ x_updated = self.layer_norm_1(x + feat_all)
177
+ x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
178
+ return x_updated
179
+
180
+
181
+ class GAEncoder(nn.Module):
182
+
183
+ def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
184
+ super(GAEncoder, self).__init__()
185
+ self.blocks = nn.ModuleList([
186
+ GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt)
187
+ for _ in range(num_layers)
188
+ ])
189
+
190
+ def forward(self, R, t, res_feat, pair_feat, mask):
191
+ for i, block in enumerate(self.blocks):
192
+ res_feat = block(R, t, res_feat, pair_feat, mask)
193
+ return res_feat
diffab/modules/encoders/pair.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from diffab.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals
6
+ from diffab.modules.common.layers import AngularEncoding
7
+ from diffab.utils.protein.constants import BBHeavyAtom, AA
8
+
9
+
10
+ class PairEmbedding(nn.Module):
11
+
12
+ def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32):
13
+ super().__init__()
14
+ self.max_num_atoms = max_num_atoms
15
+ self.max_aa_types = max_aa_types
16
+ self.max_relpos = max_relpos
17
+ self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim)
18
+ self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim)
19
+
20
+ self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms)
21
+ nn.init.zeros_(self.aapair_to_distcoef.weight)
22
+ self.distance_embed = nn.Sequential(
23
+ nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(),
24
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
25
+ )
26
+
27
+ self.dihedral_embed = AngularEncoding()
28
+ feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi
29
+
30
+ infeat_dim = feat_dim+feat_dim+feat_dim+feat_dihed_dim
31
+ self.out_mlp = nn.Sequential(
32
+ nn.Linear(infeat_dim, feat_dim), nn.ReLU(),
33
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
34
+ nn.Linear(feat_dim, feat_dim),
35
+ )
36
+
37
+ def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
38
+ """
39
+ Args:
40
+ aa: (N, L).
41
+ res_nb: (N, L).
42
+ chain_nb: (N, L).
43
+ pos_atoms: (N, L, A, 3)
44
+ mask_atoms: (N, L, A)
45
+ structure_mask: (N, L)
46
+ sequence_mask: (N, L), mask out unknown amino acids to generate.
47
+
48
+ Returns:
49
+ (N, L, L, feat_dim)
50
+ """
51
+ N, L = aa.size()
52
+
53
+ # Remove other atoms
54
+ pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
55
+ mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
56
+
57
+ mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
58
+ mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :]
59
+ pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None
60
+
61
+ # Pair identities
62
+ if sequence_mask is not None:
63
+ # Avoid data leakage at training time
64
+ aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
65
+ aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L)
66
+ feat_aapair = self.aa_pair_embed(aa_pair)
67
+
68
+ # Relative sequential positions
69
+ same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :])
70
+ relpos = torch.clamp(
71
+ res_nb[:,:,None] - res_nb[:,None,:],
72
+ min=-self.max_relpos, max=self.max_relpos,
73
+ ) # (N, L, L)
74
+ feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None]
75
+
76
+ # Distances
77
+ d = angstrom_to_nm(torch.linalg.norm(
78
+ pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:],
79
+ dim = -1, ord = 2,
80
+ )).reshape(N, L, L, -1) # (N, L, L, A*A)
81
+ c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A)
82
+ d_gauss = torch.exp(-1 * c * d**2)
83
+ mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1)
84
+ feat_dist = self.distance_embed(d_gauss * mask_atom_pair)
85
+ if pair_structure_mask is not None:
86
+ # Avoid data leakage at training time
87
+ feat_dist = feat_dist * pair_structure_mask[:, :, :, None]
88
+
89
+ # Orientations
90
+ dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2)
91
+ feat_dihed = self.dihedral_embed(dihed)
92
+ if pair_structure_mask is not None:
93
+ # Avoid data leakage at training time
94
+ feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None]
95
+
96
+ # All
97
+ feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1)
98
+ feat_all = self.out_mlp(feat_all) # (N, L, L, F)
99
+ feat_all = feat_all * mask_pair[:, :, :, None]
100
+
101
+ return feat_all
102
+
diffab/modules/encoders/residue.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from diffab.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles
5
+ from diffab.modules.common.layers import AngularEncoding
6
+ from diffab.utils.protein.constants import BBHeavyAtom, AA
7
+
8
+
9
+ class ResidueEmbedding(nn.Module):
10
+
11
+ def __init__(self, feat_dim, max_num_atoms, max_aa_types=22):
12
+ super().__init__()
13
+ self.max_num_atoms = max_num_atoms
14
+ self.max_aa_types = max_aa_types
15
+ self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim)
16
+ self.dihed_embed = AngularEncoding()
17
+ self.type_embed = nn.Embedding(10, feat_dim, padding_idx=0) # 1: Heavy, 2: Light, 3: Ag
18
+ infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3) + feat_dim
19
+ self.mlp = nn.Sequential(
20
+ nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(),
21
+ nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(),
22
+ nn.Linear(feat_dim, feat_dim), nn.ReLU(),
23
+ nn.Linear(feat_dim, feat_dim)
24
+ )
25
+
26
+ def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, fragment_type, structure_mask=None, sequence_mask=None):
27
+ """
28
+ Args:
29
+ aa: (N, L).
30
+ res_nb: (N, L).
31
+ chain_nb: (N, L).
32
+ pos_atoms: (N, L, A, 3).
33
+ mask_atoms: (N, L, A).
34
+ fragment_type: (N, L).
35
+ structure_mask: (N, L), mask out unknown structures to generate.
36
+ sequence_mask: (N, L), mask out unknown amino acids to generate.
37
+ """
38
+ N, L = aa.size()
39
+ mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
40
+
41
+ # Remove other atoms
42
+ pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
43
+ mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
44
+
45
+ # Amino acid identity features
46
+ if sequence_mask is not None:
47
+ # Avoid data leakage at training time
48
+ aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
49
+ aa_feat = self.aatype_embed(aa) # (N, L, feat)
50
+
51
+ # Coordinate features
52
+ R = construct_3d_basis(
53
+ pos_atoms[:, :, BBHeavyAtom.CA],
54
+ pos_atoms[:, :, BBHeavyAtom.C],
55
+ pos_atoms[:, :, BBHeavyAtom.N]
56
+ )
57
+ t = pos_atoms[:, :, BBHeavyAtom.CA]
58
+ crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3)
59
+ crd_mask = mask_atoms[:, :, :, None].expand_as(crd)
60
+ crd = torch.where(crd_mask, crd, torch.zeros_like(crd))
61
+
62
+ aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
63
+ rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand)
64
+ place_mask = (aa_expand == rng_expand)
65
+ crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
66
+ crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand))
67
+ crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3)
68
+ if structure_mask is not None:
69
+ # Avoid data leakage at training time
70
+ crd_feat = crd_feat * structure_mask[:, :, None]
71
+
72
+ # Backbone dihedral features
73
+ bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue)
74
+ dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3)
75
+ dihed_feat = dihed_feat.reshape(N, L, -1)
76
+ if structure_mask is not None:
77
+ # Avoid data leakage at training time
78
+ dihed_mask = torch.logical_and(
79
+ structure_mask,
80
+ torch.logical_and(
81
+ torch.roll(structure_mask, shifts=+1, dims=1),
82
+ torch.roll(structure_mask, shifts=-1, dims=1)
83
+ ),
84
+ ) # Avoid slight data leakage via dihedral angles of anchor residues
85
+ dihed_feat = dihed_feat * dihed_mask[:, :, None]
86
+
87
+ # Type feature
88
+ type_feat = self.type_embed(fragment_type) # (N, L, feat)
89
+
90
+ out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat, type_feat], dim=-1)) # (N, L, F)
91
+ out_feat = out_feat * mask_residue[:, :, None]
92
+ return out_feat
diffab/tools/dock/base.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List
3
+
4
+
5
+ FilePath = str
6
+
7
+
8
+ class DockingEngine(abc.ABC):
9
+
10
+ @abc.abstractmethod
11
+ def __enter__(self):
12
+ pass
13
+
14
+ @abc.abstractmethod
15
+ def __exit__(self, typ, value, traceback):
16
+ pass
17
+
18
+ @abc.abstractmethod
19
+ def set_receptor(self, pdb_path: FilePath):
20
+ pass
21
+
22
+ @abc.abstractmethod
23
+ def set_ligand(self, pdb_path: FilePath):
24
+ pass
25
+
26
+ @abc.abstractmethod
27
+ def dock(self) -> List[FilePath]:
28
+ pass
diffab/tools/dock/hdock.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import subprocess
5
+ import dataclasses as dc
6
+ from typing import List, Optional
7
+ from Bio import PDB
8
+ from Bio.PDB import Model as PDBModel
9
+
10
+ from diffab.tools.renumber import renumber as renumber_chothia
11
+ from .base import DockingEngine
12
+
13
+
14
+ def fix_docked_pdb(pdb_path):
15
+ fixed = []
16
+ with open(pdb_path, 'r') as f:
17
+ for ln in f.readlines():
18
+ if (ln.startswith('ATOM') or ln.startswith('HETATM')) and len(ln) == 56:
19
+ fixed.append( ln[:-1] + ' 1.00 0.00 \n' )
20
+ else:
21
+ fixed.append(ln)
22
+ with open(pdb_path, 'w') as f:
23
+ f.write(''.join(fixed))
24
+
25
+
26
+ class HDock(DockingEngine):
27
+
28
+ def __init__(
29
+ self,
30
+ hdock_bin='./bin/hdock',
31
+ createpl_bin='./bin/createpl',
32
+ ):
33
+ super().__init__()
34
+ self.hdock_bin = os.path.realpath(hdock_bin)
35
+ self.createpl_bin = os.path.realpath(createpl_bin)
36
+ self.tmpdir = tempfile.TemporaryDirectory()
37
+
38
+ self._has_receptor = False
39
+ self._has_ligand = False
40
+
41
+ self._receptor_chains = []
42
+ self._ligand_chains = []
43
+
44
+ def __enter__(self):
45
+ return self
46
+
47
+ def __exit__(self, typ, value, traceback):
48
+ self.tmpdir.cleanup()
49
+
50
+ def set_receptor(self, pdb_path):
51
+ shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'receptor.pdb'))
52
+ self._has_receptor = True
53
+
54
+ def set_ligand(self, pdb_path):
55
+ shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb'))
56
+ self._has_ligand = True
57
+
58
+ def _dump_complex_pdb(self):
59
+ parser = PDB.PDBParser(QUIET=True)
60
+ model_receptor = parser.get_structure(None, os.path.join(self.tmpdir.name, 'receptor.pdb'))[0]
61
+ docked_pdb_path = os.path.join(self.tmpdir.name, 'ligand_docked.pdb')
62
+ fix_docked_pdb(docked_pdb_path)
63
+ structure_ligdocked = parser.get_structure(None, docked_pdb_path)
64
+
65
+ pdb_io = PDB.PDBIO()
66
+ paths = []
67
+ for i, model_ligdocked in enumerate(structure_ligdocked):
68
+ model_complex = PDBModel.Model(0)
69
+ for chain in model_receptor:
70
+ model_complex.add(chain.copy())
71
+ for chain in model_ligdocked:
72
+ model_complex.add(chain.copy())
73
+ pdb_io.set_structure(model_complex)
74
+ save_path = os.path.join(self.tmpdir.name, f"complex_{i}.pdb")
75
+ pdb_io.save(save_path)
76
+ paths.append(save_path)
77
+ return paths
78
+
79
+ def dock(self):
80
+ if not (self._has_receptor and self._has_ligand):
81
+ raise ValueError('Missing receptor or ligand.')
82
+ subprocess.run(
83
+ [self.hdock_bin, "receptor.pdb", "ligand.pdb"],
84
+ cwd=self.tmpdir.name, check=True
85
+ )
86
+ subprocess.run(
87
+ [self.createpl_bin, "Hdock.out", "ligand_docked.pdb"],
88
+ cwd=self.tmpdir.name, check=True
89
+ )
90
+ return self._dump_complex_pdb()
91
+
92
+
93
+ @dc.dataclass
94
+ class DockSite:
95
+ chain: str
96
+ resseq: int
97
+
98
+
99
+ class HDockAntibody(HDock):
100
+
101
+ def __init__(self, *args, **kwargs):
102
+ super().__init__(*args, **kwargs)
103
+ self._heavy_chain_id = None
104
+ self._epitope_sites: Optional[List[DockSite]] = None
105
+
106
+ def set_ligand(self, pdb_path):
107
+ raise NotImplementedError('Please use set_antibody')
108
+
109
+ def set_receptor(self, pdb_path):
110
+ raise NotImplementedError('Please use set_antigen')
111
+
112
+ def set_antigen(self, pdb_path, epitope_sites: Optional[List[DockSite]]=None):
113
+ super().set_receptor(pdb_path)
114
+ self._epitope_sites = epitope_sites
115
+
116
+ def set_antibody(self, pdb_path):
117
+ heavy_chains, _ = renumber_chothia(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb'))
118
+ self._has_ligand = True
119
+ self._heavy_chain_id = heavy_chains[0]
120
+
121
+ def _prepare_lsite(self):
122
+ lsite_content = f"95-102:{self._heavy_chain_id}\n" # Chothia CDR H3
123
+ with open(os.path.join(self.tmpdir.name, 'lsite.txt'), 'w') as f:
124
+ f.write(lsite_content)
125
+ print(f"[INFO] lsite content: {lsite_content}")
126
+
127
+ def _prepare_rsite(self):
128
+ rsite_content = ""
129
+ for site in self._epitope_sites:
130
+ rsite_content += f"{site.resseq}:{site.chain}\n"
131
+ with open(os.path.join(self.tmpdir.name, 'rsite.txt'), 'w') as f:
132
+ f.write(rsite_content)
133
+ print(f"[INFO] rsite content: {rsite_content}")
134
+
135
+ def dock(self):
136
+ if not (self._has_receptor and self._has_ligand):
137
+ raise ValueError('Missing receptor or ligand.')
138
+ self._prepare_lsite()
139
+
140
+ cmd_hdock = [self.hdock_bin, "receptor.pdb", "ligand.pdb", "-lsite", "lsite.txt"]
141
+ if self._epitope_sites is not None:
142
+ self._prepare_rsite()
143
+ cmd_hdock += ["-rsite", "rsite.txt"]
144
+ subprocess.run(
145
+ cmd_hdock,
146
+ cwd=self.tmpdir.name, check=True
147
+ )
148
+
149
+ cmd_pl = [self.createpl_bin, "Hdock.out", "ligand_docked.pdb", "-lsite", "lsite.txt"]
150
+ if self._epitope_sites is not None:
151
+ self._prepare_rsite()
152
+ cmd_pl += ["-rsite", "rsite.txt"]
153
+ subprocess.run(
154
+ cmd_pl,
155
+ cwd=self.tmpdir.name, check=True
156
+ )
157
+ return self._dump_complex_pdb()
158
+
159
+
160
+ if __name__ == '__main__':
161
+ with HDockAntibody('hdock', 'createpl') as dock:
162
+ dock.set_antigen('./data/dock/receptor.pdb', [DockSite('A', 991)])
163
+ dock.set_antibody('./data/example_dock/3qhf_fv.pdb')
164
+ print(dock.dock())
diffab/tools/eval/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .run import main
2
+
3
+ if __name__ == '__main__':
4
+ main()
diffab/tools/eval/base.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import shelve
5
+ from Bio import PDB
6
+ from typing import Optional, Tuple, List
7
+ from dataclasses import dataclass, field
8
+
9
+
10
+ @dataclass
11
+ class EvalTask:
12
+ in_path: str
13
+ ref_path: str
14
+ info: dict
15
+ structure: str
16
+ name: str
17
+ method: str
18
+ cdr: str
19
+ ab_chains: List
20
+
21
+ residue_first: Optional[Tuple] = None
22
+ residue_last: Optional[Tuple] = None
23
+
24
+ scores: dict = field(default_factory=dict)
25
+
26
+ def get_gen_biopython_model(self):
27
+ parser = PDB.PDBParser(QUIET=True)
28
+ return parser.get_structure(self.in_path, self.in_path)[0]
29
+
30
+ def get_ref_biopython_model(self):
31
+ parser = PDB.PDBParser(QUIET=True)
32
+ return parser.get_structure(self.ref_path, self.ref_path)[0]
33
+
34
+ def save_to_db(self, db: shelve.Shelf):
35
+ db[self.in_path] = self
36
+
37
+ def to_report_dict(self):
38
+ return {
39
+ 'method': self.method,
40
+ 'structure': self.structure,
41
+ 'cdr': self.cdr,
42
+ 'filename': os.path.basename(self.in_path),
43
+ **self.scores
44
+ }
45
+
46
+
47
+ class TaskScanner:
48
+
49
+ def __init__(self, root, postfix=None, db: Optional[shelve.Shelf]=None):
50
+ super().__init__()
51
+ self.root = root
52
+ self.postfix = postfix
53
+ self.visited = set()
54
+ self.db = db
55
+ if db is not None:
56
+ for k in db.keys():
57
+ self.visited.add(k)
58
+
59
+ def _get_metadata(self, fpath):
60
+ json_path = os.path.join(
61
+ os.path.dirname(os.path.dirname(fpath)),
62
+ 'metadata.json'
63
+ )
64
+ tag_name = os.path.basename(os.path.dirname(fpath))
65
+ method_name = os.path.basename(
66
+ os.path.dirname(os.path.dirname(os.path.dirname(fpath)))
67
+ )
68
+ try:
69
+ antibody_chains = set()
70
+ info = None
71
+ with open(json_path, 'r') as f:
72
+ metadata = json.load(f)
73
+ for item in metadata['items']:
74
+ if item['tag'] == tag_name:
75
+ info = item
76
+ antibody_chains.add(item['residue_first'][0])
77
+ if info is not None:
78
+ info['antibody_chains'] = list(antibody_chains)
79
+ info['structure'] = metadata['identifier']
80
+ info['method'] = method_name
81
+ return info
82
+ except (json.JSONDecodeError, FileNotFoundError) as e:
83
+ return None
84
+
85
+ def scan(self) -> List[EvalTask]:
86
+ tasks = []
87
+ if self.postfix is None or not self.postfix:
88
+ input_fname_pattern = '^\d+\.pdb$'
89
+ ref_fname = 'REF1.pdb'
90
+ else:
91
+ input_fname_pattern = f'^\d+\_{self.postfix}\.pdb$'
92
+ ref_fname = f'REF1_{self.postfix}.pdb'
93
+ for parent, _, files in os.walk(self.root):
94
+ for fname in files:
95
+ fpath = os.path.join(parent, fname)
96
+ if not re.match(input_fname_pattern, fname):
97
+ continue
98
+ if os.path.getsize(fpath) == 0:
99
+ continue
100
+ if fpath in self.visited:
101
+ continue
102
+
103
+ # Path to the reference structure
104
+ ref_path = os.path.join(parent, ref_fname)
105
+ if not os.path.exists(ref_path):
106
+ continue
107
+
108
+ # CDR information
109
+ info = self._get_metadata(fpath)
110
+ if info is None:
111
+ continue
112
+ tasks.append(EvalTask(
113
+ in_path = fpath,
114
+ ref_path = ref_path,
115
+ info = info,
116
+ structure = info['structure'],
117
+ name = info['name'],
118
+ method = info['method'],
119
+ cdr = info['tag'],
120
+ ab_chains = info['antibody_chains'],
121
+ residue_first = info.get('residue_first', None),
122
+ residue_last = info.get('residue_last', None),
123
+ ))
124
+ self.visited.add(fpath)
125
+ return tasks
diffab/tools/eval/energy.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyright: reportMissingImports=false
2
+ import pyrosetta
3
+ from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover
4
+ pyrosetta.init(' '.join([
5
+ '-mute', 'all',
6
+ '-use_input_sc',
7
+ '-ignore_unrecognized_res',
8
+ '-ignore_zero_occupancy', 'false',
9
+ '-load_PDB_components', 'false',
10
+ '-relax:default_repeats', '2',
11
+ '-no_fconfig',
12
+ ]))
13
+
14
+ from tools.eval.base import EvalTask
15
+
16
+
17
+ def pyrosetta_interface_energy(pdb_path, interface):
18
+ pose = pyrosetta.pose_from_pdb(pdb_path)
19
+ mover = InterfaceAnalyzerMover(interface)
20
+ mover.set_pack_separated(True)
21
+ mover.apply(pose)
22
+ return pose.scores['dG_separated']
23
+
24
+
25
+ def eval_interface_energy(task: EvalTask):
26
+ model_gen = task.get_gen_biopython_model()
27
+ antigen_chains = set()
28
+ for chain in model_gen:
29
+ if chain.id not in task.ab_chains:
30
+ antigen_chains.add(chain.id)
31
+ antigen_chains = ''.join(list(antigen_chains))
32
+ antibody_chains = ''.join(task.ab_chains)
33
+ interface = f"{antibody_chains}_{antigen_chains}"
34
+
35
+ dG_gen = pyrosetta_interface_energy(task.in_path, interface)
36
+ dG_ref = pyrosetta_interface_energy(task.ref_path, interface)
37
+
38
+ task.scores.update({
39
+ 'dG_gen': dG_gen,
40
+ 'dG_ref': dG_ref,
41
+ 'ddG': dG_gen - dG_ref
42
+ })
43
+ return task
diffab/tools/eval/run.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import ray
4
+ import shelve
5
+ import time
6
+ import pandas as pd
7
+ from typing import Mapping
8
+
9
+ from tools.eval.base import EvalTask, TaskScanner
10
+ from tools.eval.similarity import eval_similarity
11
+ from tools.eval.energy import eval_interface_energy
12
+
13
+
14
+ @ray.remote(num_cpus=1)
15
+ def evaluate(task, args):
16
+ funcs = []
17
+ funcs.append(eval_similarity)
18
+ if not args.no_energy:
19
+ funcs.append(eval_interface_energy)
20
+ for f in funcs:
21
+ task = f(task)
22
+ return task
23
+
24
+
25
+ def dump_db(db: Mapping[str, EvalTask], path):
26
+ table = []
27
+ for task in db.values():
28
+ if 'abopt' in path and task.scores['seqid'] >= 100.0:
29
+ # In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type
30
+ continue
31
+ table.append(task.to_report_dict())
32
+ table = pd.DataFrame(table)
33
+ table.to_csv(path, index=False, float_format='%.6f')
34
+ return table
35
+
36
+
37
+ def main():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('--root', type=str, default='./results')
40
+ parser.add_argument('--pfx', type=str, default='rosetta')
41
+ parser.add_argument('--no_energy', action='store_true', default=False)
42
+ args = parser.parse_args()
43
+ ray.init()
44
+
45
+ db_path = os.path.join(args.root, 'evaluation_db')
46
+ with shelve.open(db_path) as db:
47
+ scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db)
48
+
49
+ while True:
50
+ tasks = scanner.scan()
51
+ futures = [evaluate.remote(t, args) for t in tasks]
52
+ if len(futures) > 0:
53
+ print(f'Submitted {len(futures)} tasks.')
54
+ while len(futures) > 0:
55
+ done_ids, futures = ray.wait(futures, num_returns=1)
56
+ for done_id in done_ids:
57
+ done_task = ray.get(done_id)
58
+ done_task.save_to_db(db)
59
+ print(f'Remaining {len(futures)}. Finished {done_task.in_path}')
60
+ db.sync()
61
+
62
+ dump_db(db, os.path.join(args.root, 'summary.csv'))
63
+ time.sleep(1.0)
64
+
65
+ if __name__ == '__main__':
66
+ main()
diffab/tools/eval/similarity.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from Bio.PDB import PDBParser, Selection
3
+ from Bio.PDB.Polypeptide import three_to_one
4
+ from Bio import pairwise2
5
+ from Bio.Align import substitution_matrices
6
+
7
+ from diffab.tools.eval.base import EvalTask
8
+
9
+
10
+ def reslist_rmsd(res_list1, res_list2):
11
+ res_short, res_long = (res_list1, res_list2) if len(res_list1) < len(res_list2) else (res_list2, res_list1)
12
+ M, N = len(res_short), len(res_long)
13
+
14
+ def d(i, j):
15
+ coord_i = np.array(res_short[i]['CA'].get_coord())
16
+ coord_j = np.array(res_long[j]['CA'].get_coord())
17
+ return ((coord_i - coord_j) ** 2).sum()
18
+
19
+ SD = np.full([M, N], np.inf)
20
+ for i in range(M):
21
+ j = N - (M - i)
22
+ SD[i, j] = sum([ d(i+k, j+k) for k in range(N-j) ])
23
+
24
+ for j in range(N):
25
+ SD[M-1, j] = d(M-1, j)
26
+
27
+ for i in range(M-2, -1, -1):
28
+ for j in range((N-(M-i))-1, -1, -1):
29
+ SD[i, j] = min(
30
+ d(i, j) + SD[i+1, j+1],
31
+ SD[i, j+1]
32
+ )
33
+
34
+ min_SD = SD[0, :N-M+1].min()
35
+ best_RMSD = np.sqrt(min_SD / M)
36
+ return best_RMSD
37
+
38
+
39
+ def entity_to_seq(entity):
40
+ seq = ''
41
+ mapping = []
42
+ for res in Selection.unfold_entities(entity, 'R'):
43
+ try:
44
+ seq += three_to_one(res.get_resname())
45
+ mapping.append(res.get_id())
46
+ except KeyError:
47
+ pass
48
+ assert len(seq) == len(mapping)
49
+ return seq, mapping
50
+
51
+
52
+ def reslist_seqid(res_list1, res_list2):
53
+ seq1, _ = entity_to_seq(res_list1)
54
+ seq2, _ = entity_to_seq(res_list2)
55
+ _, seq_id = align_sequences(seq1, seq2)
56
+ return seq_id
57
+
58
+
59
+ def align_sequences(sequence_A, sequence_B, **kwargs):
60
+ """
61
+ Performs a global pairwise alignment between two sequences
62
+ using the BLOSUM62 matrix and the Needleman-Wunsch algorithm
63
+ as implemented in Biopython. Returns the alignment, the sequence
64
+ identity and the residue mapping between both original sequences.
65
+ """
66
+
67
+ def _calculate_identity(sequenceA, sequenceB):
68
+ """
69
+ Returns the percentage of identical characters between two sequences.
70
+ Assumes the sequences are aligned.
71
+ """
72
+
73
+ sa, sb, sl = sequenceA, sequenceB, len(sequenceA)
74
+ matches = [sa[i] == sb[i] for i in range(sl)]
75
+ seq_id = (100 * sum(matches)) / sl
76
+ return seq_id
77
+
78
+ # gapless_sl = sum([1 for i in range(sl) if (sa[i] != '-' and sb[i] != '-')])
79
+ # gap_id = (100 * sum(matches)) / gapless_sl
80
+ # return (seq_id, gap_id)
81
+
82
+ #
83
+ matrix = kwargs.get('matrix', substitution_matrices.load("BLOSUM62"))
84
+ gap_open = kwargs.get('gap_open', -10.0)
85
+ gap_extend = kwargs.get('gap_extend', -0.5)
86
+
87
+ alns = pairwise2.align.globalds(sequence_A, sequence_B,
88
+ matrix, gap_open, gap_extend,
89
+ penalize_end_gaps=(False, False) )
90
+
91
+ best_aln = alns[0]
92
+ aligned_A, aligned_B, score, begin, end = best_aln
93
+
94
+ # Calculate sequence identity
95
+ seq_id = _calculate_identity(aligned_A, aligned_B)
96
+ return (aligned_A, aligned_B), seq_id
97
+
98
+
99
+ def extract_reslist(model, residue_first, residue_last):
100
+ assert residue_first[0] == residue_last[0]
101
+ residue_first, residue_last = tuple(residue_first), tuple(residue_last)
102
+
103
+ chain_id = residue_first[0]
104
+ pos_first, pos_last = residue_first[1:], residue_last[1:]
105
+ chain = model[chain_id]
106
+ reslist = []
107
+ for res in Selection.unfold_entities(chain, 'R'):
108
+ pos_current = (res.id[1], res.id[2])
109
+ if pos_first <= pos_current <= pos_last:
110
+ reslist.append(res)
111
+ return reslist
112
+
113
+
114
+ def eval_similarity(task: EvalTask):
115
+ model_gen = task.get_gen_biopython_model()
116
+ model_ref = task.get_ref_biopython_model()
117
+
118
+ reslist_gen = extract_reslist(model_gen, task.residue_first, task.residue_last)
119
+ reslist_ref = extract_reslist(model_ref, task.residue_first, task.residue_last)
120
+
121
+ task.scores.update({
122
+ 'rmsd': reslist_rmsd(reslist_gen, reslist_ref),
123
+ 'seqid': reslist_seqid(reslist_gen, reslist_ref),
124
+ })
125
+ return task
diffab/tools/relax/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .run import main
2
+
3
+ if __name__ == '__main__':
4
+ main()
diffab/tools/relax/base.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ from typing import Optional, Tuple, List
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class RelaxTask:
10
+ in_path: str
11
+ current_path: str
12
+ info: dict
13
+ status: str
14
+
15
+ flexible_residue_first: Optional[Tuple] = None
16
+ flexible_residue_last: Optional[Tuple] = None
17
+
18
+ def get_in_path_with_tag(self, tag):
19
+ name, ext = os.path.splitext(self.in_path)
20
+ new_path = f'{name}_{tag}{ext}'
21
+ return new_path
22
+
23
+ def set_current_path_tag(self, tag):
24
+ new_path = self.get_in_path_with_tag(tag)
25
+ self.current_path = new_path
26
+ return new_path
27
+
28
+ def check_current_path_exists(self):
29
+ ok = os.path.exists(self.current_path)
30
+ if not ok:
31
+ self.mark_failure()
32
+ if os.path.getsize(self.current_path) == 0:
33
+ ok = False
34
+ self.mark_failure()
35
+ os.unlink(self.current_path)
36
+ return ok
37
+
38
+ def update_if_finished(self, tag):
39
+ out_path = self.get_in_path_with_tag(tag)
40
+ if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
41
+ # print('Already finished', out_path)
42
+ self.set_current_path_tag(tag)
43
+ self.mark_success()
44
+ return True
45
+ return False
46
+
47
+ def can_proceed(self):
48
+ self.check_current_path_exists()
49
+ return self.status != 'failed'
50
+
51
+ def mark_success(self):
52
+ self.status = 'success'
53
+
54
+ def mark_failure(self):
55
+ self.status = 'failed'
56
+
57
+
58
+
59
+ class TaskScanner:
60
+
61
+ def __init__(self, root, final_postfix=None):
62
+ super().__init__()
63
+ self.root = root
64
+ self.visited = set()
65
+ self.final_postfix = final_postfix
66
+
67
+ def _get_metadata(self, fpath):
68
+ json_path = os.path.join(
69
+ os.path.dirname(os.path.dirname(fpath)),
70
+ 'metadata.json'
71
+ )
72
+ tag_name = os.path.basename(os.path.dirname(fpath))
73
+ try:
74
+ with open(json_path, 'r') as f:
75
+ metadata = json.load(f)
76
+ for item in metadata['items']:
77
+ if item['tag'] == tag_name:
78
+ return item
79
+ except (json.JSONDecodeError, FileNotFoundError) as e:
80
+ return None
81
+ return None
82
+
83
+ def scan(self) -> List[RelaxTask]:
84
+ tasks = []
85
+ input_fname_pattern = '(^\d+\.pdb$|^REF\d\.pdb$)'
86
+ for parent, _, files in os.walk(self.root):
87
+ for fname in files:
88
+ fpath = os.path.join(parent, fname)
89
+ if not re.match(input_fname_pattern, fname):
90
+ continue
91
+ if os.path.getsize(fpath) == 0:
92
+ continue
93
+ if fpath in self.visited:
94
+ continue
95
+
96
+ # If finished
97
+ if self.final_postfix is not None:
98
+ fpath_name, fpath_ext = os.path.splitext(fpath)
99
+ fpath_final = f"{fpath_name}_{self.final_postfix}{fpath_ext}"
100
+ if os.path.exists(fpath_final):
101
+ continue
102
+
103
+ # Get metadata
104
+ info = self._get_metadata(fpath)
105
+ if info is None:
106
+ continue
107
+
108
+ tasks.append(RelaxTask(
109
+ in_path = fpath,
110
+ current_path = fpath,
111
+ info = info,
112
+ status = 'created',
113
+ flexible_residue_first = info.get('residue_first', None),
114
+ flexible_residue_last = info.get('residue_last', None),
115
+ ))
116
+ self.visited.add(fpath)
117
+ return tasks
diffab/tools/relax/openmm_relaxer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import io
4
+ import logging
5
+ import pdbfixer
6
+ import openmm
7
+ from openmm import app as openmm_app
8
+ from openmm import unit
9
+ ENERGY = unit.kilocalories_per_mole
10
+ LENGTH = unit.angstroms
11
+
12
+ from diffab.tools.relax.base import RelaxTask
13
+
14
+
15
+ def current_milli_time():
16
+ return round(time.time() * 1000)
17
+
18
+
19
+ def _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last):
20
+ if ch_rs_ic[0] != flexible_residue_first[0]: return False
21
+ r_first, r_last = tuple(flexible_residue_first[1:]), tuple(flexible_residue_last[1:])
22
+ rs_ic = ch_rs_ic[1:]
23
+ return r_first <= rs_ic <= r_last
24
+
25
+
26
+ class ForceFieldMinimizer(object):
27
+
28
+ def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'):
29
+ super().__init__()
30
+ self.stiffness = stiffness
31
+ self.max_iterations = max_iterations
32
+ self.tolerance = tolerance
33
+ assert platform in ('CUDA', 'CPU')
34
+ self.platform = platform
35
+
36
+ def _fix(self, pdb_str):
37
+ fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str))
38
+ fixer.findNonstandardResidues()
39
+ fixer.replaceNonstandardResidues()
40
+
41
+ fixer.findMissingResidues()
42
+ fixer.findMissingAtoms()
43
+ fixer.addMissingAtoms(seed=0)
44
+ fixer.addMissingHydrogens()
45
+
46
+ out_handle = io.StringIO()
47
+ openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True)
48
+ return out_handle.getvalue()
49
+
50
+ def _get_pdb_string(self, topology, positions):
51
+ with io.StringIO() as f:
52
+ openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True)
53
+ return f.getvalue()
54
+
55
+ def _minimize(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None):
56
+ pdb = openmm_app.PDBFile(io.StringIO(pdb_str))
57
+
58
+ force_field = openmm_app.ForceField("amber99sb.xml")
59
+ constraints = openmm_app.HBonds
60
+ system = force_field.createSystem(pdb.topology, constraints=constraints)
61
+
62
+ # Add constraints to non-generated regions
63
+ force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)")
64
+ force.addGlobalParameter("k", self.stiffness)
65
+ for p in ["x0", "y0", "z0"]:
66
+ force.addPerParticleParameter(p)
67
+
68
+ if flexible_residue_first is not None and flexible_residue_last is not None:
69
+ for i, a in enumerate(pdb.topology.atoms()):
70
+ ch_rs_ic = (a.residue.chain.id, int(a.residue.id), a.residue.insertionCode)
71
+ if not _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last) and a.element.name != "hydrogen":
72
+ force.addParticle(i, pdb.positions[i])
73
+
74
+ system.addForce(force)
75
+
76
+ # Set up the integrator and simulation
77
+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
78
+ platform = openmm.Platform.getPlatformByName("CUDA")
79
+ simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform)
80
+ simulation.context.setPositions(pdb.positions)
81
+
82
+ # Perform minimization
83
+ ret = {}
84
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
85
+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
86
+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
87
+
88
+ simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance)
89
+
90
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
91
+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
92
+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
93
+ ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions())
94
+
95
+ return ret['min_pdb'], ret
96
+
97
+ def _add_energy_remarks(self, pdb_str, ret):
98
+ pdb_lines = pdb_str.splitlines()
99
+ pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal']))
100
+ pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit']))
101
+ return "\n".join(pdb_lines)
102
+
103
+ def __call__(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None, return_info=True):
104
+ if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"):
105
+ with open(pdb_str) as f:
106
+ pdb_str = f.read()
107
+
108
+ pdb_fixed = self._fix(pdb_str)
109
+ pdb_min, ret = self._minimize(pdb_fixed, flexible_residue_first, flexible_residue_last)
110
+ pdb_min = self._add_energy_remarks(pdb_min, ret)
111
+ if return_info:
112
+ return pdb_min, ret
113
+ else:
114
+ return pdb_min
115
+
116
+
117
+ def run_openmm(task: RelaxTask):
118
+ if not task.can_proceed() :
119
+ return task
120
+ if task.update_if_finished('openmm'):
121
+ return task
122
+
123
+ try:
124
+ minimizer = ForceFieldMinimizer()
125
+ with open(task.current_path, 'r') as f:
126
+ pdb_str = f.read()
127
+
128
+ pdb_min = minimizer(
129
+ pdb_str = pdb_str,
130
+ flexible_residue_first = task.flexible_residue_first,
131
+ flexible_residue_last = task.flexible_residue_last,
132
+ return_info = False,
133
+ )
134
+ out_path = task.set_current_path_tag('openmm')
135
+ with open(out_path, 'w') as f:
136
+ f.write(pdb_min)
137
+ task.mark_success()
138
+ except ValueError as e:
139
+ logging.warning(
140
+ f'{e.__class__.__name__}: {str(e)} ({task.current_path})'
141
+ )
142
+ task.mark_failure()
143
+ return task
144
+
diffab/tools/relax/pyrosetta_relaxer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyright: reportMissingImports=false
2
+ import os
3
+ import time
4
+ import pyrosetta
5
+ from pyrosetta.rosetta.protocols.relax import FastRelax
6
+ from pyrosetta.rosetta.core.pack.task import TaskFactory
7
+ from pyrosetta.rosetta.core.pack.task import operation
8
+ from pyrosetta.rosetta.core.select import residue_selector as selections
9
+ from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action
10
+ pyrosetta.init(' '.join([
11
+ '-mute', 'all',
12
+ '-use_input_sc',
13
+ '-ignore_unrecognized_res',
14
+ '-ignore_zero_occupancy', 'false',
15
+ '-load_PDB_components', 'false',
16
+ '-relax:default_repeats', '2',
17
+ '-no_fconfig',
18
+ ]))
19
+
20
+ from diffab.tools.relax.base import RelaxTask
21
+
22
+
23
+ def current_milli_time():
24
+ return round(time.time() * 1000)
25
+
26
+
27
+ def parse_residue_position(p):
28
+ icode = None
29
+ if not p[-1].isnumeric(): # Has ICODE
30
+ icode = p[-1]
31
+
32
+ for i, c in enumerate(p):
33
+ if c.isnumeric():
34
+ break
35
+ chain = p[:i]
36
+ resseq = int(p[i:])
37
+
38
+ if icode is not None:
39
+ return chain, resseq, icode
40
+ else:
41
+ return chain, resseq
42
+
43
+
44
+ def get_scorefxn(scorefxn_name:str):
45
+ """
46
+ Gets the scorefxn with appropriate corrections.
47
+ Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550
48
+ """
49
+ import pyrosetta
50
+
51
+ corrections = {
52
+ 'beta_july15': False,
53
+ 'beta_nov16': False,
54
+ 'gen_potential': False,
55
+ 'restore_talaris_behavior': False,
56
+ }
57
+ if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name:
58
+ # beta_july15 is ref2015
59
+ corrections['beta_july15'] = True
60
+ elif 'beta_nov16' in scorefxn_name:
61
+ corrections['beta_nov16'] = True
62
+ elif 'genpot' in scorefxn_name:
63
+ corrections['gen_potential'] = True
64
+ pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True)
65
+ elif 'talaris' in scorefxn_name: #2013 and 2014
66
+ corrections['restore_talaris_behavior'] = True
67
+ else:
68
+ pass
69
+ for corr, value in corrections.items():
70
+ pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value)
71
+ return pyrosetta.create_score_function(scorefxn_name)
72
+
73
+
74
+ class RelaxRegion(object):
75
+
76
+ def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True):
77
+ super().__init__()
78
+ self.scorefxn = get_scorefxn(scorefxn)
79
+ self.fast_relax = FastRelax()
80
+ self.fast_relax.set_scorefxn(self.scorefxn)
81
+ self.fast_relax.max_iter(max_iter)
82
+ assert subset in ('all', 'target', 'nbrs')
83
+ self.subset = subset
84
+ self.move_bb = move_bb
85
+
86
+ def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last):
87
+ pose = pyrosetta.pose_from_pdb(pdb_path)
88
+ start_t = current_milli_time()
89
+ original_pose = pose.clone()
90
+
91
+ tf = TaskFactory()
92
+ tf.push_back(operation.InitializeFromCommandline())
93
+ tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position.
94
+
95
+ # Create selector for the region to be relaxed
96
+ # Turn off design and repacking on irrelevant positions
97
+ if flexible_residue_first[-1] == ' ':
98
+ flexible_residue_first = flexible_residue_first[:-1]
99
+ if flexible_residue_last[-1] == ' ':
100
+ flexible_residue_last = flexible_residue_last[:-1]
101
+ if self.subset != 'all':
102
+ gen_selector = selections.ResidueIndexSelector()
103
+ gen_selector.set_index_range(
104
+ pose.pdb_info().pdb2pose(*flexible_residue_first),
105
+ pose.pdb_info().pdb2pose(*flexible_residue_last),
106
+ )
107
+ nbr_selector = selections.NeighborhoodResidueSelector()
108
+ nbr_selector.set_focus_selector(gen_selector)
109
+ nbr_selector.set_include_focus_in_subset(True)
110
+
111
+ if self.subset == 'nbrs':
112
+ subset_selector = nbr_selector
113
+ elif self.subset == 'target':
114
+ subset_selector = gen_selector
115
+
116
+ prevent_repacking_rlt = operation.PreventRepackingRLT()
117
+ prevent_subset_repacking = operation.OperateOnResidueSubset(
118
+ prevent_repacking_rlt,
119
+ subset_selector,
120
+ flip_subset=True,
121
+ )
122
+ tf.push_back(prevent_subset_repacking)
123
+
124
+ scorefxn = self.scorefxn
125
+ fr = self.fast_relax
126
+
127
+ pose = original_pose.clone()
128
+ pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long()
129
+ for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1):
130
+ pos_list.append(pos)
131
+ # basic_idealize(pose, pos_list, scorefxn, fast=True)
132
+
133
+ mmf = MoveMapFactory()
134
+ if self.move_bb:
135
+ mmf.add_bb_action(move_map_action.mm_enable, gen_selector)
136
+ mmf.add_chi_action(move_map_action.mm_enable, subset_selector)
137
+ mm = mmf.create_movemap_from_pose(pose)
138
+
139
+ fr.set_movemap(mm)
140
+ fr.set_task_factory(tf)
141
+ fr.apply(pose)
142
+
143
+ e_before = scorefxn(original_pose)
144
+ e_relax = scorefxn(pose)
145
+ # print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000))
146
+ # print(' > Energy (before): %.4f' % scorefxn(original_pose))
147
+ # print(' > Energy (optimized): %.4f' % scorefxn(pose))
148
+ return pose, e_before, e_relax
149
+
150
+
151
+ def run_pyrosetta(task: RelaxTask):
152
+ if not task.can_proceed() :
153
+ return task
154
+ if task.update_if_finished('rosetta'):
155
+ return task
156
+
157
+ minimizer = RelaxRegion()
158
+ pose_min, _, _ = minimizer(
159
+ pdb_path = task.current_path,
160
+ flexible_residue_first = task.flexible_residue_first,
161
+ flexible_residue_last = task.flexible_residue_last,
162
+ )
163
+
164
+ out_path = task.set_current_path_tag('rosetta')
165
+ pose_min.dump_pdb(out_path)
166
+ task.mark_success()
167
+ return task
168
+
169
+
170
+ def run_pyrosetta_fixbb(task: RelaxTask):
171
+ if not task.can_proceed() :
172
+ return task
173
+ if task.update_if_finished('fixbb'):
174
+ return task
175
+
176
+ minimizer = RelaxRegion(move_bb=False)
177
+ pose_min, _, _ = minimizer(
178
+ pdb_path = task.current_path,
179
+ flexible_residue_first = task.flexible_residue_first,
180
+ flexible_residue_last = task.flexible_residue_last,
181
+ )
182
+
183
+ out_path = task.set_current_path_tag('fixbb')
184
+ pose_min.dump_pdb(out_path)
185
+ task.mark_success()
186
+ return task
187
+
188
+
189
+
diffab/tools/relax/run.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ray
3
+ import time
4
+
5
+ from diffab.tools.relax.openmm_relaxer import run_openmm
6
+ from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb
7
+ from diffab.tools.relax.base import TaskScanner
8
+
9
+
10
+ @ray.remote(num_gpus=1/8, num_cpus=1)
11
+ def run_openmm_remote(task):
12
+ return run_openmm(task)
13
+
14
+
15
+ @ray.remote(num_cpus=1)
16
+ def run_pyrosetta_remote(task):
17
+ return run_pyrosetta(task)
18
+
19
+
20
+ @ray.remote(num_cpus=1)
21
+ def run_pyrosetta_fixbb_remote(task):
22
+ return run_pyrosetta_fixbb(task)
23
+
24
+
25
+ @ray.remote
26
+ def pipeline_openmm_pyrosetta(task):
27
+ funcs = [
28
+ run_openmm_remote,
29
+ run_pyrosetta_remote,
30
+ ]
31
+ for fn in funcs:
32
+ task = fn.remote(task)
33
+ return ray.get(task)
34
+
35
+
36
+ @ray.remote
37
+ def pipeline_pyrosetta(task):
38
+ funcs = [
39
+ run_pyrosetta_remote,
40
+ ]
41
+ for fn in funcs:
42
+ task = fn.remote(task)
43
+ return ray.get(task)
44
+
45
+
46
+ @ray.remote
47
+ def pipeline_pyrosetta_fixbb(task):
48
+ funcs = [
49
+ run_pyrosetta_fixbb_remote,
50
+ ]
51
+ for fn in funcs:
52
+ task = fn.remote(task)
53
+ return ray.get(task)
54
+
55
+
56
+ pipeline_dict = {
57
+ 'openmm_pyrosetta': pipeline_openmm_pyrosetta,
58
+ 'pyrosetta': pipeline_pyrosetta,
59
+ 'pyrosetta_fixbb': pipeline_pyrosetta_fixbb,
60
+ }
61
+
62
+
63
+ def main():
64
+ ray.init()
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument('--root', type=str, default='./results')
67
+ parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta)
68
+ args = parser.parse_args()
69
+
70
+ final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta'
71
+ scanner = TaskScanner(args.root, final_postfix=final_pfx)
72
+ while True:
73
+ tasks = scanner.scan()
74
+ futures = [args.pipeline.remote(t) for t in tasks]
75
+ if len(futures) > 0:
76
+ print(f'Submitted {len(futures)} tasks.')
77
+ while len(futures) > 0:
78
+ done_ids, futures = ray.wait(futures, num_returns=1)
79
+ for done_id in done_ids:
80
+ done_task = ray.get(done_id)
81
+ print(f'Remaining {len(futures)}. Finished {done_task.current_path}')
82
+ time.sleep(1.0)
83
+
84
+ if __name__ == '__main__':
85
+ main()
diffab/tools/renumber/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .run import renumber
diffab/tools/renumber/__main__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .run import main
2
+
3
+ if __name__ == '__main__':
4
+ main()
5
+
diffab/tools/renumber/run.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import abnumber
3
+ from Bio import PDB
4
+ from Bio.PDB import Model, Chain, Residue, Selection
5
+ from Bio.Data import SCOPData
6
+ from typing import List, Tuple
7
+
8
+
9
+ def biopython_chain_to_sequence(chain: Chain.Chain):
10
+ residue_list = Selection.unfold_entities(chain, 'R')
11
+ seq = ''.join([SCOPData.protein_letters_3to1.get(r.resname, 'X') for r in residue_list])
12
+ return seq, residue_list
13
+
14
+
15
+ def assign_number_to_sequence(seq):
16
+ abchain = abnumber.Chain(seq, scheme='chothia')
17
+ offset = seq.index(abchain.seq)
18
+ if not (offset >= 0):
19
+ raise ValueError(
20
+ 'The identified Fv sequence is not a subsequence of the original sequence.'
21
+ )
22
+
23
+ numbers = [None for _ in range(len(seq))]
24
+ for i, (pos, aa) in enumerate(abchain):
25
+ resseq = pos.number
26
+ icode = pos.letter if pos.letter else ' '
27
+ numbers[i+offset] = (resseq, icode)
28
+ return numbers, abchain
29
+
30
+
31
+ def renumber_biopython_chain(chain_id, residue_list: List[Residue.Residue], numbers: List[Tuple[int, str]]):
32
+ chain = Chain.Chain(chain_id)
33
+ for residue, number in zip(residue_list, numbers):
34
+ if number is None:
35
+ continue
36
+ residue = residue.copy()
37
+ new_id = (residue.id[0], number[0], number[1])
38
+ residue.id = new_id
39
+ chain.add(residue)
40
+ return chain
41
+
42
+
43
+ def renumber(in_pdb, out_pdb, return_other_chains=False):
44
+ parser = PDB.PDBParser(QUIET=True)
45
+ structure = parser.get_structure(None, in_pdb)
46
+ model = structure[0]
47
+ model_new = Model.Model(0)
48
+
49
+ heavy_chains, light_chains, other_chains = [], [], []
50
+
51
+ for chain in model:
52
+ try:
53
+ seq, reslist = biopython_chain_to_sequence(chain)
54
+ numbers, abchain = assign_number_to_sequence(seq)
55
+ chain_new = renumber_biopython_chain(chain.id, reslist, numbers)
56
+ print(f'[INFO] Renumbered chain {chain_new.id} ({abchain.chain_type})')
57
+ if abchain.chain_type == 'H':
58
+ heavy_chains.append(chain_new.id)
59
+ elif abchain.chain_type in ('K', 'L'):
60
+ light_chains.append(chain_new.id)
61
+ except abnumber.ChainParseError as e:
62
+ print(f'[INFO] Chain {chain.id} does not contain valid Fv: {str(e)}')
63
+ chain_new = chain.copy()
64
+ other_chains.append(chain_new.id)
65
+ model_new.add(chain_new)
66
+
67
+ pdb_io = PDB.PDBIO()
68
+ pdb_io.set_structure(model_new)
69
+ pdb_io.save(out_pdb)
70
+ if return_other_chains:
71
+ return heavy_chains, light_chains, other_chains
72
+ else:
73
+ return heavy_chains, light_chains
74
+
75
+
76
+ def main():
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('in_pdb', type=str)
79
+ parser.add_argument('out_pdb', type=str)
80
+ args = parser.parse_args()
81
+
82
+ renumber(args.in_pdb, args.out_pdb)
83
+
84
+ if __name__ == '__main__':
85
+ main()
diffab/tools/runner/design_for_pdb.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import copy
4
+ import json
5
+ from tqdm.auto import tqdm
6
+ from torch.utils.data import DataLoader
7
+
8
+ from diffab.datasets.custom import preprocess_antibody_structure
9
+ from diffab.models import get_model
10
+ from diffab.modules.common.geometry import reconstruct_backbone_partially
11
+ from diffab.modules.common.so3 import so3vec_to_rotation
12
+ from diffab.utils.inference import RemoveNative
13
+ from diffab.utils.protein.writers import save_pdb
14
+ from diffab.utils.train import recursive_to
15
+ from diffab.utils.misc import *
16
+ from diffab.utils.data import *
17
+ from diffab.utils.transforms import *
18
+ from diffab.utils.inference import *
19
+ from diffab.tools.renumber import renumber as renumber_antibody
20
+
21
+
22
+ def create_data_variants(config, structure_factory):
23
+ structure = structure_factory()
24
+ structure_id = structure['id']
25
+
26
+ data_variants = []
27
+ if config.mode == 'single_cdr':
28
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
29
+ for cdr_name in cdrs:
30
+ transform = Compose([
31
+ MaskSingleCDR(cdr_name, augmentation=False),
32
+ MergeChains(),
33
+ ])
34
+ data_var = transform(structure_factory())
35
+ residue_first, residue_last = get_residue_first_last(data_var)
36
+ data_variants.append({
37
+ 'data': data_var,
38
+ 'name': f'{structure_id}-{cdr_name}',
39
+ 'tag': f'{cdr_name}',
40
+ 'cdr': cdr_name,
41
+ 'residue_first': residue_first,
42
+ 'residue_last': residue_last,
43
+ })
44
+ elif config.mode == 'multiple_cdrs':
45
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
46
+ transform = Compose([
47
+ MaskMultipleCDRs(selection=cdrs, augmentation=False),
48
+ MergeChains(),
49
+ ])
50
+ data_var = transform(structure_factory())
51
+ data_variants.append({
52
+ 'data': data_var,
53
+ 'name': f'{structure_id}-MultipleCDRs',
54
+ 'tag': 'MultipleCDRs',
55
+ 'cdrs': cdrs,
56
+ 'residue_first': None,
57
+ 'residue_last': None,
58
+ })
59
+ elif config.mode == 'full':
60
+ transform = Compose([
61
+ MaskAntibody(),
62
+ MergeChains(),
63
+ ])
64
+ data_var = transform(structure_factory())
65
+ data_variants.append({
66
+ 'data': data_var,
67
+ 'name': f'{structure_id}-Full',
68
+ 'tag': 'Full',
69
+ 'residue_first': None,
70
+ 'residue_last': None,
71
+ })
72
+ elif config.mode == 'abopt':
73
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
74
+ for cdr_name in cdrs:
75
+ transform = Compose([
76
+ MaskSingleCDR(cdr_name, augmentation=False),
77
+ MergeChains(),
78
+ ])
79
+ data_var = transform(structure_factory())
80
+ residue_first, residue_last = get_residue_first_last(data_var)
81
+ for opt_step in config.sampling.optimize_steps:
82
+ data_variants.append({
83
+ 'data': data_var,
84
+ 'name': f'{structure_id}-{cdr_name}-O{opt_step}',
85
+ 'tag': f'{cdr_name}-O{opt_step}',
86
+ 'cdr': cdr_name,
87
+ 'opt_step': opt_step,
88
+ 'residue_first': residue_first,
89
+ 'residue_last': residue_last,
90
+ })
91
+ else:
92
+ raise ValueError(f'Unknown mode: {config.mode}.')
93
+ return data_variants
94
+
95
+
96
+ def design_for_pdb(args):
97
+ # Load configs
98
+ config, config_name = load_config(args.config)
99
+ seed_all(args.seed if args.seed is not None else config.sampling.seed)
100
+
101
+ # Structure loading
102
+ data_id = os.path.basename(args.pdb_path)
103
+ if args.no_renumber:
104
+ pdb_path = args.pdb_path
105
+ else:
106
+ in_pdb_path = args.pdb_path
107
+ out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb'
108
+ heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path)
109
+ pdb_path = out_pdb_path
110
+
111
+ if args.heavy is None and len(heavy_chains) > 0:
112
+ args.heavy = heavy_chains[0]
113
+ if args.light is None and len(light_chains) > 0:
114
+ args.light = light_chains[0]
115
+ if args.heavy is None and args.light is None:
116
+ raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.")
117
+ get_structure = lambda: preprocess_antibody_structure({
118
+ 'id': data_id,
119
+ 'pdb_path': pdb_path,
120
+ 'heavy_id': args.heavy,
121
+ # If the input is a nanobody, the light chain will be ignores
122
+ 'light_id': args.light,
123
+ })
124
+
125
+ # Logging
126
+ structure_ = get_structure()
127
+ structure_id = structure_['id']
128
+ tag_postfix = '_%s' % args.tag if args.tag else ''
129
+ log_dir = get_new_log_dir(
130
+ os.path.join(args.out_root, config_name + tag_postfix),
131
+ prefix=data_id
132
+ )
133
+ logger = get_logger('sample', log_dir)
134
+ logger.info(f'Data ID: {structure_["id"]}')
135
+ logger.info(f'Results will be saved to {log_dir}')
136
+ data_native = MergeChains()(structure_)
137
+ save_pdb(data_native, os.path.join(log_dir, 'reference.pdb'))
138
+
139
+ # Load checkpoint and model
140
+ logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint))
141
+ ckpt = torch.load(config.model.checkpoint, map_location='cpu')
142
+ cfg_ckpt = ckpt['config']
143
+ model = get_model(cfg_ckpt.model).to(args.device)
144
+ lsd = model.load_state_dict(ckpt['model'])
145
+ logger.info(str(lsd))
146
+
147
+ # Make data variants
148
+ data_variants = create_data_variants(
149
+ config = config,
150
+ structure_factory = get_structure,
151
+ )
152
+
153
+ # Save metadata
154
+ metadata = {
155
+ 'identifier': structure_id,
156
+ 'index': data_id,
157
+ 'config': args.config,
158
+ 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants],
159
+ }
160
+ with open(os.path.join(log_dir, 'metadata.json'), 'w') as f:
161
+ json.dump(metadata, f, indent=2)
162
+
163
+ # Start sampling
164
+ collate_fn = PaddingCollate(eight=False)
165
+ inference_tfm = [ PatchAroundAnchor(), ]
166
+ if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode
167
+ inference_tfm.append(RemoveNative(
168
+ remove_structure = config.sampling.sample_structure,
169
+ remove_sequence = config.sampling.sample_sequence,
170
+ ))
171
+ inference_tfm = Compose(inference_tfm)
172
+
173
+ for variant in data_variants:
174
+ os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True)
175
+ logger.info(f"Start sampling for: {variant['tag']}")
176
+
177
+ save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization
178
+
179
+ data_cropped = inference_tfm(
180
+ copy.deepcopy(variant['data'])
181
+ )
182
+ data_list_repeat = [ data_cropped ] * config.sampling.num_samples
183
+ loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
184
+
185
+ count = 0
186
+ for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True):
187
+ torch.set_grad_enabled(False)
188
+ model.eval()
189
+ batch = recursive_to(batch, args.device)
190
+ if 'abopt' in config.mode:
191
+ # Antibody optimization starting from native
192
+ traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={
193
+ 'pbar': True,
194
+ 'sample_structure': config.sampling.sample_structure,
195
+ 'sample_sequence': config.sampling.sample_sequence,
196
+ })
197
+ else:
198
+ # De novo design
199
+ traj_batch = model.sample(batch, sample_opt={
200
+ 'pbar': True,
201
+ 'sample_structure': config.sampling.sample_structure,
202
+ 'sample_sequence': config.sampling.sample_sequence,
203
+ })
204
+
205
+ aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid.
206
+ pos_atom_new, mask_atom_new = reconstruct_backbone_partially(
207
+ pos_ctx = batch['pos_heavyatom'],
208
+ R_new = so3vec_to_rotation(traj_batch[0][0]),
209
+ t_new = traj_batch[0][1],
210
+ aa = aa_new,
211
+ chain_nb = batch['chain_nb'],
212
+ res_nb = batch['res_nb'],
213
+ mask_atoms = batch['mask_heavyatom'],
214
+ mask_recons = batch['generate_flag'],
215
+ )
216
+ aa_new = aa_new.cpu()
217
+ pos_atom_new = pos_atom_new.cpu()
218
+ mask_atom_new = mask_atom_new.cpu()
219
+
220
+ for i in range(aa_new.size(0)):
221
+ data_tmpl = variant['data']
222
+ aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx'])
223
+ mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx'])
224
+ pos_ha = (
225
+ apply_patch_to_tensor(
226
+ data_tmpl['pos_heavyatom'],
227
+ pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
228
+ data_cropped['patch_idx']
229
+ )
230
+ )
231
+
232
+ save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, ))
233
+ save_pdb({
234
+ 'chain_nb': data_tmpl['chain_nb'],
235
+ 'chain_id': data_tmpl['chain_id'],
236
+ 'resseq': data_tmpl['resseq'],
237
+ 'icode': data_tmpl['icode'],
238
+ # Generated
239
+ 'aa': aa,
240
+ 'mask_heavyatom': mask_ha,
241
+ 'pos_heavyatom': pos_ha,
242
+ }, path=save_path)
243
+ # save_pdb({
244
+ # 'chain_nb': data_cropped['chain_nb'],
245
+ # 'chain_id': data_cropped['chain_id'],
246
+ # 'resseq': data_cropped['resseq'],
247
+ # 'icode': data_cropped['icode'],
248
+ # # Generated
249
+ # 'aa': aa_new[i],
250
+ # 'mask_heavyatom': mask_atom_new[i],
251
+ # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
252
+ # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, )))
253
+ count += 1
254
+
255
+ logger.info('Finished.\n')
256
+
257
+
258
+ def args_from_cmdline():
259
+ parser = argparse.ArgumentParser()
260
+ parser.add_argument('pdb_path', type=str)
261
+ parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.')
262
+ parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.')
263
+ parser.add_argument('--no_renumber', action='store_true', default=False)
264
+ parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
265
+ parser.add_argument('-o', '--out_root', type=str, default='./results')
266
+ parser.add_argument('-t', '--tag', type=str, default='')
267
+ parser.add_argument('-s', '--seed', type=int, default=None)
268
+ parser.add_argument('-d', '--device', type=str, default='cuda')
269
+ parser.add_argument('-b', '--batch_size', type=int, default=16)
270
+ args = parser.parse_args()
271
+ return args
272
+
273
+
274
+ def args_factory(**kwargs):
275
+ default_args = EasyDict(
276
+ heavy = 'H',
277
+ light = 'L',
278
+ no_renumber = False,
279
+ config = './configs/test/codesign_single.yml',
280
+ out_root = './results',
281
+ tag = '',
282
+ seed = None,
283
+ device = 'cuda',
284
+ batch_size = 16
285
+ )
286
+ default_args.update(kwargs)
287
+ return default_args
288
+
289
+
290
+ if __name__ == '__main__':
291
+ design_for_pdb(args_from_cmdline())
diffab/tools/runner/design_for_testset.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import copy
4
+ import json
5
+ from tqdm.auto import tqdm
6
+ from torch.utils.data import DataLoader
7
+
8
+ from diffab.datasets import get_dataset
9
+ from diffab.models import get_model
10
+ from diffab.modules.common.geometry import reconstruct_backbone_partially
11
+ from diffab.modules.common.so3 import so3vec_to_rotation
12
+ from diffab.utils.inference import RemoveNative
13
+ from diffab.utils.protein.writers import save_pdb
14
+ from diffab.utils.train import recursive_to
15
+ from diffab.utils.misc import *
16
+ from diffab.utils.data import *
17
+ from diffab.utils.transforms import *
18
+ from diffab.utils.inference import *
19
+
20
+
21
+ def create_data_variants(config, structure_factory):
22
+ structure = structure_factory()
23
+ structure_id = structure['id']
24
+
25
+ data_variants = []
26
+ if config.mode == 'single_cdr':
27
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
28
+ for cdr_name in cdrs:
29
+ transform = Compose([
30
+ MaskSingleCDR(cdr_name, augmentation=False),
31
+ MergeChains(),
32
+ ])
33
+ data_var = transform(structure_factory())
34
+ residue_first, residue_last = get_residue_first_last(data_var)
35
+ data_variants.append({
36
+ 'data': data_var,
37
+ 'name': f'{structure_id}-{cdr_name}',
38
+ 'tag': f'{cdr_name}',
39
+ 'cdr': cdr_name,
40
+ 'residue_first': residue_first,
41
+ 'residue_last': residue_last,
42
+ })
43
+ elif config.mode == 'multiple_cdrs':
44
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
45
+ transform = Compose([
46
+ MaskMultipleCDRs(selection=cdrs, augmentation=False),
47
+ MergeChains(),
48
+ ])
49
+ data_var = transform(structure_factory())
50
+ data_variants.append({
51
+ 'data': data_var,
52
+ 'name': f'{structure_id}-MultipleCDRs',
53
+ 'tag': 'MultipleCDRs',
54
+ 'cdrs': cdrs,
55
+ 'residue_first': None,
56
+ 'residue_last': None,
57
+ })
58
+ elif config.mode == 'full':
59
+ transform = Compose([
60
+ MaskAntibody(),
61
+ MergeChains(),
62
+ ])
63
+ data_var = transform(structure_factory())
64
+ data_variants.append({
65
+ 'data': data_var,
66
+ 'name': f'{structure_id}-Full',
67
+ 'tag': 'Full',
68
+ 'residue_first': None,
69
+ 'residue_last': None,
70
+ })
71
+ elif config.mode == 'abopt':
72
+ cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs)))
73
+ for cdr_name in cdrs:
74
+ transform = Compose([
75
+ MaskSingleCDR(cdr_name, augmentation=False),
76
+ MergeChains(),
77
+ ])
78
+ data_var = transform(structure_factory())
79
+ residue_first, residue_last = get_residue_first_last(data_var)
80
+ for opt_step in config.sampling.optimize_steps:
81
+ data_variants.append({
82
+ 'data': data_var,
83
+ 'name': f'{structure_id}-{cdr_name}-O{opt_step}',
84
+ 'tag': f'{cdr_name}-O{opt_step}',
85
+ 'cdr': cdr_name,
86
+ 'opt_step': opt_step,
87
+ 'residue_first': residue_first,
88
+ 'residue_last': residue_last,
89
+ })
90
+ else:
91
+ raise ValueError(f'Unknown mode: {config.mode}.')
92
+ return data_variants
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument('index', type=int)
97
+ parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml')
98
+ parser.add_argument('-o', '--out_root', type=str, default='./results')
99
+ parser.add_argument('-t', '--tag', type=str, default='')
100
+ parser.add_argument('-s', '--seed', type=int, default=None)
101
+ parser.add_argument('-d', '--device', type=str, default='cuda')
102
+ parser.add_argument('-b', '--batch_size', type=int, default=16)
103
+ args = parser.parse_args()
104
+
105
+ # Load configs
106
+ config, config_name = load_config(args.config)
107
+ seed_all(args.seed if args.seed is not None else config.sampling.seed)
108
+
109
+ # Testset
110
+ dataset = get_dataset(config.dataset.test)
111
+ get_structure = lambda: dataset[args.index]
112
+
113
+ # Logging
114
+ structure_ = get_structure()
115
+ structure_id = structure_['id']
116
+ tag_postfix = '_%s' % args.tag if args.tag else ''
117
+ log_dir = get_new_log_dir(os.path.join(args.out_root, config_name + tag_postfix), prefix='%04d_%s' % (args.index, structure_['id']))
118
+ logger = get_logger('sample', log_dir)
119
+ logger.info('Data ID: %s' % structure_['id'])
120
+ data_native = MergeChains()(structure_)
121
+ save_pdb(data_native, os.path.join(log_dir, 'reference.pdb'))
122
+
123
+ # Load checkpoint and model
124
+ logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint))
125
+ ckpt = torch.load(config.model.checkpoint, map_location='cpu')
126
+ cfg_ckpt = ckpt['config']
127
+ model = get_model(cfg_ckpt.model).to(args.device)
128
+ lsd = model.load_state_dict(ckpt['model'])
129
+ logger.info(str(lsd))
130
+
131
+ # Make data variants
132
+ data_variants = create_data_variants(
133
+ config = config,
134
+ structure_factory = get_structure,
135
+ )
136
+
137
+ # Save metadata
138
+ metadata = {
139
+ 'identifier': structure_id,
140
+ 'index': args.index,
141
+ 'config': args.config,
142
+ 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants],
143
+ }
144
+ with open(os.path.join(log_dir, 'metadata.json'), 'w') as f:
145
+ json.dump(metadata, f, indent=2)
146
+
147
+ # Start sampling
148
+ collate_fn = PaddingCollate(eight=False)
149
+ inference_tfm = [ PatchAroundAnchor(), ]
150
+ if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode
151
+ inference_tfm.append(RemoveNative(
152
+ remove_structure = config.sampling.sample_structure,
153
+ remove_sequence = config.sampling.sample_sequence,
154
+ ))
155
+ inference_tfm = Compose(inference_tfm)
156
+
157
+ for variant in data_variants:
158
+ os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True)
159
+ logger.info(f"Start sampling for: {variant['tag']}")
160
+
161
+ save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization
162
+
163
+ data_cropped = inference_tfm(
164
+ copy.deepcopy(variant['data'])
165
+ )
166
+ data_list_repeat = [ data_cropped ] * config.sampling.num_samples
167
+ loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
168
+
169
+ count = 0
170
+ for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True):
171
+ torch.set_grad_enabled(False)
172
+ model.eval()
173
+ batch = recursive_to(batch, args.device)
174
+ if 'abopt' in config.mode:
175
+ # Antibody optimization starting from native
176
+ traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={
177
+ 'pbar': True,
178
+ 'sample_structure': config.sampling.sample_structure,
179
+ 'sample_sequence': config.sampling.sample_sequence,
180
+ })
181
+ else:
182
+ # De novo design
183
+ traj_batch = model.sample(batch, sample_opt={
184
+ 'pbar': True,
185
+ 'sample_structure': config.sampling.sample_structure,
186
+ 'sample_sequence': config.sampling.sample_sequence,
187
+ })
188
+
189
+ aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid.
190
+ pos_atom_new, mask_atom_new = reconstruct_backbone_partially(
191
+ pos_ctx = batch['pos_heavyatom'],
192
+ R_new = so3vec_to_rotation(traj_batch[0][0]),
193
+ t_new = traj_batch[0][1],
194
+ aa = aa_new,
195
+ chain_nb = batch['chain_nb'],
196
+ res_nb = batch['res_nb'],
197
+ mask_atoms = batch['mask_heavyatom'],
198
+ mask_recons = batch['generate_flag'],
199
+ )
200
+ aa_new = aa_new.cpu()
201
+ pos_atom_new = pos_atom_new.cpu()
202
+ mask_atom_new = mask_atom_new.cpu()
203
+
204
+ for i in range(aa_new.size(0)):
205
+ data_tmpl = variant['data']
206
+ aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx'])
207
+ mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx'])
208
+ pos_ha = (
209
+ apply_patch_to_tensor(
210
+ data_tmpl['pos_heavyatom'],
211
+ pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
212
+ data_cropped['patch_idx']
213
+ )
214
+ )
215
+
216
+ save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, ))
217
+ save_pdb({
218
+ 'chain_nb': data_tmpl['chain_nb'],
219
+ 'chain_id': data_tmpl['chain_id'],
220
+ 'resseq': data_tmpl['resseq'],
221
+ 'icode': data_tmpl['icode'],
222
+ # Generated
223
+ 'aa': aa,
224
+ 'mask_heavyatom': mask_ha,
225
+ 'pos_heavyatom': pos_ha,
226
+ }, path=save_path)
227
+ # save_pdb({
228
+ # 'chain_nb': data_cropped['chain_nb'],
229
+ # 'chain_id': data_cropped['chain_id'],
230
+ # 'resseq': data_cropped['resseq'],
231
+ # 'icode': data_cropped['icode'],
232
+ # # Generated
233
+ # 'aa': aa_new[i],
234
+ # 'mask_heavyatom': mask_atom_new[i],
235
+ # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(),
236
+ # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, )))
237
+ count += 1
238
+
239
+ logger.info('Finished.\n')
240
+
241
+
242
+ if __name__ == '__main__':
243
+ main()
diffab/utils/data.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data._utils.collate import default_collate
4
+
5
+
6
+ DEFAULT_PAD_VALUES = {
7
+ 'aa': 21,
8
+ 'chain_id': ' ',
9
+ 'icode': ' ',
10
+ }
11
+
12
+ DEFAULT_NO_PADDING = {
13
+ 'origin',
14
+ }
15
+
16
+ class PaddingCollate(object):
17
+
18
+ def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True):
19
+ super().__init__()
20
+ self.length_ref_key = length_ref_key
21
+ self.pad_values = pad_values
22
+ self.no_padding = no_padding
23
+ self.eight = eight
24
+
25
+ @staticmethod
26
+ def _pad_last(x, n, value=0):
27
+ if isinstance(x, torch.Tensor):
28
+ assert x.size(0) <= n
29
+ if x.size(0) == n:
30
+ return x
31
+ pad_size = [n - x.size(0)] + list(x.shape[1:])
32
+ pad = torch.full(pad_size, fill_value=value).to(x)
33
+ return torch.cat([x, pad], dim=0)
34
+ elif isinstance(x, list):
35
+ pad = [value] * (n - len(x))
36
+ return x + pad
37
+ else:
38
+ return x
39
+
40
+ @staticmethod
41
+ def _get_pad_mask(l, n):
42
+ return torch.cat([
43
+ torch.ones([l], dtype=torch.bool),
44
+ torch.zeros([n-l], dtype=torch.bool)
45
+ ], dim=0)
46
+
47
+ @staticmethod
48
+ def _get_common_keys(list_of_dict):
49
+ keys = set(list_of_dict[0].keys())
50
+ for d in list_of_dict[1:]:
51
+ keys = keys.intersection(d.keys())
52
+ return keys
53
+
54
+
55
+ def _get_pad_value(self, key):
56
+ if key not in self.pad_values:
57
+ return 0
58
+ return self.pad_values[key]
59
+
60
+ def __call__(self, data_list):
61
+ max_length = max([data[self.length_ref_key].size(0) for data in data_list])
62
+ keys = self._get_common_keys(data_list)
63
+
64
+ if self.eight:
65
+ max_length = math.ceil(max_length / 8) * 8
66
+ data_list_padded = []
67
+ for data in data_list:
68
+ data_padded = {
69
+ k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v
70
+ for k, v in data.items()
71
+ if k in keys
72
+ }
73
+ data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length)
74
+ data_list_padded.append(data_padded)
75
+ return default_collate(data_list_padded)
76
+
77
+
78
+ def apply_patch_to_tensor(x_full, x_patch, patch_idx):
79
+ """
80
+ Args:
81
+ x_full: (N, ...)
82
+ x_patch: (M, ...)
83
+ patch_idx: (M, )
84
+ Returns:
85
+ (N, ...)
86
+ """
87
+ x_full = x_full.clone()
88
+ x_full[patch_idx] = x_patch
89
+ return x_full
diffab/utils/inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .protein import constants
3
+
4
+
5
+ def find_cdrs(structure):
6
+ cdrs = []
7
+ if structure['heavy'] is not None:
8
+ flag = structure['heavy']['cdr_flag']
9
+ if int(constants.CDR.H1) in flag:
10
+ cdrs.append('H_CDR1')
11
+ if int(constants.CDR.H2) in flag:
12
+ cdrs.append('H_CDR2')
13
+ if int(constants.CDR.H3) in flag:
14
+ cdrs.append('H_CDR3')
15
+
16
+ if structure['light'] is not None:
17
+ flag = structure['light']['cdr_flag']
18
+ if int(constants.CDR.L1) in flag:
19
+ cdrs.append('L_CDR1')
20
+ if int(constants.CDR.L2) in flag:
21
+ cdrs.append('L_CDR2')
22
+ if int(constants.CDR.L3) in flag:
23
+ cdrs.append('L_CDR3')
24
+
25
+ return cdrs
26
+
27
+
28
+ def get_residue_first_last(data):
29
+ loop_flag = data['generate_flag']
30
+ loop_idx = torch.arange(loop_flag.size(0))[loop_flag]
31
+ idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item()
32
+ residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first])
33
+ residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last])
34
+ return residue_first, residue_last
35
+
36
+
37
+ class RemoveNative(object):
38
+
39
+ def __init__(self, remove_structure, remove_sequence):
40
+ super().__init__()
41
+ self.remove_structure = remove_structure
42
+ self.remove_sequence = remove_sequence
43
+
44
+ def __call__(self, data):
45
+ generate_flag = data['generate_flag'].clone()
46
+ if self.remove_sequence:
47
+ data['aa'] = torch.where(
48
+ generate_flag,
49
+ torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop
50
+ data['aa']
51
+ )
52
+
53
+ if self.remove_structure:
54
+ data['pos_heavyatom'] = torch.where(
55
+ generate_flag[:, None, None].expand(data['pos_heavyatom'].shape),
56
+ torch.randn_like(data['pos_heavyatom']) * 10,
57
+ data['pos_heavyatom']
58
+ )
59
+
60
+ return data
diffab/utils/misc.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import logging
5
+ from typing import OrderedDict
6
+ import torch
7
+ import torch.linalg
8
+ import numpy as np
9
+ import yaml
10
+ from easydict import EasyDict
11
+ from glob import glob
12
+
13
+
14
+ class BlackHole(object):
15
+ def __setattr__(self, name, value):
16
+ pass
17
+
18
+ def __call__(self, *args, **kwargs):
19
+ return self
20
+
21
+ def __getattr__(self, name):
22
+ return self
23
+
24
+
25
+ class Counter(object):
26
+ def __init__(self, start=0):
27
+ super().__init__()
28
+ self.now = start
29
+
30
+ def step(self, delta=1):
31
+ prev = self.now
32
+ self.now += delta
33
+ return prev
34
+
35
+
36
+ def get_logger(name, log_dir=None):
37
+ logger = logging.getLogger(name)
38
+ logger.setLevel(logging.DEBUG)
39
+ formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
40
+
41
+ stream_handler = logging.StreamHandler()
42
+ stream_handler.setLevel(logging.DEBUG)
43
+ stream_handler.setFormatter(formatter)
44
+ logger.addHandler(stream_handler)
45
+
46
+ if log_dir is not None:
47
+ file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
48
+ file_handler.setLevel(logging.DEBUG)
49
+ file_handler.setFormatter(formatter)
50
+ logger.addHandler(file_handler)
51
+
52
+ return logger
53
+
54
+
55
+ def get_new_log_dir(root='./logs', prefix='', tag=''):
56
+ fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
57
+ if prefix != '':
58
+ fn = prefix + '_' + fn
59
+ if tag != '':
60
+ fn = fn + '_' + tag
61
+ log_dir = os.path.join(root, fn)
62
+ os.makedirs(log_dir)
63
+ return log_dir
64
+
65
+
66
+ def seed_all(seed):
67
+ torch.backends.cudnn.deterministic = True
68
+ torch.manual_seed(seed)
69
+ torch.cuda.manual_seed_all(seed)
70
+ np.random.seed(seed)
71
+ random.seed(seed)
72
+
73
+
74
+ def inf_iterator(iterable):
75
+ iterator = iterable.__iter__()
76
+ while True:
77
+ try:
78
+ yield iterator.__next__()
79
+ except StopIteration:
80
+ iterator = iterable.__iter__()
81
+
82
+
83
+ def log_hyperparams(writer, args):
84
+ from torch.utils.tensorboard.summary import hparams
85
+ vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
86
+ exp, ssi, sei = hparams(vars_args, {})
87
+ writer.file_writer.add_summary(exp)
88
+ writer.file_writer.add_summary(ssi)
89
+ writer.file_writer.add_summary(sei)
90
+
91
+
92
+ def int_tuple(argstr):
93
+ return tuple(map(int, argstr.split(',')))
94
+
95
+
96
+ def str_tuple(argstr):
97
+ return tuple(argstr.split(','))
98
+
99
+
100
+ def get_checkpoint_path(folder, it=None):
101
+ if it is not None:
102
+ return os.path.join(folder, '%d.pt' % it), it
103
+ all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt'))))
104
+ all_iters.sort()
105
+ return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1]
106
+
107
+
108
+ def load_config(config_path):
109
+ with open(config_path, 'r') as f:
110
+ config = EasyDict(yaml.safe_load(f))
111
+ config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')]
112
+ return config, config_name
113
+
114
+
115
+ def extract_weights(weights: OrderedDict, prefix):
116
+ extracted = OrderedDict()
117
+ for k, v in weights.items():
118
+ if k.startswith(prefix):
119
+ extracted.update({
120
+ k[len(prefix):]: v
121
+ })
122
+ return extracted
123
+
124
+
125
+ def current_milli_time():
126
+ return round(time.time() * 1000)
diffab/utils/protein/constants.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import enum
3
+
4
+ class CDR(enum.IntEnum):
5
+ H1 = 1
6
+ H2 = 2
7
+ H3 = 3
8
+ L1 = 4
9
+ L2 = 5
10
+ L3 = 6
11
+
12
+
13
+ class ChothiaCDRRange:
14
+ H1 = (26, 32)
15
+ H2 = (52, 56)
16
+ H3 = (95, 102)
17
+
18
+ L1 = (24, 34)
19
+ L2 = (50, 56)
20
+ L3 = (89, 97)
21
+
22
+ @classmethod
23
+ def to_cdr(cls, chain_type, resseq):
24
+ assert chain_type in ('H', 'L')
25
+ if chain_type == 'H':
26
+ if cls.H1[0] <= resseq <= cls.H1[1]:
27
+ return CDR.H1
28
+ elif cls.H2[0] <= resseq <= cls.H2[1]:
29
+ return CDR.H2
30
+ elif cls.H3[0] <= resseq <= cls.H3[1]:
31
+ return CDR.H3
32
+ elif chain_type == 'L':
33
+ if cls.L1[0] <= resseq <= cls.L1[1]: # Chothia VH-CDR1
34
+ return CDR.L1
35
+ elif cls.L2[0] <= resseq <= cls.L2[1]:
36
+ return CDR.L2
37
+ elif cls.L3[0] <= resseq <= cls.L3[1]:
38
+ return CDR.L3
39
+
40
+
41
+ class Fragment(enum.IntEnum):
42
+ Heavy = 1
43
+ Light = 2
44
+ Antigen = 3
45
+
46
+ ##
47
+ # Residue identities
48
+ """
49
+ This is part of the OpenMM molecular simulation toolkit originating from
50
+ Simbios, the NIH National Center for Physics-Based Simulation of
51
+ Biological Structures at Stanford, funded under the NIH Roadmap for
52
+ Medical Research, grant U54 GM072970. See https://simtk.org.
53
+
54
+ Portions copyright (c) 2013 Stanford University and the Authors.
55
+ Authors: Peter Eastman
56
+ Contributors:
57
+
58
+ Permission is hereby granted, free of charge, to any person obtaining a
59
+ copy of this software and associated documentation files (the "Software"),
60
+ to deal in the Software without restriction, including without limitation
61
+ the rights to use, copy, modify, merge, publish, distribute, sublicense,
62
+ and/or sell copies of the Software, and to permit persons to whom the
63
+ Software is furnished to do so, subject to the following conditions:
64
+
65
+ The above copyright notice and this permission notice shall be included in
66
+ all copies or substantial portions of the Software.
67
+
68
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
69
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
70
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
71
+ THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
72
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
73
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
74
+ USE OR OTHER DEALINGS IN THE SOFTWARE.
75
+ """
76
+ non_standard_residue_substitutions = {
77
+ '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG',
78
+ 'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA',
79
+ 'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS',
80
+ 'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS',
81
+ 'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA',
82
+ 'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP',
83
+ 'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY',
84
+ 'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO',
85
+ 'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN',
86
+ 'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU',
87
+ 'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE',
88
+ 'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS',
89
+ 'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA',
90
+ 'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR'
91
+ }
92
+
93
+
94
+ ressymb_to_resindex = {
95
+ 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4,
96
+ 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9,
97
+ 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14,
98
+ 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19,
99
+ 'X': 20,
100
+ }
101
+
102
+
103
+ class AA(enum.IntEnum):
104
+ ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4
105
+ GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9
106
+ MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14
107
+ SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19
108
+ UNK = 20
109
+
110
+ @classmethod
111
+ def _missing_(cls, value):
112
+ if isinstance(value, str) and len(value) == 3: # three representation
113
+ if value in non_standard_residue_substitutions:
114
+ value = non_standard_residue_substitutions[value]
115
+ if value in cls._member_names_:
116
+ return getattr(cls, value)
117
+ elif isinstance(value, str) and len(value) == 1: # one representation
118
+ if value in ressymb_to_resindex:
119
+ return cls(ressymb_to_resindex[value])
120
+
121
+ return super()._missing_(value)
122
+
123
+ def __str__(self):
124
+ return self.name
125
+
126
+ @classmethod
127
+ def is_aa(cls, value):
128
+ return (value in ressymb_to_resindex) or \
129
+ (value in non_standard_residue_substitutions) or \
130
+ (value in cls._member_names_) or \
131
+ (value in cls._member_map_.values())
132
+
133
+
134
+ num_aa_types = len(AA)
135
+
136
+ ##
137
+ # Atom identities
138
+
139
+ class BBHeavyAtom(enum.IntEnum):
140
+ N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14;
141
+
142
+ max_num_heavyatoms = 15
143
+
144
+ # Copyright 2021 DeepMind Technologies Limited
145
+ #
146
+ # Licensed under the Apache License, Version 2.0 (the "License");
147
+ # you may not use this file except in compliance with the License.
148
+ # You may obtain a copy of the License at
149
+ #
150
+ # http://www.apache.org/licenses/LICENSE-2.0
151
+ #
152
+ # Unless required by applicable law or agreed to in writing, software
153
+ # distributed under the License is distributed on an "AS IS" BASIS,
154
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
155
+ # See the License for the specific language governing permissions and
156
+ # limitations under the License.
157
+ restype_to_heavyatom_names = {
158
+ AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', '', 'OXT'],
159
+ AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', '', 'OXT'],
160
+ AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', '', 'OXT'],
161
+ AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', '', 'OXT'],
162
+ AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', '', 'OXT'],
163
+ AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', '', 'OXT'],
164
+ AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', '', 'OXT'],
165
+ AA.GLY: ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', '', 'OXT'],
166
+ AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', '', 'OXT'],
167
+ AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', '', 'OXT'],
168
+ AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', '', 'OXT'],
169
+ AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', '', 'OXT'],
170
+ AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', '', 'OXT'],
171
+ AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', '', 'OXT'],
172
+ AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', '', 'OXT'],
173
+ AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', '', 'OXT'],
174
+ AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', '', 'OXT'],
175
+ AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'],
176
+ AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', '', 'OXT'],
177
+ AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', '', 'OXT'],
178
+ AA.UNK: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', ''],
179
+ }
180
+ for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms
181
+
182
+
183
+ backbone_atom_coordinates = {
184
+ AA.ALA: [
185
+ (-0.525, 1.363, 0.0), # N
186
+ (0.0, 0.0, 0.0), # CA
187
+ (1.526, -0.0, -0.0), # C
188
+ ],
189
+ AA.ARG: [
190
+ (-0.524, 1.362, -0.0), # N
191
+ (0.0, 0.0, 0.0), # CA
192
+ (1.525, -0.0, -0.0), # C
193
+ ],
194
+ AA.ASN: [
195
+ (-0.536, 1.357, 0.0), # N
196
+ (0.0, 0.0, 0.0), # CA
197
+ (1.526, -0.0, -0.0), # C
198
+ ],
199
+ AA.ASP: [
200
+ (-0.525, 1.362, -0.0), # N
201
+ (0.0, 0.0, 0.0), # CA
202
+ (1.527, 0.0, -0.0), # C
203
+ ],
204
+ AA.CYS: [
205
+ (-0.522, 1.362, -0.0), # N
206
+ (0.0, 0.0, 0.0), # CA
207
+ (1.524, 0.0, 0.0), # C
208
+ ],
209
+ AA.GLN: [
210
+ (-0.526, 1.361, -0.0), # N
211
+ (0.0, 0.0, 0.0), # CA
212
+ (1.526, 0.0, 0.0), # C
213
+ ],
214
+ AA.GLU: [
215
+ (-0.528, 1.361, 0.0), # N
216
+ (0.0, 0.0, 0.0), # CA
217
+ (1.526, -0.0, -0.0), # C
218
+ ],
219
+ AA.GLY: [
220
+ (-0.572, 1.337, 0.0), # N
221
+ (0.0, 0.0, 0.0), # CA
222
+ (1.517, -0.0, -0.0), # C
223
+ ],
224
+ AA.HIS: [
225
+ (-0.527, 1.36, 0.0), # N
226
+ (0.0, 0.0, 0.0), # CA
227
+ (1.525, 0.0, 0.0), # C
228
+ ],
229
+ AA.ILE: [
230
+ (-0.493, 1.373, -0.0), # N
231
+ (0.0, 0.0, 0.0), # CA
232
+ (1.527, -0.0, -0.0), # C
233
+ ],
234
+ AA.LEU: [
235
+ (-0.52, 1.363, 0.0), # N
236
+ (0.0, 0.0, 0.0), # CA
237
+ (1.525, -0.0, -0.0), # C
238
+ ],
239
+ AA.LYS: [
240
+ (-0.526, 1.362, -0.0), # N
241
+ (0.0, 0.0, 0.0), # CA
242
+ (1.526, 0.0, 0.0), # C
243
+ ],
244
+ AA.MET: [
245
+ (-0.521, 1.364, -0.0), # N
246
+ (0.0, 0.0, 0.0), # CA
247
+ (1.525, 0.0, 0.0), # C
248
+ ],
249
+ AA.PHE: [
250
+ (-0.518, 1.363, 0.0), # N
251
+ (0.0, 0.0, 0.0), # CA
252
+ (1.524, 0.0, -0.0), # C
253
+ ],
254
+ AA.PRO: [
255
+ (-0.566, 1.351, -0.0), # N
256
+ (0.0, 0.0, 0.0), # CA
257
+ (1.527, -0.0, 0.0), # C
258
+ ],
259
+ AA.SER: [
260
+ (-0.529, 1.36, -0.0), # N
261
+ (0.0, 0.0, 0.0), # CA
262
+ (1.525, -0.0, -0.0), # C
263
+ ],
264
+ AA.THR: [
265
+ (-0.517, 1.364, 0.0), # N
266
+ (0.0, 0.0, 0.0), # CA
267
+ (1.526, 0.0, -0.0), # C
268
+ ],
269
+ AA.TRP: [
270
+ (-0.521, 1.363, 0.0), # N
271
+ (0.0, 0.0, 0.0), # CA
272
+ (1.525, -0.0, 0.0), # C
273
+ ],
274
+ AA.TYR: [
275
+ (-0.522, 1.362, 0.0), # N
276
+ (0.0, 0.0, 0.0), # CA
277
+ (1.524, -0.0, -0.0), # C
278
+ ],
279
+ AA.VAL: [
280
+ (-0.494, 1.373, -0.0), # N
281
+ (0.0, 0.0, 0.0), # CA
282
+ (1.527, -0.0, -0.0), # C
283
+ ],
284
+ }
285
+
286
+ bb_oxygen_coordinate = {
287
+ AA.ALA: (2.153, -1.062, 0.0),
288
+ AA.ARG: (2.151, -1.062, 0.0),
289
+ AA.ASN: (2.151, -1.062, 0.0),
290
+ AA.ASP: (2.153, -1.062, 0.0),
291
+ AA.CYS: (2.149, -1.062, 0.0),
292
+ AA.GLN: (2.152, -1.062, 0.0),
293
+ AA.GLU: (2.152, -1.062, 0.0),
294
+ AA.GLY: (2.143, -1.062, 0.0),
295
+ AA.HIS: (2.15, -1.063, 0.0),
296
+ AA.ILE: (2.154, -1.062, 0.0),
297
+ AA.LEU: (2.15, -1.063, 0.0),
298
+ AA.LYS: (2.152, -1.062, 0.0),
299
+ AA.MET: (2.15, -1.062, 0.0),
300
+ AA.PHE: (2.15, -1.062, 0.0),
301
+ AA.PRO: (2.148, -1.066, 0.0),
302
+ AA.SER: (2.151, -1.062, 0.0),
303
+ AA.THR: (2.152, -1.062, 0.0),
304
+ AA.TRP: (2.152, -1.062, 0.0),
305
+ AA.TYR: (2.151, -1.062, 0.0),
306
+ AA.VAL: (2.154, -1.062, 0.0),
307
+ }
308
+
309
+ backbone_atom_coordinates_tensor = torch.zeros([21, 3, 3])
310
+ bb_oxygen_coordinate_tensor = torch.zeros([21, 3])
311
+
312
+ def make_coordinate_tensors():
313
+ for restype, atom_coords in backbone_atom_coordinates.items():
314
+ for atom_id, atom_coord in enumerate(atom_coords):
315
+ backbone_atom_coordinates_tensor[restype][atom_id] = torch.FloatTensor(atom_coord)
316
+
317
+ for restype, bb_oxy_coord in bb_oxygen_coordinate.items():
318
+ bb_oxygen_coordinate_tensor[restype] = torch.FloatTensor(bb_oxy_coord)
319
+ make_coordinate_tensors()
diffab/utils/protein/parsers.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from Bio.PDB import Selection
3
+ from Bio.PDB.Residue import Residue
4
+ from easydict import EasyDict
5
+
6
+ from .constants import (
7
+ AA, max_num_heavyatoms,
8
+ restype_to_heavyatom_names,
9
+ BBHeavyAtom
10
+ )
11
+
12
+
13
+ class ParsingException(Exception):
14
+ pass
15
+
16
+
17
+ def _get_residue_heavyatom_info(res: Residue):
18
+ pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float)
19
+ mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool)
20
+ restype = AA(res.get_resname())
21
+ for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]):
22
+ if atom_name == '': continue
23
+ if atom_name in res:
24
+ pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype)
25
+ mask_heavyatom[idx] = True
26
+ return pos_heavyatom, mask_heavyatom
27
+
28
+
29
+ def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None):
30
+ chains = Selection.unfold_entities(entity, 'C')
31
+ chains.sort(key=lambda c: c.get_id())
32
+ data = EasyDict({
33
+ 'chain_id': [],
34
+ 'resseq': [], 'icode': [], 'res_nb': [],
35
+ 'aa': [],
36
+ 'pos_heavyatom': [], 'mask_heavyatom': [],
37
+ })
38
+ tensor_types = {
39
+ 'resseq': torch.LongTensor,
40
+ 'res_nb': torch.LongTensor,
41
+ 'aa': torch.LongTensor,
42
+ 'pos_heavyatom': torch.stack,
43
+ 'mask_heavyatom': torch.stack,
44
+ }
45
+
46
+ count_aa, count_unk = 0, 0
47
+
48
+ for i, chain in enumerate(chains):
49
+ seq_this = 0 # Renumbering residues
50
+ residues = Selection.unfold_entities(chain, 'R')
51
+ residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode
52
+ for _, res in enumerate(residues):
53
+ resseq_this = int(res.get_id()[1])
54
+ if max_resseq is not None and resseq_this > max_resseq:
55
+ continue
56
+
57
+ resname = res.get_resname()
58
+ if not AA.is_aa(resname): continue
59
+ if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue
60
+ restype = AA(resname)
61
+ count_aa += 1
62
+ if restype == AA.UNK:
63
+ count_unk += 1
64
+ continue
65
+
66
+ # Chain info
67
+ data.chain_id.append(chain.get_id())
68
+
69
+ # Residue types
70
+ data.aa.append(restype) # Will be automatically cast to torch.long
71
+
72
+ # Heavy atoms
73
+ pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res)
74
+ data.pos_heavyatom.append(pos_heavyatom)
75
+ data.mask_heavyatom.append(mask_heavyatom)
76
+
77
+ # Sequential number
78
+ resseq_this = int(res.get_id()[1])
79
+ icode_this = res.get_id()[2]
80
+ if seq_this == 0:
81
+ seq_this = 1
82
+ else:
83
+ d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item()
84
+ if d_CA_CA <= 4.0:
85
+ seq_this += 1
86
+ else:
87
+ d_resseq = resseq_this - data.resseq[-1]
88
+ seq_this += max(2, d_resseq)
89
+
90
+ data.resseq.append(resseq_this)
91
+ data.icode.append(icode_this)
92
+ data.res_nb.append(seq_this)
93
+
94
+ if len(data.aa) == 0:
95
+ raise ParsingException('No parsed residues.')
96
+
97
+ if (count_unk / count_aa) >= unknown_threshold:
98
+ raise ParsingException(
99
+ f'Too many unknown residues, threshold {unknown_threshold:.2f}.'
100
+ )
101
+
102
+ seq_map = {}
103
+ for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)):
104
+ seq_map[(chain_id, resseq, icode)] = i
105
+
106
+ for key, convert_fn in tensor_types.items():
107
+ data[key] = convert_fn(data[key])
108
+
109
+ return data, seq_map
diffab/utils/protein/writers.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from Bio import BiopythonWarning
4
+ from Bio.PDB import PDBIO
5
+ from Bio.PDB.StructureBuilder import StructureBuilder
6
+
7
+ from .constants import AA, restype_to_heavyatom_names
8
+
9
+
10
+ def save_pdb(data, path=None):
11
+ """
12
+ Args:
13
+ data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`,
14
+ `pos_heavyatom`, `mask_heavyatom`.
15
+ """
16
+
17
+ def _mask_select(v, mask):
18
+ if isinstance(v, str):
19
+ return ''.join([s for i, s in enumerate(v) if mask[i]])
20
+ elif isinstance(v, list):
21
+ return [s for i, s in enumerate(v) if mask[i]]
22
+ elif isinstance(v, torch.Tensor):
23
+ return v[mask]
24
+ else:
25
+ return v
26
+
27
+ def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch):
28
+ builder.init_chain(chain_id_ch[0])
29
+ builder.init_seg(' ')
30
+
31
+ for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \
32
+ zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch):
33
+ if not AA.is_aa(aa_res.item()):
34
+ print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item()))
35
+ continue
36
+ restype = AA(aa_res.item())
37
+ builder.init_residue(
38
+ resname = str(restype),
39
+ field = ' ',
40
+ resseq = resseq_res.item(),
41
+ icode = icode_res,
42
+ )
43
+
44
+ for i, atom_name in enumerate(restype_to_heavyatom_names[restype]):
45
+ if atom_name == '': continue # No expected atom
46
+ if (~mask_allatom_res[i]).any(): continue # Atom is missing
47
+ if len(atom_name) == 1: fullname = ' %s ' % atom_name
48
+ elif len(atom_name) == 2: fullname = ' %s ' % atom_name
49
+ elif len(atom_name) == 3: fullname = ' %s' % atom_name
50
+ else: fullname = atom_name # len == 4
51
+ builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,)
52
+
53
+ warnings.simplefilter('ignore', BiopythonWarning)
54
+ builder = StructureBuilder()
55
+ builder.init_structure(0)
56
+ builder.init_model(0)
57
+
58
+ unique_chain_nb = data['chain_nb'].unique().tolist()
59
+ for ch_nb in unique_chain_nb:
60
+ mask = (data['chain_nb'] == ch_nb)
61
+ aa = _mask_select(data['aa'], mask)
62
+ pos_heavyatom = _mask_select(data['pos_heavyatom'], mask)
63
+ mask_heavyatom = _mask_select(data['mask_heavyatom'], mask)
64
+ chain_id = _mask_select(data['chain_id'], mask)
65
+ resseq = _mask_select(data['resseq'], mask)
66
+ icode = _mask_select(data['icode'], mask)
67
+
68
+ _build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode)
69
+
70
+ structure = builder.get_structure()
71
+ if path is not None:
72
+ io = PDBIO()
73
+ io.set_structure(structure)
74
+ io.save(path)
75
+ return structure
diffab/utils/train.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .misc import BlackHole
6
+
7
+
8
+ def get_optimizer(cfg, model):
9
+ if cfg.type == 'adam':
10
+ return torch.optim.Adam(
11
+ model.parameters(),
12
+ lr=cfg.lr,
13
+ weight_decay=cfg.weight_decay,
14
+ betas=(cfg.beta1, cfg.beta2, )
15
+ )
16
+ else:
17
+ raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
18
+
19
+
20
+ def get_scheduler(cfg, optimizer):
21
+ if cfg.type is None:
22
+ return BlackHole()
23
+ elif cfg.type == 'plateau':
24
+ return torch.optim.lr_scheduler.ReduceLROnPlateau(
25
+ optimizer,
26
+ factor=cfg.factor,
27
+ patience=cfg.patience,
28
+ min_lr=cfg.min_lr,
29
+ )
30
+ elif cfg.type == 'multistep':
31
+ return torch.optim.lr_scheduler.MultiStepLR(
32
+ optimizer,
33
+ milestones=cfg.milestones,
34
+ gamma=cfg.gamma,
35
+ )
36
+ elif cfg.type == 'exp':
37
+ return torch.optim.lr_scheduler.ExponentialLR(
38
+ optimizer,
39
+ gamma=cfg.gamma,
40
+ )
41
+ elif cfg.type is None:
42
+ return BlackHole()
43
+ else:
44
+ raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
45
+
46
+
47
+ def get_warmup_sched(cfg, optimizer):
48
+ if cfg is None: return BlackHole()
49
+ lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups]
50
+ warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas)
51
+ return warmup_sched
52
+
53
+
54
+ def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}):
55
+ logstr = '[%s] Iter %05d' % (tag, it)
56
+ logstr += ' | loss %.4f' % out['overall'].item()
57
+ for k, v in out.items():
58
+ if k == 'overall': continue
59
+ logstr += ' | loss(%s) %.4f' % (k, v.item())
60
+ for k, v in others.items():
61
+ logstr += ' | %s %2.4f' % (k, v)
62
+ logger.info(logstr)
63
+
64
+ for k, v in out.items():
65
+ if k == 'overall':
66
+ writer.add_scalar('%s/loss' % tag, v, it)
67
+ else:
68
+ writer.add_scalar('%s/loss_%s' % (tag, k), v, it)
69
+ for k, v in others.items():
70
+ writer.add_scalar('%s/%s' % (tag, k), v, it)
71
+ writer.flush()
72
+
73
+
74
+ class ValidationLossTape(object):
75
+
76
+ def __init__(self):
77
+ super().__init__()
78
+ self.accumulate = {}
79
+ self.others = {}
80
+ self.total = 0
81
+
82
+ def update(self, out, n, others={}):
83
+ self.total += n
84
+ for k, v in out.items():
85
+ if k not in self.accumulate:
86
+ self.accumulate[k] = v.clone().detach()
87
+ else:
88
+ self.accumulate[k] += v.clone().detach()
89
+
90
+ for k, v in others.items():
91
+ if k not in self.others:
92
+ self.others[k] = v.clone().detach()
93
+ else:
94
+ self.others[k] += v.clone().detach()
95
+
96
+
97
+ def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'):
98
+ avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()})
99
+ avg_others = EasyDict({k:v / self.total for k, v in self.others.items()})
100
+ log_losses(avg, it, tag, logger, writer, others=avg_others)
101
+ return avg['overall']
102
+
103
+
104
+ def recursive_to(obj, device):
105
+ if isinstance(obj, torch.Tensor):
106
+ if device == 'cpu':
107
+ return obj.cpu()
108
+ try:
109
+ return obj.cuda(device=device, non_blocking=True)
110
+ except RuntimeError:
111
+ return obj.to(device)
112
+ elif isinstance(obj, list):
113
+ return [recursive_to(o, device=device) for o in obj]
114
+ elif isinstance(obj, tuple):
115
+ return tuple(recursive_to(o, device=device) for o in obj)
116
+ elif isinstance(obj, dict):
117
+ return {k: recursive_to(v, device=device) for k, v in obj.items()}
118
+
119
+ else:
120
+ return obj
121
+
122
+
123
+ def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'):
124
+ if mode == 'sqrt':
125
+ w = np.sqrt(length / max_length)
126
+ elif mode == 'linear':
127
+ w = length / max_length
128
+ elif mode is None:
129
+ w = 1.0
130
+ else:
131
+ raise ValueError('Unknown reweighting mode: %s' % mode)
132
+ return w
133
+
134
+
135
+ def sum_weighted_losses(losses, weights):
136
+ """
137
+ Args:
138
+ losses: Dict of scalar tensors.
139
+ weights: Dict of weights.
140
+ """
141
+ loss = 0
142
+ for k in losses.keys():
143
+ if weights is None:
144
+ loss = loss + losses[k]
145
+ else:
146
+ loss = loss + weights[k] * losses[k]
147
+ return loss
148
+
149
+
150
+ def count_parameters(model):
151
+ return sum(p.numel() for p in model.parameters())
diffab/utils/transforms/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Transforms
2
+ from .mask import MaskSingleCDR, MaskMultipleCDRs, MaskAntibody
3
+ from .merge import MergeChains
4
+ from .patch import PatchAroundAnchor
5
+
6
+ # Factory
7
+ from ._base import get_transform, Compose
diffab/utils/transforms/_base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from torchvision.transforms import Compose
4
+
5
+
6
+ _TRANSFORM_DICT = {}
7
+
8
+
9
+ def register_transform(name):
10
+ def decorator(cls):
11
+ _TRANSFORM_DICT[name] = cls
12
+ return cls
13
+ return decorator
14
+
15
+
16
+ def get_transform(cfg):
17
+ if cfg is None or len(cfg) == 0:
18
+ return None
19
+ tfms = []
20
+ for t_dict in cfg:
21
+ t_dict = copy.deepcopy(t_dict)
22
+ cls = _TRANSFORM_DICT[t_dict.pop('type')]
23
+ tfms.append(cls(**t_dict))
24
+ return Compose(tfms)
25
+
26
+
27
+ def _index_select(v, index, n):
28
+ if isinstance(v, torch.Tensor) and v.size(0) == n:
29
+ return v[index]
30
+ elif isinstance(v, list) and len(v) == n:
31
+ return [v[i] for i in index]
32
+ else:
33
+ return v
34
+
35
+
36
+ def _index_select_data(data, index):
37
+ return {
38
+ k: _index_select(v, index, data['aa'].size(0))
39
+ for k, v in data.items()
40
+ }
41
+
42
+
43
+ def _mask_select(v, mask):
44
+ if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0):
45
+ return v[mask]
46
+ elif isinstance(v, list) and len(v) == mask.size(0):
47
+ return [v[i] for i, b in enumerate(mask) if b]
48
+ else:
49
+ return v
50
+
51
+
52
+ def _mask_select_data(data, mask):
53
+ return {
54
+ k: _mask_select(v, mask)
55
+ for k, v in data.items()
56
+ }
diffab/utils/transforms/mask.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from typing import List, Optional
4
+
5
+ from ..protein import constants
6
+ from ._base import register_transform
7
+
8
+
9
+ def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
10
+ first, last = continuous_flag_to_range(flag)
11
+ length = flag.sum().item()
12
+ if (length - 2*shrink_limit) < min_length:
13
+ shrink_limit = 0
14
+ first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
15
+ last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
16
+ flag_ext = flag.clone()
17
+ flag_ext[first_ext : last_ext+1] = True
18
+ return flag_ext
19
+
20
+
21
+ def continuous_flag_to_range(flag):
22
+ first = (torch.arange(0, flag.size(0))[flag]).min().item()
23
+ last = (torch.arange(0, flag.size(0))[flag]).max().item()
24
+ return first, last
25
+
26
+
27
+ @register_transform('mask_single_cdr')
28
+ class MaskSingleCDR(object):
29
+
30
+ def __init__(self, selection=None, augmentation=True):
31
+ super().__init__()
32
+ cdr_str_to_enum = {
33
+ 'H1': constants.CDR.H1,
34
+ 'H2': constants.CDR.H2,
35
+ 'H3': constants.CDR.H3,
36
+ 'L1': constants.CDR.L1,
37
+ 'L2': constants.CDR.L2,
38
+ 'L3': constants.CDR.L3,
39
+ 'H_CDR1': constants.CDR.H1,
40
+ 'H_CDR2': constants.CDR.H2,
41
+ 'H_CDR3': constants.CDR.H3,
42
+ 'L_CDR1': constants.CDR.L1,
43
+ 'L_CDR2': constants.CDR.L2,
44
+ 'L_CDR3': constants.CDR.L3,
45
+ 'CDR3': 'CDR3', # H3 first, then fallback to L3
46
+ }
47
+ assert selection is None or selection in cdr_str_to_enum
48
+ self.selection = cdr_str_to_enum.get(selection, None)
49
+ self.augmentation = augmentation
50
+
51
+ def perform_masking_(self, data, selection=None):
52
+ cdr_flag = data['cdr_flag']
53
+
54
+ if selection is None:
55
+ cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
56
+ cdr_to_mask = random.choice(cdr_all)
57
+ else:
58
+ cdr_to_mask = selection
59
+
60
+ cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
61
+ if self.augmentation:
62
+ cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
63
+
64
+ cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
65
+ left_idx = max(0, cdr_first-1)
66
+ right_idx = min(data['aa'].size(0)-1, cdr_last+1)
67
+ anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
68
+ anchor_flag[left_idx] = True
69
+ anchor_flag[right_idx] = True
70
+
71
+ data['generate_flag'] = cdr_to_mask_flag
72
+ data['anchor_flag'] = anchor_flag
73
+
74
+ def __call__(self, structure):
75
+ if self.selection is None:
76
+ ab_data = []
77
+ if structure['heavy'] is not None:
78
+ ab_data.append(structure['heavy'])
79
+ if structure['light'] is not None:
80
+ ab_data.append(structure['light'])
81
+ data_to_mask = random.choice(ab_data)
82
+ sel = None
83
+ elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
84
+ data_to_mask = structure['heavy']
85
+ sel = int(self.selection)
86
+ elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
87
+ data_to_mask = structure['light']
88
+ sel = int(self.selection)
89
+ elif self.selection == 'CDR3':
90
+ if structure['heavy'] is not None:
91
+ data_to_mask = structure['heavy']
92
+ sel = constants.CDR.H3
93
+ else:
94
+ data_to_mask = structure['light']
95
+ sel = constants.CDR.L3
96
+
97
+ self.perform_masking_(data_to_mask, selection=sel)
98
+ return structure
99
+
100
+
101
+ @register_transform('mask_multiple_cdrs')
102
+ class MaskMultipleCDRs(object):
103
+
104
+ def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
105
+ super().__init__()
106
+ cdr_str_to_enum = {
107
+ 'H1': constants.CDR.H1,
108
+ 'H2': constants.CDR.H2,
109
+ 'H3': constants.CDR.H3,
110
+ 'L1': constants.CDR.L1,
111
+ 'L2': constants.CDR.L2,
112
+ 'L3': constants.CDR.L3,
113
+ 'H_CDR1': constants.CDR.H1,
114
+ 'H_CDR2': constants.CDR.H2,
115
+ 'H_CDR3': constants.CDR.H3,
116
+ 'L_CDR1': constants.CDR.L1,
117
+ 'L_CDR2': constants.CDR.L2,
118
+ 'L_CDR3': constants.CDR.L3,
119
+ }
120
+ if selection is not None:
121
+ self.selection = [cdr_str_to_enum[s] for s in selection]
122
+ else:
123
+ self.selection = None
124
+ self.augmentation = augmentation
125
+
126
+ def mask_one_cdr_(self, data, cdr_to_mask):
127
+ cdr_flag = data['cdr_flag']
128
+
129
+ cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
130
+ if self.augmentation:
131
+ cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)
132
+
133
+ cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
134
+ left_idx = max(0, cdr_first-1)
135
+ right_idx = min(data['aa'].size(0)-1, cdr_last+1)
136
+ anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
137
+ anchor_flag[left_idx] = True
138
+ anchor_flag[right_idx] = True
139
+
140
+ if 'generate_flag' not in data:
141
+ data['generate_flag'] = cdr_to_mask_flag
142
+ data['anchor_flag'] = anchor_flag
143
+ else:
144
+ data['generate_flag'] |= cdr_to_mask_flag
145
+ data['anchor_flag'] |= anchor_flag
146
+
147
+ def mask_for_one_chain_(self, data):
148
+ cdr_flag = data['cdr_flag']
149
+ cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
150
+
151
+ num_cdrs_to_mask = random.randint(1, len(cdr_all))
152
+
153
+ if self.selection is not None:
154
+ cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
155
+ else:
156
+ random.shuffle(cdr_all)
157
+ cdrs_to_mask = cdr_all[:num_cdrs_to_mask]
158
+
159
+ for cdr_to_mask in cdrs_to_mask:
160
+ self.mask_one_cdr_(data, cdr_to_mask)
161
+
162
+ def __call__(self, structure):
163
+ if structure['heavy'] is not None:
164
+ self.mask_for_one_chain_(structure['heavy'])
165
+ if structure['light'] is not None:
166
+ self.mask_for_one_chain_(structure['light'])
167
+ return structure
168
+
169
+
170
+ @register_transform('mask_antibody')
171
+ class MaskAntibody(object):
172
+
173
+ def mask_ab_chain_(self, data):
174
+ data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)
175
+
176
+ def __call__(self, structure):
177
+ pos_ab_alpha = []
178
+ if structure['heavy'] is not None:
179
+ self.mask_ab_chain_(structure['heavy'])
180
+ pos_ab_alpha.append(
181
+ structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
182
+ )
183
+ if structure['light'] is not None:
184
+ self.mask_ab_chain_(structure['light'])
185
+ pos_ab_alpha.append(
186
+ structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
187
+ )
188
+ pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3)
189
+
190
+ if structure['antigen'] is not None:
191
+ pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
192
+ ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab)
193
+ nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag)
194
+ contact_flag = (nn_ab_dist <= 6.0) # (L_Ag)
195
+ if contact_flag.sum().item() == 0:
196
+ contact_flag[nn_ab_dist.argmin()] = True
197
+
198
+ anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
199
+ anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
200
+ anchor_flag[anchor_idx] = True
201
+ structure['antigen']['anchor_flag'] = anchor_flag
202
+ structure['antigen']['contact_flag'] = contact_flag
203
+
204
+ return structure
205
+
206
+
207
+ @register_transform('remove_antigen')
208
+ class RemoveAntigen:
209
+
210
+ def __call__(self, structure):
211
+ structure['antigen'] = None
212
+ structure['antigen_seqmap'] = None
213
+ return structure
diffab/utils/transforms/merge.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..protein import constants
4
+ from ._base import register_transform
5
+
6
+
7
+ @register_transform('merge_chains')
8
+ class MergeChains(object):
9
+
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def assign_chain_number_(self, data_list):
14
+ chains = set()
15
+ for data in data_list:
16
+ chains.update(data['chain_id'])
17
+ chains = {c: i for i, c in enumerate(chains)}
18
+
19
+ for data in data_list:
20
+ data['chain_nb'] = torch.LongTensor([
21
+ chains[c] for c in data['chain_id']
22
+ ])
23
+
24
+ def _data_attr(self, data, name):
25
+ if name in ('generate_flag', 'anchor_flag') and name not in data:
26
+ return torch.zeros(data['aa'].shape, dtype=torch.bool)
27
+ else:
28
+ return data[name]
29
+
30
+ def __call__(self, structure):
31
+ data_list = []
32
+ if structure['heavy'] is not None:
33
+ structure['heavy']['fragment_type'] = torch.full_like(
34
+ structure['heavy']['aa'],
35
+ fill_value = constants.Fragment.Heavy,
36
+ )
37
+ data_list.append(structure['heavy'])
38
+
39
+ if structure['light'] is not None:
40
+ structure['light']['fragment_type'] = torch.full_like(
41
+ structure['light']['aa'],
42
+ fill_value = constants.Fragment.Light,
43
+ )
44
+ data_list.append(structure['light'])
45
+
46
+ if structure['antigen'] is not None:
47
+ structure['antigen']['fragment_type'] = torch.full_like(
48
+ structure['antigen']['aa'],
49
+ fill_value = constants.Fragment.Antigen,
50
+ )
51
+ structure['antigen']['cdr_flag'] = torch.zeros_like(
52
+ structure['antigen']['aa'],
53
+ )
54
+ data_list.append(structure['antigen'])
55
+
56
+ self.assign_chain_number_(data_list)
57
+
58
+ list_props = {
59
+ 'chain_id': [],
60
+ 'icode': [],
61
+ }
62
+ tensor_props = {
63
+ 'chain_nb': [],
64
+ 'resseq': [],
65
+ 'res_nb': [],
66
+ 'aa': [],
67
+ 'pos_heavyatom': [],
68
+ 'mask_heavyatom': [],
69
+ 'generate_flag': [],
70
+ 'cdr_flag': [],
71
+ 'anchor_flag': [],
72
+ 'fragment_type': [],
73
+ }
74
+
75
+ for data in data_list:
76
+ for k in list_props.keys():
77
+ list_props[k].append(self._data_attr(data, k))
78
+ for k in tensor_props.keys():
79
+ tensor_props[k].append(self._data_attr(data, k))
80
+
81
+ list_props = {k: sum(v, start=[]) for k, v in list_props.items()}
82
+ tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()}
83
+ data_out = {
84
+ **list_props,
85
+ **tensor_props,
86
+ }
87
+ return data_out
88
+