import torch import torch.nn as nn from diffab.modules.common.geometry import construct_3d_basis from diffab.modules.common.so3 import rotation_to_so3vec from diffab.modules.encoders.residue import ResidueEmbedding from diffab.modules.encoders.pair import PairEmbedding from diffab.modules.diffusion.dpm_full import FullDPM from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom from ._base import register_model resolution_to_num_atoms = { 'backbone+CB': 5, 'full': max_num_heavyatoms } @register_model('diffab') class DiffusionAntibodyDesign(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')] self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms) self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms) self.diffusion = FullDPM( cfg.res_feat_dim, cfg.pair_feat_dim, **cfg.diffusion, ) def encode(self, batch, remove_structure, remove_sequence): """ Returns: res_feat: (N, L, res_feat_dim) pair_feat: (N, L, L, pair_feat_dim) """ # This is used throughout embedding and encoding layers # to avoid data leakage. context_mask = torch.logical_and( batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], ~batch['generate_flag'] # Context means ``not generated'' ) structure_mask = context_mask if remove_structure else None sequence_mask = context_mask if remove_sequence else None res_feat = self.residue_embed( aa = batch['aa'], res_nb = batch['res_nb'], chain_nb = batch['chain_nb'], pos_atoms = batch['pos_heavyatom'], mask_atoms = batch['mask_heavyatom'], fragment_type = batch['fragment_type'], structure_mask = structure_mask, sequence_mask = sequence_mask, ) pair_feat = self.pair_embed( aa = batch['aa'], res_nb = batch['res_nb'], chain_nb = batch['chain_nb'], pos_atoms = batch['pos_heavyatom'], mask_atoms = batch['mask_heavyatom'], structure_mask = structure_mask, sequence_mask = sequence_mask, ) R = construct_3d_basis( batch['pos_heavyatom'][:, :, BBHeavyAtom.CA], batch['pos_heavyatom'][:, :, BBHeavyAtom.C], batch['pos_heavyatom'][:, :, BBHeavyAtom.N], ) p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA] return res_feat, pair_feat, R, p def forward(self, batch): mask_generate = batch['generate_flag'] mask_res = batch['mask'] res_feat, pair_feat, R_0, p_0 = self.encode( batch, remove_structure = self.cfg.get('train_structure', True), remove_sequence = self.cfg.get('train_sequence', True) ) v_0 = rotation_to_so3vec(R_0) s_0 = batch['aa'] loss_dict = self.diffusion( v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, denoise_structure = self.cfg.get('train_structure', True), denoise_sequence = self.cfg.get('train_sequence', True), ) return loss_dict @torch.no_grad() def sample( self, batch, sample_opt={ 'sample_structure': True, 'sample_sequence': True, } ): mask_generate = batch['generate_flag'] mask_res = batch['mask'] res_feat, pair_feat, R_0, p_0 = self.encode( batch, remove_structure = sample_opt.get('sample_structure', True), remove_sequence = sample_opt.get('sample_sequence', True) ) v_0 = rotation_to_so3vec(R_0) s_0 = batch['aa'] traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt) return traj @torch.no_grad() def optimize( self, batch, opt_step, optimize_opt={ 'sample_structure': True, 'sample_sequence': True, } ): mask_generate = batch['generate_flag'] mask_res = batch['mask'] res_feat, pair_feat, R_0, p_0 = self.encode( batch, remove_structure = optimize_opt.get('sample_structure', True), remove_sequence = optimize_opt.get('sample_sequence', True) ) v_0 = rotation_to_so3vec(R_0) s_0 = batch['aa'] traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt) return traj