import base64
import pickle
import re
import uuid
import pandas as pd
import streamlit as st
from CGRtools.files import SMILESRead
from streamlit_ketcher import st_ketcher
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import disable_progress_bars
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.mcts.search import extract_tree_stats
from synplan.mcts.tree import Tree
from synplan.chem.utils import mol_from_smiles
from synplan.utils.config import TreeConfig, PolicyNetworkConfig
from synplan.utils.loading import load_reaction_rules, load_building_blocks
from synplan.utils.visualisation import generate_results_html, get_route_svg
disable_progress_bars("huggingface_hub")
smiles_parser = SMILESRead.create_parser(ignore=True)
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
"""
Issued from
Generates a link to download the given object_to_download.
Params:
------
object_to_download: The object to be downloaded.
download_filename (str): filename and extension of file. e.g. mydata.csv,
some_txt_output.txt download_link_text (str): Text to display for download
link.
button_text (str): Text to display on download button (e.g. 'click here to download file')
pickle_it (bool): If True, pickle file.
Returns:
-------
(str): the anchor tag to download object_to_download
Examples:
--------
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
"""
if pickle_it:
try:
object_to_download = pickle.dumps(object_to_download)
except pickle.PicklingError as e:
st.write(e)
return None
else:
if isinstance(object_to_download, bytes):
pass
elif isinstance(object_to_download, pd.DataFrame):
object_to_download = object_to_download.to_csv(index=False).encode('utf-8')
# Try JSON encode for everything else # else: # object_to_download = json.dumps(object_to_download)
try:
# some strings <-> bytes conversions necessary here
b64 = base64.b64encode(object_to_download.encode()).decode()
except AttributeError:
b64 = base64.b64encode(object_to_download).decode()
button_uuid = str(uuid.uuid4()).replace('-', '')
button_id = re.sub('\d+', '', button_uuid)
custom_css = f"""
"""
dl_link = custom_css + f'{button_text}
'
return dl_link
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
intro_text = '''
This is a demo of the graphical user interface of
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
'''
st.title("`SynPlanner GUI`")
st.write(intro_text)
st.header('Molecule input')
st.markdown(
'''
You can provide a molecular structure by either providing:
* SMILES string + Enter
* Draw it + Apply
'''
)
DEFAULT_MOL = 'c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O'
molecule = st.text_input("SMILES:", DEFAULT_MOL)
smile_code = st_ketcher(molecule)
target_molecule = mol_from_smiles(smile_code)
building_blocks_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="building_blocks_em_sa_ln.smi",
subfolder="building_blocks",
local_dir="."
)
ranking_policy_weights_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="ranking_policy_network.ckpt",
subfolder="uspto/weights",
local_dir="."
)
reaction_rules_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="uspto_reaction_rules.pickle",
subfolder="uspto",
local_dir="."
)
st.header('Launch calculation')
st.markdown(
'''If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor).'''
)
st.markdown(f"The molecule SMILES is actually: ``{smile_code}``")
st.subheader('Planning options')
st.markdown(
'''
The description of each option can be found in the
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
'''
)
col_options_1, col_options_2 = st.columns(2, gap="medium")
with col_options_1:
search_strategy_input = st.selectbox(label='Search strategy', options=('Expansion first', 'Evaluation first',), index=0)
ucb_type = st.selectbox(label='Search strategy', options=('uct', 'puct', 'value'), index=0)
c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...")
with col_options_2:
max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=300, value=100)
max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6)
min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0)
search_strategy_translator = {
"Expansion first": "expansion_first",
"Evaluation first": "evaluation_first",
}
search_strategy = search_strategy_translator[search_strategy_input]
submit_planning = st.button('Start retrosynthetic planning')
if submit_planning:
with st.status("Downloading data"):
st.write("Downloading building blocks")
building_blocks = load_building_blocks(building_blocks_path, standardize=False)
st.write('Downloading reaction rules')
reaction_rules = load_reaction_rules(reaction_rules_path)
st.write('Loading policy network')
policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path)
policy_function = PolicyNetworkFunction(policy_config=policy_config)
tree_config = TreeConfig(
search_strategy=search_strategy,
evaluation_type="rollout",
max_iterations=max_iterations,
max_depth=max_depth,
min_mol_size=min_mol_size,
init_node_value=0.5,
ucb_type=ucb_type,
c_ucb=c_ucb,
silent=True
)
tree = Tree(
target=target_molecule,
config=tree_config,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
expansion_function=policy_function,
evaluation_function=None,
)
mcts_progress_text = "Running retrosynthetic planning"
mcts_bar = st.progress(0, text=mcts_progress_text)
for step, (solved, node_id) in enumerate(tree):
mcts_bar.progress(step / max_iterations, text=mcts_progress_text)
res = extract_tree_stats(tree, target_molecule)
st.header('Results')
if res["solved"]:
st.balloons()
st.subheader("Examples of found retrosynthetic routes")
image_counter = 0
visualised_node_ids = set()
for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
if image_counter == 3:
break
if n % 2 == 0 and node_id not in visualised_node_ids:
visualised_node_ids.add(node_id)
image_counter += 1
num_steps = len(tree.synthesis_route(node_id))
route_score = round(tree.route_score(node_id), 3)
st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
stat_col, download_col = st.columns(2, gap="medium")
with stat_col:
st.subheader("Statistics")
df = pd.DataFrame(res, index=[0])
st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
with download_col:
st.subheader("Downloads")
html_body = generate_results_html(tree, html_path=None, extended=True)
dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
dl_csv = download_button(pd.DataFrame(res, index=[0]), 'results_synplanner.csv',
'Download statistics as a csv file')
st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
else:
st.write("Found no reaction path.")
st.divider()
st.header('Restart from the beginning?')
if st.button("Restart"):
st.rerun()