Upload drug_generator.py
Browse files- drug_generator.py +143 -0
drug_generator.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon May 1 19:41:07 2023
|
4 |
+
|
5 |
+
@author: Sen
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
import warnings
|
11 |
+
from tqdm import tqdm
|
12 |
+
import argparse
|
13 |
+
import torch
|
14 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
15 |
+
|
16 |
+
warnings.filterwarnings('ignore')
|
17 |
+
os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
18 |
+
os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
19 |
+
|
20 |
+
|
21 |
+
# Set up command line argument parsing
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument('-p', type=str, default=None, help='Input the protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.')
|
24 |
+
parser.add_argument('-f', type=str, default=None, help='Input the FASTA file. Default value is None. Only one of -p and -f should be specified.')
|
25 |
+
parser.add_argument('-l', type=str, default='', help='Input the ligand prompt. Default value is an empty string.')
|
26 |
+
parser.add_argument('-n', type=int, default=100, help='Number of output molecules to generate. Default value is 100.')
|
27 |
+
parser.add_argument('-d', type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.")
|
28 |
+
parser.add_argument('-o', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.")
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
protein_seq = args.p
|
33 |
+
fasta_file = args.f
|
34 |
+
ligand_prompt = args.l
|
35 |
+
num_generated = args.n
|
36 |
+
device = args.d
|
37 |
+
output_path = args.o
|
38 |
+
|
39 |
+
|
40 |
+
def ifno_mkdirs(dirname):
|
41 |
+
if not os.path.exists(dirname):
|
42 |
+
os.makedirs(dirname)
|
43 |
+
|
44 |
+
ifno_mkdirs(output_path)
|
45 |
+
|
46 |
+
# Function to read in FASTA file
|
47 |
+
def read_fasta_file(file_path):
|
48 |
+
with open(file_path, 'r') as fasta_file:
|
49 |
+
sequence = []
|
50 |
+
|
51 |
+
for line in fasta_file:
|
52 |
+
line = line.strip()
|
53 |
+
if not line.startswith('>'):
|
54 |
+
sequence.append(line)
|
55 |
+
|
56 |
+
protein_sequence = ''.join(sequence)
|
57 |
+
|
58 |
+
return protein_sequence
|
59 |
+
|
60 |
+
# Check if the input is either a protein amino acid sequence or a FASTA file, but not both
|
61 |
+
if (protein_seq is not None) != (fasta_file is not None):
|
62 |
+
if fasta_file is not None:
|
63 |
+
protein_seq = read_fasta_file(fasta_file)
|
64 |
+
else:
|
65 |
+
protein_seq = protein_seq
|
66 |
+
else:
|
67 |
+
print("The input should be either a protein amino acid sequence or a FASTA file, but not both.")
|
68 |
+
|
69 |
+
# Load the tokenizer and the model
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt')
|
71 |
+
model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt")
|
72 |
+
|
73 |
+
# Generate a prompt for the model
|
74 |
+
p_prompt = "<|startoftext|><P>" + protein_seq + "<L>"
|
75 |
+
l_prompt = "" + ligand_prompt
|
76 |
+
prompt = p_prompt + l_prompt
|
77 |
+
print(prompt)
|
78 |
+
|
79 |
+
# Move the model to the specified device
|
80 |
+
model.eval()
|
81 |
+
device = torch.device(device)
|
82 |
+
model.to(device)
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
#Define post-processing function
|
87 |
+
#Define function to generate SDF files from a list of ligand SMILES using OpenBabel
|
88 |
+
def get_sdf(ligand_list,output_path):
|
89 |
+
for ligand in tqdm(ligand_list):
|
90 |
+
filename = output_path + 'ligand_' + ligand +'.sdf'
|
91 |
+
cmd = "obabel -:" + ligand + " -osdf -O " + filename + " --gen3d --forcefield mmff94"# --conformer --nconf 1 --score rmsd
|
92 |
+
#subprocess.check_call(cmd, shell=True)
|
93 |
+
try:
|
94 |
+
# 设置超时时间为 30 秒
|
95 |
+
output = subprocess.check_output(cmd, timeout=10)
|
96 |
+
except subprocess.TimeoutExpired:
|
97 |
+
pass
|
98 |
+
#Define function to filter out empty SDF files
|
99 |
+
def filter_sdf(output_path):
|
100 |
+
filelist = os.listdir(output_path)
|
101 |
+
for filename in filelist:
|
102 |
+
filepath = os.path.join(output_path,filename)
|
103 |
+
with open(filepath,'r') as f:
|
104 |
+
text = f.read()
|
105 |
+
if len(text)<2:
|
106 |
+
os.remove(filepath)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
# Generate molecules
|
112 |
+
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
113 |
+
generated = generated.to(device)
|
114 |
+
|
115 |
+
|
116 |
+
for i in range(100):
|
117 |
+
ligand_list = []
|
118 |
+
sample_outputs = model.generate(
|
119 |
+
generated,
|
120 |
+
#bos_token_id=random.randint(1,30000),
|
121 |
+
do_sample=True,
|
122 |
+
top_k=5,
|
123 |
+
max_length = 1024,
|
124 |
+
top_p=0.6,
|
125 |
+
num_return_sequences=64
|
126 |
+
)
|
127 |
+
|
128 |
+
for i, sample_output in enumerate(sample_outputs):
|
129 |
+
ligand_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1])
|
130 |
+
torch.cuda.empty_cache()
|
131 |
+
|
132 |
+
get_sdf(ligand_list,output_path)
|
133 |
+
filter_sdf(output_path)
|
134 |
+
|
135 |
+
if len(os.listdir(output_path))>num_generated:
|
136 |
+
break
|
137 |
+
else:pass
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|