Sukanyaaa commited on
Commit
d7f69ca
·
verified ·
1 Parent(s): 7448888

Update inference_app.py

Browse files
Files changed (1) hide show
  1. inference_app.py +219 -45
inference_app.py CHANGED
@@ -1,29 +1,195 @@
1
-
2
  import time
3
- import json
4
-
5
  import gradio as gr
6
-
7
  from gradio_molecule3d import Molecule3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2, input_protein_2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  start_time = time.time()
14
- # Do inference here
15
- # return an output pdb file with the protein and two chains A and B.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # also return a JSON with any metrics you want to report
17
  metrics = {"mean_plddt": 80, "binding_affinity": 2}
 
18
  end_time = time.time()
19
  run_time = end_time - start_time
20
- return "test_out.pdb",json.dumps(metrics), run_time
 
 
21
 
22
  with gr.Blocks() as app:
23
 
24
  gr.Markdown("# Template for inference")
25
 
26
- gr.Markdown("Title, description, and other information about the model")
27
  with gr.Row():
28
  with gr.Column():
29
  input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
@@ -33,9 +199,7 @@ with gr.Blocks() as app:
33
  input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
34
  input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
35
  input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
36
-
37
-
38
-
39
  # define any options here
40
 
41
  # for automated inference the default options are used
@@ -52,45 +216,55 @@ with gr.Blocks() as app:
52
  "3v1c_A.pdb",
53
  "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
54
  "3v1c_B.pdb",
55
-
56
  ],
57
  ],
58
- [input_seq_1, input_protein_1, input_seq_2, input_protein_2],
59
  )
60
- reps = [
61
- {
62
- "model": 0,
63
- "style": "cartoon",
64
- "chain": "A",
65
- "color": "whiteCarbon",
66
- },
67
- {
68
- "model": 0,
69
- "style": "cartoon",
70
- "chain": "B",
71
- "color": "greenCarbon",
72
- },
73
- {
74
- "model": 0,
75
- "chain": "A",
76
- "style": "stick",
77
- "sidechain": True,
78
- "color": "whiteCarbon",
79
- },
80
- {
81
- "model": 0,
82
- "chain": "B",
83
- "style": "stick",
84
- "sidechain": True,
85
- "color": "greenCarbon"
86
- }
87
- ]
88
- # outputs
89
-
90
  out = Molecule3D(reps=reps)
91
  metrics = gr.JSON(label="Metrics")
92
  run_time = gr.Textbox(label="Runtime")
93
 
94
- btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time])
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  app.launch()
 
 
1
  import time
2
+ import json
 
3
  import gradio as gr
 
4
  from gradio_molecule3d import Molecule3D
