NightRaven109 commited on
Commit
b22f2c5
·
verified ·
1 Parent(s): 83686fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -138
app.py CHANGED
@@ -2,115 +2,52 @@ import os
2
  import torch
3
  import gradio as gr
4
  import spaces
5
- import numpy as np
6
  from PIL import Image
7
- import safetensors.torch
8
  from huggingface_hub import snapshot_download
 
 
9
  from accelerate import Accelerator
10
- from accelerate.utils import set_seed
11
- from diffusers import (
12
- AutoencoderKL,
13
- DDPMScheduler,
14
- UNet2DConditionModel,
15
- )
16
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
17
- from models.controlnet import ControlNetModel
18
- from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
19
- from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
20
 
21
- # Initialize global variables for models
22
  pipeline = None
23
  generator = None
24
  accelerator = None
25
- model_path = None
26
-
27
- def load_pipeline(accelerator, model_path):
28
- # Load scheduler
29
- scheduler = DDPMScheduler.from_pretrained(
30
- model_path,
31
- subfolder="stable-diffusion-2-1-base/scheduler"
32
- )
33
-
34
- # Load models
35
- text_encoder = CLIPTextModel.from_pretrained(
36
- model_path,
37
- subfolder="stable-diffusion-2-1-base/text_encoder"
38
- )
39
-
40
- tokenizer = CLIPTokenizer.from_pretrained(
41
- model_path,
42
- subfolder="stable-diffusion-2-1-base/tokenizer"
43
- )
44
-
45
- feature_extractor = CLIPImageProcessor.from_pretrained(
46
- os.path.join(model_path, "stable-diffusion-2-1-base/feature_extractor")
47
- )
48
-
49
- unet = UNet2DConditionModel.from_pretrained(
50
- model_path,
51
- subfolder="stable-diffusion-2-1-base/unet"
52
- )
53
-
54
- controlnet = ControlNetModel.from_pretrained(
55
- model_path,
56
- subfolder="Controlnet"
57
- )
58
-
59
- vae = AutoencoderKL.from_pretrained(
60
- model_path,
61
- subfolder="vae"
62
- )
63
-
64
- # Freeze models
65
- for model in [vae, text_encoder, unet, controlnet]:
66
- model.requires_grad_(False)
67
-
68
- # Initialize pipeline
69
- pipeline = StableDiffusionControlNetPipeline(
70
- vae=vae,
71
- text_encoder=text_encoder,
72
- tokenizer=tokenizer,
73
- feature_extractor=feature_extractor,
74
- unet=unet,
75
- controlnet=controlnet,
76
- scheduler=scheduler,
77
- safety_checker=None,
78
- requires_safety_checker=False,
79
- )
80
-
81
- # Set weight dtype based on mixed precision
82
- weight_dtype = torch.float32
83
- if accelerator.mixed_precision == "fp16":
84
- weight_dtype = torch.float16
85
- elif accelerator.mixed_precision == "bf16":
86
- weight_dtype = torch.bfloat16
87
-
88
- # Move models to accelerator device with appropriate dtype
89
- for model in [text_encoder, vae, unet, controlnet]:
90
- model.to(accelerator.device, dtype=weight_dtype)
91
-
92
- return pipeline
93
 
94
  @spaces.GPU
95
  def initialize_models():
96
- global pipeline, generator, accelerator, model_path
97
 
98
- # Initialize accelerator
99
- accelerator = Accelerator(
100
- mixed_precision="fp16",
101
- gradient_accumulation_steps=1
102
- )
103
-
104
  try:
105
- # Download the entire repository
106
  model_path = snapshot_download(
107
  repo_id="NightRaven109/CCSRModels",
108
  token=os.environ['Read2']
109
  )
110
-
111
- # Load pipeline using the original loading function
112
- pipeline = load_pipeline(accelerator, model_path)
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Initialize generator
115
  generator = torch.Generator(device=accelerator.device)
116
 
