Loli-Killer commited on
Commit
fd20a3e
Β·
1 Parent(s): 1a696b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -34,8 +34,9 @@ model = proteinbind_new.create_proteinbind(True)
34
 
35
 
36
  def pass_through(torch_output, key: str):
 
37
  input_data = {
38
- key: torch_output,
39
  }
40
  output = model(input_data)
41
  return output[key]
 
34
 
35
 
36
  def pass_through(torch_output, key: str):
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  input_data = {
39
+ key: torch_output.type(torch.float32).to(device)
40
  }
41
  output = model(input_data)
42
  return output[key]