File size: 12,773 Bytes
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197301d
a7655fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
 
 
 
 
 
a7655fc
 
 
 
 
6bb1bdf
 
 
 
 
 
a7655fc
 
 
6bb1bdf
 
 
 
a7655fc
 
 
 
 
 
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7655fc
 
6bb1bdf
a7655fc
 
 
 
 
 
 
 
6bb1bdf
a7655fc
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
 
 
 
e98af12
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
254cfcc
59a1b31
6bb1bdf
 
 
 
8a703f0
6bb1bdf
254cfcc
6bb1bdf
a3c849b
6bb1bdf
 
 
 
 
254cfcc
6bb1bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ec6056
 
 
c5a6b33
2ec6056
a7655fc
 
 
 
 
2ec6056
96a5b46
6bb1bdf
 
 
 
 
 
 
a7655fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb1bdf
ac9d211
6bb1bdf
 
a7655fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a703f0
 
a7655fc
43da3de
a7655fc
 
 
 
6bb1bdf
a7655fc
 
 
762e097
 
6bb1bdf
 
 
a7655fc
 
6bb1bdf
a7655fc
 
 
6bb1bdf
a7655fc
 
 
6bb1bdf
 
 
a7655fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
from typing import Optional, Dict, Sequence, List
import transformers
from peft import PeftModel
import torch
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
import numpy as np
from huggingface_hub import hf_hub_download
import os
import pickle
from sklearn import preprocessing
import json
import spaces
import time

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")


from rdkit import RDLogger, Chem
# Suppress RDKit INFO messages
RDLogger.DisableLog('rdApp.*')

# we have a dictionary to store the task types of the models
#task_types = {
#    "admet_bioavailability_ma": "classification",
#    "admet_ppbr_az": "regression",
#    "admet_half_life_obach": "regression",
#}

# read the dataset descriptions
with open("dataset_descriptions.json", "r") as f:
    dataset_description_temp = json.load(f)

dataset_descriptions = dict()
dataset_property_names = dict()
dataset_task_types = dict()
dataset_property_names_to_dataset = dict()

for dataset in dataset_description_temp:
    dataset_name = dataset.lower()
    dataset_descriptions[dataset_name] = \
        f"{dataset_description_temp[dataset]['task_name']} is a {dataset_description_temp[dataset]['task_type']} task, " + \
        f"where the goal is to {dataset_description_temp[dataset]['description']}. \n" + \
        f"More information can be found at {dataset_description_temp[dataset]['url']}."
    dataset_property_names[dataset_name] = dataset_description_temp[dataset]['task_name']
    dataset_property_names_to_dataset[dataset_description_temp[dataset]['task_name']] = dataset_name
    dataset_task_types[dataset_name] = dataset_description_temp[dataset]['task_type']

class Scaler:
    def __init__(self, log=False):
        self.log = log
        self.offset = None
        self.scaler = None

    def fit(self, y):
        # make the values non-negative
        self.offset = np.min([np.min(y), 0.0])
        y = y.reshape(-1, 1) - self.offset

        # scale the input data
        if self.log:
            y = np.log10(y + 1.0)

        self.scaler = preprocessing.StandardScaler().fit(y)

    def transform(self, y):
        y = y.reshape(-1, 1) - self.offset

        # scale the input data
        if self.log:
            y = np.log10(y + 1.0)

        y_scale = self.scaler.transform(y)

        return y_scale

    def inverse_transform(self, y_scale):
        y = self.scaler.inverse_transform(y_scale.reshape(-1, 1))

        if self.log:
            y = 10.0**y - 1.0

        y = y + self.offset

        return y


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
    non_special_tokens = None,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
    num_old_tokens = model.get_input_embeddings().weight.shape[0]
    num_new_tokens = len(tokenizer) - num_old_tokens
    if num_new_tokens == 0:
        return
    
    model.resize_token_embeddings(len(tokenizer))
    
    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
    print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")

@dataclass
class DataCollator(object):
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int
    molecule_start_str: str
    end_str: str

    def augment_molecule(self, molecule: str) -> str:
        return self.sme.augment([molecule])[0]

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        with calculateDuration("DataCollator"):
            sources = []
        
            for example in instances:
                smiles = example['smiles'].strip()
                smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))

                # get the properties except the smiles and mol_id cols
                #props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
                source = f"{self.molecule_start_str}{smiles}{self.end_str}"
                sources.append(source)
        
            # Tokenize
            tokenized_sources_with_prompt = self.tokenizer(
                sources,
                max_length=self.source_max_len,
                truncation=True,
                add_special_tokens=False,
            )
            input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

            data_dict = {
                'input_ids': input_ids,
                'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
            }
        
        return data_dict

