Spaces:
Running
on
Zero
Running
on
Zero
feiyang-cai
commited on
Commit
·
6bb1bdf
1
Parent(s):
50fe1a2
finish the basic function
Browse files- app.py +185 -4
- dataset_descriptions.json +112 -0
- utils.py +286 -0
app.py
CHANGED
@@ -1,7 +1,188 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import HfApi, get_collection, list_collections
|
3 |
+
from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions
|
4 |
+
import pandas as pd
|
5 |
+
import os
|
6 |
|
7 |
+
def get_models():
|
8 |
+
# this is the collection id for the molecular property prediction models
|
9 |
+
collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c")
|
10 |
+
models = dict()
|
11 |
+
for item in collection.items:
|
12 |
+
if item.item_type == "model":
|
13 |
+
item_name = item.item_id.split("/")[-1]
|
14 |
+
models[item_name] = item.item_id
|
15 |
+
assert item_name in task_types, f"{item_name} is not in the task_types"
|
16 |
+
assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
|
17 |
+
|
18 |
+
return models
|
19 |
|
20 |
+
candidate_models = get_models()
|
21 |
+
properties = list(candidate_models.keys())
|
22 |
+
model = MolecularPropertyPredictionModel()
|
23 |
+
|
24 |
+
def get_description(property_name):
|
25 |
+
return dataset_descriptions[property_name]
|
26 |
+
|
27 |
+
def predict_single_label(smiles, property_name):
|
28 |
+
adapter_id = candidate_models[property_name]
|
29 |
+
info = model.swith_adapter(property_name, adapter_id)
|
30 |
+
|
31 |
+
running_status = None
|
32 |
+
if info == "keep":
|
33 |
+
running_status = "Adapter is the same as the current one"
|
34 |
+
#print("Adapter is the same as the current one")
|
35 |
+
elif info == "switched":
|
36 |
+
running_status = "Adapter is switched successfully"
|
37 |
+
#print("Adapter is switched successfully")
|
38 |
+
elif info == "error":
|
39 |
+
running_status = "Adapter is not found"
|
40 |
+
#print("Adapter is not found")
|
41 |
+
return "NA", running_status
|
42 |
+
else:
|
43 |
+
running_status = "Unknown error"
|
44 |
+
return "NA", running_status
|
45 |
+
|
46 |
+
#prediction = model.predict(smiles, property_name, adapter_id)
|
47 |
+
prediction = model.predict_single_smiles(smiles, task_types[property_name])
|
48 |
+
if prediction is None:
|
49 |
+
return "NA", "Invalid SMILES string"
|
50 |
+
|
51 |
+
# if the prediction is a float, round it to 3 decimal places
|
52 |
+
if isinstance(prediction, float):
|
53 |
+
prediction = round(prediction, 3)
|
54 |
+
|
55 |
+
return prediction, "Prediction is done"
|
56 |
+
|
57 |
+
def predict_file(file, property_name):
|
58 |
+
adapter_id = candidate_models[property_name]
|
59 |
+
info = model.swith_adapter(property_name, adapter_id)
|
60 |
+
|
61 |
+
running_status = None
|
62 |
+
if info == "keep":
|
63 |
+
running_status = "Adapter is the same as the current one"
|
64 |
+
#print("Adapter is the same as the current one")
|
65 |
+
elif info == "switched":
|
66 |
+
running_status = "Adapter is switched successfully"
|
67 |
+
#print("Adapter is switched successfully")
|
68 |
+
elif info == "error":
|
69 |
+
running_status = "Adapter is not found"
|
70 |
+
#print("Adapter is not found")
|
71 |
+
return None, None, file, running_status
|
72 |
+
else:
|
73 |
+
running_status = "Unknown error"
|
74 |
+
return None, None, file, running_status
|
75 |
+
|
76 |
+
df = pd.read_csv(file)
|
77 |
+
# we have already checked the file contains the "smiles" column
|
78 |
+
df = model.predict_file(df, task_types[property_name])
|
79 |
+
# we should save this file to the disk to be downloaded
|
80 |
+
# rename the file to have "_prediction" suffix
|
81 |
+
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
82 |
+
print(file, prediction_file)
|
83 |
+
# save the file to the disk
|
84 |
+
df.to_csv(prediction_file, index=False)
|
85 |
+
|
86 |
+
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
|
87 |
+
|
88 |
+
def validate_file(file):
|
89 |
+
try:
|
90 |
+
if file.endswith(".csv"):
|
91 |
+
df = pd.read_csv(file)
|
92 |
+
if "smiles" not in df.columns:
|
93 |
+
# we should clear the file input
|
94 |
+
return "Invalid file content. The csv file must contain column named 'smiles'", \
|
95 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
96 |
+
|
97 |
+
# check the length of the smiles
|
98 |
+
length = len(df["smiles"])
|
99 |
+
|
100 |
+
elif file.endswith(".smi"):
|
101 |
+
return "Invalid file extension", \
|
102 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
103 |
+
|
104 |
+
else:
|
105 |
+
return "Invalid file extension", \
|
106 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
107 |
+
except Exception as e:
|
108 |
+
return "Invalid file content.", \
|
109 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
110 |
+
|
111 |
+
if length > 100:
|
112 |
+
return "The space does not support the file containing more than 100 SMILES", \
|
113 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
114 |
+
|
115 |
+
return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
|
116 |
+
|
117 |
+
|
118 |
+
def raise_error(status):
|
119 |
+
if status != "Valid file":
|
120 |
+
raise gr.Error(status)
|
121 |
+
return None
|
122 |
+
|
123 |
+
|
124 |
+
def clear_file(download_button):
|
125 |
+
# we might need to delete the prediction file and uploaded file
|
126 |
+
prediction_path = download_button
|
127 |
+
print(prediction_path)
|
128 |
+
if prediction_path and os.path.exists(prediction_path):
|
129 |
+
os.remove(prediction_path)
|
130 |
+
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
|
131 |
+
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
|
132 |
+
if os.path.exists(original_data_file_0):
|
133 |
+
os.remove(original_data_file_0)
|
134 |
+
if os.path.exists(original_data_file_1):
|
135 |
+
os.remove(original_data_file_1)
|
136 |
+
#if os.path.exists(file):
|
137 |
+
# os.remove(file)
|
138 |
+
#prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
139 |
+
#if os.path.exists(prediction_file):
|
140 |
+
# os.remove(prediction_file)
|
141 |
+
|
142 |
+
|
143 |
+
return gr.update(visible=False), gr.update(visible=False), None
|
144 |
+
|
145 |
+
def build_inference():
|
146 |
+
|
147 |
+
with gr.Blocks() as demo:
|
148 |
+
# first row - Dropdown input
|
149 |
+
#with gr.Row():
|
150 |
+
dropdown = gr.Dropdown(properties, label="Property", value=properties[0])
|
151 |
+
description_box = gr.Textbox(label="Property description", lines=5,
|
152 |
+
interactive=False,
|
153 |
+
value=dataset_descriptions[properties[0]])
|
154 |
+
# third row - Textbox input and prediction label
|
155 |
+
with gr.Row(equal_height=True):
|
156 |
+
with gr.Column():
|
157 |
+
textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
|
158 |
+
lines=1)
|
159 |
+
predict_single_smiles_button = gr.Button("Predict", size='sm')
|
160 |
+
prediction = gr.Label("Prediction will appear here")
|
161 |
+
|
162 |
+
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
|
163 |
+
|
164 |
+
input_file = gr.File(label="Molecule file",
|
165 |
+
file_count='single',
|
166 |
+
file_types=[".smi", ".csv"], height=300)
|
167 |
+
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
168 |
+
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
169 |
+
|
170 |
+
# dropdown change event
|
171 |
+
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
172 |
+
# predict single button click event
|
173 |
+
predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])
|
174 |
+
# input file upload event
|
175 |
+
file_status = gr.State()
|
176 |
+
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
177 |
+
# input file clear event
|
178 |
+
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
179 |
+
# predict file button click event
|
180 |
+
predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label])
|
181 |
+
|
182 |
+
return demo
|
183 |
+
|
184 |
+
|
185 |
+
demo = build_inference()
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
demo.launch()
|
dataset_descriptions.json
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ADMET_Caco2_Wang": {
|
3 |
+
"task_type": "regression",
|
4 |
+
"description": "predict drug permeability, measured in cm/s, using the Caco-2 cell line as an in vitro model to simulate human intestinal tissue permeability",
|
5 |
+
"num_molecules": 906
|
6 |
+
},
|
7 |
+
"ADMET_Bioavailability_Ma": {
|
8 |
+
"task_type": "classification",
|
9 |
+
"description": "predict oral bioavailability with binary labels, indicating the rate and extent a drug becomes available at its site of action",
|
10 |
+
"num_molecules": 640
|
11 |
+
},
|
12 |
+
"ADMET_Lipophilicity_AstraZeneca": {
|
13 |
+
"task_type": "regression",
|
14 |
+
"description": "predict lipophilicity with continuous labels, measured as a log-ratio, indicating a drug's ability to dissolve in lipid environments",
|
15 |
+
"num_molecules": 4200
|
16 |
+
},
|
17 |
+
"ADMET_Solubility_AqSolDB": {
|
18 |
+
"task_type": "regression",
|
19 |
+
"description": "predict aqueous solubility with continuous labels, measured in log mol/L, indicating a drug's ability to dissolve in water",
|
20 |
+
"num_molecules": 9982
|
21 |
+
},
|
22 |
+
"ADMET_HIA_Hou": {
|
23 |
+
"task_type": "classification",
|
24 |
+
"description": "predict human intestinal absorption (HIA) with binary labels, indicating a drug's ability to be absorbed into the bloodstream",
|
25 |
+
"num_molecules": 578
|
26 |
+
},
|
27 |
+
"ADMET_Pgp_Broccatelli": {
|
28 |
+
"task_type": "classification",
|
29 |
+
"description": "predict P-glycoprotein (Pgp) inhibition with binary labels, indicating a drug's potential to alter bioavailability and overcome multidrug resistance",
|
30 |
+
"num_molecules": 1212
|
31 |
+
},
|
32 |
+
"ADMET_BBB_Martins": {
|
33 |
+
"task_type": "classification",
|
34 |
+
"description": "predict blood-brain barrier permeability with binary labels, indicating a drug's ability to penetrate the barrier to reach the brain",
|
35 |
+
"num_molecules": 1915
|
36 |
+
},
|
37 |
+
"ADMET_PPBR_AZ": {
|
38 |
+
"task_type": "regression",
|
39 |
+
"description": "predict plasma protein binding rate with continuous labels, indicating the percentage of a drug bound to plasma proteins in the blood",
|
40 |
+
"num_molecules": 1797
|
41 |
+
},
|
42 |
+
"ADMET_VDss_Lombardo": {
|
43 |
+
"task_type": "regression",
|
44 |
+
"description": "predict the volume of distribution at steady state (VDss), indicating drug concentration in tissues versus blood",
|
45 |
+
"num_molecules": 1130
|
46 |
+
},
|
47 |
+
"ADMET_CYP2C9_Veith": {
|
48 |
+
"task_type": "classification",
|
49 |
+
"description": "predict CYP2C9 inhibition with binary labels, indicating the drug's ability to inhibit the CYP2C9 enzyme involved in metabolism",
|
50 |
+
"num_molecules": 12092
|
51 |
+
},
|
52 |
+
"ADMET_CYP2D6_Veith": {
|
53 |
+
"task_type": "classification",
|
54 |
+
"description": "predict CYP2D6 inhibition with binary labels, indicating the drug's potential to inhibit the CYP2D6 enzyme involved in metabolism",
|
55 |
+
"num_molecules": 13130
|
56 |
+
},
|
57 |
+
"ADMET_CYP3A4_Veith": {
|
58 |
+
"task_type": "classification",
|
59 |
+
"description": "predict CPY3A4 inhibition with binary labels, indicating the drug's ability to inhibit the CPY3A4 enzyme involved in metabolism",
|
60 |
+
"num_molecules": 12328
|
61 |
+
},
|
62 |
+
"ADMET_CYP2C9_Substrate_CarbonMangels": {
|
63 |
+
"task_type": "classification",
|
64 |
+
"description": "predict whether a drug is a substrate of the CYP2C9 enzyme with binary labels, indicating its potential to be metabolized",
|
65 |
+
"num_molecules": 666
|
66 |
+
},
|
67 |
+
"ADMET_CYP2D6_Substrate_CarbonMangels": {
|
68 |
+
"task_type": "classification",
|
69 |
+
"description": "predict whether a drug is a substrate of the CYP2D6 enzyme with binary labels, indicating its potential to be metabolized",
|
70 |
+
"num_molecules": 664
|
71 |
+
},
|
72 |
+
"ADMET_CYP3A4_Substrate_CarbonMangels": {
|
73 |
+
"task_type": "classification",
|
74 |
+
"description": "predict whether a drug is a substrate of the CYP3A4 enzyme with binary labels, indicating its potential to be metabolized",
|
75 |
+
"num_molecules": 667
|
76 |
+
},
|
77 |
+
"ADMET_Half_Life_Obach": {
|
78 |
+
"task_type": "regression",
|
79 |
+
"description": "predict the half-life duration of a drug, measured in hours, indicating the time for its concentration to reduce by half",
|
80 |
+
"num_molecules": 667
|
81 |
+
},
|
82 |
+
"ADMET_Clearance_Hepatocyte_AZ": {
|
83 |
+
"task_type": "regression",
|
84 |
+
"description": "predict drug clearance, measured in \u03bcL/min/10^6 cells, from hepatocyte experiments, indicating the rate at which the drug is removed from body",
|
85 |
+
"num_molecules": 1020
|
86 |
+
},
|
87 |
+
"ADMET_Clearance_Microsome_AZ": {
|
88 |
+
"task_type": "regression",
|
89 |
+
"description": "predict drug clearance, measured in mL/min/g, from microsome experiments, indicating the rate at which the drug is removed from body",
|
90 |
+
"num_molecules": 1102
|
91 |
+
},
|
92 |
+
"ADMET_LD50_Zhu": {
|
93 |
+
"task_type": "regression",
|
94 |
+
"description": "predict the acute toxicity of a drug, measured as the dose leading to lethal effects in log(kg/mol)",
|
95 |
+
"num_molecules": 7385
|
96 |
+
},
|
97 |
+
"ADMET_hERG": {
|
98 |
+
"task_type": "classification",
|
99 |
+
"description": "predict whether a drug blocks the hERG channel, which is crucial for heart rhythm, potentially leading to adverse effects",
|
100 |
+
"num_molecules": 648
|
101 |
+
},
|
102 |
+
"ADMET_AMES": {
|
103 |
+
"task_type": "classification",
|
104 |
+
"description": "predict whether a drug is mutagenic with binary labels, indicating its ability to induce genetic alterations",
|
105 |
+
"num_molecules": 7255
|
106 |
+
},
|
107 |
+
"ADMET_DILI": {
|
108 |
+
"task_type": "classification",
|
109 |
+
"description": "predict whether a drug can cause liver injury with binary labels, indicating its potential for hepatotoxicity",
|
110 |
+
"num_molecules": 475
|
111 |
+
}
|
112 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
|
2 |
+
from typing import Optional, Dict, Sequence, List
|
3 |
+
import transformers
|
4 |
+
from peft import PeftModel
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import pandas as pd
|
9 |
+
from datasets import Dataset
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
import os
|
14 |
+
import pickle
|
15 |
+
from sklearn import preprocessing
|
16 |
+
import json
|
17 |
+
|
18 |
+
|
19 |
+
from rdkit import RDLogger, Chem
|
20 |
+
# Suppress RDKit INFO messages
|
21 |
+
RDLogger.DisableLog('rdApp.*')
|
22 |
+
|
23 |
+
# we have a dictionary to store the task types of the models
|
24 |
+
task_types = {
|
25 |
+
"admet_ppbr_az": "regression",
|
26 |
+
"admet_half_life_obach": "regression",
|
27 |
+
}
|
28 |
+
|
29 |
+
# read the dataset descriptions
|
30 |
+
with open("dataset_descriptions.json", "r") as f:
|
31 |
+
dataset_description_temp = json.load(f)
|
32 |
+
|
33 |
+
dataset_descriptions = dict()
|
34 |
+
|
35 |
+
for dataset in dataset_description_temp:
|
36 |
+
dataset_name = dataset.lower()
|
37 |
+
dataset_descriptions[dataset_name] = \
|
38 |
+
f"{dataset_name} is a {dataset_description_temp[dataset]['task_type']} task, " + \
|
39 |
+
f"where the goal is to {dataset_description_temp[dataset]['description']}."
|
40 |
+
|
41 |
+
class Scaler:
|
42 |
+
def __init__(self, log=False):
|
43 |
+
self.log = log
|
44 |
+
self.offset = None
|
45 |
+
self.scaler = None
|
46 |
+
|
47 |
+
def fit(self, y):
|
48 |
+
# make the values non-negative
|
49 |
+
self.offset = np.min([np.min(y), 0.0])
|
50 |
+
y = y.reshape(-1, 1) - self.offset
|
51 |
+
|
52 |
+
# scale the input data
|
53 |
+
if self.log:
|
54 |
+
y = np.log10(y + 1.0)
|
55 |
+
|
56 |
+
self.scaler = preprocessing.StandardScaler().fit(y)
|
57 |
+
|
58 |
+
def transform(self, y):
|
59 |
+
y = y.reshape(-1, 1) - self.offset
|
60 |
+
|
61 |
+
# scale the input data
|
62 |
+
if self.log:
|
63 |
+
y = np.log10(y + 1.0)
|
64 |
+
|
65 |
+
y_scale = self.scaler.transform(y)
|
66 |
+
|
67 |
+
return y_scale
|
68 |
+
|
69 |
+
def inverse_transform(self, y_scale):
|
70 |
+
y = self.scaler.inverse_transform(y_scale.reshape(-1, 1))
|
71 |
+
|
72 |
+
if self.log:
|
73 |
+
y = 10.0**y - 1.0
|
74 |
+
|
75 |
+
y = y + self.offset
|
76 |
+
|
77 |
+
return y
|
78 |
+
|
79 |
+
|
80 |
+
def smart_tokenizer_and_embedding_resize(
|
81 |
+
special_tokens_dict: Dict,
|
82 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
83 |
+
model: transformers.PreTrainedModel,
|
84 |
+
non_special_tokens = None,
|
85 |
+
):
|
86 |
+
"""Resize tokenizer and embedding.
|
87 |
+
|
88 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
89 |
+
"""
|
90 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
|
91 |
+
num_old_tokens = model.get_input_embeddings().weight.shape[0]
|
92 |
+
num_new_tokens = len(tokenizer) - num_old_tokens
|
93 |
+
if num_new_tokens == 0:
|
94 |
+
return
|
95 |
+
|
96 |
+
model.resize_token_embeddings(len(tokenizer))
|
97 |
+
|
98 |
+
if num_new_tokens > 0:
|
99 |
+
input_embeddings_data = model.get_input_embeddings().weight.data
|
100 |
+
|
101 |
+
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
|
102 |
+
|
103 |
+
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
|
104 |
+
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DataCollator(object):
|
108 |
+
tokenizer: transformers.PreTrainedTokenizer
|
109 |
+
source_max_len: int
|
110 |
+
molecule_start_str: str
|
111 |
+
end_str: str
|
112 |
+
|
113 |
+
def augment_molecule(self, molecule: str) -> str:
|
114 |
+
return self.sme.augment([molecule])[0]
|
115 |
+
|
116 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
117 |
+
|
118 |
+
sources = []
|
119 |
+
targets = []
|
120 |
+
|
121 |
+
for example in instances:
|
122 |
+
smiles = example['smiles'].strip()
|
123 |
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
124 |
+
|
125 |
+
# get the properties except the smiles and mol_id cols
|
126 |
+
#props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
|
127 |
+
source = f"{self.molecule_start_str}{smiles}{self.end_str}"
|
128 |
+
sources.append(source)
|
129 |
+
|
130 |
+
# Tokenize
|
131 |
+
tokenized_sources_with_prompt = self.tokenizer(
|
132 |
+
sources,
|
133 |
+
max_length=self.source_max_len,
|
134 |
+
truncation=True,
|
135 |
+
add_special_tokens=False,
|
136 |
+
)
|
137 |
+
input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
|
138 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
139 |
+
|
140 |
+
data_dict = {
|
141 |
+
'input_ids': input_ids,
|
142 |
+
'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
|
143 |
+
}
|
144 |
+
|
145 |
+
return data_dict
|
146 |
+
|
147 |
+
class MolecularPropertyPredictionModel():
|
148 |
+
def __init__(self):
|
149 |
+
self.adapter_name = None
|
150 |
+
|
151 |
+
# we need to keep track of the paths of adapter scalers
|
152 |
+
# we don't want to download the same scaler multiple times
|
153 |
+
self.apapter_scaler_path = dict()
|
154 |
+
|
155 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
156 |
+
|
157 |
+
# load the base model
|
158 |
+
config = AutoConfig.from_pretrained(
|
159 |
+
"ChemFM/ChemFM-3B",
|
160 |
+
num_labels=1,
|
161 |
+
finetuning_task="classification", # this is not about our task type
|
162 |
+
trust_remote_code=True,
|
163 |
+
)
|
164 |
+
|
165 |
+
self.base_model = AutoModelForSequenceClassification.from_pretrained(
|
166 |
+
"ChemFM/ChemFM-3B",
|
167 |
+
config=config,
|
168 |
+
device_map="cpu",
|
169 |
+
trust_remote_code=True,
|
170 |
+
)
|
171 |
+
|
172 |
+
# load the tokenizer
|
173 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
174 |
+
"ChemFM/admet_ppbr_az",
|
175 |
+
trust_remote_code=True,
|
176 |
+
)
|
177 |
+
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
|
178 |
+
smart_tokenizer_and_embedding_resize(
|
179 |
+
special_tokens_dict=special_tokens_dict,
|
180 |
+
tokenizer=self.tokenizer,
|
181 |
+
model=self.base_model
|
182 |
+
)
|
183 |
+
self.base_model.config.pad_token_id = self.tokenizer.pad_token_id
|
184 |
+
|
185 |
+
self.data_collator = DataCollator(
|
186 |
+
tokenizer=self.tokenizer,
|
187 |
+
source_max_len=512,
|
188 |
+
molecule_start_str="<molstart>",
|
189 |
+
end_str="<eos>",
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
def swith_adapter(self, adapter_name, adapter_id):
|
194 |
+
# return flag:
|
195 |
+
# keep: adapter is the same as the current one
|
196 |
+
# switched: adapter is switched successfully
|
197 |
+
# error: adapter is not found
|
198 |
+
|
199 |
+
if adapter_name == self.adapter_name:
|
200 |
+
return "keep"
|
201 |
+
# switch adapter
|
202 |
+
try:
|
203 |
+
self.adapter_name = adapter_name
|
204 |
+
self.lora_model = PeftModel.from_pretrained(self.base_model, adapter_id)
|
205 |
+
if adapter_name not in self.apapter_scaler_path:
|
206 |
+
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl")
|
207 |
+
if os.path.exists(self.apapter_scaler_path[adapter_name]):
|
208 |
+
self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
|
209 |
+
else:
|
210 |
+
self.scaler = None
|
211 |
+
|
212 |
+
return "switched"
|
213 |
+
except Exception as e:
|
214 |
+
# handle error
|
215 |
+
return "error"
|
216 |
+
|
217 |
+
def predict(self, valid_df, task_type):
|
218 |
+
test_dataset = Dataset.from_pandas(valid_df)
|
219 |
+
# construct the dataloader
|
220 |
+
test_loader = torch.utils.data.DataLoader(
|
221 |
+
test_dataset,
|
222 |
+
batch_size=4,
|
223 |
+
collate_fn=self.data_collator,
|
224 |
+
)
|
225 |
+
# predict
|
226 |
+
|
227 |
+
y_pred = []
|
228 |
+
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
229 |
+
with torch.no_grad():
|
230 |
+
batch = {k: v.to(self.lora_model.device) for k, v in batch.items()}
|
231 |
+
outputs = self.lora_model(**batch)
|
232 |
+
if task_type == "regression": # TODO: check if the model is regression or classification
|
233 |
+
y_pred.append(outputs.logits.cpu().detach().numpy())
|
234 |
+
else:
|
235 |
+
y_pred.append((torch.sigmoid(outputs.logits) > 0.5).cpu().detach().numpy())
|
236 |
+
|
237 |
+
y_pred = np.concatenate(y_pred, axis=0)
|
238 |
+
if task_type=="regression" and self.scaler is not None:
|
239 |
+
y_pred = self.scaler.inverse_transform(y_pred)
|
240 |
+
|
241 |
+
|
242 |
+
return y_pred
|
243 |
+
|
244 |
+
def predict_single_smiles(self, smiles, task_type):
|
245 |
+
assert task_type in ["regression", "classification"]
|
246 |
+
|
247 |
+
# check the SMILES string is valid
|
248 |
+
if not Chem.MolFromSmiles(smiles):
|
249 |
+
return None
|
250 |
+
|
251 |
+
valid_df = pd.DataFrame([smiles], columns=['smiles'])
|
252 |
+
results = self.predict(valid_df, task_type)
|
253 |
+
# predict
|
254 |
+
return results.item()
|
255 |
+
|
256 |
+
def predict_file(self, df, task_type):
|
257 |
+
# we should add the index first
|
258 |
+
df = df.reset_index()
|
259 |
+
# we need to check the SMILES strings are valid, the invalid ones will be moved to the last
|
260 |
+
valid_idx = []
|
261 |
+
invalid_idx = []
|
262 |
+
for idx, smiles in enumerate(df['smiles']):
|
263 |
+
if Chem.MolFromSmiles(smiles):
|
264 |
+
valid_idx.append(idx)
|
265 |
+
else:
|
266 |
+
invalid_idx.append(idx)
|
267 |
+
valid_df = df.loc[valid_idx]
|
268 |
+
# get the smiles list
|
269 |
+
valid_df_smiles = valid_df['smiles'].tolist()
|
270 |
+
|
271 |
+
input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
|
272 |
+
results = self.predict(input_df, task_type)
|
273 |
+
|
274 |
+
# add the results to the dataframe
|
275 |
+
df.loc[valid_idx, 'prediction'] = results
|
276 |
+
df.loc[invalid_idx, 'prediction'] = np.nan
|
277 |
+
|
278 |
+
# drop the index column
|
279 |
+
df = df.drop(columns=['index'])
|
280 |
+
|
281 |
+
# phrase file
|
282 |
+
return df
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|