File size: 13,649 Bytes
50fe1a2
6bb1bdf
63859b8
6bb1bdf
 
ac9d211
50fe1a2
6bb1bdf
 
7b5450f
6bb1bdf
 
 
 
 
63859b8
6bb1bdf
 
 
50fe1a2
6bb1bdf
63859b8
 
e98af12
6bb1bdf
 
63859b8
 
6bb1bdf
5801d99
6bb1bdf
63859b8
 
ec780ac
63859b8
 
ec780ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
ec780ac
63859b8
ec780ac
 
6bb1bdf
ec780ac
 
 
 
 
 
 
6bb1bdf
 
 
5801d99
6bb1bdf
63859b8
ec780ac
63859b8
 
ec780ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
ec780ac
 
63859b8
ec780ac
 
 
 
 
 
 
 
 
b1e4cef
6bb1bdf
ec780ac
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63859b8
 
6af59a3
63859b8
6bb1bdf
 
63859b8
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec780ac
6bb1bdf
 
 
 
ec780ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
 
 
 
 
 
ec780ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections
from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
import pandas as pd
import os
import spaces

def get_models():
    # this is the collection id for the molecular property prediction models
    collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c", token = os.environ.get("TOKEN"))
    models = dict()
    for item in collection.items:
        if item.item_type == "model":
            item_name = item.item_id.split("/")[-1]
            models[item_name] = item.item_id
            assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
            assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
    
    return models

candidate_models = get_models()
properties = [dataset_property_names[item] for item in candidate_models.keys()]
property_names = list(candidate_models.keys())
model = MolecularPropertyPredictionModel(candidate_models)

def get_description(property_name):
    property_id = dataset_property_names_to_dataset[property_name]
    return dataset_descriptions[property_id]

@spaces.GPU(duration=10)
def predict_single_label(smiles, property_name):
    property_id = dataset_property_names_to_dataset[property_name]

    try:
        adapter_id = candidate_models[property_id]
        info = model.swith_adapter(property_id, adapter_id)

        running_status = None
        if info == "keep":
            running_status = "Adapter is the same as the current one"
            #print("Adapter is the same as the current one")
        elif info == "switched":
            running_status = "Adapter is switched successfully"
            #print("Adapter is switched successfully")
        elif info == "error":
            running_status = "Adapter is not found"
            #print("Adapter is not found")
            return "NA", running_status
        else:
            running_status = "Unknown error"
            return "NA", running_status
    
        #prediction = model.predict(smiles, property_name, adapter_id)
        prediction = model.predict_single_smiles(smiles, dataset_task_types[property_id])
        if prediction is None:
            return "NA", "Invalid SMILES string"
    
        # if the prediction is a float, round it to 3 decimal places
        if isinstance(prediction, float):
            prediction = round(prediction, 3)
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return "NA", "Prediction failed"

    return prediction, "Prediction is done"

@spaces.GPU(duration=30)
def predict_file(file, property_name):
    property_id = dataset_property_names_to_dataset[property_name]
    try:
        adapter_id = candidate_models[property_id]
        info = model.swith_adapter(property_id, adapter_id)

        running_status = None
        if info == "keep":
            running_status = "Adapter is the same as the current one"
            #print("Adapter is the same as the current one")
        elif info == "switched":
            running_status = "Adapter is switched successfully"
            #print("Adapter is switched successfully")
        elif info == "error":
            running_status = "Adapter is not found"
            #print("Adapter is not found")
            return None, None, file, running_status
        else:
            running_status = "Unknown error"
            return None, None, file, running_status
    
        df = pd.read_csv(file)
        # we have already checked the file contains the "smiles" column
        df = model.predict_file(df, dataset_task_types[property_id])
        # we should save this file to the disk to be downloaded
        # rename the file to have "_prediction" suffix
        prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
        print(file, prediction_file)
        # save the file to the disk
        df.to_csv(prediction_file, index=False)
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
    
    return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"

def validate_file(file):
    try:
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if "smiles" not in df.columns:
                # we should clear the file input
                return "Invalid file content. The csv file must contain column named 'smiles'", \
                         None, gr.update(visible=False), gr.update(visible=False)
            
            # check the length of the smiles
            length = len(df["smiles"])

        elif file.endswith(".smi"):
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)

        else:
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)
    except Exception as e:
        return "Invalid file content.", \
                None, gr.update(visible=False), gr.update(visible=False)
    
    if length > 100: 
        return "The space does not support the file containing more than 100 SMILES", \
                None, gr.update(visible=False), gr.update(visible=False)

    return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
    

def raise_error(status):
    if status != "Valid file":
        raise gr.Error(status)
    return None


def clear_file(download_button):
    # we might need to delete the prediction file and uploaded file
    prediction_path = download_button
    print(prediction_path)
    if prediction_path and os.path.exists(prediction_path):
        os.remove(prediction_path)
        original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
        original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
        if os.path.exists(original_data_file_0):
            os.remove(original_data_file_0)
        if os.path.exists(original_data_file_1):
            os.remove(original_data_file_1)
    #if os.path.exists(file):
    #    os.remove(file)
    #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
    #if os.path.exists(prediction_file):
    #    os.remove(prediction_file)
    

    return gr.update(visible=False), gr.update(visible=False), None

def build_inference():

    with gr.Blocks() as demo:
        # first row - Dropdown input
        #with gr.Row():
        print(property_names[0].lower())
        print(properties)
        gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.")
        dropdown = gr.Dropdown(properties, label="Property", value=dataset_property_names[property_names[0].lower()])
        description_box = gr.Textbox(label="Property description", lines=5,
                                     interactive=False,
                                     value=dataset_descriptions[property_names[0].lower()])
        # third row - Textbox input and prediction label
        with gr.Row(equal_height=True):
            with gr.Column():
                textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
                                     lines=1)
                predict_single_smiles_button = gr.Button("Predict", size='sm')
            prediction = gr.Label("Prediction will appear here")

        running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
        
        input_file = gr.File(label="Molecule file",
                       file_count='single',
                       file_types=[".smi", ".csv"], height=300)
        predict_file_button = gr.Button("Predict", size='sm', visible=False)
        download_button = gr.DownloadButton("Download", size='sm', visible=False)
        stop_button = gr.Button("Stop", size='sm', visible=False)

        # dropdown change event
        dropdown.change(get_description, inputs=dropdown, outputs=description_box)
        # predict single button click event
        predict_single_smiles_button.click(lambda:(gr.update(interactive=False), 
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
                                                   .then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
                                                   .then(lambda:(gr.update(interactive=True), 
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
        # input file upload event
        file_status = gr.State()
        input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
        # input file clear event
        input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
        # predict file button click event
        predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False), 
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False, visible=True),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=True, visible=False),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False),
                                                               ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
                                                               .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
                                                               .then(lambda:(gr.update(interactive=True), 
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
        # stop button click event
        #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
        
    return demo


demo = build_inference() 

if __name__ == '__main__':
    demo.launch()