IDKiro commited on
Commit
df1b0df
·
verified ·
1 Parent(s): 033d00f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -17
app.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
 
6
  import torch
7
  import torchvision.transforms.functional as F
8
- from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
9
  import gradio as gr
10
 
11
  device = "cuda"
@@ -19,6 +19,16 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
19
  )
20
  pipe.to(device)
21
 
 
 
 
 
 
 
 
 
 
 
22
  style_list = [
23
  {
24
  "name": "No Style",
@@ -81,9 +91,15 @@ def run(
81
  prompt_template,
82
  style_name,
83
  controlnet_conditioning_scale,
 
84
  device_type="GPU",
85
  param_dtype="torch.float16",
86
  ):
 
 
 
 
 
87
  if device_type == "CPU":
88
  device = "cpu"
89
  param_dtype = "torch.float32"
@@ -118,24 +134,28 @@ def run(
118
  return output_pil
119
 
120
 
121
- with gr.Blocks() as demo:
122
  gr.Markdown("# SDXS-512-DreamShaper-Sketch")
123
- gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
124
- with gr.Row(elem_id="main_row"):
125
- with gr.Column(elem_id="column_input"):
126
- gr.Markdown("## INPUT", elem_id="input_header")
 
 
127
  image = gr.Sketchpad(
128
  type="pil",
129
  image_mode="RGBA",
130
  brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=8),
131
- crop_size=(512, 512),
132
  )
133
 
134
- # gr.Markdown("## Prompt", elem_id="tools_header")
135
  prompt = gr.Textbox(label="Prompt", value="", show_label=True)
136
  with gr.Row():
137
  style = gr.Dropdown(
138
- label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1
 
 
 
139
  )
140
  prompt_temp = gr.Textbox(
141
  label="Prompt Style Template",
@@ -148,6 +168,15 @@ with gr.Blocks() as demo:
148
  label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8
149
  )
150
 
 
 
 
 
 
 
 
 
 
151
  device_choices = ["GPU", "CPU"]
152
  device_type = gr.Radio(
153
  device_choices,
@@ -166,16 +195,19 @@ with gr.Blocks() as demo:
166
  info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
167
  )
168
 
169
- with gr.Column(elem_id="column_output"):
170
- gr.Markdown("## OUTPUT", elem_id="output_header")
171
  result = gr.Image(
172
  label="Result",
173
- height=512,
174
- width=512,
175
- elem_id="output_image",
176
  show_label=False,
177
  show_download_button=True,
178
  )
 
 
 
 
 
 
179
 
180
  inputs = [
181
  image,
@@ -183,6 +215,7 @@ with gr.Blocks() as demo:
183
  prompt_temp,
184
  style,
185
  controlnet_conditioning_scale,
 
186
  device_type,
187
  param_dtype,
188
  ]
@@ -190,9 +223,25 @@ with gr.Blocks() as demo:
190
 
191
  prompt.change(fn=run, inputs=inputs, outputs=outputs)
192
  style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
193
- fn=run, inputs=inputs, outputs=outputs,)
194
- image.change(run, inputs=inputs, outputs=outputs,)
195
- controlnet_conditioning_scale.change(run, inputs=inputs, outputs=outputs,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if __name__ == "__main__":
198
  demo.queue().launch()
 
5
 
6
  import torch
7
  import torchvision.transforms.functional as F
8
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, AutoencoderKL
9
  import gradio as gr
10
 
11
  device = "cuda"
 
19
  )
20
  pipe.to(device)
21
 
22
+ vae_tiny = AutoencoderTiny.from_pretrained(
23
+ "IDKiro/sdxs-512-dreamshaper", subfolder="vae"
24
+ )
25
+ vae_tiny.to(device, dtype=weight_type)
26
+
27
+ vae_large = AutoencoderKL.from_pretrained(
28
+ "IDKiro/sdxs-512-dreamshaper", subfolder="vae_large"
29
+ )
30
+ vae_tiny.to(device, dtype=weight_type)
31
+
32
  style_list = [
33
  {
34
  "name": "No Style",
 
91
  prompt_template,
92
  style_name,
93
  controlnet_conditioning_scale,
94
+ vae_type="tiny vae",
95
  device_type="GPU",
96
  param_dtype="torch.float16",
97
  ):
98
+ if vae_type == "tiny vae":
99
+ pipe.vae = vae_tiny
100
+ elif vae_type == "large vae":
101
+ pipe.vae = vae_large
102
+
103
  if device_type == "CPU":
104
  device = "cpu"
105
  param_dtype = "torch.float32"
 
134
  return output_pil
135
 
136
 
137
+ with gr.Blocks(theme="monochrome") as demo:
138
  gr.Markdown("# SDXS-512-DreamShaper-Sketch")
139
+ gr.Markdown(
140
+ "[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)"
141
+ )
142
+ with gr.Row():
143
+ with gr.Column():
144
+ gr.Markdown("## INPUT")
145
  image = gr.Sketchpad(
146
  type="pil",
147
  image_mode="RGBA",
148
  brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=8),
149
+ crop_size="1:1",
150
  )
151
 
 
152
  prompt = gr.Textbox(label="Prompt", value="", show_label=True)
153
  with gr.Row():
154
  style = gr.Dropdown(
155
+ label="Style",
156
+ choices=STYLE_NAMES,
157
+ value=DEFAULT_STYLE_NAME,
158
+ scale=1,
159
  )
160
  prompt_temp = gr.Textbox(
161
  label="Prompt Style Template",
 
168
  label="Control Strength", minimum=0, maximum=1, step=0.01, value=0.8
169
  )
170
 
171
+ vae_choices = ["tiny vae", "large vae"]
172
+ vae_type = gr.Radio(
173
+ vae_choices,
174
+ label="Image Decoder Type",
175
+ value=vae_choices[0],
176
+ interactive=True,
177
+ info="To save GPU memory, use tiny vae. For better quality, use large vae.",
178
+ )
179
+
180
  device_choices = ["GPU", "CPU"]
181
  device_type = gr.Radio(
182
  device_choices,
 
195
  info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
196
  )
197
 
198
+ with gr.Column():
199
+ gr.Markdown("## OUTPUT")
200
  result = gr.Image(
201
  label="Result",
 
 
 
202
  show_label=False,
203
  show_download_button=True,
204
  )
205
+ run_button = gr.Button("Run")
206
+ gr.Markdown("### Instructions")
207
+ gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
208
+ gr.Markdown("**2**. Start sketching")
209
+ gr.Markdown("**3**. Change the image style using a style template")
210
+ gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
211
 
212
  inputs = [
213
  image,
 
215
  prompt_temp,
216
  style,
217
  controlnet_conditioning_scale,
218
+ vae_type,
219
  device_type,
220
  param_dtype,
221
  ]
 
223
 
224
  prompt.change(fn=run, inputs=inputs, outputs=outputs)
225
  style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp]).then(
226
+ fn=run,
227
+ inputs=inputs,
228
+ outputs=outputs,
229
+ )
230
+ image.change(
231
+ run,
232
+ inputs=inputs,
233
+ outputs=outputs,
234
+ )
235
+ controlnet_conditioning_scale.change(
236
+ run,
237
+ inputs=inputs,
238
+ outputs=outputs,
239
+ )
240
+ run_button.click(
241
+ run,
242
+ inputs=inputs,
243
+ outputs=outputs,
244
+ )
245
 
246
  if __name__ == "__main__":
247
  demo.queue().launch()