Spaces:
Sleeping
Sleeping
setup optimization if model do not changes
Browse files
app.py
CHANGED
@@ -50,19 +50,20 @@ def clean_dir(save_dir):
|
|
50 |
else:
|
51 |
print(f"{save_dir} does not exist.")
|
52 |
|
53 |
-
def start_over(gallery_state
|
54 |
torch.cuda.empty_cache() # Free up cached memory
|
55 |
gc.collect()
|
56 |
if gallery_state is not None:
|
57 |
gallery_state = None
|
58 |
-
|
59 |
-
loaded_model_setup = None # Reset loaded model setup to prevent re-triggering old state
|
60 |
-
return gallery_state, None, None, gr.update(visible=False), loaded_model_setup
|
61 |
|
62 |
-
def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
|
63 |
gr.Info(f"Loading {model} model ...")
|
|
|
64 |
if prompt is None or prompt == "":
|
65 |
raise gr.Error("You forgot to provide a prompt !")
|
|
|
|
|
66 |
|
67 |
"""Clear CUDA memory before starting the training."""
|
68 |
torch.cuda.empty_cache() # Free up cached memory
|
@@ -101,10 +102,24 @@ def setup_model(prompt, model, seed, num_iterations, enable_hps, hps_w, enable_i
|
|
101 |
args.enable_multi_apply= True
|
102 |
args.multi_step_model = "flux"
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
try:
|
105 |
-
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
|
106 |
-
|
107 |
-
return f"{model} model loaded succesfully !",
|
108 |
|
109 |
except Exception as e:
|
110 |
print(f"Unexpected Error: {e}")
|
@@ -335,11 +350,11 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
|
|
335 |
|
336 |
submit_btn.click(
|
337 |
fn = start_over,
|
338 |
-
inputs =[gallery_state
|
339 |
-
outputs = [gallery_state, output_image, status, iter_gallery
|
340 |
).then(
|
341 |
fn = setup_model,
|
342 |
-
inputs = [prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
|
343 |
outputs = [model_status, loaded_model_setup] # Load the new setup into the state
|
344 |
).then(
|
345 |
fn = generate_image,
|
|
|
50 |
else:
|
51 |
print(f"{save_dir} does not exist.")
|
52 |
|
53 |
+
def start_over(gallery_state):
|
54 |
torch.cuda.empty_cache() # Free up cached memory
|
55 |
gc.collect()
|
56 |
if gallery_state is not None:
|
57 |
gallery_state = None
|
58 |
+
return gallery_state, None, None, gr.update(visible=False)
|
|
|
|
|
59 |
|
60 |
+
def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
|
61 |
gr.Info(f"Loading {model} model ...")
|
62 |
+
|
63 |
if prompt is None or prompt == "":
|
64 |
raise gr.Error("You forgot to provide a prompt !")
|
65 |
+
|
66 |
+
print(f"LOADED_MODEL SETUP: {loaded_model_setup}")
|
67 |
|
68 |
"""Clear CUDA memory before starting the training."""
|
69 |
torch.cuda.empty_cache() # Free up cached memory
|
|
|
102 |
args.enable_multi_apply= True
|
103 |
args.multi_step_model = "flux"
|
104 |
|
105 |
+
# Check if args are the same as the loaded_model_setup except for the prompt
|
106 |
+
if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
|
107 |
+
previous_args = loaded_model_setup[0]
|
108 |
+
|
109 |
+
# Exclude 'prompt' from comparison
|
110 |
+
new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'}
|
111 |
+
prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'}
|
112 |
+
|
113 |
+
if new_args_dict == prev_args_dict:
|
114 |
+
# If the arguments (excluding prompt) are the same, reuse the loaded setup
|
115 |
+
print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.")
|
116 |
+
return f"{model} model already loaded with the same configuration.", loaded_model_setup
|
117 |
+
|
118 |
+
# If other args differ, proceed with the setup
|
119 |
try:
|
120 |
+
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup)
|
121 |
+
new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
|
122 |
+
return f"{model} model loaded succesfully !", new_loaded_setup
|
123 |
|
124 |
except Exception as e:
|
125 |
print(f"Unexpected Error: {e}")
|
|
|
350 |
|
351 |
submit_btn.click(
|
352 |
fn = start_over,
|
353 |
+
inputs =[gallery_state],
|
354 |
+
outputs = [gallery_state, output_image, status, iter_gallery]
|
355 |
).then(
|
356 |
fn = setup_model,
|
357 |
+
inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
|
358 |
outputs = [model_status, loaded_model_setup] # Load the new setup into the state
|
359 |
).then(
|
360 |
fn = generate_image,
|