diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1066f07e2ce9288a5462c3a39ef33ed7badc4b1d --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +/playgrounds* +/logs* +/results* +/*.csv diff --git a/app.py b/app.py index 14cf838439007917f9d881f31c984a28bb9970b9..355449caad6b400c42ed1c8cfccbd362e9128636 100644 --- a/app.py +++ b/app.py @@ -24,8 +24,6 @@ from diffab.tools.renumber.run import ( assign_number_to_sequence, ) -DIFFAB_DIR = os.path.realpath('./diffab-repo') - CDR_OPTIONS = OrderedDict() CDR_OPTIONS['H_CDR1'] = 'H1' CDR_OPTIONS['H_CDR2'] = 'H2' @@ -96,7 +94,7 @@ def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_d bufsize=1, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - cwd=DIFFAB_DIR, + cwd=os.getcwd(), ) for line in iter(proc.stdout.readline, b''): output_buffer += line.decode() diff --git a/design_dock.py b/design_dock.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7d1c8e84b234bdb385ea015d675337bd179f57 --- /dev/null +++ b/design_dock.py @@ -0,0 +1,67 @@ +import os +import shutil +import argparse +from diffab.tools.dock.hdock import HDockAntibody +from diffab.tools.runner.design_for_pdb import args_factory, design_for_pdb + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--antigen', type=str, required=True) + parser.add_argument('--antibody', type=str, default='./data/examples/3QHF_Fv.pdb') + parser.add_argument('--heavy', type=str, default='H', help='Chain id of the heavy chain.') + parser.add_argument('--light', type=str, default='L', help='Chain id of the light chain.') + parser.add_argument('--hdock_bin', type=str, default='./bin/hdock') + parser.add_argument('--createpl_bin', type=str, default='./bin/createpl') + parser.add_argument('-n', '--num_docks', type=int, default=10) + parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') + parser.add_argument('-o', '--out_root', type=str, default='./results') + parser.add_argument('-t', '--tag', type=str, default='') + parser.add_argument('-s', '--seed', type=int, default=None) + parser.add_argument('-d', '--device', type=str, default='cuda') + parser.add_argument('-b', '--batch_size', type=int, default=16) + args = parser.parse_args() + + hdock_missing = [] + if not os.path.exists(args.hdock_bin): + hdock_missing.append(args.hdock_bin) + if not os.path.exists(args.createpl_bin): + hdock_missing.append(args.createpl_bin) + if len(hdock_missing) > 0: + print("[WARNING] The following HDOCK applications are missing:") + for f in hdock_missing: + print(f" > {f}") + print("Please download HDOCK from http://huanglab.phys.hust.edu.cn/software/hdocklite/ " + "and put `hdock` and `createpl` to the above path.") + exit() + + antigen_name = os.path.basename(os.path.splitext(args.antigen)[0]) + docked_pdb_dir = os.path.join(os.path.splitext(args.antigen)[0] + '_dock') + os.makedirs(docked_pdb_dir, exist_ok=True) + docked_pdb_paths = [] + for fname in os.listdir(docked_pdb_dir): + if fname.endswith('.pdb'): + docked_pdb_paths.append(os.path.join(docked_pdb_dir, fname)) + if len(docked_pdb_paths) < args.num_docks: + with HDockAntibody() as dock_session: + dock_session.set_antigen(args.antigen) + dock_session.set_antibody(args.antibody) + docked_tmp_paths = dock_session.dock() + for i, tmp_path in enumerate(docked_tmp_paths[:args.num_docks]): + dest_path = os.path.join(docked_pdb_dir, f"{antigen_name}_Ab_{i:04d}.pdb") + shutil.copyfile(tmp_path, dest_path) + print(f'[INFO] Copy {tmp_path} -> {dest_path}') + docked_pdb_paths.append(dest_path) + + for pdb_path in docked_pdb_paths: + current_args = vars(args) + current_args['tag'] += antigen_name + design_args = args_factory( + pdb_path = pdb_path, + **current_args, + ) + design_for_pdb(design_args) + + +if __name__ == '__main__': + main() diff --git a/design_pdb.py b/design_pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..02c5070027cd007b005e77043a49134706fc6b01 --- /dev/null +++ b/design_pdb.py @@ -0,0 +1,4 @@ +from diffab.tools.runner.design_for_pdb import args_from_cmdline, design_for_pdb + +if __name__ == '__main__': + design_for_pdb(args_from_cmdline()) diff --git a/design_testset.py b/design_testset.py new file mode 100644 index 0000000000000000000000000000000000000000..63b6008200fdd26c88a36411e7afca1bf03ddc3e --- /dev/null +++ b/design_testset.py @@ -0,0 +1,4 @@ +from diffab.tools.runner.design_for_testset import main + +if __name__ == '__main__': + main() diff --git a/diffab/datasets/__init__.py b/diffab/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d959a9e0a4e13d029c6b2c5e52340d090435adc --- /dev/null +++ b/diffab/datasets/__init__.py @@ -0,0 +1,4 @@ +from .sabdab import SAbDabDataset +from .custom import CustomDataset + +from ._base import get_dataset diff --git a/diffab/datasets/_base.py b/diffab/datasets/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b2cac9ed9ad6591eb7e1240f2e28da1791999431 --- /dev/null +++ b/diffab/datasets/_base.py @@ -0,0 +1,40 @@ +from torch.utils.data import Dataset, ConcatDataset +from diffab.utils.transforms import get_transform + + +_DATASET_DICT = {} + + +def register_dataset(name): + def decorator(cls): + _DATASET_DICT[name] = cls + return cls + return decorator + + +def get_dataset(cfg): + transform = get_transform(cfg.transform) if 'transform' in cfg else None + return _DATASET_DICT[cfg.type](cfg, transform=transform) + + +@register_dataset('concat') +def get_concat_dataset(cfg): + datasets = [get_dataset(d) for d in cfg.datasets] + return ConcatDataset(datasets) + + +@register_dataset('balanced_concat') +class BalancedConcatDataset(Dataset): + + def __init__(self, cfg, transform=None): + super().__init__() + assert transform is None, 'transform is not supported.' + self.datasets = [get_dataset(d) for d in cfg.datasets] + self.max_size = max([len(d) for d in self.datasets]) + + def __len__(self): + return self.max_size * len(self.datasets) + + def __getitem__(self, idx): + dataset_idx = idx // self.max_size + return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])] diff --git a/diffab/datasets/custom.py b/diffab/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2ebff7200fa5225fb341202728c053067cc83c --- /dev/null +++ b/diffab/datasets/custom.py @@ -0,0 +1,200 @@ +import os +import logging +import joblib +import pickle +import lmdb +from Bio import PDB +from Bio.PDB import PDBExceptions +from torch.utils.data import Dataset +from tqdm.auto import tqdm + +from ..utils.protein import parsers +from .sabdab import _label_heavy_chain_cdr, _label_light_chain_cdr +from ._base import register_dataset + + +def preprocess_antibody_structure(task): + pdb_path = task['pdb_path'] + H_id = task.get('heavy_id', 'H') + L_id = task.get('light_id', 'L') + + parser = PDB.PDBParser(QUIET=True) + model = parser.get_structure(id, pdb_path)[0] + + all_chain_ids = [c.id for c in model] + + parsed = { + 'id': task['id'], + 'heavy': None, + 'heavy_seqmap': None, + 'light': None, + 'light_seqmap': None, + 'antigen': None, + 'antigen_seqmap': None, + } + try: + if H_id in all_chain_ids: + ( + parsed['heavy'], + parsed['heavy_seqmap'] + ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure( + model[H_id], + max_resseq = 113 # Chothia, end of Heavy chain Fv + )) + + if L_id in all_chain_ids: + ( + parsed['light'], + parsed['light_seqmap'] + ) = _label_light_chain_cdr(*parsers.parse_biopython_structure( + model[L_id], + max_resseq = 106 # Chothia, end of Light chain Fv + )) + + if parsed['heavy'] is None and parsed['light'] is None: + raise ValueError( + f'Neither valid antibody H-chain or L-chain is found. ' + f'Please ensure that the chain id of heavy chain is "{H_id}" ' + f'and the id of the light chain is "{L_id}".' + ) + + + ag_chain_ids = [cid for cid in all_chain_ids if cid not in (H_id, L_id)] + if len(ag_chain_ids) > 0: + chains = [model[c] for c in ag_chain_ids] + ( + parsed['antigen'], + parsed['antigen_seqmap'] + ) = parsers.parse_biopython_structure(chains) + + except ( + PDBExceptions.PDBConstructionException, + parsers.ParsingException, + KeyError, + ValueError, + ) as e: + logging.warning('[{}] {}: {}'.format( + task['id'], + e.__class__.__name__, + str(e) + )) + return None + + return parsed + + +@register_dataset('custom') +class CustomDataset(Dataset): + + MAP_SIZE = 32*(1024*1024*1024) # 32GB + + def __init__(self, structure_dir, transform=None, reset=False): + super().__init__() + self.structure_dir = structure_dir + self.transform = transform + + self.db_conn = None + self.db_ids = None + self._load_structures(reset) + + @property + def _cache_db_path(self): + return os.path.join(self.structure_dir, 'structure_cache.lmdb') + + def _connect_db(self): + self._close_db() + self.db_conn = lmdb.open( + self._cache_db_path, + map_size=self.MAP_SIZE, + create=False, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + with self.db_conn.begin() as txn: + keys = [k.decode() for k in txn.cursor().iternext(values=False)] + self.db_ids = keys + + def _close_db(self): + if self.db_conn is not None: + self.db_conn.close() + self.db_conn = None + self.db_ids = None + + def _load_structures(self, reset): + all_pdbs = [] + for fname in os.listdir(self.structure_dir): + if not fname.endswith('.pdb'): continue + all_pdbs.append(fname) + + if reset or not os.path.exists(self._cache_db_path): + todo_pdbs = all_pdbs + else: + self._connect_db() + processed_pdbs = self.db_ids + self._close_db() + todo_pdbs = list(set(all_pdbs) - set(processed_pdbs)) + + if len(todo_pdbs) > 0: + self._preprocess_structures(todo_pdbs) + + def _preprocess_structures(self, pdb_list): + tasks = [] + for pdb_fname in pdb_list: + pdb_path = os.path.join(self.structure_dir, pdb_fname) + tasks.append({ + 'id': pdb_fname, + 'pdb_path': pdb_path, + }) + + data_list = joblib.Parallel( + n_jobs = max(joblib.cpu_count() // 2, 1), + )( + joblib.delayed(preprocess_antibody_structure)(task) + for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') + ) + + db_conn = lmdb.open( + self._cache_db_path, + map_size = self.MAP_SIZE, + create=True, + subdir=False, + readonly=False, + ) + ids = [] + with db_conn.begin(write=True, buffers=True) as txn: + for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): + if data is None: + continue + ids.append(data['id']) + txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) + + def __len__(self): + return len(self.db_ids) + + def __getitem__(self, index): + self._connect_db() + id = self.db_ids[index] + with self.db_conn.begin() as txn: + data = pickle.loads(txn.get(id.encode())) + if self.transform is not None: + data = self.transform(data) + return data + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--dir', type=str, default='./data/custom') + parser.add_argument('--reset', action='store_true', default=False) + args = parser.parse_args() + + dataset = CustomDataset( + structure_dir = args.dir, + reset = args.reset, + ) + print(dataset[0]) + print(len(dataset)) + \ No newline at end of file diff --git a/diffab/datasets/sabdab.py b/diffab/datasets/sabdab.py new file mode 100644 index 0000000000000000000000000000000000000000..4b96c1b111fd620c141519c0a346665e2afe4576 --- /dev/null +++ b/diffab/datasets/sabdab.py @@ -0,0 +1,470 @@ +import os +import random +import logging +import datetime +import pandas as pd +import joblib +import pickle +import lmdb +import subprocess +import torch +from Bio import PDB, SeqRecord, SeqIO, Seq +from Bio.PDB import PDBExceptions +from Bio.PDB import Polypeptide +from torch.utils.data import Dataset +from tqdm.auto import tqdm + +from ..utils.protein import parsers, constants +from ._base import register_dataset + + +ALLOWED_AG_TYPES = { + 'protein', + 'protein | protein', + 'protein | protein | protein', + 'protein | protein | protein | protein | protein', + 'protein | protein | protein | protein', +} + +RESOLUTION_THRESHOLD = 4.0 + +TEST_ANTIGENS = [ + 'sars-cov-2 receptor binding domain', + 'hiv-1 envelope glycoprotein gp160', + 'mers s', + 'influenza a virus', + 'cd27 antigen', +] + + +def nan_to_empty_string(val): + if val != val or not val: + return '' + else: + return val + + +def nan_to_none(val): + if val != val or not val: + return None + else: + return val + + +def split_sabdab_delimited_str(val): + if not val: + return [] + else: + return [s.strip() for s in val.split('|')] + + +def parse_sabdab_resolution(val): + if val == 'NOT' or not val or val != val: + return None + elif isinstance(val, str) and ',' in val: + return float(val.split(',')[0].strip()) + else: + return float(val) + + +def _aa_tensor_to_sequence(aa): + return ''.join([Polypeptide.index_to_one(a.item()) for a in aa.flatten()]) + + +def _label_heavy_chain_cdr(data, seq_map, max_cdr3_length=30): + if data is None or seq_map is None: + return data, seq_map + + # Add CDR labels + cdr_flag = torch.zeros_like(data['aa']) + for position, idx in seq_map.items(): + resseq = position[1] + cdr_type = constants.ChothiaCDRRange.to_cdr('H', resseq) + if cdr_type is not None: + cdr_flag[idx] = cdr_type + data['cdr_flag'] = cdr_flag + + # Add CDR sequence annotations + data['H1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H1] ) + data['H2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H2] ) + data['H3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.H3] ) + + cdr3_length = (cdr_flag == constants.CDR.H3).sum().item() + # Remove too long CDR3 + if cdr3_length > max_cdr3_length: + cdr_flag[cdr_flag == constants.CDR.H3] = 0 + logging.warning(f'CDR-H3 too long {cdr3_length}. Removed.') + return None, None + + # Filter: ensure CDR3 exists + if cdr3_length == 0: + logging.warning('No CDR-H3 found in the heavy chain.') + return None, None + + return data, seq_map + + +def _label_light_chain_cdr(data, seq_map, max_cdr3_length=30): + if data is None or seq_map is None: + return data, seq_map + cdr_flag = torch.zeros_like(data['aa']) + for position, idx in seq_map.items(): + resseq = position[1] + cdr_type = constants.ChothiaCDRRange.to_cdr('L', resseq) + if cdr_type is not None: + cdr_flag[idx] = cdr_type + data['cdr_flag'] = cdr_flag + + data['L1_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L1] ) + data['L2_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L2] ) + data['L3_seq'] = _aa_tensor_to_sequence( data['aa'][cdr_flag == constants.CDR.L3] ) + + cdr3_length = (cdr_flag == constants.CDR.L3).sum().item() + # Remove too long CDR3 + if cdr3_length > max_cdr3_length: + cdr_flag[cdr_flag == constants.CDR.L3] = 0 + logging.warning(f'CDR-L3 too long {cdr3_length}. Removed.') + return None, None + + # Ensure CDR3 exists + if cdr3_length == 0: + logging.warning('No CDRs found in the light chain.') + return None, None + + return data, seq_map + + +def preprocess_sabdab_structure(task): + entry = task['entry'] + pdb_path = task['pdb_path'] + + parser = PDB.PDBParser(QUIET=True) + model = parser.get_structure(id, pdb_path)[0] + + parsed = { + 'id': entry['id'], + 'heavy': None, + 'heavy_seqmap': None, + 'light': None, + 'light_seqmap': None, + 'antigen': None, + 'antigen_seqmap': None, + } + try: + if entry['H_chain'] is not None: + ( + parsed['heavy'], + parsed['heavy_seqmap'] + ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure( + model[entry['H_chain']], + max_resseq = 113 # Chothia, end of Heavy chain Fv + )) + + if entry['L_chain'] is not None: + ( + parsed['light'], + parsed['light_seqmap'] + ) = _label_light_chain_cdr(*parsers.parse_biopython_structure( + model[entry['L_chain']], + max_resseq = 106 # Chothia, end of Light chain Fv + )) + + if parsed['heavy'] is None and parsed['light'] is None: + raise ValueError('Neither valid H-chain or L-chain is found.') + + if len(entry['ag_chains']) > 0: + chains = [model[c] for c in entry['ag_chains']] + ( + parsed['antigen'], + parsed['antigen_seqmap'] + ) = parsers.parse_biopython_structure(chains) + + except ( + PDBExceptions.PDBConstructionException, + parsers.ParsingException, + KeyError, + ValueError, + ) as e: + logging.warning('[{}] {}: {}'.format( + task['id'], + e.__class__.__name__, + str(e) + )) + return None + + return parsed + + +class SAbDabDataset(Dataset): + + MAP_SIZE = 32*(1024*1024*1024) # 32GB + + def __init__( + self, + summary_path = './data/sabdab_summary_all.tsv', + chothia_dir = './data/all_structures/chothia', + processed_dir = './data/processed', + split = 'train', + split_seed = 2022, + transform = None, + reset = False, + ): + super().__init__() + self.summary_path = summary_path + self.chothia_dir = chothia_dir + if not os.path.exists(chothia_dir): + raise FileNotFoundError( + f"SAbDab structures not found in {chothia_dir}. " + "Please download them from http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/" + ) + self.processed_dir = processed_dir + os.makedirs(processed_dir, exist_ok=True) + + self.sabdab_entries = None + self._load_sabdab_entries() + + self.db_conn = None + self.db_ids = None + self._load_structures(reset) + + self.clusters = None + self.id_to_cluster = None + self._load_clusters(reset) + + self.ids_in_split = None + self._load_split(split, split_seed) + + self.transform = transform + + def _load_sabdab_entries(self): + df = pd.read_csv(self.summary_path, sep='\t') + entries_all = [] + for i, row in tqdm( + df.iterrows(), + dynamic_ncols=True, + desc='Loading entries', + total=len(df), + ): + entry_id = "{pdbcode}_{H}_{L}_{Ag}".format( + pdbcode = row['pdb'], + H = nan_to_empty_string(row['Hchain']), + L = nan_to_empty_string(row['Lchain']), + Ag = ''.join(split_sabdab_delimited_str( + nan_to_empty_string(row['antigen_chain']) + )) + ) + ag_chains = split_sabdab_delimited_str( + nan_to_empty_string(row['antigen_chain']) + ) + resolution = parse_sabdab_resolution(row['resolution']) + entry = { + 'id': entry_id, + 'pdbcode': row['pdb'], + 'H_chain': nan_to_none(row['Hchain']), + 'L_chain': nan_to_none(row['Lchain']), + 'ag_chains': ag_chains, + 'ag_type': nan_to_none(row['antigen_type']), + 'ag_name': nan_to_none(row['antigen_name']), + 'date': datetime.datetime.strptime(row['date'], '%m/%d/%y'), + 'resolution': resolution, + 'method': row['method'], + 'scfv': row['scfv'], + } + + # Filtering + if ( + (entry['ag_type'] in ALLOWED_AG_TYPES or entry['ag_type'] is None) + and (entry['resolution'] is not None and entry['resolution'] <= RESOLUTION_THRESHOLD) + ): + entries_all.append(entry) + self.sabdab_entries = entries_all + + def _load_structures(self, reset): + if not os.path.exists(self._structure_cache_path) or reset: + if os.path.exists(self._structure_cache_path): + os.unlink(self._structure_cache_path) + self._preprocess_structures() + + with open(self._structure_cache_path + '-ids', 'rb') as f: + self.db_ids = pickle.load(f) + self.sabdab_entries = list( + filter( + lambda e: e['id'] in self.db_ids, + self.sabdab_entries + ) + ) + + @property + def _structure_cache_path(self): + return os.path.join(self.processed_dir, 'structures.lmdb') + + def _preprocess_structures(self): + tasks = [] + for entry in self.sabdab_entries: + pdb_path = os.path.join(self.chothia_dir, '{}.pdb'.format(entry['pdbcode'])) + if not os.path.exists(pdb_path): + logging.warning(f"PDB not found: {pdb_path}") + continue + tasks.append({ + 'id': entry['id'], + 'entry': entry, + 'pdb_path': pdb_path, + }) + + data_list = joblib.Parallel( + n_jobs = max(joblib.cpu_count() // 2, 1), + )( + joblib.delayed(preprocess_sabdab_structure)(task) + for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') + ) + + db_conn = lmdb.open( + self._structure_cache_path, + map_size = self.MAP_SIZE, + create=True, + subdir=False, + readonly=False, + ) + ids = [] + with db_conn.begin(write=True, buffers=True) as txn: + for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): + if data is None: + continue + ids.append(data['id']) + txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) + + with open(self._structure_cache_path + '-ids', 'wb') as f: + pickle.dump(ids, f) + + @property + def _cluster_path(self): + return os.path.join(self.processed_dir, 'cluster_result_cluster.tsv') + + def _load_clusters(self, reset): + if not os.path.exists(self._cluster_path) or reset: + self._create_clusters() + + clusters, id_to_cluster = {}, {} + with open(self._cluster_path, 'r') as f: + for line in f.readlines(): + cluster_name, data_id = line.split() + if cluster_name not in clusters: + clusters[cluster_name] = [] + clusters[cluster_name].append(data_id) + id_to_cluster[data_id] = cluster_name + self.clusters = clusters + self.id_to_cluster = id_to_cluster + + def _create_clusters(self): + cdr_records = [] + for id in self.db_ids: + structure = self.get_structure(id) + if structure['heavy'] is not None: + cdr_records.append(SeqRecord.SeqRecord( + Seq.Seq(structure['heavy']['H3_seq']), + id = structure['id'], + name = '', + description = '', + )) + elif structure['light'] is not None: + cdr_records.append(SeqRecord.SeqRecord( + Seq.Seq(structure['light']['L3_seq']), + id = structure['id'], + name = '', + description = '', + )) + fasta_path = os.path.join(self.processed_dir, 'cdr_sequences.fasta') + SeqIO.write(cdr_records, fasta_path, 'fasta') + + cmd = ' '.join([ + 'mmseqs', 'easy-cluster', + os.path.realpath(fasta_path), + 'cluster_result', 'cluster_tmp', + '--min-seq-id', '0.5', + '-c', '0.8', + '--cov-mode', '1', + ]) + subprocess.run(cmd, cwd=self.processed_dir, shell=True, check=True) + + def _load_split(self, split, split_seed): + assert split in ('train', 'val', 'test') + ids_test = [ + entry['id'] + for entry in self.sabdab_entries + if entry['ag_name'] in TEST_ANTIGENS + ] + test_relevant_clusters = set([self.id_to_cluster[id] for id in ids_test]) + + ids_train_val = [ + entry['id'] + for entry in self.sabdab_entries + if self.id_to_cluster[entry['id']] not in test_relevant_clusters + ] + random.Random(split_seed).shuffle(ids_train_val) + if split == 'test': + self.ids_in_split = ids_test + elif split == 'val': + self.ids_in_split = ids_train_val[:20] + else: + self.ids_in_split = ids_train_val[20:] + + def _connect_db(self): + if self.db_conn is not None: + return + self.db_conn = lmdb.open( + self._structure_cache_path, + map_size=self.MAP_SIZE, + create=False, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + def get_structure(self, id): + self._connect_db() + with self.db_conn.begin() as txn: + return pickle.loads(txn.get(id.encode())) + + def __len__(self): + return len(self.ids_in_split) + + def __getitem__(self, index): + id = self.ids_in_split[index] + data = self.get_structure(id) + if self.transform is not None: + data = self.transform(data) + return data + + +@register_dataset('sabdab') +def get_sabdab_dataset(cfg, transform): + return SAbDabDataset( + summary_path = cfg.summary_path, + chothia_dir = cfg.chothia_dir, + processed_dir = cfg.processed_dir, + split = cfg.split, + split_seed = cfg.get('split_seed', 2022), + transform = transform, + ) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--split', type=str, default='train') + parser.add_argument('--processed_dir', type=str, default='./data/processed') + parser.add_argument('--reset', action='store_true', default=False) + args = parser.parse_args() + if args.reset: + sure = input('Sure to reset? (y/n): ') + if sure != 'y': + exit() + dataset = SAbDabDataset( + processed_dir=args.processed_dir, + split=args.split, + reset=args.reset + ) + print(dataset[0]) + print(len(dataset), len(dataset.clusters)) diff --git a/diffab/models/__init__.py b/diffab/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9322f11b84a68a0bc837ac6b0ef934151bcb1d --- /dev/null +++ b/diffab/models/__init__.py @@ -0,0 +1,3 @@ +from .diffab import DiffusionAntibodyDesign + +from ._base import get_model diff --git a/diffab/models/_base.py b/diffab/models/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..90c96dedac77955005f32239c78c3e5ce67c94ee --- /dev/null +++ b/diffab/models/_base.py @@ -0,0 +1,13 @@ + +_MODEL_DICT = {} + + +def register_model(name): + def decorator(cls): + _MODEL_DICT[name] = cls + return cls + return decorator + + +def get_model(cfg): + return _MODEL_DICT[cfg.type](cfg) diff --git a/diffab/models/diffab.py b/diffab/models/diffab.py new file mode 100644 index 0000000000000000000000000000000000000000..82e68a967987ebcead961f85c8172b86841a3570 --- /dev/null +++ b/diffab/models/diffab.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn + +from diffab.modules.common.geometry import construct_3d_basis +from diffab.modules.common.so3 import rotation_to_so3vec +from diffab.modules.encoders.residue import ResidueEmbedding +from diffab.modules.encoders.pair import PairEmbedding +from diffab.modules.diffusion.dpm_full import FullDPM +from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom +from ._base import register_model + + +resolution_to_num_atoms = { + 'backbone+CB': 5, + 'full': max_num_heavyatoms +} + + +@register_model('diffab') +class DiffusionAntibodyDesign(nn.Module): + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')] + self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms) + self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms) + + self.diffusion = FullDPM( + cfg.res_feat_dim, + cfg.pair_feat_dim, + **cfg.diffusion, + ) + + def encode(self, batch, remove_structure, remove_sequence): + """ + Returns: + res_feat: (N, L, res_feat_dim) + pair_feat: (N, L, L, pair_feat_dim) + """ + # This is used throughout embedding and encoding layers + # to avoid data leakage. + context_mask = torch.logical_and( + batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], + ~batch['generate_flag'] # Context means ``not generated'' + ) + + structure_mask = context_mask if remove_structure else None + sequence_mask = context_mask if remove_sequence else None + + res_feat = self.residue_embed( + aa = batch['aa'], + res_nb = batch['res_nb'], + chain_nb = batch['chain_nb'], + pos_atoms = batch['pos_heavyatom'], + mask_atoms = batch['mask_heavyatom'], + fragment_type = batch['fragment_type'], + structure_mask = structure_mask, + sequence_mask = sequence_mask, + ) + + pair_feat = self.pair_embed( + aa = batch['aa'], + res_nb = batch['res_nb'], + chain_nb = batch['chain_nb'], + pos_atoms = batch['pos_heavyatom'], + mask_atoms = batch['mask_heavyatom'], + structure_mask = structure_mask, + sequence_mask = sequence_mask, + ) + + R = construct_3d_basis( + batch['pos_heavyatom'][:, :, BBHeavyAtom.CA], + batch['pos_heavyatom'][:, :, BBHeavyAtom.C], + batch['pos_heavyatom'][:, :, BBHeavyAtom.N], + ) + p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA] + + return res_feat, pair_feat, R, p + + def forward(self, batch): + mask_generate = batch['generate_flag'] + mask_res = batch['mask'] + res_feat, pair_feat, R_0, p_0 = self.encode( + batch, + remove_structure = self.cfg.get('train_structure', True), + remove_sequence = self.cfg.get('train_sequence', True) + ) + v_0 = rotation_to_so3vec(R_0) + s_0 = batch['aa'] + + loss_dict = self.diffusion( + v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, + denoise_structure = self.cfg.get('train_structure', True), + denoise_sequence = self.cfg.get('train_sequence', True), + ) + return loss_dict + + @torch.no_grad() + def sample( + self, + batch, + sample_opt={ + 'sample_structure': True, + 'sample_sequence': True, + } + ): + mask_generate = batch['generate_flag'] + mask_res = batch['mask'] + res_feat, pair_feat, R_0, p_0 = self.encode( + batch, + remove_structure = sample_opt.get('sample_structure', True), + remove_sequence = sample_opt.get('sample_sequence', True) + ) + v_0 = rotation_to_so3vec(R_0) + s_0 = batch['aa'] + traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt) + return traj + + @torch.no_grad() + def optimize( + self, + batch, + opt_step, + optimize_opt={ + 'sample_structure': True, + 'sample_sequence': True, + } + ): + mask_generate = batch['generate_flag'] + mask_res = batch['mask'] + res_feat, pair_feat, R_0, p_0 = self.encode( + batch, + remove_structure = optimize_opt.get('sample_structure', True), + remove_sequence = optimize_opt.get('sample_sequence', True) + ) + v_0 = rotation_to_so3vec(R_0) + s_0 = batch['aa'] + + traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt) + return traj diff --git a/diffab/modules/common/geometry.py b/diffab/modules/common/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..80d39e64b7949afbd0ebe87af887a61e3117c670 --- /dev/null +++ b/diffab/modules/common/geometry.py @@ -0,0 +1,481 @@ +import torch +import torch.nn.functional as F + +from diffab.utils.protein.constants import ( + BBHeavyAtom, + backbone_atom_coordinates_tensor, + bb_oxygen_coordinate_tensor, +) +from .topology import get_terminus_flag + + +def safe_norm(x, dim=-1, keepdim=False, eps=1e-8, sqrt=True): + out = torch.clamp(torch.sum(torch.square(x), dim=dim, keepdim=keepdim), min=eps) + return torch.sqrt(out) if sqrt else out + + +def pairwise_distances(x, y=None, return_v=False): + """ + Args: + x: (B, N, d) + y: (B, M, d) + """ + if y is None: y = x + v = x.unsqueeze(2) - y.unsqueeze(1) # (B, N, M, d) + d = safe_norm(v, dim=-1) + if return_v: + return d, v + else: + return d + + +def normalize_vector(v, dim, eps=1e-6): + return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps) + + +def project_v2v(v, e, dim): + """ + Description: + Project vector `v` onto vector `e`. + Args: + v: (N, L, 3). + e: (N, L, 3). + """ + return (e * v).sum(dim=dim, keepdim=True) * e + + +def construct_3d_basis(center, p1, p2): + """ + Args: + center: (N, L, 3), usually the position of C_alpha. + p1: (N, L, 3), usually the position of C. + p2: (N, L, 3), usually the position of N. + Returns + A batch of orthogonal basis matrix, (N, L, 3, 3cols_index). + The matrix is composed of 3 column vectors: [e1, e2, e3]. + """ + v1 = p1 - center # (N, L, 3) + e1 = normalize_vector(v1, dim=-1) + + v2 = p2 - center # (N, L, 3) + u2 = v2 - project_v2v(v2, e1, dim=-1) + e2 = normalize_vector(u2, dim=-1) + + e3 = torch.cross(e1, e2, dim=-1) # (N, L, 3) + + mat = torch.cat([ + e1.unsqueeze(-1), e2.unsqueeze(-1), e3.unsqueeze(-1) + ], dim=-1) # (N, L, 3, 3_index) + return mat + + +def local_to_global(R, t, p): + """ + Description: + Convert local (internal) coordinates to global (external) coordinates q. + q <- Rp + t + Args: + R: (N, L, 3, 3). + t: (N, L, 3). + p: Local coordinates, (N, L, ..., 3). + Returns: + q: Global coordinates, (N, L, ..., 3). + """ + assert p.size(-1) == 3 + p_size = p.size() + N, L = p_size[0], p_size[1] + + p = p.view(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *) + q = torch.matmul(R, p) + t.unsqueeze(-1) # (N, L, 3, *) + q = q.transpose(-1, -2).reshape(p_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3) + return q + + +def global_to_local(R, t, q): + """ + Description: + Convert global (external) coordinates q to local (internal) coordinates p. + p <- R^{T}(q - t) + Args: + R: (N, L, 3, 3). + t: (N, L, 3). + q: Global coordinates, (N, L, ..., 3). + Returns: + p: Local coordinates, (N, L, ..., 3). + """ + assert q.size(-1) == 3 + q_size = q.size() + N, L = q_size[0], q_size[1] + + q = q.reshape(N, L, -1, 3).transpose(-1, -2) # (N, L, *, 3) -> (N, L, 3, *) + p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (N, L, 3, *) + p = p.transpose(-1, -2).reshape(q_size) # (N, L, 3, *) -> (N, L, *, 3) -> (N, L, ..., 3) + return p + + +def apply_rotation_to_vector(R, p): + return local_to_global(R, torch.zeros_like(p), p) + + +def compose_rotation_and_translation(R1, t1, R2, t2): + """ + Args: + R1,t1: Frame basis and coordinate, (N, L, 3, 3), (N, L, 3). + R2,t2: Rotation and translation to be applied to (R1, t1), (N, L, 3, 3), (N, L, 3). + Returns + R_new <- R1R2 + t_new <- R1t2 + t1 + """ + R_new = torch.matmul(R1, R2) # (N, L, 3, 3) + t_new = torch.matmul(R1, t2.unsqueeze(-1)).squeeze(-1) + t1 + return R_new, t_new + + +def compose_chain(Ts): + while len(Ts) >= 2: + R1, t1 = Ts[-2] + R2, t2 = Ts[-1] + T_next = compose_rotation_and_translation(R1, t1, R2, t2) + Ts = Ts[:-2] + [T_next] + return Ts[0] + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +def quaternion_to_rotation_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + quaternions = F.normalize(quaternions, dim=-1) + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +BSD License + +For PyTorch3D software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +def quaternion_1ijk_to_rotation_matrix(q): + """ + (1 + ai + bj + ck) -> R + Args: + q: (..., 3) + """ + b, c, d = torch.unbind(q, dim=-1) + s = torch.sqrt(1 + b**2 + c**2 + d**2) + a, b, c, d = 1/s, b/s, c/s, d/s + + o = torch.stack( + ( + a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c, + 2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b, + 2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2, + ), + -1, + ) + return o.reshape(q.shape[:-1] + (3, 3)) + + +def repr_6d_to_rotation_matrix(x): + """ + Args: + x: 6D representations, (..., 6). + Returns: + Rotation matrices, (..., 3, 3_index). + """ + a1, a2 = x[..., 0:3], x[..., 3:6] + b1 = normalize_vector(a1, dim=-1) + b2 = normalize_vector(a2 - project_v2v(a2, b1, dim=-1), dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + + mat = torch.cat([ + b1.unsqueeze(-1), b2.unsqueeze(-1), b3.unsqueeze(-1) + ], dim=-1) # (N, L, 3, 3_index) + return mat + + +def dihedral_from_four_points(p0, p1, p2, p3): + """ + Args: + p0-3: (*, 3). + Returns: + Dihedral angles in radian, (*, ). + """ + v0 = p2 - p1 + v1 = p0 - p1 + v2 = p3 - p2 + u1 = torch.cross(v0, v1, dim=-1) + n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True) + u2 = torch.cross(v0, v2, dim=-1) + n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True) + sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) ) + dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999) ) + dihed = torch.nan_to_num(dihed) + return dihed + + +def knn_gather(idx, value): + """ + Args: + idx: (B, N, K) + value: (B, M, d) + Returns: + (B, N, K, d) + """ + N, d = idx.size(1), value.size(-1) + idx = idx.unsqueeze(-1).repeat(1, 1, 1, d) # (B, N, K, d) + value = value.unsqueeze(1).repeat(1, N, 1, 1) # (B, N, M, d) + return torch.gather(value, dim=2, index=idx) + + +def knn_points(q, p, K): + """ + Args: + q: (B, M, d) + p: (B, N, d) + Returns: + (B, M, K), (B, M, K), (B, M, K, d) + """ + _, L, _ = p.size() + d = pairwise_distances(q, p) # (B, N, M) + dist, idx = d.topk(min(L, K), dim=-1, largest=False) # (B, M, K), (B, M, K) + return dist, idx, knn_gather(idx, p) + + +def angstrom_to_nm(x): + return x / 10 + + +def nm_to_angstrom(x): + return x * 10 + + +def get_backbone_dihedral_angles(pos_atoms, chain_nb, res_nb, mask): + """ + Args: + pos_atoms: (N, L, A, 3). + chain_nb: (N, L). + res_nb: (N, L). + mask: (N, L). + Returns: + bb_dihedral: Omega, Phi, and Psi angles in radian, (N, L, 3). + mask_bb_dihed: Masks of dihedral angles, (N, L, 3). + """ + pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3) + pos_CA = pos_atoms[:, :, BBHeavyAtom.CA] + pos_C = pos_atoms[:, :, BBHeavyAtom.C] + + N_term_flag, C_term_flag = get_terminus_flag(chain_nb, res_nb, mask) # (N, L) + omega_mask = torch.logical_not(N_term_flag) + phi_mask = torch.logical_not(N_term_flag) + psi_mask = torch.logical_not(C_term_flag) + + # N-termini don't have omega and phi + omega = F.pad( + dihedral_from_four_points(pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:]), + pad=(1, 0), value=0, + ) + phi = F.pad( + dihedral_from_four_points(pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:], pos_C[:, 1:]), + pad=(1, 0), value=0, + ) + + # C-termini don't have psi + psi = F.pad( + dihedral_from_four_points(pos_N[:, :-1], pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:]), + pad=(0, 1), value=0, + ) + + mask_bb_dihed = torch.stack([omega_mask, phi_mask, psi_mask], dim=-1) + bb_dihedral = torch.stack([omega, phi, psi], dim=-1) * mask_bb_dihed + return bb_dihedral, mask_bb_dihed + + +def pairwise_dihedrals(pos_atoms): + """ + Args: + pos_atoms: (N, L, A, 3). + Returns: + Inter-residue Phi and Psi angles, (N, L, L, 2). + """ + N, L = pos_atoms.shape[:2] + pos_N = pos_atoms[:, :, BBHeavyAtom.N] # (N, L, 3) + pos_CA = pos_atoms[:, :, BBHeavyAtom.CA] + pos_C = pos_atoms[:, :, BBHeavyAtom.C] + + ir_phi = dihedral_from_four_points( + pos_C[:,:,None].expand(N, L, L, 3), + pos_N[:,None,:].expand(N, L, L, 3), + pos_CA[:,None,:].expand(N, L, L, 3), + pos_C[:,None,:].expand(N, L, L, 3) + ) + ir_psi = dihedral_from_four_points( + pos_N[:,:,None].expand(N, L, L, 3), + pos_CA[:,:,None].expand(N, L, L, 3), + pos_C[:,:,None].expand(N, L, L, 3), + pos_N[:,None,:].expand(N, L, L, 3) + ) + ir_dihed = torch.stack([ir_phi, ir_psi], dim=-1) + return ir_dihed + + +def apply_rotation_matrix_to_rot6d(R, O): + """ + Args: + R: (..., 3, 3) + O: (..., 6) + Returns: + Rotated 6D representation, (..., 6). + """ + u1, u2 = O[..., :3, None], O[..., 3:, None] # (..., 3, 1) + v1 = torch.matmul(R, u1).squeeze(-1) # (..., 3) + v2 = torch.matmul(R, u2).squeeze(-1) + return torch.cat([v1, v2], dim=-1) + + +def normalize_rot6d(O): + """ + Args: + O: (..., 6) + """ + u1, u2 = O[..., :3], O[..., 3:] # (..., 3) + v1 = F.normalize(u1, p=2, dim=-1) # (..., 3) + v2 = F.normalize(u2 - project_v2v(u2, v1), p=2, dim=-1) + return torch.cat([v1, v2], dim=-1) + + +def reconstruct_backbone(R, t, aa, chain_nb, res_nb, mask): + """ + Args: + R: (N, L, 3, 3) + t: (N, L, 3) + aa: (N, L) + chain_nb: (N, L) + res_nb: (N, L) + mask: (N, L) + Returns: + Reconstructed backbone atoms, (N, L, 4, 3). + """ + N, L = aa.size() + # atom_coords = restype_heavyatom_rigid_group_positions.clone().to(t) # (21, 14, 3) + bb_coords = backbone_atom_coordinates_tensor.clone().to(t) # (21, 3, 3) + oxygen_coord = bb_oxygen_coordinate_tensor.clone().to(t) # (21, 3) + aa = aa.clamp(min=0, max=20) # 20 for UNK + + bb_coords = bb_coords[aa.flatten()].reshape(N, L, -1, 3) # (N, L, 3, 3) + oxygen_coord = oxygen_coord[aa.flatten()].reshape(N, L, -1) # (N, L, 3) + bb_pos = local_to_global(R, t, bb_coords) # Global coordinates of N, CA, C. (N, L, 3, 3). + + # Compute PSI angle + bb_dihedral, _ = get_backbone_dihedral_angles(bb_pos, chain_nb, res_nb, mask) + psi = bb_dihedral[..., 2] # (N, L) + # Make rotation matrix for PSI + sin_psi = torch.sin(psi).reshape(N, L, 1, 1) + cos_psi = torch.cos(psi).reshape(N, L, 1, 1) + zero = torch.zeros_like(sin_psi) + one = torch.ones_like(sin_psi) + row1 = torch.cat([one, zero, zero], dim=-1) # (N, L, 1, 3) + row2 = torch.cat([zero, cos_psi, -sin_psi], dim=-1) # (N, L, 1, 3) + row3 = torch.cat([zero, sin_psi, cos_psi], dim=-1) # (N, L, 1, 3) + R_psi = torch.cat([row1, row2, row3], dim=-2) # (N, L, 3, 3) + + # Compute rotoation and translation of PSI frame, and position of O. + R_psi, t_psi = compose_chain([ + (R, t), # Backbone + (R_psi, torch.zeros_like(t)), # PSI angle + ]) + O_pos = local_to_global(R_psi, t_psi, oxygen_coord.reshape(N, L, 1, 3)) + + bb_pos = torch.cat([bb_pos, O_pos], dim=2) # (N, L, 4, 3) + return bb_pos + + +def reconstruct_backbone_partially(pos_ctx, R_new, t_new, aa, chain_nb, res_nb, mask_atoms, mask_recons): + """ + Args: + pos: (N, L, A, 3). + R_new: (N, L, 3, 3). + t_new: (N, L, 3). + mask_atoms: (N, L, A). + mask_recons:(N, L). + Returns: + pos_new: (N, L, A, 3). + mask_new: (N, L, A). + """ + N, L, A = mask_atoms.size() + + mask_res = mask_atoms[:, :, BBHeavyAtom.CA] + pos_recons = reconstruct_backbone(R_new, t_new, aa, chain_nb, res_nb, mask_res) # (N, L, 4, 3) + pos_recons = F.pad(pos_recons, pad=(0, 0, 0, A-4), value=0) # (N, L, A, 3) + + pos_new = torch.where( + mask_recons[:, :, None, None].expand_as(pos_ctx), + pos_recons, pos_ctx + ) # (N, L, A, 3) + + mask_bb_atoms = torch.zeros_like(mask_atoms) + mask_bb_atoms[:, :, :4] = True + mask_new = torch.where( + mask_recons[:, :, None].expand_as(mask_atoms), + mask_bb_atoms, mask_atoms + ) + + return pos_new, mask_new + diff --git a/diffab/modules/common/layers.py b/diffab/modules/common/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..6cf78089f8c0f8f2d3b780f75dfab5a55b3841c0 --- /dev/null +++ b/diffab/modules/common/layers.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def mask_zero(mask, value): + return torch.where(mask, value, torch.zeros_like(value)) + + +def clampped_one_hot(x, num_classes): + mask = (x >= 0) & (x < num_classes) # (N, L) + x = x.clamp(min=0, max=num_classes-1) + y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C) + return y + + +class DistanceToBins(nn.Module): + + def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False): + super().__init__() + self.dist_min = dist_min + self.dist_max = dist_max + self.num_bins = num_bins + self.use_onehot = use_onehot + + if use_onehot: + offset = torch.linspace(dist_min, dist_max, self.num_bins) + else: + offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag + self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred + self.register_buffer('offset', offset) + + @property + def out_channels(self): + return self.num_bins + + def forward(self, dist, dim, normalize=True): + """ + Args: + dist: (N, *, 1, *) + Returns: + (N, *, num_bins, *) + """ + assert dist.size()[dim] == 1 + offset_shape = [1] * len(dist.size()) + offset_shape[dim] = -1 + + if self.use_onehot: + diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *) + bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *) + y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0) + else: + overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *) + y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *) + y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *) + y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *) + if normalize: + y = y / y.sum(dim=dim, keepdim=True) + + return y + + +class PositionalEncoding(nn.Module): + + def __init__(self, num_funcs=6): + super().__init__() + self.num_funcs = num_funcs + self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs)) + + def get_out_dim(self, in_dim): + return in_dim * (2 * self.num_funcs + 1) + + def forward(self, x): + """ + Args: + x: (..., d). + """ + shape = list(x.shape[:-1]) + [-1] + x = x.unsqueeze(-1) # (..., d, 1) + code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) + code = code.reshape(shape) + return code + + +class AngularEncoding(nn.Module): + + def __init__(self, num_funcs=3): + super().__init__() + self.num_funcs = num_funcs + self.register_buffer('freq_bands', torch.FloatTensor( + [i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)] + )) + + def get_out_dim(self, in_dim): + return in_dim * (1 + 2*2*self.num_funcs) + + def forward(self, x): + """ + Args: + x: (..., d). + """ + shape = list(x.shape[:-1]) + [-1] + x = x.unsqueeze(-1) # (..., d, 1) + code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) + code = code.reshape(shape) + return code + + +class LayerNorm(nn.Module): + + def __init__(self, + normal_shape, + gamma=True, + beta=True, + epsilon=1e-10): + """Layer normalization layer + See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) + :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. + :param gamma: Add a scale parameter if it is True. + :param beta: Add an offset parameter if it is True. + :param epsilon: Epsilon for calculating variance. + """ + super().__init__() + if isinstance(normal_shape, int): + normal_shape = (normal_shape,) + else: + normal_shape = (normal_shape[-1],) + self.normal_shape = torch.Size(normal_shape) + self.epsilon = epsilon + if gamma: + self.gamma = nn.Parameter(torch.Tensor(*normal_shape)) + else: + self.register_parameter('gamma', None) + if beta: + self.beta = nn.Parameter(torch.Tensor(*normal_shape)) + else: + self.register_parameter('beta', None) + self.reset_parameters() + + def reset_parameters(self): + if self.gamma is not None: + self.gamma.data.fill_(1) + if self.beta is not None: + self.beta.data.zero_() + + def forward(self, x): + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) + std = (var + self.epsilon).sqrt() + y = (x - mean) / std + if self.gamma is not None: + y *= self.gamma + if self.beta is not None: + y += self.beta + return y + + def extra_repr(self): + return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format( + self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon, + ) diff --git a/diffab/modules/common/so3.py b/diffab/modules/common/so3.py new file mode 100644 index 0000000000000000000000000000000000000000..794b65c7bff9289e0cc5b6cf2c1a37f9db2cb6f5 --- /dev/null +++ b/diffab/modules/common/so3.py @@ -0,0 +1,146 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .geometry import quaternion_to_rotation_matrix + + +def log_rotation(R): + trace = R[..., range(3), range(3)].sum(-1) + if torch.is_grad_enabled(): + # The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999 + min_cos = -0.999 + else: + min_cos = -1.0 + cos_theta = ( (trace-1) / 2 ).clamp_min(min=min_cos) + sin_theta = torch.sqrt(1 - cos_theta**2) + theta = torch.acos(cos_theta) + coef = ((theta+1e-8)/(2*sin_theta+2e-8))[..., None, None] + logR = coef * (R - R.transpose(-1, -2)) + return logR + + +def skewsym_to_so3vec(S): + x = S[..., 1, 2] + y = S[..., 2, 0] + z = S[..., 0, 1] + w = torch.stack([x,y,z], dim=-1) + return w + + +def so3vec_to_skewsym(w): + x, y, z = torch.unbind(w, dim=-1) + o = torch.zeros_like(x) + S = torch.stack([ + o, z, -y, + -z, o, x, + y, -x, o, + ], dim=-1).reshape(w.shape[:-1] + (3, 3)) + return S + + +def exp_skewsym(S): + x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1) + I = torch.eye(3).to(S).view([1 for _ in range(S.dim()-2)] + [3, 3]) + + sinx, cosx = torch.sin(x), torch.cos(x) + b = (sinx + 1e-8) / (x + 1e-8) + c = (1-cosx + 1e-8) / (x**2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5 + + S2 = S @ S + return I + b[..., None, None]*S + c[..., None, None]*S2 + + +def so3vec_to_rotation(w): + return exp_skewsym(so3vec_to_skewsym(w)) + + +def rotation_to_so3vec(R): + logR = log_rotation(R) + w = skewsym_to_so3vec(logR) + return w + + +def random_uniform_so3(size, device='cpu'): + q = F.normalize(torch.randn(list(size)+[4,], device=device), dim=-1) # (..., 4) + return rotation_to_so3vec(quaternion_to_rotation_matrix(q)) + + +class ApproxAngularDistribution(nn.Module): + + def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024): + super().__init__() + self.std_threshold = std_threshold + self.num_bins = num_bins + self.num_iters = num_iters + self.register_buffer('stddevs', torch.FloatTensor(stddevs)) + self.register_buffer('approx_flag', self.stddevs <= std_threshold) + self._precompute_histograms() + + @staticmethod + def _pdf(x, e, L): + """ + Args: + x: (N, ) + e: Float + L: Integer + """ + x = x[:, None] # (N, *) + c = ((1 - torch.cos(x)) / math.pi) # (N, *) + l = torch.arange(0, L)[None, :] # (*, L) + a = (2*l+1) * torch.exp(-l*(l+1)*(e**2)) # (*, L) + b = (torch.sin( (l+0.5)* x ) + 1e-6) / (torch.sin( x / 2 ) + 1e-6) # (N, L) + + f = (c * a * b).sum(dim=1) + return f + + def _precompute_histograms(self): + X, Y = [], [] + for std in self.stddevs: + std = std.item() + x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,) + y = self._pdf(x, std, self.num_iters) # (n_bins,) + y = torch.nan_to_num(y).clamp_min(0) + X.append(x) + Y.append(y) + self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins) + self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins) + + def sample(self, std_idx): + """ + Args: + std_idx: Indices of standard deviation. + Returns: + samples: Angular samples [0, PI), same size as std. + """ + size = std_idx.size() + std_idx = std_idx.flatten() # (N,) + + # Samples from histogram + prob = self.Y[std_idx] # (N, n_bins) + bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,) + bin_start = self.X[std_idx, bin_idx] # (N,) + bin_width = self.X[std_idx, bin_idx+1] - self.X[std_idx, bin_idx] + samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,) + + # Samples from Gaussian approximation + mean_gaussian = self.stddevs[std_idx]*2 + std_gaussian = self.stddevs[std_idx] + samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian + samples_gaussian = samples_gaussian.abs() % math.pi + + # Choose from histogram or Gaussian + gaussian_flag = self.approx_flag[std_idx] + samples = torch.where(gaussian_flag, samples_gaussian, samples_hist) + + return samples.reshape(size) + + +def random_normal_so3(std_idx, angular_distrib, device='cpu'): + size = std_idx.size() + u = F.normalize(torch.randn(list(size)+[3,], device=device), dim=-1) + theta = angular_distrib.sample(std_idx) + w = u * theta[..., None] + return w diff --git a/diffab/modules/common/structure.py b/diffab/modules/common/structure.py new file mode 100644 index 0000000000000000000000000000000000000000..afac456640869cf205cfeeeca7c656e1d5ff2d00 --- /dev/null +++ b/diffab/modules/common/structure.py @@ -0,0 +1,77 @@ +import torch +from torch.nn import Module, Linear, LayerNorm, Sequential, ReLU + +from ..common.geometry import compose_rotation_and_translation, quaternion_to_rotation_matrix, repr_6d_to_rotation_matrix + + +class FrameRotationTranslationPrediction(Module): + + def __init__(self, feat_dim, rot_repr, nn_type='mlp'): + super().__init__() + assert rot_repr in ('quaternion', '6d') + self.rot_repr = rot_repr + if rot_repr == 'quaternion': + out_dim = 3 + 3 + elif rot_repr == '6d': + out_dim = 6 + 3 + + if nn_type == 'linear': + self.nn = Linear(feat_dim, out_dim) + elif nn_type == 'mlp': + self.nn = Sequential( + Linear(feat_dim, feat_dim), ReLU(), + Linear(feat_dim, feat_dim), ReLU(), + Linear(feat_dim, out_dim) + ) + else: + raise ValueError('Unknown nn_type: %s' % nn_type) + + def forward(self, x): + y = self.nn(x) # (..., d+3) + if self.rot_repr == 'quaternion': + quaternion = torch.cat([torch.ones_like(y[..., :1]), y[..., 0:3]], dim=-1) + R_delta = quaternion_to_rotation_matrix(quaternion) + t_delta = y[..., 3:6] + return R_delta, t_delta + elif self.rot_repr == '6d': + R_delta = repr_6d_to_rotation_matrix(y[..., 0:6]) + t_delta = y[..., 6:9] + return R_delta, t_delta + + +class FrameUpdate(Module): + + def __init__(self, node_feat_dim, rot_repr='quaternion', rot_tran_nn_type='mlp'): + super().__init__() + self.transition_mlp = Sequential( + Linear(node_feat_dim, node_feat_dim), ReLU(), + Linear(node_feat_dim, node_feat_dim), ReLU(), + Linear(node_feat_dim, node_feat_dim), + ) + self.transition_layer_norm = LayerNorm(node_feat_dim) + + self.rot_tran = FrameRotationTranslationPrediction(node_feat_dim, rot_repr, nn_type=rot_tran_nn_type) + + def forward(self, R, t, x, mask_generate): + """ + Args: + R: Frame basis matrices, (N, L, 3, 3_index). + t: Frame external (absolute) coordinates, (N, L, 3). Unit: Angstrom. + x: Node-wise features, (N, L, F). + mask_generate: Masks, (N, L). + Returns: + R': Updated basis matrices, (N, L, 3, 3_index). + t': Updated coordinates, (N, L, 3). + """ + x = self.transition_layer_norm(x + self.transition_mlp(x)) + + R_delta, t_delta = self.rot_tran(x) # (N, L, 3, 3), (N, L, 3) + R_new, t_new = compose_rotation_and_translation(R, t, R_delta, t_delta) + + mask_R = mask_generate[:, :, None, None].expand_as(R) + mask_t = mask_generate[:, :, None].expand_as(t) + + R_new = torch.where(mask_R, R_new, R) + t_new = torch.where(mask_t, t_new, t) + + return R_new, t_new diff --git a/diffab/modules/common/topology.py b/diffab/modules/common/topology.py new file mode 100644 index 0000000000000000000000000000000000000000..c1249882e86b8ab5d06d2f360b4502b344b9a0c7 --- /dev/null +++ b/diffab/modules/common/topology.py @@ -0,0 +1,24 @@ +import torch +import torch.nn.functional as F + + +def get_consecutive_flag(chain_nb, res_nb, mask): + """ + Args: + chain_nb, res_nb + Returns: + consec: A flag tensor indicating whether residue-i is connected to residue-(i+1), + BoolTensor, (B, L-1)[b, i]. + """ + d_res_nb = (res_nb[:, 1:] - res_nb[:, :-1]).abs() # (B, L-1) + same_chain = (chain_nb[:, 1:] == chain_nb[:, :-1]) + consec = torch.logical_and(d_res_nb == 1, same_chain) + consec = torch.logical_and(consec, mask[:, :-1]) + return consec + + +def get_terminus_flag(chain_nb, res_nb, mask): + consec = get_consecutive_flag(chain_nb, res_nb, mask) + N_term_flag = F.pad(torch.logical_not(consec), pad=(1, 0), value=1) + C_term_flag = F.pad(torch.logical_not(consec), pad=(0, 1), value=1) + return N_term_flag, C_term_flag diff --git a/diffab/modules/diffusion/dpm_full.py b/diffab/modules/diffusion/dpm_full.py new file mode 100644 index 0000000000000000000000000000000000000000..49fe30db80a76deaf7d0a011dbd8116cf4e27b0e --- /dev/null +++ b/diffab/modules/diffusion/dpm_full.py @@ -0,0 +1,319 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +from tqdm.auto import tqdm + +from diffab.modules.common.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix +from diffab.modules.common.so3 import so3vec_to_rotation, rotation_to_so3vec, random_uniform_so3 +from diffab.modules.encoders.ga import GAEncoder +from .transition import RotationTransition, PositionTransition, AminoacidCategoricalTransition + + +def rotation_matrix_cosine_loss(R_pred, R_true): + """ + Args: + R_pred: (*, 3, 3). + R_true: (*, 3, 3). + Returns: + Per-matrix losses, (*, ). + """ + size = list(R_pred.shape[:-2]) + ncol = R_pred.numel() // 3 + + RT_pred = R_pred.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3) + RT_true = R_true.transpose(-2, -1).reshape(ncol, 3) # (ncol, 3) + + ones = torch.ones([ncol, ], dtype=torch.long, device=R_pred.device) + loss = F.cosine_embedding_loss(RT_pred, RT_true, ones, reduction='none') # (ncol*3, ) + loss = loss.reshape(size + [3]).sum(dim=-1) # (*, ) + return loss + + +class EpsilonNet(nn.Module): + + def __init__(self, res_feat_dim, pair_feat_dim, num_layers, encoder_opt={}): + super().__init__() + self.current_sequence_embedding = nn.Embedding(25, res_feat_dim) # 22 is padding + self.res_feat_mixer = nn.Sequential( + nn.Linear(res_feat_dim * 2, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, res_feat_dim), + ) + self.encoder = GAEncoder(res_feat_dim, pair_feat_dim, num_layers, **encoder_opt) + + self.eps_crd_net = nn.Sequential( + nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, 3) + ) + + self.eps_rot_net = nn.Sequential( + nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, 3) + ) + + self.eps_seq_net = nn.Sequential( + nn.Linear(res_feat_dim+3, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, res_feat_dim), nn.ReLU(), + nn.Linear(res_feat_dim, 20), nn.Softmax(dim=-1) + ) + + def forward(self, v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res): + """ + Args: + v_t: (N, L, 3). + p_t: (N, L, 3). + s_t: (N, L). + res_feat: (N, L, res_dim). + pair_feat: (N, L, L, pair_dim). + beta: (N,). + mask_generate: (N, L). + mask_res: (N, L). + Returns: + v_next: UPDATED (not epsilon) SO3-vector of orietnations, (N, L, 3). + eps_pos: (N, L, 3). + """ + N, L = mask_res.size() + R = so3vec_to_rotation(v_t) # (N, L, 3, 3) + + # s_t = s_t.clamp(min=0, max=19) # TODO: clamping is good but ugly. + res_feat = self.res_feat_mixer(torch.cat([res_feat, self.current_sequence_embedding(s_t)], dim=-1)) # [Important] Incorporate sequence at the current step. + res_feat = self.encoder(R, p_t, res_feat, pair_feat, mask_res) + + t_embed = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1)[:, None, :].expand(N, L, 3) + in_feat = torch.cat([res_feat, t_embed], dim=-1) + + # Position changes + eps_crd = self.eps_crd_net(in_feat) # (N, L, 3) + eps_pos = apply_rotation_to_vector(R, eps_crd) # (N, L, 3) + eps_pos = torch.where(mask_generate[:, :, None].expand_as(eps_pos), eps_pos, torch.zeros_like(eps_pos)) + + # New orientation + eps_rot = self.eps_rot_net(in_feat) # (N, L, 3) + U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (N, L, 3, 3) + R_next = R @ U + v_next = rotation_to_so3vec(R_next) # (N, L, 3) + v_next = torch.where(mask_generate[:, :, None].expand_as(v_next), v_next, v_t) + + # New sequence categorical distributions + c_denoised = self.eps_seq_net(in_feat) # Already softmax-ed, (N, L, 20) + + return v_next, R_next, eps_pos, c_denoised + + +class FullDPM(nn.Module): + + def __init__( + self, + res_feat_dim, + pair_feat_dim, + num_steps, + eps_net_opt={}, + trans_rot_opt={}, + trans_pos_opt={}, + trans_seq_opt={}, + position_mean=[0.0, 0.0, 0.0], + position_scale=[10.0], + ): + super().__init__() + self.eps_net = EpsilonNet(res_feat_dim, pair_feat_dim, **eps_net_opt) + self.num_steps = num_steps + self.trans_rot = RotationTransition(num_steps, **trans_rot_opt) + self.trans_pos = PositionTransition(num_steps, **trans_pos_opt) + self.trans_seq = AminoacidCategoricalTransition(num_steps, **trans_seq_opt) + + self.register_buffer('position_mean', torch.FloatTensor(position_mean).view(1, 1, -1)) + self.register_buffer('position_scale', torch.FloatTensor(position_scale).view(1, 1, -1)) + self.register_buffer('_dummy', torch.empty([0, ])) + + def _normalize_position(self, p): + p_norm = (p - self.position_mean) / self.position_scale + return p_norm + + def _unnormalize_position(self, p_norm): + p = p_norm * self.position_scale + self.position_mean + return p + + def forward(self, v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, denoise_structure, denoise_sequence, t=None): + N, L = res_feat.shape[:2] + if t == None: + t = torch.randint(0, self.num_steps, (N,), dtype=torch.long, device=self._dummy.device) + p_0 = self._normalize_position(p_0) + + if denoise_structure: + # Add noise to rotation + R_0 = so3vec_to_rotation(v_0) + v_noisy, _ = self.trans_rot.add_noise(v_0, mask_generate, t) + # Add noise to positions + p_noisy, eps_p = self.trans_pos.add_noise(p_0, mask_generate, t) + else: + R_0 = so3vec_to_rotation(v_0) + v_noisy = v_0.clone() + p_noisy = p_0.clone() + eps_p = torch.zeros_like(p_noisy) + + if denoise_sequence: + # Add noise to sequence + _, s_noisy = self.trans_seq.add_noise(s_0, mask_generate, t) + else: + s_noisy = s_0.clone() + + beta = self.trans_pos.var_sched.betas[t] + v_pred, R_pred, eps_p_pred, c_denoised = self.eps_net( + v_noisy, p_noisy, s_noisy, res_feat, pair_feat, beta, mask_generate, mask_res + ) # (N, L, 3), (N, L, 3, 3), (N, L, 3), (N, L, 20), (N, L) + + loss_dict = {} + + # Rotation loss + loss_rot = rotation_matrix_cosine_loss(R_pred, R_0) # (N, L) + loss_rot = (loss_rot * mask_generate).sum() / (mask_generate.sum().float() + 1e-8) + loss_dict['rot'] = loss_rot + + # Position loss + loss_pos = F.mse_loss(eps_p_pred, eps_p, reduction='none').sum(dim=-1) # (N, L) + loss_pos = (loss_pos * mask_generate).sum() / (mask_generate.sum().float() + 1e-8) + loss_dict['pos'] = loss_pos + + # Sequence categorical loss + post_true = self.trans_seq.posterior(s_noisy, s_0, t) + log_post_pred = torch.log(self.trans_seq.posterior(s_noisy, c_denoised, t) + 1e-8) + kldiv = F.kl_div( + input=log_post_pred, + target=post_true, + reduction='none', + log_target=False + ).sum(dim=-1) # (N, L) + loss_seq = (kldiv * mask_generate).sum() / (mask_generate.sum().float() + 1e-8) + loss_dict['seq'] = loss_seq + + return loss_dict + + @torch.no_grad() + def sample( + self, + v, p, s, + res_feat, pair_feat, + mask_generate, mask_res, + sample_structure=True, sample_sequence=True, + pbar=False, + ): + """ + Args: + v: Orientations of contextual residues, (N, L, 3). + p: Positions of contextual residues, (N, L, 3). + s: Sequence of contextual residues, (N, L). + """ + N, L = v.shape[:2] + p = self._normalize_position(p) + + # Set the orientation and position of residues to be predicted to random values + if sample_structure: + v_rand = random_uniform_so3([N, L], device=self._dummy.device) + p_rand = torch.randn_like(p) + v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_rand, v) + p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_rand, p) + else: + v_init, p_init = v, p + + if sample_sequence: + s_rand = torch.randint_like(s, low=0, high=19) + s_init = torch.where(mask_generate, s_rand, s) + else: + s_init = s + + traj = {self.num_steps: (v_init, self._unnormalize_position(p_init), s_init)} + if pbar: + pbar = functools.partial(tqdm, total=self.num_steps, desc='Sampling') + else: + pbar = lambda x: x + for t in pbar(range(self.num_steps, 0, -1)): + v_t, p_t, s_t = traj[t] + p_t = self._normalize_position(p_t) + + beta = self.trans_pos.var_sched.betas[t].expand([N, ]) + t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device) + + v_next, R_next, eps_p, c_denoised = self.eps_net( + v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res + ) # (N, L, 3), (N, L, 3, 3), (N, L, 3) + + v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor) + p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor) + _, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor) + + if not sample_structure: + v_next, p_next = v_t, p_t + if not sample_sequence: + s_next = s_t + + traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next) + traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory. + + return traj + + @torch.no_grad() + def optimize( + self, + v, p, s, + opt_step: int, + res_feat, pair_feat, + mask_generate, mask_res, + sample_structure=True, sample_sequence=True, + pbar=False, + ): + """ + Description: + First adds noise to the given structure, then denoises it. + """ + N, L = v.shape[:2] + p = self._normalize_position(p) + t = torch.full([N, ], fill_value=opt_step, dtype=torch.long, device=self._dummy.device) + + # Set the orientation and position of residues to be predicted to random values + if sample_structure: + # Add noise to rotation + v_noisy, _ = self.trans_rot.add_noise(v, mask_generate, t) + # Add noise to positions + p_noisy, _ = self.trans_pos.add_noise(p, mask_generate, t) + v_init = torch.where(mask_generate[:, :, None].expand_as(v), v_noisy, v) + p_init = torch.where(mask_generate[:, :, None].expand_as(p), p_noisy, p) + else: + v_init, p_init = v, p + + if sample_sequence: + _, s_noisy = self.trans_seq.add_noise(s, mask_generate, t) + s_init = torch.where(mask_generate, s_noisy, s) + else: + s_init = s + + traj = {opt_step: (v_init, self._unnormalize_position(p_init), s_init)} + if pbar: + pbar = functools.partial(tqdm, total=opt_step, desc='Optimizing') + else: + pbar = lambda x: x + for t in pbar(range(opt_step, 0, -1)): + v_t, p_t, s_t = traj[t] + p_t = self._normalize_position(p_t) + + beta = self.trans_pos.var_sched.betas[t].expand([N, ]) + t_tensor = torch.full([N, ], fill_value=t, dtype=torch.long, device=self._dummy.device) + + v_next, R_next, eps_p, c_denoised = self.eps_net( + v_t, p_t, s_t, res_feat, pair_feat, beta, mask_generate, mask_res + ) # (N, L, 3), (N, L, 3, 3), (N, L, 3) + + v_next = self.trans_rot.denoise(v_t, v_next, mask_generate, t_tensor) + p_next = self.trans_pos.denoise(p_t, eps_p, mask_generate, t_tensor) + _, s_next = self.trans_seq.denoise(s_t, c_denoised, mask_generate, t_tensor) + + if not sample_structure: + v_next, p_next = v_t, p_t + if not sample_sequence: + s_next = s_t + + traj[t-1] = (v_next, self._unnormalize_position(p_next), s_next) + traj[t] = tuple(x.cpu() for x in traj[t]) # Move previous states to cpu memory. + + return traj diff --git a/diffab/modules/diffusion/transition.py b/diffab/modules/diffusion/transition.py new file mode 100644 index 0000000000000000000000000000000000000000..80ef6cc03a11a5241a47f762c82134cf535f8ed6 --- /dev/null +++ b/diffab/modules/diffusion/transition.py @@ -0,0 +1,223 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffab.modules.common.layers import clampped_one_hot +from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec + + +class VarianceSchedule(nn.Module): + + def __init__(self, num_steps=100, s=0.01): + super().__init__() + T = num_steps + t = torch.arange(0, num_steps+1, dtype=torch.float) + f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2 + alpha_bars = f_t / f_t[0] + + betas = 1 - (alpha_bars[1:] / alpha_bars[:-1]) + betas = torch.cat([torch.zeros([1]), betas], dim=0) + betas = betas.clamp_max(0.999) + + sigmas = torch.zeros_like(betas) + for i in range(1, betas.size(0)): + sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] + sigmas = torch.sqrt(sigmas) + + self.register_buffer('betas', betas) + self.register_buffer('alpha_bars', alpha_bars) + self.register_buffer('alphas', 1 - betas) + self.register_buffer('sigmas', sigmas) + + +class PositionTransition(nn.Module): + + def __init__(self, num_steps, var_sched_opt={}): + super().__init__() + self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) + + def add_noise(self, p_0, mask_generate, t): + """ + Args: + p_0: (N, L, 3). + mask_generate: (N, L). + t: (N,). + """ + alpha_bar = self.var_sched.alpha_bars[t] + + c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) + c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) + + e_rand = torch.randn_like(p_0) + p_noisy = c0*p_0 + c1*e_rand + p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0) + + return p_noisy, e_rand + + def denoise(self, p_t, eps_p, mask_generate, t): + # IMPORTANT: + # clampping alpha is to fix the instability issue at the first step (t=T) + # it seems like a problem with the ``improved ddpm''. + alpha = self.var_sched.alphas[t].clamp_min( + self.var_sched.alphas[-2] + ) + alpha_bar = self.var_sched.alpha_bars[t] + sigma = self.var_sched.sigmas[t].view(-1, 1, 1) + + c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1) + c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1) + + z = torch.where( + (t > 1)[:, None, None].expand_as(p_t), + torch.randn_like(p_t), + torch.zeros_like(p_t), + ) + + p_next = c0 * (p_t - c1 * eps_p) + sigma * z + p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t) + return p_next + + +class RotationTransition(nn.Module): + + def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}): + super().__init__() + self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) + + # Forward (perturb) + c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,). + self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt) + + # Inverse (generate) + sigma = self.var_sched.sigmas + self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt) + + self.register_buffer('_dummy', torch.empty([0, ])) + + def add_noise(self, v_0, mask_generate, t): + """ + Args: + v_0: (N, L, 3). + mask_generate: (N, L). + t: (N,). + """ + N, L = mask_generate.size() + alpha_bar = self.var_sched.alpha_bars[t] + c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) + c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) + + # Noise rotation + e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3) + e_normal = e_scaled / (c1 + 1e-8) + E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3) + + # Scaled true rotation + R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3) + + R_noisy = E_scaled @ R0_scaled + v_noisy = rotation_to_so3vec(R_noisy) + v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0) + + return v_noisy, e_scaled + + def denoise(self, v_t, v_next, mask_generate, t): + N, L = mask_generate.size() + e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3) + e = torch.where( + (t > 1)[:, None, None].expand(N, L, 3), + e, + torch.zeros_like(e) # Simply denoise and don't add noise at the last step + ) + E = so3vec_to_rotation(e) + + R_next = E @ so3vec_to_rotation(v_next) + v_next = rotation_to_so3vec(R_next) + v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t) + + return v_next + + +class AminoacidCategoricalTransition(nn.Module): + + def __init__(self, num_steps, num_classes=20, var_sched_opt={}): + super().__init__() + self.num_classes = num_classes + self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) + + @staticmethod + def _sample(c): + """ + Args: + c: (N, L, K). + Returns: + x: (N, L). + """ + N, L, K = c.size() + c = c.view(N*L, K) + 1e-8 + x = torch.multinomial(c, 1).view(N, L) + return x + + def add_noise(self, x_0, mask_generate, t): + """ + Args: + x_0: (N, L) + mask_generate: (N, L). + t: (N,). + Returns: + c_t: Probability, (N, L, K). + x_t: Sample, LongTensor, (N, L). + """ + N, L = x_0.size() + K = self.num_classes + c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K). + alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) + c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K ) + c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0) + x_t = self._sample(c_t) + return c_t, x_t + + def posterior(self, x_t, x_0, t): + """ + Args: + x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). + x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). + t: (N,). + Returns: + theta: Posterior probability at (t-1)-th step, (N, L, K). + """ + K = self.num_classes + + if x_t.dim() == 3: + c_t = x_t # When x_t is probability distribution. + else: + c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K) + + if x_0.dim() == 3: + c_0 = x_0 # When x_0 is probability distribution. + else: + c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K) + + alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) + alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) + + theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K) + theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8) + return theta + + def denoise(self, x_t, c_0_pred, mask_generate, t): + """ + Args: + x_t: (N, L). + c_0_pred: Normalized probability predicted by networks, (N, L, K). + mask_generate: (N, L). + t: (N,). + Returns: + post: Posterior probability at (t-1)-th step, (N, L, K). + x_next: Sample at (t-1)-th step, LongTensor, (N, L). + """ + c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K) + post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K) + post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t) + x_next = self._sample(post) + return post, x_next diff --git a/diffab/modules/encoders/ga.py b/diffab/modules/encoders/ga.py new file mode 100644 index 0000000000000000000000000000000000000000..c8679971f065484bbc7b49df2caf066d49527880 --- /dev/null +++ b/diffab/modules/encoders/ga.py @@ -0,0 +1,193 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm +from diffab.modules.common.layers import mask_zero, LayerNorm +from diffab.utils.protein.constants import BBHeavyAtom + + +def _alpha_from_logits(logits, mask, inf=1e5): + """ + Args: + logits: Logit matrices, (N, L_i, L_j, num_heads). + mask: Masks, (N, L). + Returns: + alpha: Attention weights. + """ + N, L, _, _ = logits.size() + mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *) + mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *) + + logits = torch.where(mask_pair, logits, logits - inf) + alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads) + alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) + return alpha + + +def _heads(x, n_heads, n_ch): + """ + Args: + x: (..., num_heads * num_channels) + Returns: + (..., num_heads, num_channels) + """ + s = list(x.size())[:-1] + [n_heads, n_ch] + return x.view(*s) + + +class GABlock(nn.Module): + + def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8, + num_value_points=8, num_heads=12, bias=False): + super().__init__() + self.node_feat_dim = node_feat_dim + self.pair_feat_dim = pair_feat_dim + self.value_dim = value_dim + self.query_key_dim = query_key_dim + self.num_query_points = num_query_points + self.num_value_points = num_value_points + self.num_heads = num_heads + + # Node + self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) + self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) + self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias) + + # Pair + self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias) + + # Spatial + self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)), + requires_grad=True) + self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) + self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) + self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias) + + # Output + self.out_transform = nn.Linear( + in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + ( + num_heads * num_value_points * (3 + 3 + 1)), + out_features=node_feat_dim, + ) + + self.layer_norm_1 = LayerNorm(node_feat_dim) + self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), + nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), + nn.Linear(node_feat_dim, node_feat_dim)) + self.layer_norm_2 = LayerNorm(node_feat_dim) + + def _node_logits(self, x): + query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) + key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) + logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) * + (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads) + return logits_node + + def _pair_logits(self, z): + logits_pair = self.proj_pair_bias(z) + return logits_pair + + def _spatial_logits(self, R, t, x): + N, L, _ = t.size() + + # Query + query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points, + 3) # (N, L, n_heads * n_pnts, 3) + query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3) + query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) + + # Key + key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points, + 3) # (N, L, 3, n_heads * n_pnts) + key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3) + key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) + + # Q-K Product + sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads) + gamma = F.softplus(self.spatial_coef) + logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points))) + / 2) # (N, L, L, n_heads) + return logits_spatial + + def _pair_aggregation(self, alpha, z): + N, L = z.shape[:2] + feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C) + feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C) + return feat_p2n.reshape(N, L, -1) + + def _node_aggregation(self, alpha, x): + N, L = x.shape[:2] + value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch) + feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch) + feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch) + return feat_node.reshape(N, L, -1) + + def _spatial_aggregation(self, alpha, R, t, x): + N, L, _ = t.size() + value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points, + 3) # (N, L, n_heads * n_v_pnts, 3) + value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points, + 3)) # (N, L, n_heads, n_v_pnts, 3) + aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \ + value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3) + aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3) + + feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3) + feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts) + feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3) + + feat_spatial = torch.cat([ + feat_points.reshape(N, L, -1), + feat_distance.reshape(N, L, -1), + feat_direction.reshape(N, L, -1), + ], dim=-1) + + return feat_spatial + + def forward(self, R, t, x, z, mask): + """ + Args: + R: Frame basis matrices, (N, L, 3, 3_index). + t: Frame external (absolute) coordinates, (N, L, 3). + x: Node-wise features, (N, L, F). + z: Pair-wise features, (N, L, L, C). + mask: Masks, (N, L). + Returns: + x': Updated node-wise features, (N, L, F). + """ + # Attention logits + logits_node = self._node_logits(x) + logits_pair = self._pair_logits(z) + logits_spatial = self._spatial_logits(R, t, x) + # Summing logits up and apply `softmax`. + logits_sum = logits_node + logits_pair + logits_spatial + alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads) + + # Aggregate features + feat_p2n = self._pair_aggregation(alpha, z) + feat_node = self._node_aggregation(alpha, x) + feat_spatial = self._spatial_aggregation(alpha, R, t, x) + + # Finally + feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F) + feat_all = mask_zero(mask.unsqueeze(-1), feat_all) + x_updated = self.layer_norm_1(x + feat_all) + x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated)) + return x_updated + + +class GAEncoder(nn.Module): + + def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}): + super(GAEncoder, self).__init__() + self.blocks = nn.ModuleList([ + GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt) + for _ in range(num_layers) + ]) + + def forward(self, R, t, res_feat, pair_feat, mask): + for i, block in enumerate(self.blocks): + res_feat = block(R, t, res_feat, pair_feat, mask) + return res_feat diff --git a/diffab/modules/encoders/pair.py b/diffab/modules/encoders/pair.py new file mode 100644 index 0000000000000000000000000000000000000000..becf8c6b1a102730c8ae224d93fc94f2ad494898 --- /dev/null +++ b/diffab/modules/encoders/pair.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffab.modules.common.geometry import angstrom_to_nm, pairwise_dihedrals +from diffab.modules.common.layers import AngularEncoding +from diffab.utils.protein.constants import BBHeavyAtom, AA + + +class PairEmbedding(nn.Module): + + def __init__(self, feat_dim, max_num_atoms, max_aa_types=22, max_relpos=32): + super().__init__() + self.max_num_atoms = max_num_atoms + self.max_aa_types = max_aa_types + self.max_relpos = max_relpos + self.aa_pair_embed = nn.Embedding(self.max_aa_types*self.max_aa_types, feat_dim) + self.relpos_embed = nn.Embedding(2*max_relpos+1, feat_dim) + + self.aapair_to_distcoef = nn.Embedding(self.max_aa_types*self.max_aa_types, max_num_atoms*max_num_atoms) + nn.init.zeros_(self.aapair_to_distcoef.weight) + self.distance_embed = nn.Sequential( + nn.Linear(max_num_atoms*max_num_atoms, feat_dim), nn.ReLU(), + nn.Linear(feat_dim, feat_dim), nn.ReLU(), + ) + + self.dihedral_embed = AngularEncoding() + feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi + + infeat_dim = feat_dim+feat_dim+feat_dim+feat_dihed_dim + self.out_mlp = nn.Sequential( + nn.Linear(infeat_dim, feat_dim), nn.ReLU(), + nn.Linear(feat_dim, feat_dim), nn.ReLU(), + nn.Linear(feat_dim, feat_dim), + ) + + def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None): + """ + Args: + aa: (N, L). + res_nb: (N, L). + chain_nb: (N, L). + pos_atoms: (N, L, A, 3) + mask_atoms: (N, L, A) + structure_mask: (N, L) + sequence_mask: (N, L), mask out unknown amino acids to generate. + + Returns: + (N, L, L, feat_dim) + """ + N, L = aa.size() + + # Remove other atoms + pos_atoms = pos_atoms[:, :, :self.max_num_atoms] + mask_atoms = mask_atoms[:, :, :self.max_num_atoms] + + mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L) + mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :] + pair_structure_mask = structure_mask[:, :, None] * structure_mask[:, None, :] if structure_mask is not None else None + + # Pair identities + if sequence_mask is not None: + # Avoid data leakage at training time + aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK)) + aa_pair = aa[:,:,None]*self.max_aa_types + aa[:,None,:] # (N, L, L) + feat_aapair = self.aa_pair_embed(aa_pair) + + # Relative sequential positions + same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :]) + relpos = torch.clamp( + res_nb[:,:,None] - res_nb[:,None,:], + min=-self.max_relpos, max=self.max_relpos, + ) # (N, L, L) + feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:,:,:,None] + + # Distances + d = angstrom_to_nm(torch.linalg.norm( + pos_atoms[:,:,None,:,None] - pos_atoms[:,None,:,None,:], + dim = -1, ord = 2, + )).reshape(N, L, L, -1) # (N, L, L, A*A) + c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A) + d_gauss = torch.exp(-1 * c * d**2) + mask_atom_pair = (mask_atoms[:,:,None,:,None] * mask_atoms[:,None,:,None,:]).reshape(N, L, L, -1) + feat_dist = self.distance_embed(d_gauss * mask_atom_pair) + if pair_structure_mask is not None: + # Avoid data leakage at training time + feat_dist = feat_dist * pair_structure_mask[:, :, :, None] + + # Orientations + dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2) + feat_dihed = self.dihedral_embed(dihed) + if pair_structure_mask is not None: + # Avoid data leakage at training time + feat_dihed = feat_dihed * pair_structure_mask[:, :, :, None] + + # All + feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1) + feat_all = self.out_mlp(feat_all) # (N, L, L, F) + feat_all = feat_all * mask_pair[:, :, :, None] + + return feat_all + diff --git a/diffab/modules/encoders/residue.py b/diffab/modules/encoders/residue.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b6ffb7a814e78178a083868764abdced6e5618 --- /dev/null +++ b/diffab/modules/encoders/residue.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +from diffab.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles +from diffab.modules.common.layers import AngularEncoding +from diffab.utils.protein.constants import BBHeavyAtom, AA + + +class ResidueEmbedding(nn.Module): + + def __init__(self, feat_dim, max_num_atoms, max_aa_types=22): + super().__init__() + self.max_num_atoms = max_num_atoms + self.max_aa_types = max_aa_types + self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim) + self.dihed_embed = AngularEncoding() + self.type_embed = nn.Embedding(10, feat_dim, padding_idx=0) # 1: Heavy, 2: Light, 3: Ag + infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3) + feat_dim + self.mlp = nn.Sequential( + nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(), + nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(), + nn.Linear(feat_dim, feat_dim), nn.ReLU(), + nn.Linear(feat_dim, feat_dim) + ) + + def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, fragment_type, structure_mask=None, sequence_mask=None): + """ + Args: + aa: (N, L). + res_nb: (N, L). + chain_nb: (N, L). + pos_atoms: (N, L, A, 3). + mask_atoms: (N, L, A). + fragment_type: (N, L). + structure_mask: (N, L), mask out unknown structures to generate. + sequence_mask: (N, L), mask out unknown amino acids to generate. + """ + N, L = aa.size() + mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L) + + # Remove other atoms + pos_atoms = pos_atoms[:, :, :self.max_num_atoms] + mask_atoms = mask_atoms[:, :, :self.max_num_atoms] + + # Amino acid identity features + if sequence_mask is not None: + # Avoid data leakage at training time + aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK)) + aa_feat = self.aatype_embed(aa) # (N, L, feat) + + # Coordinate features + R = construct_3d_basis( + pos_atoms[:, :, BBHeavyAtom.CA], + pos_atoms[:, :, BBHeavyAtom.C], + pos_atoms[:, :, BBHeavyAtom.N] + ) + t = pos_atoms[:, :, BBHeavyAtom.CA] + crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3) + crd_mask = mask_atoms[:, :, :, None].expand_as(crd) + crd = torch.where(crd_mask, crd, torch.zeros_like(crd)) + + aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3) + rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand) + place_mask = (aa_expand == rng_expand) + crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3) + crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand)) + crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3) + if structure_mask is not None: + # Avoid data leakage at training time + crd_feat = crd_feat * structure_mask[:, :, None] + + # Backbone dihedral features + bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue) + dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3) + dihed_feat = dihed_feat.reshape(N, L, -1) + if structure_mask is not None: + # Avoid data leakage at training time + dihed_mask = torch.logical_and( + structure_mask, + torch.logical_and( + torch.roll(structure_mask, shifts=+1, dims=1), + torch.roll(structure_mask, shifts=-1, dims=1) + ), + ) # Avoid slight data leakage via dihedral angles of anchor residues + dihed_feat = dihed_feat * dihed_mask[:, :, None] + + # Type feature + type_feat = self.type_embed(fragment_type) # (N, L, feat) + + out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat, type_feat], dim=-1)) # (N, L, F) + out_feat = out_feat * mask_residue[:, :, None] + return out_feat diff --git a/diffab/tools/dock/base.py b/diffab/tools/dock/base.py new file mode 100644 index 0000000000000000000000000000000000000000..68e52b6d33bf45433fccaffe5481af9d21f15bc4 --- /dev/null +++ b/diffab/tools/dock/base.py @@ -0,0 +1,28 @@ +import abc +from typing import List + + +FilePath = str + + +class DockingEngine(abc.ABC): + + @abc.abstractmethod + def __enter__(self): + pass + + @abc.abstractmethod + def __exit__(self, typ, value, traceback): + pass + + @abc.abstractmethod + def set_receptor(self, pdb_path: FilePath): + pass + + @abc.abstractmethod + def set_ligand(self, pdb_path: FilePath): + pass + + @abc.abstractmethod + def dock(self) -> List[FilePath]: + pass diff --git a/diffab/tools/dock/hdock.py b/diffab/tools/dock/hdock.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3fbd7c2bdb7d5cab4ffb317aa3a39bca6ed5e6 --- /dev/null +++ b/diffab/tools/dock/hdock.py @@ -0,0 +1,164 @@ +import os +import shutil +import tempfile +import subprocess +import dataclasses as dc +from typing import List, Optional +from Bio import PDB +from Bio.PDB import Model as PDBModel + +from diffab.tools.renumber import renumber as renumber_chothia +from .base import DockingEngine + + +def fix_docked_pdb(pdb_path): + fixed = [] + with open(pdb_path, 'r') as f: + for ln in f.readlines(): + if (ln.startswith('ATOM') or ln.startswith('HETATM')) and len(ln) == 56: + fixed.append( ln[:-1] + ' 1.00 0.00 \n' ) + else: + fixed.append(ln) + with open(pdb_path, 'w') as f: + f.write(''.join(fixed)) + + +class HDock(DockingEngine): + + def __init__( + self, + hdock_bin='./bin/hdock', + createpl_bin='./bin/createpl', + ): + super().__init__() + self.hdock_bin = os.path.realpath(hdock_bin) + self.createpl_bin = os.path.realpath(createpl_bin) + self.tmpdir = tempfile.TemporaryDirectory() + + self._has_receptor = False + self._has_ligand = False + + self._receptor_chains = [] + self._ligand_chains = [] + + def __enter__(self): + return self + + def __exit__(self, typ, value, traceback): + self.tmpdir.cleanup() + + def set_receptor(self, pdb_path): + shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'receptor.pdb')) + self._has_receptor = True + + def set_ligand(self, pdb_path): + shutil.copyfile(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb')) + self._has_ligand = True + + def _dump_complex_pdb(self): + parser = PDB.PDBParser(QUIET=True) + model_receptor = parser.get_structure(None, os.path.join(self.tmpdir.name, 'receptor.pdb'))[0] + docked_pdb_path = os.path.join(self.tmpdir.name, 'ligand_docked.pdb') + fix_docked_pdb(docked_pdb_path) + structure_ligdocked = parser.get_structure(None, docked_pdb_path) + + pdb_io = PDB.PDBIO() + paths = [] + for i, model_ligdocked in enumerate(structure_ligdocked): + model_complex = PDBModel.Model(0) + for chain in model_receptor: + model_complex.add(chain.copy()) + for chain in model_ligdocked: + model_complex.add(chain.copy()) + pdb_io.set_structure(model_complex) + save_path = os.path.join(self.tmpdir.name, f"complex_{i}.pdb") + pdb_io.save(save_path) + paths.append(save_path) + return paths + + def dock(self): + if not (self._has_receptor and self._has_ligand): + raise ValueError('Missing receptor or ligand.') + subprocess.run( + [self.hdock_bin, "receptor.pdb", "ligand.pdb"], + cwd=self.tmpdir.name, check=True + ) + subprocess.run( + [self.createpl_bin, "Hdock.out", "ligand_docked.pdb"], + cwd=self.tmpdir.name, check=True + ) + return self._dump_complex_pdb() + + +@dc.dataclass +class DockSite: + chain: str + resseq: int + + +class HDockAntibody(HDock): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._heavy_chain_id = None + self._epitope_sites: Optional[List[DockSite]] = None + + def set_ligand(self, pdb_path): + raise NotImplementedError('Please use set_antibody') + + def set_receptor(self, pdb_path): + raise NotImplementedError('Please use set_antigen') + + def set_antigen(self, pdb_path, epitope_sites: Optional[List[DockSite]]=None): + super().set_receptor(pdb_path) + self._epitope_sites = epitope_sites + + def set_antibody(self, pdb_path): + heavy_chains, _ = renumber_chothia(pdb_path, os.path.join(self.tmpdir.name, 'ligand.pdb')) + self._has_ligand = True + self._heavy_chain_id = heavy_chains[0] + + def _prepare_lsite(self): + lsite_content = f"95-102:{self._heavy_chain_id}\n" # Chothia CDR H3 + with open(os.path.join(self.tmpdir.name, 'lsite.txt'), 'w') as f: + f.write(lsite_content) + print(f"[INFO] lsite content: {lsite_content}") + + def _prepare_rsite(self): + rsite_content = "" + for site in self._epitope_sites: + rsite_content += f"{site.resseq}:{site.chain}\n" + with open(os.path.join(self.tmpdir.name, 'rsite.txt'), 'w') as f: + f.write(rsite_content) + print(f"[INFO] rsite content: {rsite_content}") + + def dock(self): + if not (self._has_receptor and self._has_ligand): + raise ValueError('Missing receptor or ligand.') + self._prepare_lsite() + + cmd_hdock = [self.hdock_bin, "receptor.pdb", "ligand.pdb", "-lsite", "lsite.txt"] + if self._epitope_sites is not None: + self._prepare_rsite() + cmd_hdock += ["-rsite", "rsite.txt"] + subprocess.run( + cmd_hdock, + cwd=self.tmpdir.name, check=True + ) + + cmd_pl = [self.createpl_bin, "Hdock.out", "ligand_docked.pdb", "-lsite", "lsite.txt"] + if self._epitope_sites is not None: + self._prepare_rsite() + cmd_pl += ["-rsite", "rsite.txt"] + subprocess.run( + cmd_pl, + cwd=self.tmpdir.name, check=True + ) + return self._dump_complex_pdb() + + +if __name__ == '__main__': + with HDockAntibody('hdock', 'createpl') as dock: + dock.set_antigen('./data/dock/receptor.pdb', [DockSite('A', 991)]) + dock.set_antibody('./data/example_dock/3qhf_fv.pdb') + print(dock.dock()) diff --git a/diffab/tools/eval/__main__.py b/diffab/tools/eval/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcdeb82d00dcc46488d4ff37b67e21f342de368 --- /dev/null +++ b/diffab/tools/eval/__main__.py @@ -0,0 +1,4 @@ +from .run import main + +if __name__ == '__main__': + main() diff --git a/diffab/tools/eval/base.py b/diffab/tools/eval/base.py new file mode 100644 index 0000000000000000000000000000000000000000..867eaa9aab04fd54ef1c3c8300e3f63ad05ff8f0 --- /dev/null +++ b/diffab/tools/eval/base.py @@ -0,0 +1,125 @@ +import os +import re +import json +import shelve +from Bio import PDB +from typing import Optional, Tuple, List +from dataclasses import dataclass, field + + +@dataclass +class EvalTask: + in_path: str + ref_path: str + info: dict + structure: str + name: str + method: str + cdr: str + ab_chains: List + + residue_first: Optional[Tuple] = None + residue_last: Optional[Tuple] = None + + scores: dict = field(default_factory=dict) + + def get_gen_biopython_model(self): + parser = PDB.PDBParser(QUIET=True) + return parser.get_structure(self.in_path, self.in_path)[0] + + def get_ref_biopython_model(self): + parser = PDB.PDBParser(QUIET=True) + return parser.get_structure(self.ref_path, self.ref_path)[0] + + def save_to_db(self, db: shelve.Shelf): + db[self.in_path] = self + + def to_report_dict(self): + return { + 'method': self.method, + 'structure': self.structure, + 'cdr': self.cdr, + 'filename': os.path.basename(self.in_path), + **self.scores + } + + +class TaskScanner: + + def __init__(self, root, postfix=None, db: Optional[shelve.Shelf]=None): + super().__init__() + self.root = root + self.postfix = postfix + self.visited = set() + self.db = db + if db is not None: + for k in db.keys(): + self.visited.add(k) + + def _get_metadata(self, fpath): + json_path = os.path.join( + os.path.dirname(os.path.dirname(fpath)), + 'metadata.json' + ) + tag_name = os.path.basename(os.path.dirname(fpath)) + method_name = os.path.basename( + os.path.dirname(os.path.dirname(os.path.dirname(fpath))) + ) + try: + antibody_chains = set() + info = None + with open(json_path, 'r') as f: + metadata = json.load(f) + for item in metadata['items']: + if item['tag'] == tag_name: + info = item + antibody_chains.add(item['residue_first'][0]) + if info is not None: + info['antibody_chains'] = list(antibody_chains) + info['structure'] = metadata['identifier'] + info['method'] = method_name + return info + except (json.JSONDecodeError, FileNotFoundError) as e: + return None + + def scan(self) -> List[EvalTask]: + tasks = [] + if self.postfix is None or not self.postfix: + input_fname_pattern = '^\d+\.pdb$' + ref_fname = 'REF1.pdb' + else: + input_fname_pattern = f'^\d+\_{self.postfix}\.pdb$' + ref_fname = f'REF1_{self.postfix}.pdb' + for parent, _, files in os.walk(self.root): + for fname in files: + fpath = os.path.join(parent, fname) + if not re.match(input_fname_pattern, fname): + continue + if os.path.getsize(fpath) == 0: + continue + if fpath in self.visited: + continue + + # Path to the reference structure + ref_path = os.path.join(parent, ref_fname) + if not os.path.exists(ref_path): + continue + + # CDR information + info = self._get_metadata(fpath) + if info is None: + continue + tasks.append(EvalTask( + in_path = fpath, + ref_path = ref_path, + info = info, + structure = info['structure'], + name = info['name'], + method = info['method'], + cdr = info['tag'], + ab_chains = info['antibody_chains'], + residue_first = info.get('residue_first', None), + residue_last = info.get('residue_last', None), + )) + self.visited.add(fpath) + return tasks diff --git a/diffab/tools/eval/energy.py b/diffab/tools/eval/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd97d1ac90cdd7cd2db93e35e60934a3632c782 --- /dev/null +++ b/diffab/tools/eval/energy.py @@ -0,0 +1,43 @@ +# pyright: reportMissingImports=false +import pyrosetta +from pyrosetta.rosetta.protocols.analysis import InterfaceAnalyzerMover +pyrosetta.init(' '.join([ + '-mute', 'all', + '-use_input_sc', + '-ignore_unrecognized_res', + '-ignore_zero_occupancy', 'false', + '-load_PDB_components', 'false', + '-relax:default_repeats', '2', + '-no_fconfig', +])) + +from tools.eval.base import EvalTask + + +def pyrosetta_interface_energy(pdb_path, interface): + pose = pyrosetta.pose_from_pdb(pdb_path) + mover = InterfaceAnalyzerMover(interface) + mover.set_pack_separated(True) + mover.apply(pose) + return pose.scores['dG_separated'] + + +def eval_interface_energy(task: EvalTask): + model_gen = task.get_gen_biopython_model() + antigen_chains = set() + for chain in model_gen: + if chain.id not in task.ab_chains: + antigen_chains.add(chain.id) + antigen_chains = ''.join(list(antigen_chains)) + antibody_chains = ''.join(task.ab_chains) + interface = f"{antibody_chains}_{antigen_chains}" + + dG_gen = pyrosetta_interface_energy(task.in_path, interface) + dG_ref = pyrosetta_interface_energy(task.ref_path, interface) + + task.scores.update({ + 'dG_gen': dG_gen, + 'dG_ref': dG_ref, + 'ddG': dG_gen - dG_ref + }) + return task diff --git a/diffab/tools/eval/run.py b/diffab/tools/eval/run.py new file mode 100644 index 0000000000000000000000000000000000000000..f902e40d5e26adf53b6dde14ea0d94d77a12903f --- /dev/null +++ b/diffab/tools/eval/run.py @@ -0,0 +1,66 @@ +import os +import argparse +import ray +import shelve +import time +import pandas as pd +from typing import Mapping + +from tools.eval.base import EvalTask, TaskScanner +from tools.eval.similarity import eval_similarity +from tools.eval.energy import eval_interface_energy + + +@ray.remote(num_cpus=1) +def evaluate(task, args): + funcs = [] + funcs.append(eval_similarity) + if not args.no_energy: + funcs.append(eval_interface_energy) + for f in funcs: + task = f(task) + return task + + +def dump_db(db: Mapping[str, EvalTask], path): + table = [] + for task in db.values(): + if 'abopt' in path and task.scores['seqid'] >= 100.0: + # In abopt (Antibody Optimization) mode, ignore sequences identical to the wild-type + continue + table.append(task.to_report_dict()) + table = pd.DataFrame(table) + table.to_csv(path, index=False, float_format='%.6f') + return table + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, default='./results') + parser.add_argument('--pfx', type=str, default='rosetta') + parser.add_argument('--no_energy', action='store_true', default=False) + args = parser.parse_args() + ray.init() + + db_path = os.path.join(args.root, 'evaluation_db') + with shelve.open(db_path) as db: + scanner = TaskScanner(root=args.root, postfix=args.pfx, db=db) + + while True: + tasks = scanner.scan() + futures = [evaluate.remote(t, args) for t in tasks] + if len(futures) > 0: + print(f'Submitted {len(futures)} tasks.') + while len(futures) > 0: + done_ids, futures = ray.wait(futures, num_returns=1) + for done_id in done_ids: + done_task = ray.get(done_id) + done_task.save_to_db(db) + print(f'Remaining {len(futures)}. Finished {done_task.in_path}') + db.sync() + + dump_db(db, os.path.join(args.root, 'summary.csv')) + time.sleep(1.0) + +if __name__ == '__main__': + main() diff --git a/diffab/tools/eval/similarity.py b/diffab/tools/eval/similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd396f27bc3e98e1e79a501cc1117318d1cfe96 --- /dev/null +++ b/diffab/tools/eval/similarity.py @@ -0,0 +1,125 @@ +import numpy as np +from Bio.PDB import PDBParser, Selection +from Bio.PDB.Polypeptide import three_to_one +from Bio import pairwise2 +from Bio.Align import substitution_matrices + +from diffab.tools.eval.base import EvalTask + + +def reslist_rmsd(res_list1, res_list2): + res_short, res_long = (res_list1, res_list2) if len(res_list1) < len(res_list2) else (res_list2, res_list1) + M, N = len(res_short), len(res_long) + + def d(i, j): + coord_i = np.array(res_short[i]['CA'].get_coord()) + coord_j = np.array(res_long[j]['CA'].get_coord()) + return ((coord_i - coord_j) ** 2).sum() + + SD = np.full([M, N], np.inf) + for i in range(M): + j = N - (M - i) + SD[i, j] = sum([ d(i+k, j+k) for k in range(N-j) ]) + + for j in range(N): + SD[M-1, j] = d(M-1, j) + + for i in range(M-2, -1, -1): + for j in range((N-(M-i))-1, -1, -1): + SD[i, j] = min( + d(i, j) + SD[i+1, j+1], + SD[i, j+1] + ) + + min_SD = SD[0, :N-M+1].min() + best_RMSD = np.sqrt(min_SD / M) + return best_RMSD + + +def entity_to_seq(entity): + seq = '' + mapping = [] + for res in Selection.unfold_entities(entity, 'R'): + try: + seq += three_to_one(res.get_resname()) + mapping.append(res.get_id()) + except KeyError: + pass + assert len(seq) == len(mapping) + return seq, mapping + + +def reslist_seqid(res_list1, res_list2): + seq1, _ = entity_to_seq(res_list1) + seq2, _ = entity_to_seq(res_list2) + _, seq_id = align_sequences(seq1, seq2) + return seq_id + + +def align_sequences(sequence_A, sequence_B, **kwargs): + """ + Performs a global pairwise alignment between two sequences + using the BLOSUM62 matrix and the Needleman-Wunsch algorithm + as implemented in Biopython. Returns the alignment, the sequence + identity and the residue mapping between both original sequences. + """ + + def _calculate_identity(sequenceA, sequenceB): + """ + Returns the percentage of identical characters between two sequences. + Assumes the sequences are aligned. + """ + + sa, sb, sl = sequenceA, sequenceB, len(sequenceA) + matches = [sa[i] == sb[i] for i in range(sl)] + seq_id = (100 * sum(matches)) / sl + return seq_id + + # gapless_sl = sum([1 for i in range(sl) if (sa[i] != '-' and sb[i] != '-')]) + # gap_id = (100 * sum(matches)) / gapless_sl + # return (seq_id, gap_id) + + # + matrix = kwargs.get('matrix', substitution_matrices.load("BLOSUM62")) + gap_open = kwargs.get('gap_open', -10.0) + gap_extend = kwargs.get('gap_extend', -0.5) + + alns = pairwise2.align.globalds(sequence_A, sequence_B, + matrix, gap_open, gap_extend, + penalize_end_gaps=(False, False) ) + + best_aln = alns[0] + aligned_A, aligned_B, score, begin, end = best_aln + + # Calculate sequence identity + seq_id = _calculate_identity(aligned_A, aligned_B) + return (aligned_A, aligned_B), seq_id + + +def extract_reslist(model, residue_first, residue_last): + assert residue_first[0] == residue_last[0] + residue_first, residue_last = tuple(residue_first), tuple(residue_last) + + chain_id = residue_first[0] + pos_first, pos_last = residue_first[1:], residue_last[1:] + chain = model[chain_id] + reslist = [] + for res in Selection.unfold_entities(chain, 'R'): + pos_current = (res.id[1], res.id[2]) + if pos_first <= pos_current <= pos_last: + reslist.append(res) + return reslist + + +def eval_similarity(task: EvalTask): + model_gen = task.get_gen_biopython_model() + model_ref = task.get_ref_biopython_model() + + reslist_gen = extract_reslist(model_gen, task.residue_first, task.residue_last) + reslist_ref = extract_reslist(model_ref, task.residue_first, task.residue_last) + + task.scores.update({ + 'rmsd': reslist_rmsd(reslist_gen, reslist_ref), + 'seqid': reslist_seqid(reslist_gen, reslist_ref), + }) + return task diff --git a/diffab/tools/relax/__main__.py b/diffab/tools/relax/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcdeb82d00dcc46488d4ff37b67e21f342de368 --- /dev/null +++ b/diffab/tools/relax/__main__.py @@ -0,0 +1,4 @@ +from .run import main + +if __name__ == '__main__': + main() diff --git a/diffab/tools/relax/base.py b/diffab/tools/relax/base.py new file mode 100644 index 0000000000000000000000000000000000000000..88996180702819a02cc59f2a88ae2dc20ce9f6a8 --- /dev/null +++ b/diffab/tools/relax/base.py @@ -0,0 +1,117 @@ +import os +import re +import json +from typing import Optional, Tuple, List +from dataclasses import dataclass + + +@dataclass +class RelaxTask: + in_path: str + current_path: str + info: dict + status: str + + flexible_residue_first: Optional[Tuple] = None + flexible_residue_last: Optional[Tuple] = None + + def get_in_path_with_tag(self, tag): + name, ext = os.path.splitext(self.in_path) + new_path = f'{name}_{tag}{ext}' + return new_path + + def set_current_path_tag(self, tag): + new_path = self.get_in_path_with_tag(tag) + self.current_path = new_path + return new_path + + def check_current_path_exists(self): + ok = os.path.exists(self.current_path) + if not ok: + self.mark_failure() + if os.path.getsize(self.current_path) == 0: + ok = False + self.mark_failure() + os.unlink(self.current_path) + return ok + + def update_if_finished(self, tag): + out_path = self.get_in_path_with_tag(tag) + if os.path.exists(out_path) and os.path.getsize(out_path) > 0: + # print('Already finished', out_path) + self.set_current_path_tag(tag) + self.mark_success() + return True + return False + + def can_proceed(self): + self.check_current_path_exists() + return self.status != 'failed' + + def mark_success(self): + self.status = 'success' + + def mark_failure(self): + self.status = 'failed' + + + +class TaskScanner: + + def __init__(self, root, final_postfix=None): + super().__init__() + self.root = root + self.visited = set() + self.final_postfix = final_postfix + + def _get_metadata(self, fpath): + json_path = os.path.join( + os.path.dirname(os.path.dirname(fpath)), + 'metadata.json' + ) + tag_name = os.path.basename(os.path.dirname(fpath)) + try: + with open(json_path, 'r') as f: + metadata = json.load(f) + for item in metadata['items']: + if item['tag'] == tag_name: + return item + except (json.JSONDecodeError, FileNotFoundError) as e: + return None + return None + + def scan(self) -> List[RelaxTask]: + tasks = [] + input_fname_pattern = '(^\d+\.pdb$|^REF\d\.pdb$)' + for parent, _, files in os.walk(self.root): + for fname in files: + fpath = os.path.join(parent, fname) + if not re.match(input_fname_pattern, fname): + continue + if os.path.getsize(fpath) == 0: + continue + if fpath in self.visited: + continue + + # If finished + if self.final_postfix is not None: + fpath_name, fpath_ext = os.path.splitext(fpath) + fpath_final = f"{fpath_name}_{self.final_postfix}{fpath_ext}" + if os.path.exists(fpath_final): + continue + + # Get metadata + info = self._get_metadata(fpath) + if info is None: + continue + + tasks.append(RelaxTask( + in_path = fpath, + current_path = fpath, + info = info, + status = 'created', + flexible_residue_first = info.get('residue_first', None), + flexible_residue_last = info.get('residue_last', None), + )) + self.visited.add(fpath) + return tasks diff --git a/diffab/tools/relax/openmm_relaxer.py b/diffab/tools/relax/openmm_relaxer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad8a60d5a03f96e2a2e771ad866cb90e620a720 --- /dev/null +++ b/diffab/tools/relax/openmm_relaxer.py @@ -0,0 +1,144 @@ +import os +import time +import io +import logging +import pdbfixer +import openmm +from openmm import app as openmm_app +from openmm import unit +ENERGY = unit.kilocalories_per_mole +LENGTH = unit.angstroms + +from diffab.tools.relax.base import RelaxTask + + +def current_milli_time(): + return round(time.time() * 1000) + + +def _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last): + if ch_rs_ic[0] != flexible_residue_first[0]: return False + r_first, r_last = tuple(flexible_residue_first[1:]), tuple(flexible_residue_last[1:]) + rs_ic = ch_rs_ic[1:] + return r_first <= rs_ic <= r_last + + +class ForceFieldMinimizer(object): + + def __init__(self, stiffness=10.0, max_iterations=0, tolerance=2.39*unit.kilocalories_per_mole, platform='CUDA'): + super().__init__() + self.stiffness = stiffness + self.max_iterations = max_iterations + self.tolerance = tolerance + assert platform in ('CUDA', 'CPU') + self.platform = platform + + def _fix(self, pdb_str): + fixer = pdbfixer.PDBFixer(pdbfile=io.StringIO(pdb_str)) + fixer.findNonstandardResidues() + fixer.replaceNonstandardResidues() + + fixer.findMissingResidues() + fixer.findMissingAtoms() + fixer.addMissingAtoms(seed=0) + fixer.addMissingHydrogens() + + out_handle = io.StringIO() + openmm_app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, keepIds=True) + return out_handle.getvalue() + + def _get_pdb_string(self, topology, positions): + with io.StringIO() as f: + openmm_app.PDBFile.writeFile(topology, positions, f, keepIds=True) + return f.getvalue() + + def _minimize(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None): + pdb = openmm_app.PDBFile(io.StringIO(pdb_str)) + + force_field = openmm_app.ForceField("amber99sb.xml") + constraints = openmm_app.HBonds + system = force_field.createSystem(pdb.topology, constraints=constraints) + + # Add constraints to non-generated regions + force = openmm.CustomExternalForce("0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)") + force.addGlobalParameter("k", self.stiffness) + for p in ["x0", "y0", "z0"]: + force.addPerParticleParameter(p) + + if flexible_residue_first is not None and flexible_residue_last is not None: + for i, a in enumerate(pdb.topology.atoms()): + ch_rs_ic = (a.residue.chain.id, int(a.residue.id), a.residue.insertionCode) + if not _is_in_the_range(ch_rs_ic, flexible_residue_first, flexible_residue_last) and a.element.name != "hydrogen": + force.addParticle(i, pdb.positions[i]) + + system.addForce(force) + + # Set up the integrator and simulation + integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) + platform = openmm.Platform.getPlatformByName("CUDA") + simulation = openmm_app.Simulation(pdb.topology, system, integrator, platform) + simulation.context.setPositions(pdb.positions) + + # Perform minimization + ret = {} + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + + simulation.minimizeEnergy(maxIterations=self.max_iterations, tolerance=self.tolerance) + + state = simulation.context.getState(getEnergy=True, getPositions=True) + ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY) + ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH) + ret["min_pdb"] = self._get_pdb_string(simulation.topology, state.getPositions()) + + return ret['min_pdb'], ret + + def _add_energy_remarks(self, pdb_str, ret): + pdb_lines = pdb_str.splitlines() + pdb_lines.insert(1, "REMARK 1 FINAL ENERGY: {:.3f} KCAL/MOL".format(ret['efinal'])) + pdb_lines.insert(1, "REMARK 1 INITIAL ENERGY: {:.3f} KCAL/MOL".format(ret['einit'])) + return "\n".join(pdb_lines) + + def __call__(self, pdb_str, flexible_residue_first=None, flexible_residue_last=None, return_info=True): + if '\n' not in pdb_str and pdb_str.lower().endswith(".pdb"): + with open(pdb_str) as f: + pdb_str = f.read() + + pdb_fixed = self._fix(pdb_str) + pdb_min, ret = self._minimize(pdb_fixed, flexible_residue_first, flexible_residue_last) + pdb_min = self._add_energy_remarks(pdb_min, ret) + if return_info: + return pdb_min, ret + else: + return pdb_min + + +def run_openmm(task: RelaxTask): + if not task.can_proceed() : + return task + if task.update_if_finished('openmm'): + return task + + try: + minimizer = ForceFieldMinimizer() + with open(task.current_path, 'r') as f: + pdb_str = f.read() + + pdb_min = minimizer( + pdb_str = pdb_str, + flexible_residue_first = task.flexible_residue_first, + flexible_residue_last = task.flexible_residue_last, + return_info = False, + ) + out_path = task.set_current_path_tag('openmm') + with open(out_path, 'w') as f: + f.write(pdb_min) + task.mark_success() + except ValueError as e: + logging.warning( + f'{e.__class__.__name__}: {str(e)} ({task.current_path})' + ) + task.mark_failure() + return task + diff --git a/diffab/tools/relax/pyrosetta_relaxer.py b/diffab/tools/relax/pyrosetta_relaxer.py new file mode 100644 index 0000000000000000000000000000000000000000..2696f313850d2faa1695883904884e9ccb9cd964 --- /dev/null +++ b/diffab/tools/relax/pyrosetta_relaxer.py @@ -0,0 +1,189 @@ +# pyright: reportMissingImports=false +import os +import time +import pyrosetta +from pyrosetta.rosetta.protocols.relax import FastRelax +from pyrosetta.rosetta.core.pack.task import TaskFactory +from pyrosetta.rosetta.core.pack.task import operation +from pyrosetta.rosetta.core.select import residue_selector as selections +from pyrosetta.rosetta.core.select.movemap import MoveMapFactory, move_map_action +pyrosetta.init(' '.join([ + '-mute', 'all', + '-use_input_sc', + '-ignore_unrecognized_res', + '-ignore_zero_occupancy', 'false', + '-load_PDB_components', 'false', + '-relax:default_repeats', '2', + '-no_fconfig', +])) + +from diffab.tools.relax.base import RelaxTask + + +def current_milli_time(): + return round(time.time() * 1000) + + +def parse_residue_position(p): + icode = None + if not p[-1].isnumeric(): # Has ICODE + icode = p[-1] + + for i, c in enumerate(p): + if c.isnumeric(): + break + chain = p[:i] + resseq = int(p[i:]) + + if icode is not None: + return chain, resseq, icode + else: + return chain, resseq + + +def get_scorefxn(scorefxn_name:str): + """ + Gets the scorefxn with appropriate corrections. + Taken from: https://gist.github.com/matteoferla/b33585f3aeab58b8424581279e032550 + """ + import pyrosetta + + corrections = { + 'beta_july15': False, + 'beta_nov16': False, + 'gen_potential': False, + 'restore_talaris_behavior': False, + } + if 'beta_july15' in scorefxn_name or 'beta_nov15' in scorefxn_name: + # beta_july15 is ref2015 + corrections['beta_july15'] = True + elif 'beta_nov16' in scorefxn_name: + corrections['beta_nov16'] = True + elif 'genpot' in scorefxn_name: + corrections['gen_potential'] = True + pyrosetta.rosetta.basic.options.set_boolean_option('corrections:beta_july15', True) + elif 'talaris' in scorefxn_name: #2013 and 2014 + corrections['restore_talaris_behavior'] = True + else: + pass + for corr, value in corrections.items(): + pyrosetta.rosetta.basic.options.set_boolean_option(f'corrections:{corr}', value) + return pyrosetta.create_score_function(scorefxn_name) + + +class RelaxRegion(object): + + def __init__(self, scorefxn='ref2015', max_iter=1000, subset='nbrs', move_bb=True): + super().__init__() + self.scorefxn = get_scorefxn(scorefxn) + self.fast_relax = FastRelax() + self.fast_relax.set_scorefxn(self.scorefxn) + self.fast_relax.max_iter(max_iter) + assert subset in ('all', 'target', 'nbrs') + self.subset = subset + self.move_bb = move_bb + + def __call__(self, pdb_path, flexible_residue_first, flexible_residue_last): + pose = pyrosetta.pose_from_pdb(pdb_path) + start_t = current_milli_time() + original_pose = pose.clone() + + tf = TaskFactory() + tf.push_back(operation.InitializeFromCommandline()) + tf.push_back(operation.RestrictToRepacking()) # Only allow residues to repack. No design at any position. + + # Create selector for the region to be relaxed + # Turn off design and repacking on irrelevant positions + if flexible_residue_first[-1] == ' ': + flexible_residue_first = flexible_residue_first[:-1] + if flexible_residue_last[-1] == ' ': + flexible_residue_last = flexible_residue_last[:-1] + if self.subset != 'all': + gen_selector = selections.ResidueIndexSelector() + gen_selector.set_index_range( + pose.pdb_info().pdb2pose(*flexible_residue_first), + pose.pdb_info().pdb2pose(*flexible_residue_last), + ) + nbr_selector = selections.NeighborhoodResidueSelector() + nbr_selector.set_focus_selector(gen_selector) + nbr_selector.set_include_focus_in_subset(True) + + if self.subset == 'nbrs': + subset_selector = nbr_selector + elif self.subset == 'target': + subset_selector = gen_selector + + prevent_repacking_rlt = operation.PreventRepackingRLT() + prevent_subset_repacking = operation.OperateOnResidueSubset( + prevent_repacking_rlt, + subset_selector, + flip_subset=True, + ) + tf.push_back(prevent_subset_repacking) + + scorefxn = self.scorefxn + fr = self.fast_relax + + pose = original_pose.clone() + pos_list = pyrosetta.rosetta.utility.vector1_unsigned_long() + for pos in range(pose.pdb_info().pdb2pose(*flexible_residue_first), pose.pdb_info().pdb2pose(*flexible_residue_last)+1): + pos_list.append(pos) + # basic_idealize(pose, pos_list, scorefxn, fast=True) + + mmf = MoveMapFactory() + if self.move_bb: + mmf.add_bb_action(move_map_action.mm_enable, gen_selector) + mmf.add_chi_action(move_map_action.mm_enable, subset_selector) + mm = mmf.create_movemap_from_pose(pose) + + fr.set_movemap(mm) + fr.set_task_factory(tf) + fr.apply(pose) + + e_before = scorefxn(original_pose) + e_relax = scorefxn(pose) + # print('\n\n[Finished in %.2f secs]' % ((current_milli_time() - start_t) / 1000)) + # print(' > Energy (before): %.4f' % scorefxn(original_pose)) + # print(' > Energy (optimized): %.4f' % scorefxn(pose)) + return pose, e_before, e_relax + + +def run_pyrosetta(task: RelaxTask): + if not task.can_proceed() : + return task + if task.update_if_finished('rosetta'): + return task + + minimizer = RelaxRegion() + pose_min, _, _ = minimizer( + pdb_path = task.current_path, + flexible_residue_first = task.flexible_residue_first, + flexible_residue_last = task.flexible_residue_last, + ) + + out_path = task.set_current_path_tag('rosetta') + pose_min.dump_pdb(out_path) + task.mark_success() + return task + + +def run_pyrosetta_fixbb(task: RelaxTask): + if not task.can_proceed() : + return task + if task.update_if_finished('fixbb'): + return task + + minimizer = RelaxRegion(move_bb=False) + pose_min, _, _ = minimizer( + pdb_path = task.current_path, + flexible_residue_first = task.flexible_residue_first, + flexible_residue_last = task.flexible_residue_last, + ) + + out_path = task.set_current_path_tag('fixbb') + pose_min.dump_pdb(out_path) + task.mark_success() + return task + + + diff --git a/diffab/tools/relax/run.py b/diffab/tools/relax/run.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbfd57589e539443709b0d38d9615b6f8b42dbd --- /dev/null +++ b/diffab/tools/relax/run.py @@ -0,0 +1,85 @@ +import argparse +import ray +import time + +from diffab.tools.relax.openmm_relaxer import run_openmm +from diffab.tools.relax.pyrosetta_relaxer import run_pyrosetta, run_pyrosetta_fixbb +from diffab.tools.relax.base import TaskScanner + + +@ray.remote(num_gpus=1/8, num_cpus=1) +def run_openmm_remote(task): + return run_openmm(task) + + +@ray.remote(num_cpus=1) +def run_pyrosetta_remote(task): + return run_pyrosetta(task) + + +@ray.remote(num_cpus=1) +def run_pyrosetta_fixbb_remote(task): + return run_pyrosetta_fixbb(task) + + +@ray.remote +def pipeline_openmm_pyrosetta(task): + funcs = [ + run_openmm_remote, + run_pyrosetta_remote, + ] + for fn in funcs: + task = fn.remote(task) + return ray.get(task) + + +@ray.remote +def pipeline_pyrosetta(task): + funcs = [ + run_pyrosetta_remote, + ] + for fn in funcs: + task = fn.remote(task) + return ray.get(task) + + +@ray.remote +def pipeline_pyrosetta_fixbb(task): + funcs = [ + run_pyrosetta_fixbb_remote, + ] + for fn in funcs: + task = fn.remote(task) + return ray.get(task) + + +pipeline_dict = { + 'openmm_pyrosetta': pipeline_openmm_pyrosetta, + 'pyrosetta': pipeline_pyrosetta, + 'pyrosetta_fixbb': pipeline_pyrosetta_fixbb, +} + + +def main(): + ray.init() + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, default='./results') + parser.add_argument('--pipeline', type=lambda s: pipeline_dict[s], default=pipeline_openmm_pyrosetta) + args = parser.parse_args() + + final_pfx = 'fixbb' if args.pipeline == pipeline_pyrosetta_fixbb else 'rosetta' + scanner = TaskScanner(args.root, final_postfix=final_pfx) + while True: + tasks = scanner.scan() + futures = [args.pipeline.remote(t) for t in tasks] + if len(futures) > 0: + print(f'Submitted {len(futures)} tasks.') + while len(futures) > 0: + done_ids, futures = ray.wait(futures, num_returns=1) + for done_id in done_ids: + done_task = ray.get(done_id) + print(f'Remaining {len(futures)}. Finished {done_task.current_path}') + time.sleep(1.0) + +if __name__ == '__main__': + main() diff --git a/diffab/tools/renumber/__init__.py b/diffab/tools/renumber/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95c78181a7c9dd2ba964b3c44be1bce1cba46f60 --- /dev/null +++ b/diffab/tools/renumber/__init__.py @@ -0,0 +1 @@ +from .run import renumber diff --git a/diffab/tools/renumber/__main__.py b/diffab/tools/renumber/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca1c759534be0a3ae14e4ab178e981eb4fb211a --- /dev/null +++ b/diffab/tools/renumber/__main__.py @@ -0,0 +1,5 @@ +from .run import main + +if __name__ == '__main__': + main() + \ No newline at end of file diff --git a/diffab/tools/renumber/run.py b/diffab/tools/renumber/run.py new file mode 100644 index 0000000000000000000000000000000000000000..50bfb98e8ed8d12a7a0748b741659eb94d27b1e7 --- /dev/null +++ b/diffab/tools/renumber/run.py @@ -0,0 +1,85 @@ +import argparse +import abnumber +from Bio import PDB +from Bio.PDB import Model, Chain, Residue, Selection +from Bio.Data import SCOPData +from typing import List, Tuple + + +def biopython_chain_to_sequence(chain: Chain.Chain): + residue_list = Selection.unfold_entities(chain, 'R') + seq = ''.join([SCOPData.protein_letters_3to1.get(r.resname, 'X') for r in residue_list]) + return seq, residue_list + + +def assign_number_to_sequence(seq): + abchain = abnumber.Chain(seq, scheme='chothia') + offset = seq.index(abchain.seq) + if not (offset >= 0): + raise ValueError( + 'The identified Fv sequence is not a subsequence of the original sequence.' + ) + + numbers = [None for _ in range(len(seq))] + for i, (pos, aa) in enumerate(abchain): + resseq = pos.number + icode = pos.letter if pos.letter else ' ' + numbers[i+offset] = (resseq, icode) + return numbers, abchain + + +def renumber_biopython_chain(chain_id, residue_list: List[Residue.Residue], numbers: List[Tuple[int, str]]): + chain = Chain.Chain(chain_id) + for residue, number in zip(residue_list, numbers): + if number is None: + continue + residue = residue.copy() + new_id = (residue.id[0], number[0], number[1]) + residue.id = new_id + chain.add(residue) + return chain + + +def renumber(in_pdb, out_pdb, return_other_chains=False): + parser = PDB.PDBParser(QUIET=True) + structure = parser.get_structure(None, in_pdb) + model = structure[0] + model_new = Model.Model(0) + + heavy_chains, light_chains, other_chains = [], [], [] + + for chain in model: + try: + seq, reslist = biopython_chain_to_sequence(chain) + numbers, abchain = assign_number_to_sequence(seq) + chain_new = renumber_biopython_chain(chain.id, reslist, numbers) + print(f'[INFO] Renumbered chain {chain_new.id} ({abchain.chain_type})') + if abchain.chain_type == 'H': + heavy_chains.append(chain_new.id) + elif abchain.chain_type in ('K', 'L'): + light_chains.append(chain_new.id) + except abnumber.ChainParseError as e: + print(f'[INFO] Chain {chain.id} does not contain valid Fv: {str(e)}') + chain_new = chain.copy() + other_chains.append(chain_new.id) + model_new.add(chain_new) + + pdb_io = PDB.PDBIO() + pdb_io.set_structure(model_new) + pdb_io.save(out_pdb) + if return_other_chains: + return heavy_chains, light_chains, other_chains + else: + return heavy_chains, light_chains + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('in_pdb', type=str) + parser.add_argument('out_pdb', type=str) + args = parser.parse_args() + + renumber(args.in_pdb, args.out_pdb) + +if __name__ == '__main__': + main() diff --git a/diffab/tools/runner/design_for_pdb.py b/diffab/tools/runner/design_for_pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..e67c38800c0c9d76b17ca122d580372a0b68f0a7 --- /dev/null +++ b/diffab/tools/runner/design_for_pdb.py @@ -0,0 +1,291 @@ +import os +import argparse +import copy +import json +from tqdm.auto import tqdm +from torch.utils.data import DataLoader + +from diffab.datasets.custom import preprocess_antibody_structure +from diffab.models import get_model +from diffab.modules.common.geometry import reconstruct_backbone_partially +from diffab.modules.common.so3 import so3vec_to_rotation +from diffab.utils.inference import RemoveNative +from diffab.utils.protein.writers import save_pdb +from diffab.utils.train import recursive_to +from diffab.utils.misc import * +from diffab.utils.data import * +from diffab.utils.transforms import * +from diffab.utils.inference import * +from diffab.tools.renumber import renumber as renumber_antibody + + +def create_data_variants(config, structure_factory): + structure = structure_factory() + structure_id = structure['id'] + + data_variants = [] + if config.mode == 'single_cdr': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + for cdr_name in cdrs: + transform = Compose([ + MaskSingleCDR(cdr_name, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + residue_first, residue_last = get_residue_first_last(data_var) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-{cdr_name}', + 'tag': f'{cdr_name}', + 'cdr': cdr_name, + 'residue_first': residue_first, + 'residue_last': residue_last, + }) + elif config.mode == 'multiple_cdrs': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + transform = Compose([ + MaskMultipleCDRs(selection=cdrs, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-MultipleCDRs', + 'tag': 'MultipleCDRs', + 'cdrs': cdrs, + 'residue_first': None, + 'residue_last': None, + }) + elif config.mode == 'full': + transform = Compose([ + MaskAntibody(), + MergeChains(), + ]) + data_var = transform(structure_factory()) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-Full', + 'tag': 'Full', + 'residue_first': None, + 'residue_last': None, + }) + elif config.mode == 'abopt': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + for cdr_name in cdrs: + transform = Compose([ + MaskSingleCDR(cdr_name, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + residue_first, residue_last = get_residue_first_last(data_var) + for opt_step in config.sampling.optimize_steps: + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-{cdr_name}-O{opt_step}', + 'tag': f'{cdr_name}-O{opt_step}', + 'cdr': cdr_name, + 'opt_step': opt_step, + 'residue_first': residue_first, + 'residue_last': residue_last, + }) + else: + raise ValueError(f'Unknown mode: {config.mode}.') + return data_variants + + +def design_for_pdb(args): + # Load configs + config, config_name = load_config(args.config) + seed_all(args.seed if args.seed is not None else config.sampling.seed) + + # Structure loading + data_id = os.path.basename(args.pdb_path) + if args.no_renumber: + pdb_path = args.pdb_path + else: + in_pdb_path = args.pdb_path + out_pdb_path = os.path.splitext(in_pdb_path)[0] + '_chothia.pdb' + heavy_chains, light_chains = renumber_antibody(in_pdb_path, out_pdb_path) + pdb_path = out_pdb_path + + if args.heavy is None and len(heavy_chains) > 0: + args.heavy = heavy_chains[0] + if args.light is None and len(light_chains) > 0: + args.light = light_chains[0] + if args.heavy is None and args.light is None: + raise ValueError("Neither heavy chain id (--heavy) or light chain id (--light) is specified.") + get_structure = lambda: preprocess_antibody_structure({ + 'id': data_id, + 'pdb_path': pdb_path, + 'heavy_id': args.heavy, + # If the input is a nanobody, the light chain will be ignores + 'light_id': args.light, + }) + + # Logging + structure_ = get_structure() + structure_id = structure_['id'] + tag_postfix = '_%s' % args.tag if args.tag else '' + log_dir = get_new_log_dir( + os.path.join(args.out_root, config_name + tag_postfix), + prefix=data_id + ) + logger = get_logger('sample', log_dir) + logger.info(f'Data ID: {structure_["id"]}') + logger.info(f'Results will be saved to {log_dir}') + data_native = MergeChains()(structure_) + save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) + + # Load checkpoint and model + logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) + ckpt = torch.load(config.model.checkpoint, map_location='cpu') + cfg_ckpt = ckpt['config'] + model = get_model(cfg_ckpt.model).to(args.device) + lsd = model.load_state_dict(ckpt['model']) + logger.info(str(lsd)) + + # Make data variants + data_variants = create_data_variants( + config = config, + structure_factory = get_structure, + ) + + # Save metadata + metadata = { + 'identifier': structure_id, + 'index': data_id, + 'config': args.config, + 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], + } + with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: + json.dump(metadata, f, indent=2) + + # Start sampling + collate_fn = PaddingCollate(eight=False) + inference_tfm = [ PatchAroundAnchor(), ] + if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode + inference_tfm.append(RemoveNative( + remove_structure = config.sampling.sample_structure, + remove_sequence = config.sampling.sample_sequence, + )) + inference_tfm = Compose(inference_tfm) + + for variant in data_variants: + os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) + logger.info(f"Start sampling for: {variant['tag']}") + + save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization + + data_cropped = inference_tfm( + copy.deepcopy(variant['data']) + ) + data_list_repeat = [ data_cropped ] * config.sampling.num_samples + loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + + count = 0 + for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): + torch.set_grad_enabled(False) + model.eval() + batch = recursive_to(batch, args.device) + if 'abopt' in config.mode: + # Antibody optimization starting from native + traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ + 'pbar': True, + 'sample_structure': config.sampling.sample_structure, + 'sample_sequence': config.sampling.sample_sequence, + }) + else: + # De novo design + traj_batch = model.sample(batch, sample_opt={ + 'pbar': True, + 'sample_structure': config.sampling.sample_structure, + 'sample_sequence': config.sampling.sample_sequence, + }) + + aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. + pos_atom_new, mask_atom_new = reconstruct_backbone_partially( + pos_ctx = batch['pos_heavyatom'], + R_new = so3vec_to_rotation(traj_batch[0][0]), + t_new = traj_batch[0][1], + aa = aa_new, + chain_nb = batch['chain_nb'], + res_nb = batch['res_nb'], + mask_atoms = batch['mask_heavyatom'], + mask_recons = batch['generate_flag'], + ) + aa_new = aa_new.cpu() + pos_atom_new = pos_atom_new.cpu() + mask_atom_new = mask_atom_new.cpu() + + for i in range(aa_new.size(0)): + data_tmpl = variant['data'] + aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) + mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) + pos_ha = ( + apply_patch_to_tensor( + data_tmpl['pos_heavyatom'], + pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), + data_cropped['patch_idx'] + ) + ) + + save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) + save_pdb({ + 'chain_nb': data_tmpl['chain_nb'], + 'chain_id': data_tmpl['chain_id'], + 'resseq': data_tmpl['resseq'], + 'icode': data_tmpl['icode'], + # Generated + 'aa': aa, + 'mask_heavyatom': mask_ha, + 'pos_heavyatom': pos_ha, + }, path=save_path) + # save_pdb({ + # 'chain_nb': data_cropped['chain_nb'], + # 'chain_id': data_cropped['chain_id'], + # 'resseq': data_cropped['resseq'], + # 'icode': data_cropped['icode'], + # # Generated + # 'aa': aa_new[i], + # 'mask_heavyatom': mask_atom_new[i], + # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), + # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) + count += 1 + + logger.info('Finished.\n') + + +def args_from_cmdline(): + parser = argparse.ArgumentParser() + parser.add_argument('pdb_path', type=str) + parser.add_argument('--heavy', type=str, default=None, help='Chain id of the heavy chain.') + parser.add_argument('--light', type=str, default=None, help='Chain id of the light chain.') + parser.add_argument('--no_renumber', action='store_true', default=False) + parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') + parser.add_argument('-o', '--out_root', type=str, default='./results') + parser.add_argument('-t', '--tag', type=str, default='') + parser.add_argument('-s', '--seed', type=int, default=None) + parser.add_argument('-d', '--device', type=str, default='cuda') + parser.add_argument('-b', '--batch_size', type=int, default=16) + args = parser.parse_args() + return args + + +def args_factory(**kwargs): + default_args = EasyDict( + heavy = 'H', + light = 'L', + no_renumber = False, + config = './configs/test/codesign_single.yml', + out_root = './results', + tag = '', + seed = None, + device = 'cuda', + batch_size = 16 + ) + default_args.update(kwargs) + return default_args + + +if __name__ == '__main__': + design_for_pdb(args_from_cmdline()) diff --git a/diffab/tools/runner/design_for_testset.py b/diffab/tools/runner/design_for_testset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ce6c96bc06a3c921514255de758707d78baf63 --- /dev/null +++ b/diffab/tools/runner/design_for_testset.py @@ -0,0 +1,243 @@ +import os +import argparse +import copy +import json +from tqdm.auto import tqdm +from torch.utils.data import DataLoader + +from diffab.datasets import get_dataset +from diffab.models import get_model +from diffab.modules.common.geometry import reconstruct_backbone_partially +from diffab.modules.common.so3 import so3vec_to_rotation +from diffab.utils.inference import RemoveNative +from diffab.utils.protein.writers import save_pdb +from diffab.utils.train import recursive_to +from diffab.utils.misc import * +from diffab.utils.data import * +from diffab.utils.transforms import * +from diffab.utils.inference import * + + +def create_data_variants(config, structure_factory): + structure = structure_factory() + structure_id = structure['id'] + + data_variants = [] + if config.mode == 'single_cdr': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + for cdr_name in cdrs: + transform = Compose([ + MaskSingleCDR(cdr_name, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + residue_first, residue_last = get_residue_first_last(data_var) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-{cdr_name}', + 'tag': f'{cdr_name}', + 'cdr': cdr_name, + 'residue_first': residue_first, + 'residue_last': residue_last, + }) + elif config.mode == 'multiple_cdrs': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + transform = Compose([ + MaskMultipleCDRs(selection=cdrs, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-MultipleCDRs', + 'tag': 'MultipleCDRs', + 'cdrs': cdrs, + 'residue_first': None, + 'residue_last': None, + }) + elif config.mode == 'full': + transform = Compose([ + MaskAntibody(), + MergeChains(), + ]) + data_var = transform(structure_factory()) + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-Full', + 'tag': 'Full', + 'residue_first': None, + 'residue_last': None, + }) + elif config.mode == 'abopt': + cdrs = sorted(list(set(find_cdrs(structure)).intersection(config.sampling.cdrs))) + for cdr_name in cdrs: + transform = Compose([ + MaskSingleCDR(cdr_name, augmentation=False), + MergeChains(), + ]) + data_var = transform(structure_factory()) + residue_first, residue_last = get_residue_first_last(data_var) + for opt_step in config.sampling.optimize_steps: + data_variants.append({ + 'data': data_var, + 'name': f'{structure_id}-{cdr_name}-O{opt_step}', + 'tag': f'{cdr_name}-O{opt_step}', + 'cdr': cdr_name, + 'opt_step': opt_step, + 'residue_first': residue_first, + 'residue_last': residue_last, + }) + else: + raise ValueError(f'Unknown mode: {config.mode}.') + return data_variants + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('index', type=int) + parser.add_argument('-c', '--config', type=str, default='./configs/test/codesign_single.yml') + parser.add_argument('-o', '--out_root', type=str, default='./results') + parser.add_argument('-t', '--tag', type=str, default='') + parser.add_argument('-s', '--seed', type=int, default=None) + parser.add_argument('-d', '--device', type=str, default='cuda') + parser.add_argument('-b', '--batch_size', type=int, default=16) + args = parser.parse_args() + + # Load configs + config, config_name = load_config(args.config) + seed_all(args.seed if args.seed is not None else config.sampling.seed) + + # Testset + dataset = get_dataset(config.dataset.test) + get_structure = lambda: dataset[args.index] + + # Logging + structure_ = get_structure() + structure_id = structure_['id'] + tag_postfix = '_%s' % args.tag if args.tag else '' + log_dir = get_new_log_dir(os.path.join(args.out_root, config_name + tag_postfix), prefix='%04d_%s' % (args.index, structure_['id'])) + logger = get_logger('sample', log_dir) + logger.info('Data ID: %s' % structure_['id']) + data_native = MergeChains()(structure_) + save_pdb(data_native, os.path.join(log_dir, 'reference.pdb')) + + # Load checkpoint and model + logger.info('Loading model config and checkpoints: %s' % (config.model.checkpoint)) + ckpt = torch.load(config.model.checkpoint, map_location='cpu') + cfg_ckpt = ckpt['config'] + model = get_model(cfg_ckpt.model).to(args.device) + lsd = model.load_state_dict(ckpt['model']) + logger.info(str(lsd)) + + # Make data variants + data_variants = create_data_variants( + config = config, + structure_factory = get_structure, + ) + + # Save metadata + metadata = { + 'identifier': structure_id, + 'index': args.index, + 'config': args.config, + 'items': [{kk: vv for kk, vv in var.items() if kk != 'data'} for var in data_variants], + } + with open(os.path.join(log_dir, 'metadata.json'), 'w') as f: + json.dump(metadata, f, indent=2) + + # Start sampling + collate_fn = PaddingCollate(eight=False) + inference_tfm = [ PatchAroundAnchor(), ] + if 'abopt' not in config.mode: # Don't remove native CDR in optimization mode + inference_tfm.append(RemoveNative( + remove_structure = config.sampling.sample_structure, + remove_sequence = config.sampling.sample_sequence, + )) + inference_tfm = Compose(inference_tfm) + + for variant in data_variants: + os.makedirs(os.path.join(log_dir, variant['tag']), exist_ok=True) + logger.info(f"Start sampling for: {variant['tag']}") + + save_pdb(data_native, os.path.join(log_dir, variant['tag'], 'REF1.pdb')) # w/ OpenMM minimization + + data_cropped = inference_tfm( + copy.deepcopy(variant['data']) + ) + data_list_repeat = [ data_cropped ] * config.sampling.num_samples + loader = DataLoader(data_list_repeat, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + + count = 0 + for batch in tqdm(loader, desc=variant['name'], dynamic_ncols=True): + torch.set_grad_enabled(False) + model.eval() + batch = recursive_to(batch, args.device) + if 'abopt' in config.mode: + # Antibody optimization starting from native + traj_batch = model.optimize(batch, opt_step=variant['opt_step'], optimize_opt={ + 'pbar': True, + 'sample_structure': config.sampling.sample_structure, + 'sample_sequence': config.sampling.sample_sequence, + }) + else: + # De novo design + traj_batch = model.sample(batch, sample_opt={ + 'pbar': True, + 'sample_structure': config.sampling.sample_structure, + 'sample_sequence': config.sampling.sample_sequence, + }) + + aa_new = traj_batch[0][2] # 0: Last sampling step. 2: Amino acid. + pos_atom_new, mask_atom_new = reconstruct_backbone_partially( + pos_ctx = batch['pos_heavyatom'], + R_new = so3vec_to_rotation(traj_batch[0][0]), + t_new = traj_batch[0][1], + aa = aa_new, + chain_nb = batch['chain_nb'], + res_nb = batch['res_nb'], + mask_atoms = batch['mask_heavyatom'], + mask_recons = batch['generate_flag'], + ) + aa_new = aa_new.cpu() + pos_atom_new = pos_atom_new.cpu() + mask_atom_new = mask_atom_new.cpu() + + for i in range(aa_new.size(0)): + data_tmpl = variant['data'] + aa = apply_patch_to_tensor(data_tmpl['aa'], aa_new[i], data_cropped['patch_idx']) + mask_ha = apply_patch_to_tensor(data_tmpl['mask_heavyatom'], mask_atom_new[i], data_cropped['patch_idx']) + pos_ha = ( + apply_patch_to_tensor( + data_tmpl['pos_heavyatom'], + pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), + data_cropped['patch_idx'] + ) + ) + + save_path = os.path.join(log_dir, variant['tag'], '%04d.pdb' % (count, )) + save_pdb({ + 'chain_nb': data_tmpl['chain_nb'], + 'chain_id': data_tmpl['chain_id'], + 'resseq': data_tmpl['resseq'], + 'icode': data_tmpl['icode'], + # Generated + 'aa': aa, + 'mask_heavyatom': mask_ha, + 'pos_heavyatom': pos_ha, + }, path=save_path) + # save_pdb({ + # 'chain_nb': data_cropped['chain_nb'], + # 'chain_id': data_cropped['chain_id'], + # 'resseq': data_cropped['resseq'], + # 'icode': data_cropped['icode'], + # # Generated + # 'aa': aa_new[i], + # 'mask_heavyatom': mask_atom_new[i], + # 'pos_heavyatom': pos_atom_new[i] + batch['origin'][i].view(1, 1, 3).cpu(), + # }, path=os.path.join(log_dir, variant['tag'], '%04d_patch.pdb' % (count, ))) + count += 1 + + logger.info('Finished.\n') + + +if __name__ == '__main__': + main() diff --git a/diffab/utils/data.py b/diffab/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..c206ae5d8e304a0117b78a38cc144a48bc8e5d10 --- /dev/null +++ b/diffab/utils/data.py @@ -0,0 +1,89 @@ +import math +import torch +from torch.utils.data._utils.collate import default_collate + + +DEFAULT_PAD_VALUES = { + 'aa': 21, + 'chain_id': ' ', + 'icode': ' ', +} + +DEFAULT_NO_PADDING = { + 'origin', +} + +class PaddingCollate(object): + + def __init__(self, length_ref_key='aa', pad_values=DEFAULT_PAD_VALUES, no_padding=DEFAULT_NO_PADDING, eight=True): + super().__init__() + self.length_ref_key = length_ref_key + self.pad_values = pad_values + self.no_padding = no_padding + self.eight = eight + + @staticmethod + def _pad_last(x, n, value=0): + if isinstance(x, torch.Tensor): + assert x.size(0) <= n + if x.size(0) == n: + return x + pad_size = [n - x.size(0)] + list(x.shape[1:]) + pad = torch.full(pad_size, fill_value=value).to(x) + return torch.cat([x, pad], dim=0) + elif isinstance(x, list): + pad = [value] * (n - len(x)) + return x + pad + else: + return x + + @staticmethod + def _get_pad_mask(l, n): + return torch.cat([ + torch.ones([l], dtype=torch.bool), + torch.zeros([n-l], dtype=torch.bool) + ], dim=0) + + @staticmethod + def _get_common_keys(list_of_dict): + keys = set(list_of_dict[0].keys()) + for d in list_of_dict[1:]: + keys = keys.intersection(d.keys()) + return keys + + + def _get_pad_value(self, key): + if key not in self.pad_values: + return 0 + return self.pad_values[key] + + def __call__(self, data_list): + max_length = max([data[self.length_ref_key].size(0) for data in data_list]) + keys = self._get_common_keys(data_list) + + if self.eight: + max_length = math.ceil(max_length / 8) * 8 + data_list_padded = [] + for data in data_list: + data_padded = { + k: self._pad_last(v, max_length, value=self._get_pad_value(k)) if k not in self.no_padding else v + for k, v in data.items() + if k in keys + } + data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length) + data_list_padded.append(data_padded) + return default_collate(data_list_padded) + + +def apply_patch_to_tensor(x_full, x_patch, patch_idx): + """ + Args: + x_full: (N, ...) + x_patch: (M, ...) + patch_idx: (M, ) + Returns: + (N, ...) + """ + x_full = x_full.clone() + x_full[patch_idx] = x_patch + return x_full diff --git a/diffab/utils/inference.py b/diffab/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..edd2653c0f4627c957cd560eaef818948e7afd3a --- /dev/null +++ b/diffab/utils/inference.py @@ -0,0 +1,60 @@ +import torch +from .protein import constants + + +def find_cdrs(structure): + cdrs = [] + if structure['heavy'] is not None: + flag = structure['heavy']['cdr_flag'] + if int(constants.CDR.H1) in flag: + cdrs.append('H_CDR1') + if int(constants.CDR.H2) in flag: + cdrs.append('H_CDR2') + if int(constants.CDR.H3) in flag: + cdrs.append('H_CDR3') + + if structure['light'] is not None: + flag = structure['light']['cdr_flag'] + if int(constants.CDR.L1) in flag: + cdrs.append('L_CDR1') + if int(constants.CDR.L2) in flag: + cdrs.append('L_CDR2') + if int(constants.CDR.L3) in flag: + cdrs.append('L_CDR3') + + return cdrs + + +def get_residue_first_last(data): + loop_flag = data['generate_flag'] + loop_idx = torch.arange(loop_flag.size(0))[loop_flag] + idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item() + residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first]) + residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last]) + return residue_first, residue_last + + +class RemoveNative(object): + + def __init__(self, remove_structure, remove_sequence): + super().__init__() + self.remove_structure = remove_structure + self.remove_sequence = remove_sequence + + def __call__(self, data): + generate_flag = data['generate_flag'].clone() + if self.remove_sequence: + data['aa'] = torch.where( + generate_flag, + torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)), # Is loop + data['aa'] + ) + + if self.remove_structure: + data['pos_heavyatom'] = torch.where( + generate_flag[:, None, None].expand(data['pos_heavyatom'].shape), + torch.randn_like(data['pos_heavyatom']) * 10, + data['pos_heavyatom'] + ) + + return data \ No newline at end of file diff --git a/diffab/utils/misc.py b/diffab/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a47d038390b5fb93e399a0e4a6be00240643f8 --- /dev/null +++ b/diffab/utils/misc.py @@ -0,0 +1,126 @@ +import os +import time +import random +import logging +from typing import OrderedDict +import torch +import torch.linalg +import numpy as np +import yaml +from easydict import EasyDict +from glob import glob + + +class BlackHole(object): + def __setattr__(self, name, value): + pass + + def __call__(self, *args, **kwargs): + return self + + def __getattr__(self, name): + return self + + +class Counter(object): + def __init__(self, start=0): + super().__init__() + self.now = start + + def step(self, delta=1): + prev = self.now + self.now += delta + return prev + + +def get_logger(name, log_dir=None): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + if log_dir is not None: + file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_new_log_dir(root='./logs', prefix='', tag=''): + fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) + if prefix != '': + fn = prefix + '_' + fn + if tag != '': + fn = fn + '_' + tag + log_dir = os.path.join(root, fn) + os.makedirs(log_dir) + return log_dir + + +def seed_all(seed): + torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def inf_iterator(iterable): + iterator = iterable.__iter__() + while True: + try: + yield iterator.__next__() + except StopIteration: + iterator = iterable.__iter__() + + +def log_hyperparams(writer, args): + from torch.utils.tensorboard.summary import hparams + vars_args = {k: v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} + exp, ssi, sei = hparams(vars_args, {}) + writer.file_writer.add_summary(exp) + writer.file_writer.add_summary(ssi) + writer.file_writer.add_summary(sei) + + +def int_tuple(argstr): + return tuple(map(int, argstr.split(','))) + + +def str_tuple(argstr): + return tuple(argstr.split(',')) + + +def get_checkpoint_path(folder, it=None): + if it is not None: + return os.path.join(folder, '%d.pt' % it), it + all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) + all_iters.sort() + return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] + + +def load_config(config_path): + with open(config_path, 'r') as f: + config = EasyDict(yaml.safe_load(f)) + config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] + return config, config_name + + +def extract_weights(weights: OrderedDict, prefix): + extracted = OrderedDict() + for k, v in weights.items(): + if k.startswith(prefix): + extracted.update({ + k[len(prefix):]: v + }) + return extracted + + +def current_milli_time(): + return round(time.time() * 1000) diff --git a/diffab/utils/protein/constants.py b/diffab/utils/protein/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..04aafd1767a368c9ed9d6699f0e65101ff5ce9c0 --- /dev/null +++ b/diffab/utils/protein/constants.py @@ -0,0 +1,319 @@ +import torch +import enum + +class CDR(enum.IntEnum): + H1 = 1 + H2 = 2 + H3 = 3 + L1 = 4 + L2 = 5 + L3 = 6 + + +class ChothiaCDRRange: + H1 = (26, 32) + H2 = (52, 56) + H3 = (95, 102) + + L1 = (24, 34) + L2 = (50, 56) + L3 = (89, 97) + + @classmethod + def to_cdr(cls, chain_type, resseq): + assert chain_type in ('H', 'L') + if chain_type == 'H': + if cls.H1[0] <= resseq <= cls.H1[1]: + return CDR.H1 + elif cls.H2[0] <= resseq <= cls.H2[1]: + return CDR.H2 + elif cls.H3[0] <= resseq <= cls.H3[1]: + return CDR.H3 + elif chain_type == 'L': + if cls.L1[0] <= resseq <= cls.L1[1]: # Chothia VH-CDR1 + return CDR.L1 + elif cls.L2[0] <= resseq <= cls.L2[1]: + return CDR.L2 + elif cls.L3[0] <= resseq <= cls.L3[1]: + return CDR.L3 + + +class Fragment(enum.IntEnum): + Heavy = 1 + Light = 2 + Antigen = 3 + +## +# Residue identities +""" +This is part of the OpenMM molecular simulation toolkit originating from +Simbios, the NIH National Center for Physics-Based Simulation of +Biological Structures at Stanford, funded under the NIH Roadmap for +Medical Research, grant U54 GM072970. See https://simtk.org. + +Portions copyright (c) 2013 Stanford University and the Authors. +Authors: Peter Eastman +Contributors: + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +non_standard_residue_substitutions = { + '2AS':'ASP', '3AH':'HIS', '5HP':'GLU', 'ACL':'ARG', 'AGM':'ARG', 'AIB':'ALA', 'ALM':'ALA', 'ALO':'THR', 'ALY':'LYS', 'ARM':'ARG', + 'ASA':'ASP', 'ASB':'ASP', 'ASK':'ASP', 'ASL':'ASP', 'ASQ':'ASP', 'AYA':'ALA', 'BCS':'CYS', 'BHD':'ASP', 'BMT':'THR', 'BNN':'ALA', + 'BUC':'CYS', 'BUG':'LEU', 'C5C':'CYS', 'C6C':'CYS', 'CAS':'CYS', 'CCS':'CYS', 'CEA':'CYS', 'CGU':'GLU', 'CHG':'ALA', 'CLE':'LEU', 'CME':'CYS', + 'CSD':'ALA', 'CSO':'CYS', 'CSP':'CYS', 'CSS':'CYS', 'CSW':'CYS', 'CSX':'CYS', 'CXM':'MET', 'CY1':'CYS', 'CY3':'CYS', 'CYG':'CYS', + 'CYM':'CYS', 'CYQ':'CYS', 'DAH':'PHE', 'DAL':'ALA', 'DAR':'ARG', 'DAS':'ASP', 'DCY':'CYS', 'DGL':'GLU', 'DGN':'GLN', 'DHA':'ALA', + 'DHI':'HIS', 'DIL':'ILE', 'DIV':'VAL', 'DLE':'LEU', 'DLY':'LYS', 'DNP':'ALA', 'DPN':'PHE', 'DPR':'PRO', 'DSN':'SER', 'DSP':'ASP', + 'DTH':'THR', 'DTR':'TRP', 'DTY':'TYR', 'DVA':'VAL', 'EFC':'CYS', 'FLA':'ALA', 'FME':'MET', 'GGL':'GLU', 'GL3':'GLY', 'GLZ':'GLY', + 'GMA':'GLU', 'GSC':'GLY', 'HAC':'ALA', 'HAR':'ARG', 'HIC':'HIS', 'HIP':'HIS', 'HMR':'ARG', 'HPQ':'PHE', 'HTR':'TRP', 'HYP':'PRO', + 'IAS':'ASP', 'IIL':'ILE', 'IYR':'TYR', 'KCX':'LYS', 'LLP':'LYS', 'LLY':'LYS', 'LTR':'TRP', 'LYM':'LYS', 'LYZ':'LYS', 'MAA':'ALA', 'MEN':'ASN', + 'MHS':'HIS', 'MIS':'SER', 'MLE':'LEU', 'MPQ':'GLY', 'MSA':'GLY', 'MSE':'MET', 'MVA':'VAL', 'NEM':'HIS', 'NEP':'HIS', 'NLE':'LEU', + 'NLN':'LEU', 'NLP':'LEU', 'NMC':'GLY', 'OAS':'SER', 'OCS':'CYS', 'OMT':'MET', 'PAQ':'TYR', 'PCA':'GLU', 'PEC':'CYS', 'PHI':'PHE', + 'PHL':'PHE', 'PR3':'CYS', 'PRR':'ALA', 'PTR':'TYR', 'PYX':'CYS', 'SAC':'SER', 'SAR':'GLY', 'SCH':'CYS', 'SCS':'CYS', 'SCY':'CYS', + 'SEL':'SER', 'SEP':'SER', 'SET':'SER', 'SHC':'CYS', 'SHR':'LYS', 'SMC':'CYS', 'SOC':'CYS', 'STY':'TYR', 'SVA':'SER', 'TIH':'ALA', + 'TPL':'TRP', 'TPO':'THR', 'TPQ':'ALA', 'TRG':'LYS', 'TRO':'TRP', 'TYB':'TYR', 'TYI':'TYR', 'TYQ':'TYR', 'TYS':'TYR', 'TYY':'TYR' +} + + +ressymb_to_resindex = { + 'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, + 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, + 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, + 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19, + 'X': 20, +} + + +class AA(enum.IntEnum): + ALA = 0; CYS = 1; ASP = 2; GLU = 3; PHE = 4 + GLY = 5; HIS = 6; ILE = 7; LYS = 8; LEU = 9 + MET = 10; ASN = 11; PRO = 12; GLN = 13; ARG = 14 + SER = 15; THR = 16; VAL = 17; TRP = 18; TYR = 19 + UNK = 20 + + @classmethod + def _missing_(cls, value): + if isinstance(value, str) and len(value) == 3: # three representation + if value in non_standard_residue_substitutions: + value = non_standard_residue_substitutions[value] + if value in cls._member_names_: + return getattr(cls, value) + elif isinstance(value, str) and len(value) == 1: # one representation + if value in ressymb_to_resindex: + return cls(ressymb_to_resindex[value]) + + return super()._missing_(value) + + def __str__(self): + return self.name + + @classmethod + def is_aa(cls, value): + return (value in ressymb_to_resindex) or \ + (value in non_standard_residue_substitutions) or \ + (value in cls._member_names_) or \ + (value in cls._member_map_.values()) + + +num_aa_types = len(AA) + +## +# Atom identities + +class BBHeavyAtom(enum.IntEnum): + N = 0; CA = 1; C = 2; O = 3; CB = 4; OXT=14; + +max_num_heavyatoms = 15 + +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +restype_to_heavyatom_names = { + AA.ALA: ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', '', 'OXT'], + AA.ARG: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', '', 'OXT'], + AA.ASN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', '', 'OXT'], + AA.ASP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', '', 'OXT'], + AA.CYS: ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', '', 'OXT'], + AA.GLN: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', '', 'OXT'], + AA.GLU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', '', 'OXT'], + AA.GLY: ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', '', 'OXT'], + AA.HIS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', '', 'OXT'], + AA.ILE: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', '', 'OXT'], + AA.LEU: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', '', 'OXT'], + AA.LYS: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', '', 'OXT'], + AA.MET: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', '', 'OXT'], + AA.PHE: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', '', 'OXT'], + AA.PRO: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', '', 'OXT'], + AA.SER: ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', '', 'OXT'], + AA.THR: ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', '', 'OXT'], + AA.TRP: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT'], + AA.TYR: ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', '', 'OXT'], + AA.VAL: ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', '', 'OXT'], + AA.UNK: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +for names in restype_to_heavyatom_names.values(): assert len(names) == max_num_heavyatoms + + +backbone_atom_coordinates = { + AA.ALA: [ + (-0.525, 1.363, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, -0.0, -0.0), # C + ], + AA.ARG: [ + (-0.524, 1.362, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, -0.0, -0.0), # C + ], + AA.ASN: [ + (-0.536, 1.357, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, -0.0, -0.0), # C + ], + AA.ASP: [ + (-0.525, 1.362, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.527, 0.0, -0.0), # C + ], + AA.CYS: [ + (-0.522, 1.362, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.524, 0.0, 0.0), # C + ], + AA.GLN: [ + (-0.526, 1.361, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, 0.0, 0.0), # C + ], + AA.GLU: [ + (-0.528, 1.361, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, -0.0, -0.0), # C + ], + AA.GLY: [ + (-0.572, 1.337, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.517, -0.0, -0.0), # C + ], + AA.HIS: [ + (-0.527, 1.36, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, 0.0, 0.0), # C + ], + AA.ILE: [ + (-0.493, 1.373, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.527, -0.0, -0.0), # C + ], + AA.LEU: [ + (-0.52, 1.363, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, -0.0, -0.0), # C + ], + AA.LYS: [ + (-0.526, 1.362, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, 0.0, 0.0), # C + ], + AA.MET: [ + (-0.521, 1.364, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, 0.0, 0.0), # C + ], + AA.PHE: [ + (-0.518, 1.363, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.524, 0.0, -0.0), # C + ], + AA.PRO: [ + (-0.566, 1.351, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.527, -0.0, 0.0), # C + ], + AA.SER: [ + (-0.529, 1.36, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, -0.0, -0.0), # C + ], + AA.THR: [ + (-0.517, 1.364, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.526, 0.0, -0.0), # C + ], + AA.TRP: [ + (-0.521, 1.363, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.525, -0.0, 0.0), # C + ], + AA.TYR: [ + (-0.522, 1.362, 0.0), # N + (0.0, 0.0, 0.0), # CA + (1.524, -0.0, -0.0), # C + ], + AA.VAL: [ + (-0.494, 1.373, -0.0), # N + (0.0, 0.0, 0.0), # CA + (1.527, -0.0, -0.0), # C + ], +} + +bb_oxygen_coordinate = { + AA.ALA: (2.153, -1.062, 0.0), + AA.ARG: (2.151, -1.062, 0.0), + AA.ASN: (2.151, -1.062, 0.0), + AA.ASP: (2.153, -1.062, 0.0), + AA.CYS: (2.149, -1.062, 0.0), + AA.GLN: (2.152, -1.062, 0.0), + AA.GLU: (2.152, -1.062, 0.0), + AA.GLY: (2.143, -1.062, 0.0), + AA.HIS: (2.15, -1.063, 0.0), + AA.ILE: (2.154, -1.062, 0.0), + AA.LEU: (2.15, -1.063, 0.0), + AA.LYS: (2.152, -1.062, 0.0), + AA.MET: (2.15, -1.062, 0.0), + AA.PHE: (2.15, -1.062, 0.0), + AA.PRO: (2.148, -1.066, 0.0), + AA.SER: (2.151, -1.062, 0.0), + AA.THR: (2.152, -1.062, 0.0), + AA.TRP: (2.152, -1.062, 0.0), + AA.TYR: (2.151, -1.062, 0.0), + AA.VAL: (2.154, -1.062, 0.0), +} + +backbone_atom_coordinates_tensor = torch.zeros([21, 3, 3]) +bb_oxygen_coordinate_tensor = torch.zeros([21, 3]) + +def make_coordinate_tensors(): + for restype, atom_coords in backbone_atom_coordinates.items(): + for atom_id, atom_coord in enumerate(atom_coords): + backbone_atom_coordinates_tensor[restype][atom_id] = torch.FloatTensor(atom_coord) + + for restype, bb_oxy_coord in bb_oxygen_coordinate.items(): + bb_oxygen_coordinate_tensor[restype] = torch.FloatTensor(bb_oxy_coord) +make_coordinate_tensors() diff --git a/diffab/utils/protein/parsers.py b/diffab/utils/protein/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..50ea42d7d58da13fe969625f9fc9a2a1037441e0 --- /dev/null +++ b/diffab/utils/protein/parsers.py @@ -0,0 +1,109 @@ +import torch +from Bio.PDB import Selection +from Bio.PDB.Residue import Residue +from easydict import EasyDict + +from .constants import ( + AA, max_num_heavyatoms, + restype_to_heavyatom_names, + BBHeavyAtom +) + + +class ParsingException(Exception): + pass + + +def _get_residue_heavyatom_info(res: Residue): + pos_heavyatom = torch.zeros([max_num_heavyatoms, 3], dtype=torch.float) + mask_heavyatom = torch.zeros([max_num_heavyatoms, ], dtype=torch.bool) + restype = AA(res.get_resname()) + for idx, atom_name in enumerate(restype_to_heavyatom_names[restype]): + if atom_name == '': continue + if atom_name in res: + pos_heavyatom[idx] = torch.tensor(res[atom_name].get_coord().tolist(), dtype=pos_heavyatom.dtype) + mask_heavyatom[idx] = True + return pos_heavyatom, mask_heavyatom + + +def parse_biopython_structure(entity, unknown_threshold=1.0, max_resseq=None): + chains = Selection.unfold_entities(entity, 'C') + chains.sort(key=lambda c: c.get_id()) + data = EasyDict({ + 'chain_id': [], + 'resseq': [], 'icode': [], 'res_nb': [], + 'aa': [], + 'pos_heavyatom': [], 'mask_heavyatom': [], + }) + tensor_types = { + 'resseq': torch.LongTensor, + 'res_nb': torch.LongTensor, + 'aa': torch.LongTensor, + 'pos_heavyatom': torch.stack, + 'mask_heavyatom': torch.stack, + } + + count_aa, count_unk = 0, 0 + + for i, chain in enumerate(chains): + seq_this = 0 # Renumbering residues + residues = Selection.unfold_entities(chain, 'R') + residues.sort(key=lambda res: (res.get_id()[1], res.get_id()[2])) # Sort residues by resseq-icode + for _, res in enumerate(residues): + resseq_this = int(res.get_id()[1]) + if max_resseq is not None and resseq_this > max_resseq: + continue + + resname = res.get_resname() + if not AA.is_aa(resname): continue + if not (res.has_id('CA') and res.has_id('C') and res.has_id('N')): continue + restype = AA(resname) + count_aa += 1 + if restype == AA.UNK: + count_unk += 1 + continue + + # Chain info + data.chain_id.append(chain.get_id()) + + # Residue types + data.aa.append(restype) # Will be automatically cast to torch.long + + # Heavy atoms + pos_heavyatom, mask_heavyatom = _get_residue_heavyatom_info(res) + data.pos_heavyatom.append(pos_heavyatom) + data.mask_heavyatom.append(mask_heavyatom) + + # Sequential number + resseq_this = int(res.get_id()[1]) + icode_this = res.get_id()[2] + if seq_this == 0: + seq_this = 1 + else: + d_CA_CA = torch.linalg.norm(data.pos_heavyatom[-2][BBHeavyAtom.CA] - data.pos_heavyatom[-1][BBHeavyAtom.CA], ord=2).item() + if d_CA_CA <= 4.0: + seq_this += 1 + else: + d_resseq = resseq_this - data.resseq[-1] + seq_this += max(2, d_resseq) + + data.resseq.append(resseq_this) + data.icode.append(icode_this) + data.res_nb.append(seq_this) + + if len(data.aa) == 0: + raise ParsingException('No parsed residues.') + + if (count_unk / count_aa) >= unknown_threshold: + raise ParsingException( + f'Too many unknown residues, threshold {unknown_threshold:.2f}.' + ) + + seq_map = {} + for i, (chain_id, resseq, icode) in enumerate(zip(data.chain_id, data.resseq, data.icode)): + seq_map[(chain_id, resseq, icode)] = i + + for key, convert_fn in tensor_types.items(): + data[key] = convert_fn(data[key]) + + return data, seq_map diff --git a/diffab/utils/protein/writers.py b/diffab/utils/protein/writers.py new file mode 100644 index 0000000000000000000000000000000000000000..2889e8e7ebe938f2a054a6d1a84b7f50318d8430 --- /dev/null +++ b/diffab/utils/protein/writers.py @@ -0,0 +1,75 @@ +import torch +import warnings +from Bio import BiopythonWarning +from Bio.PDB import PDBIO +from Bio.PDB.StructureBuilder import StructureBuilder + +from .constants import AA, restype_to_heavyatom_names + + +def save_pdb(data, path=None): + """ + Args: + data: A dict that contains: `chain_nb`, `chain_id`, `aa`, `resseq`, `icode`, + `pos_heavyatom`, `mask_heavyatom`. + """ + + def _mask_select(v, mask): + if isinstance(v, str): + return ''.join([s for i, s in enumerate(v) if mask[i]]) + elif isinstance(v, list): + return [s for i, s in enumerate(v) if mask[i]] + elif isinstance(v, torch.Tensor): + return v[mask] + else: + return v + + def _build_chain(builder, aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, chain_id_ch, resseq_ch, icode_ch): + builder.init_chain(chain_id_ch[0]) + builder.init_seg(' ') + + for aa_res, pos_allatom_res, mask_allatom_res, resseq_res, icode_res in \ + zip(aa_ch, pos_heavyatom_ch, mask_heavyatom_ch, resseq_ch, icode_ch): + if not AA.is_aa(aa_res.item()): + print('[Warning] Unknown amino acid type at %d%s: %r' % (resseq_res.item(), icode_res, aa_res.item())) + continue + restype = AA(aa_res.item()) + builder.init_residue( + resname = str(restype), + field = ' ', + resseq = resseq_res.item(), + icode = icode_res, + ) + + for i, atom_name in enumerate(restype_to_heavyatom_names[restype]): + if atom_name == '': continue # No expected atom + if (~mask_allatom_res[i]).any(): continue # Atom is missing + if len(atom_name) == 1: fullname = ' %s ' % atom_name + elif len(atom_name) == 2: fullname = ' %s ' % atom_name + elif len(atom_name) == 3: fullname = ' %s' % atom_name + else: fullname = atom_name # len == 4 + builder.init_atom(atom_name, pos_allatom_res[i].tolist(), 0.0, 1.0, ' ', fullname,) + + warnings.simplefilter('ignore', BiopythonWarning) + builder = StructureBuilder() + builder.init_structure(0) + builder.init_model(0) + + unique_chain_nb = data['chain_nb'].unique().tolist() + for ch_nb in unique_chain_nb: + mask = (data['chain_nb'] == ch_nb) + aa = _mask_select(data['aa'], mask) + pos_heavyatom = _mask_select(data['pos_heavyatom'], mask) + mask_heavyatom = _mask_select(data['mask_heavyatom'], mask) + chain_id = _mask_select(data['chain_id'], mask) + resseq = _mask_select(data['resseq'], mask) + icode = _mask_select(data['icode'], mask) + + _build_chain(builder, aa, pos_heavyatom, mask_heavyatom, chain_id, resseq, icode) + + structure = builder.get_structure() + if path is not None: + io = PDBIO() + io.set_structure(structure) + io.save(path) + return structure diff --git a/diffab/utils/train.py b/diffab/utils/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1af071f27913a63e6569d9b6d92d89249b7d4b --- /dev/null +++ b/diffab/utils/train.py @@ -0,0 +1,151 @@ +import numpy as np +import torch +from easydict import EasyDict + +from .misc import BlackHole + + +def get_optimizer(cfg, model): + if cfg.type == 'adam': + return torch.optim.Adam( + model.parameters(), + lr=cfg.lr, + weight_decay=cfg.weight_decay, + betas=(cfg.beta1, cfg.beta2, ) + ) + else: + raise NotImplementedError('Optimizer not supported: %s' % cfg.type) + + +def get_scheduler(cfg, optimizer): + if cfg.type is None: + return BlackHole() + elif cfg.type == 'plateau': + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=cfg.factor, + patience=cfg.patience, + min_lr=cfg.min_lr, + ) + elif cfg.type == 'multistep': + return torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=cfg.milestones, + gamma=cfg.gamma, + ) + elif cfg.type == 'exp': + return torch.optim.lr_scheduler.ExponentialLR( + optimizer, + gamma=cfg.gamma, + ) + elif cfg.type is None: + return BlackHole() + else: + raise NotImplementedError('Scheduler not supported: %s' % cfg.type) + + +def get_warmup_sched(cfg, optimizer): + if cfg is None: return BlackHole() + lambdas = [lambda it : (it / cfg.max_iters) if it <= cfg.max_iters else 1 for _ in optimizer.param_groups] + warmup_sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lambdas) + return warmup_sched + + +def log_losses(out, it, tag, logger=BlackHole(), writer=BlackHole(), others={}): + logstr = '[%s] Iter %05d' % (tag, it) + logstr += ' | loss %.4f' % out['overall'].item() + for k, v in out.items(): + if k == 'overall': continue + logstr += ' | loss(%s) %.4f' % (k, v.item()) + for k, v in others.items(): + logstr += ' | %s %2.4f' % (k, v) + logger.info(logstr) + + for k, v in out.items(): + if k == 'overall': + writer.add_scalar('%s/loss' % tag, v, it) + else: + writer.add_scalar('%s/loss_%s' % (tag, k), v, it) + for k, v in others.items(): + writer.add_scalar('%s/%s' % (tag, k), v, it) + writer.flush() + + +class ValidationLossTape(object): + + def __init__(self): + super().__init__() + self.accumulate = {} + self.others = {} + self.total = 0 + + def update(self, out, n, others={}): + self.total += n + for k, v in out.items(): + if k not in self.accumulate: + self.accumulate[k] = v.clone().detach() + else: + self.accumulate[k] += v.clone().detach() + + for k, v in others.items(): + if k not in self.others: + self.others[k] = v.clone().detach() + else: + self.others[k] += v.clone().detach() + + + def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'): + avg = EasyDict({k:v / self.total for k, v in self.accumulate.items()}) + avg_others = EasyDict({k:v / self.total for k, v in self.others.items()}) + log_losses(avg, it, tag, logger, writer, others=avg_others) + return avg['overall'] + + +def recursive_to(obj, device): + if isinstance(obj, torch.Tensor): + if device == 'cpu': + return obj.cpu() + try: + return obj.cuda(device=device, non_blocking=True) + except RuntimeError: + return obj.to(device) + elif isinstance(obj, list): + return [recursive_to(o, device=device) for o in obj] + elif isinstance(obj, tuple): + return tuple(recursive_to(o, device=device) for o in obj) + elif isinstance(obj, dict): + return {k: recursive_to(v, device=device) for k, v in obj.items()} + + else: + return obj + + +def reweight_loss_by_sequence_length(length, max_length, mode='sqrt'): + if mode == 'sqrt': + w = np.sqrt(length / max_length) + elif mode == 'linear': + w = length / max_length + elif mode is None: + w = 1.0 + else: + raise ValueError('Unknown reweighting mode: %s' % mode) + return w + + +def sum_weighted_losses(losses, weights): + """ + Args: + losses: Dict of scalar tensors. + weights: Dict of weights. + """ + loss = 0 + for k in losses.keys(): + if weights is None: + loss = loss + losses[k] + else: + loss = loss + weights[k] * losses[k] + return loss + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) diff --git a/diffab/utils/transforms/__init__.py b/diffab/utils/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c4cd2f33b86e4b0ad55bdb5c5a1f8ed392d9f6c --- /dev/null +++ b/diffab/utils/transforms/__init__.py @@ -0,0 +1,7 @@ +# Transforms +from .mask import MaskSingleCDR, MaskMultipleCDRs, MaskAntibody +from .merge import MergeChains +from .patch import PatchAroundAnchor + +# Factory +from ._base import get_transform, Compose diff --git a/diffab/utils/transforms/_base.py b/diffab/utils/transforms/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..0694aae80271b0a8e990e209daa1822393988f1b --- /dev/null +++ b/diffab/utils/transforms/_base.py @@ -0,0 +1,56 @@ +import copy +import torch +from torchvision.transforms import Compose + + +_TRANSFORM_DICT = {} + + +def register_transform(name): + def decorator(cls): + _TRANSFORM_DICT[name] = cls + return cls + return decorator + + +def get_transform(cfg): + if cfg is None or len(cfg) == 0: + return None + tfms = [] + for t_dict in cfg: + t_dict = copy.deepcopy(t_dict) + cls = _TRANSFORM_DICT[t_dict.pop('type')] + tfms.append(cls(**t_dict)) + return Compose(tfms) + + +def _index_select(v, index, n): + if isinstance(v, torch.Tensor) and v.size(0) == n: + return v[index] + elif isinstance(v, list) and len(v) == n: + return [v[i] for i in index] + else: + return v + + +def _index_select_data(data, index): + return { + k: _index_select(v, index, data['aa'].size(0)) + for k, v in data.items() + } + + +def _mask_select(v, mask): + if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): + return v[mask] + elif isinstance(v, list) and len(v) == mask.size(0): + return [v[i] for i, b in enumerate(mask) if b] + else: + return v + + +def _mask_select_data(data, mask): + return { + k: _mask_select(v, mask) + for k, v in data.items() + } diff --git a/diffab/utils/transforms/mask.py b/diffab/utils/transforms/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1b49a9692a078c8e58044cf100a2c5341288ab --- /dev/null +++ b/diffab/utils/transforms/mask.py @@ -0,0 +1,213 @@ +import torch +import random +from typing import List, Optional + +from ..protein import constants +from ._base import register_transform + + +def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2): + first, last = continuous_flag_to_range(flag) + length = flag.sum().item() + if (length - 2*shrink_limit) < min_length: + shrink_limit = 0 + first_ext = max(0, first-random.randint(-shrink_limit, extend_limit)) + last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1) + flag_ext = flag.clone() + flag_ext[first_ext : last_ext+1] = True + return flag_ext + + +def continuous_flag_to_range(flag): + first = (torch.arange(0, flag.size(0))[flag]).min().item() + last = (torch.arange(0, flag.size(0))[flag]).max().item() + return first, last + + +@register_transform('mask_single_cdr') +class MaskSingleCDR(object): + + def __init__(self, selection=None, augmentation=True): + super().__init__() + cdr_str_to_enum = { + 'H1': constants.CDR.H1, + 'H2': constants.CDR.H2, + 'H3': constants.CDR.H3, + 'L1': constants.CDR.L1, + 'L2': constants.CDR.L2, + 'L3': constants.CDR.L3, + 'H_CDR1': constants.CDR.H1, + 'H_CDR2': constants.CDR.H2, + 'H_CDR3': constants.CDR.H3, + 'L_CDR1': constants.CDR.L1, + 'L_CDR2': constants.CDR.L2, + 'L_CDR3': constants.CDR.L3, + 'CDR3': 'CDR3', # H3 first, then fallback to L3 + } + assert selection is None or selection in cdr_str_to_enum + self.selection = cdr_str_to_enum.get(selection, None) + self.augmentation = augmentation + + def perform_masking_(self, data, selection=None): + cdr_flag = data['cdr_flag'] + + if selection is None: + cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() + cdr_to_mask = random.choice(cdr_all) + else: + cdr_to_mask = selection + + cdr_to_mask_flag = (cdr_flag == cdr_to_mask) + if self.augmentation: + cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) + + cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) + left_idx = max(0, cdr_first-1) + right_idx = min(data['aa'].size(0)-1, cdr_last+1) + anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) + anchor_flag[left_idx] = True + anchor_flag[right_idx] = True + + data['generate_flag'] = cdr_to_mask_flag + data['anchor_flag'] = anchor_flag + + def __call__(self, structure): + if self.selection is None: + ab_data = [] + if structure['heavy'] is not None: + ab_data.append(structure['heavy']) + if structure['light'] is not None: + ab_data.append(structure['light']) + data_to_mask = random.choice(ab_data) + sel = None + elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ): + data_to_mask = structure['heavy'] + sel = int(self.selection) + elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ): + data_to_mask = structure['light'] + sel = int(self.selection) + elif self.selection == 'CDR3': + if structure['heavy'] is not None: + data_to_mask = structure['heavy'] + sel = constants.CDR.H3 + else: + data_to_mask = structure['light'] + sel = constants.CDR.L3 + + self.perform_masking_(data_to_mask, selection=sel) + return structure + + +@register_transform('mask_multiple_cdrs') +class MaskMultipleCDRs(object): + + def __init__(self, selection: Optional[List[str]]=None, augmentation=True): + super().__init__() + cdr_str_to_enum = { + 'H1': constants.CDR.H1, + 'H2': constants.CDR.H2, + 'H3': constants.CDR.H3, + 'L1': constants.CDR.L1, + 'L2': constants.CDR.L2, + 'L3': constants.CDR.L3, + 'H_CDR1': constants.CDR.H1, + 'H_CDR2': constants.CDR.H2, + 'H_CDR3': constants.CDR.H3, + 'L_CDR1': constants.CDR.L1, + 'L_CDR2': constants.CDR.L2, + 'L_CDR3': constants.CDR.L3, + } + if selection is not None: + self.selection = [cdr_str_to_enum[s] for s in selection] + else: + self.selection = None + self.augmentation = augmentation + + def mask_one_cdr_(self, data, cdr_to_mask): + cdr_flag = data['cdr_flag'] + + cdr_to_mask_flag = (cdr_flag == cdr_to_mask) + if self.augmentation: + cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag) + + cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag) + left_idx = max(0, cdr_first-1) + right_idx = min(data['aa'].size(0)-1, cdr_last+1) + anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool) + anchor_flag[left_idx] = True + anchor_flag[right_idx] = True + + if 'generate_flag' not in data: + data['generate_flag'] = cdr_to_mask_flag + data['anchor_flag'] = anchor_flag + else: + data['generate_flag'] |= cdr_to_mask_flag + data['anchor_flag'] |= anchor_flag + + def mask_for_one_chain_(self, data): + cdr_flag = data['cdr_flag'] + cdr_all = cdr_flag[cdr_flag > 0].unique().tolist() + + num_cdrs_to_mask = random.randint(1, len(cdr_all)) + + if self.selection is not None: + cdrs_to_mask = list(set(cdr_all).intersection(self.selection)) + else: + random.shuffle(cdr_all) + cdrs_to_mask = cdr_all[:num_cdrs_to_mask] + + for cdr_to_mask in cdrs_to_mask: + self.mask_one_cdr_(data, cdr_to_mask) + + def __call__(self, structure): + if structure['heavy'] is not None: + self.mask_for_one_chain_(structure['heavy']) + if structure['light'] is not None: + self.mask_for_one_chain_(structure['light']) + return structure + + +@register_transform('mask_antibody') +class MaskAntibody(object): + + def mask_ab_chain_(self, data): + data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool) + + def __call__(self, structure): + pos_ab_alpha = [] + if structure['heavy'] is not None: + self.mask_ab_chain_(structure['heavy']) + pos_ab_alpha.append( + structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] + ) + if structure['light'] is not None: + self.mask_ab_chain_(structure['light']) + pos_ab_alpha.append( + structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] + ) + pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0) # (L_Ab, 3) + + if structure['antigen'] is not None: + pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA] + ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha) # (L_Ag, L_Ab) + nn_ab_dist = ag_ab_dist.min(dim=1)[0] # (L_Ag) + contact_flag = (nn_ab_dist <= 6.0) # (L_Ag) + if contact_flag.sum().item() == 0: + contact_flag[nn_ab_dist.argmin()] = True + + anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item() + anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool) + anchor_flag[anchor_idx] = True + structure['antigen']['anchor_flag'] = anchor_flag + structure['antigen']['contact_flag'] = contact_flag + + return structure + + +@register_transform('remove_antigen') +class RemoveAntigen: + + def __call__(self, structure): + structure['antigen'] = None + structure['antigen_seqmap'] = None + return structure diff --git a/diffab/utils/transforms/merge.py b/diffab/utils/transforms/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0bdbd4e1984e47b13e4fa75276a6d3dcb0f98b --- /dev/null +++ b/diffab/utils/transforms/merge.py @@ -0,0 +1,88 @@ +import torch + +from ..protein import constants +from ._base import register_transform + + +@register_transform('merge_chains') +class MergeChains(object): + + def __init__(self): + super().__init__() + + def assign_chain_number_(self, data_list): + chains = set() + for data in data_list: + chains.update(data['chain_id']) + chains = {c: i for i, c in enumerate(chains)} + + for data in data_list: + data['chain_nb'] = torch.LongTensor([ + chains[c] for c in data['chain_id'] + ]) + + def _data_attr(self, data, name): + if name in ('generate_flag', 'anchor_flag') and name not in data: + return torch.zeros(data['aa'].shape, dtype=torch.bool) + else: + return data[name] + + def __call__(self, structure): + data_list = [] + if structure['heavy'] is not None: + structure['heavy']['fragment_type'] = torch.full_like( + structure['heavy']['aa'], + fill_value = constants.Fragment.Heavy, + ) + data_list.append(structure['heavy']) + + if structure['light'] is not None: + structure['light']['fragment_type'] = torch.full_like( + structure['light']['aa'], + fill_value = constants.Fragment.Light, + ) + data_list.append(structure['light']) + + if structure['antigen'] is not None: + structure['antigen']['fragment_type'] = torch.full_like( + structure['antigen']['aa'], + fill_value = constants.Fragment.Antigen, + ) + structure['antigen']['cdr_flag'] = torch.zeros_like( + structure['antigen']['aa'], + ) + data_list.append(structure['antigen']) + + self.assign_chain_number_(data_list) + + list_props = { + 'chain_id': [], + 'icode': [], + } + tensor_props = { + 'chain_nb': [], + 'resseq': [], + 'res_nb': [], + 'aa': [], + 'pos_heavyatom': [], + 'mask_heavyatom': [], + 'generate_flag': [], + 'cdr_flag': [], + 'anchor_flag': [], + 'fragment_type': [], + } + + for data in data_list: + for k in list_props.keys(): + list_props[k].append(self._data_attr(data, k)) + for k in tensor_props.keys(): + tensor_props[k].append(self._data_attr(data, k)) + + list_props = {k: sum(v, start=[]) for k, v in list_props.items()} + tensor_props = {k: torch.cat(v, dim=0) for k, v in tensor_props.items()} + data_out = { + **list_props, + **tensor_props, + } + return data_out + diff --git a/diffab/utils/transforms/patch.py b/diffab/utils/transforms/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..abe678eb6fa3f64a0637ab8dc87e3ef6102347b8 --- /dev/null +++ b/diffab/utils/transforms/patch.py @@ -0,0 +1,73 @@ +import torch + +from ._base import _mask_select_data, register_transform +from ..protein import constants + + +@register_transform('patch_around_anchor') +class PatchAroundAnchor(object): + + def __init__(self, initial_patch_size=128, antigen_size=128): + super().__init__() + self.initial_patch_size = initial_patch_size + self.antigen_size = antigen_size + + def _center(self, data, origin): + origin = origin.reshape(1, 1, 3) + data['pos_heavyatom'] -= origin # (L, A, 3) + data['pos_heavyatom'] = data['pos_heavyatom'] * data['mask_heavyatom'][:, :, None] + data['origin'] = origin.reshape(3) + return data + + def __call__(self, data): + anchor_flag = data['anchor_flag'] # (L,) + anchor_points = data['pos_heavyatom'][anchor_flag, constants.BBHeavyAtom.CA] # (n_anchors, 3) + antigen_mask = (data['fragment_type'] == constants.Fragment.Antigen) + antibody_mask = torch.logical_not(antigen_mask) + + if anchor_flag.sum().item() == 0: + # Generating full antibody-Fv, no antigen given + data_patch = _mask_select_data( + data = data, + mask = antibody_mask, + ) + data_patch = self._center( + data_patch, + origin = data_patch['pos_heavyatom'][:, constants.BBHeavyAtom.CA].mean(dim=0) + ) + return data_patch + + pos_alpha = data['pos_heavyatom'][:, constants.BBHeavyAtom.CA] # (L, 3) + dist_anchor = torch.cdist(pos_alpha, anchor_points).min(dim=1)[0] # (L, ) + initial_patch_idx = torch.topk( + dist_anchor, + k = min(self.initial_patch_size, dist_anchor.size(0)), + largest=False, + )[1] # (initial_patch_size, ) + + dist_anchor_antigen = dist_anchor.masked_fill( + mask = antibody_mask, # Fill antibody with +inf + value = float('+inf') + ) # (L, ) + antigen_patch_idx = torch.topk( + dist_anchor_antigen, + k = min(self.antigen_size, antigen_mask.sum().item()), + largest=False, sorted=True + )[1] # (ag_size, ) + + patch_mask = torch.logical_or( + data['generate_flag'], + data['anchor_flag'], + ) + patch_mask[initial_patch_idx] = True + patch_mask[antigen_patch_idx] = True + + patch_idx = torch.arange(0, patch_mask.shape[0])[patch_mask] + + data_patch = _mask_select_data(data, patch_mask) + data_patch = self._center( + data_patch, + origin = anchor_points.mean(dim=0) + ) + data_patch['patch_idx'] = patch_idx + return data_patch diff --git a/diffab/utils/transforms/select_atom.py b/diffab/utils/transforms/select_atom.py new file mode 100644 index 0000000000000000000000000000000000000000..7d067ecb50873dc3ec3c5626bc8d1a836258780a --- /dev/null +++ b/diffab/utils/transforms/select_atom.py @@ -0,0 +1,20 @@ + +from ._base import register_transform + + +@register_transform('select_atom') +class SelectAtom(object): + + def __init__(self, resolution): + super().__init__() + assert resolution in ('full', 'backbone') + self.resolution = resolution + + def __call__(self, data): + if self.resolution == 'full': + data['pos_atoms'] = data['pos_heavyatom'][:, :] + data['mask_atoms'] = data['mask_heavyatom'][:, :] + elif self.resolution == 'backbone': + data['pos_atoms'] = data['pos_heavyatom'][:, :5] + data['mask_atoms'] = data['mask_heavyatom'][:, :5] + return data diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa3cf492a788d2b9c935519c5f97a036f7881f5b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +--extra-index-url https://download.pytorch.org/whl/cu113 +pytorch +torchvision +biopython==1.79 +git+ssh://git@github.com/oxpig/ANARCI.git +git+ssh://git@github.com/prihoda/AbNumber.git +joblib +lmdb +tqdm +easydict +pyyaml +stmol