File size: 13,179 Bytes
edaff0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f4c1fe
 
 
 
 
 
 
 
 
edaff0a
 
 
 
 
 
 
 
 
 
eefdfab
edaff0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed5d690
0e7c02e
 
edaff0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import gradio as gr
from huggingface_hub import HfApi, get_collection, list_collections, list_models
#from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset
from utils import ReactionPredictionModel
import pandas as pd
import os
import spaces

def get_models():
    # we only support two models 
    # 1. ChemFM/uspto_mit_synthesis
    # 2. ChemFM/uspto_full_retro


    models = dict()
    models['mit_synthesis'] = 'ChemFM/uspto_mit_synthesis'
    models['full_retro'] = 'ChemFM/uspto_full_retro'

    
    #for item in collection.items:
    #    if item.item_type == "model":
    #        item_name = item.item_id.split("/")[-1]
    #        models[item_name] = item.item_id
    #        assert item_name in dataset_task_types, f"{item_name} is not in the task_types"
    #        assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
    
    return models

candidate_models = get_models()
task_names = {
    'mit_synthesis': 'Reaction Synthesis',
    'full_retro': 'Reaction Retro Synthesis'
}
task_names_to_tasks = {v: k for k, v in task_names.items()}
tasks = list(candidate_models.keys())
task_descriptions = {
    'mit_synthesis': 'Predict the reaction products given the reactants and reagents. \n' + \
                     '1. This model is trained on the USPTO MIT dataset. \n' + \
                     '2. The reactants and reagents are mixed in the input SMILES string. \n' + \
                     '3. Different compounds are separated by ".". \n' + \
                     '4. Input SMILES string example: C1CCOC1.N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F.[H-].[Na+]',
    'full_retro': 'Predict the reaction precursors given the reaction products. \n' + \
                    '1. This model is trained on the USPTO Full dataset. \n' + \
                    '2. In this dataset, we consider only a single product in the input SMILES string. \n' + \
                    '3. Input SMILES string example: CC(=O)OCC(=O)[C@@]1(O)CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(C)C3=CC[C@@]21C'
}

#property_names = list(candidate_models.keys())
model = ReactionPredictionModel(candidate_models)
#model = MolecularPropertyPredictionModel(candidate_models)

def get_description(task_name):
    task = task_names_to_tasks[task_name]
    return task_descriptions[task]

@spaces.GPU(duration=60)
def predict_single_label(smiles, task_name):
    task = task_names_to_tasks[task_name]

    try:

        running_status = None
    
        #prediction = model.predict(smiles, property_name, adapter_id)
        prediction = model.predict_single_smiles(smiles, task)
        if prediction is None:
            return "NA", "Invalid SMILES string"
    
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return "NA", "Prediction failed"

    prediction = "\n".join([f"{idx+1}. {item}" for idx, item in enumerate(prediction)])
    return prediction, "Prediction is done"

"""
@spaces.GPU(duration=30)
def predict_file(file, property_name):
    property_id = dataset_property_names_to_dataset[property_name]
    try:
        adapter_id = candidate_models[property_id]
        info = model.swith_adapter(property_id, adapter_id)

        running_status = None
        if info == "keep":
            running_status = "Adapter is the same as the current one"
            #print("Adapter is the same as the current one")
        elif info == "switched":
            running_status = "Adapter is switched successfully"
            #print("Adapter is switched successfully")
        elif info == "error":
            running_status = "Adapter is not found"
            #print("Adapter is not found")
            return None, None, file, running_status
        else:
            running_status = "Unknown error"
            return None, None, file, running_status
    
        df = pd.read_csv(file)
        # we have already checked the file contains the "smiles" column
        df = model.predict_file(df, dataset_task_types[property_id])
        # we should save this file to the disk to be downloaded
        # rename the file to have "_prediction" suffix
        prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
        print(file, prediction_file)
        # save the file to the disk
        df.to_csv(prediction_file, index=False)
    except Exception as e:
        # no matter what the error is, we should return
        print(e)
        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed"
    
    return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"

def validate_file(file):
    try:
        if file.endswith(".csv"):
            df = pd.read_csv(file)
            if "smiles" not in df.columns:
                # we should clear the file input
                return "Invalid file content. The csv file must contain column named 'smiles'", \
                         None, gr.update(visible=False), gr.update(visible=False)
            
            # check the length of the smiles
            length = len(df["smiles"])

        elif file.endswith(".smi"):
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)

        else:
            return "Invalid file extension", \
                    None, gr.update(visible=False), gr.update(visible=False)
    except Exception as e:
        return "Invalid file content.", \
                None, gr.update(visible=False), gr.update(visible=False)
    
    if length > 100: 
        return "The space does not support the file containing more than 100 SMILES", \
                None, gr.update(visible=False), gr.update(visible=False)

    return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
"""
    

