fffiloni commited on
Commit
f57f3d1
·
verified ·
1 Parent(s): 3e6b0ce

GPU management optimizations

Browse files
Files changed (1) hide show
  1. app.py +389 -342
app.py CHANGED
@@ -1,372 +1,419 @@
 
 
 
 
 
1
  import torch
2
  import gc
3
- import gradio as gr
4
- from main import setup, execute_task
 
 
5
  from arguments import parse_args
6
- import os
7
- import shutil
8
- import glob
9
- import time
10
- import threading
11
- import argparse
12
-
13
- def list_iter_images(save_dir):
14
- # Specify only PNG images
15
- image_extension = 'png'
16
-
17
- # Create a list to store the image file paths
18
- image_paths = []
19
-
20
- # Use glob to find all PNG image files
21
- all_images = glob.glob(os.path.join(save_dir, f'*.{image_extension}'))
22
-
23
- # Filter out 'best_image.png'
24
- image_paths = [img for img in all_images if os.path.basename(img) != 'best_image.png']
25
-
26
- return image_paths
27
-
28
- def clean_dir(save_dir):
29
- # Check if the directory exists
30
- if os.path.exists(save_dir):
31
- # Check if the directory contains any files
32
- if len(os.listdir(save_dir)) > 0:
33
- # If it contains files, delete all files in the directory
34
- for filename in os.listdir(save_dir):
35
- file_path = os.path.join(save_dir, filename)
36
- try:
37
- if os.path.isfile(file_path) or os.path.islink(file_path):
38
- os.unlink(file_path) # Remove file or symbolic link
39
- elif os.path.isdir(file_path):
40
- shutil.rmtree(file_path) # Remove directory and its contents
41
- except Exception as e:
42
- print(f"Failed to delete {file_path}. Reason: {e}")
43
- print(f"All files in {save_dir} have been deleted.")
44
- else:
45
- print(f"{save_dir} exists but is empty.")
46
- else:
47
- print(f"{save_dir} does not exist.")
48
 
49
- def start_over(gallery_state):
50
- torch.cuda.empty_cache() # Free up cached memory
51
- gc.collect()
52
- if gallery_state is not None:
53
- gallery_state = None
54
- return gallery_state, None, None, gr.update(visible=False)
55
 
56
- def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate, progress=gr.Progress(track_tqdm=True)):
57
- gr.Info(f"Loading {model} model ...")
58
-
59
- if prompt is None or prompt == "":
60
- raise gr.Error("You forgot to provide a prompt !")
61
 
62
- print(f"LOADED_MODEL SETUP: {loaded_model_setup}")
 
 
 
 
63
 
64
- """Clear CUDA memory before starting the training."""
65
- torch.cuda.empty_cache() # Free up cached memory
66
  gc.collect()
 
 
 
 
 
67
 
68
- # Set up arguments
69
- args = parse_args()
70
- args.task = "single"
71
- args.prompt = prompt
72
- args.model = model
73
- args.seed = seed
74
- args.n_iters = num_iterations
75
- args.lr = learning_rate
76
- args.cache_dir = "./HF_model_cache"
77
- args.save_dir = "./outputs"
78
- args.save_all_images = True
79
-
80
- if enable_hps is True:
81
- args.disable_hps = False
82
- args.hps_weighting = hps_w
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- if enable_imagereward is True:
85
- args.disable_imagereward = False
86
- args.imagereward_weighting = imgrw_w
 
 
 
 
 
 
87
 
88
- if enable_pickscore is True:
89
- args.disable_pickscore = False
90
- args.pickscore_weighting = pcks_w
91
 
92
- if enable_clip is True:
93
- args.disable_clip = False
94
- args.clip_weighting = clip_w
95
-
96
- if model == "flux":
97
- args.cpu_offloading = True
98
- args.enable_multi_apply = True
99
- args.multi_step_model = "flux"
100
-
101
- # Check if args are the same as the loaded_model_setup except for the prompt
102
- if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
103
- previous_args = loaded_model_setup[0]
104
 
