feiyang-cai commited on
Commit
d87c5bb
·
verified ·
1 Parent(s): cd388a2

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +11 -6
utils.py CHANGED
@@ -208,18 +208,19 @@ class ReactionPredictionModel():
208
  self.forward_model.to("cuda")
209
 
210
  @spaces.GPU(duration=20)
211
- def predict(self, test_loader):
212
  predictions = []
213
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
214
  with torch.no_grad():
215
  generation_prompts = batch['generation_prompts'][0]
216
  inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
217
- inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
218
- print(inputs)
219
- print(self.forward_model.device)
220
- print(self.retro_model.device)
221
  del inputs['token_type_ids']
 
222
  if task_type == "retrosynthesis":
 
 
 
 
223
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
224
  do_sample=False, num_beams=10,
225
  eos_token_id=self.tokenizer.eos_token_id,
@@ -228,6 +229,10 @@ class ReactionPredictionModel():
228
  length_penalty=0.0,
229
  )
230
  else:
 
 
 
 
231
  outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
232
  do_sample=False, num_beams=10,
233
  eos_token_id=self.tokenizer.eos_token_id,
@@ -281,7 +286,7 @@ class ReactionPredictionModel():
281
  collate_fn=self.data_collator,
282
  )
283
 
284
- rank = self.predict(test_loader)
285
 
286
  return rank
287
 
 
208
  self.forward_model.to("cuda")
209
 
210
  @spaces.GPU(duration=20)
211
+ def predict(self, test_loader, task_type):
212
  predictions = []
213
  for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
214
  with torch.no_grad():
215
  generation_prompts = batch['generation_prompts'][0]
216
  inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
217
  del inputs['token_type_ids']
218
+
219
  if task_type == "retrosynthesis":
220
+ self.retro_model.to("cuda")
221
+ inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
222
+ print(inputs)
223
+ print(self.retro_model.device)
224
  outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
225
  do_sample=False, num_beams=10,
226
  eos_token_id=self.tokenizer.eos_token_id,
 
229
  length_penalty=0.0,
230
  )
231
  else:
232
+ self.forward_model.to("cuda")
233
+ inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
234
+ print(inputs)
235
+ print(self.forward_model.device)
236
  outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
237
  do_sample=False, num_beams=10,
238
  eos_token_id=self.tokenizer.eos_token_id,
 
286
  collate_fn=self.data_collator,
287
  )
288
 
289
+ rank = self.predict(test_loader, task_type)
290
 
291
  return rank
292