from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import os from typing import Optional, Dict, Sequence import transformers from peft import PeftModel import torch from dataclasses import dataclass, field from huggingface_hub import hf_hub_download import json import pandas as pd from datasets import Dataset from tqdm import tqdm import spaces from rdkit import RDLogger, Chem # Suppress RDKit INFO messages RDLogger.DisableLog('rdApp.*') DEFAULT_PAD_TOKEN = "[PAD]" device_map = "cuda" def compute_rank(prediction,raw=False,alpha=1.0): valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))] invalid_rates = [0 for k in range(len(prediction[0]))] rank = {} highest = {} for j in range(len(prediction)): for k in range(len(prediction[j])): if prediction[j][k] == "": valid_score[j][k] = 10 + 1 invalid_rates[k] += 1 de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0] != ""] prediction[j] = list(set(de_error)) prediction[j].sort(key=de_error.index) for k, data in enumerate(prediction[j]): if data in rank: rank[data] += 1 / (alpha * k + 1) else: rank[data] = 1 / (alpha * k + 1) if data in highest: highest[data] = min(k,highest[data]) else: highest[data] = k return rank,invalid_rates @dataclass class DataCollatorForCausalLMEval(object): tokenizer: transformers.PreTrainedTokenizer source_max_len: int target_max_len: int reactant_start_str: str product_start_str: str end_str: str def augment_molecule(self, molecule: str) -> str: return self.sme.augment([molecule])[0] def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: srcs = instances[0]['src'] task_type = instances[0]['task_type'] if task_type == 'retrosynthesis': src_start_str = self.product_start_str tgt_start_str = self.reactant_start_str else: src_start_str = self.reactant_start_str tgt_start_str = self.product_start_str generation_prompts = [] generation_prompt = f"{src_start_str}{srcs}{self.end_str}{tgt_start_str}" generation_prompts.append(generation_prompt) data_dict = { 'generation_prompts': generation_prompts } return data_dict def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, non_special_tokens = None, ): """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. """ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens) num_old_tokens = model.get_input_embeddings().weight.shape[0] num_new_tokens = len(tokenizer) - num_old_tokens if num_new_tokens == 0: return model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings_data = model.get_input_embeddings().weight.data input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings_data[-num_new_tokens:] = input_embeddings_avg print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.") class ReactionPredictionModel(): def __init__(self, candidate_models): for model in candidate_models: if "retro" in model: self.tokenizer = AutoTokenizer.from_pretrained( candidate_models[list(candidate_models.keys())[0]], padding_side="right", use_fast=True, trust_remote_code=True, token = os.environ.get("TOKEN") ) self.load_retro_model(candidate_models[model]) else: self.tokenizer = AutoTokenizer.from_pretrained( candidate_models[list(candidate_models.keys())[0]], padding_side="right", use_fast=True, trust_remote_code=True, token = os.environ.get("TOKEN") ) self.load_forward_model(candidate_models[model]) string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN")) string_template = json.load(open(string_template_path, 'r')) reactant_start_str = string_template['REACTANTS_START_STRING'] product_start_str = string_template['PRODUCTS_START_STRING'] end_str = string_template['END_STRING'] self.data_collator = DataCollatorForCausalLMEval( tokenizer=self.tokenizer, source_max_len=512, target_max_len=512, reactant_start_str=reactant_start_str, product_start_str=product_start_str, end_str=end_str, ) def load_retro_model(self, model_path): # our retro model is lora model config = AutoConfig.from_pretrained( "ChemFM/ChemFM-3B", trust_remote_code=True, token=os.environ.get("TOKEN") ) base_model = AutoModelForCausalLM.from_pretrained( "ChemFM/ChemFM-3B", config=config, trust_remote_code=True, device_map=device_map, token = os.environ.get("TOKEN") ) # we should resize the embedding layer of the base model to match the adapter's tokenizer special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=self.tokenizer, model=base_model ) base_model.config.pad_token_id = self.tokenizer.pad_token_id # load the adapter model self.retro_model = PeftModel.from_pretrained( base_model, model_path, token = os.environ.get("TOKEN") ) self.retro_model.to("cuda") self.retro_model.eval() def load_forward_model(self, model_path): config = AutoConfig.from_pretrained( model_path, device_map=device_map, trust_remote_code=True, token = os.environ.get("TOKEN") ) self.forward_model = AutoModelForCausalLM.from_pretrained( model_path, config=config, device_map=device_map, trust_remote_code=True, token = os.environ.get("TOKEN") ) # the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) smart_tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=self.tokenizer, model=self.forward_model ) self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id self.forward_model.to("cuda") self.forward_model.eval() def predict(self, test_loader, task_type): predictions = [] for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"): generation_prompts = batch['generation_prompts'][0] inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True) del inputs['token_type_ids'] if task_type == "retrosynthesis": inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10, do_sample=False, num_beams=10, eos_token_id=self.tokenizer.eos_token_id, early_stopping='never', pad_token_id=self.tokenizer.pad_token_id, length_penalty=0.0, ) else: inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10, do_sample=False, num_beams=10, eos_token_id=self.tokenizer.eos_token_id, early_stopping='never', pad_token_id=self.tokenizer.pad_token_id, length_penalty=0.0, ) original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):], skip_special_tokens=True) original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list) # canonize the SMILES canonized_smiles_list = [] temp = [] for original_smiles in original_smiles_list: temp.append(original_smiles) try: canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles))) except: canonized_smiles_list.append("") #canonized_smiles_list = \ #['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]'] predictions.append(canonized_smiles_list) rank, invalid_rate = compute_rank(predictions) return rank def predict_single_smiles(self, smiles, task_type): if task_type == "full_retro": if "." in smiles: return None task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis" # canonicalize the smiles mol = Chem.MolFromSmiles(smiles) if mol is None: return None smiles = Chem.MolToSmiles(mol) smiles_list = [smiles] task_type_list = [task_type] df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list}) test_dataset = Dataset.from_pandas(df) # construct the dataloader test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=1, collate_fn=self.data_collator, ) rank = self.predict(test_loader, task_type) return rank