import gradio as gr from huggingface_hub import HfApi, get_collection, list_collections from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions import pandas as pd import os def get_models(): # this is the collection id for the molecular property prediction models collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c") models = dict() for item in collection.items: if item.item_type == "model": item_name = item.item_id.split("/")[-1] models[item_name] = item.item_id assert item_name in task_types, f"{item_name} is not in the task_types" assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions" return models candidate_models = get_models() properties = list(candidate_models.keys()) model = MolecularPropertyPredictionModel() def get_description(property_name): return dataset_descriptions[property_name] def predict_single_label(smiles, property_name): adapter_id = candidate_models[property_name] info = model.swith_adapter(property_name, adapter_id) running_status = None if info == "keep": running_status = "Adapter is the same as the current one" #print("Adapter is the same as the current one") elif info == "switched": running_status = "Adapter is switched successfully" #print("Adapter is switched successfully") elif info == "error": running_status = "Adapter is not found" #print("Adapter is not found") return "NA", running_status else: running_status = "Unknown error" return "NA", running_status #prediction = model.predict(smiles, property_name, adapter_id) prediction = model.predict_single_smiles(smiles, task_types[property_name]) if prediction is None: return "NA", "Invalid SMILES string" # if the prediction is a float, round it to 3 decimal places if isinstance(prediction, float): prediction = round(prediction, 3) return prediction, "Prediction is done" def predict_file(file, property_name): adapter_id = candidate_models[property_name] info = model.swith_adapter(property_name, adapter_id) running_status = None if info == "keep": running_status = "Adapter is the same as the current one" #print("Adapter is the same as the current one") elif info == "switched": running_status = "Adapter is switched successfully" #print("Adapter is switched successfully") elif info == "error": running_status = "Adapter is not found" #print("Adapter is not found") return None, None, file, running_status else: running_status = "Unknown error" return None, None, file, running_status df = pd.read_csv(file) # we have already checked the file contains the "smiles" column df = model.predict_file(df, task_types[property_name]) # we should save this file to the disk to be downloaded # rename the file to have "_prediction" suffix prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") print(file, prediction_file) # save the file to the disk df.to_csv(prediction_file, index=False) return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done" def validate_file(file): try: if file.endswith(".csv"): df = pd.read_csv(file) if "smiles" not in df.columns: # we should clear the file input return "Invalid file content. The csv file must contain column named 'smiles'", \ None, gr.update(visible=False), gr.update(visible=False) # check the length of the smiles length = len(df["smiles"]) elif file.endswith(".smi"): return "Invalid file extension", \ None, gr.update(visible=False), gr.update(visible=False) else: return "Invalid file extension", \ None, gr.update(visible=False), gr.update(visible=False) except Exception as e: return "Invalid file content.", \ None, gr.update(visible=False), gr.update(visible=False) if length > 100: return "The space does not support the file containing more than 100 SMILES", \ None, gr.update(visible=False), gr.update(visible=False) return "Valid file", file, gr.update(visible=True), gr.update(visible=False) def raise_error(status): if status != "Valid file": raise gr.Error(status) return None def clear_file(download_button): # we might need to delete the prediction file and uploaded file prediction_path = download_button print(prediction_path) if prediction_path and os.path.exists(prediction_path): os.remove(prediction_path) original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv") original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi") if os.path.exists(original_data_file_0): os.remove(original_data_file_0) if os.path.exists(original_data_file_1): os.remove(original_data_file_1) #if os.path.exists(file): # os.remove(file) #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") #if os.path.exists(prediction_file): # os.remove(prediction_file) return gr.update(visible=False), gr.update(visible=False), None def build_inference(): with gr.Blocks() as demo: # first row - Dropdown input #with gr.Row(): dropdown = gr.Dropdown(properties, label="Property", value=properties[0]) description_box = gr.Textbox(label="Property description", lines=5, interactive=False, value=dataset_descriptions[properties[0]]) # third row - Textbox input and prediction label with gr.Row(equal_height=True): with gr.Column(): textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here", lines=1) predict_single_smiles_button = gr.Button("Predict", size='sm') prediction = gr.Label("Prediction will appear here") running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False) input_file = gr.File(label="Molecule file", file_count='single', file_types=[".smi", ".csv"], height=300) predict_file_button = gr.Button("Predict", size='sm', visible=False) download_button = gr.DownloadButton("Download", size='sm', visible=False) # dropdown change event dropdown.change(get_description, inputs=dropdown, outputs=description_box) # predict single button click event predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label]) # input file upload event file_status = gr.State() input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status) # input file clear event input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file]) # predict file button click event predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label]) return demo demo = build_inference() if __name__ == '__main__': demo.launch()