NightRaven109 commited on
Commit
a495ef9
·
verified ·
1 Parent(s): eeb4ef5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -61
app.py CHANGED
@@ -24,6 +24,73 @@ generator = None
24
  accelerator = None
25
  model_path = None
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @spaces.GPU
28
  def initialize_models():
29
  global pipeline, generator, accelerator, model_path
@@ -41,62 +108,8 @@ def initialize_models():
41
  token=os.environ['Read2']
42
  )
43
 
44
- # Load models from local directory
45
- scheduler = DDPMScheduler.from_pretrained(
46
- os.path.join(model_path, "stable-diffusion-2-1-base/scheduler")
47
- )
48
-
49
- text_encoder = CLIPTextModel.from_pretrained(
50
- os.path.join(model_path, "stable-diffusion-2-1-base/text_encoder")
51
- )
52
-
53
- tokenizer = CLIPTokenizer.from_pretrained(
54
- os.path.join(model_path, "stable-diffusion-2-1-base/tokenizer")
55
- )
56
-
57
- feature_extractor = CLIPImageProcessor.from_pretrained(
58
- os.path.join(model_path, "stable-diffusion-2-1-base/feature_extractor")
59
- )
60
-
61
- unet = UNet2DConditionModel.from_pretrained(
62
- os.path.join(model_path, "stable-diffusion-2-1-base/unet")
63
- )
64
-
65
- controlnet = ControlNetModel.from_pretrained(
66
- os.path.join(model_path, "Controlnet")
67
- )
68
-
69
- vae = AutoencoderKL.from_pretrained(
70
- os.path.join(model_path, "vae")
71
- )
72
-
73
- # Freeze models
74
- for model in [vae, text_encoder, unet, controlnet]:
75
- model.requires_grad_(False)
76
-
77
- # Initialize pipeline
78
- pipeline = StableDiffusionControlNetPipeline(
79
- vae=vae,
80
- text_encoder=text_encoder,
81
- tokenizer=tokenizer,
82
- feature_extractor=feature_extractor,
83
- unet=unet,
84
- controlnet=controlnet,
85
- scheduler=scheduler,
86
- safety_checker=None,
87
- requires_safety_checker=False,
88
- )
89
-
90
- # Get weight dtype based on mixed precision
91
- weight_dtype = torch.float32
92
- if accelerator.mixed_precision == "fp16":
93
- weight_dtype = torch.float16
94
- elif accelerator.mixed_precision == "bf16":
95
- weight_dtype = torch.bfloat16
96
-
97
- # Move models to device with appropriate dtype
98
- for model in [text_encoder, vae, unet, controlnet]:
99
- model.to(accelerator.device, dtype=weight_dtype)
100
 
101
  # Initialize generator
102
  generator = torch.Generator(device=accelerator.device)
@@ -149,6 +162,8 @@ def process_image(
149
  t_max=0.6666,
150
  t_min=0.0,
151
  tile_diffusion=False,
 
 
152
  added_prompt=prompt,
153
  image=input_pil,
154
  num_inference_steps=num_inference_steps,
@@ -158,6 +173,9 @@ def process_image(
158
  guidance_scale=guidance_scale,
159
  negative_prompt=negative_prompt,
160
  conditioning_scale=conditioning_scale,
 
 
 
161
  )
162
 
163
  generated_image = output.images[0]
@@ -193,11 +211,7 @@ iface = gr.Interface(
193
  ],
194
  outputs=gr.Image(label="Generated Image"),
195
  title="Controllable Conditional Super-Resolution",
196
- description="Upload an image to enhance its resolution using CCSR.",
197
- examples=[
198
- ["example1.jpg", "clean, sharp, detailed", "blurry, noise", 1.0, 1.0, 20, 42, 2, "adain"],
199
- ["example2.jpg", "high-resolution, pristine", "artifacts, pixelated", 1.5, 1.0, 30, 123, 2, "wavelet"],
200
- ]
201
  )
