import time import json import gradio as gr from gradio_molecule3d import Molecule3D import torch from torch_geometric.data import HeteroData import numpy as np from loguru import logger from Bio import PDB from Bio.PDB.PDBIO import PDBIO from pinder.core.loader.geodata import structure2tensor from pinder.core.loader.structure import Structure from src.models.pinder_module import PinderLitModule try: from torch_cluster import knn_graph torch_cluster_installed = True except ImportError: logger.warning( "torch-cluster is not installed!" "Please install the appropriate library for your pytorch installation." "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." ) torch_cluster_installed = False def get_props_pdb(pdb_file): structure = Structure.read_pdb(pdb_file) atom_mask = np.isin(getattr(structure, "atom_name"), list(["CA"])) calpha = structure[atom_mask].copy() props = structure2tensor( atom_coordinates=structure.coord, atom_types=structure.atom_name, element_types=structure.element, residue_coordinates=calpha.coord, residue_types=calpha.res_name, residue_ids=calpha.res_id, ) return props def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")): props_ligand = get_props_pdb(pdb_1) props_receptor = get_props_pdb(pdb_2) data = HeteroData() data["ligand"].x = props_ligand["atom_types"] data["ligand"].pos = props_ligand["atom_coordinates"] data["ligand", "ligand"].edge_index = knn_graph(data["ligand"].pos, k=k) data["receptor"].x = props_receptor["atom_types"] data["receptor"].pos = props_receptor["atom_coordinates"] data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k) data = data.to(device) return data def update_pdb_coordinates_from_tensor( input_filename, output_filename, coordinates_tensor ): r""" Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor. Parameters: - input_filename (str): Path to the original PDB file. - output_filename (str): Path to the new PDB file to save updated coordinates. - coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates. """ # Convert the tensor to a list of tuples new_coordinates = coordinates_tensor.squeeze(0).tolist() # Create a parser and parse the structure parser = PDB.PDBParser(QUIET=True) structure = parser.get_structure("structure", input_filename) # Flattened iterator for atoms to update coordinates atom_iterator = ( atom for model in structure for chain in model for residue in chain for atom in residue ) # Update each atom's coordinates for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates): original_anisou = atom.get_anisou() original_uij = atom.get_siguij() original_tm = atom.get_sigatm() original_occupancy = atom.get_occupancy() original_bfactor = atom.get_bfactor() original_altloc = atom.get_altloc() original_serial_number = atom.get_serial_number() original_element = atom.get_charge() original_parent = atom.get_parent() original_radius = atom.get_radius() # Update only the atom coordinates, keep other fields intact atom.coord = np.array([new_x, new_y, new_z]) # Reapply the preserved properties atom.set_anisou(original_anisou) atom.set_siguij(original_uij) atom.set_sigatm(original_tm) atom.set_occupancy(original_occupancy) atom.set_bfactor(original_bfactor) atom.set_altloc(original_altloc) # atom.set_fullname(original_fullname) atom.set_serial_number(original_serial_number) atom.set_charge(original_element) atom.set_radius(original_radius) atom.set_parent(original_parent) # atom.set_name(original_name) # atom.set_leve # Save the updated structure to a new PDB file io = PDBIO() io.set_structure(structure) io.save(output_filename) # Return the path to the updated PDB file return output_filename def merge_pdb_files(file1, file2, output_file): r""" Merges two PDB files by concatenating them without altering their contents. Parameters: - file1 (str): Path to the first PDB file (e.g., receptor). - file2 (str): Path to the second PDB file (e.g., ligand). - output_file (str): Path to the output file where the merged structure will be saved. """ with open(output_file, "w") as outfile: # Copy the contents of the first file with open(file1, "r") as f1: lines = f1.readlines() # Write all lines except the last 'END' line outfile.writelines(lines[:-1]) # Copy the contents of the second file with open(file2, "r") as f2: outfile.write(f2.read()) print(f"Merged PDB saved to {output_file}") return output_file def predict( input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2 ): start_time = time.time() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") data = create_graph(input_protein_1, input_protein_2, k=10, device=device) logger.info("Created graph data") model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt") model = model.to(device) model.eval() logger.info("Loaded model") with torch.no_grad(): receptor_coords, ligand_coords = model(data) file1 = update_pdb_coordinates_from_tensor( input_protein_1, "holo_ligand.pdb", ligand_coords ) file2 = update_pdb_coordinates_from_tensor( input_protein_2, "holo_receptor.pdb", receptor_coords ) out_pdb = merge_pdb_files(file1, file2, "output.pdb") # return an output pdb file with the protein and two chains A and B. # also return a JSON with any metrics you want to report metrics = {"mean_plddt": 80, "binding_affinity": 2} end_time = time.time() run_time = end_time - start_time return out_pdb, json.dumps(metrics), run_time with gr.Blocks() as app: gr.Markdown("# Template for inference") gr.Markdown("EquiMPNN MOdel") with gr.Row(): with gr.Column(): input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)") input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)") input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)") with gr.Column(): input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)") input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)") input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)") # define any options here # for automated inference the default options are used # slider_option = gr.Slider(0,10, label="Slider Option") # checkbox_option = gr.Checkbox(label="Checkbox Option") # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") btn = gr.Button("Run Inference") gr.Examples( [ [ "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", "3v1c_A.pdb", "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", "3v1c_B.pdb", ], ], [input_seq_1, input_protein_1, input_seq_2, input_protein_2], ) reps = [ { "model": 0, "style": "cartoon", "chain": "A", "color": "whiteCarbon", }, { "model": 0, "style": "cartoon", "chain": "B", "color": "greenCarbon", }, { "model": 0, "chain": "A", "style": "stick", "sidechain": True, "color": "whiteCarbon", }, { "model": 0, "chain": "B", "style": "stick", "sidechain": True, "color": "greenCarbon", }, ] # outputs out = Molecule3D(reps=reps) metrics = gr.JSON(label="Metrics") run_time = gr.Textbox(label="Runtime") btn.click( predict, inputs=[ input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2, ], outputs=[out, metrics, run_time], ) app.launch()