Sukanyaaa commited on
Commit
1216041
·
verified ·
1 Parent(s): 3928b6a

Update inference_app.py

Browse files
Files changed (1) hide show
  1. inference_app.py +23 -80
inference_app.py CHANGED
@@ -6,11 +6,10 @@ import torch
6
  from torch_geometric.data import HeteroData
7
  import numpy as np
8
  from loguru import logger
9
- from Bio import PDB
10
- from Bio.PDB.PDBIO import PDBIO
11
  from pinder.core.loader.geodata import structure2tensor
12
  from pinder.core.loader.structure import Structure
13
  from src.models.pinder_module import PinderLitModule
 
14
 
15
  try:
16
  from torch_cluster import knn_graph
@@ -37,13 +36,13 @@ def get_props_pdb(pdb_file):
37
  residue_types=calpha.res_name,
38
  residue_ids=calpha.res_id,
39
  )
40
- return props
41
 
42
 
43
  def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")):
44
 
45
- props_ligand = get_props_pdb(pdb_1)
46
- props_receptor = get_props_pdb(pdb_2)
47
 
48
  data = HeteroData()
49
 
@@ -56,74 +55,7 @@ def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")):
56
  data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k)
57
 
58
  data = data.to(device)
59
- return data
60
-
61
-
62
- def update_pdb_coordinates_from_tensor(
63
- input_filename, output_filename, coordinates_tensor
64
- ):
65
- r"""
66
- Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
67
-
68
- Parameters:
69
- - input_filename (str): Path to the original PDB file.
70
- - output_filename (str): Path to the new PDB file to save updated coordinates.
71
- - coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates.
72
- """
73
- # Convert the tensor to a list of tuples
74
- new_coordinates = coordinates_tensor.squeeze(0).tolist()
75
-
76
- # Create a parser and parse the structure
77
- parser = PDB.PDBParser(QUIET=True)
78
- structure = parser.get_structure("structure", input_filename)
79
-
80
- # Flattened iterator for atoms to update coordinates
81
- atom_iterator = (
82
- atom
83
- for model in structure
84
- for chain in model
85
- for residue in chain
86
- for atom in residue
87
- )
88
-
89
- # Update each atom's coordinates
90
- for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates):
91
- original_anisou = atom.get_anisou()
92
- original_uij = atom.get_siguij()
93
- original_tm = atom.get_sigatm()
94
- original_occupancy = atom.get_occupancy()
95
- original_bfactor = atom.get_bfactor()
96
- original_altloc = atom.get_altloc()
97
- original_serial_number = atom.get_serial_number()
98
- original_element = atom.get_charge()
99
- original_parent = atom.get_parent()
100
- original_radius = atom.get_radius()
101
-
102
- # Update only the atom coordinates, keep other fields intact
103
- atom.coord = np.array([new_x, new_y, new_z])
104
-
105
- # Reapply the preserved properties
106
- atom.set_anisou(original_anisou)
107
- atom.set_siguij(original_uij)
108
- atom.set_sigatm(original_tm)
109
- atom.set_occupancy(original_occupancy)
110
- atom.set_bfactor(original_bfactor)
111
- atom.set_altloc(original_altloc)
112
- # atom.set_fullname(original_fullname)
113
- atom.set_serial_number(original_serial_number)
114
- atom.set_charge(original_element)
115
- atom.set_radius(original_radius)
116
- atom.set_parent(original_parent)
117
- # atom.set_name(original_name)
118
- # atom.set_leve
119
-
120
- # Save the updated structure to a new PDB file
121
- io = PDBIO()
122
- io.set_structure(structure)
123
- io.save(output_filename)
124
-
125
- # Return the path to the updated PDB file
126
- return output_filename
127
 
128
 
129
  def merge_pdb_files(file1, file2, output_file):
@@ -156,7 +88,9 @@ def predict(
156
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
157
  logger.info(f"Using device: {device}")
158
 
159
- data = create_graph(input_protein_1, input_protein_2, k=10, device=device)
 
 
160
  logger.info("Created graph data")
161
 
162
  model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt")
@@ -167,13 +101,22 @@ def predict(
167
  with torch.no_grad():
168
  receptor_coords, ligand_coords = model(data)
169
 
170
- file1 = update_pdb_coordinates_from_tensor(
171
- input_protein_1, "holo_ligand.pdb", ligand_coords
 
 
 
172
  )
173
- file2 = update_pdb_coordinates_from_tensor(
174
- input_protein_2, "holo_receptor.pdb", receptor_coords
 
 
 
 
 
 
 
175
  )
176
- out_pdb = merge_pdb_files(file1, file2, "output.pdb")
177
 
178
  # return an output pdb file with the protein and two chains A and B.
179
  # also return a JSON with any metrics you want to report
@@ -267,4 +210,4 @@ with gr.Blocks() as app:
267
  outputs=[out, metrics, run_time],
268
  )
269
 
270
- app.launch()
 
6
  from torch_geometric.data import HeteroData
7
  import numpy as np
8
  from loguru import logger
 
 
9
  from pinder.core.loader.geodata import structure2tensor
10
  from pinder.core.loader.structure import Structure
11
  from src.models.pinder_module import PinderLitModule
12
+ from pathlib import Path
13
 
14
  try:
15
  from torch_cluster import knn_graph
 
36
  residue_types=calpha.res_name,
37
  residue_ids=calpha.res_id,
38
  )
39
+ return structure, props
40
 
41
 
42
  def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")):
43
 
44
+ ligand_structure, props_ligand = get_props_pdb(pdb_1)
45
+ receptor_structure, props_receptor = get_props_pdb(pdb_2)
46
 
47
  data = HeteroData()
48
 
 
55
  data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k)
56
 
57
  data = data.to(device)
58
+ return data, receptor_structure, ligand_structure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  def merge_pdb_files(file1, file2, output_file):
 
88
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
  logger.info(f"Using device: {device}")
90
 
91
+ data, receptor_structure, ligand_structure = create_graph(
92
+ input_protein_1, input_protein_2, k=10, device=device
93
+ )
94
  logger.info("Created graph data")
95
 
96
  model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt")
 
101
  with torch.no_grad():
102
  receptor_coords, ligand_coords = model(data)
103
 
104
+ receptor_structure.coord = receptor_coords.squeeze(0).cpu().numpy()
105
+ ligand_structure.coord = ligand_coords.squeeze(0).cpu().numpy()
106
+
107
+ receptor_pinder = Structure(
108
+ filepath=Path("./holo_receptor.pdb"), atom_array=receptor_structure
109
  )
110
+ ligand_pinder = Structure(
111
+ filepath=Path("./holo_ligand.pdb"), atom_array=ligand_structure
112
+ )
113
+
114
+ receptor_pinder.to_pdb()
115
+ ligand_pinder.to_pdb()
116
+
117
+ out_pdb = merge_pdb_files(
118
+ "./holo_receptor.pdb", "./holo_ligand.pdb", "./output.pdb"
119
  )
 
120
 
121
  # return an output pdb file with the protein and two chains A and B.
122
  # also return a JSON with any metrics you want to report
 
210
  outputs=[out, metrics, run_time],
211
  )
212
 
213
+ app.launch()