class MolecularPropertyPredictionModel():
    def __init__(self, candidate_models):
        self.adapter_name = None

        # we need to keep track of the paths of adapter scalers
        # we don't want to download the same scaler multiple times
        self.apapter_scaler_path = dict()

        DEFAULT_PAD_TOKEN = "[PAD]"

        # load the base model
        config = AutoConfig.from_pretrained(
            "ChemFM/ChemFM-3B",
            num_labels=1,
            finetuning_task="classification", # this is not about our task type
            trust_remote_code=True,
            token = os.environ.get("TOKEN")
        )

        self.base_model = AutoModelForSequenceClassification.from_pretrained(
            "ChemFM/ChemFM-3B",
            config=config,
            device_map="cuda",
            trust_remote_code=True,
            token = os.environ.get("TOKEN")
        )
        #

        # load the tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            "ChemFM/admet_ppbr_az",
            trust_remote_code=True,
            token = os.environ.get("TOKEN")
        )
        special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=self.tokenizer,
            model=self.base_model
        )
        self.base_model.config.pad_token_id = self.tokenizer.pad_token_id

        self.data_collator = DataCollator(
            tokenizer=self.tokenizer,
            source_max_len=512,
            molecule_start_str="<molstart>",
            end_str="<eos>",
        )

        # load the adapters firstly
        for adapter_name in candidate_models:
            adapter_id = candidate_models[adapter_name]
            print(f"loading {adapter_name} from {adapter_id}...")
            self.base_model.load_adapter(adapter_id, adapter_name=adapter_name, token = os.environ.get("TOKEN"))
            try:
                self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
            except:
                self.apapter_scaler_path[adapter_name] = None
                assert dataset_task_types[adapter_name] == "classification", f"{adapter_name} is not a regression task."

        self.base_model.to("cuda")
    
    def swith_adapter(self, adapter_name, adapter_id):
        # return flag:
        # keep: adapter is the same as the current one
        # switched: adapter is switched successfully
        # error: adapter is not found

        with calculateDuration("switching adapter"):
            if adapter_name == self.adapter_name:
                return "keep"
            # switch adapter
            try:
                #self.adapter_name = adapter_name
                #print(self.adapter_name, adapter_id)
                #self.lora_model = PeftModel.from_pretrained(self.base_model, adapter_id, token = os.environ.get("TOKEN"))
                #self.lora_model.to("cuda")
                #print(self.lora_model)
                self.base_model.set_adapter(adapter_name)
                self.base_model.eval()

                #if adapter_name not in self.apapter_scaler_path:
                #    self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
                if self.apapter_scaler_path[adapter_name] and os.path.exists(self.apapter_scaler_path[adapter_name]):
                    self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
                else:
                    self.scaler = None

                self.adapter_name = adapter_name
                return "switched"
            except Exception as e:
                # handle error
                return "error"
    

    def predict(self, valid_df, task_type):

        with calculateDuration("predicting"):
            with calculateDuration("construct dataloader"):
                test_dataset = Dataset.from_pandas(valid_df)
                # construct the dataloader
                test_loader = torch.utils.data.DataLoader(
                    test_dataset,
                    batch_size=16,
                    collate_fn=self.data_collator,
                )

            # predict
            y_pred = []
            for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
                with torch.no_grad():
                    batch = {k: v.to(self.base_model.device) for k, v in batch.items()}
                    print(self.base_model.device)
                    print(batch)
                    outputs = self.base_model(**batch)
                    print(outputs)
                if task_type == "regression": # TODO: check if the model is regression or classification
                    y_pred.append(outputs.logits.cpu().detach().numpy())
                else:
                    y_pred.append((torch.sigmoid(outputs.logits)).cpu().detach().numpy())
        
            y_pred = np.concatenate(y_pred, axis=0)
            if task_type=="regression" and self.scaler is not None:
                y_pred = self.scaler.inverse_transform(y_pred)


        return y_pred
    
    def predict_single_smiles(self, smiles, task_type):
        with calculateDuration("predicting a single SMILES"):
            assert task_type in ["regression", "classification"]

            # check the SMILES string is valid
            if not Chem.MolFromSmiles(smiles):
                return None
        
            valid_df = pd.DataFrame([smiles], columns=['smiles'])
            results = self.predict(valid_df, task_type)
            # predict
        return results.item()
    
    def predict_file(self, df, task_type):
        with calculateDuration("predicting a file"):
            # we should add the index first
            df = df.reset_index()

            with calculateDuration("pre-checking SMILES"):
                # we need to check the SMILES strings are valid, the invalid ones will be moved to the last
                valid_idx = []
                invalid_idx = []
                for idx, smiles in enumerate(df['smiles']):
                    if Chem.MolFromSmiles(smiles):
                        valid_idx.append(idx)
                    else:
                        invalid_idx.append(idx)
                valid_df = df.loc[valid_idx]
                # get the smiles list
                valid_df_smiles = valid_df['smiles'].tolist()

            input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
            results = self.predict(input_df, task_type)

            # add the results to the dataframe
            df.loc[valid_idx, 'prediction'] = results
            df.loc[invalid_idx, 'prediction'] = np.nan

            # drop the index column
            df = df.drop(columns=['index'])

            # phrase file
        return df