Warlord-K commited on
Commit
1aad4f8
·
1 Parent(s): b29a22b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -2
app.py CHANGED
@@ -1,3 +1,306 @@
1
- import gradio as gr
 
 
 
2
  import os
3
- gr.Interface.load("models/segmind/SSD-1B", api_key=os.environ.get("HF_TOKEN")).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
  import os
6
+ import random
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers import AutoencoderKL, DiffusionPipeline
13
+
14
+ DESCRIPTION = "# SD-XL"
15
+ if not torch.cuda.is_available():
16
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
20
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
21
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
22
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
23
+ ENABLE_REFINER = os.getenv("ENABLE_REFINER", "0") == "1"
24
+
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+ if torch.cuda.is_available():
27
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
28
+ pipe = DiffusionPipeline.from_pretrained(
29
+ "segmind/SSD-1B",
30
+ vae=vae,
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=True,
33
+ variant="fp16",
34
+ )
35
+ if ENABLE_REFINER:
36
+ refiner = DiffusionPipeline.from_pretrained(
37
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
38
+ vae=vae,
39
+ torch_dtype=torch.float16,
40
+ use_safetensors=True,
41
+ variant="fp16",
42
+ )
43
+
44
+ if ENABLE_CPU_OFFLOAD:
45
+ pipe.enable_model_cpu_offload()
46
+ if ENABLE_REFINER:
47
+ refiner.enable_model_cpu_offload()
48
+ else:
49
+ pipe.to(device)
50
+ if ENABLE_REFINER:
51
+ refiner.to(device)
52
+
53
+ if USE_TORCH_COMPILE:
54
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
55
+ if ENABLE_REFINER:
56
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
57
+
58
+
59
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
60
+ if randomize_seed:
61
+ seed = random.randint(0, MAX_SEED)
62
+ return seed
63
+
64
+
65
+ def generate(
66
+ prompt: str,
67
+ negative_prompt: str = "",
68
+ prompt_2: str = "",
69
+ negative_prompt_2: str = "",
70
+ use_negative_prompt: bool = False,
71
+ use_prompt_2: bool = False,
72
+ use_negative_prompt_2: bool = False,
73
+ seed: int = 0,
74
+ width: int = 1024,
75
+ height: int = 1024,
76
+ guidance_scale_base: float = 5.0,
77
+ guidance_scale_refiner: float = 5.0,
78
+ num_inference_steps_base: int = 25,
79
+ num_inference_steps_refiner: int = 25,
80
+ apply_refiner: bool = False,
81
+ ) -> PIL.Image.Image:
82
+ generator = torch.Generator().manual_seed(seed)
83
+
84
+ if not use_negative_prompt:
85
+ negative_prompt = None # type: ignore
86
+ if not use_prompt_2:
87
+ prompt_2 = None # type: ignore
88
+ if not use_negative_prompt_2:
89
+ negative_prompt_2 = None # type: ignore
90
+
91
+ if not apply_refiner:
92
+ return pipe(
93
+ prompt=prompt,
94
+ negative_prompt=negative_prompt,
95
+ prompt_2=prompt_2,
96
+ negative_prompt_2=negative_prompt_2,
97
+ width=width,
98
+ height=height,
99
+ guidance_scale=guidance_scale_base,
100
+ num_inference_steps=num_inference_steps_base,
101
+ generator=generator,
102
+ output_type="pil",
103
+ ).images[0]
104
+ else:
105
+ latents = pipe(
106
+ prompt=prompt,
107
+ negative_prompt=negative_prompt,
108
+ prompt_2=prompt_2,
109
+ negative_prompt_2=negative_prompt_2,
110
+ width=width,
111
+ height=height,
112
+ guidance_scale=guidance_scale_base,
113
+ num_inference_steps=num_inference_steps_base,
114
+ generator=generator,
115
+ output_type="latent",
116
+ ).images
117
+ image = refiner(
118
+ prompt=prompt,
119
+ negative_prompt=negative_prompt,
120
+ prompt_2=prompt_2,
121
+ negative_prompt_2=negative_prompt_2,
122
+ guidance_scale=guidance_scale_refiner,
123
+ num_inference_steps=num_inference_steps_refiner,
124
+ image=latents,
125
+ generator=generator,
126
+ ).images[0]
127
+ return image
128
+
129
+
130
+ examples = [
131
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
132
+ "An astronaut riding a green horse",
133
+ ]
134
+
135
+ with gr.Blocks(css="style.css") as demo:
136
+ gr.Markdown(DESCRIPTION)
137
+ gr.DuplicateButton(
138
+ value="Duplicate Space for private use",
139
+ elem_id="duplicate-button",
140
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
141
+ )
142
+ with gr.Group():
143
+ with gr.Row():
144
+ prompt = gr.Text(
145
+ label="Prompt",
146
+ show_label=False,
147
+ max_lines=1,
148
+ placeholder="Enter your prompt",
149
+ container=False,
150
+ )
151
+ run_button = gr.Button("Run", scale=0)
152
+ result = gr.Image(label="Result", show_label=False)
153
+ with gr.Accordion("Advanced options", open=False):
154
+ with gr.Row():
155
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
156
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
157
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
158
+ negative_prompt = gr.Text(
159
+ label="Negative prompt",
160
+ max_lines=1,
161
+ placeholder="Enter a negative prompt",
162
+ visible=False,
163
+ )
164
+ prompt_2 = gr.Text(
165
+ label="Prompt 2",
166
+ max_lines=1,
167
+ placeholder="Enter your prompt",
168
+ visible=False,
169
+ )
170
+ negative_prompt_2 = gr.Text(
171
+ label="Negative prompt 2",
172
+ max_lines=1,
173
+ placeholder="Enter a negative prompt",
174
+ visible=False,
175
+ )
176
+
177
+ seed = gr.Slider(
178
+ label="Seed",
179
+ minimum=0,
180
+ maximum=MAX_SEED,
181
+ step=1,
182
+ value=0,
183
+ )
184
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
185
+ with gr.Row():
186
+ width = gr.Slider(
187
+ label="Width",
188
+ minimum=256,
189
+ maximum=MAX_IMAGE_SIZE,
190
+ step=32,
191
+ value=1024,
192
+ )
193
+ height = gr.Slider(
194
+ label="Height",
195
+ minimum=256,
196
+ maximum=MAX_IMAGE_SIZE,
197
+ step=32,
198
+ value=1024,
199
+ )
200
+ apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
201
+ with gr.Row():
202
+ guidance_scale_base = gr.Slider(
203
+ label="Guidance scale for base",
204
+ minimum=1,
205
+ maximum=20,
206
+ step=0.1,
207
+ value=5.0,
208
+ )
209
+ num_inference_steps_base = gr.Slider(
210
+ label="Number of inference steps for base",
211
+ minimum=10,
212
+ maximum=100,
213
+ step=1,
214
+ value=25,
215
+ )
216
+ with gr.Row(visible=False) as refiner_params:
217
+ guidance_scale_refiner = gr.Slider(
218
+ label="Guidance scale for refiner",
219
+ minimum=1,
220
+ maximum=20,
221
+ step=0.1,
222
+ value=5.0,
223
+ )
224
+ num_inference_steps_refiner = gr.Slider(
225
+ label="Number of inference steps for refiner",
226
+ minimum=10,
227
+ maximum=100,
228
+ step=1,
229
+ value=25,
230
+ )
231
+
232
+ gr.Examples(
233
+ examples=examples,
234
+ inputs=prompt,
235
+ outputs=result,
236
+ fn=generate,
237
+ cache_examples=CACHE_EXAMPLES,
238
+ )
239
+
240
+ use_negative_prompt.change(
241
+ fn=lambda x: gr.update(visible=x),
242
+ inputs=use_negative_prompt,
243
+ outputs=negative_prompt,
244
+ queue=False,
245
+ api_name=False,
246
+ )
247
+ use_prompt_2.change(
248
+ fn=lambda x: gr.update(visible=x),
249
+ inputs=use_prompt_2,
250
+ outputs=prompt_2,
251
+ queue=False,
252
+ api_name=False,
253
+ )
254
+ use_negative_prompt_2.change(
255
+ fn=lambda x: gr.update(visible=x),
256
+ inputs=use_negative_prompt_2,
257
+ outputs=negative_prompt_2,
258
+ queue=False,
259
+ api_name=False,
260
+ )
261
+ apply_refiner.change(
262
+ fn=lambda x: gr.update(visible=x),
263
+ inputs=apply_refiner,
264
+ outputs=refiner_params,
265
+ queue=False,
266
+ api_name=False,
267
+ )
268
+
269
+ gr.on(
270
+ triggers=[
271
+ prompt.submit,
272
+ negative_prompt.submit,
273
+ prompt_2.submit,
274
+ negative_prompt_2.submit,
275
+ run_button.click,
276
+ ],
277
+ fn=randomize_seed_fn,
278
+ inputs=[seed, randomize_seed],
279
+ outputs=seed,
280
+ queue=False,
281
+ api_name=False,
282
+ ).then(
283
+ fn=generate,
284
+ inputs=[
285
+ prompt,
286
+ negative_prompt,
287
+ prompt_2,
288
+ negative_prompt_2,
289
+ use_negative_prompt,
290
+ use_prompt_2,
291
+ use_negative_prompt_2,
292
+ seed,
293
+ width,
294
+ height,
295
+ guidance_scale_base,
296
+ guidance_scale_refiner,
297
+ num_inference_steps_base,
298
+ num_inference_steps_refiner,
299
+ apply_refiner,
300
+ ],
301
+ outputs=result,
302
+ api_name="run",
303
+ )
304
+
305
+ if __name__ == "__main__":
306
+ demo.queue(max_size=20).launch()