Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -3,8 +3,9 @@ import torch
|
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
from PIL import Image
|
|
|
6 |
from huggingface_hub import snapshot_download
|
7 |
-
from test_ccsr_tile import
|
8 |
import argparse
|
9 |
from accelerate import Accelerator
|
10 |
|
@@ -48,6 +49,12 @@ def initialize_models():
|
|
48 |
# Load pipeline
|
49 |
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Initialize generator
|
52 |
generator = torch.Generator(device=accelerator.device)
|
53 |
|
@@ -57,7 +64,7 @@ def initialize_models():
|
|
57 |
print(f"Error initializing models: {str(e)}")
|
58 |
return False
|
59 |
|
60 |
-
@spaces.GPU
|
61 |
def process_image(
|
62 |
input_image,
|
63 |
prompt="clean, high-resolution, 8k",
|
@@ -117,13 +124,6 @@ def process_image(
|
|
117 |
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
|
118 |
width, height = validation_image.size
|
119 |
|
120 |
-
# Move pipeline to GPU and set to eval mode
|
121 |
-
pipeline.to(accelerator.device)
|
122 |
-
pipeline.unet.eval()
|
123 |
-
pipeline.controlnet.eval()
|
124 |
-
pipeline.vae.eval()
|
125 |
-
pipeline.text_encoder.eval()
|
126 |
-
|
127 |
# Generate image
|
128 |
with torch.no_grad():
|
129 |
inference_time, output = pipeline(
|
@@ -157,62 +157,30 @@ def process_image(
|
|
157 |
if resize_flag:
|
158 |
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
|
159 |
|
160 |
-
# Move pipeline back to CPU to free up GPU memory
|
161 |
-
pipeline.to("cpu")
|
162 |
-
torch.cuda.empty_cache()
|
163 |
-
|
164 |
return image
|
165 |
|
166 |
except Exception as e:
|
167 |
print(f"Error processing image: {str(e)}")
|
168 |
return None
|
169 |
|
170 |
-
#
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
)
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
vae_decoder_tile_size=224
|
192 |
-
)
|
193 |
-
|
194 |
-
# Initialize accelerator
|
195 |
-
accelerator = Accelerator(
|
196 |
-
mixed_precision=args.mixed_precision,
|
197 |
-
)
|
198 |
-
|
199 |
-
# Load pipeline
|
200 |
-
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
201 |
-
|
202 |
-
# Set pipeline to eval mode
|
203 |
-
pipeline.unet.eval()
|
204 |
-
pipeline.controlnet.eval()
|
205 |
-
pipeline.vae.eval()
|
206 |
-
pipeline.text_encoder.eval()
|
207 |
-
|
208 |
-
# Move to CPU initially to save memory
|
209 |
-
pipeline.to("cpu")
|
210 |
-
|
211 |
-
# Initialize generator
|
212 |
-
generator = torch.Generator(device=accelerator.device)
|
213 |
-
|
214 |
-
return True
|
215 |
-
|
216 |
-
except Exception as e:
|
217 |
-
print(f"Error initializing models: {str(e)}")
|
218 |
-
return False
|
|
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
from PIL import Image
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
from huggingface_hub import snapshot_download
|
8 |
+
from test_ccsr_tile import load_pipeline
|
9 |
import argparse
|
10 |
from accelerate import Accelerator
|
11 |
|
|
|
49 |
# Load pipeline
|
50 |
pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
|
51 |
|
52 |
+
# Set pipeline to eval mode
|
53 |
+
pipeline.unet.eval()
|
54 |
+
pipeline.controlnet.eval()
|
55 |
+
pipeline.vae.eval()
|
56 |
+
pipeline.text_encoder.eval()
|
57 |
+
|
58 |
# Initialize generator
|
59 |
generator = torch.Generator(device=accelerator.device)
|
60 |
|
|
|
64 |
print(f"Error initializing models: {str(e)}")
|
65 |
return False
|
66 |
|
67 |
+
@spaces.GPU(processing_timeout=180) # Increased timeout for longer processing
|
68 |
def process_image(
|
69 |
input_image,
|
70 |
prompt="clean, high-resolution, 8k",
|
|
|
124 |
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
|
125 |
width, height = validation_image.size
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
# Generate image
|
128 |
with torch.no_grad():
|
129 |
inference_time, output = pipeline(
|
|
|
157 |
if resize_flag:
|
158 |
image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
|
159 |
|
|
|
|
|
|
|
|
|
160 |
return image
|
161 |
|
162 |
except Exception as e:
|
163 |
print(f"Error processing image: {str(e)}")
|
164 |
return None
|
165 |
|
166 |
+
# Create Gradio interface
|
167 |
+
demo = gr.Interface(
|
168 |
+
fn=process_image,
|
169 |
+
inputs=[
|
170 |
+
gr.Image(label="Input Image"),
|
171 |
+
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
|
172 |
+
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
|
173 |
+
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
|
174 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
|
175 |
+
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
|
176 |
+
gr.Number(label="Seed", value=42),
|
177 |
+
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
|
178 |
+
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
|
179 |
+
],
|
180 |
+
outputs=gr.Image(label="Generated Image"),
|
181 |
+
title="Controllable Conditional Super-Resolution",
|
182 |
+
description="Upload an image to enhance its resolution using CCSR."
|
183 |
+
)
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|