fffiloni commited on
Commit
917391d
·
verified ·
1 Parent(s): dd551fd

setup optimization if model do not changes

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -50,19 +50,20 @@ def clean_dir(save_dir):
50
  else:
51
  print(f"{save_dir} does not exist.")
52
 
53
- def start_over(gallery_state, loaded_model_setup):
54
  torch.cuda.empty_cache() # Free up cached memory
55
  gc.collect()
56
  if gallery_state is not None:
57
  gallery_state = None
58
- if loaded_model_setup is not None:
59
- loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
60
- return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
61
 
62
- def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
63
  gr.Info(f"Loading {model} model ...")
 
64
  if prompt is None or prompt == "":
65
  raise gr.Error("You forgot to provide a prompt !")
 
 
66
 
67
  """Clear CUDA memory before starting the training."""
68
  torch.cuda.empty_cache() # Free up cached memory
@@ -101,10 +102,24 @@ def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_i
101
  args.enable_multi_apply= True
102
  args.multi_step_model = "flux"
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  try:
105
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
106
- loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
107
- return f"{model} model loaded succesfully !", loaded_setup
108
 
109
  except Exception as e:
110
  print(f"Unexpected Error: {e}")
@@ -335,11 +350,11 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
335
 
336
  submit_btn.click(
337
  fn = start_over,
338
- inputs =[gallery_state, loaded_model_setup], # Reset loaded model setup as well
339
- outputs = [gallery_state, output_image, status, iter_gallery, loaded_model_setup] # Ensure loaded_model_setup is reset
340
  ).then(
341
  fn = setup_model,
342
- inputs = [prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
343
  outputs = [model_status, loaded_model_setup] # Load the new setup into the state
344
  ).then(
345
  fn = generate_image,
 
50
  else:
51
  print(f"{save_dir} does not exist.")
52
 
53
+ def start_over(gallery_state):
54
  torch.cuda.empty_cache() # Free up cached memory
55
  gc.collect()
56
  if gallery_state is not None:
57
  gallery_state = None
58
+ return gallery_state, None, None, gr.update(visible=False)
 
 
59
 
60
+ def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
61
  gr.Info(f"Loading {model} model ...")
62
+
63
  if prompt is None or prompt == "":
64
  raise gr.Error("You forgot to provide a prompt !")
65
+
66
+ print(f"LOADED_MODEL SETUP: {loaded_model_setup}")
67
 
68
  """Clear CUDA memory before starting the training."""
69
  torch.cuda.empty_cache() # Free up cached memory
 
102
  args.enable_multi_apply= True
103
  args.multi_step_model = "flux"
104
 
105
+ # Check if args are the same as the loaded_model_setup except for the prompt
106
+ if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
107
+ previous_args = loaded_model_setup[0]
108
+
109
+ # Exclude 'prompt' from comparison
110
+ new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'}
111
+ prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'}
112
+
113
+ if new_args_dict == prev_args_dict:
114
+ # If the arguments (excluding prompt) are the same, reuse the loaded setup
115
+ print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.")
116
+ return f"{model} model already loaded with the same configuration.", loaded_model_setup
117
+
118
+ # If other args differ, proceed with the setup
119
  try:
120
+ args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup)
121
+ new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
122
+ return f"{model} model loaded succesfully !", new_loaded_setup
123
 
124
  except Exception as e:
125
  print(f"Unexpected Error: {e}")
 
350
 
351
  submit_btn.click(
352
  fn = start_over,
353
+ inputs =[gallery_state],
354
+ outputs = [gallery_state, output_image, status, iter_gallery]
355
  ).then(
356
  fn = setup_model,
357
+ inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
358
  outputs = [model_status, loaded_model_setup] # Load the new setup into the state
359
  ).then(
360
  fn = generate_image,