5
+ import torch
6
+ from torch_geometric.data import HeteroData
7
+ import numpy as np
8
+ from loguru import logger
9
+ from Bio import PDB
10
+ from Bio.PDB.PDBIO import PDBIO
11
+ from pinder.core.loader.geodata import structure2tensor
12
+ from pinder.core.loader.structure import Structure
13
+ from src.models.pinder_module import PinderLitModule
14
+
15
+ try:
16
+ from torch_cluster import knn_graph
17
+
18
+ torch_cluster_installed = True
19
+ except ImportError:
20
+ logger.warning(
21
+ "torch-cluster is not installed!"
22
+ "Please install the appropriate library for your pytorch installation."
23
+ "See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
24
+ )
25
+ torch_cluster_installed = False
26
+
27
+
28
+ def get_props_pdb(pdb_file):
29
+ structure = Structure.read_pdb(pdb_file)
30
+ atom_mask = np.isin(getattr(structure, "atom_name"), list(["CA"]))
31
+ calpha = structure[atom_mask].copy()
32
+ props = structure2tensor(
33
+ atom_coordinates=structure.coord,
34
+ atom_types=structure.atom_name,
35
+ element_types=structure.element,
36
+ residue_coordinates=calpha.coord,
37
+ residue_types=calpha.res_name,
38
+ residue_ids=calpha.res_id,
39
+ )
40
+ return props
41
+
42
+
43
+ def create_graph(pdb_1, pdb_2, k=5, device: torch.device = torch.device("cpu")):
44
+
45
+ props_ligand = get_props_pdb(pdb_1)
46
+ props_receptor = get_props_pdb(pdb_2)
47
+
48
+ data = HeteroData()
49
+
50
+ data["ligand"].x = props_ligand["atom_types"]
51
+ data["ligand"].pos = props_ligand["atom_coordinates"]
52
+ data["ligand", "ligand"].edge_index = knn_graph(data["ligand"].pos, k=k)
53
+
54
+ data["receptor"].x = props_receptor["atom_types"]
55
+ data["receptor"].pos = props_receptor["atom_coordinates"]
56
+ data["receptor", "receptor"].edge_index = knn_graph(data["receptor"].pos, k=k)
57
+
58
+ data = data.to(device)
59
+ return data
60
+
61
+
62
+ def update_pdb_coordinates_from_tensor(
63
+ input_filename, output_filename, coordinates_tensor
64
+ ):
65
+ r"""
66
+ Updates atom coordinates in a PDB file with new transformed coordinates provided in a tensor.
67
+
68
+ Parameters:
69
+ - input_filename (str): Path to the original PDB file.
70
+ - output_filename (str): Path to the new PDB file to save updated coordinates.
71
+ - coordinates_tensor (torch.Tensor): Tensor of shape (1, N, 3) with transformed coordinates.
72
+ """
73
+ # Convert the tensor to a list of tuples
74
+ new_coordinates = coordinates_tensor.squeeze(0).tolist()
75
+
76
+ # Create a parser and parse the structure
77
+ parser = PDB.PDBParser(QUIET=True)
78
+ structure = parser.get_structure("structure", input_filename)
79
+
80
+ # Flattened iterator for atoms to update coordinates
81
+ atom_iterator = (
82
+ atom
83
+ for model in structure
84
+ for chain in model
85
+ for residue in chain
86
+ for atom in residue
87
+ )
88
 
89
+ # Update each atom's coordinates
90
+ for atom, (new_x, new_y, new_z) in zip(atom_iterator, new_coordinates):
91
+ original_anisou = atom.get_anisou()
92
+ original_uij = atom.get_siguij()
93
+ original_tm = atom.get_sigatm()
94
+ original_occupancy = atom.get_occupancy()
95
+ original_bfactor = atom.get_bfactor()
96
+ original_altloc = atom.get_altloc()
97
+ original_serial_number = atom.get_serial_number()
98
+ original_element = atom.get_charge()
99
+ original_parent = atom.get_parent()
100
+ original_radius = atom.get_radius()
101
 
102
+ # Update only the atom coordinates, keep other fields intact
103
+ atom.coord = np.array([new_x, new_y, new_z])
104
 
105
+ # Reapply the preserved properties
106
+ atom.set_anisou(original_anisou)
107
+ atom.set_siguij(original_uij)
108
+ atom.set_sigatm(original_tm)
109
+ atom.set_occupancy(original_occupancy)
110
+ atom.set_bfactor(original_bfactor)
111
+ atom.set_altloc(original_altloc)
112
+ # atom.set_fullname(original_fullname)
113
+ atom.set_serial_number(original_serial_number)
114
+ atom.set_charge(original_element)
115
+ atom.set_radius(original_radius)
116
+ atom.set_parent(original_parent)
117
+ # atom.set_name(original_name)
118
+ # atom.set_leve
119
 
