eolecvk commited on
Commit
56b4d7a
·
1 Parent(s): 5925669

add multi gpu app

Browse files
Files changed (2) hide show
  1. app.py +0 -11
  2. app_multi.py +222 -0
app.py CHANGED
@@ -2,24 +2,14 @@ from contextlib import nullcontext
2
  import gradio as gr
3
  import torch
4
  from torch import autocast
5
- <<<<<<< HEAD
6
- from diffusers import StableDiffusionPipeline
7
- =======
8
  from diffusers import StableDiffusionPipeline, StableDiffusionOnnxPipeline
9
- >>>>>>> a680d9594c0ff489aea01c48f81a693c55dffb9d
10
 
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  context = autocast if device == "cuda" else nullcontext
14
  dtype = torch.float16 if device == "cuda" else torch.float32
15
 
16
- <<<<<<< HEAD
17
- pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype)
18
- pipe = pipe.to(device)
19
-
20
-
21
  # Sometimes the nsfw checker is confused by the Pokémon images, you can disable
22
- =======
23
  try:
24
  if device == "cuda":
25
  pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-naruto-diffusers", torch_dtype=dtype)
@@ -38,7 +28,6 @@ except:
38
  pipe = pipe.to(device)
39
 
40
  # Sometimes the nsfw checker is confused by the Naruto images, you can disable
41
- >>>>>>> a680d9594c0ff489aea01c48f81a693c55dffb9d
42
  # it at your own risk here
43
  disable_safety = True
44
 
 
2
  import gradio as gr
3
  import torch
4
  from torch import autocast
 
 
 
5
  from diffusers import StableDiffusionPipeline, StableDiffusionOnnxPipeline
 
6
 
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  context = autocast if device == "cuda" else nullcontext
10
  dtype = torch.float16 if device == "cuda" else torch.float32
11
 
 
 
 
 
 
12
  # Sometimes the nsfw checker is confused by the Pokémon images, you can disable
 
13
  try:
14
  if device == "cuda":
15
  pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-naruto-diffusers", torch_dtype=dtype)
 
28
  pipe = pipe.to(device)
29
 
30
  # Sometimes the nsfw checker is confused by the Naruto images, you can disable
 
31
  # it at your own risk here
32
  disable_safety = True
33
 
