Spaces:
Sleeping
Sleeping
import time | |
import json | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
import torch | |
from torch_geometric.data import HeteroData | |
import numpy as np | |
from loguru import logger | |
from pinder.core.loader.geodata import structure2tensor | |
from pinder.core.loader.structure import Structure | |
from src.models.pinder_module import PinderLitModule | |
from pathlib import Path | |
try: | |
from torch_cluster import knn_graph | |
torch_cluster_installed = True | |
except ImportError: | |
logger.warning( | |
"torch-cluster is not installed!" | |
"Please install the appropriate library for your pytorch installation." | |
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background." | |
) | |
torch_cluster_installed = False | |
def get_props_pdb(pdb_file): | |
structure = Structure.read_pdb(pdb_file) | |
atom_mask = np.isin(getattr(structure, "atom_name"), list(["CA"])) | |
calpha = structure[atom_mask].copy() | |
props = structure2tensor( | |
atom_coordinates=structure.coord, | |
atom_types=structure.atom_name, | |
element_types=structure.element, | |
residue_coordinates=calpha.coord, | |
residue_types=calpha.res_name, | |
residue_ids=calpha.res_id, | |
) | |
return structure, props | |
def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")): | |
ligand_structure, props_ligand = get_props_pdb(pdb_1) | |
receptor_structure, props_receptor = get_props_pdb(pdb_2) | |
data = HeteroData() | |
data["ligand"].x = props_ligand["atom_types"] | |
data["ligand"].pos = props_ligand["atom_coordinates"] | |
data["ligand", "ligand"].edge_index = knn_graph(data["ligand"].pos, k=k) | |
data["receptor"].x = props_receptor["atom_types"] | |
data["receptor"].pos = props_receptor["atom_coordinates"] | |
data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k) | |
data = data.to(device) | |
return data, receptor_structure, ligand_structure | |
def merge_pdb_files(file1, file2, output_file): | |
r""" | |
Merges two PDB files by concatenating them without altering their contents. | |
Parameters: | |
- file1 (str): Path to the first PDB file (e.g., receptor). | |
- file2 (str): Path to the second PDB file (e.g., ligand). | |
- output_file (str): Path to the output file where the merged structure will be saved. | |
""" | |
with open(output_file, "w") as outfile: | |
# Copy the contents of the first file | |
with open(file1, "r") as f1: | |
lines = f1.readlines() | |
# Write all lines except the last 'END' line | |
outfile.writelines(lines[:-1]) | |
# Copy the contents of the second file | |
with open(file2, "r") as f2: | |
outfile.write(f2.read()) | |
print(f"Merged PDB saved to {output_file}") | |
return output_file | |
def predict( | |
input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2 | |
): | |
start_time = time.time() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
data, receptor_structure, ligand_structure = create_graph( | |
input_protein_1, input_protein_2, k=10, device=device | |
) | |
logger.info("Created graph data") | |
model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt") | |
model = model.to(device) | |
model.eval() | |
logger.info("Loaded model") | |
with torch.no_grad(): | |
receptor_coords, ligand_coords = model(data) | |
receptor_structure.coord = receptor_coords.squeeze(0).cpu().numpy() | |
ligand_structure.coord = ligand_coords.squeeze(0).cpu().numpy() | |
receptor_pinder = Structure( | |
filepath=Path("./holo_receptor.pdb"), atom_array=receptor_structure | |
) | |
ligand_pinder = Structure( | |
filepath=Path("./holo_ligand.pdb"), atom_array=ligand_structure | |
) | |
receptor_pinder.to_pdb() | |
ligand_pinder.to_pdb() | |
out_pdb = merge_pdb_files( | |
"./holo_receptor.pdb", "./holo_ligand.pdb", "./output.pdb" | |
) | |
# return an output pdb file with the protein and two chains A and B. | |
# also return a JSON with any metrics you want to report | |
metrics = {"mean_plddt": 80, "binding_affinity": 2} | |
end_time = time.time() | |
run_time = end_time - start_time | |
return out_pdb, json.dumps(metrics), run_time | |
with gr.Blocks() as app: | |
gr.Markdown("# Template for inference") | |
gr.Markdown("EquiMPNN MOdel") | |
with gr.Row(): | |
with gr.Column(): | |
input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)") | |
input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)") | |
input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)") | |
with gr.Column(): | |
input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)") | |
input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)") | |
input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)") | |
# define any options here | |
# for automated inference the default options are used | |
# slider_option = gr.Slider(0,10, label="Slider Option") | |
# checkbox_option = gr.Checkbox(label="Checkbox Option") | |
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") | |
btn = gr.Button("Run Inference") | |
gr.Examples( | |
[ | |
[ | |
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
"3v1c_A.pdb", | |
"GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD", | |
"3v1c_B.pdb", | |
], | |
], | |
[input_seq_1, input_protein_1, input_seq_2, input_protein_2], | |
) | |
reps = [ | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "A", | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "B", | |
"color": "greenCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "A", | |
"style": "stick", | |
"sidechain": True, | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "B", | |
"style": "stick", | |
"sidechain": True, | |
"color": "greenCarbon", | |
}, | |
] | |
# outputs | |
out = Molecule3D(reps=reps) | |
metrics = gr.JSON(label="Metrics") | |
run_time = gr.Textbox(label="Runtime") | |
btn.click( | |
predict, | |
inputs=[ | |
input_seq_1, | |
input_msa_1, | |
input_protein_1, | |
input_seq_2, | |
input_msa_2, | |
input_protein_2, | |
], | |
outputs=[out, metrics, run_time], | |
) | |
app.launch() |