feiyang-cai commited on
Commit
6bb1bdf
·
1 Parent(s): 50fe1a2

finish the basic function

Browse files
Files changed (3) hide show
  1. app.py +185 -4
  2. dataset_descriptions.json +112 -0
  3. utils.py +286 -0
app.py CHANGED
@@ -1,7 +1,188 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import HfApi, get_collection, list_collections
3
+ from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions
4
+ import pandas as pd
5
+ import os
6
 
7
+ def get_models():
8
+ # this is the collection id for the molecular property prediction models
9
+ collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c")
10
+ models = dict()
11
+ for item in collection.items:
12
+ if item.item_type == "model":
13
+ item_name = item.item_id.split("/")[-1]
14
+ models[item_name] = item.item_id
15
+ assert item_name in task_types, f"{item_name} is not in the task_types"
16
+ assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
17
+
18
+ return models
19
 
20
+ candidate_models = get_models()
21
+ properties = list(candidate_models.keys())
22
+ model = MolecularPropertyPredictionModel()
23
+
24
+ def get_description(property_name):
25
+ return dataset_descriptions[property_name]
26
+
27
+ def predict_single_label(smiles, property_name):
28
+ adapter_id = candidate_models[property_name]
29
+ info = model.swith_adapter(property_name, adapter_id)
30
+
31
+ running_status = None
32
+ if info == "keep":
33
+ running_status = "Adapter is the same as the current one"
34
+ #print("Adapter is the same as the current one")
35
+ elif info == "switched":
36
+ running_status = "Adapter is switched successfully"
37
+ #print("Adapter is switched successfully")
38
+ elif info == "error":
39
+ running_status = "Adapter is not found"
40
+ #print("Adapter is not found")
41
+ return "NA", running_status
42
+ else:
43
+ running_status = "Unknown error"
44
+ return "NA", running_status
45
+
46
+ #prediction = model.predict(smiles, property_name, adapter_id)
47
+ prediction = model.predict_single_smiles(smiles, task_types[property_name])
48
+ if prediction is None:
49
+ return "NA", "Invalid SMILES string"
50
+
51
+ # if the prediction is a float, round it to 3 decimal places
52
+ if isinstance(prediction, float):
53
+ prediction = round(prediction, 3)
54
+
55
+ return prediction, "Prediction is done"
56
+
57
+ def predict_file(file, property_name):
58
+ adapter_id = candidate_models[property_name]
59
+ info = model.swith_adapter(property_name, adapter_id)
60
+
61
+ running_status = None
62
+ if info == "keep":
63
+ running_status = "Adapter is the same as the current one"
64
+ #print("Adapter is the same as the current one")
65
+ elif info == "switched":
66
+ running_status = "Adapter is switched successfully"
67
+ #print("Adapter is switched successfully")
68
+ elif info == "error":
69
+ running_status = "Adapter is not found"
70
+ #print("Adapter is not found")
71
+ return None, None, file, running_status
72
+ else:
73
+ running_status = "Unknown error"
74
+ return None, None, file, running_status
75
+
76
+ df = pd.read_csv(file)
77
+ # we have already checked the file contains the "smiles" column
78
+ df = model.predict_file(df, task_types[property_name])
79
+ # we should save this file to the disk to be downloaded
80
+ # rename the file to have "_prediction" suffix
81
+ prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
82
+ print(file, prediction_file)
83
+ # save the file to the disk
84
+ df.to_csv(prediction_file, index=False)
85
+
86
+ return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
87
+
88
+ def validate_file(file):
89
+ try:
90
+ if file.endswith(".csv"):
91
+ df = pd.read_csv(file)
92
+ if "smiles" not in df.columns:
93
+ # we should clear the file input
94
+ return "Invalid file content. The csv file must contain column named 'smiles'", \
95
+ None, gr.update(visible=False), gr.update(visible=False)
96
+
97
+ # check the length of the smiles
98
+ length = len(df["smiles"])
99
+
100
+ elif file.endswith(".smi"):
101
+ return "Invalid file extension", \
102
+ None, gr.update(visible=False), gr.update(visible=False)
103
+
104
+ else:
105
+ return "Invalid file extension", \
106
+ None, gr.update(visible=False), gr.update(visible=False)
107
+ except Exception as e:
108
+ return "Invalid file content.", \
109
+ None, gr.update(visible=False), gr.update(visible=False)
110
+
111
+ if length > 100:
112
+ return "The space does not support the file containing more than 100 SMILES", \
113
+ None, gr.update(visible=False), gr.update(visible=False)
114
+
115
+ return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
116
+
117
+
118
+ def raise_error(status):
119
+ if status != "Valid file":
120
+ raise gr.Error(status)
121
+ return None
122
+
123
+
124
+ def clear_file(download_button):
125
+ # we might need to delete the prediction file and uploaded file
126
+ prediction_path = download_button
127
+ print(prediction_path)
128
+ if prediction_path and os.path.exists(prediction_path):
129
+ os.remove(prediction_path)
130
+ original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
131
+ original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
132
+ if os.path.exists(original_data_file_0):
133
+ os.remove(original_data_file_0)
134
+ if os.path.exists(original_data_file_1):
135
+ os.remove(original_data_file_1)
136
+ #if os.path.exists(file):
137
+ # os.remove(file)
138
+ #prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
139
+ #if os.path.exists(prediction_file):
140
+ # os.remove(prediction_file)
141
+
142
+
143
+ return gr.update(visible=False), gr.update(visible=False), None
144
+
145
+ def build_inference():
146
+
147
+ with gr.Blocks() as demo:
148
+ # first row - Dropdown input
149
+ #with gr.Row():
150
+ dropdown = gr.Dropdown(properties, label="Property", value=properties[0])
151
+ description_box = gr.Textbox(label="Property description", lines=5,
152
+ interactive=False,
153
+ value=dataset_descriptions[properties[0]])
154
+ # third row - Textbox input and prediction label
155
+ with gr.Row(equal_height=True):
156
+ with gr.Column():
157
+ textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
158
+ lines=1)
159
+ predict_single_smiles_button = gr.Button("Predict", size='sm')
160
+ prediction = gr.Label("Prediction will appear here")
161
+
162
+ running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
163
+
164
+ input_file = gr.File(label="Molecule file",
165
+ file_count='single',
166
+ file_types=[".smi", ".csv"], height=300)
167
+ predict_file_button = gr.Button("Predict", size='sm', visible=False)
168
+ download_button = gr.DownloadButton("Download", size='sm', visible=False)
169
+
170
+ # dropdown change event
171
+ dropdown.change(get_description, inputs=dropdown, outputs=description_box)
172
+ # predict single button click event
173
+ predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])
174
+ # input file upload event
175
+ file_status = gr.State()
176
+ input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
177
+ # input file clear event
178
+ input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
179
+ # predict file button click event
180
+ predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label])
181
+
182
+ return demo
183
+
184
+
185
+ demo = build_inference()
186
+
187
+ if __name__ == '__main__':
188
+ demo.launch()
dataset_descriptions.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ADMET_Caco2_Wang": {
3
+ "task_type": "regression",
4
+ "description": "predict drug permeability, measured in cm/s, using the Caco-2 cell line as an in vitro model to simulate human intestinal tissue permeability",
5
+ "num_molecules": 906
6
+ },
7
+ "ADMET_Bioavailability_Ma": {
8
+ "task_type": "classification",
9
+ "description": "predict oral bioavailability with binary labels, indicating the rate and extent a drug becomes available at its site of action",
10
+ "num_molecules": 640
11
+ },
12
+ "ADMET_Lipophilicity_AstraZeneca": {
13
+ "task_type": "regression",
14
+ "description": "predict lipophilicity with continuous labels, measured as a log-ratio, indicating a drug's ability to dissolve in lipid environments",
15
+ "num_molecules": 4200
16
+ },
17
+ "ADMET_Solubility_AqSolDB": {
18
+ "task_type": "regression",
19
+ "description": "predict aqueous solubility with continuous labels, measured in log mol/L, indicating a drug's ability to dissolve in water",
20
+ "num_molecules": 9982
21
+ },
22
+ "ADMET_HIA_Hou": {
23
+ "task_type": "classification",
24
+ "description": "predict human intestinal absorption (HIA) with binary labels, indicating a drug's ability to be absorbed into the bloodstream",
25
+ "num_molecules": 578
26
+ },
27
+ "ADMET_Pgp_Broccatelli": {
28
+ "task_type": "classification",
29
+ "description": "predict P-glycoprotein (Pgp) inhibition with binary labels, indicating a drug's potential to alter bioavailability and overcome multidrug resistance",
30
+ "num_molecules": 1212
31
+ },
32
+ "ADMET_BBB_Martins": {
33
+ "task_type": "classification",
34
+ "description": "predict blood-brain barrier permeability with binary labels, indicating a drug's ability to penetrate the barrier to reach the brain",
35
+ "num_molecules": 1915
36
+ },
37
+ "ADMET_PPBR_AZ": {
38
+ "task_type": "regression",
39
+ "description": "predict plasma protein binding rate with continuous labels, indicating the percentage of a drug bound to plasma proteins in the blood",
40
+ "num_molecules": 1797
41
+ },
42
+ "ADMET_VDss_Lombardo": {
43
+ "task_type": "regression",
44
+ "description": "predict the volume of distribution at steady state (VDss), indicating drug concentration in tissues versus blood",
45
+ "num_molecules": 1130
46
+ },
47
+ "ADMET_CYP2C9_Veith": {
48
+ "task_type": "classification",
49
+ "description": "predict CYP2C9 inhibition with binary labels, indicating the drug's ability to inhibit the CYP2C9 enzyme involved in metabolism",
50
+ "num_molecules": 12092
51
+ },
52
+ "ADMET_CYP2D6_Veith": {
53
+ "task_type": "classification",
54
+ "description": "predict CYP2D6 inhibition with binary labels, indicating the drug's potential to inhibit the CYP2D6 enzyme involved in metabolism",
55
+ "num_molecules": 13130
56
+ },
57
+ "ADMET_CYP3A4_Veith": {
58
+ "task_type": "classification",
59
+ "description": "predict CPY3A4 inhibition with binary labels, indicating the drug's ability to inhibit the CPY3A4 enzyme involved in metabolism",
60
+ "num_molecules": 12328
61
+ },
62
+ "ADMET_CYP2C9_Substrate_CarbonMangels": {
63
+ "task_type": "classification",
64
+ "description": "predict whether a drug is a substrate of the CYP2C9 enzyme with binary labels, indicating its potential to be metabolized",
65
+ "num_molecules": 666
66
+ },
67
+ "ADMET_CYP2D6_Substrate_CarbonMangels": {
68
+ "task_type": "classification",
69
+ "description": "predict whether a drug is a substrate of the CYP2D6 enzyme with binary labels, indicating its potential to be metabolized",
70
+ "num_molecules": 664
71
+ },
72
+ "ADMET_CYP3A4_Substrate_CarbonMangels": {
73
+ "task_type": "classification",
74
+ "description": "predict whether a drug is a substrate of the CYP3A4 enzyme with binary labels, indicating its potential to be metabolized",
75
+ "num_molecules": 667
76
+ },
77
+ "ADMET_Half_Life_Obach": {
78
+ "task_type": "regression",
79
+ "description": "predict the half-life duration of a drug, measured in hours, indicating the time for its concentration to reduce by half",
80
+ "num_molecules": 667
81
+ },
82
+ "ADMET_Clearance_Hepatocyte_AZ": {
83
+ "task_type": "regression",
84
+ "description": "predict drug clearance, measured in \u03bcL/min/10^6 cells, from hepatocyte experiments, indicating the rate at which the drug is removed from body",
85
+ "num_molecules": 1020
86
+ },
87
+ "ADMET_Clearance_Microsome_AZ": {
88
+ "task_type": "regression",
89
+ "description": "predict drug clearance, measured in mL/min/g, from microsome experiments, indicating the rate at which the drug is removed from body",
90
+ "num_molecules": 1102
91
+ },
92
+ "ADMET_LD50_Zhu": {
93
+ "task_type": "regression",
94
+ "description": "predict the acute toxicity of a drug, measured as the dose leading to lethal effects in log(kg/mol)",
95
+ "num_molecules": 7385
96
+ },
97
+ "ADMET_hERG": {
98
+ "task_type": "classification",
99
+ "description": "predict whether a drug blocks the hERG channel, which is crucial for heart rhythm, potentially leading to adverse effects",
100
+ "num_molecules": 648
101
+ },
102
+ "ADMET_AMES": {
103
+ "task_type": "classification",
104
+ "description": "predict whether a drug is mutagenic with binary labels, indicating its ability to induce genetic alterations",
105
+ "num_molecules": 7255
106
+ },
107
+ "ADMET_DILI": {
108
+ "task_type": "classification",
109
+ "description": "predict whether a drug can cause liver injury with binary labels, indicating its potential for hepatotoxicity",
110
+ "num_molecules": 475
111
+ }
112
+ }
utils.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
2
+ from typing import Optional, Dict, Sequence, List
3
+ import transformers
4
+ from peft import PeftModel
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from dataclasses import dataclass
8
+ import pandas as pd
9
+ from datasets import Dataset
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from huggingface_hub import hf_hub_download
13
+ import os
14
+ import pickle
15
+ from sklearn import preprocessing
16
+ import json
17
+
18
+
19
+ from rdkit import RDLogger, Chem
20
+ # Suppress RDKit INFO messages
21
+ RDLogger.DisableLog('rdApp.*')
22
+
23
+ # we have a dictionary to store the task types of the models
24
+ task_types = {
25
+ "admet_ppbr_az": "regression",
26
+ "admet_half_life_obach": "regression",
27
+ }
28
+
29
+ # read the dataset descriptions
30
+ with open("dataset_descriptions.json", "r") as f:
31
+ dataset_description_temp = json.load(f)
32
+
33
+ dataset_descriptions = dict()
34
+
35
+ for dataset in dataset_description_temp:
36
+ dataset_name = dataset.lower()
37
+ dataset_descriptions[dataset_name] = \
38
+ f"{dataset_name} is a {dataset_description_temp[dataset]['task_type']} task, " + \
39
+ f"where the goal is to {dataset_description_temp[dataset]['description']}."
40
+
41
+ class Scaler:
42
+ def __init__(self, log=False):
43
+ self.log = log
44
+ self.offset = None
45
+ self.scaler = None
46
+
47
+ def fit(self, y):
48
+ # make the values non-negative
49
+ self.offset = np.min([np.min(y), 0.0])
50
+ y = y.reshape(-1, 1) - self.offset
51
+
52
+ # scale the input data
53
+ if self.log:
54
+ y = np.log10(y + 1.0)
55
+
56
+ self.scaler = preprocessing.StandardScaler().fit(y)
57
+
58
+ def transform(self, y):
59
+ y = y.reshape(-1, 1) - self.offset
60
+
61
+ # scale the input data
62
+ if self.log:
63
+ y = np.log10(y + 1.0)
64
+
65
+ y_scale = self.scaler.transform(y)
66
+
67
+ return y_scale
68
+
69
+ def inverse_transform(self, y_scale):
70
+ y = self.scaler.inverse_transform(y_scale.reshape(-1, 1))
71
+
72
+ if self.log:
73
+ y = 10.0**y - 1.0
74
+
75
+ y = y + self.offset
76
+
77
+ return y
78
+
79
+
80
+ def smart_tokenizer_and_embedding_resize(
81
+ special_tokens_dict: Dict,
82
+ tokenizer: transformers.PreTrainedTokenizer,
83
+ model: transformers.PreTrainedModel,
84
+ non_special_tokens = None,
85
+ ):
86
+ """Resize tokenizer and embedding.
87
+
88
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
89
+ """
90
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
91
+ num_old_tokens = model.get_input_embeddings().weight.shape[0]
92
+ num_new_tokens = len(tokenizer) - num_old_tokens
93
+ if num_new_tokens == 0:
94
+ return
95
+
96
+ model.resize_token_embeddings(len(tokenizer))
97
+
98
+ if num_new_tokens > 0:
99
+ input_embeddings_data = model.get_input_embeddings().weight.data
100
+
101
+ input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
102
+
103
+ input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
104
+ print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
105
+
106
+ @dataclass
107
+ class DataCollator(object):
108
+ tokenizer: transformers.PreTrainedTokenizer
109
+ source_max_len: int
110
+ molecule_start_str: str
111
+ end_str: str
112
+
113
+ def augment_molecule(self, molecule: str) -> str:
114
+ return self.sme.augment([molecule])[0]
115
+
116
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
117
+
118
+ sources = []
119
+ targets = []
120
+
121
+ for example in instances:
122
+ smiles = example['smiles'].strip()
123
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
124
+
125
+ # get the properties except the smiles and mol_id cols
126
+ #props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
127
+ source = f"{self.molecule_start_str}{smiles}{self.end_str}"
128
+ sources.append(source)
129
+
130
+ # Tokenize
131
+ tokenized_sources_with_prompt = self.tokenizer(
132
+ sources,
133
+ max_length=self.source_max_len,
134
+ truncation=True,
135
+ add_special_tokens=False,
136
+ )
137
+ input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
138
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
139
+
140
+ data_dict = {
141
+ 'input_ids': input_ids,
142
+ 'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
143
+ }
144
+
145
+ return data_dict
146
+
147
+ class MolecularPropertyPredictionModel():
148
+ def __init__(self):
149
+ self.adapter_name = None
150
+
151
+ # we need to keep track of the paths of adapter scalers
152
+ # we don't want to download the same scaler multiple times
153
+ self.apapter_scaler_path = dict()
154
+
155
+ DEFAULT_PAD_TOKEN = "[PAD]"
156
+
157
+ # load the base model
158
+ config = AutoConfig.from_pretrained(
159
+ "ChemFM/ChemFM-3B",
160
+ num_labels=1,
161
+ finetuning_task="classification", # this is not about our task type
162
+ trust_remote_code=True,
163
+ )
164
+
165
+ self.base_model = AutoModelForSequenceClassification.from_pretrained(
166
+ "ChemFM/ChemFM-3B",
167
+ config=config,
168
+ device_map="cpu",
169
+ trust_remote_code=True,
170
+ )
171
+
172
+ # load the tokenizer
173
+ self.tokenizer = AutoTokenizer.from_pretrained(
174
+ "ChemFM/admet_ppbr_az",
175
+ trust_remote_code=True,
176
+ )
177
+ special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
178
+ smart_tokenizer_and_embedding_resize(
179
+ special_tokens_dict=special_tokens_dict,
180
+ tokenizer=self.tokenizer,
181
+ model=self.base_model
182
+ )
183
+ self.base_model.config.pad_token_id = self.tokenizer.pad_token_id
184
+
185
+ self.data_collator = DataCollator(
186
+ tokenizer=self.tokenizer,
187
+ source_max_len=512,
188
+ molecule_start_str="<molstart>",
189
+ end_str="<eos>",
190
+ )
191
+
192
+
193
+ def swith_adapter(self, adapter_name, adapter_id):
194
+ # return flag:
195
+ # keep: adapter is the same as the current one
196
+ # switched: adapter is switched successfully
197
+ # error: adapter is not found
198
+
199
+ if adapter_name == self.adapter_name:
200
+ return "keep"
201
+ # switch adapter
202
+ try:
203
+ self.adapter_name = adapter_name
204
+ self.lora_model = PeftModel.from_pretrained(self.base_model, adapter_id)
205
+ if adapter_name not in self.apapter_scaler_path:
206
+ self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl")
207
+ if os.path.exists(self.apapter_scaler_path[adapter_name]):
208
+ self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
209
+ else:
210
+ self.scaler = None
211
+
212
+ return "switched"
213
+ except Exception as e:
214
+ # handle error
215
+ return "error"
216
+
217
+ def predict(self, valid_df, task_type):
218
+ test_dataset = Dataset.from_pandas(valid_df)
219
+ # construct the dataloader
220
+ test_loader = torch.utils.data.DataLoader(
221
+ test_dataset,
222
+ batch_size=4,
223
+ collate_fn=self.data_collator,
224
+ )
225
+ # predict
226
+
227
+ y_pred = []
228
+ for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
229
+ with torch.no_grad():
230
+ batch = {k: v.to(self.lora_model.device) for k, v in batch.items()}
231
+ outputs = self.lora_model(**batch)
232
+ if task_type == "regression": # TODO: check if the model is regression or classification
233
+ y_pred.append(outputs.logits.cpu().detach().numpy())
234
+ else:
235
+ y_pred.append((torch.sigmoid(outputs.logits) > 0.5).cpu().detach().numpy())
236
+
237
+ y_pred = np.concatenate(y_pred, axis=0)
238
+ if task_type=="regression" and self.scaler is not None:
239
+ y_pred = self.scaler.inverse_transform(y_pred)
240
+
241
+
242
+ return y_pred
243
+
244
+ def predict_single_smiles(self, smiles, task_type):
245
+ assert task_type in ["regression", "classification"]
246
+
247
+ # check the SMILES string is valid
248
+ if not Chem.MolFromSmiles(smiles):
249
+ return None
250
+
251
+ valid_df = pd.DataFrame([smiles], columns=['smiles'])
252
+ results = self.predict(valid_df, task_type)
253
+ # predict
254
+ return results.item()
255
+
256
+ def predict_file(self, df, task_type):
257
+ # we should add the index first
258
+ df = df.reset_index()
259
+ # we need to check the SMILES strings are valid, the invalid ones will be moved to the last
260
+ valid_idx = []
261
+ invalid_idx = []
262
+ for idx, smiles in enumerate(df['smiles']):
263
+ if Chem.MolFromSmiles(smiles):
264
+ valid_idx.append(idx)
265
+ else:
266
+ invalid_idx.append(idx)
267
+ valid_df = df.loc[valid_idx]
268
+ # get the smiles list
269
+ valid_df_smiles = valid_df['smiles'].tolist()
270
+
271
+ input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
272
+ results = self.predict(input_df, task_type)
273
+
274
+ # add the results to the dataframe
275
+ df.loc[valid_idx, 'prediction'] = results
276
+ df.loc[invalid_idx, 'prediction'] = np.nan
277
+
278
+ # drop the index column
279
+ df = df.drop(columns=['index'])
280
+
281
+ # phrase file
282
+ return df
283
+
284
+
285
+
286
+