app_multi.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+ import gradio as gr
3
+ import torch
4
+ from torch import autocast
5
+ from diffusers import StableDiffusionPipeline
6
+ from ray.serve.gradio_integrations import GradioServer
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ context = autocast if device == "cuda" else nullcontext
10
+ dtype = torch.float16 if device == "cuda" else torch.float32
11
+
12
+ # Sometimes the nsfw checker is confused by the Naruto images, you can disable
13
+ try:
14
+ if device == "cuda":
15
+ pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-naruto-diffusers", torch_dtype=dtype)
16
+
17
+ else:
18
+ pipe = StableDiffusionOnnxPipeline.from_pretrained(
19
+ "lambdalabs/sd-naruto-diffusers",
20
+ revision="onnx",
21
+ provider="CPUExecutionProvider"
22
+ )
23
+
24
+ # onnx model revision not available
25
+ except:
26
+ pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-naruto-diffusers", torch_dtype=dtype)
27
+
28
+ pipe = pipe.to(device)
29
+
30
+ # Sometimes the nsfw checker is confused by the Naruto images, you can disable
31
+ # it at your own risk here
32
+ disable_safety = True
33
+
34
+ if disable_safety:
35
+ def null_safety(images, **kwargs):
36
+ return images, False
37
+ pipe.safety_checker = null_safety
38
+
39
+
40
+ def infer(prompt, n_samples, steps, scale):
41
+
42
+ with context("cuda"):
43
+ images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
44
+
45
+ return images
46
+
47
+ css = """
48
+ a {
49
+ color: inherit;
50
+ text-decoration: underline;
51
+ }
52
+ .gradio-container {
53
+ font-family: 'IBM Plex Sans', sans-serif;
54
+ }
55
+ .gr-button {
56
+ color: white;
57
+ border-color: #9d66e5;
58
+ background: #9d66e5;
59
+ }
60
+ input[type='range'] {
61
+ accent-color: #9d66e5;
62
+ }
63
+ .dark input[type='range'] {
64
+ accent-color: #dfdfdf;
65
+ }
66
+ .container {
67
+ max-width: 730px;
68
+ margin: auto;
69
+ padding-top: 1.5rem;
70
+ }
71
+ #gallery {
72
+ min-height: 22rem;
73
+ margin-bottom: 15px;
74
+ margin-left: auto;
75
+ margin-right: auto;
76
+ border-bottom-right-radius: .5rem !important;
77
+ border-bottom-left-radius: .5rem !important;
78
+ }
79
+ #gallery>div>.h-full {
80
+ min-height: 20rem;
81
+ }
82
+ .details:hover {
83
+ text-decoration: underline;
84
+ }
85
+ .gr-button {
86
+ white-space: nowrap;
87
+ }
88
+ .gr-button:focus {
89
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
90
+ outline: none;
91
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
92
+ --tw-border-opacity: 1;
93
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
94
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
95
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
96
+ --tw-ring-opacity: .5;
97
+ }
98
+ #advanced-options {
99
+ margin-bottom: 20px;
100
+ }
101
+ .footer {
102
+ margin-bottom: 45px;
103
+ margin-top: 35px;
104
+ text-align: center;
105
+ border-bottom: 1px solid #e5e5e5;
106
+ }
107
+ .footer>p {
108
+ font-size: .8rem;
109
+ display: inline-block;
110
+ padding: 0 10px;
111
+ transform: translateY(10px);
112
+ background: white;
113
+ }
114
+ .dark .logo{ filter: invert(1); }
115
+ .dark .footer {
116
+ border-color: #303030;
117
+ }
118
+ .dark .footer>p {
119
+ background: #0b0f19;
120
+ }
121
+ .acknowledgments h4{
122
+ margin: 1.25em 0 .25em 0;
123
+ font-weight: bold;
124
+ font-size: 115%;
125
+ }
126
+ """
127
+
128
+ block = gr.Blocks(css=css)
129
+
130
+ examples = [
131
+ [
132
+ 'Bill Gates with a hoodie',
133
+ 2,
134
+ 7.5,
135
+ ],
136
+ [
137
+ 'Jon Snow ninja portrait',
138
+ 2,
139
+ 7.5,
140
+ ],
141
+ [
142
+ 'Leo Messi in the style of Naruto',
143
+ 2,
144
+ 7.5
145
+ ],
146
+ ]
147
+
148
+ with block:
149
+ gr.HTML(
150
+ """
151
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
152
+ <div>
153
+ <img class="logo" src="https://lambdalabs.com/hubfs/logos/lambda-logo.svg" alt="Lambda Logo"
154
+ style="margin: auto; max-width: 7rem;">
155
+ <h1 style="font-weight: 900; font-size: 3rem;">
156
+ Naruto text to image
157
+ </h1>
158
+ </div>
159
+ <p style="margin-bottom: 10px; font-size: 94%">
160
+ Generate new Naruto anime character from a text description,
161
+ <a href="https://lambdalabs.com/blog/how-to-fine-tune-stable-diffusion-how-we-made-the-text-to-pokemon-model-at-lambda/">created by Lambda Labs</a>.
162
+ </p>
163
+ </div>
164
+ """
165
+ )
166
+ with gr.Group():
167
+ with gr.Box():
168
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
169
+ text = gr.Textbox(
170
+ label="Enter your prompt",
171
+ show_label=False,
172
+ max_lines=1,
173
+ placeholder="Enter your prompt",
174
+ ).style(
175
+ border=(True, False, True, True),
176
+ rounded=(True, False, False, True),
177
+ container=False,
178
+ )
179
+ btn = gr.Button("Generate image").style(
180
+ margin=False,
181
+ rounded=(False, True, True, False),
182
+ )
183
+
184
+ gallery = gr.Gallery(
185
+ label="Generated images", show_label=False, elem_id="gallery"
186
+ ).style(grid=[2], height="auto")
187
+
188
+
189
+ with gr.Row(elem_id="advanced-options"):
190
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1)
191
+ steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=45, step=5)
192
+ scale = gr.Slider(
193
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
194
+ )
195
+
196
+
197
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False)
198
+ ex.dataset.headers = [""]
199
+
200
+
201
+ text.submit(infer, inputs=[text, samples, steps, scale], outputs=gallery)
202
+ btn.click(infer, inputs=[text, samples, steps, scale], outputs=gallery)
203
+ gr.HTML(
204
+ """
205
+ <div class="footer">
206
+ <p> Gradio Demo by 🤗 Hugging Face and Lambda Labs
207
+ </p>
208
+ </div>
209
+ <div class="acknowledgments">
210
+ <p> Put in a text prompt and generate your own Naruto anime character!
211
+ <p> Here are some <a href="https://huggingface.co/lambdalabs/sd-naruto-diffusers">examples</a> of generated images.
212
+ <p>If you want to find out how we made this model read about it in <a href="https://lambdalabs.com/blog/how-to-fine-tune-stable-diffusion-how-we-made-the-text-to-pokemon-model-at-lambda/">this blog post</a>.
213
+ <p>And if you want to train your own Stable Diffusion variants, see our <a href="https://github.com/LambdaLabsML/examples/tree/main/stable-diffusion-finetuning">Examples Repo</a>!
214
+ <p>Trained by Eole Cervenka at <a href="https://lambdalabs.com/">Lambda Labs</a>.</p>
215
+ </div>
216
+ """
217
+ )
218
+
219
+ #block.launch()
220
+
221
+ io = block
222
+ app = GradioServer.options(num_replicas=2, ray_actor_options={"num_cpus": 6.0, "num_gpus" : 1.0}).bind(io)