Spaces:
Sleeping
Sleeping
try handle OOM errors
Browse files
app.py
CHANGED
@@ -108,14 +108,25 @@ def generate_image(setup_args, num_iterations):
|
|
108 |
|
109 |
# Function to run main in a separate thread
|
110 |
def run_main():
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
# Start main in a separate thread
|
114 |
main_thread = threading.Thread(target=run_main)
|
115 |
main_thread.start()
|
116 |
|
117 |
last_step_yielded = 0
|
118 |
-
while main_thread.is_alive()
|
119 |
# Check if new steps have been completed
|
120 |
if steps_completed and steps_completed[-1] > last_step_yielded:
|
121 |
last_step_yielded = steps_completed[-1]
|
@@ -130,21 +141,35 @@ def generate_image(setup_args, num_iterations):
|
|
130 |
# Small sleep to prevent busy waiting
|
131 |
time.sleep(0.1)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
# After main is complete, yield the final image
|
136 |
-
final_image_path = os.path.join(save_dir, "best_image.png")
|
137 |
-
if os.path.exists(final_image_path):
|
138 |
-
iter_images = list_iter_images(save_dir)
|
139 |
torch.cuda.empty_cache() # Free up cached memory
|
140 |
-
yield (
|
141 |
else:
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
except Exception as e:
|
146 |
torch.cuda.empty_cache() # Free up cached memory
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
def show_gallery_output(gallery_state):
|
150 |
if gallery_state is not None:
|
|
|
108 |
|
109 |
# Function to run main in a separate thread
|
110 |
def run_main():
|
111 |
+
try:
|
112 |
+
# Call main and handle any potential OOM errors
|
113 |
+
result_container["best_image"], result_container["total_init_rewards"], result_container["total_best_rewards"] = execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback)
|
114 |
+
except torch.cuda.OutOfMemoryError as e:
|
115 |
+
# Handle CUDA OOM error
|
116 |
+
print("CUDA Out of Memory Error: ", e)
|
117 |
+
status["error_occurred"] = True # Update status on error
|
118 |
+
except RuntimeError as e:
|
119 |
+
if 'out of memory' in str(e):
|
120 |
+
status["error_occurred"] = True # Update status on error
|
121 |
+
else:
|
122 |
+
raise # Reraise if it's not a CUDA OOM error
|
123 |
|
124 |
# Start main in a separate thread
|
125 |
main_thread = threading.Thread(target=run_main)
|
126 |
main_thread.start()
|
127 |
|
128 |
last_step_yielded = 0
|
129 |
+
while main_thread.is_alive() and not status["error_occurred"]:
|
130 |
# Check if new steps have been completed
|
131 |
if steps_completed and steps_completed[-1] > last_step_yielded:
|
132 |
last_step_yielded = steps_completed[-1]
|
|
|
141 |
# Small sleep to prevent busy waiting
|
142 |
time.sleep(0.1)
|
143 |
|
144 |
+
# If an error occurred, clean up resources and stop
|
145 |
+
if status["error_occurred"]:
|
|
|
|
|
|
|
|
|
146 |
torch.cuda.empty_cache() # Free up cached memory
|
147 |
+
yield (None, "CUDA out of memory. Please reduce your batch size or image resolution.", None)
|
148 |
else:
|
149 |
+
main_thread.join()
|
150 |
+
|
151 |
+
# After main is complete, yield the final image
|
152 |
+
final_image_path = os.path.join(save_dir, "best_image.png")
|
153 |
+
if os.path.exists(final_image_path):
|
154 |
+
iter_images = list_iter_images(save_dir)
|
155 |
+
torch.cuda.empty_cache() # Free up cached memory
|
156 |
+
yield (final_image_path, f"Final image saved at {final_image_path}", iter_images)
|
157 |
+
else:
|
158 |
+
torch.cuda.empty_cache() # Free up cached memory
|
159 |
+
yield (None, "Image generation completed, but no final image was found.", None)
|
160 |
|
|
|
161 |
torch.cuda.empty_cache() # Free up cached memory
|
162 |
+
|
163 |
+
except torch.cuda.OutOfMemoryError as e:
|
164 |
+
# Handle CUDA OOM error globally
|
165 |
+
yield (None, "CUDA out of memory.", None)
|
166 |
+
except RuntimeError as e:
|
167 |
+
if 'out of memory' in str(e):
|
168 |
+
yield (None, "CUDA out of memory.", None)
|
169 |
+
else:
|
170 |
+
yield (None, f"An error occurred: {str(e)}", None)
|
171 |
+
except Exception as e:
|
172 |
+
yield (None, f"An unexpected error occurred: {str(e)}", None)
|
173 |
|
174 |
def show_gallery_output(gallery_state):
|
175 |
if gallery_state is not None:
|