Spaces:
Runtime error
Runtime error
Add app.py
Browse files
app.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./diffab-repo')
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import pandas as pd
|
6 |
+
import yaml
|
7 |
+
import subprocess
|
8 |
+
import streamlit as st
|
9 |
+
import stmol
|
10 |
+
import py3Dmol
|
11 |
+
import tempfile
|
12 |
+
import re
|
13 |
+
import abnumber
|
14 |
+
import gzip
|
15 |
+
import tarfile
|
16 |
+
import torch
|
17 |
+
from tqdm.auto import tqdm
|
18 |
+
from Bio import PDB
|
19 |
+
from collections import OrderedDict
|
20 |
+
|
21 |
+
from diffab.tools.renumber import renumber as renumber_antibody
|
22 |
+
from diffab.tools.renumber.run import (
|
23 |
+
biopython_chain_to_sequence,
|
24 |
+
assign_number_to_sequence,
|
25 |
+
)
|
26 |
+
|
27 |
+
DIFFAB_DIR = os.path.realpath('./diffab-repo')
|
28 |
+
|
29 |
+
CDR_OPTIONS = OrderedDict()
|
30 |
+
CDR_OPTIONS['H_CDR1'] = 'H1'
|
31 |
+
CDR_OPTIONS['H_CDR2'] = 'H2'
|
32 |
+
CDR_OPTIONS['H_CDR3'] = 'H3'
|
33 |
+
CDR_OPTIONS['L_CDR1'] = 'L1'
|
34 |
+
CDR_OPTIONS['L_CDR2'] = 'L2'
|
35 |
+
CDR_OPTIONS['L_CDR3'] = 'L3'
|
36 |
+
|
37 |
+
DESIGN_MODES = OrderedDict()
|
38 |
+
DESIGN_MODES['denovo'] = 'De novo design'
|
39 |
+
DESIGN_MODES['denovo_dock'] = 'De novo design (with HDOCK)'
|
40 |
+
DESIGN_MODES['opt'] = 'Optimization'
|
41 |
+
DESIGN_MODES['fixbb'] = 'Fix-backbone'
|
42 |
+
|
43 |
+
MODE_CONFIG = {
|
44 |
+
'denovo': './configs/test/codesign_multicdrs.yml',
|
45 |
+
'denovo_dock': './configs/test/codesign_multicdrs.yml',
|
46 |
+
'opt': './configs/test/abopt_singlecdr.yml',
|
47 |
+
'fixbb': './configs/test/fixbb.yml',
|
48 |
+
}
|
49 |
+
|
50 |
+
GPU_AVAILABLE = torch.cuda.is_available()
|
51 |
+
DEFAULT_NUM_SAMPLES = 5 if GPU_AVAILABLE else 1
|
52 |
+
DEFAULT_NUM_DOCKS = 3
|
53 |
+
|
54 |
+
|
55 |
+
def dict_to_func(d):
|
56 |
+
def f(x):
|
57 |
+
return d[x]
|
58 |
+
return f
|
59 |
+
|
60 |
+
|
61 |
+
def get_config(save_dir, mode, cdrs, num_samples=5, optimization_step=4):
|
62 |
+
tmpl_path = MODE_CONFIG[mode]
|
63 |
+
with open(tmpl_path, 'r') as f:
|
64 |
+
cfg = yaml.safe_load(f)
|
65 |
+
cfg['sampling']['cdrs'] = cdrs
|
66 |
+
cfg['sampling']['num_samples'] = num_samples
|
67 |
+
cfg['sampling']['optimize_steps'] = [optimization_step, ]
|
68 |
+
|
69 |
+
save_path = os.path.join(save_dir, 'design.yml')
|
70 |
+
with open(save_path, 'w') as f:
|
71 |
+
yaml.dump(cfg, f)
|
72 |
+
return cfg, save_path
|
73 |
+
|
74 |
+
|
75 |
+
def run_design(pdb_path, config_path, output_dir, docking, display_widget, num_docks=DEFAULT_NUM_DOCKS):
|
76 |
+
if docking:
|
77 |
+
cmd = f"python design_dock.py --antigen {pdb_path} --config {config_path} --num_docks {num_docks} "
|
78 |
+
else:
|
79 |
+
cmd = f"python design_pdb.py {pdb_path} --config {config_path} "
|
80 |
+
cmd += f"--batch_size 1 --out_root {output_dir} "
|
81 |
+
|
82 |
+
if GPU_AVAILABLE:
|
83 |
+
cmd += "--device cuda"
|
84 |
+
else:
|
85 |
+
cmd += "--device cpu"
|
86 |
+
|
87 |
+
result_dir = os.path.join(output_dir, 'design')
|
88 |
+
if os.path.exists(result_dir):
|
89 |
+
shutil.rmtree(result_dir)
|
90 |
+
|
91 |
+
output_buffer = ''
|
92 |
+
proc = subprocess.Popen(
|
93 |
+
cmd,
|
94 |
+
shell=True,
|
95 |
+
env=os.environ.copy(),
|
96 |
+
bufsize=1,
|
97 |
+
stdout=subprocess.PIPE,
|
98 |
+
stderr=subprocess.STDOUT,
|
99 |
+
cwd=DIFFAB_DIR,
|
100 |
+
)
|
101 |
+
for line in iter(proc.stdout.readline, b''):
|
102 |
+
output_buffer += line.decode()
|
103 |
+
|
104 |
+
display_widget.code(
|
105 |
+
'\n'.join(output_buffer.splitlines()[-10:]),
|
106 |
+
)
|
107 |
+
proc.stdout.close()
|
108 |
+
proc.wait()
|
109 |
+
|
110 |
+
|
111 |
+
@st.cache
|
112 |
+
def renumber_antibody_cached(in_pdb, out_pdb, file_id):
|
113 |
+
return renumber_antibody(
|
114 |
+
in_pdb, out_pdb, return_other_chains=True
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def gather_results(result_dir):
|
119 |
+
outputs = []
|
120 |
+
for root, dirs, files in os.walk(result_dir):
|
121 |
+
for fname in files:
|
122 |
+
if not re.match('^\d\d\d\d\.pdb$', fname):
|
123 |
+
continue
|
124 |
+
fpath = os.path.join(root, fname)
|
125 |
+
gname = os.path.basename(root)
|
126 |
+
|
127 |
+
outputs.append((gname, fname, fpath))
|
128 |
+
|
129 |
+
parser = PDB.PDBParser(QUIET=True)
|
130 |
+
records = []
|
131 |
+
fpath_to_name = {}
|
132 |
+
for gname, fname, fpath in tqdm(outputs):
|
133 |
+
name = f"{gname}_{fname}"
|
134 |
+
structure = parser.get_structure(name, fpath)
|
135 |
+
model = structure[0]
|
136 |
+
record = {
|
137 |
+
'name': name,
|
138 |
+
'H1': None, 'H2': None, 'H3': None,
|
139 |
+
'L1': None, 'L2': None, 'L3': None,
|
140 |
+
'gname': gname, 'fname': fname, 'fpath': fpath,
|
141 |
+
}
|
142 |
+
for chain in model:
|
143 |
+
try:
|
144 |
+
seq, reslist = biopython_chain_to_sequence(chain)
|
145 |
+
numbers, abchain = assign_number_to_sequence(seq)
|
146 |
+
if abchain.chain_type == 'H':
|
147 |
+
record['H1'] = abchain.cdr1_seq
|
148 |
+
record['H2'] = abchain.cdr2_seq
|
149 |
+
record['H3'] = abchain.cdr3_seq
|
150 |
+
elif abchain.chain_type in ('L', 'K'):
|
151 |
+
record['L1'] = abchain.cdr1_seq
|
152 |
+
record['L2'] = abchain.cdr2_seq
|
153 |
+
record['L3'] = abchain.cdr3_seq
|
154 |
+
except abnumber.ChainParseError as e:
|
155 |
+
pass
|
156 |
+
records.append(record)
|
157 |
+
fpath_to_name[fpath] = name
|
158 |
+
|
159 |
+
with tarfile.open(os.path.join(result_dir, 'generated.tar.gz'), 'w:gz') as tar:
|
160 |
+
for record in records:
|
161 |
+
info = tar.gettarinfo(record['fpath'])
|
162 |
+
info.name = record['name']
|
163 |
+
tar.addfile(
|
164 |
+
tarinfo = info,
|
165 |
+
fileobj = open(record['fpath'], 'rb'),
|
166 |
+
)
|
167 |
+
|
168 |
+
records = pd.DataFrame(records)
|
169 |
+
|
170 |
+
return records, fpath_to_name
|
171 |
+
|
172 |
+
|
173 |
+
def main():
|
174 |
+
# Temporary workspace directory
|
175 |
+
if 'tempdir_path' not in st.session_state:
|
176 |
+
tempdir_path = tempfile.mkdtemp(prefix='streamlit')
|
177 |
+
st.session_state.tempdir_path = tempdir_path
|
178 |
+
else:
|
179 |
+
tempdir_path = st.session_state.tempdir_path
|
180 |
+
# Page layout
|
181 |
+
st.set_page_config(layout="wide")
|
182 |
+
st.markdown(
|
183 |
+
"# DiffAb \n\n"
|
184 |
+
"Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures (NeurIPS 2022) \n\n"
|
185 |
+
"[[Paper](https://www.biorxiv.org/content/10.1101/2022.07.10.499510.abstract)] "
|
186 |
+
"[[Code](https://github.com/luost26/diffab)]"
|
187 |
+
)
|
188 |
+
left_col, right_col = st.columns(2)
|
189 |
+
|
190 |
+
# Step 1: Upload PDB or choose an example
|
191 |
+
uploaded_file = None
|
192 |
+
with left_col:
|
193 |
+
uploaded_file = st.file_uploader(
|
194 |
+
'Antigen structure or antibody-antigen complex',
|
195 |
+
# disabled=True
|
196 |
+
)
|
197 |
+
|
198 |
+
if uploaded_file is None:
|
199 |
+
with st.expander('Download examples', expanded=True):
|
200 |
+
with open('./data/examples/7DK2_AB_C.pdb', 'r') as f:
|
201 |
+
st.download_button(
|
202 |
+
'RBD + Antibody Complex',
|
203 |
+
data = f,
|
204 |
+
file_name='RBD_AbAg.pdb',
|
205 |
+
)
|
206 |
+
with open('./data/examples/Omicron_RBD.pdb', 'r') as f:
|
207 |
+
st.download_button(
|
208 |
+
'RBD Antigen Only',
|
209 |
+
data = f,
|
210 |
+
file_name = 'RBD_AgOnly.pdb',
|
211 |
+
)
|
212 |
+
st.text('Please upload the downloaded PDB file to run the demo.')
|
213 |
+
|
214 |
+
if 'submit' not in st.session_state:
|
215 |
+
st.session_state.submit = False
|
216 |
+
if 'done' not in st.session_state:
|
217 |
+
st.session_state.done = False
|
218 |
+
|
219 |
+
# Step 1.2: Retrieve uploaded PDB
|
220 |
+
if uploaded_file is not None:
|
221 |
+
pdb_path = os.path.join(tempdir_path, 'structure.pdb')
|
222 |
+
renum_path = os.path.join(tempdir_path, 'structure_renumber.pdb')
|
223 |
+
with open(pdb_path, 'w') as f:
|
224 |
+
f.write(uploaded_file.getvalue().decode())
|
225 |
+
H_chains, L_chains, Ag_chains = renumber_antibody_cached(
|
226 |
+
in_pdb = pdb_path,
|
227 |
+
out_pdb = renum_path,
|
228 |
+
file_id = uploaded_file.id
|
229 |
+
)
|
230 |
+
H_chain = H_chains[0] if H_chains else None
|
231 |
+
L_chain = L_chains[0] if L_chains else None
|
232 |
+
docking = H_chain is None and L_chain is None
|
233 |
+
|
234 |
+
# Step 2: Design options
|
235 |
+
if uploaded_file is not None:
|
236 |
+
with left_col:
|
237 |
+
st.dataframe(pd.DataFrame({
|
238 |
+
'Heavy': {'Chain': H_chain},
|
239 |
+
'Light': {'Chain': L_chain},
|
240 |
+
'Antigen': {'Chain': ','.join(Ag_chains)},
|
241 |
+
}), use_container_width=True)
|
242 |
+
|
243 |
+
form = st.form('design_form')
|
244 |
+
with form:
|
245 |
+
if H_chain is None and L_chain is None:
|
246 |
+
# Antigen only
|
247 |
+
cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3']
|
248 |
+
cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
|
249 |
+
mode_options = ['denovo_dock']
|
250 |
+
elif H_chain is not None and L_chain is None:
|
251 |
+
# Heavy chain + Antigen
|
252 |
+
cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3']
|
253 |
+
cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
|
254 |
+
mode_options = ['denovo', 'opt', 'fixbb']
|
255 |
+
elif H_chain is None and L_chain is not None:
|
256 |
+
# Light chain + Antigen
|
257 |
+
cdr_options = ['L_CDR1', 'L_CDR2', 'L_CDR3']
|
258 |
+
cdr_default = ['L_CDR1', 'L_CDR2', 'L_CDR3']
|
259 |
+
mode_options = ['denovo', 'opt', 'fixbb']
|
260 |
+
else:
|
261 |
+
# H + L + Ag
|
262 |
+
cdr_options = ['H_CDR1', 'H_CDR2', 'H_CDR3', 'L_CDR1', 'L_CDR2', 'L_CDR3']
|
263 |
+
cdr_default = ['H_CDR1', 'H_CDR2', 'H_CDR3']
|
264 |
+
mode_options = ['denovo', 'opt', 'fixbb']
|
265 |
+
|
266 |
+
design_mode = st.radio(
|
267 |
+
'Mode',
|
268 |
+
mode_options,
|
269 |
+
format_func=dict_to_func(DESIGN_MODES),
|
270 |
+
# disabled=True,
|
271 |
+
)
|
272 |
+
cdr_choices = st.multiselect(
|
273 |
+
'CDRs',
|
274 |
+
cdr_options,
|
275 |
+
default = cdr_default,
|
276 |
+
format_func=dict_to_func(CDR_OPTIONS),
|
277 |
+
# disabled=True,
|
278 |
+
)
|
279 |
+
|
280 |
+
if docking:
|
281 |
+
num_docks = st.slider(
|
282 |
+
'Number of docking poses',
|
283 |
+
min_value=1, max_value=10, value=DEFAULT_NUM_DOCKS,
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
num_docks = 0
|
287 |
+
num_designs = st.slider(
|
288 |
+
'Number of samples',
|
289 |
+
min_value=1, max_value=10, value=DEFAULT_NUM_SAMPLES,
|
290 |
+
)
|
291 |
+
|
292 |
+
submit = st.form_submit_button('Run')
|
293 |
+
st.session_state.submit = st.session_state.submit or submit
|
294 |
+
if submit:
|
295 |
+
st.session_state.done = False
|
296 |
+
|
297 |
+
# Step 3: Prepare configuration and run design
|
298 |
+
if uploaded_file is not None and st.session_state.submit:
|
299 |
+
config, config_path = get_config(
|
300 |
+
save_dir = tempdir_path,
|
301 |
+
mode = design_mode,
|
302 |
+
cdrs = cdr_choices,
|
303 |
+
num_samples = num_designs,
|
304 |
+
)
|
305 |
+
|
306 |
+
with right_col:
|
307 |
+
result_molecule_display = st.empty()
|
308 |
+
result_select_widget = st.empty()
|
309 |
+
result_table_display = st.empty()
|
310 |
+
result_download_btn = st.empty()
|
311 |
+
output_display = st.empty()
|
312 |
+
if not st.session_state.done:
|
313 |
+
run_design(
|
314 |
+
pdb_path = renum_path,
|
315 |
+
config_path = config_path,
|
316 |
+
output_dir = tempdir_path,
|
317 |
+
docking = docking,
|
318 |
+
display_widget = output_display,
|
319 |
+
num_docks = num_docks,
|
320 |
+
)
|
321 |
+
st.session_state.done = True
|
322 |
+
|
323 |
+
result_dir = os.path.join(tempdir_path, 'design')
|
324 |
+
df_cols = ['name'] + list(CDR_OPTIONS.values())
|
325 |
+
df_results, fpath_to_name = gather_results(result_dir)
|
326 |
+
st.session_state.results = (df_results, fpath_to_name)
|
327 |
+
|
328 |
+
# Step 5: Show results:
|
329 |
+
if st.session_state.submit and st.session_state.done:
|
330 |
+
result_dir = os.path.join(tempdir_path, 'design')
|
331 |
+
df_results, fpath_to_name = st.session_state.results
|
332 |
+
|
333 |
+
df_cols = ['name'] + list(CDR_OPTIONS.values())
|
334 |
+
result_table_display.dataframe(df_results[df_cols], use_container_width=True)
|
335 |
+
|
336 |
+
display_pdb_path = result_select_widget.selectbox(
|
337 |
+
label = "Visualize",
|
338 |
+
options = df_results['fpath'],
|
339 |
+
format_func = dict_to_func(fpath_to_name),
|
340 |
+
)
|
341 |
+
|
342 |
+
with open(os.path.join(result_dir, 'generated.tar.gz'), 'rb') as f:
|
343 |
+
result_download_btn.download_button(
|
344 |
+
label = "Download PDBs",
|
345 |
+
data = f,
|
346 |
+
file_name = "generated.tar.gz",
|
347 |
+
)
|
348 |
+
|
349 |
+
if not os.path.exists(display_pdb_path):
|
350 |
+
display_pdb_path = df_results['fpath'][0]
|
351 |
+
with open(display_pdb_path, 'r') as f:
|
352 |
+
pdb_str = f.read()
|
353 |
+
xyzview = py3Dmol.view(width=380, height=380)
|
354 |
+
xyzview.addModelsAsFrames(pdb_str)
|
355 |
+
xyzview.setStyle({'cartoon':{'color':'spectrum'}})
|
356 |
+
xyzview.zoomTo()
|
357 |
+
with result_molecule_display:
|
358 |
+
stmol.showmol(xyzview, width=380, height=380)
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == '__main__':
|
362 |
+
main()
|