202
 
203
  if __name__ == "__main__":
 
24
  accelerator = None
25
  model_path = None
26
 
27
+ def load_pipeline(accelerator, model_path):
28
+ # Load scheduler
29
+ scheduler = DDPMScheduler.from_pretrained(
30
+ model_path,
31
+ subfolder="stable-diffusion-2-1-base/scheduler"
32
+ )
33
+
34
+ # Load models
35
+ text_encoder = CLIPTextModel.from_pretrained(
36
+ model_path,
37
+ subfolder="stable-diffusion-2-1-base/text_encoder"
38
+ )
39
+
40
+ tokenizer = CLIPTokenizer.from_pretrained(
41
+ model_path,
42
+ subfolder="stable-diffusion-2-1-base/tokenizer"
43
+ )
44
+
45
+ feature_extractor = CLIPImageProcessor.from_pretrained(
46
+ os.path.join(model_path, "stable-diffusion-2-1-base/feature_extractor")
47
+ )
48
+
49
+ unet = UNet2DConditionModel.from_pretrained(
50
+ model_path,
51
+ subfolder="stable-diffusion-2-1-base/unet"
52
+ )
53
+
54
+ controlnet = ControlNetModel.from_pretrained(
55
+ model_path,
56
+ subfolder="Controlnet"
57
+ )
58
+
59
+ vae = AutoencoderKL.from_pretrained(
60
+ model_path,
61
+ subfolder="vae"
62
+ )
63
+
64
+ # Freeze models
65
+ for model in [vae, text_encoder, unet, controlnet]:
66
+ model.requires_grad_(False)
67
+
68
+ # Initialize pipeline
69
+ pipeline = StableDiffusionControlNetPipeline(
70
+ vae=vae,
71
+ text_encoder=text_encoder,
72
+ tokenizer=tokenizer,
73
+ feature_extractor=feature_extractor,
74
+ unet=unet,
75
+ controlnet=controlnet,
76
+ scheduler=scheduler,
77
+ safety_checker=None,
78
+ requires_safety_checker=False,
79
+ )
80
+
81
+ # Set weight dtype based on mixed precision
82
+ weight_dtype = torch.float32
83
+ if accelerator.mixed_precision == "fp16":
84
+ weight_dtype = torch.float16
85
+ elif accelerator.mixed_precision == "bf16":
86
+ weight_dtype = torch.bfloat16
87
+
88
+ # Move models to accelerator device with appropriate dtype
89
+ for model in [text_encoder, vae, unet, controlnet]:
90
+ model.to(accelerator.device, dtype=weight_dtype)
91
+
92
+ return pipeline
93
+
94
  @spaces.GPU
95
  def initialize_models():
96
  global pipeline, generator, accelerator, model_path
 
108
  token=os.environ['Read2']
109
  )
110
 
111
+ # Load pipeline using the original loading function
112
+ pipeline = load_pipeline(accelerator, model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Initialize generator
115
  generator = torch.Generator(device=accelerator.device)
 
162
  t_max=0.6666,
163
  t_min=0.0,
164
  tile_diffusion=False,
165
+ tile_diffusion_size=512,
166
+ tile_diffusion_stride=256,
167
  added_prompt=prompt,
168
  image=input_pil,
169
  num_inference_steps=num_inference_steps,
 
173
  guidance_scale=guidance_scale,
174
  negative_prompt=negative_prompt,
175
  conditioning_scale=conditioning_scale,
176
+ start_steps=999,
177
+ start_point='lr',
178
+ use_vae_encode_condition=False
179
  )
180
 
181
  generated_image = output.images[0]
 
211
  ],
212
  outputs=gr.Image(label="Generated Image"),
213
  title="Controllable Conditional Super-Resolution",
214
+ description="Upload an image to enhance its resolution using CCSR."
 
 
 
 
215
  )
216
 
217
  if __name__ == "__main__":