105
- # Exclude 'prompt' from comparison
106
- new_args_dict = {k: v for k, v in args.__dict__.items() if k != 'prompt'}
107
- prev_args_dict = {k: v for k, v in previous_args.__dict__.items() if k != 'prompt'}
 
 
 
 
 
 
 
 
108
 
109
- if new_args_dict == prev_args_dict:
110
- # If the arguments (excluding prompt) are the same, reuse the loaded setup
111
- print(f"Arguments (excluding prompt) are the same, reusing loaded setup for {model} model.")
112
-
113
- # Update the prompt in the loaded_model_setup
114
- loaded_model_setup[0].prompt = prompt
115
-
116
- return f"{model} model already loaded with the same configuration.", loaded_model_setup
117
-
118
- # Attempt to set up the model
119
- try:
120
- # If other args differ, proceed with the setup
121
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup)
122
- new_loaded_setup = [args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings]
123
- return f"{model} model loaded successfully!", new_loaded_setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- except Exception as e:
126
- print(f"Failed to load {model} model: {e}.")
127
- return f"Failed to load {model} model: {e}. You can try again, as it usually finally loads on the second try :)", None
128
-
129
 
130
- def generate_image(setup_args, num_iterations):
131
- torch.cuda.empty_cache() # Free up cached memory
 
 
 
 
 
132
  gc.collect()
133
 