120
+ # Save the updated structure to a new PDB file
121
+ io = PDBIO()
122
+ io.set_structure(structure)
123
+ io.save(output_filename)
124
+
125
+ # Return the path to the updated PDB file
126
+ return output_filename
127
+
128
+
129
+ def merge_pdb_files(file1, file2, output_file):
130
+ r"""
131
+ Merges two PDB files by concatenating them without altering their contents.
132
+
133
+ Parameters:
134
+ - file1 (str): Path to the first PDB file (e.g., receptor).
135
+ - file2 (str): Path to the second PDB file (e.g., ligand).
136
+ - output_file (str): Path to the output file where the merged structure will be saved.
137
+ """
138
+ with open(output_file, "w") as outfile:
139
+ # Copy the contents of the first file
140
+ with open(file1, "r") as f1:
141
+ lines = f1.readlines()
142
+ # Write all lines except the last 'END' line
143
+ outfile.writelines(lines[:-1])
144
+ # Copy the contents of the second file
145
+ with open(file2, "r") as f2:
146
+ outfile.write(f2.read())
147
+
148
+ print(f"Merged PDB saved to {output_file}")
149
+ return output_file
150
+
151
+
152
+ def predict(
153
+ input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2
154
+ ):
155
  start_time = time.time()
156
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
157
+ logger.info(f"Using device: {device}")
158
+
159
+ data = create_graph(input_protein_1, input_protein_2, k=10, device=device)
160
+ logger.info("Created graph data")
161
+
162
+ model = PinderLitModule.load_from_checkpoint("./checkpoints/epoch_010.ckpt")
163
+ model = model.to(device)
164
+ model.eval()
165
+ logger.info("Loaded model")
166
+
167
+ with torch.no_grad():
168
+ receptor_coords, ligand_coords = model(data)
169
+
170
+ file1 = update_pdb_coordinates_from_tensor(
171
+ input_protein_1, "holo_ligand.pdb", ligand_coords
172
+ )
173
+ file2 = update_pdb_coordinates_from_tensor(
174
+ input_protein_2, "holo_receptor.pdb", receptor_coords
175
+ )
176
+ out_pdb = merge_pdb_files(file1, file2, "output.pdb")
177
+
178
+ # return an output pdb file with the protein and two chains A and B.
179
  # also return a JSON with any metrics you want to report
180
  metrics = {"mean_plddt": 80, "binding_affinity": 2}
181
+
182
  end_time = time.time()
183
  run_time = end_time - start_time
184
+
185
+ return out_pdb, json.dumps(metrics), run_time
186
+
187
 
188
  with gr.Blocks() as app:
189
 
190
  gr.Markdown("# Template for inference")
191
 
192
+ gr.Markdown("EquiMPNN MOdel")
193
  with gr.Row():
194
  with gr.Column():
195
  input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
 
199
  input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
200
  input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
201
  input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
202
+
 
 
203
  # define any options here
204
 
205
  # for automated inference the default options are used
 
216
  "3v1c_A.pdb",
217
  "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
218
  "3v1c_B.pdb",
 
219
  ],
220
  ],
221
+ [input_seq_1, input_protein_1, input_seq_2, input_protein_2],
222
  )
223
+ reps = [
224
+ {
225
+ "model": 0,
226
+ "style": "cartoon",
227
+ "chain": "A",
228
+ "color": "whiteCarbon",
229
+ },
230
+ {
231
+ "model": 0,
232
+ "style": "cartoon",
233
+ "chain": "B",
234
+ "color": "greenCarbon",
235
+ },
236
+ {
237
+ "model": 0,
238
+ "chain": "A",
239
+ "style": "stick",
240
+ "sidechain": True,
241
+ "color": "whiteCarbon",
242
+ },
243
+ {
244
+ "model": 0,
245
+ "chain": "B",
246
+ "style": "stick",
247
+ "sidechain": True,
248
+ "color": "greenCarbon",
249
+ },
250
+ ]
251
+ # outputs
252
+
253
  out = Molecule3D(reps=reps)
254
  metrics = gr.JSON(label="Metrics")
255
  run_time = gr.Textbox(label="Runtime")
256
 
257
+ btn.click(
258
+ predict,
259
+ inputs=[
260
+ input_seq_1,
261
+ input_msa_1,
262
+ input_protein_1,
263
+ input_seq_2,
264
+ input_msa_2,
265
+ input_protein_2,
266
+ ],
267
+ outputs=[out, metrics, run_time],
268
+ )
269
 
270
  app.launch()