@@ -137,72 +74,83 @@ def process_image(
137
  if pipeline is None:
138
  if not initialize_models():
139
  return None
140
-
141
  try:
142
- # Set seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if seed is not None:
144
  generator.manual_seed(seed)
145
-
146
  # Process input image
147
  validation_image = Image.fromarray(input_image)
148
  ori_width, ori_height = validation_image.size
149
 
150
- # Resize logic from original script
151
  resize_flag = False
152
- rscale = upscale_factor
153
- process_size = 512 # Same as args.process_size in original
154
-
155
- if ori_width < process_size//rscale or ori_height < process_size//rscale:
156
- scale = (process_size//rscale)/min(ori_width, ori_height)
157
- tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
158
- validation_image = tmp_image
159
  resize_flag = True
160
 
161
- validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
162
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
163
  width, height = validation_image.size
164
-
165
- # Move pipeline to GPU for processing
166
- pipeline.to(accelerator.device)
167
-
168
  # Generate image
169
- with torch.no_grad():
170
- inference_time, output = pipeline(
171
- 0.6666, # t_max
172
- 0.0, # t_min
173
- False, # tile_diffusion
174
- None, # tile_diffusion_size
175
- None, # tile_diffusion_stride
176
- prompt,
177
- validation_image,
178
- num_inference_steps=num_inference_steps,
179
- generator=generator,
180
- height=height,
181
- width=width,
182
- guidance_scale=guidance_scale,
183
- negative_prompt=negative_prompt,
184
- conditioning_scale=conditioning_scale,
185
- start_steps=999,
186
- start_point='lr',
187
- use_vae_encode_condition=False
188
- )
189
-
190
  image = output.images[0]
191
 
192
  # Apply color fixing if specified
193
- if color_fix_method != "none":
194
- fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
 
195
  image = fix_func(image, validation_image)
196
 
197
  if resize_flag:
198
- image = image.resize((ori_width*rscale, ori_height*rscale))
199
-
200
- # Move pipeline back to CPU
201
- pipeline.to("cpu")
202
- torch.cuda.empty_cache()
203
-
204
  return image
205
-
206
  except Exception as e:
207
  print(f"Error processing image: {str(e)}")
208
  return None
 
2
  import torch
3
  import gradio as gr
4
  import spaces
 
5
  from PIL import Image
 
6
  from huggingface_hub import snapshot_download
7
+ from test_ccsr_tile import main, load_pipeline
8
+ import argparse
9
  from accelerate import Accelerator
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Initialize global variables
12
  pipeline = None
13
  generator = None
14
  accelerator = None
15
+
16
+ class Args:
17
+ def __init__(self, **kwargs):
18
+ self.__dict__.update(kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @spaces.GPU
21
  def initialize_models():
22
+ global pipeline, generator, accelerator
23
 
 
 
 
 
 
 
24
  try:
25
+ # Download model repository
26
  model_path = snapshot_download(
27
  repo_id="NightRaven109/CCSRModels",
28
  token=os.environ['Read2']
29
  )
 
 
 
30
 
31
+ # Set up default arguments
32
+ args = Args(
33
+ pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"),
34
+ controlnet_model_path=os.path.join(model_path, "Controlnet"),
35
+ vae_model_path=os.path.join(model_path, "vae"),
36
+ mixed_precision="fp16",
37
+ tile_vae=False,
38
+ sample_method="ddpm",
39
+ vae_encoder_tile_size=1024,
40
+ vae_decoder_tile_size=224
41
+ )
42
+
43
+ # Initialize accelerator
44
+ accelerator = Accelerator(
45
+ mixed_precision=args.mixed_precision,
46
+ )
47
+
48
+ # Load pipeline
49
+ pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
50
+
51
  # Initialize generator
52
  generator = torch.Generator(device=accelerator.device)
53
 
 
74
  if pipeline is None:
75
  if not initialize_models():
76
  return None
77
+
78
  try:
79
+ # Create args object with all necessary parameters
80
+ args = Args(
81
+ added_prompt=prompt,
82
+ negative_prompt=negative_prompt,
83
+ guidance_scale=guidance_scale,
84
+ conditioning_scale=conditioning_scale,
85
+ num_inference_steps=num_inference_steps,
86
+ seed=seed,
87
+ upscale=upscale_factor,
88
+ process_size=512,
89
+ align_method=color_fix_method,
90
+ t_max=0.6666,
91
+ t_min=0.0,
92
+ tile_diffusion=False,
93
+ tile_diffusion_size=None,
94
+ tile_diffusion_stride=None,
95
+ start_steps=999,
96
+ start_point='lr',
97
+ use_vae_encode_condition=False,
98
+ sample_times=1
99
+ )
100
+
101
+ # Set seed if provided
102
  if seed is not None:
103
  generator.manual_seed(seed)
104
+
105
  # Process input image
106
  validation_image = Image.fromarray(input_image)
107
  ori_width, ori_height = validation_image.size
108
 
109
+ # Resize logic
110
  resize_flag = False
111
+ if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale:
112
+ scale = (args.process_size//args.upscale)/min(ori_width, ori_height)
113
+ validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
 
 
 
 
114
  resize_flag = True
115
 
116
+ validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale))
117
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
118
  width, height = validation_image.size
119
+
 
 
 
120
  # Generate image
121
+ inference_time, output = pipeline(
122
+ args.t_max,
123
+ args.t_min,
124
+ args.tile_diffusion,
125
+ args.tile_diffusion_size,
126
+ args.tile_diffusion_stride,
127
+ args.added_prompt,
128
+ validation_image,
129
+ num_inference_steps=args.num_inference_steps,
130
+ generator=generator,
131
+ height=height,
132
+ width=width,
133
+ guidance_scale=args.guidance_scale,
134
+ negative_prompt=args.negative_prompt,
135
+ conditioning_scale=args.conditioning_scale,
136
+ start_steps=args.start_steps,
137
+ start_point=args.start_point,
138
+ use_vae_encode_condition=args.use_vae_encode_condition,
139
+ )
140
+
 
141
  image = output.images[0]
142
 
143
  # Apply color fixing if specified
144
+ if args.align_method != "none":
145
+ from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
146
+ fix_func = wavelet_color_fix if args.align_method == "wavelet" else adain_color_fix
147
  image = fix_func(image, validation_image)
148
 
149
  if resize_flag:
150
+ image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
151
+
 
 
 
 
152
  return image
153
+
154
  except Exception as e:
155
  print(f"Error processing image: {str(e)}")
156
  return None