linoyts HF staff commited on
Commit
4a2500d
·
verified ·
1 Parent(s): f67b21b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -0
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ import torchvision.transforms.functional as TF
11
+
12
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
13
+ from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler
14
+ from controlnet_aux import PidiNetDetector, HEDdetector
15
+ from diffusers.utils import load_image
16
+ from huggingface_hub import HfApi
17
+ from pathlib import Path
18
+ from PIL import Image
19
+ import torch
20
+ import numpy as np
21
+ import cv2
22
+ import os
23
+ import random
24
+
25
+
26
+ def nms(x, t, s):
27
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
28
+
29
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
30
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
31
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
32
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
33
+
34
+ y = np.zeros_like(x)
35
+
36
+ for f in [f1, f2, f3, f4]:
37
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
38
+
39
+ z = np.zeros_like(y, dtype=np.uint8)
40
+ z[y > t] = 255
41
+ return z
42
+
43
+
44
+ DESCRIPTION = '''#
45
+ sketch to image with SDXL, using [@xinsir](https://huggingface.co/xinsir) [scribble sdxl controlnet](https://huggingface.co/xinsir/controlnet-scribble-sdxl-1.0)
46
+ '''
47
+
48
+ if not torch.cuda.is_available():
49
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
50
+
51
+ style_list = [
52
+ {
53
+ "name": "(No style)",
54
+ "prompt": "{prompt}",
55
+ "negative_prompt": "",
56
+ },
57
+ {
58
+ "name": "Cinematic",
59
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
60
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
61
+ },
62
+ {
63
+ "name": "3D Model",
64
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
65
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
66
+ },
67
+ {
68
+ "name": "Anime",
69
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
70
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
71
+ },
72
+ {
73
+ "name": "Digital Art",
74
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
75
+ "negative_prompt": "photo, photorealistic, realism, ugly",
76
+ },
77
+ {
78
+ "name": "Photographic",
79
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
80
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
81
+ },
82
+ {
83
+ "name": "Pixel art",
84
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
85
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
86
+ },
87
+ {
88
+ "name": "Fantasy art",
89
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
90
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
91
+ },
92
+ {
93
+ "name": "Neonpunk",
94
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
95
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
96
+ },
97
+ {
98
+ "name": "Manga",
99
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
100
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
101
+ },
102
+ ]
103
+
104
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
105
+ STYLE_NAMES = list(styles.keys())
106
+ DEFAULT_STYLE_NAME = "(No style)"
107
+
108
+
109
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
110
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
111
+ return p.replace("{prompt}", positive), n + negative
112
+
113
+
114
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
115
+
116
+ eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
117
+
118
+
119
+ controlnet = ControlNetModel.from_pretrained(
120
+ "xinsir/controlnet-scribble-sdxl-1.0",
121
+ torch_dtype=torch.float16
122
+ )
123
+
124
+ # when test with other base model, you need to change the vae also.
125
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
126
+
127
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
128
+ "stabilityai/stable-diffusion-xl-base-1.0",
129
+ controlnet=controlnet,
130
+ vae=vae,
131
+ torch_dtype=torch.float16,
132
+ scheduler=eulera_scheduler,
133
+ )
134
+ pipe.to(device)
135
+ # Load model.
136
+
137
+ MAX_SEED = np.iinfo(np.int32).max
138
+ processor = HEDdetector.from_pretrained('lllyasviel/Annotators')
139
+ def nms(x, t, s):
140
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
141
+
142
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
143
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
144
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
145
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
146
+
147
+ y = np.zeros_like(x)
148
+
149
+ for f in [f1, f2, f3, f4]:
150
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
151
+
152
+ z = np.zeros_like(y, dtype=np.uint8)
153
+ z[y > t] = 255
154
+ return z
155
+
156
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
157
+ if randomize_seed:
158
+ seed = random.randint(0, MAX_SEED)
159
+ return seed
160
+
161
+
162
+ def run(
163
+ image: PIL.Image.Image,
164
+ prompt: str,
165
+ negative_prompt: str,
166
+ style_name: str = DEFAULT_STYLE_NAME,
167
+ num_steps: int = 25,
168
+ guidance_scale: float = 5,
169
+ controlnet_conditioning_scale: float = 1.0,
170
+ seed: int = 0,
171
+ use_hed: bool = False,
172
+ progress=gr.Progress(track_tqdm=True),
173
+ ) -> PIL.Image.Image:
174
+ # image = image.convert("RGB")
175
+ # image = TF.to_tensor(image) > 0.5
176
+ # image = TF.to_pil_image(image.to(torch.float32))
177
+ width, height = image['composite'].size
178
+ ratio = np.sqrt(1024. * 1024. / (width * height))
179
+ new_width, new_height = int(width * ratio), int(height * ratio)
180
+ image = image['composite'].resize((new_width, new_height))
181
+
182
+ if use_hed:
183
+ controlnet_img = processor(image, scribble=False)
184
+ # following is some processing to simulate human sketch draw, different threshold can generate different width of lines
185
+ controlnet_img = np.array(controlnet_img)
186
+ controlnet_img = nms(controlnet_img, 127, 3)
187
+ controlnet_img = cv2.GaussianBlur(controlnet_img, (0, 0), 3)
188
+
189
+ # higher threshold, thiner line
190
+ random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
191
+ controlnet_img[controlnet_img > random_val] = 255
192
+ controlnet_img[controlnet_img < 255] = 0
193
+ image = Image.fromarray(controlnet_img)
194
+
195
+
196
+
197
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
198
+
199
+ generator = torch.Generator(device=device).manual_seed(seed)
200
+ out = pipe(
201
+ prompt=prompt,
202
+ negative_prompt=negative_prompt,
203
+ image=image,
204
+ num_inference_steps=num_steps,
205
+ generator=generator,
206
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
207
+ guidance_scale=guidance_scale,
208
+ width=new_width,
209
+ height=new_height,
210
+ ).images[0]
211
+
212
+ return out
213
+
214
+
215
+ with gr.Blocks(css="style.css") as demo:
216
+ gr.Markdown(DESCRIPTION, elem_id="description")
217
+ gr.DuplicateButton(
218
+ value="Duplicate Space for private use",
219
+ elem_id="duplicate-button",
220
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
221
+ )
222
+
223
+ with gr.Row():
224
+ with gr.Column():
225
+ with gr.Group():
226
+ image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512),brush=gr.Brush(color_mode="fixed", colors=["#00000"]))
227
+ prompt = gr.Textbox(label="Prompt")
228
+ style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
229
+ use_hed = gr.Checkbox(label="use HED detector", value=False)
230
+ run_button = gr.Button("Run")
231
+ with gr.Accordion("Advanced options", open=False):
232
+ negative_prompt = gr.Textbox(
233
+ label="Negative prompt",
234
+ value=" extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured",
235
+ )
236
+ num_steps = gr.Slider(
237
+ label="Number of steps",
238
+ minimum=1,
239
+ maximum=50,
240
+ step=1,
241
+ value=25,
242
+ )
243
+ guidance_scale = gr.Slider(
244
+ label="Guidance scale",
245
+ minimum=0.1,
246
+ maximum=10.0,
247
+ step=0.1,
248
+ value=5,
249
+ )
250
+ controlnet_conditioning_scale = gr.Slider(
251
+ label="controlnet conditioning scale",
252
+ minimum=0.5,
253
+ maximum=5.0,
254
+ step=0.1,
255
+ value=0.9,
256
+ )
257
+ seed = gr.Slider(
258
+ label="Seed",
259
+ minimum=0,
260
+ maximum=MAX_SEED,
261
+ step=1,
262
+ value=0,
263
+ )
264
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
265
+
266
+ with gr.Column():
267
+ result = gr.Image(label="Result", height=400)
268
+
269
+ inputs = [
270
+ image,
271
+ prompt,
272
+ negative_prompt,
273
+ style,
274
+ num_steps,
275
+ guidance_scale,
276
+ controlnet_conditioning_scale,
277
+ seed,
278
+ use_hed,
279
+ ]
280
+ prompt.submit(
281
+ fn=randomize_seed_fn,
282
+ inputs=[seed, randomize_seed],
283
+ outputs=seed,
284
+ queue=False,
285
+ api_name=False,
286
+ ).then(
287
+ fn=run,
288
+ inputs=inputs,
289
+ outputs=result,
290
+ api_name=False,
291
+ )
292
+ negative_prompt.submit(
293
+ fn=randomize_seed_fn,
294
+ inputs=[seed, randomize_seed],
295
+ outputs=seed,
296
+ queue=False,
297
+ api_name=False,
298
+ ).then(
299
+ fn=run,
300
+ inputs=inputs,
301
+ outputs=result,
302
+ api_name=False,
303
+ )
304
+ run_button.click(
305
+ fn=randomize_seed_fn,
306
+ inputs=[seed, randomize_seed],
307
+ outputs=seed,
308
+ queue=False,
309
+ api_name=False,
310
+ ).then(
311
+ fn=run,
312
+ inputs=inputs,
313
+ outputs=result,
314
+ api_name=False,
315
+ )
316
+
317
+ demo.queue().launch()