feiyang-cai commited on
Commit
4dab948
·
verified ·
1 Parent(s): f0ad687

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -11
utils.py CHANGED
@@ -59,10 +59,8 @@ class DataCollatorForCausalLMEval(object):
59
 
60
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
61
 
62
- print(instances)
63
  srcs = instances[0]['src']
64
  task_type = instances[0]['task_type']
65
- print(task_type)
66
 
67
  if task_type == 'retrosynthesis':
68
  src_start_str = self.product_start_str
@@ -78,7 +76,6 @@ class DataCollatorForCausalLMEval(object):
78
  data_dict = {
79
  'generation_prompts': generation_prompts
80
  }
81
- print(data_dict)
82
  return data_dict
83
 
84
  def smart_tokenizer_and_embedding_resize(
@@ -131,7 +128,6 @@ class ReactionPredictionModel():
131
  )
132
  self.load_forward_model(candidate_models[model])
133
 
134
- print(self.forward_model.device, self.retro_model.device)
135
  string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
136
  string_template = json.load(open(string_template_path, 'r'))
137
  reactant_start_str = string_template['REACTANTS_START_STRING']
@@ -220,8 +216,6 @@ class ReactionPredictionModel():
220
 
221
  if task_type == "retrosynthesis":
222
  inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
223
- print(inputs)
224
- print(self.retro_model.device)
225
  with torch.no_grad():
226
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
227
  do_sample=False, num_beams=10,
@@ -232,8 +226,6 @@ class ReactionPredictionModel():
232
  )
233
  else:
234
  inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
235
- print(inputs)
236
- print(self.forward_model.device)
237
  with torch.no_grad():
238
  outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
239
  do_sample=False, num_beams=10,
@@ -243,11 +235,9 @@ class ReactionPredictionModel():
243
  length_penalty=0.0,
244
  )
245
 
246
- print(outputs)
247
  original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
248
  skip_special_tokens=True)
249
  original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
250
- print(original_smiles_list)
251
  # canonize the SMILES
252
  canonized_smiles_list = []
253
  temp = []
@@ -262,7 +252,6 @@ class ReactionPredictionModel():
262
  predictions.append(canonized_smiles_list)
263
 
264
  rank, invalid_rate = compute_rank(predictions)
265
- print(predictions, rank)
266
  return rank
267
 
268
  def predict_single_smiles(self, smiles, task_type):
 
59
 
60
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
61
 
 
62
  srcs = instances[0]['src']
63
  task_type = instances[0]['task_type']
 
64
 
65
  if task_type == 'retrosynthesis':
66
  src_start_str = self.product_start_str
 
76
  data_dict = {
77
  'generation_prompts': generation_prompts
78
  }
 
79
  return data_dict
80
 
81
  def smart_tokenizer_and_embedding_resize(
 
128
  )
129
  self.load_forward_model(candidate_models[model])
130
 
 
131
  string_template_path = hf_hub_download(candidate_models[list(candidate_models.keys())[0]], filename="string_template.json", token = os.environ.get("TOKEN"))
132
  string_template = json.load(open(string_template_path, 'r'))
133
  reactant_start_str = string_template['REACTANTS_START_STRING']
 
216
 
217
  if task_type == "retrosynthesis":
218
  inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
 
 
219
  with torch.no_grad():
220
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
221
  do_sample=False, num_beams=10,
 
226
  )
227
  else:
228
  inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
 
 
229
  with torch.no_grad():
230
  outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
231
  do_sample=False, num_beams=10,
 
235
  length_penalty=0.0,
236
  )
237
 
 
238
  original_smiles_list = self.tokenizer.batch_decode(outputs.detach().cpu().numpy()[:, len(inputs['input_ids'][0]):],
239
  skip_special_tokens=True)
240
  original_smiles_list = map(lambda x: x.replace(" ", ""), original_smiles_list)
 
241
  # canonize the SMILES
242
  canonized_smiles_list = []
243
  temp = []
 
252
  predictions.append(canonized_smiles_list)
253
 
254
  rank, invalid_rate = compute_rank(predictions)
 
255
  return rank
256
 
257
  def predict_single_smiles(self, smiles, task_type):