simonduerr commited on
Commit
914ed63
·
1 Parent(s): 9084a83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -96,7 +96,7 @@ def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_s
96
  print("Cleaning up after protGPT2")
97
  print(gpu_usage())
98
  del protgpt2
99
- torch.cuda.empty_cache()
100
  device = cuda.get_current_device()
101
  device.reset()
102
  print(gpu_usage())
@@ -105,6 +105,8 @@ def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_s
105
 
106
  def run_alphafold(startsequence):
107
  print(gpu_usage())
 
 
108
  model_runners = {}
109
  models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
110
  for model_name in models:
@@ -127,8 +129,8 @@ def run_alphafold(startsequence):
127
  plddts = predict_structure("test", feature_dict, model_runners)
128
  print("Cleaning up after AF2")
129
  print(gpu_usage())
130
- backend = jax.lib.xla_bridge.get_backend()
131
- for buf in backend.live_buffers(): buf.delete()
132
  #device = cuda.get_current_device()
133
  #device.reset()
134
  #print(gpu_usage())
@@ -145,7 +147,7 @@ def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
145
  for i, seq in enumerate(gen_seqs):
146
  s = seq.replace("\n","")
147
  s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
148
- sequencestxt +=f">seq{i}\n{seq}\n"
149
  return sequencestxt
150
 
151
 
 
96
  print("Cleaning up after protGPT2")
97
  print(gpu_usage())
98
  del protgpt2
99
+ #torch.cuda.empty_cache()
100
  device = cuda.get_current_device()
101
  device.reset()
102
  print(gpu_usage())
 
105
 
106
  def run_alphafold(startsequence):
107
  print(gpu_usage())
108
+ device = cuda.get_current_device()
109
+ device.reset()
110
  model_runners = {}
111
  models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
112
  for model_name in models:
 
129
  plddts = predict_structure("test", feature_dict, model_runners)
130
  print("Cleaning up after AF2")
131
  print(gpu_usage())
132
+ #backend = jax.lib.xla_bridge.get_backend()
133
+ #for buf in backend.live_buffers(): buf.delete()
134
  #device = cuda.get_current_device()
135
  #device.reset()
136
  #print(gpu_usage())
 
147
  for i, seq in enumerate(gen_seqs):
148
  s = seq.replace("\n","")
149
  s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
150
+ sequencestxt +=f">seq{i}\n{s}\n"
151
  return sequencestxt
152
 
153