|
import base64 |
|
import os |
|
import json |
|
import pickle |
|
import uuid |
|
import re |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import pygad |
|
import streamlit as st |
|
from CGRtools.containers import QueryContainer, MoleculeContainer |
|
from CGRtools.utils import grid_depict |
|
from CGRtools import smiles |
|
from VQGAE.models import VQGAE, OrderingNetwork |
|
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules |
|
from streamlit.components.v1 import html |
|
|
|
MoleculeContainer.depict_settings(aam=False) |
|
|
|
|
|
def download_button(object_to_download, download_filename, button_text, pickle_it=False): |
|
""" |
|
Generates a link to download the given object_to_download. |
|
Params: |
|
------ |
|
object_to_download: The object to be downloaded. |
|
download_filename (str): filename and extension of file. e.g. mydata.csv, |
|
some_txt_output.txt download_link_text (str): Text to display for download |
|
link. |
|
button_text (str): Text to display on download button (e.g. 'click here to download file') |
|
pickle_it (bool): If True, pickle file. |
|
Returns: |
|
------- |
|
(str): the anchor tag to download object_to_download |
|
Examples: |
|
-------- |
|
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') |
|
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') |
|
""" |
|
if pickle_it: |
|
try: |
|
object_to_download = pickle.dumps(object_to_download) |
|
except pickle.PicklingError as e: |
|
st.write(e) |
|
return None |
|
|
|
else: |
|
if isinstance(object_to_download, bytes): |
|
pass |
|
|
|
elif isinstance(object_to_download, pd.DataFrame): |
|
object_to_download = object_to_download.to_csv(index=False) |
|
|
|
|
|
else: |
|
object_to_download = json.dumps(object_to_download) |
|
|
|
try: |
|
|
|
b64 = base64.b64encode(object_to_download.encode()).decode() |
|
|
|
except AttributeError as e: |
|
b64 = base64.b64encode(object_to_download).decode() |
|
|
|
button_uuid = str(uuid.uuid4()).replace('-', '') |
|
button_id = re.sub('\d+', '', button_uuid) |
|
|
|
custom_css = f""" |
|
<style> |
|
#{button_id} {{ |
|
background-color: rgb(255, 255, 255); |
|
color: rgb(38, 39, 48); |
|
padding: 0.25em 0.38em; |
|
position: relative; |
|
text-decoration: none; |
|
border-radius: 4px; |
|
border-width: 1px; |
|
border-style: solid; |
|
border-color: rgb(230, 234, 241); |
|
border-image: initial; |
|
}} |
|
#{button_id}:hover {{ |
|
border-color: rgb(246, 51, 102); |
|
color: rgb(246, 51, 102); |
|
}} |
|
#{button_id}:active {{ |
|
box-shadow: none; |
|
background-color: rgb(246, 51, 102); |
|
color: white; |
|
}} |
|
</style> """ |
|
|
|
dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>' |
|
|
|
return dl_link |
|
|
|
|
|
def file_selector(folder_path='.'): |
|
filenames = os.listdir(folder_path) |
|
selected_filename = st.selectbox('Select a file', filenames) |
|
return os.path.join(folder_path, selected_filename) |
|
|
|
|
|
def render_svg(svg_string): |
|
"""Renders the given svg string.""" |
|
c = st.container() |
|
with c: |
|
html(svg_string, height=300, scrolling=True) |
|
|
|
|
|
|
|
allene = QueryContainer() |
|
allene.add_atom("C") |
|
allene.add_atom("A") |
|
allene.add_atom("A") |
|
allene.add_bond(1, 2, 2) |
|
allene.add_bond(1, 3, 2) |
|
|
|
peroxide_charge = QueryContainer() |
|
peroxide_charge.add_atom("O", charge=-1) |
|
peroxide_charge.add_atom("O") |
|
peroxide_charge.add_bond(1, 2, 1) |
|
|
|
peroxide = QueryContainer() |
|
peroxide.add_atom("O") |
|
peroxide.add_atom("O") |
|
peroxide.add_bond(1, 2, 1) |
|
|
|
|
|
def convert_df(df): |
|
return df.to_csv(index=False).encode('utf-8') |
|
|
|
|
|
def tanimoto_kernel(x, y): |
|
""" |
|
"The Tanimoto coefficient is a measure of the similarity between two sets. |
|
It is defined as the size of the intersection divided by the size of the union of the sample sets." |
|
|
|
The Tanimoto coefficient is also known as the Jaccard index |
|
|
|
Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py |
|
|
|
:param x: 2D array of features. |
|
:param y: 2D array of features. |
|
:return: The Tanimoto coefficient between the two arrays. |
|
""" |
|
x_dot = np.dot(x, y.T) |
|
|
|
x2 = (x ** 2).sum(axis=1) |
|
y2 = (y ** 2).sum(axis=1) |
|
|
|
len_x2 = len(x2) |
|
len_y2 = len(y2) |
|
|
|
result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot) |
|
result[np.isnan(result)] = 0 |
|
|
|
return result |
|
|
|
|
|
def fitness_func_batch(ga_instance, solutions, solutions_indices): |
|
frag_counts = np.array(solutions) |
|
if len(frag_counts.shape) == 1: |
|
frag_counts = frag_counts[np.newaxis, :] |
|
|
|
|
|
rf_score = rf_model.predict_proba(frag_counts)[:, 1] |
|
|
|
|
|
mol_size = frag_counts.sum(-1).astype(np.int64) |
|
size_penalty = np.where(mol_size < 18, -1.0, 0.) |
|
|
|
|
|
dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1) |
|
dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0) |
|
|
|
|
|
fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty |
|
|
|
|
|
if use_ordering_score: |
|
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
|
_, ordering_scores = restore_order(frag_inds, ordering_model) |
|
ordering_scores = np.array(ordering_scores) |
|
fitness += 0.2 * ordering_scores |
|
|
|
return fitness.tolist() |
|
|
|
|
|
def on_generation_progress(ga): |
|
global ga_progress |
|
global ga_bar |
|
ga_progress = ga_progress + 1 |
|
ga_bar.progress(ga_progress / num_generations, text=ga_progress_text) |
|
|
|
|
|
@st.cache_data |
|
def load_data(batch_size): |
|
X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"] |
|
Y = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["y"] |
|
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp: |
|
rf_model = pickle.load(inp) |
|
|
|
vqgae_model = VQGAE.load_from_checkpoint( |
|
"saved_model/vqgae.ckpt", |
|
task="decode", |
|
batch_size=batch_size, |
|
map_location="cpu" |
|
) |
|
vqgae_model = vqgae_model.eval() |
|
|
|
ordering_model = OrderingNetwork.load_from_checkpoint( |
|
"saved_model/ordering_network.ckpt", |
|
batch_size=batch_size, |
|
map_location="cpu" |
|
) |
|
ordering_model = ordering_model.eval() |
|
return X, Y, rf_model, vqgae_model, ordering_model |
|
|
|
|
|
st.title('Inverse QSAR of Tubulin with VQGAE') |
|
|
|
with st.sidebar: |
|
with st.form("ga_options"): |
|
num_generations = st.slider( |
|
'Number of generations for GA', |
|
min_value=3, |
|
max_value=40, |
|
value=5 |
|
) |
|
|
|
parent_selection_type = st.selectbox( |
|
label='Parent selection type', |
|
options=( |
|
'Steady-state selection', |
|
'Roulette wheel selection', |
|
'Stochastic universal selection', |
|
'Rank selection', |
|
'Random selection', |
|
'Tournament selection' |
|
), |
|
index=1 |
|
) |
|
|
|
parent_selection_translator = { |
|
"Steady-state selection": "sss", |
|
"Roulette wheel selection": "rws", |
|
"Stochastic universal selection": "sus", |
|
"Rank selection": "rank", |
|
"Random selection": "random", |
|
"Tournament selection": "tournament", |
|
} |
|
|
|
parent_selection_type = parent_selection_translator[parent_selection_type] |
|
|
|
crossover_type = st.selectbox( |
|
label='Crossover type', |
|
options=( |
|
'Single point', |
|
'Two points', |
|
), |
|
index=0 |
|
) |
|
|
|
crossover_translator = { |
|
"Single point": "single_point", |
|
"Two points": "two_points", |
|
} |
|
|
|
crossover_type = crossover_translator[crossover_type] |
|
|
|
num_parents_mating = int( |
|
st.slider( |
|
'Pecentage of parents mating taken from initial population', |
|
min_value=0, |
|
max_value=100, |
|
step=1, |
|
value=33, |
|
) * 603 // 100 |
|
) |
|
|
|
keep_parents = int( |
|
st.slider( |
|
'Percentage of parents kept taken from number of parents mating', |
|
min_value=0, |
|
max_value=100, |
|
step=1, |
|
value=66 |
|
) * num_parents_mating // 100 |
|
) |
|
|
|
use_ordering_score = st.toggle('Use ordering score', value=True) |
|
batch_size = int(st.number_input("Batch size", value=200, placeholder="Type a number...")) |
|
random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number...")) |
|
submit = st.form_submit_button('Start optimisation') |
|
|
|
if submit: |
|
X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size) |
|
assert X.shape == (603, 4096) |
|
|
|
ga_instance = pygad.GA( |
|
fitness_func=fitness_func_batch, |
|
on_generation=on_generation_progress, |
|
initial_population=X, |
|
num_genes=X.shape[-1], |
|
fitness_batch_size=batch_size, |
|
num_generations=num_generations, |
|
num_parents_mating=num_parents_mating, |
|
parent_selection_type=parent_selection_type, |
|
crossover_type=crossover_type, |
|
mutation_type="adaptive", |
|
mutation_percent_genes=[10, 5], |
|
|
|
save_best_solutions=False, |
|
save_solutions=True, |
|
keep_elitism=0, |
|
keep_parents=keep_parents, |
|
suppress_warnings=True, |
|
random_seed=random_seed, |
|
gene_type=int |
|
) |
|
|
|
ga_progress = 0 |
|
ga_progress_text = "Genetic optimisation in progress. Please wait." |
|
ga_bar = st.progress(0, text=ga_progress_text) |
|
ga_instance.run() |
|
ga_bar.empty() |
|
|
|
with st.spinner('Getting unique solutions'): |
|
unique_solutions = list(set(tuple(s) for s in ga_instance.solutions)) |
|
st.success(f'{len(unique_solutions)} solutions were obtained') |
|
|
|
scores = { |
|
"rf_score": [], |
|
"similarity_score": [] |
|
} |
|
|
|
if use_ordering_score: |
|
scores["ordering_score"] = [] |
|
|
|
rescoring_progress = 0 |
|
rescoring_progress_text = "Rescoring obtained solutions" |
|
rescoring_bar = st.progress(0, text=rescoring_progress_text) |
|
total_rescoring_steps = len(unique_solutions) // batch_size + 1 |
|
for rescoring_step in range(total_rescoring_steps): |
|
vqgae_latents = unique_solutions[rescoring_step * batch_size: (rescoring_step + 1) * batch_size] |
|
frag_counts = np.array(vqgae_latents) |
|
rf_scores = rf_model.predict_proba(frag_counts)[:, 1] |
|
similarity_scores = tanimoto_kernel(frag_counts, X).max(-1) |
|
scores["rf_score"].extend(rf_scores.tolist()) |
|
scores["similarity_score"].extend(similarity_scores.tolist()) |
|
if use_ordering_score: |
|
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
|
_, ordering_scores = restore_order(frag_inds, ordering_model) |
|
scores["ordering_score"].extend(ordering_scores) |
|
rescoring_bar.progress(rescoring_step / total_rescoring_steps, text=rescoring_progress_text) |
|
sc_df = pd.DataFrame(scores) |
|
rescoring_bar.empty() |
|
|
|
if use_ordering_score: |
|
chosen_gen = sc_df[ |
|
(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7) |
|
] |
|
else: |
|
chosen_gen = sc_df[ |
|
(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) |
|
] |
|
|
|
chosen_ids = chosen_gen.index.to_list() |
|
chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids]) |
|
gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51) |
|
st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️") |
|
|
|
gen_molecules = [] |
|
results = {"smiles": [], "ordering_score": [], "validity": []} |
|
decoding_progress = 0 |
|
decoding_progress_text = "Decoding chosen solutions" |
|
decoding_bar = st.progress(0, text=decoding_progress_text) |
|
total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1 |
|
for decoding_step in range(total_decoding_steps): |
|
inputs = gen_frag_inds[decoding_step * batch_size: (decoding_step + 1) * batch_size] |
|
canon_order_inds, scores = restore_order( |
|
frag_inds=inputs, |
|
ordering_model=ordering_model, |
|
) |
|
molecules, validity = decode_molecules( |
|
ordered_frag_inds=canon_order_inds, |
|
vqgae_model=vqgae_model |
|
) |
|
gen_molecules.extend(molecules) |
|
results["smiles"].extend([str(molecule) for molecule in molecules]) |
|
results["ordering_score"].extend(scores) |
|
results["validity"].extend([1 if i else 0 for i in validity]) |
|
decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text) |
|
gen_stats = pd.DataFrame(results) |
|
decoding_bar.empty() |
|
full_stats = pd.concat([gen_stats, chosen_gen.reset_index()[["similarity_score", "rf_score"]]], axis=1) |
|
full_stats = full_stats[["smiles", "similarity_score", "rf_score", "ordering_score", "validity"]] |
|
valid_gen_stats = full_stats[full_stats.validity == 1] |
|
|
|
valid_gen_mols = [] |
|
for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")): |
|
valid_gen_mols.append(gen_molecules[i]) |
|
|
|
filtered_gen_mols = [] |
|
filtered_indices = [] |
|
for mol_i, mol in enumerate(valid_gen_mols): |
|
is_frag = allene < mol or peroxide_charge < mol or peroxide < mol |
|
is_ring = False |
|
for ring in mol.sssr: |
|
if len(ring) > 8 or len(ring) < 4: |
|
is_ring = True |
|
break |
|
if not is_frag and not is_ring: |
|
filtered_gen_mols.append(mol) |
|
filtered_indices.append(mol_i) |
|
|
|
filtered_gen_stats = valid_gen_stats.iloc[filtered_indices] |
|
|
|
st.subheader('Generation results', divider='rainbow') |
|
st.dataframe(filtered_gen_stats) |
|
download_button( |
|
object_to_download=convert_df(filtered_gen_stats), |
|
download_filename='vqgae_tubulin_inhibitors_valid.csv', |
|
button_text="Download results as CSV" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.subheader('Examples of generated molecules') |
|
examples_smiles = filtered_gen_stats.sort_values(by=["similarity_score"], ascending=False).iloc[:6].smiles.to_list() |
|
examples = [] |
|
for smi in examples_smiles: |
|
mol = smiles(smi) |
|
mol.clean2d() |
|
examples.append(mol) |
|
svg = grid_depict(examples, 2) |
|
render_svg(svg) |
|
|
|
with st.expander("Show full stats"): |
|
st.dataframe(full_stats) |
|
|
|
download_button( |
|
object_to_download=convert_df(full_stats), |
|
download_filename='vqgae_tubulin_inhibitors_full.csv', |
|
button_text="Download full results as CSV" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Restart"): |
|
st.rerun() |
|
|