134
- gr.Info(f"Executing iterations task ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- args = setup_args[0]
137
- trainer = setup_args[1]
138
- device = setup_args[2]
139
- dtype = setup_args[3]
140
- shape = setup_args[4]
141
- enable_grad = setup_args[5]
142
- multi_apply_fn = setup_args[6]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- settings = setup_args[7]
145
- print(f"SETTINGS: {settings}")
 
146
 
147
- save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
148
- clean_dir(save_dir)
149
 
150
- try:
151
- torch.cuda.empty_cache() # Free up cached memory
152
- gc.collect()
153
- steps_completed = []
154
- result_container = {"best_image": None, "total_init_rewards": None, "total_best_rewards": None}
155
- error_status = {"error_occurred": False} # Shared dictionary to track error status
156
- thread_status = {"running": False} # Track whether a thread is already running
157
-
158
- def progress_callback(step):
159
- # Limit redundant prints by checking the step number
160
- if not steps_completed or step > steps_completed[-1]:
161
- steps_completed.append(step)
162
- print(f"Progress: Step {step} completed.")
163
-
164
- def run_main():
165
- thread_status["running"] = True # Mark thread as running
166
- try:
167
- execute_task(
168
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback
169
- )
170
- except torch.cuda.OutOfMemoryError as e:
171
- print(f"CUDA Out of Memory Error: {e}")
172
- error_status["error_occurred"] = True
173
- except RuntimeError as e:
174
- if 'out of memory' in str(e):
175
- print(f"Runtime Error: {e}")
176
- error_status["error_occurred"] = True
177
- else:
178
- raise
179
- finally:
180
- thread_status["running"] = False # Mark thread as completed
181
-
182
- if not thread_status["running"]: # Ensure no other thread is running
183
- main_thread = threading.Thread(target=run_main)
184
- main_thread.start()
185
-
186
- last_step_yielded = 0
187
- while main_thread.is_alive() and not error_status["error_occurred"]:
188
- # Check if new steps have been completed
189
- if steps_completed and steps_completed[-1] > last_step_yielded:
190
- last_step_yielded = steps_completed[-1]
191
- png_number = last_step_yielded - 1
192
- # Get the image for this step
193
- image_path = os.path.join(save_dir, f"{png_number}.png")
194
- if os.path.exists(image_path):
195
- yield (image_path, f"Iteration {last_step_yielded}/{num_iterations} - Image saved", None)
196
- else:
197
- yield (None, f"Iteration {last_step_yielded}/{num_iterations} - Image not found", None)
198
- else:
199
- time.sleep(0.1) # Sleep to prevent busy waiting
200
-
201
- if error_status["error_occurred"]:
202
- torch.cuda.empty_cache() # Free up cached memory
203
- gc.collect()
204
- yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
205
- else:
206
- main_thread.join() # Ensure thread completion
207
- final_image_path = os.path.join(save_dir, "best_image.png")
208
- if os.path.exists(final_image_path):
209
- iter_images = list_iter_images(save_dir)
210
- torch.cuda.empty_cache() # Free up cached memory
211
- gc.collect()
212
- time.sleep(0.5)
213
- yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
214
- else:
215
- torch.cuda.empty_cache() # Free up cached memory
216
- gc.collect()
217
- yield (None, "Image generation completed, but no final image was found.", None)
218
 
219
- torch.cuda.empty_cache() # Free up cached memory
220
- gc.collect()
221
 
222
- except torch.cuda.OutOfMemoryError as e:
223
- print(f"Global CUDA Out of Memory Error: {e}")
224
- yield (None, f"{e}", None)
225
- except RuntimeError as e:
226
- if 'out of memory' in str(e):
227
- print(f"Runtime Error: {e}")
228
- yield (None, f"{e}", None)
 
 
 
229
  else:
230
- yield (None, f"An error occurred: {str(e)}", None)
231
- except Exception as e:
232
- print(f"Unexpected Error: {e}")
233
- yield (None, f"An unexpected error occurred: {str(e)}", None)
234
-
235
- def show_gallery_output(gallery_state):
236
- if gallery_state is not None:
237
- return gr.update(value=gallery_state, visible=True)
238
- else:
239
- return gr.update(value=None, visible=False)
240
-
241
- # Create Gradio interface
242
- title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
243
- description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
244
-
245
- css="""
246
- #model-status-id{
247
- height: 126px;
248
- }
249
- #model-status-id .progress-text{
250
- font-size: 10px!important;
251
- }
252
- #model-status-id .progress-level-inner{
253
- font-size: 8px!important;
254
- }
255
- """
256
-
257
- with gr.Blocks(css=css, analytics_enabled=False) as demo:
258
- loaded_model_setup = gr.State()
259
- gallery_state = gr.State()
260
- with gr.Column():
261
- gr.Markdown(title)
262
- gr.Markdown(description)
263
- gr.HTML("""
264
- <div style="display:flex;column-gap:4px;">
265
- <a href='https://github.com/ExplainableML/ReNO'>
266
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
267
- </a>
268
- <a href='https://arxiv.org/abs/2406.04312v1'>
269
- <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
270
- </a>
271
- </div>
272
- """)
273
-
274
- with gr.Row():
275
- with gr.Column():
276
- prompt = gr.Textbox(label="Prompt")
277
- with gr.Row():
278
- chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
279
- seed = gr.Number(label="seed", value=0)
280
-
281
- model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
282
-
283
- with gr.Row():
284
- n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=10, label="Number of Iterations")
285
- learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
286
-
287
- with gr.Accordion("Advanced Settings", open=True):
288
- with gr.Column():
289
- with gr.Row():
290
- enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1)
291
- hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
292
- with gr.Row():
293
- enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1)
294
- imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
295
- with gr.Row():
296
- enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1)
297
- pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3)
298
- with gr.Row():
299
- enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1)
300
- clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
301
-
302
- submit_btn = gr.Button("Submit")
303
-
304
- gr.Examples(
305
- examples = [
306
- "A red dog and a green cat",
307
- "A pink elephant and a grey cow",
308
- "A toaster riding a bike",
309
- "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski",
310
- "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
311
- "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains"
312
- ],
313
- inputs = [prompt]
314
- )
315
 
316
- with gr.Column():
317
- output_image = gr.Image(type="filepath", label="Best Generated Image")
318
- status = gr.Textbox(label="Status")
319
- iter_gallery = gr.Gallery(label="Iterations", columns=4, visible=False)
320
-
321
- def allow_weighting(weight_type):
322
- if weight_type is True:
323
- return gr.update(interactive=True)
 
 
 
 
 
 
 
 
324
  else:
325
- return gr.update(interactive=False)
326
-
327
- enable_hps.change(
328
- fn = allow_weighting,
329
- inputs = [enable_hps],
330
- outputs = [hps_w],
331
- queue = False
332
- )
333
- enable_imagereward.change(
334
- fn = allow_weighting,
335
- inputs = [enable_imagereward],
336
- outputs = [imgrw_w],
337
- queue = False
338
- )
339
- enable_pickscore.change(
340
- fn = allow_weighting,
341
- inputs = [enable_pickscore],
342
- outputs = [pcks_w],
343
- queue = False
344
- )
345
- enable_clip.change(
346
- fn = allow_weighting,
347
- inputs = [enable_clip],
348
- outputs = [clip_w],
349
- queue = False
350
- )
351
 
 
 
352
 
353
- submit_btn.click(
354
- fn = start_over,
355
- inputs =[gallery_state],
356
- outputs = [gallery_state, output_image, status, iter_gallery]
357
- ).then(
358
- fn = setup_model,
359
- inputs = [loaded_model_setup, prompt, chosen_model, seed, n_iter, enable_hps, hps_w, enable_imagereward, imgrw_w, enable_pickscore, pcks_w, enable_clip, clip_w, learning_rate],
360
- outputs = [model_status, loaded_model_setup] # Load the new setup into the state
361
- ).then(
362
- fn = generate_image,
363
- inputs = [loaded_model_setup, n_iter],
364
- outputs = [output_image, status, gallery_state]
365
- ).then(
366
- fn = show_gallery_output,
367
- inputs = [gallery_state],
368
- outputs = iter_gallery
369
- )
370
 
