Spaces:
Sleeping
Sleeping
feiyang-cai
commited on
Update utils.py
Browse files
utils.py
CHANGED
@@ -59,10 +59,8 @@ class DataCollatorForCausalLMEval(object):
|
|
59 |
|
60 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
61 |
|
62 |
-
print(instances)
|
63 |
srcs = instances[0]['src']
|
64 |
task_type = instances[0]['task_type']
|
65 |
-
print(task_type)
|
66 |
|
67 |
if task_type == 'retrosynthesis':
|
68 |
src_start_str = self.product_start_str
|
@@ -78,7 +76,6 @@ class DataCollatorForCausalLMEval(object):
|
|
78 |
data_dict = {
|
79 |
'generation_prompts': generation_prompts
|
80 |
}
|
81 |
-
print(data_dict)
|
82 |
return data_dict
|
83 |
|
84 |
def smart_tokenizer_and_embedding_resize(
|
@@ -131,7 +128,6 @@ class ReactionPredictionModel():
|
|
131 |
)
|
132 |
self.load_forward_model(candidate_models[model])
|
133 |
|
134 |
-
print(self.forward_model.device, self.retro_model.device)
|
135 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
136 |
string_template = json.load(open(string_template_path, 'r'))
|
137 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
@@ -220,8 +216,6 @@ class ReactionPredictionModel():
|
|
220 |
|
221 |
if task_type == "retrosynthesis":
|
222 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
223 |
-
print(inputs)
|
224 |
-
print(self.retro_model.device)
|
225 |
with torch.no_grad():
|
226 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
227 |
do_sample=False, num_beams=10,
|
@@ -232,8 +226,6 @@ class ReactionPredictionModel():
|
|
232 |
)
|
233 |
else:
|
234 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
235 |
-
print(inputs)
|
236 |
-
print(self.forward_model.device)
|
237 |
with torch.no_grad():
|
238 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
239 |
do_sample=False, num_beams=10,
|
@@ -243,11 +235,9 @@ class ReactionPredictionModel():
|
|
243 |
length_penalty=0.0,
|
244 |
)
|
245 |
|
246 |
-
print(outputs)
|
247 |
original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
|
248 |
skip_special_tokens=True)
|
249 |
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
250 |
-
print(original_smiles_list)
|
251 |
# canonize the SMILES
|
252 |
canonized_smiles_list = []
|
253 |
temp = []
|
@@ -262,7 +252,6 @@ class ReactionPredictionModel():
|
|
262 |
predictions.append(canonized_smiles_list)
|
263 |
|
264 |
rank, invalid_rate = compute_rank(predictions)
|
265 |
-
print(predictions, rank)
|
266 |
return rank
|
267 |
|
268 |
def predict_single_smiles(self, smiles, task_type):
|
|
|
59 |
|
60 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
61 |
|
|
|
62 |
srcs = instances[0]['src']
|
63 |
task_type = instances[0]['task_type']
|
|
|
64 |
|
65 |
if task_type == 'retrosynthesis':
|
66 |
src_start_str = self.product_start_str
|
|
|
76 |
data_dict = {
|
77 |
'generation_prompts': generation_prompts
|
78 |
}
|
|
|
79 |
return data_dict
|
80 |
|
81 |
def smart_tokenizer_and_embedding_resize(
|
|
|
128 |
)
|
129 |
self.load_forward_model(candidate_models[model])
|
130 |
|
|
|
131 |
string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
|
132 |
string_template = json.load(open(string_template_path, 'r'))
|
133 |
reactant_start_str = string_template['REACTANTS_START_STRING']
|
|
|
216 |
|
217 |
if task_type == "retrosynthesis":
|
218 |
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
|
|
|
|
219 |
with torch.no_grad():
|
220 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
221 |
do_sample=False, num_beams=10,
|
|
|
226 |
)
|
227 |
else:
|
228 |
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
|
|
|
|
229 |
with torch.no_grad():
|
230 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
231 |
do_sample=False, num_beams=10,
|
|
|
235 |
length_penalty=0.0,
|
236 |
)
|
237 |
|
|
|
238 |
original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
|
239 |
skip_special_tokens=True)
|
240 |
original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
|
|
|
241 |
# canonize the SMILES
|
242 |
canonized_smiles_list = []
|
243 |
temp = []
|
|
|
252 |
predictions.append(canonized_smiles_list)
|
253 |
|
254 |
rank, invalid_rate = compute_rank(predictions)
|
|
|
255 |
return rank
|
256 |
|
257 |
def predict_single_smiles(self, smiles, task_type):
|