feiyang-cai's picture
add the readme
6af59a3
import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections
from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
import pandas as pd
import os
import spaces
def get_models():
# this is the collection id for the molecular property prediction models
collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c", token = os.environ.get("TOKEN"))
models = dict()
for item in collection.items:
if item.item_type == "model":
item_name = item.item_id.split("/")[-1]
models[item_name] = item.item_id
assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
return models
candidate_models = get_models()
properties = [dataset_property_names[item] for item in candidate_models.keys()]
property_names = list(candidate_models.keys())
model = MolecularPropertyPredictionModel(candidate_models)
def get_description(property_name):
property_id = dataset_property_names_to_dataset[property_name]
return dataset_descriptions[property_id]
@spaces.GPU(duration=10)
def predict_single_label(smiles, property_name):
property_id = dataset_property_names_to_dataset[property_name]
try:
adapter_id = candidate_models[property_id]
info = model.swith_adapter(property_id, adapter_id)
running_status = None
if info == "keep":
running_status = "Adapter is the same as the current one"
#print("Adapter is the same as the current one")
elif info == "switched":
running_status = "Adapter is switched successfully"
#print("Adapter is switched successfully")
elif info == "error":
running_status = "Adapter is not found"
#print("Adapter is not found")
return "NA", running_status
else:
running_status = "Unknown error"
return "NA", running_status
#prediction = model.predict(smiles, property_name, adapter_id)
prediction = model.predict_single_smiles(smiles, dataset_task_types[property_id])
if prediction is None:
return "NA", "Invalid SMILES string"
# if the prediction is a float, round it to 3 decimal places
if isinstance(prediction, float):
prediction = round(prediction, 3)
except Exception as e:
# no matter what the error is, we should return
print(e)
return "NA", "Prediction failed"
return prediction, "Prediction is done"
@spaces.GPU(duration=30)
def predict_file(file, property_name):
property_id = dataset_property_names_to_dataset[property_name]
try:
adapter_id = candidate_models[property_id]
info = model.swith_adapter(property_id, adapter_id)
running_status = None
if info == "keep":
running_status = "Adapter is the same as the current one"
#print("Adapter is the same as the current one")
elif info == "switched":
running_status = "Adapter is switched successfully"
#print("Adapter is switched successfully")
elif info == "error":
running_status = "Adapter is not found"
#print("Adapter is not found")
return None, None, file, running_status
else:
running_status = "Unknown error"
return None, None, file, running_status
df = pd.read_csv(file)
# we have already checked the file contains the "smiles" column
df = model.predict_file(df, dataset_task_types[property_id])
# we should save this file to the disk to be downloaded
# rename the file to have "_prediction" suffix
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
print(file, prediction_file)
# save the file to the disk
df.to_csv(prediction_file, index=False)
except Exception as e:
# no matter what the error is, we should return
print(e)
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
def validate_file(file):
try:
if file.endswith(".csv"):
df = pd.read_csv(file)
if "smiles" not in df.columns:
# we should clear the file input
return "Invalid file content. The csv file must contain column named 'smiles'", \
None, gr.update(visible=False), gr.update(visible=False)
# check the length of the smiles
length = len(df["smiles"])
elif file.endswith(".smi"):
return "Invalid file extension", \
None, gr.update(visible=False), gr.update(visible=False)
else:
return "Invalid file extension", \
None, gr.update(visible=False), gr.update(visible=False)
except Exception as e:
return "Invalid file content.", \
None, gr.update(visible=False), gr.update(visible=False)
if length > 100:
return "The space does not support the file containing more than 100 SMILES", \
None, gr.update(visible=False), gr.update(visible=False)
return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
def raise_error(status):
if status != "Valid file":
raise gr.Error(status)
return None
def clear_file(download_button):
# we might need to delete the prediction file and uploaded file
prediction_path = download_button
print(prediction_path)
if prediction_path and os.path.exists(prediction_path):
os.remove(prediction_path)
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
if os.path.exists(original_data_file_0):
os.remove(original_data_file_0)
if os.path.exists(original_data_file_1):
os.remove(original_data_file_1)
#if os.path.exists(file):
# os.remove(file)
#prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
#if os.path.exists(prediction_file):
# os.remove(prediction_file)
return gr.update(visible=False), gr.update(visible=False), None
def build_inference():
with gr.Blocks() as demo:
# first row - Dropdown input
#with gr.Row():
print(property_names[0].lower())
print(properties)
gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
dropdown = gr.Dropdown(properties, label="Property", value=dataset_property_names[property_names[0].lower()])
description_box = gr.Textbox(label="Property description", lines=5,
interactive=False,
value=dataset_descriptions[property_names[0].lower()])
# third row - Textbox input and prediction label
with gr.Row(equal_height=True):
with gr.Column():
textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
lines=1)
predict_single_smiles_button = gr.Button("Predict", size='sm')
prediction = gr.Label("Prediction will appear here")
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
input_file = gr.File(label="Molecule file",
file_count='single',
file_types=[".smi", ".csv"], height=300)
predict_file_button = gr.Button("Predict", size='sm', visible=False)
download_button = gr.DownloadButton("Download", size='sm', visible=False)
stop_button = gr.Button("Stop", size='sm', visible=False)
# dropdown change event
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
# predict single button click event
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
.then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
.then(lambda:(gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
# input file upload event
file_status = gr.State()
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
# input file clear event
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
# predict file button click event
predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False, visible=True),
gr.update(interactive=False),
gr.update(interactive=True, visible=False),
gr.update(interactive=False),
gr.update(interactive=False),
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
.then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
.then(lambda:(gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(interactive=True),
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
# stop button click event
#stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
return demo
demo = build_inference()
if __name__ == '__main__':
demo.launch()