luost26 commited on
Commit
20002df
·
1 Parent(s): 54edec6

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +362 -0
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./diffab-repo')
3
+ import os
4
+ import shutil
5
+ import pandas as pd
6
+ import yaml
7
+ import subprocess
8
+ import streamlit as st
9
+ import stmol
10
+ import py3Dmol
11
+ import tempfile
12
+ import re
13
+ import abnumber
14
+ import gzip
15
+ import tarfile
16
+ import torch
17
+ from tqdm.auto import tqdm
18
+ from Bio import PDB
19
+ from collections import OrderedDict
20
+
21
+ from diffab.tools.renumber import renumber as renumber_antibody
22
+ from diffab.tools.renumber.run import (
23
+ biopython_chain_to_sequence,
24
+ assign_number_to_sequence,
25
+ )
26
+
27
+ DIFFAB_DIR = os.path.realpath('./diffab-repo')
28
+
29
+ CDR_OPTIONS = OrderedDict()
30
+ CDR_OPTIONS['H_CDR1'] = 'H1'
31
+ CDR_OPTIONS['H_CDR2'] = 'H2'
32
+ CDR_OPTIONS['H_CDR3'] = 'H3'
33
+ CDR_OPTIONS['L_CDR1'] = 'L1'
34
+ CDR_OPTIONS['L_CDR2'] = 'L2'
35
+ CDR_OPTIONS['L_CDR3'] = 'L3'
36
+
37
+ DESIGN_MODES = OrderedDict()
38
+ DESIGN_MODES['denovo'] = 'De novo design'
39
+ DESIGN_MODES['denovo_dock'] = 'De novo design (with HDOCK)'
40
+ DESIGN_MODES['opt'] = 'Optimization'
41
+ DESIGN_MODES['fixbb'] = 'Fix-backbone'
42
+
43
+ MODE_CONFIG = {
44
+ 'denovo': './configs/test/codesign_multicdrs.yml',
45
+ 'denovo_dock': './configs/test/codesign_multicdrs.yml',
46
+ 'opt': './configs/test/abopt_singlecdr.yml',
47
+ 'fixbb': './configs/test/fixbb.yml',
48
+ }
49
+
50
+ GPU_AVAILABLE = torch.cuda.is_available()
51
+ DEFAULT_NUM_SAMPLES = 5 if GPU_AVAILABLE else 1
52
+ DEFAULT_NUM_DOCKS = 3
53
+
54
+
55
+ def dict_to_func(d):
56
+ def f(x):
57
+ return d[x]
58
+ return f
59
+
60
+
61
+ def get_config(save_dir, mode, cdrs, num_samples=5, optimization_step=4):
62
+ tmpl_path = MODE_CONFIG[mode]
63
+ with open(tmpl_path, 'r') as f:
64
+ cfg = yaml.safe_load(f)
65
+ cfg['sampling']['cdrs'] = cdrs
66
+ cfg['sampling']['num_samples'] = num_samples
67
+ cfg['sampling']['optimize_steps'] = [optimization_step, ]
68
+
69
+ save_path = os.path.join(save_dir, 'design.yml')
70
+ with open(save_path, 'w') as f:
71
+ yaml.dump(cfg, f)
72
+ return cfg, save_path
73
+
74
+
75
+ def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_docks=DEFAULT_NUM_DOCKS):
76
+ if docking:
77
+ cmd = f"python design_dock.py --antigen {pdb_path} --config {config_path} --num_docks {num_docks} "
78
+ else:
79
+ cmd = f"python design_pdb.py {pdb_path} --config {config_path} "
80
+ cmd += f"--batch_size 1 --out_root {output_dir} "
81
+
82
+ if GPU_AVAILABLE:
83
+ cmd += "--device cuda"
84
+ else:
85
+ cmd += "--device cpu"
86
+
87
+ result_dir = os.path.join(output_dir, 'design')
88
+ if os.path.exists(result_dir):
89
+ shutil.rmtree(result_dir)
90
+
91
+ output_buffer = ''
92
+ proc = subprocess.Popen(
93
+ cmd,
94
+ shell=True,
95
+ env=os.environ.copy(),
96
+ bufsize=1,
97
+ stdout=subprocess.PIPE,
98
+ stderr=subprocess.STDOUT,
99
+ cwd=DIFFAB_DIR,
100
+ )
101
+ for line in iter(proc.stdout.readline, b''):
102
+ output_buffer += line.decode()
103
+
104
+ display_widget.code(
105
+ '\n'.join(output_buffer.splitlines()[-10:]),
106
+ )
107
+ proc.stdout.close()
108
+ proc.wait()
109
+
110
+
111
+ @st.cache
112
+ def renumber_antibody_cached(in_pdb, out_pdb, file_id):
113
+ return renumber_antibody(
114
+ in_pdb, out_pdb, return_other_chains=True
115
+ )
116
+
117
+
118
+ def gather_results(result_dir):
119
+ outputs = []
120
+ for root, dirs, files in os.walk(result_dir):
121
+ for fname in files:
122
+ if not re.match('^\d\d\d\d\.pdb$', fname):
123
+ continue
124
+ fpath = os.path.join(root, fname)
125
+ gname = os.path.basename(root)
126
+
127
+ outputs.append((gname, fname, fpath))
128
+
129
+ parser = PDB.PDBParser(QUIET=True)
130
+ records = []
131
+ fpath_to_name = {}
132
+ for gname, fname, fpath in tqdm(outputs):
133
+ name = f"{gname}_{fname}"
134
+ structure = parser.get_structure(name, fpath)
135
+ model = structure[0]
136
+ record = {
137
+ 'name': name,
138
+ 'H1': None, 'H2': None, 'H3': None,
139
+ 'L1': None, 'L2': None, 'L3': None,
140
+ 'gname': gname, 'fname': fname, 'fpath': fpath,
141
+ }
142
+ for chain in model:
143
+ try:
144
+ seq, reslist = biopython_chain_to_sequence(chain)
145
+ numbers, abchain = assign_number_to_sequence(seq)
146
+ if abchain.chain_type == 'H':
147
+ record['H1'] = abchain.cdr1_seq
148
+ record['H2'] = abchain.cdr2_seq
149
+ record['H3'] = abchain.cdr3_seq
150
+ elif abchain.chain_type in ('L', 'K'):
151
+ record['L1'] = abchain.cdr1_seq
152
+ record['L2'] = abchain.cdr2_seq
153
+ record['L3'] = abchain.cdr3_seq
154
+ except abnumber.ChainParseError as e:
155
+ pass
156
+ records.append(record)
157
+ fpath_to_name[fpath] = name
158
+
159
+ with tarfile.open(os.path.join(result_dir, 'generated.tar.gz'), 'w:gz') as tar:
160
+ for record in records:
161
+ info = tar.gettarinfo(record['fpath'])
162
+ info.name = record['name']
163
+ tar.addfile(
164
+ tarinfo = info,
165
+ fileobj = open(record['fpath'], 'rb'),
166
+ )
167
+
168
+ records = pd.DataFrame(records)
169
+
170
+ return records, fpath_to_name
171
+
172
+
173
+ def main():
174
+ # Temporary workspace directory
175
+ if 'tempdir_path' not in st.session_state:
176
+ tempdir_path = tempfile.mkdtemp(prefix='streamlit')
177
+ st.session_state.tempdir_path = tempdir_path
178
+ else:
179
+ tempdir_path = st.session_state.tempdir_path
180
+ # Page layout
181
+ st.set_page_config(layout="wide")
182
+ st.markdown(
183
+ "# DiffAb \n\n"
184
+ "Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures (NeurIPS 2022) \n\n"
185
+ "[[Paper](https://www.biorxiv.org/content/10.1101/2022.07.10.499510.abstract)] "
186
+ "[[Code](https://github.com/luost26/diffab)]"
187
+ )
188
+ left_col, right_col = st.columns(2)
189
+
190
+ # Step 1: Upload PDB or choose an example
191
+ uploaded_file = None
192
+ with left_col:
193
+ uploaded_file = st.file_uploader(
194
+ 'Antigen structure or antibody-antigen complex',
195
+ # disabled=True
196
+ )
197
+
198
+ if uploaded_file is None:
199
+ with st.expander('Download examples', expanded=True):
200
+ with open('./data/examples/7DK2_AB_C.pdb', 'r') as f:
201
+ st.download_button(
202
+ 'RBD + Antibody Complex',
203
+ data = f,
204
+ file_name='RBD_AbAg.pdb',
205
+ )
206
+ with open('./data/examples/Omicron_RBD.pdb', 'r') as f:
207
+ st.download_button(
208
+ 'RBD Antigen Only',
209
+ data = f,
210
+ file_name = 'RBD_AgOnly.pdb',
211
+ )
212
+ st.text('Please upload the downloaded PDB file to run the demo.')
213
+
214
+ if 'submit' not in st.session_state:
215
+ st.session_state.submit = False
216
+ if 'done' not in st.session_state:
217
+ st.session_state.done = False
218
+
219
+ # Step 1.2: Retrieve uploaded PDB
220
+ if uploaded_file is not None:
221
+ pdb_path = os.path.join(tempdir_path, 'structure.pdb')
222
+ renum_path = os.path.join(tempdir_path, 'structure_renumber.pdb')
223
+ with open(pdb_path, 'w') as f:
224
+ f.write(uploaded_file.getvalue().decode())
225
+ H_chains, L_chains, Ag_chains = renumber_antibody_cached(
226
+ in_pdb = pdb_path,
227
+ out_pdb = renum_path,
228
+ file_id = uploaded_file.id
229
+ )
230
+ H_chain = H_chains[0] if H_chains else None
231
+ L_chain = L_chains[0] if L_chains else None
232
+ docking = H_chain is None and L_chain is None
233
+
234
+ # Step 2: Design options
235
+ if uploaded_file is not None:
236
+ with left_col:
237
+ st.dataframe(pd.DataFrame({
238
+ 'Heavy': {'Chain': H_chain},
239
+ 'Light': {'Chain': L_chain},
240
+ 'Antigen': {'Chain': ','.join(Ag_chains)},
241
+ }), use_container_width=True)
242
+
243
+ form = st.form('design_form')
244
+ with form:
245
+ if H_chain is None and L_chain is None:
246
+ # Antigen only
247
+ cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3']
248
+ cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
249
+ mode_options = ['denovo_dock']
250
+ elif H_chain is not None and L_chain is None:
251
+ # Heavy chain + Antigen
252
+ cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3']
253
+ cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
254
+ mode_options = ['denovo', 'opt', 'fixbb']
255
+ elif H_chain is None and L_chain is not None:
256
+ # Light chain + Antigen
257
+ cdr_options = ['L_CDR1', 'L_CDR2', 'L_CDR3']
258
+ cdr_default = ['L_CDR1', 'L_CDR2', 'L_CDR3']
259
+ mode_options = ['denovo', 'opt', 'fixbb']
260
+ else:
261
+ # H + L + Ag
262
+ cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3']
263
+ cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
264
+ mode_options = ['denovo', 'opt', 'fixbb']
265
+
266
+ design_mode = st.radio(
267
+ 'Mode',
268
+ mode_options,
269
+ format_func=dict_to_func(DESIGN_MODES),
270
+ # disabled=True,
271
+ )
272
+ cdr_choices = st.multiselect(
273
+ 'CDRs',
274
+ cdr_options,
275
+ default = cdr_default,
276
+ format_func=dict_to_func(CDR_OPTIONS),
277
+ # disabled=True,
278
+ )
279
+
280
+ if docking:
281
+ num_docks = st.slider(
282
+ 'Number of docking poses',
283
+ min_value=1, max_value=10, value=DEFAULT_NUM_DOCKS,
284
+ )
285
+ else:
286
+ num_docks = 0
287
+ num_designs = st.slider(
288
+ 'Number of samples',
289
+ min_value=1, max_value=10, value=DEFAULT_NUM_SAMPLES,
290
+ )
291
+
292
+ submit = st.form_submit_button('Run')
293
+ st.session_state.submit = st.session_state.submit or submit
294
+ if submit:
295
+ st.session_state.done = False
296
+
297
+ # Step 3: Prepare configuration and run design
298
+ if uploaded_file is not None and st.session_state.submit:
299
+ config, config_path = get_config(
300
+ save_dir = tempdir_path,
301
+ mode = design_mode,
302
+ cdrs = cdr_choices,
303
+ num_samples = num_designs,
304
+ )
305
+
306
+ with right_col:
307
+ result_molecule_display = st.empty()
308
+ result_select_widget = st.empty()
309
+ result_table_display = st.empty()
310
+ result_download_btn = st.empty()
311
+ output_display = st.empty()
312
+ if not st.session_state.done:
313
+ run_design(
314
+ pdb_path = renum_path,
315
+ config_path = config_path,
316
+ output_dir = tempdir_path,
317
+ docking = docking,
318
+ display_widget = output_display,
319
+ num_docks = num_docks,
320
+ )
321
+ st.session_state.done = True
322
+
323
+ result_dir = os.path.join(tempdir_path, 'design')
324
+ df_cols = ['name'] + list(CDR_OPTIONS.values())
325
+ df_results, fpath_to_name = gather_results(result_dir)
326
+ st.session_state.results = (df_results, fpath_to_name)
327
+
328
+ # Step 5: Show results:
329
+ if st.session_state.submit and st.session_state.done:
330
+ result_dir = os.path.join(tempdir_path, 'design')
331
+ df_results, fpath_to_name = st.session_state.results
332
+
333
+ df_cols = ['name'] + list(CDR_OPTIONS.values())
334
+ result_table_display.dataframe(df_results[df_cols], use_container_width=True)
335
+
336
+ display_pdb_path = result_select_widget.selectbox(
337
+ label = "Visualize",
338
+ options = df_results['fpath'],
339
+ format_func = dict_to_func(fpath_to_name),
340
+ )
341
+
342
+ with open(os.path.join(result_dir, 'generated.tar.gz'), 'rb') as f:
343
+ result_download_btn.download_button(
344
+ label = "Download PDBs",
345
+ data = f,
346
+ file_name = "generated.tar.gz",
347
+ )
348
+
349
+ if not os.path.exists(display_pdb_path):
350
+ display_pdb_path = df_results['fpath'][0]
351
+ with open(display_pdb_path, 'r') as f:
352
+ pdb_str = f.read()
353
+ xyzview = py3Dmol.view(width=380, height=380)
354
+ xyzview.addModelsAsFrames(pdb_str)
355
+ xyzview.setStyle({'cartoon':{'color':'spectrum'}})
356
+ xyzview.zoomTo()
357
+ with result_molecule_display:
358
+ stmol.showmol(xyzview, width=380, height=380)
359
+
360
+
361
+ if __name__ == '__main__':
362
+ main()