NightRaven109 commited on
Commit
877f3e5
·
verified ·
1 Parent(s): d7ca9c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -49,14 +49,17 @@ def initialize_models():
49
  # Load pipeline
50
  pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
51
 
52
- # Set pipeline to eval mode
53
  pipeline.unet.eval()
54
  pipeline.controlnet.eval()
55
  pipeline.vae.eval()
56
  pipeline.text_encoder.eval()
57
 
 
 
 
58
  # Initialize generator
59
- generator = torch.Generator(device=accelerator.device)
60
 
61
  return True
62
 
@@ -64,7 +67,7 @@ def initialize_models():
64
  print(f"Error initializing models: {str(e)}")
65
  return False
66
 
67
- @spaces.GPU(processing_timeout=180) # Increased timeout for longer processing
68
  def process_image(
69
  input_image,
70
  prompt="clean, high-resolution, 8k",
@@ -78,11 +81,12 @@ def process_image(
78
  ):
79
  global pipeline, generator, accelerator
80
 
81
- if pipeline is None:
82
- if not initialize_models():
83
- return None
84
-
85
  try:
 
 
 
 
 
86
  # Create args object with all necessary parameters
87
  args = Args(
88
  added_prompt=prompt,
@@ -124,27 +128,38 @@ def process_image(
124
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
125
  width, height = validation_image.size
126
 
 
 
 
 
 
 
 
127
  # Generate image
128
  with torch.no_grad():
129
- inference_time, output = pipeline(
130
- args.t_max,
131
- args.t_min,
132
- args.tile_diffusion,
133
- args.tile_diffusion_size,
134
- args.tile_diffusion_stride,
135
- args.added_prompt,
136
- validation_image,
137
- num_inference_steps=args.num_inference_steps,
138
- generator=generator,
139
- height=height,
140
- width=width,
141
- guidance_scale=args.guidance_scale,
142
- negative_prompt=args.negative_prompt,
143
- conditioning_scale=args.conditioning_scale,
144
- start_steps=args.start_steps,
145
- start_point=args.start_point,
146
- use_vae_encode_condition=args.use_vae_encode_condition,
147
- )
 
 
 
 
148
 
149
  image = output.images[0]
150
 
@@ -161,6 +176,8 @@ def process_image(
161
 
162
  except Exception as e:
163
  print(f"Error processing image: {str(e)}")
 
 
164
  return None
165
 
166
  # Create Gradio interface
 
49
  # Load pipeline
50
  pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
51
 
52
+ # Ensure all models are in eval mode
53
  pipeline.unet.eval()
54
  pipeline.controlnet.eval()
55
  pipeline.vae.eval()
56
  pipeline.text_encoder.eval()
57
 
58
+ # Move pipeline to CUDA
59
+ pipeline = pipeline.to("cuda")
60
+
61
  # Initialize generator
62
+ generator = torch.Generator("cuda")
63
 
64
  return True
65
 
 
67
  print(f"Error initializing models: {str(e)}")
68
  return False
69
 
70
+ @spaces.GPU(processing_timeout=180)
71
  def process_image(
72
  input_image,
73
  prompt="clean, high-resolution, 8k",
 
81
  ):
82
  global pipeline, generator, accelerator
83
 
 
 
 
 
84
  try:
85
+ # Initialize models if not already done
86
+ if pipeline is None:
87
+ if not initialize_models():
88
+ return None
89
+
90
  # Create args object with all necessary parameters
91
  args = Args(
92
  added_prompt=prompt,
 
128
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
129
  width, height = validation_image.size
130
 
131
+ # Ensure pipeline is on CUDA and in eval mode
132
+ pipeline = pipeline.to("cuda")
133
+ pipeline.unet.eval()
134
+ pipeline.controlnet.eval()
135
+ pipeline.vae.eval()
136
+ pipeline.text_encoder.eval()
137
+
138
  # Generate image
139
  with torch.no_grad():
140
+ try:
141
+ inference_time, output = pipeline(
142
+ args.t_max,
143
+ args.t_min,
144
+ args.tile_diffusion,
145
+ args.tile_diffusion_size,
146
+ args.tile_diffusion_stride,
147
+ args.added_prompt,
148
+ validation_image,
149
+ num_inference_steps=args.num_inference_steps,
150
+ generator=generator,
151
+ height=height,
152
+ width=width,
153
+ guidance_scale=args.guidance_scale,
154
+ negative_prompt=args.negative_prompt,
155
+ conditioning_scale=args.conditioning_scale,
156
+ start_steps=args.start_steps,
157
+ start_point=args.start_point,
158
+ use_vae_encode_condition=args.use_vae_encode_condition,
159
+ )
160
+ except Exception as e:
161
+ print(f"Pipeline execution error: {str(e)}")
162
+ raise
163
 
164
  image = output.images[0]
165
 
 
176
 
177
  except Exception as e:
178
  print(f"Error processing image: {str(e)}")
179
+ import traceback
180
+ traceback.print_exc()
181
  return None
182
 
183
  # Create Gradio interface