liyuesen commited on
Commit
5c3b926
·
1 Parent(s): d426ea6

Delete drug_generator.py

Browse files
Files changed (1) hide show
  1. drug_generator.py +0 -140
drug_generator.py DELETED
@@ -1,140 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Mon May 1 19:41:07 2023
4
-
5
- @author: Sen
6
- """
7
-
8
- import os
9
- import subprocess
10
- import warnings
11
- from tqdm import tqdm
12
- import argparse
13
- import torch
14
- from transformers import AutoTokenizer, GPT2LMHeadModel
15
-
16
- warnings.filterwarnings('ignore')
17
- #Sometimes, using Hugging Face may require a proxy.
18
- #os.environ["http_proxy"] = "http://your.proxy.server:port"
19
- #os.environ["https_proxy"] = "http://your.proxy.server:port"
20
-
21
-
22
- # Set up command line argument parsing
23
- parser = argparse.ArgumentParser()
24
- parser.add_argument('-p', type=str, default=None, help='Input the protein amino acid sequence. Default value is None. Only one of -p and -f should be specified.')
25
- parser.add_argument('-f', type=str, default=None, help='Input the FASTA file. Default value is None. Only one of -p and -f should be specified.')
26
- parser.add_argument('-l', type=str, default='', help='Input the ligand prompt. Default value is an empty string.')
27
- parser.add_argument('-n', type=int, default=100, help='Number of output molecules to generate. Default value is 100.')
28
- parser.add_argument('-d', type=str, default='cuda', help="Hardware device to use. Default value is 'cuda'.")
29
- parser.add_argument('-o', type=str, default='./ligand_output/', help="Output directory for generated molecules. Default value is './ligand_output/'.")
30
-
31
- args = parser.parse_args()
32
-
33
- protein_seq = args.p
34
- fasta_file = args.f
35
- ligand_prompt = args.l
36
- num_generated = args.n
37
- device = args.d
38
- output_path = args.o
39
-
40
-
41
- def ifno_mkdirs(dirname):
42
- if not os.path.exists(dirname):
43
- os.makedirs(dirname)
44
-
45
- ifno_mkdirs(output_path)
46
-
47
- # Function to read in FASTA file
48
- def read_fasta_file(file_path):
49
- with open(file_path, 'r') as fasta_file:
50
- sequence = []
51
-
52
- for line in fasta_file:
53
- line = line.strip()
54
- if not line.startswith('>'):
55
- sequence.append(line)
56
-
57
- protein_sequence = ''.join(sequence)
58
-
59
- return protein_sequence
60
-
61
- # Check if the input is either a protein amino acid sequence or a FASTA file, but not both
62
- if (protein_seq is not None) != (fasta_file is not None):
63
- if fasta_file is not None:
64
- protein_seq = read_fasta_file(fasta_file)
65
- else:
66
- protein_seq = protein_seq
67
- else:
68
- print("The input should be either a protein amino acid sequence or a FASTA file, but not both.")
69
-
70
- # Load the tokenizer and the model
71
- tokenizer = AutoTokenizer.from_pretrained('liyuesen/druggpt')
72
- model = GPT2LMHeadModel.from_pretrained("liyuesen/druggpt")
73
-
74
- # Generate a prompt for the model
75
- p_prompt = "<|startoftext|><P>" + protein_seq + "<L>"
76
- l_prompt = "" + ligand_prompt
77
- prompt = p_prompt + l_prompt
78
- print(prompt)
79
-
80
- # Move the model to the specified device
81
- model.eval()
82
- device = torch.device(device)
83
- model.to(device)
84
-
85
-
86
-
87
- #Define post-processing function
88
- #Define function to generate SDF files from a list of ligand SMILES using OpenBabel
89
- def get_sdf(ligand_list,output_path):
90
- for ligand in tqdm(ligand_list):
91
- filename = output_path + 'ligand_' + ligand +'.sdf'
92
- cmd = "obabel -:" + ligand + " -osdf -O " + filename + " --gen3d --forcefield mmff94"# --conformer --nconf 1 --score rmsd
93
- #subprocess.check_call(cmd, shell=True)
94
- try:
95
- output = subprocess.check_output(cmd, timeout=10)
96
- except subprocess.TimeoutExpired:
97
- pass
98
- #Define function to filter out empty SDF files
99
- def filter_sdf(output_path):
100
- filelist = os.listdir(output_path)
101
- for filename in filelist:
102
- filepath = os.path.join(output_path,filename)
103
- with open(filepath,'r') as f:
104
- text = f.read()
105
- if len(text)<2:
106
- os.remove(filepath)
107
-
108
-
109
-
110
-
111
- # Generate molecules
112
- generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
113
- generated = generated.to(device)
114
-
115
-
116
- for i in range(100):
117
- ligand_list = []
118
- sample_outputs = model.generate(
119
- generated,
120
- #bos_token_id=random.randint(1,30000),
121
- do_sample=True,
122
- top_k=5,
123
- max_length = 1024,
124
- top_p=0.6,
125
- num_return_sequences=64
126
- )
127
-
128
- for i, sample_output in enumerate(sample_outputs):
129
- ligand_list.append(tokenizer.decode(sample_output, skip_special_tokens=True).split('<L>')[1])
130
- torch.cuda.empty_cache()
131
-
132
- get_sdf(ligand_list,output_path)
133
- filter_sdf(output_path)
134
-
135
- if len(os.listdir(output_path))>num_generated:
136
- break
137
- else:pass
138
-
139
-
140
-