import base64 |
import os |
import json |
import pickle |
import uuid |
import re |
import numpy as np |
import pandas as pd |
import pygad |
import streamlit as st |
from CGRtools.containers import QueryContainer, MoleculeContainer |
from CGRtools.utils import grid_depict |
from CGRtools import smiles |
from VQGAE.models import VQGAE, OrderingNetwork |
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules |
from streamlit.components.v1 import html |
MoleculeContainer.depict_settings(aam=False) |
def download_button(object_to_download, download_filename, button_text, pickle_it=False): |
""" |
Generates a link to download the given object_to_download. |
Params: |
------ |
object_to_download: The object to be downloaded. |
download_filename (str): filename and extension of file. e.g. mydata.csv, |
some_txt_output.txt download_link_text (str): Text to display for download |
link. |
button_text (str): Text to display on download button (e.g. 'click here to download file') |
pickle_it (bool): If True, pickle file. |
Returns: |
------- |
(str): the anchor tag to download object_to_download |
Examples: |
-------- |
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') |
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') |
""" |
if pickle_it: |
try: |
object_to_download = pickle.dumps(object_to_download) |
except pickle.PicklingError as e: |
st.write(e) |
return None |
else: |
if isinstance(object_to_download, bytes): |
pass |
elif isinstance(object_to_download, pd.DataFrame): |
object_to_download = object_to_download.to_csv(index=False) |
else: |
object_to_download = json.dumps(object_to_download) |
try: |
b64 = base64.b64encode(object_to_download.encode()).decode() |
except AttributeError as e: |
b64 = base64.b64encode(object_to_download).decode() |
button_uuid = str(uuid.uuid4()).replace('-', '') |
button_id = re.sub('\d+', '', button_uuid) |
custom_css = f""" |
<style> |
#{button_id} {{ |
background-color: rgb(255, 255, 255); |
color: rgb(38, 39, 48); |
padding: 0.25em 0.38em; |
position: relative; |
text-decoration: none; |
border-radius: 4px; |
border-width: 1px; |
border-style: solid; |
border-color: rgb(230, 234, 241); |
border-image: initial; |
}} |
#{button_id}:hover {{ |
border-color: rgb(246, 51, 102); |
color: rgb(246, 51, 102); |
}} |
#{button_id}:active {{ |
box-shadow: none; |
background-color: rgb(246, 51, 102); |
color: white; |
}} |
</style> """ |
dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>' |
return dl_link |
def file_selector(folder_path='.'): |
filenames = os.listdir(folder_path) |
selected_filename = st.selectbox('Select a file', filenames) |
return os.path.join(folder_path, selected_filename) |
def render_svg(svg_string): |
"""Renders the given svg string.""" |
c = st.container() |
with c: |
html(svg_string, height=300, scrolling=True) |
allene = QueryContainer() |
allene.add_atom("C") |
allene.add_atom("A") |
allene.add_atom("A") |
allene.add_bond(1, 2, 2) |
allene.add_bond(1, 3, 2) |
peroxide_charge = QueryContainer() |
peroxide_charge.add_atom("O", charge=-1) |
peroxide_charge.add_atom("O") |
peroxide_charge.add_bond(1, 2, 1) |
peroxide = QueryContainer() |
peroxide.add_atom("O") |
peroxide.add_atom("O") |
peroxide.add_bond(1, 2, 1) |
def convert_df(df): |
return df.to_csv(index=False).encode('utf-8') |
def tanimoto_kernel(x, y): |
""" |
"The Tanimoto coefficient is a measure of the similarity between two sets. |
It is defined as the size of the intersection divided by the size of the union of the sample sets." |
The Tanimoto coefficient is also known as the Jaccard index |
Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py |
:param x: 2D array of features. |
:param y: 2D array of features. |
:return: The Tanimoto coefficient between the two arrays. |
""" |
x_dot = np.dot(x, y.T) |
x2 = (x ** 2).sum(axis=1) |
y2 = (y ** 2).sum(axis=1) |
len_x2 = len(x2) |
len_y2 = len(y2) |
result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot) |
result[np.isnan(result)] = 0 |
return result |
def fitness_func_batch(ga_instance, solutions, solutions_indices): |
frag_counts = np.array(solutions) |
if len(frag_counts.shape) == 1: |
frag_counts = frag_counts[np.newaxis, :] |
rf_score = rf_model.predict_proba(frag_counts)[:, 1] |
mol_size = frag_counts.sum(-1).astype(np.int64) |
size_penalty = np.where(mol_size < 18, -1.0, 0.) |
dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1) |
dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0) |
fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty |
if use_ordering_score: |
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
_, ordering_scores = restore_order(frag_inds, ordering_model) |
ordering_scores = np.array(ordering_scores) |
fitness += 0.2 * ordering_scores |
return fitness.tolist() |
def on_generation_progress(ga): |
global ga_progress |
global ga_bar |
ga_progress = ga_progress + 1 |
ga_bar.progress(ga_progress / num_generations, text=ga_progress_text) |
@st.cache_data |
def load_data(batch_size): |
X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"] |
Y = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["y"] |
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp: |
rf_model = pickle.load(inp) |
vqgae_model = VQGAE.load_from_checkpoint( |
"saved_model/vqgae.ckpt", |
task="decode", |
batch_size=batch_size, |
map_location="cpu" |
) |
vqgae_model = vqgae_model.eval() |
ordering_model = OrderingNetwork.load_from_checkpoint( |
"saved_model/ordering_network.ckpt", |
batch_size=batch_size, |
map_location="cpu" |
) |
ordering_model = ordering_model.eval() |
return X, Y, rf_model, vqgae_model, ordering_model |
st.title('Inverse QSAR of Tubulin with VQGAE') |
with st.sidebar: |
with st.form("ga_options"): |
num_generations = st.slider( |
'Number of generations for GA', |
min_value=3, |
max_value=40, |
value=5 |
) |
parent_selection_type = st.selectbox( |
label='Parent selection type', |
options=( |
'Steady-state selection', |
'Roulette wheel selection', |
'Stochastic universal selection', |
'Rank selection', |
'Random selection', |
'Tournament selection' |
), |
index=1 |
) |
parent_selection_translator = { |
"Steady-state selection": "sss", |
"Roulette wheel selection": "rws", |
"Stochastic universal selection": "sus", |
"Rank selection": "rank", |
"Random selection": "random", |
"Tournament selection": "tournament", |
} |
parent_selection_type = parent_selection_translator[parent_selection_type] |
crossover_type = st.selectbox( |
label='Crossover type', |
options=( |
'Single point', |
'Two points', |
), |
index=0 |
) |
crossover_translator = { |
"Single point": "single_point", |
"Two points": "two_points", |
} |
crossover_type = crossover_translator[crossover_type] |
num_parents_mating = int( |
st.slider( |
'Pecentage of parents mating taken from initial population', |
min_value=0, |
max_value=100, |
step=1, |
value=33, |
) * 603 // 100 |
) |
keep_parents = int( |
st.slider( |
'Percentage of parents kept taken from number of parents mating', |
min_value=0, |
max_value=100, |
step=1, |
value=66 |
) * num_parents_mating // 100 |
) |
use_ordering_score = st.toggle('Use ordering score', value=True) |
batch_size = int(st.number_input("Batch size", value=200, placeholder="Type a number...")) |
random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number...")) |
submit = st.form_submit_button('Start optimisation') |
if submit: |
X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size) |
assert X.shape == (603, 4096) |
ga_instance = pygad.GA( |
fitness_func=fitness_func_batch, |
on_generation=on_generation_progress, |
initial_population=X, |
num_genes=X.shape[-1], |
fitness_batch_size=batch_size, |
num_generations=num_generations, |
num_parents_mating=num_parents_mating, |
parent_selection_type=parent_selection_type, |
crossover_type=crossover_type, |
mutation_type="adaptive", |
mutation_percent_genes=[10, 5], |
save_best_solutions=False, |
save_solutions=True, |
keep_elitism=0, |
keep_parents=keep_parents, |
suppress_warnings=True, |
random_seed=random_seed, |
gene_type=int |
) |
ga_progress = 0 |
ga_progress_text = "Genetic optimisation in progress. Please wait." |
ga_bar = st.progress(0, text=ga_progress_text) |
ga_instance.run() |
ga_bar.empty() |
with st.spinner('Getting unique solutions'): |
unique_solutions = list(set(tuple(s) for s in ga_instance.solutions)) |
st.success(f'{len(unique_solutions)} solutions were obtained') |
scores = { |
"rf_score": [], |
"similarity_score": [] |
} |
if use_ordering_score: |
scores["ordering_score"] = [] |
rescoring_progress = 0 |
rescoring_progress_text = "Rescoring obtained solutions" |
rescoring_bar = st.progress(0, text=rescoring_progress_text) |
total_rescoring_steps = len(unique_solutions) // batch_size + 1 |
for rescoring_step in range(total_rescoring_steps): |
vqgae_latents = unique_solutions[rescoring_step * batch_size: (rescoring_step + 1) * batch_size] |
frag_counts = np.array(vqgae_latents) |
rf_scores = rf_model.predict_proba(frag_counts)[:, 1] |
similarity_scores = tanimoto_kernel(frag_counts, X).max(-1) |
scores["rf_score"].extend(rf_scores.tolist()) |
scores["similarity_score"].extend(similarity_scores.tolist()) |
if use_ordering_score: |
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) |
_, ordering_scores = restore_order(frag_inds, ordering_model) |
scores["ordering_score"].extend(ordering_scores) |
rescoring_bar.progress(rescoring_step / total_rescoring_steps, text=rescoring_progress_text) |
sc_df = pd.DataFrame(scores) |
rescoring_bar.empty() |
if use_ordering_score: |
chosen_gen = sc_df[ |
(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7) |
] |
else: |
chosen_gen = sc_df[ |
(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) |
] |
chosen_ids = chosen_gen.index.to_list() |
chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids]) |
gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51) |
st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️") |
gen_molecules = [] |
results = {"smiles": [], "ordering_score": [], "validity": []} |
decoding_progress = 0 |
decoding_progress_text = "Decoding chosen solutions" |
decoding_bar = st.progress(0, text=decoding_progress_text) |
total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1 |
for decoding_step in range(total_decoding_steps): |
inputs = gen_frag_inds[decoding_step * batch_size: (decoding_step + 1) * batch_size] |
canon_order_inds, scores = restore_order( |
frag_inds=inputs, |
ordering_model=ordering_model, |
) |
molecules, validity = decode_molecules( |
ordered_frag_inds=canon_order_inds, |
vqgae_model=vqgae_model |
) |
gen_molecules.extend(molecules) |
results["smiles"].extend([str(molecule) for molecule in molecules]) |
results["ordering_score"].extend(scores) |
results["validity"].extend([1 if i else 0 for i in validity]) |
decoding_bar.progress(decoding_step / total_decoding_steps, text=decoding_progress_text) |
gen_stats = pd.DataFrame(results) |
decoding_bar.empty() |
full_stats = pd.concat([gen_stats, chosen_gen.reset_index()[["similarity_score", "rf_score"]]], axis=1) |
full_stats = full_stats[["smiles", "similarity_score", "rf_score", "ordering_score", "validity"]] |
valid_gen_stats = full_stats[full_stats.validity == 1] |
valid_gen_mols = [] |
for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")): |
valid_gen_mols.append(gen_molecules[i]) |
filtered_gen_mols = [] |
filtered_indices = [] |
for mol_i, mol in enumerate(valid_gen_mols): |
is_frag = allene < mol or peroxide_charge < mol or peroxide < mol |
is_ring = False |
for ring in mol.sssr: |
if len(ring) > 8 or len(ring) < 4: |
is_ring = True |
break |
if not is_frag and not is_ring: |
filtered_gen_mols.append(mol) |
filtered_indices.append(mol_i) |
filtered_gen_stats = valid_gen_stats.iloc[filtered_indices] |
st.subheader('Generation results', divider='rainbow') |
st.dataframe(filtered_gen_stats) |
download_button( |
object_to_download=convert_df(filtered_gen_stats), |
download_filename='vqgae_tubulin_inhibitors_valid.csv', |
button_text="Download results as CSV" |
) |
st.subheader('Examples of generated molecules') |
examples_smiles = filtered_gen_stats.sort_values(by=["similarity_score"], ascending=False).iloc[:6].smiles.to_list() |
examples = [] |
for smi in examples_smiles: |
mol = smiles(smi) |
mol.clean2d() |
examples.append(mol) |
svg = grid_depict(examples, 2) |
render_svg(svg) |
with st.expander("Show full stats"): |
st.dataframe(full_stats) |
download_button( |
object_to_download=convert_df(full_stats), |
download_filename='vqgae_tubulin_inhibitors_full.csv', |
button_text="Download full results as CSV" |
) |
if st.button("Restart"): |
st.rerun() |