Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
import os | |
from typing import Optional, Dict, Sequence | |
import transformers | |
from peft import PeftModel | |
import torch | |
from dataclasses import dataclass, field | |
from huggingface_hub import hf_hub_download | |
import json | |
import pandas as pd | |
from datasets import Dataset | |
from tqdm import tqdm | |
import spaces | |
from rdkit import RDLogger, Chem | |
# Suppress RDKit INFO messages | |
RDLogger.DisableLog('rdApp.*') | |
DEFAULT_PAD_TOKEN = "[PAD]" | |
device_map = "cuda" | |
def compute_rank(prediction,raw=False,alpha=1.0): | |
valid_score = [[k for k in range(len(prediction[j]))] for j in range(len(prediction))] | |
invalid_rates = [0 for k in range(len(prediction[0]))] | |
rank = {} | |
highest = {} | |
for j in range(len(prediction)): | |
for k in range(len(prediction[j])): | |
if prediction[j][k] == "": | |
valid_score[j][k] = 10 + 1 | |
invalid_rates[k] += 1 | |
de_error = [i[0] for i in sorted(list(zip(prediction[j], valid_score[j])), key=lambda x: x[1]) if i[0] != ""] | |
prediction[j] = list(set(de_error)) | |
prediction[j].sort(key=de_error.index) | |
for k, data in enumerate(prediction[j]): | |
if data in rank: | |
rank[data] += 1 / (alpha * k + 1) | |
else: | |
rank[data] = 1 / (alpha * k + 1) | |
if data in highest: | |
highest[data] = min(k,highest[data]) | |
else: | |
highest[data] = k | |
return rank,invalid_rates | |
class DataCollatorForCausalLMEval(object): | |
tokenizer: transformers.PreTrainedTokenizer | |
source_max_len: int | |
target_max_len: int | |
reactant_start_str: str | |
product_start_str: str | |
end_str: str | |
def augment_molecule(self, molecule: str) -> str: | |
return self.sme.augment([molecule])[0] | |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
srcs = instances[0]['src'] | |
task_type = instances[0]['task_type'] | |
if task_type == 'retrosynthesis': | |
src_start_str = self.product_start_str | |
tgt_start_str = self.reactant_start_str | |
else: | |
src_start_str = self.reactant_start_str | |
tgt_start_str = self.product_start_str | |
generation_prompts = [] | |
generation_prompt = f"{src_start_str}{srcs}{self.end_str}{tgt_start_str}" | |
generation_prompts.append(generation_prompt) | |
data_dict = { | |
'generation_prompts': generation_prompts | |
} | |
return data_dict | |
def smart_tokenizer_and_embedding_resize( | |
special_tokens_dict: Dict, | |
tokenizer: transformers.PreTrainedTokenizer, | |
model: transformers.PreTrainedModel, | |
non_special_tokens = None, | |
): | |
"""Resize tokenizer and embedding. | |
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. | |
""" | |
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens) | |
num_old_tokens = model.get_input_embeddings().weight.shape[0] | |
num_new_tokens = len(tokenizer) - num_old_tokens | |
if num_new_tokens == 0: | |
return | |
model.resize_token_embeddings(len(tokenizer)) | |
if num_new_tokens > 0: | |
input_embeddings_data = model.get_input_embeddings().weight.data | |
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) | |
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg | |
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.") | |
class ReactionPredictionModel(): | |
def __init__(self, candidate_models): | |
for model in candidate_models: | |
if "retro" in model: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
candidate_models[list(candidate_models.keys())[0]], | |
padding_side="right", | |
use_fast=True, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
self.load_retro_model(candidate_models[model]) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
candidate_models[list(candidate_models.keys())[0]], | |
padding_side="right", | |
use_fast=True, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
self.load_forward_model(candidate_models[model]) | |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN")) | |
string_template = json.load(open(string_template_path, 'r')) | |
reactant_start_str = string_template['REACTANTS_START_STRING'] | |
product_start_str = string_template['PRODUCTS_START_STRING'] | |
end_str = string_template['END_STRING'] | |
self.data_collator = DataCollatorForCausalLMEval( | |
tokenizer=self.tokenizer, | |
source_max_len=512, | |
target_max_len=512, | |
reactant_start_str=reactant_start_str, | |
product_start_str=product_start_str, | |
end_str=end_str, | |
) | |
def load_retro_model(self, model_path): | |
# our retro model is lora model | |
config = AutoConfig.from_pretrained( | |
"ChemFM/ChemFM-3B", | |
trust_remote_code=True, | |
token=os.environ.get("TOKEN") | |
) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"ChemFM/ChemFM-3B", | |
config=config, | |
trust_remote_code=True, | |
device_map=device_map, | |
token = os.environ.get("TOKEN") | |
) | |
# we should resize the embedding layer of the base model to match the adapter's tokenizer | |
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) | |
smart_tokenizer_and_embedding_resize( | |
special_tokens_dict=special_tokens_dict, | |
tokenizer=self.tokenizer, | |
model=base_model | |
) | |
base_model.config.pad_token_id = self.tokenizer.pad_token_id | |
# load the adapter model | |
self.retro_model = PeftModel.from_pretrained( | |
base_model, | |
model_path, | |
token = os.environ.get("TOKEN") | |
) | |
self.retro_model.to("cuda") | |
self.retro_model.eval() | |
def load_forward_model(self, model_path): | |
config = AutoConfig.from_pretrained( | |
model_path, | |
device_map=device_map, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
self.forward_model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
config=config, | |
device_map=device_map, | |
trust_remote_code=True, | |
token = os.environ.get("TOKEN") | |
) | |
# the finetune tokenizer could be in different size with pretrain tokenizer, and also, we need to add PAD_TOKEN | |
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) | |
smart_tokenizer_and_embedding_resize( | |
special_tokens_dict=special_tokens_dict, | |
tokenizer=self.tokenizer, | |
model=self.forward_model | |
) | |
self.forward_model.config.pad_token_id = self.tokenizer.pad_token_id | |
self.forward_model.to("cuda") | |
self.forward_model.eval() | |
def predict(self, test_loader, task_type): | |
predictions = [] | |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"): | |
generation_prompts = batch['generation_prompts'][0] | |
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True) | |
del inputs['token_type_ids'] | |
if task_type == "retrosynthesis": | |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10, | |
do_sample=False, num_beams=10, | |
eos_token_id=self.tokenizer.eos_token_id, | |
early_stopping='never', | |
pad_token_id=self.tokenizer.pad_token_id, | |
length_penalty=0.0, | |
) | |
else: | |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10, | |
do_sample=False, num_beams=10, | |
eos_token_id=self.tokenizer.eos_token_id, | |
early_stopping='never', | |
pad_token_id=self.tokenizer.pad_token_id, | |
length_penalty=0.0, | |
) | |
original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):], | |
skip_special_tokens=True) | |
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list) | |
# canonize the SMILES | |
canonized_smiles_list = [] | |
temp = [] | |
for original_smiles in original_smiles_list: | |
temp.append(original_smiles) | |
try: | |
canonized_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(original_smiles))) | |
except: | |
canonized_smiles_list.append("") | |
#canonized_smiles_list = \ | |
#['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1cc(F)c([N+](=O)[O-])cc1F', 'N#Cc1ccsc1Nc1cc(Cl)c(F)cc1[N+](=O)[O-]', 'N#Cc1cnsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1cc(F)c(F)cc1Nc1sccc1C#N', 'N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=N)[O-]', 'N#Cc1cc(C#N)c(Nc2cc(F)c(F)cc2[N+](=O)[O-])s1', 'N#Cc1ccsc1Nc1c(F)c(F)cc(F)c1[N+](=O)[O-]', 'Nc1sccc1CNc1cc(F)c(F)cc1[N+](=O)[O-]', 'N#Cc1ccsc1Nc1ccc(F)cc1[N+](=O)[O-]'] | |
predictions.append(canonized_smiles_list) | |
rank, invalid_rate = compute_rank(predictions) | |
return rank | |
def predict_single_smiles(self, smiles, task_type): | |
if task_type == "full_retro": | |
if "." in smiles: | |
return None | |
task_type = "retrosynthesis" if task_type == "full_retro" else "synthesis" | |
# canonicalize the smiles | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
smiles = Chem.MolToSmiles(mol) | |
smiles_list = [smiles] | |
task_type_list = [task_type] | |
df = pd.DataFrame({"src": smiles_list, "task_type": task_type_list}) | |
test_dataset = Dataset.from_pandas(df) | |
# construct the dataloader | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=1, | |
collate_fn=self.data_collator, | |
) | |
rank = self.predict(test_loader, task_type) | |
return rank | |