Deadmon commited on
Commit
379b6d3
·
verified ·
1 Parent(s): 69dc758

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -93
app.py CHANGED
@@ -34,7 +34,6 @@ function refresh() {
34
  }
35
  }
36
  """
37
-
38
  def nms(x, t, s):
39
  x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
40
 
@@ -134,14 +133,17 @@ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
134
  STYLE_NAMES = list(styles.keys())
135
  DEFAULT_STYLE_NAME = "(No style)"
136
 
 
137
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
138
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
139
  return p.replace("{prompt}", positive), n + negative
140
 
 
141
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
 
143
  eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
144
 
 
145
  controlnet = ControlNetModel.from_pretrained(
146
  "xinsir/controlnet-scribble-sdxl-1.0",
147
  torch_dtype=torch.float16
@@ -173,98 +175,181 @@ pipe_canny = StableDiffusionXLControlNetPipeline.from_pretrained(
173
 
174
  pipe_canny.to(device)
175
 
176
- MAX_IMAGE_PIXELS = 100000000 # Adjust if needed.
177
-
178
- def resize_image(image, max_pixels=MAX_IMAGE_PIXELS):
179
- """Resize an image to have at most max_pixels, maintaining aspect ratio."""
180
- width, height = image.size
181
- if width * height > max_pixels:
182
- scale_factor = (max_pixels / (width * height)) ** 0.5
183
- new_size = (int(width * scale_factor), int(height * scale_factor))
184
- return image.resize(new_size, Image.ANTIALIAS)
185
- return image
186
-
187
- def process(image, prompt, style, detector_name):
188
- # Convert image to RGB mode if it's not already
189
- if image.mode != 'RGB':
190
- image = image.convert('RGB')
191
- image = resize_image(image)
192
-
193
- width, height = image.size
194
-
195
- prompt, negative_prompt = apply_style(style, prompt)
196
-
197
- if detector_name == "hed":
198
- image = HWC3(np.array(image, dtype=np.uint8))
199
- with torch.no_grad():
200
- detected_map = hed(image, scribble=True)
201
- detected_map = HWC3(detected_map)
202
- image = Image.fromarray(detected_map)
203
- images = pipe(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
204
- return images[0]
205
- elif detector_name == "scribble":
206
- image = HWC3(np.array(image, dtype=np.uint8))
207
- with torch.no_grad():
208
- detected_map = nms(image, 127, 3.0)
209
- detected_map = HWC3(detected_map)
210
- image = Image.fromarray(detected_map)
211
- images = pipe(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
212
- return images[0]
213
- elif detector_name == "canny":
214
- image = np.array(image, dtype=np.uint8)
215
- image = cv2.Canny(image, 100, 200)
216
- image = image[:, :, None]
217
- image = np.concatenate([image, image, image], axis=2)
218
- detected_map = image
219
- image = Image.fromarray(detected_map)
220
- images = pipe_canny(prompt, negative_prompt=negative_prompt, image=image, height=height, width=width).images
221
- return images[0]
222
-
223
- block_css = (
224
- code := """
225
- #image_upload {
226
- height: 100% !important;
227
- }
228
- #prompt_input {
229
- height: 100% !important;
230
- }
231
- #select_style {
232
- height: 100% !important;
233
- }
234
- #detect_method {
235
- height: 100% !important;
236
- }
237
- #submit_button {
238
- height: 100% !important;
239
- }
240
- """
241
- )
242
 
243
- def create_demo():
244
- """Create Gradio demo."""
245
-
246
- with gr.Blocks(css=block_css) as demo:
247
- gr.Markdown(DESCRIPTION)
248
- with gr.Row():
249
- with gr.Column():
250
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='editor', type="pil")
251
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", elem_id="prompt_input")
252
- style = gr.Dropdown(STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Select style", elem_id="select_style")
253
- detect_method = gr.Dropdown(choices=["scribble", "hed", "canny"], value="scribble", label="Select Detect Method", elem_id="detect_method")
254
- submit_btn = gr.Button("Generate", elem_id="submit_button")
255
- with gr.Column():
256
- gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=2, height="auto")
257
-
258
- submit_btn.click(process, inputs=[input_image, prompt, style, detect_method], outputs=[gallery])
259
-
260
- # Refresh button to apply the dark theme
261
- refresh_btn = gr.Button("Refresh for Dark Theme")
262
- refresh_btn.click(None, None, None, _js=js_func)
263
-
264
- return demo
265
 
266
- hed = HEDdetector.from_pretrained('lllyasviel/ControlNet')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
- if __name__ == "__main__":
269
- demo = create_demo()
270
- demo.launch(debug=True)
 
34
  }
35
  }
36
  """
 
37
  def nms(x, t, s):
38
  x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
39
 
 
133
  STYLE_NAMES = list(styles.keys())
134
  DEFAULT_STYLE_NAME = "(No style)"
135
 
136
+
137
  def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
138
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
139
  return p.replace("{prompt}", positive), n + negative
140
 
141
+
142
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
 
144
  eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
145
 
