|
import base64 |
|
import pickle |
|
import re |
|
import uuid |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from CGRtools.files import SMILESRead |
|
from streamlit_ketcher import st_ketcher |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import disable_progress_bars |
|
|
|
|
|
from synplan.mcts.expansion import PolicyNetworkFunction |
|
from synplan.mcts.search import extract_tree_stats |
|
from synplan.mcts.tree import Tree |
|
from synplan.chem.utils import mol_from_smiles |
|
from synplan.utils.config import TreeConfig, PolicyNetworkConfig |
|
from synplan.utils.loading import load_reaction_rules, load_building_blocks |
|
from synplan.utils.visualisation import generate_results_html, get_route_svg |
|
|
|
disable_progress_bars("huggingface_hub") |
|
|
|
smiles_parser = SMILESRead.create_parser(ignore=True) |
|
|
|
|
|
def download_button(object_to_download, download_filename, button_text, pickle_it=False): |
|
""" |
|
Issued from |
|
Generates a link to download the given object_to_download. |
|
Params: |
|
------ |
|
object_to_download: The object to be downloaded. |
|
download_filename (str): filename and extension of file. e.g. mydata.csv, |
|
some_txt_output.txt download_link_text (str): Text to display for download |
|
link. |
|
button_text (str): Text to display on download button (e.g. 'click here to download file') |
|
pickle_it (bool): If True, pickle file. |
|
Returns: |
|
------- |
|
(str): the anchor tag to download object_to_download |
|
Examples: |
|
-------- |
|
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') |
|
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') |
|
""" |
|
if pickle_it: |
|
try: |
|
object_to_download = pickle.dumps(object_to_download) |
|
except pickle.PicklingError as e: |
|
st.write(e) |
|
return None |
|
|
|
else: |
|
if isinstance(object_to_download, bytes): |
|
pass |
|
|
|
elif isinstance(object_to_download, pd.DataFrame): |
|
object_to_download = object_to_download.to_csv(index=False).encode('utf-8') |
|
|
|
|
|
|
|
try: |
|
|
|
b64 = base64.b64encode(object_to_download.encode()).decode() |
|
|
|
except AttributeError: |
|
b64 = base64.b64encode(object_to_download).decode() |
|
|
|
button_uuid = str(uuid.uuid4()).replace('-', '') |
|
button_id = re.sub('\d+', '', button_uuid) |
|
|
|
custom_css = f""" |
|
<style> |
|
#{button_id} {{ |
|
background-color: rgb(255, 255, 255); |
|
color: rgb(38, 39, 48); |
|
text-decoration: none; |
|
border-radius: 4px; |
|
border-width: 1px; |
|
border-style: solid; |
|
border-color: rgb(230, 234, 241); |
|
border-image: initial; |
|
}} |
|
#{button_id}:hover {{ |
|
border-color: rgb(246, 51, 102); |
|
color: rgb(246, 51, 102); |
|
}} |
|
#{button_id}:active {{ |
|
box-shadow: none; |
|
background-color: rgb(246, 51, 102); |
|
color: white; |
|
}} |
|
</style> """ |
|
|
|
dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>' |
|
|
|
return dl_link |
|
|
|
|
|
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide") |
|
|
|
intro_text = ''' |
|
This is a demo of the graphical user interface of |
|
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/). |
|
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning. |
|
|
|
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html). |
|
''' |
|
|
|
st.title("`SynPlanner GUI`") |
|
|
|
st.write(intro_text) |
|
|
|
st.header('Molecule input') |
|
st.markdown( |
|
''' |
|
You can provide a molecular structure by either providing: |
|
* SMILES string + Enter |
|
* Draw it + Apply |
|
''' |
|
) |
|
|
|
DEFAULT_MOL = 'c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O' |
|
molecule = st.text_input("SMILES:", DEFAULT_MOL) |
|
smile_code = st_ketcher(molecule) |
|
target_molecule = mol_from_smiles(smile_code) |
|
|
|
building_blocks_path = hf_hub_download( |
|
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", |
|
filename="building_blocks_em_sa_ln.smi", |
|
subfolder="building_blocks", |
|
local_dir="." |
|
) |
|
|
|
ranking_policy_weights_path = hf_hub_download( |
|
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", |
|
filename="ranking_policy_network.ckpt", |
|
subfolder="uspto/weights", |
|
local_dir="." |
|
) |
|
|
|
reaction_rules_path = hf_hub_download( |
|
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", |
|
filename="uspto_reaction_rules.pickle", |
|
subfolder="uspto", |
|
local_dir="." |
|
) |
|
|
|
st.header('Launch calculation') |
|
st.markdown( |
|
'''If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor).''' |
|
) |
|
st.markdown(f"The molecule SMILES is actually: ``{smile_code}``") |
|
|
|
st.subheader('Planning options') |
|
|
|
st.markdown( |
|
''' |
|
The description of each option can be found in the |
|
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree). |
|
''' |
|
) |
|
|
|
col_options_1, col_options_2 = st.columns(2, gap="medium") |
|
|
|
with col_options_1: |
|
search_strategy_input = st.selectbox(label='Search strategy', options=('Expansion first', 'Evaluation first',), index=0) |
|
ucb_type = st.selectbox(label='Search strategy', options=('uct', 'puct', 'value'), index=0) |
|
c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...") |
|
|
|
with col_options_2: |
|
max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=300, value=100) |
|
max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6) |
|
min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0) |
|
|
|
search_strategy_translator = { |
|
"Expansion first": "expansion_first", |
|
"Evaluation first": "evaluation_first", |
|
} |
|
search_strategy = search_strategy_translator[search_strategy_input] |
|
|
|
submit_planning = st.button('Start retrosynthetic planning') |
|
|
|
if submit_planning: |
|
with st.status("Downloading data"): |
|
st.write("Downloading building blocks") |
|
building_blocks = load_building_blocks(building_blocks_path, standardize=False) |
|
st.write('Downloading reaction rules') |
|
reaction_rules = load_reaction_rules(reaction_rules_path) |
|
st.write('Loading policy network') |
|
policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path) |
|
policy_function = PolicyNetworkFunction(policy_config=policy_config) |
|
|
|
tree_config = TreeConfig( |
|
search_strategy=search_strategy, |
|
evaluation_type="rollout", |
|
max_iterations=max_iterations, |
|
max_depth=max_depth, |
|
min_mol_size=min_mol_size, |
|
init_node_value=0.5, |
|
ucb_type=ucb_type, |
|
c_ucb=c_ucb, |
|
silent=True |
|
) |
|
|
|
tree = Tree( |
|
target=target_molecule, |
|
config=tree_config, |
|
reaction_rules=reaction_rules, |
|
building_blocks=building_blocks, |
|
expansion_function=policy_function, |
|
evaluation_function=None, |
|
) |
|
|
|
mcts_progress_text = "Running retrosynthetic planning" |
|
mcts_bar = st.progress(0, text=mcts_progress_text) |
|
for step, (solved, node_id) in enumerate(tree): |
|
mcts_bar.progress(step / max_iterations, text=mcts_progress_text) |
|
|
|
res = extract_tree_stats(tree, target_molecule) |
|
|
|
st.header('Results') |
|
if res["solved"]: |
|
st.balloons() |
|
|
|
st.subheader("Examples of found retrosynthetic routes") |
|
|
|
image_counter = 0 |
|
visualised_node_ids = set() |
|
for n, node_id in enumerate(sorted(set(tree.winning_nodes))): |
|
if image_counter == 3: |
|
break |
|
if n % 2 == 0 and node_id not in visualised_node_ids: |
|
visualised_node_ids.add(node_id) |
|
image_counter += 1 |
|
num_steps = len(tree.synthesis_route(node_id)) |
|
route_score = round(tree.route_score(node_id), 3) |
|
st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}") |
|
|
|
stat_col, download_col = st.columns(2, gap="medium") |
|
|
|
with stat_col: |
|
st.subheader("Statistics") |
|
df = pd.DataFrame(res, index=[0]) |
|
st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]]) |
|
|
|
with download_col: |
|
st.subheader("Downloads") |
|
html_body = generate_results_html(tree, html_path=None, extended=True) |
|
dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file') |
|
|
|
dl_csv = download_button(pd.DataFrame(res, index=[0]), 'results_synplanner.csv', |
|
'Download statistics as a csv file') |
|
st.markdown(dl_html + dl_csv, unsafe_allow_html=True) |
|
|
|
else: |
|
st.write("Found no reaction path.") |
|
|
|
st.divider() |
|
st.header('Restart from the beginning?') |
|
if st.button("Restart"): |
|
st.rerun() |
|
|