371
- # Launch the app
372
- demo.queue().launch(show_error=True, show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ import blobfile as bf
6
  import torch
7
  import gc
8
+ from datasets import load_dataset
9
+ from pytorch_lightning import seed_everything
10
+ from tqdm import tqdm
11
+
12
  from arguments import parse_args
13
+ from models import get_model, get_multi_apply_fn
14
+ from rewards import get_reward_losses
15
+ from training import LatentNoiseTrainer, get_optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
17
 
18
+ import torch
19
+ import gc
 
 
 
20
 
21
+ def clear_gpu():
22
+ """Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
23
+ # List memory usage before clearing
24
+ print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
25
+ print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
26
 
27
+ # Force the garbage collector to free unreferenced objects
 
28
  gc.collect()
29
+
30
+ # Move any bound tensors back to CPU if needed
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache() # Free up the cached memory
33
+ torch.cuda.ipc_collect() # Clear any cross-process memory
34
 
35
+ print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
36
+ print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
37
+
38
+ def unload_previous_model_if_needed(loaded_model_setup):
39
+ """Unload the current model from the GPU and free resources if a new model is being loaded."""
40
+ if loaded_model_setup is not None:
41
+ print("Unloading previous model from GPU to free memory.")
42
+ previous_model = loaded_model_setup[7] # Assuming pipe is at position [7] in the setup
43
+ if hasattr(previous_model, 'to') and loaded_model_setup[0].model != "flux":
44
+ previous_model.to('cpu') # Move model to CPU to free GPU memory
45
+ del previous_model # Delete the reference to the model
46
+ clear_gpu() # Clear all remaining GPU memory
47
+
48
+ def setup(args, loaded_model_setup=None):
49
+ seed_everything(args.seed)
50
+ bf.makedirs(f"{args.save_dir}/logs/{args.task}")
51
+
52
+ # Set up logging and name settings
53
+ logger = logging.getLogger()
54
+ logger.handlers.clear() # Clear existing handlers
55
+ settings = (
56
+ f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
57
+ f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}"
58
+ f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}"
59
+ f"_reg{args.reg_weight if args.enable_reg else '0'}"
60
+ f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}"
61
+ f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}"
62
+ f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}"
63
+ f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
64
+ f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
65
+ )
66
 
67
+ file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
68
+ handler = logging.StreamHandler(file_stream)
69
+ formatter = logging.Formatter("%(asctime)s - %(message)s")
70
+ handler.setFormatter(formatter)
71
+ logger.addHandler(handler)
72
+ logger.setLevel("INFO")
73
+ consoleHandler = logging.StreamHandler()
74
+ consoleHandler.setFormatter(formatter)
75
+ logger.addHandler(consoleHandler)
76
 
77
+ logging.info(args)
 
 
78
 
79
+ if args.device_id is not None:
80
+ logging.info(f"Using CUDA device {args.device_id}")
81
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
82
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
83
+
84
+ device = torch.device("cuda")
85
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
86
+
87
+ # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
88
+ if loaded_model_setup and args.model == loaded_model_setup[0].model:
89
+ print(f"Reusing model {args.model} from loaded setup.")
90
+ trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
91
 
92
+ # Update trainer with the new arguments
93
+ trainer.n_iters = args.n_iters
94
+ trainer.n_inference_steps = args.n_inference_steps
95
+ trainer.seed = args.seed
96
+ trainer.save_all_images = args.save_all_images
97
+ trainer.no_optim = args.no_optim
98
+ trainer.regularize = args.enable_reg
99
+ trainer.regularization_weight = args.reg_weight
100
+ trainer.grad_clip = args.grad_clip
101
+ trainer.log_metrics = args.task == "single" or not args.no_optim
102
+ trainer.imageselect = args.imageselect
103
 
104
+ # Get latents (this step is still required)
105
+ if args.model == "flux":
106
+ shape = (1, 16 * 64, 64)
107
+ elif args.model != "pixart":
108
+ height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
109
+ width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
110
+ shape = (
111
+ 1,
112
+ trainer.model.unet.in_channels,
113
+ height // trainer.model.vae_scale_factor,
114
+ width // trainer.model.vae_scale_factor,
115
+ )
116
+ else:
117
+ height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
118
+ width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
119
+ shape = (
120
+ 1,
121
+ trainer.model.transformer.config.in_channels,
122
+ height // trainer.model.vae_scale_factor,
123
+ width // trainer.model.vae_scale_factor,
124
+ )
125
+
126
+ pipe = loaded_model_setup[7]
127
+ enable_grad = not args.no_optim
128
+
129
+ return args, trainer, device, dtype, shape, enable_grad, settings, pipe
130
+
131
+ # Unload previous model and clear GPU resources
132
+ unload_previous_model_if_needed(loaded_model_setup)
133
+
134
+ # Proceed with full model loading if args.model is different
135
+ print(f"Loading new model: {args.model}")
136
 
137
+ # Get reward losses
138
+ reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
 
 
139
 
140
+ # Get model and noise trainer
141
+ pipe = get_model(
142
+ args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
143
+ )
144
+
145
+ # Final memory cleanup after model loading
146
+ torch.cuda.empty_cache()
147
  gc.collect()
148
 
149
+ trainer = LatentNoiseTrainer(
150
+ reward_losses=reward_losses,
151
+ model=pipe,
152
+ n_iters=args.n_iters,
153
+ n_inference_steps=args.n_inference_steps,
154
+ seed=args.seed,
155
+ save_all_images=args.save_all_images,
156
+ device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled
157
+ no_optim=args.no_optim,
158
+ regularize=args.enable_reg,
159
+ regularization_weight=args.reg_weight,
160
+ grad_clip=args.grad_clip,
161
+ log_metrics=args.task == "single" or not args.no_optim,
162
+ imageselect=args.imageselect,
163
+ )
164
 
165
+ # Create latents
166
+ if args.model == "flux":
167
+ shape = (1, 16 * 64, 64)
168
+ elif args.model != "pixart":
169
+ height = pipe.unet.config.sample_size * pipe.vae_scale_factor
170
+ width = pipe.unet.config.sample_size * pipe.vae_scale_factor
171
+ shape = (
172
+ 1,
173
+ pipe.unet.in_channels,
174
+ height // pipe.vae_scale_factor,
175
+ width // pipe.vae_scale_factor,
176
+ )
177
+ else:
178
+ height = pipe.transformer.config.sample_size * pipe.vae_scale_factor
179
+ width = pipe.transformer.config.sample_size * pipe.vae_scale_factor
180
+ shape = (
181
+ 1,
182
+ pipe.transformer.config.in_channels,
183
+ height // pipe.vae_scale_factor,
184
+ width // pipe.vae_scale_factor,
185
+ )
186
+
187
+ enable_grad = not args.no_optim
188
 
189
+ # Final memory cleanup
190
+ torch.cuda.empty_cache() # Free up cached memory
191
+ gc.collect()
192
 
 
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ return args, trainer, device, dtype, shape, enable_grad, settings, pipe
 
196
 
197
+
198
+
199
+
200
+ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None):
201
+
202
+ if args.task == "single":
203
+ # Attempt to move the model to GPU if model is not Flux
204
+ if args.model != "flux":
205
+ if pipe.device != torch.device('cuda'):
206
+ pipe.to(device, dtype)
207
  else:
208
+ print(f"PIPE:{pipe}")
209
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
+ if args.cpu_offloading:
212
+ pipe.enable_sequential_cpu_offload()
213
+
214
+ #if pipe.device != torch.device('cuda'):
215
+ # pipe.to(device, dtype)
216
+
217
+ if args.enable_multi_apply:
218
+
219
+ multi_apply_fn = get_multi_apply_fn(
220
+ model_type=args.multi_step_model,
221
+ seed=args.seed,
222
+ pipe=pipe,
223
+ cache_dir=args.cache_dir,
224
+ device=device if not args.cpu_offloading else 'cpu',
225
+ dtype=dtype,
226
+ )
227
  else:
228
+ multi_apply_fn = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ torch.cuda.empty_cache() # Free up cached memory
231
+ gc.collect()
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ init_latents = torch.randn(shape, device=device, dtype=dtype)
235
+ latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
236
+ optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
237
+ save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}"
238
+ os.makedirs(f"{save_dir}", exist_ok=True)
239
+ init_image, best_image, total_init_rewards, total_best_rewards = trainer.train(
240
+ latents, args.prompt, optimizer, save_dir, multi_apply_fn, progress_callback=progress_callback
241
+ )
242
+ best_image.save(f"{save_dir}/best_image.png")
243
+ #init_image.save(f"{save_dir}/init_image.png")
244
+
245
+ elif args.task == "example-prompts":
246
+ fo = open("assets/example_prompts.txt", "r")
247
+ prompts = fo.readlines()
248
+ fo.close()
249
+ for i, prompt in tqdm(enumerate(prompts)):
250
+ # Get new latents and optimizer
251
+ init_latents = torch.randn(shape, device=device, dtype=dtype)
252
+ latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
253
+ optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
254
+
255
+ prompt = prompt.strip()
256
+ name = f"{i:03d}_{prompt[:150]}.png"
257
+ save_dir = f"{args.save_dir}/{args.task}/{settings}/{name}"
258
+ os.makedirs(save_dir, exist_ok=True)
259
+ init_image, best_image, init_rewards, best_rewards = trainer.train(
260
+ latents, prompt, optimizer, save_dir, multi_apply_fn
261
+ )
262
+ if i == 0:
263
+ total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
264
+ total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
265
+ for k in best_rewards.keys():
266
+ total_best_rewards[k] += best_rewards[k]
267
+ total_init_rewards[k] += init_rewards[k]
268
+ best_image.save(f"{save_dir}/best_image.png")
269
+ init_image.save(f"{save_dir}/init_image.png")
270
+ logging.info(f"Initial rewards: {init_rewards}")
271
+ logging.info(f"Best rewards: {best_rewards}")
272
+ for k in total_best_rewards.keys():
273
+ total_best_rewards[k] /= len(prompts)
274
+ total_init_rewards[k] /= len(prompts)
275
+
276
+ # save results to directory
277
+ with open(f"{args.save_dir}/example-prompts/{settings}/results.txt", "w") as f:
278
+ f.write(
279
+ f"Mean initial all rewards: {total_init_rewards}\n"
280
+ f"Mean best all rewards: {total_best_rewards}\n"
281
+ )
282
+ elif args.task == "t2i-compbench":
283
+ prompt_list_file = f"../T2I-CompBench/examples/dataset/{args.prompt}.txt"
284
+ fo = open(prompt_list_file, "r")
285
+ prompts = fo.readlines()
286
+ fo.close()
287
+ os.makedirs(f"{args.save_dir}/{args.task}/{settings}/samples", exist_ok=True)
288
+ for i, prompt in tqdm(enumerate(prompts)):
289
+ # Get new latents and optimizer
290
+ init_latents = torch.randn(shape, device=device, dtype=dtype)
291
+ latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
292
+ optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
293
+
294
+ prompt = prompt.strip()
295
+ init_image, best_image, init_rewards, best_rewards = trainer.train(
296
+ latents, prompt, optimizer, None, multi_apply_fn
297
+ )
298
+ if i == 0:
299
+ total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
300
+ total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
301
+ for k in best_rewards.keys():
302
+ total_best_rewards[k] += best_rewards[k]
303
+ total_init_rewards[k] += init_rewards[k]
304
+ name = f"{prompt}_{i:06d}.png"
305
+ best_image.save(f"{args.save_dir}/{args.task}/{settings}/samples/{name}")
306
+ logging.info(f"Initial rewards: {init_rewards}")
307
+ logging.info(f"Best rewards: {best_rewards}")
308
+ for k in total_best_rewards.keys():
309
+ total_best_rewards[k] /= len(prompts)
310
+ total_init_rewards[k] /= len(prompts)
311
+ elif args.task == "parti-prompts":
312
+ parti_dataset = load_dataset("nateraw/parti-prompts", split="train")
313
+ total_reward_diff = 0.0
314
+ total_best_reward = 0.0
315
+ total_init_reward = 0.0
316
+ total_improved_samples = 0
317
+ for index, sample in enumerate(parti_dataset):
318
+ os.makedirs(
319
+ f"{args.save_dir}/{args.task}/{settings}/{index}", exist_ok=True
320
+ )
321
+ prompt = sample["Prompt"]
322
+ init_image, best_image, init_rewards, best_rewards = trainer.train(
323
+ latents, prompt, optimizer, multi_apply_fn
324
+ )
325
+ best_image.save(
326
+ f"{args.save_dir}/{args.task}/{settings}/{index}/best_image.png"
327
+ )
328
+ open(
329
+ f"{args.save_dir}/{args.task}/{settings}/{index}/prompt.txt", "w"
330
+ ).write(
331
+ f"{prompt} \n Initial Rewards: {init_rewards} \n Best Rewards: {best_rewards}"
332
+ )
333
+ logging.info(f"Initial rewards: {init_rewards}")
334
+ logging.info(f"Best rewards: {best_rewards}")
335
+ initial_reward = init_rewards[args.benchmark_reward]
336
+ best_reward = best_rewards[args.benchmark_reward]
337
+ total_reward_diff += best_reward - initial_reward
338
+ total_best_reward += best_reward
339
+ total_init_reward += initial_reward
340
+ if best_reward < initial_reward:
341
+ total_improved_samples += 1
342
+ if i == 0:
343
+ total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
344
+ total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
345
+ for k in best_rewards.keys():
346
+ total_best_rewards[k] += best_rewards[k]
347
+ total_init_rewards[k] += init_rewards[k]
348
+ # Get new latents and optimizer
349
+ init_latents = torch.randn(shape, device=device, dtype=dtype)
350
+ latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
351
+ optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
352
+ improvement_percentage = total_improved_samples / parti_dataset.num_rows
353
+ mean_best_reward = total_best_reward / parti_dataset.num_rows
354
+ mean_init_reward = total_init_reward / parti_dataset.num_rows
355
+ mean_reward_diff = total_reward_diff / parti_dataset.num_rows
356
+ logging.info(
357
+ f"Improvement percentage: {improvement_percentage:.4f}, "
358
+ f"mean initial reward: {mean_init_reward:.4f}, "
359
+ f"mean best reward: {mean_best_reward:.4f}, "
360
+ f"mean reward diff: {mean_reward_diff:.4f}"
361
+ )
362
+ for k in total_best_rewards.keys():
363
+ total_best_rewards[k] /= len(parti_dataset)
364
+ total_init_rewards[k] /= len(parti_dataset)
365
+ # save results
366
+ os.makedirs(f"{args.save_dir}/parti-prompts/{settings}", exist_ok=True)
367
+ with open(f"{args.save_dir}/parti-prompts/{settings}/results.txt", "w") as f:
368
+ f.write(
369
+ f"Mean improvement: {improvement_percentage:.4f}, "
370
+ f"mean initial reward: {mean_init_reward:.4f}, "
371
+ f"mean best reward: {mean_best_reward:.4f}, "
372
+ f"mean reward diff: {mean_reward_diff:.4f}\n"
373
+ f"Mean initial all rewards: {total_init_rewards}\n"
374
+ f"Mean best all rewards: {total_best_rewards}"
375
+ )
376
+ elif args.task == "geneval":
377
+ prompt_list_file = "../geneval/prompts/evaluation_metadata.jsonl"
378
+ with open(prompt_list_file) as fp:
379
+ metadatas = [json.loads(line) for line in fp]
380
+ outdir = f"{args.save_dir}/{args.task}/{settings}"
381
+ for index, metadata in enumerate(metadatas):
382
+ # Get new latents and optimizer
383
+ init_latents = torch.randn(shape, device=device, dtype=dtype)
384
+ latents = torch.nn.Parameter(init_latents, requires_grad=True)
385
+ optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
386
+
387
+ prompt = metadata["prompt"]
388
+ init_image, best_image, init_rewards, best_rewards = trainer.train(
389
+ latents, prompt, optimizer, None, multi_apply_fn
390
+ )
391
+ logging.info(f"Initial rewards: {init_rewards}")
392
+ logging.info(f"Best rewards: {best_rewards}")
393
+ outpath = f"{outdir}/{index:0>5}"
394
+ os.makedirs(f"{outpath}/samples", exist_ok=True)
395
+ with open(f"{outpath}/metadata.jsonl", "w") as fp:
396
+ json.dump(metadata, fp)
397
+ best_image.save(f"{outpath}/samples/{args.seed:05}.png")
398
+ if i == 0:
399
+ total_best_rewards = {k: 0.0 for k in best_rewards.keys()}
400
+ total_init_rewards = {k: 0.0 for k in best_rewards.keys()}
401
+ for k in best_rewards.keys():
402
+ total_best_rewards[k] += best_rewards[k]
403
+ total_init_rewards[k] += init_rewards[k]
404
+ for k in total_best_rewards.keys():
405
+ total_best_rewards[k] /= len(parti_dataset)
406
+ total_init_rewards[k] /= len(parti_dataset)
407
+ else:
408
+ raise ValueError(f"Unknown task {args.task}")
409
+ # log total rewards
410
+ logging.info(f"Mean initial rewards: {total_init_rewards}")
411
+ logging.info(f"Mean best rewards: {total_best_rewards}")
412
+
413
+ def main():
414
+ args = parse_args()
415
+ args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup=None)
416
+ execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe)
417
+
418
+ if __name__ == "__main__":
419
+ main()