Spaces:
Running
on
Zero
Running
on
Zero
feiyang-cai
commited on
Update utils.py
Browse files
utils.py
CHANGED
@@ -208,18 +208,19 @@ class ReactionPredictionModel():
|
|
208 |
self.forward_model.to("cuda")
|
209 |
|
210 |
@spaces.GPU(duration=20)
|
211 |
-
def predict(self, test_loader):
|
212 |
predictions = []
|
213 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
214 |
with torch.no_grad():
|
215 |
generation_prompts = batch['generation_prompts'][0]
|
216 |
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
|
217 |
-
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
218 |
-
print(inputs)
|
219 |
-
print(self.forward_model.device)
|
220 |
-
print(self.retro_model.device)
|
221 |
del inputs['token_type_ids']
|
|
|
222 |
if task_type == "retrosynthesis":
|
|
|
|
|
|
|
|
|
223 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
224 |
do_sample=False, num_beams=10,
|
225 |
eos_token_id=self.tokenizer.eos_token_id,
|
@@ -228,6 +229,10 @@ class ReactionPredictionModel():
|
|
228 |
length_penalty=0.0,
|
229 |
)
|
230 |
else:
|
|
|
|
|
|
|
|
|
231 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
232 |
do_sample=False, num_beams=10,
|
233 |
eos_token_id=self.tokenizer.eos_token_id,
|
@@ -281,7 +286,7 @@ class ReactionPredictionModel():
|
|
281 |
collate_fn=self.data_collator,
|
282 |
)
|
283 |
|
284 |
-
rank = self.predict(test_loader)
|
285 |
|
286 |
return rank
|
287 |
|
|
|
208 |
self.forward_model.to("cuda")
|
209 |
|
210 |
@spaces.GPU(duration=20)
|
211 |
+
def predict(self, test_loader, task_type):
|
212 |
predictions = []
|
213 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
214 |
with torch.no_grad():
|
215 |
generation_prompts = batch['generation_prompts'][0]
|
216 |
inputs = self.tokenizer(generation_prompts, return_tensors="pt", padding=True, truncation=True)
|
|
|
|
|
|
|
|
|
217 |
del inputs['token_type_ids']
|
218 |
+
|
219 |
if task_type == "retrosynthesis":
|
220 |
+
self.retro_model.to("cuda")
|
221 |
+
inputs = {k: v.to(self.retro_model.device) for k, v in inputs.items()}
|
222 |
+
print(inputs)
|
223 |
+
print(self.retro_model.device)
|
224 |
outputs = self.retro_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
225 |
do_sample=False, num_beams=10,
|
226 |
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
229 |
length_penalty=0.0,
|
230 |
)
|
231 |
else:
|
232 |
+
self.forward_model.to("cuda")
|
233 |
+
inputs = {k: v.to(self.forward_model.device) for k, v in inputs.items()}
|
234 |
+
print(inputs)
|
235 |
+
print(self.forward_model.device)
|
236 |
outputs = self.forward_model.generate(**inputs, max_length=512, num_return_sequences=10,
|
237 |
do_sample=False, num_beams=10,
|
238 |
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
286 |
collate_fn=self.data_collator,
|
287 |
)
|
288 |
|
289 |
+
rank = self.predict(test_loader, task_type)
|
290 |
|
291 |
return rank
|
292 |
|