Spaces:
Running
on
Zero
Running
on
Zero
feiyang-cai
commited on
Commit
·
ec780ac
1
Parent(s):
2a36eec
freeze other components when the prediction is processing
Browse files
app.py
CHANGED
@@ -25,65 +25,75 @@ def get_description(property_name):
|
|
25 |
return dataset_descriptions[property_name]
|
26 |
|
27 |
def predict_single_label(smiles, property_name):
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
|
55 |
return prediction, "Prediction is done"
|
56 |
|
57 |
def predict_file(file, property_name):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
|
87 |
|
88 |
def validate_file(file):
|
89 |
try:
|
@@ -166,18 +176,57 @@ def build_inference():
|
|
166 |
file_types=[".smi", ".csv"], height=300)
|
167 |
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
168 |
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
|
|
169 |
|
170 |
# dropdown change event
|
171 |
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
172 |
# predict single button click event
|
173 |
-
predict_single_smiles_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# input file upload event
|
175 |
file_status = gr.State()
|
176 |
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
177 |
# input file clear event
|
178 |
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
179 |
# predict file button click event
|
180 |
-
predict_file_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
return demo
|
183 |
|
|
|
25 |
return dataset_descriptions[property_name]
|
26 |
|
27 |
def predict_single_label(smiles, property_name):
|
28 |
+
try:
|
29 |
+
adapter_id = candidate_models[property_name]
|
30 |
+
info = model.swith_adapter(property_name, adapter_id)
|
31 |
+
|
32 |
+
running_status = None
|
33 |
+
if info == "keep":
|
34 |
+
running_status = "Adapter is the same as the current one"
|
35 |
+
#print("Adapter is the same as the current one")
|
36 |
+
elif info == "switched":
|
37 |
+
running_status = "Adapter is switched successfully"
|
38 |
+
#print("Adapter is switched successfully")
|
39 |
+
elif info == "error":
|
40 |
+
running_status = "Adapter is not found"
|
41 |
+
#print("Adapter is not found")
|
42 |
+
return "NA", running_status
|
43 |
+
else:
|
44 |
+
running_status = "Unknown error"
|
45 |
+
return "NA", running_status
|
46 |
|
47 |
+
#prediction = model.predict(smiles, property_name, adapter_id)
|
48 |
+
prediction = model.predict_single_smiles(smiles, task_types[property_name])
|
49 |
+
if prediction is None:
|
50 |
+
return "NA", "Invalid SMILES string"
|
51 |
|
52 |
+
# if the prediction is a float, round it to 3 decimal places
|
53 |
+
if isinstance(prediction, float):
|
54 |
+
prediction = round(prediction, 3)
|
55 |
+
except Exception as e:
|
56 |
+
# no matter what the error is, we should return
|
57 |
+
print(e)
|
58 |
+
return "NA", "Prediction failed"
|
59 |
|
60 |
return prediction, "Prediction is done"
|
61 |
|
62 |
def predict_file(file, property_name):
|
63 |
+
try:
|
64 |
+
adapter_id = candidate_models[property_name]
|
65 |
+
info = model.swith_adapter(property_name, adapter_id)
|
66 |
+
|
67 |
+
running_status = None
|
68 |
+
if info == "keep":
|
69 |
+
running_status = "Adapter is the same as the current one"
|
70 |
+
#print("Adapter is the same as the current one")
|
71 |
+
elif info == "switched":
|
72 |
+
running_status = "Adapter is switched successfully"
|
73 |
+
#print("Adapter is switched successfully")
|
74 |
+
elif info == "error":
|
75 |
+
running_status = "Adapter is not found"
|
76 |
+
#print("Adapter is not found")
|
77 |
+
return None, None, file, running_status
|
78 |
+
else:
|
79 |
+
running_status = "Unknown error"
|
80 |
+
return None, None, file, running_status
|
81 |
|
82 |
+
df = pd.read_csv(file)
|
83 |
+
# we have already checked the file contains the "smiles" column
|
84 |
+
df = model.predict_file(df, task_types[property_name])
|
85 |
+
# we should save this file to the disk to be downloaded
|
86 |
+
# rename the file to have "_prediction" suffix
|
87 |
+
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
88 |
+
print(file, prediction_file)
|
89 |
+
# save the file to the disk
|
90 |
+
df.to_csv(prediction_file, index=False)
|
91 |
+
except Exception as e:
|
92 |
+
# no matter what the error is, we should return
|
93 |
+
print(e)
|
94 |
+
return None, None, gr.update(visible=False), file, "Prediction failed"
|
95 |
|
96 |
+
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
|
97 |
|
98 |
def validate_file(file):
|
99 |
try:
|
|
|
176 |
file_types=[".smi", ".csv"], height=300)
|
177 |
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
178 |
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
179 |
+
stop_button = gr.Button("Stop", size='sm', visible=False)
|
180 |
|
181 |
# dropdown change event
|
182 |
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
183 |
# predict single button click event
|
184 |
+
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
|
185 |
+
gr.update(interactive=False),
|
186 |
+
gr.update(interactive=False),
|
187 |
+
gr.update(interactive=False),
|
188 |
+
gr.update(interactive=False),
|
189 |
+
gr.update(interactive=False),
|
190 |
+
gr.update(interactive=False),
|
191 |
+
gr.update(interactive=False),
|
192 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
193 |
+
.then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
|
194 |
+
.then(lambda:(gr.update(interactive=True),
|
195 |
+
gr.update(interactive=True),
|
196 |
+
gr.update(interactive=True),
|
197 |
+
gr.update(interactive=True),
|
198 |
+
gr.update(interactive=True),
|
199 |
+
gr.update(interactive=True),
|
200 |
+
gr.update(interactive=True),
|
201 |
+
gr.update(interactive=True),
|
202 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
|
203 |
# input file upload event
|
204 |
file_status = gr.State()
|
205 |
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
206 |
# input file clear event
|
207 |
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
208 |
# predict file button click event
|
209 |
+
predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
|
210 |
+
gr.update(interactive=False),
|
211 |
+
gr.update(interactive=False),
|
212 |
+
gr.update(interactive=False, visible=True),
|
213 |
+
gr.update(interactive=False),
|
214 |
+
gr.update(interactive=True, visible=False),
|
215 |
+
gr.update(interactive=False),
|
216 |
+
gr.update(interactive=False),
|
217 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
218 |
+
.then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
219 |
+
.then(lambda:(gr.update(interactive=True),
|
220 |
+
gr.update(interactive=True),
|
221 |
+
gr.update(interactive=True),
|
222 |
+
gr.update(interactive=True),
|
223 |
+
gr.update(interactive=True),
|
224 |
+
gr.update(interactive=True),
|
225 |
+
gr.update(interactive=True),
|
226 |
+
gr.update(interactive=True),
|
227 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
|
228 |
+
# stop button click event
|
229 |
+
#stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
|
230 |
|
231 |
return demo
|
232 |
|
utils.py
CHANGED
@@ -201,7 +201,7 @@ class MolecularPropertyPredictionModel():
|
|
201 |
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
202 |
|
203 |
#self.base_model.to("cuda")
|
204 |
-
print(self.base_model)
|
205 |
|
206 |
def swith_adapter(self, adapter_name, adapter_id):
|
207 |
# return flag:
|
@@ -220,6 +220,7 @@ class MolecularPropertyPredictionModel():
|
|
220 |
#print(self.lora_model)
|
221 |
|
222 |
self.base_model.set_adapter(adapter_name)
|
|
|
223 |
|
224 |
#if adapter_name not in self.apapter_scaler_path:
|
225 |
# self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
@@ -244,8 +245,8 @@ class MolecularPropertyPredictionModel():
|
|
244 |
batch_size=16,
|
245 |
collate_fn=self.data_collator,
|
246 |
)
|
247 |
-
# predict
|
248 |
|
|
|
249 |
y_pred = []
|
250 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
251 |
with torch.no_grad():
|
|
|
201 |
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
202 |
|
203 |
#self.base_model.to("cuda")
|
204 |
+
#print(self.base_model)
|
205 |
|
206 |
def swith_adapter(self, adapter_name, adapter_id):
|
207 |
# return flag:
|
|
|
220 |
#print(self.lora_model)
|
221 |
|
222 |
self.base_model.set_adapter(adapter_name)
|
223 |
+
self.base_model.eval()
|
224 |
|
225 |
#if adapter_name not in self.apapter_scaler_path:
|
226 |
# self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
|
|
245 |
batch_size=16,
|
246 |
collate_fn=self.data_collator,
|
247 |
)
|
|
|
248 |
|
249 |
+
# predict
|
250 |
y_pred = []
|
251 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
252 |
with torch.no_grad():
|