feiyang-cai commited on
Commit
ec780ac
·
1 Parent(s): 2a36eec

freeze other components when the prediction is processing

Browse files
Files changed (2) hide show
  1. app.py +102 -53
  2. utils.py +3 -2
app.py CHANGED
@@ -25,65 +25,75 @@ def get_description(property_name):
25
  return dataset_descriptions[property_name]
26
 
27
  def predict_single_label(smiles, property_name):
28
- adapter_id = candidate_models[property_name]
29
- info = model.swith_adapter(property_name, adapter_id)
30
-
31
- running_status = None
32
- if info == "keep":
33
- running_status = "Adapter is the same as the current one"
34
- #print("Adapter is the same as the current one")
35
- elif info == "switched":
36
- running_status = "Adapter is switched successfully"
37
- #print("Adapter is switched successfully")
38
- elif info == "error":
39
- running_status = "Adapter is not found"
40
- #print("Adapter is not found")
41
- return "NA", running_status
42
- else:
43
- running_status = "Unknown error"
44
- return "NA", running_status
 
45
 
46
- #prediction = model.predict(smiles, property_name, adapter_id)
47
- prediction = model.predict_single_smiles(smiles, task_types[property_name])
48
- if prediction is None:
49
- return "NA", "Invalid SMILES string"
50
 
51
- # if the prediction is a float, round it to 3 decimal places
52
- if isinstance(prediction, float):
53
- prediction = round(prediction, 3)
 
 
 
 
54
 
55
  return prediction, "Prediction is done"
56
 
57
  def predict_file(file, property_name):
58
- adapter_id = candidate_models[property_name]
59
- info = model.swith_adapter(property_name, adapter_id)
60
-
61
- running_status = None
62
- if info == "keep":
63
- running_status = "Adapter is the same as the current one"
64
- #print("Adapter is the same as the current one")
65
- elif info == "switched":
66
- running_status = "Adapter is switched successfully"
67
- #print("Adapter is switched successfully")
68
- elif info == "error":
69
- running_status = "Adapter is not found"
70
- #print("Adapter is not found")
71
- return None, None, file, running_status
72
- else:
73
- running_status = "Unknown error"
74
- return None, None, file, running_status
 
75
 
76
- df = pd.read_csv(file)
77
- # we have already checked the file contains the "smiles" column
78
- df = model.predict_file(df, task_types[property_name])
79
- # we should save this file to the disk to be downloaded
80
- # rename the file to have "_prediction" suffix
81
- prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
82
- print(file, prediction_file)
83
- # save the file to the disk
84
- df.to_csv(prediction_file, index=False)
 
 
 
 
85
 
86
- return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
87
 
88
  def validate_file(file):
89
  try:
@@ -166,18 +176,57 @@ def build_inference():
166
  file_types=[".smi", ".csv"], height=300)
167
  predict_file_button = gr.Button("Predict", size='sm', visible=False)
168
  download_button = gr.DownloadButton("Download", size='sm', visible=False)
 
169
 
170
  # dropdown change event
171
  dropdown.change(get_description, inputs=dropdown, outputs=description_box)
172
  # predict single button click event
173
- predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # input file upload event
175
  file_status = gr.State()
176
  input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
177
  # input file clear event
178
  input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
179
  # predict file button click event
180
- predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  return demo
183
 
 
25
  return dataset_descriptions[property_name]
26
 
27
  def predict_single_label(smiles, property_name):
28
+ try:
29
+ adapter_id = candidate_models[property_name]
30
+ info = model.swith_adapter(property_name, adapter_id)
31
+
32
+ running_status = None
33
+ if info == "keep":
34
+ running_status = "Adapter is the same as the current one"
35
+ #print("Adapter is the same as the current one")
36
+ elif info == "switched":
37
+ running_status = "Adapter is switched successfully"
38
+ #print("Adapter is switched successfully")
39
+ elif info == "error":
40
+ running_status = "Adapter is not found"
41
+ #print("Adapter is not found")
42
+ return "NA", running_status
43
+ else:
44
+ running_status = "Unknown error"
45
+ return "NA", running_status
46
 
47
+ #prediction = model.predict(smiles, property_name, adapter_id)
48
+ prediction = model.predict_single_smiles(smiles, task_types[property_name])
49
+ if prediction is None:
50
+ return "NA", "Invalid SMILES string"
51
 
52
+ # if the prediction is a float, round it to 3 decimal places
53
+ if isinstance(prediction, float):
54
+ prediction = round(prediction, 3)
55
+ except Exception as e:
56
+ # no matter what the error is, we should return
57
+ print(e)
58
+ return "NA", "Prediction failed"
59
 
60
  return prediction, "Prediction is done"
61
 
62
  def predict_file(file, property_name):