146
+
147
  controlnet = ControlNetModel.from_pretrained(
148
  "xinsir/controlnet-scribble-sdxl-1.0",
149
  torch_dtype=torch.float16
 
175
 
176
  pipe_canny.to(device)
177
 
178
+ MAX_SEED = np.iinfo(np.int32).max
179
+ processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
180
+ def nms(x, t, s):
181
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
184
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
185
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
186
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
187
+
188
+ y = np.zeros_like(x)
189
+
190
+ for f in [f1, f2, f3, f4]:
191
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
192
+
193
+ z = np.zeros_like(y, dtype=np.uint8)
194
+ z[y > t] = 255
195
+ return z
 
 
 
 
 
 
 
 
 
196
 
197
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
198
+ if randomize_seed:
199
+ seed = random.randint(0, MAX_SEED)
200
+ return seed
201
+
202
+ @spaces.GPU
203
+ def run(
204
+ image: PIL.Image.Image,
205
+ prompt: str,
206
+ negative_prompt: str,
207
+ style_name: str = DEFAULT_STYLE_NAME,
208
+ num_steps: int = 25,
209
+ guidance_scale: float = 5,
210
+ controlnet_conditioning_scale: float = 1.0,
211
+ seed: int = 0,
212
+ use_hed: bool = False,
213
+ use_canny: bool = False,
214
+ progress=gr.Progress(track_tqdm=True),
215
+ ) -> PIL.Image.Image:
216
+ width, height = image['composite'].size
217
+ ratio = np.sqrt(1024. * 1024. / (width * height))
218
+ new_width, new_height = int(width * ratio), int(height * ratio)
219
+ image = image['composite'].resize((new_width, new_height))
220
+
221
+ if use_canny:
222
+ controlnet_img = np.array(image)
223
+ controlnet_img = cv2.Canny(controlnet_img, 100, 200)
224
+ controlnet_img = HWC3(controlnet_img)
225
+ image = Image.fromarray(controlnet_img)
226
+
227
+ elif not use_hed:
228
+ controlnet_img = image
229
+ else:
230
+ controlnet_img = processor(image, scribble=False)
231
+ # following is some processing to simulate human sketch draw, different threshold can generate different width of lines
232
+ controlnet_img = np.array(controlnet_img)
233
+ controlnet_img = nms(controlnet_img, 127, 3)
234
+ controlnet_img = cv2.GaussianBlur(controlnet_img, (0, 0), 3)
235
+
236
+ # higher threshold, thiner line
237
+ random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
238
+ controlnet_img[controlnet_img > random_val] = 255
239
+ controlnet_img[controlnet_img < 255] = 0
240
+ image = Image.fromarray(controlnet_img)
241
+
242
+
243
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
244
+
245
+ generator = torch.Generator(device=device).manual_seed(seed)
246
+ if use_canny:
247
+ out = pipe_canny(
248
+ prompt=prompt,
249
+ negative_prompt=negative_prompt,
250
+ image=image,
251
+ num_inference_steps=num_steps,
252
+ generator=generator,
253
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
254
+ guidance_scale=guidance_scale,
255
+ width=new_width,
256
+ height=new_height,
257
+ ).images[0]
258
+ else:
259
+ out = pipe(
260
+ prompt=prompt,
261
+ negative_prompt=negative_prompt,
262
+ image=image,
263
+ num_inference_steps=num_steps,
264
+ generator=generator,
265
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
266
+ guidance_scale=guidance_scale,
267
+ width=new_width,
268
+ height=new_height,).images[0]
269
+
270
+ return (controlnet_img, out)
271
+
272
+
273
+ with gr.Blocks(css="style.css", js=js_func) as demo:
274
+ gr.Markdown(DESCRIPTION, elem_id="description")
275
+ gr.DuplicateButton(
276
+ value="Duplicate Space for private use",
277
+ elem_id="duplicate-button",
278
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
279
+ )
280
+
281
+ with gr.Row():
282
+ with gr.Column():
283
+ with gr.Group():
284
+ image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
285
+ prompt = gr.Textbox(label="Prompt")
286
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
287
+ use_hed = gr.Checkbox(label="use HED detector", value=False, info="check this box if you upload an image and want to turn it to a sketch")
288
+ use_canny = gr.Checkbox(label="use Canny", value=False, info="check this to use ControlNet canny instead of scribble")
289
+ run_button = gr.Button("Run")
290
+ with gr.Accordion("Advanced options", open=False):
291
+ negative_prompt = gr.Textbox(
292
+ label="Negative prompt",
293
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
294
+ )
295
+ num_steps = gr.Slider(
296
+ label="Number of steps",
297
+ minimum=1,
298
+ maximum=50,
299
+ step=1,
300
+ value=25,
301
+ )
302
+ guidance_scale = gr.Slider(
303
+ label="Guidance scale",
304
+ minimum=0.1,
305
+ maximum=10.0,
306
+ step=0.1,
307
+ value=5,
308
+ )
309
+ controlnet_conditioning_scale = gr.Slider(
310
+ label="controlnet conditioning scale",
311
+ minimum=0.5,
312
+ maximum=5.0,
313
+ step=0.1,
314
+ value=0.9,
315
+ )
316
+ seed = gr.Slider(
317
+ label="Seed",
318
+ minimum=0,
319
+ maximum=MAX_SEED,
320
+ step=1,
321
+ value=0,
322
+ )
323
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
324
+
325
+ with gr.Column():
326
+ with gr.Group():
327
+ image_slider = ImageSlider(position=0.5)
328
+
329
+
330
+ inputs = [
331
+ image,
332
+ prompt,
333
+ negative_prompt,
334
+ style,
335
+ num_steps,
336
+ guidance_scale,
337
+ controlnet_conditioning_scale,
338
+ seed,
339
+ use_hed,
340
+ use_canny
341
+ ]
342
+ outputs = [image_slider]
343
+ run_button.click(
344
+ fn=randomize_seed_fn,
345
+ inputs=[seed, randomize_seed],
346
+ outputs=seed,
347
+ queue=False,
348
+ api_name=False,
349
+ ).then(lambda x: None, inputs=None, outputs=image_slider).then(
350
+ fn=run, inputs=inputs, outputs=outputs
351
+ )
352
+
353
+
354
 
355
+ demo.queue().launch()