def raise_error(status):
    if status != "Valid file":
        raise gr.Error(status)
    return None


"""
def clear_file(download_button):
    # we might need to delete the prediction file and uploaded file
    prediction_path = download_button
    print(prediction_path)
    if prediction_path and os.path.exists(prediction_path):
        os.remove(prediction_path)
        original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
        original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
        if os.path.exists(original_data_file_0):
            os.remove(original_data_file_0)
        if os.path.exists(original_data_file_1):
            os.remove(original_data_file_1)
    #if os.path.exists(file):
    #    os.remove(file)
    #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
    #if os.path.exists(prediction_file):
    #    os.remove(prediction_file)
    

    return gr.update(visible=False), gr.update(visible=False), None
"""

def build_inference():

    with gr.Blocks() as demo:
        # first row - Dropdown input
        #with gr.Row():
        # gr.Markdown(f"<span style='color: red;'>This is space is a Beta version, and you might encounter the problems duing the using. We will inspect this space and launch a new version by Jan 26, 2025.  </span> ")
        
        
        dropdown = gr.Dropdown([task_names[key] for key in tasks], label="Task", value=task_names[tasks[0]])
        description_box = gr.Textbox(label="Task description", lines=5,
                                     interactive=False,
                                     value= task_descriptions[tasks[0]])
        # third row - Textbox input and prediction label
        #with gr.Row(equal_height=True):
        #    with gr.Column():
        textbox = gr.Textbox(label="Reatants (Products) SMILES string", type="text", placeholder="Provide a SMILES string here",
                             lines=1)
        predict_single_smiles_button = gr.Button("Predict", size='sm')
        #prediction = gr.Label("Prediction will appear here")
        prediction = gr.Textbox(label="Predictions", type="text", placeholder=None, lines=10, interactive=False)

        running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
        
        #input_file = gr.File(label="Molecule file",
        #               file_count='single',
        #               file_types=[".smi", ".csv"], height=300)
        #predict_file_button = gr.Button("Predict", size='sm', visible=False)
        #download_button = gr.DownloadButton("Download", size='sm', visible=False)
        #stop_button = gr.Button("Stop", size='sm', visible=False)

        # dropdown change event
        dropdown.change(get_description, inputs=dropdown, outputs=description_box)
        # predict single button click event
        predict_single_smiles_button.click(lambda:(gr.update(interactive=False), 
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   gr.update(interactive=False),
                                                   ) , outputs=[dropdown, textbox, predict_single_smiles_button, running_terminal_label])\
                                                   .then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
                                                   .then(lambda:(gr.update(interactive=True), 
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 gr.update(interactive=True),
                                                                 ) , outputs=[dropdown, textbox, predict_single_smiles_button, running_terminal_label])
        """
        # input file upload event
        file_status = gr.State()
        input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
        # input file clear event
        input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
        # predict file button click event
        predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False), 
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False, visible=True),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=True, visible=False),
                                                               gr.update(interactive=False),
                                                               gr.update(interactive=False),
                                                               ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
                                                               .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
                                                               .then(lambda:(gr.update(interactive=True), 
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             gr.update(interactive=True),
                                                                             ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
        
        # stop button click event
        #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
        """
        
    return demo


demo = build_inference() 

if __name__ == '__main__':
    demo.launch()