63
+ try:
64
+ adapter_id = candidate_models[property_name]
65
+ info = model.swith_adapter(property_name, adapter_id)
66
+
67
+ running_status = None
68
+ if info == "keep":
69
+ running_status = "Adapter is the same as the current one"
70
+ #print("Adapter is the same as the current one")
71
+ elif info == "switched":
72
+ running_status = "Adapter is switched successfully"
73
+ #print("Adapter is switched successfully")
74
+ elif info == "error":
75
+ running_status = "Adapter is not found"
76
+ #print("Adapter is not found")
77
+ return None, None, file, running_status
78
+ else:
79
+ running_status = "Unknown error"
80
+ return None, None, file, running_status
81
 
82
+ df = pd.read_csv(file)
83
+ # we have already checked the file contains the "smiles" column
84
+ df = model.predict_file(df, task_types[property_name])
85
+ # we should save this file to the disk to be downloaded
86
+ # rename the file to have "_prediction" suffix
87
+ prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
88
+ print(file, prediction_file)
89
+ # save the file to the disk
90
+ df.to_csv(prediction_file, index=False)
91
+ except Exception as e:
92
+ # no matter what the error is, we should return
93
+ print(e)
94
+ return None, None, gr.update(visible=False), file, "Prediction failed"
95
 
96
+ return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
97
 
98
  def validate_file(file):
99
  try:
 
176
  file_types=[".smi", ".csv"], height=300)
177
  predict_file_button = gr.Button("Predict", size='sm', visible=False)
178
  download_button = gr.DownloadButton("Download", size='sm', visible=False)
179
+ stop_button = gr.Button("Stop", size='sm', visible=False)
180
 
181
  # dropdown change event
182
  dropdown.change(get_description, inputs=dropdown, outputs=description_box)
183
  # predict single button click event
184
+ predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
185
+ gr.update(interactive=False),
186
+ gr.update(interactive=False),
187
+ gr.update(interactive=False),
188
+ gr.update(interactive=False),
189
+ gr.update(interactive=False),
190
+ gr.update(interactive=False),
191
+ gr.update(interactive=False),
192
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
193
+ .then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
194
+ .then(lambda:(gr.update(interactive=True),
195
+ gr.update(interactive=True),
196
+ gr.update(interactive=True),
197
+ gr.update(interactive=True),
198
+ gr.update(interactive=True),
199
+ gr.update(interactive=True),
200
+ gr.update(interactive=True),
201
+ gr.update(interactive=True),
202
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
203
  # input file upload event
204
  file_status = gr.State()
205
  input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
206
  # input file clear event
207
  input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
208
  # predict file button click event
209
+ predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
210
+ gr.update(interactive=False),
211
+ gr.update(interactive=False),
212
+ gr.update(interactive=False, visible=True),
213
+ gr.update(interactive=False),
214
+ gr.update(interactive=True, visible=False),
215
+ gr.update(interactive=False),
216
+ gr.update(interactive=False),
217
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
218
+ .then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
219
+ .then(lambda:(gr.update(interactive=True),
220
+ gr.update(interactive=True),
221
+ gr.update(interactive=True),
222
+ gr.update(interactive=True),
223
+ gr.update(interactive=True),
224
+ gr.update(interactive=True),
225
+ gr.update(interactive=True),
226
+ gr.update(interactive=True),
227
+ ) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
228
+ # stop button click event
229
+ #stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
230
 
231
  return demo
232
 
utils.py CHANGED
@@ -201,7 +201,7 @@ class MolecularPropertyPredictionModel():
201
  self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
202
 
203
  #self.base_model.to("cuda")
204
- print(self.base_model)
205
 
206
  def swith_adapter(self, adapter_name, adapter_id):
207
  # return flag:
@@ -220,6 +220,7 @@ class MolecularPropertyPredictionModel():
220
  #print(self.lora_model)
221
 
222
  self.base_model.set_adapter(adapter_name)
 
223
 
224
  #if adapter_name not in self.apapter_scaler_path:
225
  # self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
@@ -244,8 +245,8 @@ class MolecularPropertyPredictionModel():
244
  batch_size=16,
245
  collate_fn=self.data_collator,
246
  )
247
- # predict
248
 
 
249
  y_pred = []
250
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
251
  with torch.no_grad():
 
201
  self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
202
 
203
  #self.base_model.to("cuda")
204
+ #print(self.base_model)
205
 
206
  def swith_adapter(self, adapter_name, adapter_id):
207
  # return flag:
 
220
  #print(self.lora_model)
221
 
222
  self.base_model.set_adapter(adapter_name)
223
+ self.base_model.eval()
224
 
225
  #if adapter_name not in self.apapter_scaler_path:
226
  # self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
 
245
  batch_size=16,
246
  collate_fn=self.data_collator,
247
  )
 
248
 
249
+ # predict
250
  y_pred = []
251
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
252
  with torch.no_grad():