NightRaven109 commited on
Commit
d7ca9c8
·
verified ·
1 Parent(s): bfecb5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -62
app.py CHANGED
@@ -3,8 +3,9 @@ import torch
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
 
6
  from huggingface_hub import snapshot_download
7
- from test_ccsr_tile import main, load_pipeline
8
  import argparse
9
  from accelerate import Accelerator
10
 
@@ -48,6 +49,12 @@ def initialize_models():
48
  # Load pipeline
49
  pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
50
 
 
 
 
 
 
 
51
  # Initialize generator
52
  generator = torch.Generator(device=accelerator.device)
53
 
@@ -57,7 +64,7 @@ def initialize_models():
57
  print(f"Error initializing models: {str(e)}")
58
  return False
59
 
60
- @spaces.GPU
61
  def process_image(
62
  input_image,
63
  prompt="clean, high-resolution, 8k",
@@ -117,13 +124,6 @@ def process_image(
117
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
118
  width, height = validation_image.size
119
 
120
- # Move pipeline to GPU and set to eval mode
121
- pipeline.to(accelerator.device)
122
- pipeline.unet.eval()
123
- pipeline.controlnet.eval()
124
- pipeline.vae.eval()
125
- pipeline.text_encoder.eval()
126
-
127
  # Generate image
128
  with torch.no_grad():
129
  inference_time, output = pipeline(
@@ -157,62 +157,30 @@ def process_image(
157
  if resize_flag:
158
  image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
159
 
160
- # Move pipeline back to CPU to free up GPU memory
161
- pipeline.to("cpu")
162
- torch.cuda.empty_cache()
163
-
164
  return image
165
 
166
  except Exception as e:
167
  print(f"Error processing image: {str(e)}")
168
  return None
169
 
170
- # Also update the initialize_models function:
171
- @spaces.GPU
172
- def initialize_models():
173
- global pipeline, generator, accelerator
174
-
175
- try:
176
- # Download model repository
177
- model_path = snapshot_download(
178
- repo_id="NightRaven109/CCSRModels",
179
- token=os.environ['Read2']
180
- )
181
-
182
- # Set up default arguments
183
- args = Args(
184
- pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"),
185
- controlnet_model_path=os.path.join(model_path, "Controlnet"),
186
- vae_model_path=os.path.join(model_path, "vae"),
187
- mixed_precision="fp16",
188
- tile_vae=False,
189
- sample_method="ddpm",
190
- vae_encoder_tile_size=1024,
191
- vae_decoder_tile_size=224
192
- )
193
-
194
- # Initialize accelerator
195
- accelerator = Accelerator(
196
- mixed_precision=args.mixed_precision,
197
- )
198
-
199
- # Load pipeline
200
- pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
201
-
202
- # Set pipeline to eval mode
203
- pipeline.unet.eval()
204
- pipeline.controlnet.eval()
205
- pipeline.vae.eval()
206
- pipeline.text_encoder.eval()
207
-
208
- # Move to CPU initially to save memory
209
- pipeline.to("cpu")
210
-
211
- # Initialize generator
212
- generator = torch.Generator(device=accelerator.device)
213
-
214
- return True
215
-
216
- except Exception as e:
217
- print(f"Error initializing models: {str(e)}")
218
- return False
 
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
6
+ from diffusers import DiffusionPipeline
7
  from huggingface_hub import snapshot_download
8
+ from test_ccsr_tile import load_pipeline
9
  import argparse
10
  from accelerate import Accelerator
11
 
 
49
  # Load pipeline
50
  pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
51
 
52
+ # Set pipeline to eval mode
53
+ pipeline.unet.eval()
54
+ pipeline.controlnet.eval()
55
+ pipeline.vae.eval()
56
+ pipeline.text_encoder.eval()
57
+
58
  # Initialize generator
59
  generator = torch.Generator(device=accelerator.device)
60
 
 
64
  print(f"Error initializing models: {str(e)}")
65
  return False
66
 
67
+ @spaces.GPU(processing_timeout=180) # Increased timeout for longer processing
68
  def process_image(
69
  input_image,
70
  prompt="clean, high-resolution, 8k",
 
124
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
125
  width, height = validation_image.size
126
 
 
 
 
 
 
 
 
127
  # Generate image
128
  with torch.no_grad():
129
  inference_time, output = pipeline(
 
157
  if resize_flag:
158
  image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
159
 
 
 
 
 
160
  return image
161
 
162
  except Exception as e:
163
  print(f"Error processing image: {str(e)}")
164
  return None
165
 
166
+ # Create Gradio interface
167
+ demo = gr.Interface(
168
+ fn=process_image,
169
+ inputs=[
170
+ gr.Image(label="Input Image"),
171
+ gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
172
+ gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
173
+ gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
174
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
175
+ gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
176
+ gr.Number(label="Seed", value=42),
177
+ gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
178
+ gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
179
+ ],
180
+ outputs=gr.Image(label="Generated Image"),
181
+ title="Controllable Conditional Super-Resolution",
182
+ description="Upload an image to enhance its resolution using CCSR."
183
+ )
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch()